]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Re-enable construction of SpatialRestriction node.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Mon, 9 Apr 2012 18:51:50 +0000 (19:51 +0100)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Mon, 9 Apr 2012 18:51:50 +0000 (19:51 +0100)
The SpatialRestriction node corresponds to a FFT-box. We need to
specify where the FFT-box is to be constructed, so we have to pass an
index variable to "fftbox" as well. This is ugly, but we can just about
handle it in our current parser.

examples/test.ofl
src/ofc/generators/onetep/OnetepTypes.scala
src/ofc/generators/onetep/SpatialRestriction.scala
src/ofc/generators/onetep/TreeBuilder.scala
src/ofc/parser/Parser.scala
src/ofc/parser/Statement.scala

index a248f186f34f1d3d2e8ea7de33f3cf147a4f3741..72267b73a2216df332bf33f4e8478b5734cb4195 100644 (file)
@@ -4,7 +4,7 @@ FunctionSet ket
 Index beta
 
 # Computation
-kinet = ket[beta]
+kinet = fftbox(beta, ket[beta])
 
 # Implementation specific
 target ONETEP
index 6c19b83121a1c5880a7ad58e23478b80732672be..6e5907c7775c2e09473cd487e1aa447c03f127fd 100644 (file)
@@ -48,4 +48,9 @@ object OnetepTypes {
     val finishPPD = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("finish_ppds"+dim)}.toSeq
     def getFortranAttributes = Set("type(FUNCTION_TIGHT_BOX)")
   }
+
+  object FFTBoxInfo extends StructType {
+    val totalPts = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("total_pt"+dim)}.toSeq
+    def getFortranAttributes = Set("type(FFTBOX_INFO)")
+  }
 }
index 4f546766128e34cf4ded941c42b7799757982869..fb469c9d2a09c2259b843b76031be686b7c6302a 100644 (file)
@@ -1,6 +1,26 @@
 package ofc.generators.onetep
+import ofc.codegen._
+
+object SpatialRestriction {
+  private val pubFFTBox = new NamedUnboundVarSymbol[StructType]("pub_fftbox", OnetepTypes.FFTBoxInfo)
+}
+
+class SpatialRestriction(op: IterationSpace, function: BindingIndex) extends IterationSpace {
+  def getOperands = List(op)
+  def getDiscreteIndices = Nil
+  def getSuffixFragment = new NullStatement
+  def getDataValue = op.getDataValue
+  def getReaderFragment = {
+    //TODO: Implement me!
+    new NullStatement
+  }
+
+  def getSpatialIndices = {
+    //TODO: Implement me!
+    Nil
+  }
+
 /*
-class SpatialRestriction(op: IterationSpace) extends IterationSpace {
   class RestrictedIndex(parent: SpatialRestriction, dimension: Int)  extends SpatialIndex {
     def getName = "restriction_index_" + dimension
     def getDependencies = Set()
@@ -13,7 +33,6 @@ class SpatialRestriction(op: IterationSpace) extends IterationSpace {
 
   val spatialIndices = for (dimension <- 0 until op.getSpatialIndices.size) yield new RestrictedIndex(this, dimension)
 
-  def getOperands = List(op)
   def getSpatialIndices = spatialIndices.toList
   def getDiscreteIndices = op.getDiscreteIndices
   def getExternalIndices = Set()
@@ -25,5 +44,6 @@ class SpatialRestriction(op: IterationSpace) extends IterationSpace {
   })
   def getTransformGenerator = None
   def getProducerGenerator = None
-}
 */
+}
+
index 8e4a2708d803d6f7351ee4d8e7c0513581446351..643ab769bd4f39674d2e112355f6666322e9ec59 100644 (file)
@@ -4,7 +4,7 @@ import ofc.parser
 import ofc.parser.Identifier
 import ofc.{InvalidInputException,UnimplementedException}
 
-class BindingIndex(name : String) {
+case class BindingIndex(name : String) {
   override def toString()  = name
 }
 
@@ -47,14 +47,14 @@ class TreeBuilder(dictionary : Dictionary) {
   val indexBindings = new IndexBindings
   var nextBindingIndexID = 0
 
-  def newBindingIndex() = {
+  private def newBindingIndex() = {
     val index = new BindingIndex("synthetic_"+nextBindingIndexID)
     nextBindingIndexID += 1
     index
   }
 
-  def apply(lhs: parser.IndexedTerm, rhs: parser.Expression) = {
-    val lhsTree = buildIndexedTerm(lhs)
+  def apply(lhs: parser.IndexedIdentifier, rhs: parser.Expression) = {
+    val lhsTree = buildIndexedSpace(lhs)
     val rhsTree = buildExpression(rhs)
 
     lhsTree match {
@@ -63,12 +63,12 @@ class TreeBuilder(dictionary : Dictionary) {
     }
   }
 
-  def buildIndexedTerm(term: parser.IndexedTerm) : IterationSpace = {
+  private def buildIndexedSpace(term: parser.IndexedIdentifier) : IterationSpace = {
     val dataSpace = dictionary.getData(term.id)
     val indices = for(bindingID <- term.indices) yield dictionary.getIndex(bindingID)
 
     if (indices.size != dataSpace.getDiscreteIndices.size)
-      throw new InvalidInputException("Incorrect number of indices for object "+term.id.name);
+      throw new InvalidInputException("Incorrect number of indices for object "+term.id.name)
 
     for(i <- indices zip dataSpace.getDiscreteIndices)
       indexBindings.add(i._1, i._2)
@@ -79,12 +79,27 @@ class TreeBuilder(dictionary : Dictionary) {
     }
   }
 
-  def buildExpression(term: parser.Expression) : IterationSpace = {
+  private def buildIndex(term: parser.Expression) : BindingIndex = {
+    term match {
+      case (indexedID: parser.IndexedIdentifier) => {
+        if (indexedID.indices.nonEmpty)
+          throw new InvalidInputException("Tried to parse expression "+term+" as index but it is indexed.")
+        else
+          dictionary.getIndex(indexedID.id)
+      }
+      case other => throw new InvalidInputException("Cannot parse expression "+other+" as index.")
+    }
+  }
+
+  private def buildExpression(term: parser.Expression) : IterationSpace = {
     import parser._
 
     term match {
-      case (t: IndexedTerm) => buildIndexedTerm(t)
-/*
+      case (t: IndexedIdentifier) => buildIndexedSpace(t)
+      case Operator(Identifier("fftbox"), List(indexID, op)) => 
+        new SpatialRestriction(buildExpression(op), buildIndex(indexID))
+
+      /*
       case ScalarConstant(s) => new Scalar(s)
       case Multiplication(a, b) => 
         new GeneralInnerProduct(List(buildExpression(a), buildExpression(b)), Set())
@@ -104,9 +119,8 @@ class TreeBuilder(dictionary : Dictionary) {
       }
       case Operator(Identifier("reciprocal"), List(op)) => new Reciprocal(buildExpression(op))
       case Operator(Identifier("laplacian"), List(op)) => new Laplacian(buildExpression(op))
-      case Operator(Identifier("fftbox"), List(op)) => new SpatialRestriction(buildExpression(op))
       case Operator(Identifier(name), _) => throw new UnimplementedException("Unknown or incorrectly called operator: "+name)
-*/
+      */
     }
   }
 }
index 2a7b2bd4e81268dcce5c271ed0889d8fab956022..8e9cf7bd7e86f536e8ee132edac1d18704e031d5 100644 (file)
@@ -41,8 +41,8 @@ class Parser extends JavaTokenParsers {
 
   def term: Parser[Expression] = scalarConstant ||| indexedIdentifier ||| operator
   def scalarConstant : Parser[ScalarConstant] = floatingPointNumber ^^ (x => new ScalarConstant(x.toDouble))
-  def indexedIdentifier: Parser[IndexedTerm]  = identifier~opt("["~>repsep(identifier, ",")<~"]") ^^ 
-    (x => new IndexedTerm(x._1, x._2 match { 
+  def indexedIdentifier: Parser[IndexedIdentifier]  = identifier~opt("["~>repsep(identifier, ",")<~"]") ^^ 
+    (x => new IndexedIdentifier(x._1, x._2 match { 
       case Some(list) => list 
       case None => Nil 
     }))
index d794c1f0a80d7ac15cced390c0284db4523c4759..d4a735f9b5a86bb1eb4477cb359a2aa5cea9b5aa 100644 (file)
@@ -11,7 +11,7 @@ class Comment(value: String) extends Statement {
 case class DeclarationList(oflType: OFLType, names: List[Identifier]) extends Statement {
   override def toString : String = "decl("+oflType+", "+names+")"
 }
-case class Definition(term: IndexedTerm, expr: Expression) extends Statement {
+case class Definition(term: IndexedIdentifier, expr: Expression) extends Statement {
   override def toString : String = "define("+term+", "+expr+")"
 }
 case class Target(name: Identifier) extends Statement {
@@ -36,7 +36,7 @@ sealed abstract class Expression
 case class ScalarConstant(s: Double) extends Expression {
   override def toString : String = s.toString
 }
-case class IndexedTerm(id: Identifier, indices : List[Identifier]) extends Expression {
+case class IndexedIdentifier(id: Identifier, indices : List[Identifier]) extends Expression {
   override def toString : String = id+indices.mkString("[", ", ", "]")
 }
 case class Operator(id: Identifier, operands : List[Expression]) extends Expression {