package sbt

import java.util.concurrent.ConcurrentHashMap
import TaskName._

private[sbt] final class TaskTimings extends ExecuteProgress[Task] {
  private[this] val calledBy = new ConcurrentHashMap[Task[_], Task[_]]
  private[this] val anonOwners = new ConcurrentHashMap[Task[_], Task[_]]
  private[this] val timings = new ConcurrentHashMap[Task[_], Long]
  private[this] var start = 0L

  type S = Unit

  def initial = { start = System.nanoTime }
  def registered(state: Unit, task: Task[_], allDeps: Iterable[Task[_]], pendingDeps: Iterable[Task[_]]) = {
    pendingDeps foreach { t => if (transformNode(t).isEmpty) anonOwners.put(t, task) }
  }
  def ready(state: Unit, task: Task[_]) = ()
  def workStarting(task: Task[_]) = timings.put(task, System.nanoTime)
  def workFinished[T](task: Task[T], result: Either[Task[T], Result[T]]) = {
    timings.put(task, System.nanoTime - timings.get(task))
    result.left.foreach { t => calledBy.put(t, task) }
  }
  def completed[T](state: Unit, task: Task[T], result: Result[T]) = ()
  def allCompleted(state: Unit, results: RMap[Task, Result]) =
    {
      val total = System.nanoTime - start
      println("Total time: " + (total * 1e-6) + " ms")
      import collection.JavaConversions._
      def sumTimes(in: Seq[(Task[_], Long)]) = in.map(_._2).sum
      val timingsByName = timings.toSeq.groupBy { case (t, time) => mappedName(t) } mapValues (sumTimes)
      for ((taskName, time) <- timingsByName.toSeq.sortBy(_._2).reverse)
        println("  " + taskName + ": " + (time * 1e-6) + " ms")
    }
  private[this] def inferredName(t: Task[_]): Option[String] = nameDelegate(t) map mappedName
  private[this] def nameDelegate(t: Task[_]): Option[Task[_]] = Option(anonOwners.get(t)) orElse Option(calledBy.get(t))
  private[this] def mappedName(t: Task[_]): String = definedName(t) orElse inferredName(t) getOrElse anonymousName(t)
}