package sbt
package compiler

import scala.reflect.Manifest
import scala.tools.nsc.{ast, interpreter, io, reporters, util, CompilerCommand, Global, Phase, Settings}
import interpreter.AbstractFileClassLoader
import io.{AbstractFile, PlainFile, VirtualDirectory}
import ast.parser.Tokens
import reporters.{ConsoleReporter, Reporter}
import scala.reflect.internal.util.BatchSourceFile
import Tokens.{EOF, NEWLINE, NEWLINES, SEMI}
import java.io.File
import java.nio.ByteBuffer
import java.net.URLClassLoader
import Eval.{getModule, getValue, WrapValName}

// TODO: provide a way to cleanup backing directory

final class EvalImports(val strings: Seq[(String,Int)], val srcName: String)

/** The result of evaluating a Scala expression.  The inferred type of the expression is given by `tpe`. 
* The value may be obtained from `getValue` by providing a parent class loader that provides the classes from the classpath 
* this expression was compiled against.  Each call to `getValue` constructs a new class loader and loads
* the module from that class loader.  `generated` contains the compiled classes and cache files related
* to the expression.  The name of the auto-generated module wrapping the expression is `enclosingModule`. */
final class EvalResult(val tpe: String, val getValue: ClassLoader => Any, val generated: Seq[File], val enclosingModule: String)

/** The result of evaluating a group of Scala definitions.  The definitions are wrapped in an auto-generated,
* top-level module named `enclosingModule`.  `generated` contains the compiled classes and cache files related to the definitions.
* A new class loader containing the module may be obtained from `loader` by passing the parent class loader providing the classes
* from the classpath that the definitions were compiled against.  The list of vals with the requested types is `valNames`.
* The values for these may be obtained by providing the parent class loader to `values` as is done with `loader`.*/
final class EvalDefinitions(val loader: ClassLoader => ClassLoader, val generated: Seq[File], val enclosingModule: String, val valNames: Seq[String])
{
	def values(parent: ClassLoader): Seq[Any] = {
		val module = getModule(enclosingModule, loader(parent))
		for(n <- valNames) yield
			module.getClass.getMethod(n).invoke(module)
	}
}

final class EvalException(msg: String) extends RuntimeException(msg)
// not thread safe, since it reuses a Global instance
final class Eval(optionsNoncp: Seq[String], classpath: Seq[File], mkReporter: Settings => Reporter, backing: Option[File])
{
	def this(mkReporter: Settings => Reporter, backing: Option[File]) = this(Nil, IO.classLocationFile[Product] :: Nil, mkReporter, backing)
	def this() = this(s => new ConsoleReporter(s), None)

	backing.foreach(IO.createDirectory)
	val classpathString = Path.makeString(classpath ++ backing.toList)
	val options = "-cp" +: classpathString +: optionsNoncp

	lazy val settings =
	{
		val s = new Settings(println)
		val command = new CompilerCommand(options.toList, s)
		s
	}
	lazy val reporter = mkReporter(settings)
	lazy val global: Global = new Global(settings, reporter)
	import global._
	import definitions._

	private[sbt] def unlinkDeferred() {
		toUnlinkLater foreach unlink
		toUnlinkLater = Nil
	}

	private[this] var toUnlinkLater = List[Symbol]()
	private[this] def unlink(sym: Symbol) = sym.owner.info.decls.unlink(sym)

	def eval(expression: String, imports: EvalImports = noImports, tpeName: Option[String] = None, srcName: String = "<setting>", line: Int = DefaultStartLine): EvalResult =
	{
		val ev = new EvalType[String] {
			def makeUnit = mkUnit(srcName, line, expression)
			def unlink = true
			def unitBody(unit: CompilationUnit, importTrees: Seq[Tree], moduleName: String): Tree = {
				val (parser, tree) = parse(unit, settingErrorStrings, _.expr())
				val tpt: Tree = expectedType(tpeName)
				augment(parser, importTrees, tree, tpt, moduleName)
			}
			def extra(run: Run, unit: CompilationUnit) = atPhase(run.typerPhase.next) { (new TypeExtractor).getType(unit.body) }
			def read(file: File) = IO.read(file)
			def write(value: String, f: File) = IO.write(f, value)
		}
		val i = evalCommon(expression :: Nil, imports, tpeName, ev)
		val value = (cl: ClassLoader) => getValue[Any](i.enclosingModule, i.loader(cl))
		new EvalResult(i.extra, value, i.generated, i.enclosingModule)
	}
	def evalDefinitions(definitions: Seq[(String,Range)], imports: EvalImports, srcName: String, valTypes: Seq[String]): EvalDefinitions =
	{
		require(definitions.nonEmpty, "Definitions to evaluate cannot be empty.")
		val ev = new EvalType[Seq[String]] {
			lazy val (fullUnit, defUnits) = mkDefsUnit(srcName, definitions)
			def makeUnit = fullUnit
			def unlink = false
			def unitBody(unit: CompilationUnit, importTrees: Seq[Tree], moduleName: String): Tree = {
				val fullParser = new syntaxAnalyzer.UnitParser(unit)
				val trees = defUnits flatMap parseDefinitions
				syntheticModule(fullParser, importTrees, trees.toList, moduleName)
			}
			def extra(run: Run, unit: CompilationUnit) = {
				val tpes = valTypes.map(tpe => rootMirror.getRequiredClass(tpe).tpe)
				atPhase(run.typerPhase.next) { (new ValExtractor(tpes)).getVals(unit.body) }
			}
			def read(file: File) = IO.readLines(file)
			def write(value: Seq[String], file: File) = IO.writeLines(file, value)
		}
		val i = evalCommon(definitions.map(_._1), imports, Some(""), ev)
		new EvalDefinitions(i.loader, i.generated, i.enclosingModule, i.extra)
	}

	private[this] def evalCommon[T](content: Seq[String], imports: EvalImports, tpeName: Option[String], ev: EvalType[T]): EvalIntermediate[T] =
	{
			import Eval._
		val hash = Hash.toHex(Hash(bytes( stringSeqBytes(content) :: optBytes(backing)(fileExistsBytes) :: stringSeqBytes(options) ::
			seqBytes(classpath)(fileModifiedBytes) :: stringSeqBytes(imports.strings.map(_._1)) :: optBytes(tpeName)(bytes) :: Nil)))
		val moduleName = makeModuleName(hash)
		
		lazy val unit = {
			reporter.reset
			ev.makeUnit
		}
		lazy val run = new Run {
			override def units = (unit :: Nil).iterator
		}
		def unlinkAll(): Unit = for( (sym, _) <- run.symSource ) if(ev.unlink) unlink(sym) else toUnlinkLater ::= sym

		val (extra, loader) = backing match {
			case Some(back) if classExists(back, moduleName) =>
				val loader = (parent: ClassLoader) => new URLClassLoader(Array(back.toURI.toURL), parent)
				val extra = ev.read(cacheFile(back,moduleName))
				(extra, loader)
			case _ =>
				try { compileAndLoad(run, unit, imports, backing, moduleName, ev) }
				finally { unlinkAll() }
		}

		val classFiles = getClassFiles(backing, moduleName)
		new EvalIntermediate(extra, loader, classFiles, moduleName)
	}
	// location of the cached type or definition information
	private[this] def cacheFile(base: File, moduleName: String): File = new File(base, moduleName + ".cache")
	private[this] def compileAndLoad[T](run: Run, unit: CompilationUnit, imports: EvalImports, backing: Option[File], moduleName: String, ev: EvalType[T]): (T, ClassLoader => ClassLoader) =
	{
		val dir = outputDirectory(backing)
		settings.outputDirs setSingleOutput dir

		val importTrees = parseImports(imports)
		unit.body = ev.unitBody(unit, importTrees, moduleName)

		def compile(phase: Phase): Unit =
		{
			globalPhase = phase
			if(phase == null || phase == phase.next || reporter.hasErrors)
				()
			else
			{
				atPhase(phase) { phase.run }
				compile(phase.next)
			}
		}

		compile(run.namerPhase)
		checkError("Type error in expression")

		val extra = ev.extra(run, unit)
		for(f <- backing) ev.write(extra, cacheFile(f, moduleName))
		val loader = (parent: ClassLoader) => new AbstractFileClassLoader(dir, parent)
		(extra, loader)
	}

	private[this] def expectedType(tpeName: Option[String]): Tree = 
		tpeName match {
			case Some(tpe) => parseType(tpe)
			case None => TypeTree(NoType)
		}

	private[this] def outputDirectory(backing: Option[File]): AbstractFile = 
		backing match { case None => new VirtualDirectory("<virtual>", None); case Some(dir) => new PlainFile(dir) }

	def load(dir: AbstractFile, moduleName: String): ClassLoader => Any  = parent => getValue[Any](moduleName, new AbstractFileClassLoader(dir, parent))
	def loadPlain(dir: File, moduleName: String): ClassLoader => Any  = parent => getValue[Any](moduleName, new URLClassLoader(Array(dir.toURI.toURL), parent))

		//wrap tree in object objectName { def WrapValName = <tree> }
	def augment(parser: global.syntaxAnalyzer.UnitParser, imports: Seq[Tree], tree: Tree, tpt: Tree, objectName: String): Tree =
	{
		val method = DefDef(NoMods, newTermName(WrapValName), Nil, Nil, tpt, tree)
		syntheticModule(parser, imports, method :: Nil, objectName)
	}
	private[this] def syntheticModule(parser: global.syntaxAnalyzer.UnitParser, imports: Seq[Tree], definitions: List[Tree], objectName: String): Tree =
	{
		val emptyTypeName = nme.EMPTY.toTypeName
		def emptyPkg = parser.atPos(0, 0, 0) { Ident(nme.EMPTY_PACKAGE_NAME) }
		def emptyInit = DefDef(
			NoMods,
			nme.CONSTRUCTOR,
			Nil,
			List(Nil),
			TypeTree(),
			Block(List(Apply(Select(Super(This(emptyTypeName), emptyTypeName), nme.CONSTRUCTOR), Nil)), Literal(Constant(())))
		)

		def moduleBody = Template(List(gen.scalaAnyRefConstr), emptyValDef, emptyInit :: definitions)
		def moduleDef = ModuleDef(NoMods, newTermName(objectName), moduleBody)
		parser.makePackaging(0, emptyPkg, (imports :+ moduleDef).toList)
	}

	private[this] final class TypeExtractor extends Traverser {
		private[this] var result = ""
		def getType(t: Tree) = { result = ""; traverse(t); result }
		override def traverse(tree: Tree): Unit = tree match {
			case d: DefDef if  d.symbol.nameString == WrapValName => result = d.symbol.tpe.finalResultType.toString
			case _ => super.traverse(tree)
		}
	}
	/** Tree traverser that obtains the names of vals in a top-level module whose type is a subtype of one of `types`.*/
	private[this] final class ValExtractor(types: Seq[Type]) extends Traverser {
		private[this] var vals = List[String]()
		def getVals(t: Tree): List[String] = { vals = Nil; traverse(t); vals }
		override def traverse(tree: Tree): Unit = tree match {
			case ValDef(_, n, actualTpe, _) if isTopLevelModule(tree.symbol.owner) && types.exists(_ <:< actualTpe.tpe) =>
				vals ::= nme.localToGetter(n).encoded
			case _ => super.traverse(tree)
		}
	}
	// inlined implemented of Symbol.isTopLevelModule that was removed in e5b050814deb2e7e1d6d05511d3a6cb6b013b549
	private[this] def isTopLevelModule(s: Symbol): Boolean = s.hasFlag(reflect.internal.Flags.MODULE) && s.owner.isPackageClass

	private[this] final class EvalIntermediate[T](val extra: T, val loader: ClassLoader => ClassLoader, val generated: Seq[File], val enclosingModule: String)

	private[this] def classExists(dir: File, name: String) = (new File(dir, name + ".class")).exists
	// TODO: use the code from Analyzer
	private[this] def getClassFiles(backing: Option[File], moduleName: String): Seq[File] =
		backing match {
			case None => Nil
			case Some(dir) => dir listFiles moduleClassFilter(moduleName)
		}
	private[this] def moduleClassFilter(moduleName: String) = new java.io.FilenameFilter { def accept(dir: File, s: String) =
		(s contains moduleName) && (s endsWith ".class")
	}

	private[this] class ParseErrorStrings(val base: String, val extraBlank: String, val missingBlank: String, val extraSemi: String)
	private[this] def definitionErrorStrings = new ParseErrorStrings(
		base = "Error parsing definition.",
		extraBlank = "  Ensure that there are no blank lines within a definition.",
		missingBlank = "  Ensure that definitions are separated by blank lines.",
		extraSemi = "  A trailing semicolon is not permitted for standalone definitions."
	)
	private[this] def settingErrorStrings = new ParseErrorStrings(
		base = "Error parsing expression.",
		extraBlank = "  Ensure that there are no blank lines within a setting.",
		missingBlank = "  Ensure that settings are separated by blank lines.",
		extraSemi = "  Note that settings are expressions and do not end with semicolons.  (Semicolons are fine within {} blocks, however.)"
	)

	/** Parses the provided compilation `unit` according to `f` and then performs checks on the final parser state
	* to catch errors that are common when the content is embedded in a blank-line-delimited format. */
	private[this] def parse[T](unit: CompilationUnit, errors: ParseErrorStrings, f: syntaxAnalyzer.UnitParser => T): (syntaxAnalyzer.UnitParser, T) =
	{
		val parser = new syntaxAnalyzer.UnitParser(unit)

		val tree = f(parser)
		val extra = parser.in.token match {
			case EOF => errors.extraBlank
			case _ => ""
		}
		checkError(errors.base + extra)

		parser.accept(EOF)
		val extra2 = parser.in.token match {
			case SEMI => errors.extraSemi
			case NEWLINE | NEWLINES => errors.missingBlank
			case _ => ""
		}
		checkError(errors.base + extra2)

		(parser, tree)
	}
	private[this] def parseType(tpe: String): Tree =
	{
		val tpeParser = new syntaxAnalyzer.UnitParser(mkUnit("<expected-type>", DefaultStartLine, tpe))
		val tpt0: Tree = tpeParser.typ()
		tpeParser.accept(EOF)
		checkError("Error parsing expression type.")
		tpt0
	}
	private[this] def parseImports(imports: EvalImports): Seq[Tree] =
		imports.strings flatMap { case (s, line) => parseImport(mkUnit(imports.srcName, line, s)) }
	private[this] def parseImport(importUnit: CompilationUnit): Seq[Tree] =
	{
		val parser = new syntaxAnalyzer.UnitParser(importUnit)
		val trees: Seq[Tree] = parser.importClause()
		parser.accept(EOF)
		checkError("Error parsing imports for expression.")
		trees
	}
	private[this] def parseDefinitions(du: CompilationUnit): Seq[Tree] =
		parse(du, definitionErrorStrings, parseDefinitions)._2

	/** Parses one or more definitions (defs, vals, lazy vals, classes, traits, modules). */
	private[this] def parseDefinitions(parser: syntaxAnalyzer.UnitParser): Seq[Tree] =
	{
		var defs = parser.nonLocalDefOrDcl
		parser.acceptStatSepOpt()
		while(!parser.isStatSeqEnd) {
			val next = parser.nonLocalDefOrDcl
			defs ++= next
			parser.acceptStatSepOpt()
		}
		defs
	}

	private[this] trait EvalType[T]
	{
		/** Extracts additional information after the compilation unit is evaluated.*/
		def extra(run: Run, unit: CompilationUnit): T

		/** Deserializes the extra information for unchanged inputs from a cache file.*/
		def read(file: File): T

		/** Serializes the extra information to a cache file, where it can be `read` back if inputs haven't changed.*/
		def write(value: T, file: File): Unit
	
		/** Constructs the full compilation unit for this evaluation.
		* This is used for error reporting during compilation.
		* The `unitBody` method actually does the parsing and may parse the Tree from another source. */
		def makeUnit: CompilationUnit
	
		/** If true, all top-level symbols from this evaluation will be unlinked.*/
		def unlink: Boolean

		/** Constructs the Tree to be compiled.  The full compilation `unit` from `makeUnit` is provided along with the
		* parsed imports `importTrees` to be used.  `moduleName` should be name of the enclosing module.
		* The Tree doesn't need to be parsed from the contents of `unit`. */
		def unitBody(unit: CompilationUnit, importTrees: Seq[Tree], moduleName: String): Tree
	}

	val DefaultStartLine = 0
	private[this] def makeModuleName(hash: String): String  =  "$" + Hash.halve(hash)
	private[this] def noImports = new EvalImports(Nil, "")
	private[this] def mkUnit(srcName: String, firstLine: Int, s: String) = new CompilationUnit(new EvalSourceFile(srcName, firstLine, s))
	private[this] def checkError(label: String) = if(reporter.hasErrors) throw new EvalException(label)

	private[this] final class EvalSourceFile(name: String, startLine: Int, contents: String) extends BatchSourceFile(name, contents)
	{
		override def lineToOffset(line: Int): Int = super.lineToOffset((line - startLine) max 0)
		override def offsetToLine(offset: Int): Int = super.offsetToLine(offset) + startLine
	}
	/** Constructs a CompilationUnit for each definition, which can be used to independently parse the definition into a Tree.
	* Additionally, a CompilationUnit for the combined definitions is constructed for use by combined compilation after parsing. */
	private[this] def mkDefsUnit(srcName: String, definitions: Seq[(String,Range)]): (CompilationUnit, Seq[CompilationUnit]) =
	{
		def fragmentUnit(content: String, lineMap: Array[Int]) = new CompilationUnit(fragmentSourceFile(srcName, content, lineMap))

			import collection.mutable.ListBuffer
		val lines = new ListBuffer[Int]()
		val defs = new ListBuffer[CompilationUnit]()
		val fullContent = new java.lang.StringBuilder()
		for( (defString, range) <- definitions )
		{
			defs += fragmentUnit(defString, range.toArray)
			fullContent.append(defString)
			lines ++= range
			fullContent.append("\n\n")
			lines ++= (range.end :: range.end :: Nil)
		}
		val fullUnit = fragmentUnit(fullContent.toString, lines.toArray)
		(fullUnit, defs.toSeq)
	}

	/** Source file that can map the offset in the file to and from line numbers that may discontinuous.
	* The values in `lineMap` must be ordered, but need not be consecutive. */
	private[this] def fragmentSourceFile(srcName: String, content: String, lineMap: Array[Int]) = new BatchSourceFile(srcName, content) {
		override def lineToOffset(line: Int): Int = super.lineToOffset(lineMap.indexWhere(_ == line) max 0)
		override def offsetToLine(offset: Int): Int = index(lineMap, super.offsetToLine(offset))
		// the SourceFile attribute is populated from this method, so we are required to only return the name
		override def toString = new File(srcName).getName
		private[this] def index(a: Array[Int], i: Int): Int = if(i < 0 || i >= a.length) 0 else a(i)
	}
}
private object Eval
{
	def optBytes[T](o: Option[T])(f: T => Array[Byte]): Array[Byte] = seqBytes(o.toSeq)(f)
	def stringSeqBytes(s: Seq[String]): Array[Byte] = seqBytes(s)(bytes)
	def seqBytes[T](s: Seq[T])(f: T => Array[Byte]): Array[Byte] = bytes(s map f)
	def bytes(b: Seq[Array[Byte]]): Array[Byte] = bytes(b.length) ++ b.flatten.toArray[Byte]
	def bytes(b: Boolean): Array[Byte] = Array[Byte](if(b) 1 else 0)
	def filesModifiedBytes(fs: Array[File]): Array[Byte] = if(fs eq null) filesModifiedBytes(Array[File]()) else seqBytes(fs)(fileModifiedBytes)
	def fileModifiedBytes(f: File): Array[Byte] =
		(if(f.isDirectory) filesModifiedBytes(f listFiles classDirFilter) else bytes(f.lastModified)) ++
		bytes(f.getAbsolutePath)
	def fileExistsBytes(f: File): Array[Byte] =
		bytes(f.exists) ++
		bytes(f.getAbsolutePath)
 
	def bytes(s: String): Array[Byte] = s getBytes "UTF-8"
	def bytes(l: Long): Array[Byte] =
	{
		val buffer = ByteBuffer.allocate(8)
		buffer.putLong(l)
		buffer.array
	}
	def bytes(i: Int): Array[Byte] =
	{
		val buffer = ByteBuffer.allocate(4)
		buffer.putInt(i)
		buffer.array
	}

	/** The name of the synthetic val in the synthetic module that an expression is assigned to. */
	final val WrapValName = "$sbtdef"

	/** Gets the value of the expression wrapped in module `objectName`, which is accessible via `loader`.
	* The module name should not include the trailing `$`. */
	def getValue[T](objectName: String, loader: ClassLoader): T =
	{
		val module = getModule(objectName, loader)
		val accessor = module.getClass.getMethod(WrapValName)
		val value = accessor.invoke(module)
		value.asInstanceOf[T]
	}

	/** Gets the top-level module `moduleName` from the provided class `loader`.  The module name should not include the trailing `$`.*/
	def getModule(moduleName: String, loader: ClassLoader): Any =
	{
		val clazz = Class.forName(moduleName + "$", true, loader)
		clazz.getField("MODULE$").get(null)
	}

	private val classDirFilter: FileFilter = DirectoryFilter || GlobFilter("*.class")
}