/* sbt -- Simple Build Tool
 * Copyright 2009 Mark Harrah
 */
package sbt
package classfile

import java.io.{ DataInputStream, File, InputStream }
import scala.annotation.switch

// Translation of jdepend.framework.ClassFileParser by Mike Clark, Clarkware Consulting, Inc.
// BSD Licensed
//
// Note that unlike the rest of sbt, some things might be null.

import Constants._

private[sbt] object Parser {
  def apply(file: File): ClassFile = Using.fileInputStream(file)(parse(file.getAbsolutePath)).right.get
  private def parse(fileName: String)(is: InputStream): Either[String, ClassFile] = Right(parseImpl(fileName, is))
  private def parseImpl(filename: String, is: InputStream): ClassFile =
    {
      val in = new DataInputStream(is)
      new ClassFile {
        assume(in.readInt() == JavaMagic, "Invalid class file: " + fileName)

        val fileName = filename
        val minorVersion: Int = in.readUnsignedShort()
        val majorVersion: Int = in.readUnsignedShort()

        val constantPool = parseConstantPool(in)
        val accessFlags: Int = in.readUnsignedShort()

        val className = getClassConstantName(in.readUnsignedShort())
        val superClassName = getClassConstantName(in.readUnsignedShort())
        val interfaceNames = array(in.readUnsignedShort())(getClassConstantName(in.readUnsignedShort()))

        val fields = readFieldsOrMethods()
        val methods = readFieldsOrMethods()

        val attributes = array(in.readUnsignedShort())(parseAttribute())

        lazy val sourceFile =
          for (sourceFileAttribute <- attributes.find(_.isSourceFile)) yield toUTF8(entryIndex(sourceFileAttribute))

        def stringValue(a: AttributeInfo) = toUTF8(entryIndex(a))

        private def readFieldsOrMethods() = array(in.readUnsignedShort())(parseFieldOrMethodInfo())
        private def toUTF8(entryIndex: Int) =
          {
            val entry = constantPool(entryIndex)
            assume(entry.tag == ConstantUTF8, "Constant pool entry is not a UTF8 type: " + entryIndex)
            entry.value.get.asInstanceOf[String]
          }
        private def getClassConstantName(entryIndex: Int) =
          {
            val entry = constantPool(entryIndex)
            if (entry == null) ""
            else slashesToDots(toUTF8(entry.nameIndex))
          }
        private def toString(index: Int) =
          {
            if (index <= 0) None
            else Some(toUTF8(index))
          }
        private def parseFieldOrMethodInfo() =
          new FieldOrMethodInfo(in.readUnsignedShort(), toString(in.readUnsignedShort()), toString(in.readUnsignedShort()),
            array(in.readUnsignedShort())(parseAttribute()))
        private def parseAttribute() =
          {
            val nameIndex = in.readUnsignedShort()
            val name = if (nameIndex == -1) None else Some(toUTF8(nameIndex))
            val value = array(in.readInt())(in.readByte())
            new AttributeInfo(name, value)
          }

        def types = (classConstantReferences ++ fieldTypes ++ methodTypes).toSet

        private def getTypes(fieldsOrMethods: Array[FieldOrMethodInfo]) =
          fieldsOrMethods.flatMap { fieldOrMethod =>
            descriptorToTypes(fieldOrMethod.descriptor)
          }

        private def fieldTypes = getTypes(fields)
        private def methodTypes = getTypes(methods)

        private def classConstantReferences =
          constants.flatMap { constant =>
            constant.tag match {
              case ConstantClass =>
                val name = toUTF8(constant.nameIndex)
                if (name.startsWith("["))
                  descriptorToTypes(Some(name))
                else
                  slashesToDots(name) :: Nil
              case _ => Nil
            }
          }
        private def constants =
          {
            def next(i: Int, list: List[Constant]): List[Constant] =
              {
                if (i < constantPool.length) {
                  val constant = constantPool(i)
                  next(if (constant.wide) i + 2 else i + 1, constant :: list)
                } else
                  list
              }
            next(1, Nil)
          }
      }
    }
  private def array[T: scala.reflect.Manifest](size: Int)(f: => T) = Array.tabulate(size)(_ => f)
  private def parseConstantPool(in: DataInputStream) =
    {
      val constantPoolSize = in.readUnsignedShort()
      val pool = new Array[Constant](constantPoolSize)

      def parse(i: Int): Unit =
        if (i < constantPoolSize) {
          val constant = getConstant(in)
          pool(i) = constant
          parse(if (constant.wide) i + 2 else i + 1)
        }

      parse(1)
      pool
    }

  private def getConstant(in: DataInputStream): Constant =
    {
      val tag = in.readByte()

      // No switch for byte scrutinees! Stupid compiler.
      ((tag: Int): @switch) match {
        case ConstantClass | ConstantString => new Constant(tag, in.readUnsignedShort())
        case ConstantField | ConstantMethod | ConstantInterfaceMethod | ConstantNameAndType =>
          new Constant(tag, in.readUnsignedShort(), in.readUnsignedShort())
        case ConstantInteger => new Constant(tag, new java.lang.Integer(in.readInt()))
        case ConstantFloat   => new Constant(tag, new java.lang.Float(in.readFloat()))
        case ConstantLong    => new Constant(tag, new java.lang.Long(in.readLong()))
        case ConstantDouble  => new Constant(tag, new java.lang.Double(in.readDouble()))
        case ConstantUTF8    => new Constant(tag, in.readUTF())
        // TODO: proper support
        case ConstantMethodHandle =>
          val kind = in.readByte()
          val ref = in.readUnsignedShort()
          new Constant(tag, -1, -1, None)
        case ConstantMethodType =>
          val descriptorIndex = in.readUnsignedShort()
          new Constant(tag, -1, -1, None)
        case ConstantInvokeDynamic =>
          val bootstrapIndex = in.readUnsignedShort()
          val nameAndTypeIndex = in.readUnsignedShort()
          new Constant(tag, -1, -1, None)
        case _ => sys.error("Unknown constant: " + tag)
      }
    }

  private def toInt(v: Byte) = if (v < 0) v + 256 else v.toInt
  private def entryIndex(a: AttributeInfo) =
    {
      val Array(v0, v1) = a.value
      toInt(v0) * 256 + toInt(v1)
    }

  private def slashesToDots(s: String) = s.replace('/', '.')

  private def descriptorToTypes(descriptor: Option[String]) =
    {
      def toTypes(descriptor: String, types: List[String]): List[String] =
        {
          val startIndex = descriptor.indexOf(ClassDescriptor)
          if (startIndex < 0)
            types
          else {
            val endIndex = descriptor.indexOf(';', startIndex + 1)
            val tpe = slashesToDots(descriptor.substring(startIndex + 1, endIndex))
            toTypes(descriptor.substring(endIndex), tpe :: types)
          }
        }
      toTypes(descriptor.getOrElse(""), Nil)
    }
}