]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Build back-end independent DSL representation.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Thu, 26 Apr 2012 17:35:18 +0000 (18:35 +0100)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Thu, 26 Apr 2012 17:35:18 +0000 (18:35 +0100)
examples/integrals_kinetic.ofl
src/ofc/OFC.scala
src/ofc/expression/Assignment.scala [new file with mode: 0644]
src/ofc/expression/Dictionary.scala [new file with mode: 0644]
src/ofc/expression/Expression.scala [new file with mode: 0644]
src/ofc/expression/TreeBuilder.scala [new file with mode: 0644]
src/ofc/generators/Generator.scala
src/ofc/generators/Onetep.scala
src/ofc/parser/Statement.scala

index 3357a13102bd9a8760d77460f3ef1d68d66aad58..c4ad090af33614887b725c3096cfc76c77d3ec6c 100644 (file)
@@ -4,7 +4,7 @@ FunctionSet bra, ket
 Index alpha, beta
 
 # Computation
-kinet[alpha, beta] = inner(bra[alpha], reciprocal(laplacian(reciprocal(fftbox(ket[beta])))*-0.5))
+kinet[alpha, beta] = inner(bra[alpha], laplacian(ket[beta])*-0.5)
 
 # Implementation specific
 target ONETEP
index b5b303d198923be50fcf951ee477a9fc960abdb1..688c0692bf6d45bf143bc7a045b7788e121958ef 100644 (file)
@@ -1,7 +1,10 @@
 package ofc
 
+import scala.reflect.Manifest
+import scala.reflect.Manifest.singleType
 import java.io.FileReader
-import parser.{Parser,Statement,Target,Identifier,ParseException}
+import parser.{Parser,Statement,Target,TargetAssignment,Identifier,ParseException,Definition}
+import expression.{Dictionary,TreeBuilder}
 import generators.Generator
 
 class InvalidInputException(s: String) extends Exception(s)
@@ -35,9 +38,50 @@ object OFC extends Parser {
     }
   }
 
-  def processAST(statements : Seq[Statement]) = {
-    val targetStatements = 
-      statements.filter(_ match { case _ : Target => true; case _ => false }).asInstanceOf[Seq[Target]]
+  private def filterStatements[T <: parser.Statement](statements : Seq[parser.Statement])(implicit m: Manifest[T]) =
+    statements.foldLeft(List[T]())((list, item) => item match {
+      case s if (singleType(s) <:< m) =>  s.asInstanceOf[T] :: list
+      case _ => list
+    })
+
+  private def getDeclarations(statements : Seq[parser.Statement]) : Map[parser.Identifier, parser.OFLType] = {
+    def getMappings(dl : parser.DeclarationList) =
+      for (name <- dl.names) yield
+        (name, dl.oflType)
+
+    filterStatements[parser.DeclarationList](statements).flatMap(getMappings(_)).toMap
+  }
+
+  private def buildDictionary(declarations : Map[parser.Identifier, parser.OFLType]) : Dictionary = {
+    import expression.{Matrix,FunctionSet,Index}
+    val dictionary = new Dictionary
+
+    for(d <- declarations) {
+      // Find corresponding target-specific declaration if it exists.
+      d match {
+        case (id, parser.Matrix()) => dictionary.add(new Matrix(id))
+        case (id, parser.FunctionSet()) => dictionary.add(new FunctionSet(id))
+        case (id, parser.Index()) => dictionary.add(new Index(id))
+      }
+    }
+
+    dictionary
+  }
+
+  private def processAST(statements : Seq[Statement]) = {
+    val declarations = getDeclarations(statements)
+    val dictionary = buildDictionary(declarations)
+    val treeBuilder = new TreeBuilder(dictionary)
+
+    val definitions = filterStatements[Definition](statements)
+
+    val definition = definitions match {
+      case Seq(singleDef) => singleDef
+      case _ => throw new InvalidInputException("OFL file should only have one definition.")
+    }
+
+    val expressionTree = treeBuilder(definition)
+    val targetStatements = filterStatements[Target](statements)
 
     val generator : Generator = targetStatements match {
       case Seq(Target(Identifier("ONETEP"))) => new generators.Onetep
@@ -45,6 +89,7 @@ object OFC extends Parser {
       case _ => throw new InvalidInputException("OFL file should have single target statement.")
     }
 
-    generator.acceptInput(statements)
+    val targetAssignments = filterStatements[TargetAssignment](statements)
+    generator.acceptInput(dictionary, expressionTree, targetAssignments)
   }
 }
diff --git a/src/ofc/expression/Assignment.scala b/src/ofc/expression/Assignment.scala
new file mode 100644 (file)
index 0000000..fe44991
--- /dev/null
@@ -0,0 +1,5 @@
+package ofc.expression
+
+case class Assignment(lhs: ScalarExpression, rhs: ScalarExpression) {
+  override def toString = lhs.toString + " = " + rhs.toString
+}
diff --git a/src/ofc/expression/Dictionary.scala b/src/ofc/expression/Dictionary.scala
new file mode 100644 (file)
index 0000000..74d60e6
--- /dev/null
@@ -0,0 +1,28 @@
+package ofc.expression
+import ofc.parser.Identifier
+
+class Dictionary {
+  import scala.collection.mutable.HashMap
+
+  var matrices = new HashMap[Identifier, Matrix]
+  var functionSets = new HashMap[Identifier, FunctionSet]
+  var indices = new HashMap[Identifier, Index]
+
+  def add(matrix: Matrix) {
+    matrices += matrix.getIdentifier -> matrix
+  }
+
+  def add(functionSet: FunctionSet) {
+    functionSets += functionSet.getIdentifier -> functionSet
+  }
+
+  def add(index: Index) {
+    indices += index.getIdentifier -> index
+  }
+
+  def getMatrix(id: Identifier) : Option[Matrix] = matrices.get(id)
+
+  def getFunctionSet(id: Identifier) : Option[FunctionSet] = functionSets.get(id)
+
+  def getIndex(id: Identifier) : Option[Index] = indices.get(id)
+}
diff --git a/src/ofc/expression/Expression.scala b/src/ofc/expression/Expression.scala
new file mode 100644 (file)
index 0000000..4b5d260
--- /dev/null
@@ -0,0 +1,78 @@
+package ofc.expression
+import ofc.parser.Identifier
+
+case class Index(id: Identifier) {
+  def getIdentifier = id
+  def getName = id.getName
+  override def toString() = id.getName
+}
+
+trait Expression {
+  def isAssignable : Boolean
+  def numIndices : Int
+  def getDependentIndices : Set[Index]
+}
+
+trait ScalarExpression extends Expression
+trait FieldExpression extends Expression
+
+trait IndexingOperation {
+  val op: Expression
+  def getIndices : List[Index]
+  def isAssignable = op.isAssignable
+  def numIndices = 0
+  def getDependentIndices = op.getDependentIndices ++ getIndices
+  override def toString = op.toString + getIndices.map(_.getName).mkString("[",",","]")
+}
+
+trait NamedOperand {
+  val id: Identifier
+  def getIdentifier = id
+  def isAssignable = true
+  def getDependentIndices : Set[Index] = Set.empty
+  override def toString = id.getName
+}
+
+case class ScalarIndexingOperation(val op: ScalarExpression, indices: List[Index]) extends IndexingOperation with ScalarExpression {
+  def getIndices = indices
+}
+
+case class FieldIndexingOperation(val op: FieldExpression, indices: List[Index]) extends IndexingOperation with FieldExpression {
+  def getIndices = indices
+}
+
+case class InnerProduct(left: FieldExpression, right: FieldExpression) extends ScalarExpression {
+  override def toString = "inner(" + left.toString + ", " + right.toString+")"
+  def isAssignable = false
+  def numIndices = 0
+  def getDependentIndices = left.getDependentIndices ++ right.getDependentIndices
+}
+
+case class Laplacian(op: FieldExpression) extends FieldExpression {
+  override def toString = "laplacian("+op.toString+")"
+  def isAssignable = false
+  def numIndices = 0
+  def getDependentIndices = op.getDependentIndices
+}
+
+case class FieldScaling(op: FieldExpression, scale: ScalarExpression) extends FieldExpression {
+  override def toString = "(" + op.toString + "*" + scale.toString + ")"
+  def isAssignable = false
+  def numIndices = 0
+  def getDependentIndices = op.getDependentIndices ++ scale.getDependentIndices
+}
+
+class FunctionSet(val id: Identifier) extends FieldExpression with NamedOperand {
+  def numIndices = 1
+}
+
+class Matrix(val id: Identifier) extends ScalarExpression with NamedOperand {
+  def numIndices = 2
+}
+
+class ScalarLiteral(literal: Double) extends ScalarExpression {
+  override def toString = literal.toString
+  def isAssignable = false
+  def numIndices = 0
+  def getDependentIndices = Set.empty
+}
diff --git a/src/ofc/expression/TreeBuilder.scala b/src/ofc/expression/TreeBuilder.scala
new file mode 100644 (file)
index 0000000..6d17772
--- /dev/null
@@ -0,0 +1,69 @@
+package ofc.expression
+
+import ofc.parser
+import ofc.parser.Identifier
+import ofc.{InvalidInputException,UnimplementedException}
+
+class TreeBuilder(dictionary : Dictionary) {
+  def apply(definition: parser.Definition) : Assignment = {
+    val lhsTree = buildExpression(definition.term)
+    val rhsTree = buildExpression(definition.expr)
+
+    if (!lhsTree.isAssignable)
+      throw new InvalidInputException("Non-assignable expression on LHS of assignment.")
+    else
+      new Assignment(lhsTree, rhsTree)
+  }
+  
+  private def buildIndexedOperand(term: parser.IndexedIdentifier) : Expression = {
+    val indices = for(id <- term.indices) yield buildIndex(id)
+
+    dictionary.getMatrix(term.id) match {
+      case Some(matrix) => new ScalarIndexingOperation(matrix, indices)
+      case None => dictionary.getFunctionSet(term.id) match {
+        case Some(functionSet) => new FieldIndexingOperation(functionSet, indices)
+        case None => throw new UnimplementedException("No idea how to index "+term.id)
+      }
+    }
+  }
+
+  private def buildIndex(id: parser.Identifier) : Index = dictionary.getIndex(id) match {
+    case Some(index) => index
+    case None => throw new InvalidInputException("Unknown index "+id)
+  }
+
+  private def buildIndex(term: parser.Expression) : Index = term match {
+    case (indexedID: parser.IndexedIdentifier) => {
+      if (indexedID.indices.nonEmpty)
+        throw new InvalidInputException("Tried to parse expression "+term+" as index but it is indexed.")
+      else
+        buildIndex(indexedID.id)
+    }
+    case other => throw new InvalidInputException("Cannot parse expression "+other+" as index.")
+  }
+
+  private def buildExpression(term: parser.Expression) : Expression = {
+    import parser._
+
+    term match {
+      case (t: IndexedIdentifier) => buildIndexedOperand(t)
+      case ScalarConstant(s) => new ScalarLiteral(s)
+      case Division(a, b) => 
+        throw new UnimplementedException("Semantics of division not yet defined, or implemented.")
+      case Multiplication(left, right) => (buildExpression(left), buildExpression(right)) match {
+        case (field: FieldExpression, factor: ScalarExpression) => new FieldScaling(field, factor)
+        case (factor: ScalarExpression, field: FieldExpression) => new FieldScaling(field, factor)
+        case _ => throw new InvalidInputException("Cannot multiply "+left+" and "+right+".")
+      }
+      case Operator(Identifier("inner"), List(a,b)) => (buildExpression(a), buildExpression(b)) match {
+        case (left: FieldExpression, right: FieldExpression) => new InnerProduct(left, right)
+        case _ => throw new InvalidInputException("inner requires both operands to be fields.")
+      }
+      case Operator(Identifier("laplacian"), List(op)) => buildExpression(op) match {
+        case (field: FieldExpression) => new Laplacian(field)
+        case _ => throw new InvalidInputException("laplacian can only be applied to a field.")
+      }
+      case Operator(Identifier(name), _) => throw new UnimplementedException("Unknown or incorrectly called operator: "+name)
+    }
+  }
+}
index cf938eb979c82365ce12395dbf3fdb1b89ba438d..f5ee970a13da73e30b1026de241ce87f4ddb629f 100644 (file)
@@ -1,6 +1,7 @@
 package ofc.generators
-import ofc.parser.Statement
+import ofc.parser.TargetAssignment
+import ofc.expression.{Dictionary,Assignment}
 
 trait Generator {
-  def acceptInput(program : Seq[Statement]) : Unit
+  def acceptInput(dictionary: Dictionary, expression: Assignment, targetSpecific : Seq[TargetAssignment]) : Unit
 }
index e7f10f3f9f1d8f153f22b27222184e865784866b..dc33598b376ca49ef98cdae8650aeee7153b4408 100644 (file)
@@ -1,34 +1,14 @@
 package ofc.generators
 
-import ofc.parser
 import ofc.generators.onetep._
-import ofc.InvalidInputException
-import scala.reflect.Manifest
-import scala.reflect.Manifest.singleType
+import ofc.parser
+import ofc.expression.{Assignment,Dictionary}
 
 class Onetep extends Generator {
-
-  var dictionary = new Dictionary
-
-  def acceptInput(program : Seq[parser.Statement]) = {
-    println("Parsed input:\n"+program.mkString("\n") + "\n")
-    buildDictionary(program)
-    buildDefinitions(program)
-  }
-
-  def filterStatements[T <: parser.Statement](statements : Seq[parser.Statement])(implicit m: Manifest[T]) =
-    statements.foldLeft(List[T]())((list, item) => item match {
-      case s if (singleType(s) <:< m) =>  s.asInstanceOf[T] :: list
-      case _ => list
-    })
-
-  def getDeclarations(statements : Seq[parser.Statement]) : Map[parser.Identifier, parser.OFLType] = {
-    def getMappings(dl : parser.DeclarationList) =
-      for (name <- dl.names) yield
-        (name, dl.oflType)
-
-    filterStatements[parser.DeclarationList](statements).flatMap(getMappings(_)).toMap
+  def acceptInput(dictionary: Dictionary, assignment: Assignment, targetSpecific : Seq[parser.TargetAssignment]) = {
+    println(assignment)
   }
+  /*
 
   def buildMatrix(id: parser.Identifier, call : Option[parser.FunctionCall]) {
     import parser._
@@ -97,4 +77,5 @@ class Onetep extends Generator {
     else 
       buildDefinition(definitions.head)
   }
+  */
 }
index d4a735f9b5a86bb1eb4477cb359a2aa5cea9b5aa..94e6487bf9215eab063f422d8b4747da89edbf0c 100644 (file)
@@ -2,6 +2,7 @@ package ofc.parser
 
 case class Identifier(name: String) {
   override def toString : String = "id(\""+name+"\")"
+  def getName = name
 }
 
 sealed abstract class Statement