]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
More work on ONETEP-specific expression tree.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Fri, 20 Jan 2012 19:06:38 +0000 (19:06 +0000)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Fri, 20 Jan 2012 19:06:38 +0000 (19:06 +0000)
src/ofc/OFC.scala
src/ofc/generators/onetep/Tree.scala

index 8f22440eea51f1ac4acd4ad383179f31cd07bf17..96bc906af9f80e23c5404d2118d7b9e83dc5111b 100644 (file)
@@ -5,6 +5,7 @@ import parser.{Parser,Statement,Target,Identifier,ParseException}
 import generators.Generator
 
 class InvalidInputException(s: String) extends Exception(s)
+class UnimplementedException(s: String) extends Exception(s)
 
 object OFC extends Parser {
 
index 9f836ab7976bdf85f95e49581337b3d6c15362f4..0c0bb0f49defab4232f44843fb842e02f8773db5 100644 (file)
@@ -3,21 +3,33 @@ package ofc.generators.onetep
 import scala.collection.mutable.{HashMap,HashSet,Set}
 import ofc.parser
 import ofc.parser.Identifier
-import ofc.InvalidInputException
+import ofc.{InvalidInputException,UnimplementedException}
 
 trait Index
 trait SpatialIndex extends Index
 trait DiscreteIndex extends Index
 
-trait DataSpace
-{
+trait IterationSpace {
   def getSpatialIndices() : List[SpatialIndex]
   def getDiscreteIndices() : List[DiscreteIndex]
 }
 
+trait DataSpace extends IterationSpace {
+}
+
 trait Matrix extends DataSpace
 trait FunctionSet extends DataSpace
 
+class Scalar(value: Double) extends IterationSpace {
+  def getSpatialIndices() = Nil
+  def getDiscreteIndices() = Nil
+}
+
+class GeneralInnerProduct(operands: List[IterationSpace], removedIndices: List[Index]) extends IterationSpace {
+  def getSpatialIndices() = Nil
+  def getDiscreteIndices() = Nil
+}
+
 class SPAM3(name : String) extends Matrix {
   override def toString = name
 
@@ -28,8 +40,11 @@ class SPAM3(name : String) extends Matrix {
     override def toString = parent + ".col"
   }
 
+  val rowIndex = new RowIndex(this)
+  val colIndex = new ColIndex(this)
+
   def getSpatialIndices() = Nil
-  def getDiscreteIndices() = List(new RowIndex(this), new ColIndex(this))
+  def getDiscreteIndices() = List(rowIndex, colIndex)
 }
 
 class PPDFunctionSet(basis : String, data : String) extends FunctionSet {
@@ -37,14 +52,21 @@ class PPDFunctionSet(basis : String, data : String) extends FunctionSet {
   class PPDIndex(parent: PPDFunctionSet) extends DiscreteIndex
   class IntraPPDIndex(parent: PPDFunctionSet, dimension: Int) extends SpatialIndex
 
-  def getSpatialIndices() = (for (dimension <- 0 to 2) yield new IntraPPDIndex(this, dimension)).toList
-  def getDiscreteIndices() = List(new SphereIndex(this), new PPDIndex(this))
+  val ppdIndex = new PPDIndex(this)
+  val sphereIndex = new SphereIndex(this)
+  val spatialIndices = for (dimension <- 0 to 2) yield new IntraPPDIndex(this, dimension)
+
+  def getPPDIndex() = ppdIndex
+  def getSphereIndex() = sphereIndex
+
+  def getSpatialIndices() = spatialIndices.toList
+  def getDiscreteIndices() = List(getSphereIndex(), getPPDIndex())
 }
 
-class Restriction
-class Reciprocal
-class Pointwise
-class Summation
+//class Restriction
+//class Reciprocal
+//class Summation
+
 
 class BindingIndex(name : String) {
   override def toString()  = name
@@ -85,12 +107,22 @@ class IndexBindings {
 
 class TreeBuilder(dictionary : Dictionary) {
   val indexBindings = new IndexBindings
+  var nextBindingIndexID = 0
+
+  def newBindingIndex() = {
+    val index = new BindingIndex("synthetic_"+nextBindingIndexID)
+    nextBindingIndexID += 1
+    index
+  }
 
   def apply(lhs: parser.IndexedTerm, rhs: parser.Expression) {
     buildIndexedTerm(lhs)
+    buildExpression(rhs)
+
+    print(indexBindings)
   }
 
-  def buildIndexedTerm(term: parser.IndexedTerm) {
+  def buildIndexedTerm(term: parser.IndexedTerm) : DataSpace = {
     val dataSpace = dictionary.getData(term.id)
     val indices = for(bindingID <- term.indices) yield dictionary.getIndex(bindingID)
 
@@ -100,6 +132,32 @@ class TreeBuilder(dictionary : Dictionary) {
     for(i <- indices zip dataSpace.getDiscreteIndices)
       indexBindings.add(i._1, i._2)
 
-    print(indexBindings)
+    dataSpace
+  }
+
+  def buildExpression(term: parser.Expression) : IterationSpace = {
+    import parser._
+
+    term match {
+      case (t: IndexedTerm) => buildIndexedTerm(t)
+      case ScalarConstant(s) => new Scalar(s)
+      case Multiplication(a, b) => 
+        new GeneralInnerProduct(List(buildExpression(a), buildExpression(b)), Nil)
+      case Division(a, b) => 
+        throw new UnimplementedException("Semantics of division not yet defined, or implemented.")
+      case Operator(Identifier("inner"), List(a,b)) => {
+        val aExpression = buildExpression(a)
+        val bExpression = buildExpression(b)
+
+        for ((left,right) <- aExpression.getSpatialIndices zip bExpression.getSpatialIndices) {
+          val bindingIndex = newBindingIndex()
+          indexBindings.add(bindingIndex, left)
+          indexBindings.add(bindingIndex, right)
+        }
+
+        new GeneralInnerProduct(List(aExpression, bExpression), aExpression.getSpatialIndices)
+      }
+      case Operator(Identifier(name), _) => throw new UnimplementedException("Unknown or unimplemented operator: "+name)
+    }
   }
 }