]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Work on SpatialRestriction consumer/producer.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Tue, 10 Apr 2012 09:57:03 +0000 (10:57 +0100)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Tue, 10 Apr 2012 09:57:03 +0000 (10:57 +0100)
src/ofc/codegen/ProducerStatement.scala
src/ofc/generators/onetep/SpatialRestriction.scala

index e564ca7d16edc4e191cb8b52df417be6d13f6a77..b651c967e794962ff5fbc293cb1c3ecf04e82e47 100644 (file)
@@ -96,6 +96,10 @@ class ProducerStatement extends Statement {
     symbol
   }
 
+  def addPredicate(condition: Expression[BoolType]) {
+    predicates +:= new Predicate(condition)
+  }
+
   def toConcrete : Statement = {
     val contexts = ranges ++ predicates ++ expressions
     val sortedContexts = Context.sort(contexts)
index 27903a690f75e864f8a00bf82365cdf53ee313e8..98a7c38491c831e4f6377552caf00c900fa03612 100644 (file)
@@ -1,4 +1,5 @@
 package ofc.generators.onetep
+import ofc.InvalidInputException
 import ofc.codegen._
 
 object SpatialRestriction {
@@ -14,8 +15,32 @@ class SpatialRestriction(op: IterationSpace) extends IterationSpace {
   def getSuffixFragment = new NullStatement
   def getDataValue = op.getDataValue
   def getReaderFragment = {
-    //TODO: Implement me!
-    new NullStatement
+    import OnetepTypes._
+    import SpatialRestriction._
+
+    val inputIndices = for(index <- op.getSpatialIndices) yield
+      index match {
+        case (f: FunctionSpatialIndex) => f
+        case _ => throw new InvalidInputException("Input to SpatialRestriction must be a function")
+      }
+
+    val ppdWidths = for(dim <- 0 to 2) yield CellInfo.public % CellInfo.ppdWidth(dim)
+    val cellWidthPPDs = for(dim <- 0 to 2) yield CellInfo.public % CellInfo.numPPDs(dim)
+    val cellWidthPts = for(dim <- 0 to 2) yield cellWidthPPDs(dim) * ppdWidths(dim)
+    val functionCentre = inputIndices.map(_.getFunctionCentre)
+      
+    val producer = new ProducerStatement
+    val origin = for (dim <- 0 to 2) 
+      yield producer.addExpression("fftbox_origin_pt"+(dim+1), functionCentre(dim) - pubFFTBoxWidth(dim)/2)
+
+    val offset = for (dim <- 0 to 2)
+      yield producer.addExpression("fftbox_offset_pt"+(dim+1), 
+        (inputIndices(dim).getValue - origin(dim) + cellWidthPts(dim)) % cellWidthPts(dim))
+
+    for (dim <- 0 to 2)
+      producer.addPredicate(offset(dim) |<| pubFFTBoxWidth(dim))
+
+    producer
   }
 
   def getSpatialIndices = {