From: Francis Russell Date: Fri, 20 Jan 2012 19:06:38 +0000 (+0000) Subject: More work on ONETEP-specific expression tree. X-Git-Url: https://git.unchartedbackwaters.co.uk/w/?a=commitdiff_plain;h=c976a61f2a142db4c1e3edb485e465a774cdc721;p=francis%2Fofc.git More work on ONETEP-specific expression tree. --- diff --git a/src/ofc/OFC.scala b/src/ofc/OFC.scala index 8f22440..96bc906 100644 --- a/src/ofc/OFC.scala +++ b/src/ofc/OFC.scala @@ -5,6 +5,7 @@ import parser.{Parser,Statement,Target,Identifier,ParseException} import generators.Generator class InvalidInputException(s: String) extends Exception(s) +class UnimplementedException(s: String) extends Exception(s) object OFC extends Parser { diff --git a/src/ofc/generators/onetep/Tree.scala b/src/ofc/generators/onetep/Tree.scala index 9f836ab..0c0bb0f 100644 --- a/src/ofc/generators/onetep/Tree.scala +++ b/src/ofc/generators/onetep/Tree.scala @@ -3,21 +3,33 @@ package ofc.generators.onetep import scala.collection.mutable.{HashMap,HashSet,Set} import ofc.parser import ofc.parser.Identifier -import ofc.InvalidInputException +import ofc.{InvalidInputException,UnimplementedException} trait Index trait SpatialIndex extends Index trait DiscreteIndex extends Index -trait DataSpace -{ +trait IterationSpace { def getSpatialIndices() : List[SpatialIndex] def getDiscreteIndices() : List[DiscreteIndex] } +trait DataSpace extends IterationSpace { +} + trait Matrix extends DataSpace trait FunctionSet extends DataSpace +class Scalar(value: Double) extends IterationSpace { + def getSpatialIndices() = Nil + def getDiscreteIndices() = Nil +} + +class GeneralInnerProduct(operands: List[IterationSpace], removedIndices: List[Index]) extends IterationSpace { + def getSpatialIndices() = Nil + def getDiscreteIndices() = Nil +} + class SPAM3(name : String) extends Matrix { override def toString = name @@ -28,8 +40,11 @@ class SPAM3(name : String) extends Matrix { override def toString = parent + ".col" } + val rowIndex = new RowIndex(this) + val colIndex = new ColIndex(this) + def getSpatialIndices() = Nil - def getDiscreteIndices() = List(new RowIndex(this), new ColIndex(this)) + def getDiscreteIndices() = List(rowIndex, colIndex) } class PPDFunctionSet(basis : String, data : String) extends FunctionSet { @@ -37,14 +52,21 @@ class PPDFunctionSet(basis : String, data : String) extends FunctionSet { class PPDIndex(parent: PPDFunctionSet) extends DiscreteIndex class IntraPPDIndex(parent: PPDFunctionSet, dimension: Int) extends SpatialIndex - def getSpatialIndices() = (for (dimension <- 0 to 2) yield new IntraPPDIndex(this, dimension)).toList - def getDiscreteIndices() = List(new SphereIndex(this), new PPDIndex(this)) + val ppdIndex = new PPDIndex(this) + val sphereIndex = new SphereIndex(this) + val spatialIndices = for (dimension <- 0 to 2) yield new IntraPPDIndex(this, dimension) + + def getPPDIndex() = ppdIndex + def getSphereIndex() = sphereIndex + + def getSpatialIndices() = spatialIndices.toList + def getDiscreteIndices() = List(getSphereIndex(), getPPDIndex()) } -class Restriction -class Reciprocal -class Pointwise -class Summation +//class Restriction +//class Reciprocal +//class Summation + class BindingIndex(name : String) { override def toString() = name @@ -85,12 +107,22 @@ class IndexBindings { class TreeBuilder(dictionary : Dictionary) { val indexBindings = new IndexBindings + var nextBindingIndexID = 0 + + def newBindingIndex() = { + val index = new BindingIndex("synthetic_"+nextBindingIndexID) + nextBindingIndexID += 1 + index + } def apply(lhs: parser.IndexedTerm, rhs: parser.Expression) { buildIndexedTerm(lhs) + buildExpression(rhs) + + print(indexBindings) } - def buildIndexedTerm(term: parser.IndexedTerm) { + def buildIndexedTerm(term: parser.IndexedTerm) : DataSpace = { val dataSpace = dictionary.getData(term.id) val indices = for(bindingID <- term.indices) yield dictionary.getIndex(bindingID) @@ -100,6 +132,32 @@ class TreeBuilder(dictionary : Dictionary) { for(i <- indices zip dataSpace.getDiscreteIndices) indexBindings.add(i._1, i._2) - print(indexBindings) + dataSpace + } + + def buildExpression(term: parser.Expression) : IterationSpace = { + import parser._ + + term match { + case (t: IndexedTerm) => buildIndexedTerm(t) + case ScalarConstant(s) => new Scalar(s) + case Multiplication(a, b) => + new GeneralInnerProduct(List(buildExpression(a), buildExpression(b)), Nil) + case Division(a, b) => + throw new UnimplementedException("Semantics of division not yet defined, or implemented.") + case Operator(Identifier("inner"), List(a,b)) => { + val aExpression = buildExpression(a) + val bExpression = buildExpression(b) + + for ((left,right) <- aExpression.getSpatialIndices zip bExpression.getSpatialIndices) { + val bindingIndex = newBindingIndex() + indexBindings.add(bindingIndex, left) + indexBindings.add(bindingIndex, right) + } + + new GeneralInnerProduct(List(aExpression, bExpression), aExpression.getSpatialIndices) + } + case Operator(Identifier(name), _) => throw new UnimplementedException("Unknown or unimplemented operator: "+name) + } } }