From 8edf46284b6b7623f79d2152b33a6f39a739a66f Mon Sep 17 00:00:00 2001 From: Francis Russell Date: Mon, 23 Jan 2012 17:42:16 +0000 Subject: [PATCH] Generate ONETEP-specific expression tree. --- src/ofc/generators/Onetep.scala | 5 ++- src/ofc/generators/onetep/Tree.scala | 62 ++++++++++++++++++++-------- 2 files changed, 48 insertions(+), 19 deletions(-) diff --git a/src/ofc/generators/Onetep.scala b/src/ofc/generators/Onetep.scala index 0e649ee..6f5e19b 100644 --- a/src/ofc/generators/Onetep.scala +++ b/src/ofc/generators/Onetep.scala @@ -49,7 +49,7 @@ class Onetep extends Generator { call match { case Some(FunctionCall(fSetType, params)) => (fSetType, params) match { case (Identifier("PPDFunctionSet"), ParameterList(StringParameter(basis), StringParameter(data))) => - dictionary.functionSets += (id -> new PPDFunctionSet(basis, data)) + dictionary.functionSets += id -> new PPDFunctionSet(basis, data) case _ => throw new InvalidInputException("Unknown usage of type: "+fSetType.name) } case _ => throw new InvalidInputException("Undefined concrete type for function set: "+id.name) @@ -86,7 +86,8 @@ class Onetep extends Generator { def buildDefinition(definition : parser.Definition) { println(definition) val builder = new TreeBuilder(dictionary) - builder(definition.term, definition.expr) + val assignment = builder(definition.term, definition.expr) + println(assignment) } def buildDefinitions(statements : List[parser.Statement]) { diff --git a/src/ofc/generators/onetep/Tree.scala b/src/ofc/generators/onetep/Tree.scala index 0c0bb0f..c07e001 100644 --- a/src/ofc/generators/onetep/Tree.scala +++ b/src/ofc/generators/onetep/Tree.scala @@ -26,8 +26,29 @@ class Scalar(value: Double) extends IterationSpace { } class GeneralInnerProduct(operands: List[IterationSpace], removedIndices: List[Index]) extends IterationSpace { - def getSpatialIndices() = Nil - def getDiscreteIndices() = Nil + def getSpatialIndices() = operands flatMap (op => op.getSpatialIndices filterNot (index => removedIndices.contains(index))) + def getDiscreteIndices() = operands flatMap (op => op.getDiscreteIndices filterNot (index => removedIndices.contains(index))) +} + +class Reciprocal(op: IterationSpace) extends IterationSpace { + class BlockIndex(parent: Reciprocal, dimension: Int) extends SpatialIndex + val spatialIndices = for (dimension <- 0 to op.getSpatialIndices.size) yield new BlockIndex(this, dimension) + + def getSpatialIndices() = spatialIndices.toList + def getDiscreteIndices() = op.getDiscreteIndices +} + +class Laplacian(op: IterationSpace) extends IterationSpace { + def getSpatialIndices() = op.getSpatialIndices + def getDiscreteIndices() = op.getDiscreteIndices +} + +class SpatialRestriction(op: IterationSpace) extends IterationSpace { + class RestrictedIndex(parent: SpatialRestriction, dimension: Int) extends SpatialIndex + val spatialIndices = for (dimension <- 0 to op.getSpatialIndices.size) yield new RestrictedIndex(this, dimension) + + def getSpatialIndices() = spatialIndices.toList + def getDiscreteIndices() = op.getDiscreteIndices } class SPAM3(name : String) extends Matrix { @@ -63,11 +84,6 @@ class PPDFunctionSet(basis : String, data : String) extends FunctionSet { def getDiscreteIndices() = List(getSphereIndex(), getPPDIndex()) } -//class Restriction -//class Reciprocal -//class Summation - - class BindingIndex(name : String) { override def toString() = name } @@ -93,7 +109,9 @@ class Dictionary { } } -class Definition(lhs: DataSpace, rhs: DataSpace) +class Assignment(indexBindings: IndexBindings, lhs: DataSpace, rhs: IterationSpace) { + override def toString = indexBindings.toString +} class IndexBindings { val spatial = new HashMap[BindingIndex, Set[SpatialIndex]] @@ -102,7 +120,7 @@ class IndexBindings { def add(binding: BindingIndex, index: SpatialIndex) = spatial.getOrElseUpdate(binding, new HashSet()) += index def add(binding: BindingIndex, index: DiscreteIndex) = discrete.getOrElseUpdate(binding, new HashSet()) += index - override def toString = spatial.toString + discrete.toString + override def toString = spatial.mkString("\n") + "\n" + discrete.mkString("\n") } class TreeBuilder(dictionary : Dictionary) { @@ -115,15 +133,22 @@ class TreeBuilder(dictionary : Dictionary) { index } - def apply(lhs: parser.IndexedTerm, rhs: parser.Expression) { - buildIndexedTerm(lhs) - buildExpression(rhs) + def apply(lhs: parser.IndexedTerm, rhs: parser.Expression) = { + val lhsTree = buildIndexedTerm(lhs) + val rhsTree = buildExpression(rhs) - print(indexBindings) + lhsTree match { + case (lhsTree: DataSpace) => new Assignment(indexBindings, lhsTree, rhsTree) + case _ => new InvalidInputException("Non-assignable expression on LHS of assignment.") + } } - def buildIndexedTerm(term: parser.IndexedTerm) : DataSpace = { - val dataSpace = dictionary.getData(term.id) + def buildIndexedTerm(term: parser.IndexedTerm) : IterationSpace = { + val dataSpace = dictionary.getData(term.id) match { + case (functionSet : PPDFunctionSet) => new GeneralInnerProduct(List(functionSet), List(functionSet.getPPDIndex)) + case v => v + } + val indices = for(bindingID <- term.indices) yield dictionary.getIndex(bindingID) if (indices.size != dataSpace.getDiscreteIndices.size) @@ -155,9 +180,12 @@ class TreeBuilder(dictionary : Dictionary) { indexBindings.add(bindingIndex, right) } - new GeneralInnerProduct(List(aExpression, bExpression), aExpression.getSpatialIndices) + new GeneralInnerProduct(List(aExpression, bExpression), aExpression.getSpatialIndices ++ bExpression.getSpatialIndices) } - case Operator(Identifier(name), _) => throw new UnimplementedException("Unknown or unimplemented operator: "+name) + case Operator(Identifier("reciprocal"), List(op)) => new Reciprocal(buildExpression(op)) + case Operator(Identifier("laplacian"), List(op)) => new Laplacian(buildExpression(op)) + case Operator(Identifier("fftbox"), List(op)) => new SpatialRestriction(buildExpression(op)) + case Operator(Identifier(name), _) => throw new UnimplementedException("Unknown or incorrectly called operator: "+name) } } } -- 2.47.3