From f3f574a9a319c380e7387580aa45abd94c6b2452 Mon Sep 17 00:00:00 2001 From: Francis Russell Date: Mon, 9 Apr 2012 19:51:50 +0100 Subject: [PATCH] Re-enable construction of SpatialRestriction node. The SpatialRestriction node corresponds to a FFT-box. We need to specify where the FFT-box is to be constructed, so we have to pass an index variable to "fftbox" as well. This is ugly, but we can just about handle it in our current parser. --- examples/test.ofl | 2 +- src/ofc/generators/onetep/OnetepTypes.scala | 5 +++ .../onetep/SpatialRestriction.scala | 26 ++++++++++++-- src/ofc/generators/onetep/TreeBuilder.scala | 36 +++++++++++++------ src/ofc/parser/Parser.scala | 4 +-- src/ofc/parser/Statement.scala | 4 +-- 6 files changed, 58 insertions(+), 19 deletions(-) diff --git a/examples/test.ofl b/examples/test.ofl index a248f18..72267b7 100644 --- a/examples/test.ofl +++ b/examples/test.ofl @@ -4,7 +4,7 @@ FunctionSet ket Index beta # Computation -kinet = ket[beta] +kinet = fftbox(beta, ket[beta]) # Implementation specific target ONETEP diff --git a/src/ofc/generators/onetep/OnetepTypes.scala b/src/ofc/generators/onetep/OnetepTypes.scala index 6c19b83..6e5907c 100644 --- a/src/ofc/generators/onetep/OnetepTypes.scala +++ b/src/ofc/generators/onetep/OnetepTypes.scala @@ -48,4 +48,9 @@ object OnetepTypes { val finishPPD = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("finish_ppds"+dim)}.toSeq def getFortranAttributes = Set("type(FUNCTION_TIGHT_BOX)") } + + object FFTBoxInfo extends StructType { + val totalPts = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("total_pt"+dim)}.toSeq + def getFortranAttributes = Set("type(FFTBOX_INFO)") + } } diff --git a/src/ofc/generators/onetep/SpatialRestriction.scala b/src/ofc/generators/onetep/SpatialRestriction.scala index 4f54676..fb469c9 100644 --- a/src/ofc/generators/onetep/SpatialRestriction.scala +++ b/src/ofc/generators/onetep/SpatialRestriction.scala @@ -1,6 +1,26 @@ package ofc.generators.onetep +import ofc.codegen._ + +object SpatialRestriction { + private val pubFFTBox = new NamedUnboundVarSymbol[StructType]("pub_fftbox", OnetepTypes.FFTBoxInfo) +} + +class SpatialRestriction(op: IterationSpace, function: BindingIndex) extends IterationSpace { + def getOperands = List(op) + def getDiscreteIndices = Nil + def getSuffixFragment = new NullStatement + def getDataValue = op.getDataValue + def getReaderFragment = { + //TODO: Implement me! + new NullStatement + } + + def getSpatialIndices = { + //TODO: Implement me! + Nil + } + /* -class SpatialRestriction(op: IterationSpace) extends IterationSpace { class RestrictedIndex(parent: SpatialRestriction, dimension: Int) extends SpatialIndex { def getName = "restriction_index_" + dimension def getDependencies = Set() @@ -13,7 +33,6 @@ class SpatialRestriction(op: IterationSpace) extends IterationSpace { val spatialIndices = for (dimension <- 0 until op.getSpatialIndices.size) yield new RestrictedIndex(this, dimension) - def getOperands = List(op) def getSpatialIndices = spatialIndices.toList def getDiscreteIndices = op.getDiscreteIndices def getExternalIndices = Set() @@ -25,5 +44,6 @@ class SpatialRestriction(op: IterationSpace) extends IterationSpace { }) def getTransformGenerator = None def getProducerGenerator = None -} */ +} + diff --git a/src/ofc/generators/onetep/TreeBuilder.scala b/src/ofc/generators/onetep/TreeBuilder.scala index 8e4a270..643ab76 100644 --- a/src/ofc/generators/onetep/TreeBuilder.scala +++ b/src/ofc/generators/onetep/TreeBuilder.scala @@ -4,7 +4,7 @@ import ofc.parser import ofc.parser.Identifier import ofc.{InvalidInputException,UnimplementedException} -class BindingIndex(name : String) { +case class BindingIndex(name : String) { override def toString() = name } @@ -47,14 +47,14 @@ class TreeBuilder(dictionary : Dictionary) { val indexBindings = new IndexBindings var nextBindingIndexID = 0 - def newBindingIndex() = { + private def newBindingIndex() = { val index = new BindingIndex("synthetic_"+nextBindingIndexID) nextBindingIndexID += 1 index } - def apply(lhs: parser.IndexedTerm, rhs: parser.Expression) = { - val lhsTree = buildIndexedTerm(lhs) + def apply(lhs: parser.IndexedIdentifier, rhs: parser.Expression) = { + val lhsTree = buildIndexedSpace(lhs) val rhsTree = buildExpression(rhs) lhsTree match { @@ -63,12 +63,12 @@ class TreeBuilder(dictionary : Dictionary) { } } - def buildIndexedTerm(term: parser.IndexedTerm) : IterationSpace = { + private def buildIndexedSpace(term: parser.IndexedIdentifier) : IterationSpace = { val dataSpace = dictionary.getData(term.id) val indices = for(bindingID <- term.indices) yield dictionary.getIndex(bindingID) if (indices.size != dataSpace.getDiscreteIndices.size) - throw new InvalidInputException("Incorrect number of indices for object "+term.id.name); + throw new InvalidInputException("Incorrect number of indices for object "+term.id.name) for(i <- indices zip dataSpace.getDiscreteIndices) indexBindings.add(i._1, i._2) @@ -79,12 +79,27 @@ class TreeBuilder(dictionary : Dictionary) { } } - def buildExpression(term: parser.Expression) : IterationSpace = { + private def buildIndex(term: parser.Expression) : BindingIndex = { + term match { + case (indexedID: parser.IndexedIdentifier) => { + if (indexedID.indices.nonEmpty) + throw new InvalidInputException("Tried to parse expression "+term+" as index but it is indexed.") + else + dictionary.getIndex(indexedID.id) + } + case other => throw new InvalidInputException("Cannot parse expression "+other+" as index.") + } + } + + private def buildExpression(term: parser.Expression) : IterationSpace = { import parser._ term match { - case (t: IndexedTerm) => buildIndexedTerm(t) -/* + case (t: IndexedIdentifier) => buildIndexedSpace(t) + case Operator(Identifier("fftbox"), List(indexID, op)) => + new SpatialRestriction(buildExpression(op), buildIndex(indexID)) + + /* case ScalarConstant(s) => new Scalar(s) case Multiplication(a, b) => new GeneralInnerProduct(List(buildExpression(a), buildExpression(b)), Set()) @@ -104,9 +119,8 @@ class TreeBuilder(dictionary : Dictionary) { } case Operator(Identifier("reciprocal"), List(op)) => new Reciprocal(buildExpression(op)) case Operator(Identifier("laplacian"), List(op)) => new Laplacian(buildExpression(op)) - case Operator(Identifier("fftbox"), List(op)) => new SpatialRestriction(buildExpression(op)) case Operator(Identifier(name), _) => throw new UnimplementedException("Unknown or incorrectly called operator: "+name) -*/ + */ } } } diff --git a/src/ofc/parser/Parser.scala b/src/ofc/parser/Parser.scala index 2a7b2bd..8e9cf7b 100644 --- a/src/ofc/parser/Parser.scala +++ b/src/ofc/parser/Parser.scala @@ -41,8 +41,8 @@ class Parser extends JavaTokenParsers { def term: Parser[Expression] = scalarConstant ||| indexedIdentifier ||| operator def scalarConstant : Parser[ScalarConstant] = floatingPointNumber ^^ (x => new ScalarConstant(x.toDouble)) - def indexedIdentifier: Parser[IndexedTerm] = identifier~opt("["~>repsep(identifier, ",")<~"]") ^^ - (x => new IndexedTerm(x._1, x._2 match { + def indexedIdentifier: Parser[IndexedIdentifier] = identifier~opt("["~>repsep(identifier, ",")<~"]") ^^ + (x => new IndexedIdentifier(x._1, x._2 match { case Some(list) => list case None => Nil })) diff --git a/src/ofc/parser/Statement.scala b/src/ofc/parser/Statement.scala index d794c1f..d4a735f 100644 --- a/src/ofc/parser/Statement.scala +++ b/src/ofc/parser/Statement.scala @@ -11,7 +11,7 @@ class Comment(value: String) extends Statement { case class DeclarationList(oflType: OFLType, names: List[Identifier]) extends Statement { override def toString : String = "decl("+oflType+", "+names+")" } -case class Definition(term: IndexedTerm, expr: Expression) extends Statement { +case class Definition(term: IndexedIdentifier, expr: Expression) extends Statement { override def toString : String = "define("+term+", "+expr+")" } case class Target(name: Identifier) extends Statement { @@ -36,7 +36,7 @@ sealed abstract class Expression case class ScalarConstant(s: Double) extends Expression { override def toString : String = s.toString } -case class IndexedTerm(id: Identifier, indices : List[Identifier]) extends Expression { +case class IndexedIdentifier(id: Identifier, indices : List[Identifier]) extends Expression { override def toString : String = id+indices.mkString("[", ", ", "]") } case class Operator(id: Identifier, operands : List[Expression]) extends Expression { -- 2.47.3