From: Francis Russell Date: Thu, 26 Apr 2012 17:35:18 +0000 (+0100) Subject: Build back-end independent DSL representation. X-Git-Url: https://git.unchartedbackwaters.co.uk/w/?a=commitdiff_plain;h=f456a564970e13442e154b0a40d7fb7ccdd699a9;p=francis%2Fofc.git Build back-end independent DSL representation. --- diff --git a/examples/integrals_kinetic.ofl b/examples/integrals_kinetic.ofl index 3357a13..c4ad090 100644 --- a/examples/integrals_kinetic.ofl +++ b/examples/integrals_kinetic.ofl @@ -4,7 +4,7 @@ FunctionSet bra, ket Index alpha, beta # Computation -kinet[alpha, beta] = inner(bra[alpha], reciprocal(laplacian(reciprocal(fftbox(ket[beta])))*-0.5)) +kinet[alpha, beta] = inner(bra[alpha], laplacian(ket[beta])*-0.5) # Implementation specific target ONETEP diff --git a/src/ofc/OFC.scala b/src/ofc/OFC.scala index b5b303d..688c069 100644 --- a/src/ofc/OFC.scala +++ b/src/ofc/OFC.scala @@ -1,7 +1,10 @@ package ofc +import scala.reflect.Manifest +import scala.reflect.Manifest.singleType import java.io.FileReader -import parser.{Parser,Statement,Target,Identifier,ParseException} +import parser.{Parser,Statement,Target,TargetAssignment,Identifier,ParseException,Definition} +import expression.{Dictionary,TreeBuilder} import generators.Generator class InvalidInputException(s: String) extends Exception(s) @@ -35,9 +38,50 @@ object OFC extends Parser { } } - def processAST(statements : Seq[Statement]) = { - val targetStatements = - statements.filter(_ match { case _ : Target => true; case _ => false }).asInstanceOf[Seq[Target]] + private def filterStatements[T <: parser.Statement](statements : Seq[parser.Statement])(implicit m: Manifest[T]) = + statements.foldLeft(List[T]())((list, item) => item match { + case s if (singleType(s) <:< m) => s.asInstanceOf[T] :: list + case _ => list + }) + + private def getDeclarations(statements : Seq[parser.Statement]) : Map[parser.Identifier, parser.OFLType] = { + def getMappings(dl : parser.DeclarationList) = + for (name <- dl.names) yield + (name, dl.oflType) + + filterStatements[parser.DeclarationList](statements).flatMap(getMappings(_)).toMap + } + + private def buildDictionary(declarations : Map[parser.Identifier, parser.OFLType]) : Dictionary = { + import expression.{Matrix,FunctionSet,Index} + val dictionary = new Dictionary + + for(d <- declarations) { + // Find corresponding target-specific declaration if it exists. + d match { + case (id, parser.Matrix()) => dictionary.add(new Matrix(id)) + case (id, parser.FunctionSet()) => dictionary.add(new FunctionSet(id)) + case (id, parser.Index()) => dictionary.add(new Index(id)) + } + } + + dictionary + } + + private def processAST(statements : Seq[Statement]) = { + val declarations = getDeclarations(statements) + val dictionary = buildDictionary(declarations) + val treeBuilder = new TreeBuilder(dictionary) + + val definitions = filterStatements[Definition](statements) + + val definition = definitions match { + case Seq(singleDef) => singleDef + case _ => throw new InvalidInputException("OFL file should only have one definition.") + } + + val expressionTree = treeBuilder(definition) + val targetStatements = filterStatements[Target](statements) val generator : Generator = targetStatements match { case Seq(Target(Identifier("ONETEP"))) => new generators.Onetep @@ -45,6 +89,7 @@ object OFC extends Parser { case _ => throw new InvalidInputException("OFL file should have single target statement.") } - generator.acceptInput(statements) + val targetAssignments = filterStatements[TargetAssignment](statements) + generator.acceptInput(dictionary, expressionTree, targetAssignments) } } diff --git a/src/ofc/expression/Assignment.scala b/src/ofc/expression/Assignment.scala new file mode 100644 index 0000000..fe44991 --- /dev/null +++ b/src/ofc/expression/Assignment.scala @@ -0,0 +1,5 @@ +package ofc.expression + +case class Assignment(lhs: ScalarExpression, rhs: ScalarExpression) { + override def toString = lhs.toString + " = " + rhs.toString +} diff --git a/src/ofc/expression/Dictionary.scala b/src/ofc/expression/Dictionary.scala new file mode 100644 index 0000000..74d60e6 --- /dev/null +++ b/src/ofc/expression/Dictionary.scala @@ -0,0 +1,28 @@ +package ofc.expression +import ofc.parser.Identifier + +class Dictionary { + import scala.collection.mutable.HashMap + + var matrices = new HashMap[Identifier, Matrix] + var functionSets = new HashMap[Identifier, FunctionSet] + var indices = new HashMap[Identifier, Index] + + def add(matrix: Matrix) { + matrices += matrix.getIdentifier -> matrix + } + + def add(functionSet: FunctionSet) { + functionSets += functionSet.getIdentifier -> functionSet + } + + def add(index: Index) { + indices += index.getIdentifier -> index + } + + def getMatrix(id: Identifier) : Option[Matrix] = matrices.get(id) + + def getFunctionSet(id: Identifier) : Option[FunctionSet] = functionSets.get(id) + + def getIndex(id: Identifier) : Option[Index] = indices.get(id) +} diff --git a/src/ofc/expression/Expression.scala b/src/ofc/expression/Expression.scala new file mode 100644 index 0000000..4b5d260 --- /dev/null +++ b/src/ofc/expression/Expression.scala @@ -0,0 +1,78 @@ +package ofc.expression +import ofc.parser.Identifier + +case class Index(id: Identifier) { + def getIdentifier = id + def getName = id.getName + override def toString() = id.getName +} + +trait Expression { + def isAssignable : Boolean + def numIndices : Int + def getDependentIndices : Set[Index] +} + +trait ScalarExpression extends Expression +trait FieldExpression extends Expression + +trait IndexingOperation { + val op: Expression + def getIndices : List[Index] + def isAssignable = op.isAssignable + def numIndices = 0 + def getDependentIndices = op.getDependentIndices ++ getIndices + override def toString = op.toString + getIndices.map(_.getName).mkString("[",",","]") +} + +trait NamedOperand { + val id: Identifier + def getIdentifier = id + def isAssignable = true + def getDependentIndices : Set[Index] = Set.empty + override def toString = id.getName +} + +case class ScalarIndexingOperation(val op: ScalarExpression, indices: List[Index]) extends IndexingOperation with ScalarExpression { + def getIndices = indices +} + +case class FieldIndexingOperation(val op: FieldExpression, indices: List[Index]) extends IndexingOperation with FieldExpression { + def getIndices = indices +} + +case class InnerProduct(left: FieldExpression, right: FieldExpression) extends ScalarExpression { + override def toString = "inner(" + left.toString + ", " + right.toString+")" + def isAssignable = false + def numIndices = 0 + def getDependentIndices = left.getDependentIndices ++ right.getDependentIndices +} + +case class Laplacian(op: FieldExpression) extends FieldExpression { + override def toString = "laplacian("+op.toString+")" + def isAssignable = false + def numIndices = 0 + def getDependentIndices = op.getDependentIndices +} + +case class FieldScaling(op: FieldExpression, scale: ScalarExpression) extends FieldExpression { + override def toString = "(" + op.toString + "*" + scale.toString + ")" + def isAssignable = false + def numIndices = 0 + def getDependentIndices = op.getDependentIndices ++ scale.getDependentIndices +} + +class FunctionSet(val id: Identifier) extends FieldExpression with NamedOperand { + def numIndices = 1 +} + +class Matrix(val id: Identifier) extends ScalarExpression with NamedOperand { + def numIndices = 2 +} + +class ScalarLiteral(literal: Double) extends ScalarExpression { + override def toString = literal.toString + def isAssignable = false + def numIndices = 0 + def getDependentIndices = Set.empty +} diff --git a/src/ofc/expression/TreeBuilder.scala b/src/ofc/expression/TreeBuilder.scala new file mode 100644 index 0000000..6d17772 --- /dev/null +++ b/src/ofc/expression/TreeBuilder.scala @@ -0,0 +1,69 @@ +package ofc.expression + +import ofc.parser +import ofc.parser.Identifier +import ofc.{InvalidInputException,UnimplementedException} + +class TreeBuilder(dictionary : Dictionary) { + def apply(definition: parser.Definition) : Assignment = { + val lhsTree = buildExpression(definition.term) + val rhsTree = buildExpression(definition.expr) + + if (!lhsTree.isAssignable) + throw new InvalidInputException("Non-assignable expression on LHS of assignment.") + else + new Assignment(lhsTree, rhsTree) + } + + private def buildIndexedOperand(term: parser.IndexedIdentifier) : Expression = { + val indices = for(id <- term.indices) yield buildIndex(id) + + dictionary.getMatrix(term.id) match { + case Some(matrix) => new ScalarIndexingOperation(matrix, indices) + case None => dictionary.getFunctionSet(term.id) match { + case Some(functionSet) => new FieldIndexingOperation(functionSet, indices) + case None => throw new UnimplementedException("No idea how to index "+term.id) + } + } + } + + private def buildIndex(id: parser.Identifier) : Index = dictionary.getIndex(id) match { + case Some(index) => index + case None => throw new InvalidInputException("Unknown index "+id) + } + + private def buildIndex(term: parser.Expression) : Index = term match { + case (indexedID: parser.IndexedIdentifier) => { + if (indexedID.indices.nonEmpty) + throw new InvalidInputException("Tried to parse expression "+term+" as index but it is indexed.") + else + buildIndex(indexedID.id) + } + case other => throw new InvalidInputException("Cannot parse expression "+other+" as index.") + } + + private def buildExpression(term: parser.Expression) : Expression = { + import parser._ + + term match { + case (t: IndexedIdentifier) => buildIndexedOperand(t) + case ScalarConstant(s) => new ScalarLiteral(s) + case Division(a, b) => + throw new UnimplementedException("Semantics of division not yet defined, or implemented.") + case Multiplication(left, right) => (buildExpression(left), buildExpression(right)) match { + case (field: FieldExpression, factor: ScalarExpression) => new FieldScaling(field, factor) + case (factor: ScalarExpression, field: FieldExpression) => new FieldScaling(field, factor) + case _ => throw new InvalidInputException("Cannot multiply "+left+" and "+right+".") + } + case Operator(Identifier("inner"), List(a,b)) => (buildExpression(a), buildExpression(b)) match { + case (left: FieldExpression, right: FieldExpression) => new InnerProduct(left, right) + case _ => throw new InvalidInputException("inner requires both operands to be fields.") + } + case Operator(Identifier("laplacian"), List(op)) => buildExpression(op) match { + case (field: FieldExpression) => new Laplacian(field) + case _ => throw new InvalidInputException("laplacian can only be applied to a field.") + } + case Operator(Identifier(name), _) => throw new UnimplementedException("Unknown or incorrectly called operator: "+name) + } + } +} diff --git a/src/ofc/generators/Generator.scala b/src/ofc/generators/Generator.scala index cf938eb..f5ee970 100644 --- a/src/ofc/generators/Generator.scala +++ b/src/ofc/generators/Generator.scala @@ -1,6 +1,7 @@ package ofc.generators -import ofc.parser.Statement +import ofc.parser.TargetAssignment +import ofc.expression.{Dictionary,Assignment} trait Generator { - def acceptInput(program : Seq[Statement]) : Unit + def acceptInput(dictionary: Dictionary, expression: Assignment, targetSpecific : Seq[TargetAssignment]) : Unit } diff --git a/src/ofc/generators/Onetep.scala b/src/ofc/generators/Onetep.scala index e7f10f3..dc33598 100644 --- a/src/ofc/generators/Onetep.scala +++ b/src/ofc/generators/Onetep.scala @@ -1,34 +1,14 @@ package ofc.generators -import ofc.parser import ofc.generators.onetep._ -import ofc.InvalidInputException -import scala.reflect.Manifest -import scala.reflect.Manifest.singleType +import ofc.parser +import ofc.expression.{Assignment,Dictionary} class Onetep extends Generator { - - var dictionary = new Dictionary - - def acceptInput(program : Seq[parser.Statement]) = { - println("Parsed input:\n"+program.mkString("\n") + "\n") - buildDictionary(program) - buildDefinitions(program) - } - - def filterStatements[T <: parser.Statement](statements : Seq[parser.Statement])(implicit m: Manifest[T]) = - statements.foldLeft(List[T]())((list, item) => item match { - case s if (singleType(s) <:< m) => s.asInstanceOf[T] :: list - case _ => list - }) - - def getDeclarations(statements : Seq[parser.Statement]) : Map[parser.Identifier, parser.OFLType] = { - def getMappings(dl : parser.DeclarationList) = - for (name <- dl.names) yield - (name, dl.oflType) - - filterStatements[parser.DeclarationList](statements).flatMap(getMappings(_)).toMap + def acceptInput(dictionary: Dictionary, assignment: Assignment, targetSpecific : Seq[parser.TargetAssignment]) = { + println(assignment) } + /* def buildMatrix(id: parser.Identifier, call : Option[parser.FunctionCall]) { import parser._ @@ -97,4 +77,5 @@ class Onetep extends Generator { else buildDefinition(definitions.head) } + */ } diff --git a/src/ofc/parser/Statement.scala b/src/ofc/parser/Statement.scala index d4a735f..94e6487 100644 --- a/src/ofc/parser/Statement.scala +++ b/src/ofc/parser/Statement.scala @@ -2,6 +2,7 @@ package ofc.parser case class Identifier(name: String) { override def toString : String = "id(\""+name+"\")" + def getName = name } sealed abstract class Statement