package sbt
package logic

import scala.annotation.tailrec
import Formula.{ And, True }

/*
Defines a propositional logic with negation as failure and only allows stratified rule sets (negation must be acyclic) in order to have a unique minimal model.

For example, this is not allowed:
  + p :- not q
  + q :- not p
but this is:
  + p :- q
  + q :- p
as is this:
  + p :- q
  + q := not r


 Some useful links:
  + https://en.wikipedia.org/wiki/Nonmonotonic_logic
  + https://en.wikipedia.org/wiki/Negation_as_failure
  + https://en.wikipedia.org/wiki/Propositional_logic
  + https://en.wikipedia.org/wiki/Stable_model_semantics
  + http://www.w3.org/2005/rules/wg/wiki/negation
*/

/** Disjunction (or) of the list of clauses. */
final case class Clauses(clauses: List[Clause]) {
  assert(clauses.nonEmpty, "At least one clause is required.")
}

/** When the `body` Formula succeeds, atoms in `head` are true. */
final case class Clause(body: Formula, head: Set[Atom])

/** A literal is an [[Atom]] or its [[negation|Negated]]. */
sealed abstract class Literal extends Formula {
  /** The underlying (positive) atom. */
  def atom: Atom
  /** Negates this literal.*/
  def unary_! : Literal
}
/** A variable with name `label`. */
final case class Atom(label: String) extends Literal {
  def atom = this
  def unary_! : Negated = Negated(this)
}
/**
 * A negated atom, in the sense of negation as failure, not logical negation.
 * That is, it is true if `atom` is not known/defined.
 */
final case class Negated(atom: Atom) extends Literal {
  def unary_! : Atom = atom
}

/**
 * A formula consists of variables, negation, and conjunction (and).
 * (Disjunction is not currently included- it is modeled at the level of a sequence of clauses.
 *  This is less convenient when defining clauses, but is not less powerful.)
 */
sealed abstract class Formula {
  /** Constructs a clause that proves `atoms` when this formula is true. */
  def proves(atom: Atom, atoms: Atom*): Clause = Clause(this, (atom +: atoms).toSet)

  /** Constructs a formula that is true iff this formula and `f` are both true.*/
  def &&(f: Formula): Formula = (this, f) match {
    case (True, x)                => x
    case (x, True)                => x
    case (And(as), And(bs))       => And(as ++ bs)
    case (And(as), b: Literal)    => And(as + b)
    case (a: Literal, And(bs))    => And(bs + a)
    case (a: Literal, b: Literal) => And(Set(a, b))
  }
}

object Formula {
  /** A conjunction of literals. */
  final case class And(literals: Set[Literal]) extends Formula {
    assert(literals.nonEmpty, "'And' requires at least one literal.")
  }
  final case object True extends Formula
}

object Logic {
  def reduceAll(clauses: List[Clause], initialFacts: Set[Literal]): Either[LogicException, Matched] =
    reduce(Clauses(clauses), initialFacts)

  /**
   * Computes the variables in the unique stable model for the program represented by `clauses` and `initialFacts`.
   * `clause` may not have any negative feedback (that is, negation is acyclic)
   * and `initialFacts` cannot be in the head of any clauses in `clause`.
   * These restrictions ensure that the logic program has a unique minimal model.
   */
  def reduce(clauses: Clauses, initialFacts: Set[Literal]): Either[LogicException, Matched] =
    {
      val (posSeq, negSeq) = separate(initialFacts.toSeq)
      val (pos, neg) = (posSeq.toSet, negSeq.toSet)

      val problem =
        checkContradictions(pos, neg) orElse
          checkOverlap(clauses, pos) orElse
          checkAcyclic(clauses)

      problem.toLeft(
        reduce0(clauses, initialFacts, Matched.empty)
      )
    }

  /**
   * Verifies `initialFacts` are not in the head of any `clauses`.
   * This avoids the situation where an atom is proved but no clauses prove it.
   * This isn't necessarily a problem, but the main sbt use cases expects
   * a proven atom to have at least one clause satisfied.
   */
  private[this] def checkOverlap(clauses: Clauses, initialFacts: Set[Atom]): Option[InitialOverlap] = {
    val as = atoms(clauses)
    val initialOverlap = initialFacts.filter(as.inHead)
    if (initialOverlap.nonEmpty) Some(new InitialOverlap(initialOverlap)) else None
  }

  private[this] def checkContradictions(pos: Set[Atom], neg: Set[Atom]): Option[InitialContradictions] = {
    val contradictions = pos intersect neg
    if (contradictions.nonEmpty) Some(new InitialContradictions(contradictions)) else None
  }

  private[this] def checkAcyclic(clauses: Clauses): Option[CyclicNegation] = {
    val deps = dependencyMap(clauses)
    val cycle = Dag.findNegativeCycle(graph(deps))
    if (cycle.nonEmpty) Some(new CyclicNegation(cycle)) else None
  }
  private[this] def graph(deps: Map[Atom, Set[Literal]]) = new Dag.DirectedSignedGraph[Atom] {
    type Arrow = Literal
    def nodes = deps.keys.toList
    def dependencies(a: Atom) = deps.getOrElse(a, Set.empty).toList
    def isNegative(b: Literal) = b match {
      case Negated(_) => true
      case Atom(_)    => false
    }
    def head(b: Literal) = b.atom
  }

  private[this] def dependencyMap(clauses: Clauses): Map[Atom, Set[Literal]] =
    (Map.empty[Atom, Set[Literal]] /: clauses.clauses) {
      case (m, Clause(formula, heads)) =>
        val deps = literals(formula)
        (m /: heads) { (n, head) => n.updated(head, n.getOrElse(head, Set.empty) ++ deps) }
    }

  sealed abstract class LogicException(override val toString: String)
  final class InitialContradictions(val literals: Set[Atom]) extends LogicException("Initial facts cannot be both true and false:\n\t" + literals.mkString("\n\t"))
  final class InitialOverlap(val literals: Set[Atom]) extends LogicException("Initial positive facts cannot be implied by any clauses:\n\t" + literals.mkString("\n\t"))
  final class CyclicNegation(val cycle: List[Literal]) extends LogicException("Negation may not be involved in a cycle:\n\t" + cycle.mkString("\n\t"))

  /** Tracks proven atoms in the reverse order they were proved. */
  final class Matched private (val provenSet: Set[Atom], reverseOrdered: List[Atom]) {
    def add(atoms: Set[Atom]): Matched = add(atoms.toList)
    def add(atoms: List[Atom]): Matched = {
      val newOnly = atoms.filterNot(provenSet)
      new Matched(provenSet ++ newOnly, newOnly ::: reverseOrdered)
    }
    def ordered: List[Atom] = reverseOrdered.reverse
    override def toString = ordered.map(_.label).mkString("Matched(", ",", ")")
  }
  object Matched {
    val empty = new Matched(Set.empty, Nil)
  }

  /** Separates a sequence of literals into `(pos, neg)` atom sequences. */
  private[this] def separate(lits: Seq[Literal]): (Seq[Atom], Seq[Atom]) = Util.separate(lits) {
    case a: Atom    => Left(a)
    case Negated(n) => Right(n)
  }

  /**
   * Finds clauses that have no body and thus prove their head.
   * Returns `(<proven atoms>, <remaining unproven clauses>)`.
   */
  private[this] def findProven(c: Clauses): (Set[Atom], List[Clause]) =
    {
      val (proven, unproven) = c.clauses.partition(_.body == True)
      (proven.flatMap(_.head).toSet, unproven)
    }
  private[this] def keepPositive(lits: Set[Literal]): Set[Atom] =
    lits.collect { case a: Atom => a }.toSet

  // precondition: factsToProcess contains no contradictions
  @tailrec
  private[this] def reduce0(clauses: Clauses, factsToProcess: Set[Literal], state: Matched): Matched =
    applyAll(clauses, factsToProcess) match {
      case None => // all of the remaining clauses failed on the new facts
        state
      case Some(applied) =>
        val (proven, unprovenClauses) = findProven(applied)
        val processedFacts = state add keepPositive(factsToProcess)
        val newlyProven = proven -- processedFacts.provenSet
        val newState = processedFacts add newlyProven
        if (unprovenClauses.isEmpty)
          newState // no remaining clauses, done.
        else {
          val unproven = Clauses(unprovenClauses)
          val nextFacts: Set[Literal] = if (newlyProven.nonEmpty) newlyProven.toSet else inferFailure(unproven)
          reduce0(unproven, nextFacts, newState)
        }
    }

  /**
   * Finds negated atoms under the negation as failure rule and returns them.
   * This should be called only after there are no more known atoms to be substituted.
   */
  private[this] def inferFailure(clauses: Clauses): Set[Literal] =
    {
      /* At this point, there is at least one clause and one of the following is the case as the result of the acyclic negation rule:
				i. there is at least one variable that occurs in a clause body but not in the head of a clause
				ii. there is at least one variable that occurs in the head of a clause and does not transitively depend on a negated variable
			In either case, each such variable x cannot be proven true and therefore proves 'not x' (negation as failure, !x in the code).
		*/
      val allAtoms = atoms(clauses)
      val newFacts: Set[Literal] = negated(allAtoms.triviallyFalse)
      if (newFacts.nonEmpty)
        newFacts
      else {
        val possiblyTrue = hasNegatedDependency(clauses.clauses, Relation.empty, Relation.empty)
        val newlyFalse: Set[Literal] = negated(allAtoms.inHead -- possiblyTrue)
        if (newlyFalse.nonEmpty)
          newlyFalse
        else // should never happen due to the acyclic negation rule
          error(s"No progress:\n\tclauses: $clauses\n\tpossibly true: $possiblyTrue")
      }
    }

  private[this] def negated(atoms: Set[Atom]): Set[Literal] = atoms.map(a => Negated(a))

  /**
   * Computes the set of atoms in `clauses` that directly or transitively take a negated atom as input.
   * For example, for the following clauses, this method would return `List(a, d)` :
   *  a :- b, not c
   *  d :- a
   */
  @tailrec
  def hasNegatedDependency(clauses: Seq[Clause], posDeps: Relation[Atom, Atom], negDeps: Relation[Atom, Atom]): List[Atom] =
    clauses match {
      case Seq() =>
        // because cycles between positive literals are allowed, this isn't strictly a topological sort
        Dag.topologicalSortUnchecked(negDeps._1s)(posDeps.reverse)
      case Clause(formula, head) +: tail =>
        // collect direct positive and negative literals and track them in separate graphs
        val (pos, neg) = directDeps(formula)
        val (newPos, newNeg) = ((posDeps, negDeps) /: head) {
          case ((pdeps, ndeps), d) =>
            (pdeps + (d, pos), ndeps + (d, neg))
        }
        hasNegatedDependency(tail, newPos, newNeg)
    }

  /** Computes the `(positive, negative)` literals in `formula`. */
  private[this] def directDeps(formula: Formula): (Seq[Atom], Seq[Atom]) =
    Util.separate(literals(formula).toSeq) {
      case Negated(a) => Right(a)
      case a: Atom    => Left(a)
    }
  private[this] def literals(formula: Formula): Set[Literal] = formula match {
    case And(lits)  => lits
    case l: Literal => Set(l)
    case True       => Set.empty
  }

  /** Computes the atoms in the heads and bodies of the clauses in `clause`. */
  def atoms(cs: Clauses): Atoms = cs.clauses.map(c => Atoms(c.head, atoms(c.body))).reduce(_ ++ _)

  /** Computes the set of all atoms in `formula`. */
  def atoms(formula: Formula): Set[Atom] = formula match {
    case And(lits)    => lits.map(_.atom)
    case Negated(lit) => Set(lit)
    case a: Atom      => Set(a)
    case True         => Set()
  }

  /** Represents the set of atoms in the heads of clauses and in the bodies (formulas) of clauses. */
  final case class Atoms(val inHead: Set[Atom], val inFormula: Set[Atom]) {
    /** Concatenates this with `as`. */
    def ++(as: Atoms): Atoms = Atoms(inHead ++ as.inHead, inFormula ++ as.inFormula)
    /** Atoms that cannot be true because they do not occur in a head. */
    def triviallyFalse: Set[Atom] = inFormula -- inHead
  }

  /**
   * Applies known facts to `clause`s, deriving a new, possibly empty list of clauses.
   * 1. If a fact is in the body of a clause, the derived clause has that fact removed from the body.
   * 2. If the negation of a fact is in a body of a clause, that clause fails and is removed.
   * 3. If a fact or its negation is in the head of a clause, the derived clause has that fact (or its negation) removed from the head.
   * 4. If a head is empty, the clause proves nothing and is removed.
   *
   * NOTE: empty bodies do not cause a clause to succeed yet.
   *       All known facts must be applied before this can be done in order to avoid inconsistencies.
   * Precondition: no contradictions in `facts`
   * Postcondition: no atom in `facts` is present in the result
   * Postcondition: No clauses have an empty head
   */
  def applyAll(cs: Clauses, facts: Set[Literal]): Option[Clauses] =
    {
      val newClauses =
        if (facts.isEmpty)
          cs.clauses.filter(_.head.nonEmpty) // still need to drop clauses with an empty head
        else
          cs.clauses.map(c => applyAll(c, facts)).flatMap(_.toList)
      if (newClauses.isEmpty) None else Some(Clauses(newClauses))
    }

  def applyAll(c: Clause, facts: Set[Literal]): Option[Clause] =
    {
      val atoms = facts.map(_.atom)
      val newHead = c.head -- atoms // 3.
      if (newHead.isEmpty) // 4. empty head
        None
      else
        substitute(c.body, facts).map(f => Clause(f, newHead)) // 1, 2
    }

  /** Derives the formula that results from substituting `facts` into `formula`. */
  @tailrec
  def substitute(formula: Formula, facts: Set[Literal]): Option[Formula] = formula match {
    case And(lits) =>
      def negated(lits: Set[Literal]): Set[Literal] = lits.map(a => !a)
      if (lits.exists(negated(facts))) // 2.
        None
      else {
        val newLits = lits -- facts
        val newF = if (newLits.isEmpty) True else And(newLits)
        Some(newF) // 1.
      }
    case True => Some(True)
    case lit: Literal => // define in terms of And
      substitute(And(Set(lit)), facts)
  }
}