]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Generate ONETEP-specific expression tree.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Mon, 23 Jan 2012 17:42:16 +0000 (17:42 +0000)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Mon, 23 Jan 2012 17:42:16 +0000 (17:42 +0000)
src/ofc/generators/Onetep.scala
src/ofc/generators/onetep/Tree.scala

index 0e649eeebeed334d9ea622d3f6ad839c57e35b4e..6f5e19b38365f16d6312e19b44c7517b3d5adeb5 100644 (file)
@@ -49,7 +49,7 @@ class Onetep extends Generator {
     call match {
       case Some(FunctionCall(fSetType, params)) => (fSetType, params) match {
         case (Identifier("PPDFunctionSet"), ParameterList(StringParameter(basis), StringParameter(data))) => 
-          dictionary.functionSets += (id -> new PPDFunctionSet(basis, data))
+          dictionary.functionSets += id -> new PPDFunctionSet(basis, data)
         case _ => throw new InvalidInputException("Unknown usage of type: "+fSetType.name)
       }
       case _ => throw new InvalidInputException("Undefined concrete type for function set: "+id.name)
@@ -86,7 +86,8 @@ class Onetep extends Generator {
   def buildDefinition(definition : parser.Definition) {
     println(definition)
     val builder = new TreeBuilder(dictionary)
-    builder(definition.term, definition.expr)
+    val assignment = builder(definition.term, definition.expr)
+    println(assignment)
   }
 
   def buildDefinitions(statements : List[parser.Statement]) {
index 0c0bb0f49defab4232f44843fb842e02f8773db5..c07e001c92c9d927d35e57db35a8e5cff915d441 100644 (file)
@@ -26,8 +26,29 @@ class Scalar(value: Double) extends IterationSpace {
 }
 
 class GeneralInnerProduct(operands: List[IterationSpace], removedIndices: List[Index]) extends IterationSpace {
-  def getSpatialIndices() = Nil
-  def getDiscreteIndices() = Nil
+  def getSpatialIndices() = operands flatMap (op => op.getSpatialIndices filterNot (index => removedIndices.contains(index)))
+  def getDiscreteIndices() = operands flatMap (op => op.getDiscreteIndices filterNot (index => removedIndices.contains(index)))
+}
+
+class Reciprocal(op: IterationSpace) extends IterationSpace {
+  class BlockIndex(parent: Reciprocal, dimension: Int)  extends SpatialIndex
+  val spatialIndices = for (dimension <- 0 to op.getSpatialIndices.size) yield new BlockIndex(this, dimension)
+
+  def getSpatialIndices() = spatialIndices.toList
+  def getDiscreteIndices() = op.getDiscreteIndices
+}
+
+class Laplacian(op: IterationSpace) extends IterationSpace {
+  def getSpatialIndices() = op.getSpatialIndices
+  def getDiscreteIndices() = op.getDiscreteIndices
+}
+
+class SpatialRestriction(op: IterationSpace) extends IterationSpace {
+  class RestrictedIndex(parent: SpatialRestriction, dimension: Int)  extends SpatialIndex
+  val spatialIndices = for (dimension <- 0 to op.getSpatialIndices.size) yield new RestrictedIndex(this, dimension)
+
+  def getSpatialIndices() = spatialIndices.toList
+  def getDiscreteIndices() = op.getDiscreteIndices
 }
 
 class SPAM3(name : String) extends Matrix {
@@ -63,11 +84,6 @@ class PPDFunctionSet(basis : String, data : String) extends FunctionSet {
   def getDiscreteIndices() = List(getSphereIndex(), getPPDIndex())
 }
 
-//class Restriction
-//class Reciprocal
-//class Summation
-
-
 class BindingIndex(name : String) {
   override def toString()  = name
 }
@@ -93,7 +109,9 @@ class Dictionary {
     }
 }
 
-class Definition(lhs: DataSpace, rhs: DataSpace)
+class Assignment(indexBindings: IndexBindings, lhs: DataSpace, rhs: IterationSpace) {
+  override def toString = indexBindings.toString
+}
 
 class IndexBindings {
   val spatial = new HashMap[BindingIndex, Set[SpatialIndex]]
@@ -102,7 +120,7 @@ class IndexBindings {
   def add(binding: BindingIndex, index: SpatialIndex) = spatial.getOrElseUpdate(binding, new HashSet()) += index
   def add(binding: BindingIndex, index: DiscreteIndex) = discrete.getOrElseUpdate(binding, new HashSet()) += index
   
-  override def toString = spatial.toString + discrete.toString
+  override def toString = spatial.mkString("\n") + "\n" + discrete.mkString("\n")
 }
 
 class TreeBuilder(dictionary : Dictionary) {
@@ -115,15 +133,22 @@ class TreeBuilder(dictionary : Dictionary) {
     index
   }
 
-  def apply(lhs: parser.IndexedTerm, rhs: parser.Expression) {
-    buildIndexedTerm(lhs)
-    buildExpression(rhs)
+  def apply(lhs: parser.IndexedTerm, rhs: parser.Expression) {
+    val lhsTree = buildIndexedTerm(lhs)
+    val rhsTree = buildExpression(rhs)
 
-    print(indexBindings)
+    lhsTree match {
+      case (lhsTree: DataSpace) => new Assignment(indexBindings, lhsTree, rhsTree)
+      case _ => new InvalidInputException("Non-assignable expression on LHS of assignment.")
+    }
   }
 
-  def buildIndexedTerm(term: parser.IndexedTerm) : DataSpace = {
-    val dataSpace = dictionary.getData(term.id)
+  def buildIndexedTerm(term: parser.IndexedTerm) : IterationSpace = {
+    val dataSpace = dictionary.getData(term.id) match {
+      case (functionSet : PPDFunctionSet) => new GeneralInnerProduct(List(functionSet), List(functionSet.getPPDIndex))
+      case v => v
+    }
+
     val indices = for(bindingID <- term.indices) yield dictionary.getIndex(bindingID)
 
     if (indices.size != dataSpace.getDiscreteIndices.size)
@@ -155,9 +180,12 @@ class TreeBuilder(dictionary : Dictionary) {
           indexBindings.add(bindingIndex, right)
         }
 
-        new GeneralInnerProduct(List(aExpression, bExpression), aExpression.getSpatialIndices)
+        new GeneralInnerProduct(List(aExpression, bExpression), aExpression.getSpatialIndices ++ bExpression.getSpatialIndices)
       }
-      case Operator(Identifier(name), _) => throw new UnimplementedException("Unknown or unimplemented operator: "+name)
+      case Operator(Identifier("reciprocal"), List(op)) => new Reciprocal(buildExpression(op))
+      case Operator(Identifier("laplacian"), List(op)) => new Laplacian(buildExpression(op))
+      case Operator(Identifier("fftbox"), List(op)) => new SpatialRestriction(buildExpression(op))
+      case Operator(Identifier(name), _) => throw new UnimplementedException("Unknown or incorrectly called operator: "+name)
     }
   }
 }