]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Ignore composition issues for now.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Sun, 15 Apr 2012 17:27:00 +0000 (18:27 +0100)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Sun, 15 Apr 2012 17:27:00 +0000 (18:27 +0100)
It's unclear how to correctly compose the code fragments generated by
the various operands. Instead, let's just have everything create
ProducerStatements and merge them.

src/ofc/codegen/ProducerStatement.scala
src/ofc/generators/onetep/Assignment.scala
src/ofc/generators/onetep/CodeGenerator.scala
src/ofc/generators/onetep/IterationSpace.scala
src/ofc/generators/onetep/IterationSpaceIndexBinding.scala
src/ofc/generators/onetep/PPDFunctionSet.scala
src/ofc/generators/onetep/SPAM3.scala
src/ofc/generators/onetep/SpatialRestriction.scala
src/ofc/generators/onetep/TreeBuilder.scala

index b651c967e794962ff5fbc293cb1c3ecf04e82e47..9bd9bb472f36fd24faa238a265314df1b52ac1b1 100644 (file)
@@ -2,7 +2,7 @@ package ofc.codegen
 import ofc.LogicError
 import ofc.util.DirectedGraph
 
-class ProducerStatement extends Statement {
+object ProducerStatement {
   object Context {
     private def priority(context: Context) : Int = {
       // This ensures that the nesting ordering is Predicate, DerivedExpression, VariableRange
@@ -78,6 +78,10 @@ class ProducerStatement extends Statement {
     def defines = Set(symbol)
     def depends = Expression.findReferencedVariables(expression)
   }
+}
+
+class ProducerStatement extends Statement {
+  import ProducerStatement._
 
   var statement = new Comment("Placeholder statement for consumer.")
   var ranges : Seq[VariableRange] = Nil
@@ -100,6 +104,12 @@ class ProducerStatement extends Statement {
     predicates +:= new Predicate(condition)
   }
 
+  def merge(statement: ProducerStatement) {
+    ranges ++= statement.ranges
+    predicates ++= statement.predicates
+    expressions ++= statement.expressions
+  }
+
   def toConcrete : Statement = {
     val contexts = ranges ++ predicates ++ expressions
     val sortedContexts = Context.sort(contexts)
index b45c879ae4f926075989ff8a9fc0ae094daa8f04..90d432b6947fa4ed1e6b202787d4dd32520f1149 100644 (file)
@@ -1,11 +1,13 @@
 package ofc.generators.onetep
-import ofc.codegen.{NullStatement,FloatLiteral}
+import ofc.codegen.{ProducerStatement,NullStatement,FloatLiteral}
 
 class Assignment(indexBindings: IndexBindings, lhs: DataSpace, rhs: IterationSpace) extends IterationSpace {
-  def getOperands = List(lhs, rhs)
+  // TODO: Implement assignment
+  def getOperands = List(rhs)
   def getSpatialIndices = Nil
   def getDiscreteIndices = Nil
-  def getReaderFragment = new NullStatement
-  def getSuffixFragment = new NullStatement
+  def getPrefixFragment = new ProducerStatement
+  def getSuffixFragment = new ProducerStatement
+  def getBodyFragment = new ProducerStatement
   def getDataValue = new FloatLiteral(0.0)
 }
index 262ac1750d5c7f6a30d1e40eb062e0fb0f3ee46b..17c15a305d1e8550d658d7552a8e5597bbb914fe 100644 (file)
@@ -50,6 +50,22 @@ class CodeGenerator {
   }
   */
 
+  private def buildStatement(space: IterationSpace) : ProducerStatement  = {
+    // TODO: Until we can handle multi-operand nodes
+    assert(space.getOperands.size < 2)
+
+    val result = new ProducerStatement
+    result.merge(space.getBodyFragment)
+    result.merge(space.getPrefixFragment)
+    result.merge(space.getSuffixFragment)
+
+    for(operand <- space.getOperands) {
+      val opStatement = buildStatement(operand)
+      result.merge(opStatement)
+    }
+    result
+  }
+
   def generateCode(space: IterationSpace) {
     val allSpaces = IterationSpace.flattenPostorder(space)
     val allIndices = allSpaces flatMap (_.getIndices)
@@ -62,14 +78,9 @@ class CodeGenerator {
       println(i)
     println("")
 
-    val statements = new BlockStatement
-    for (op <- IterationSpace.sort(allSpaces)) {
-      statements += op.getReaderFragment
-      statements += op.getSuffixFragment
-    }
-
+    val statement = buildStatement(space)
     val fortranGenerator = new FortranGenerator
-    val code = fortranGenerator(statements)
+    val code = fortranGenerator(statement)
     println(code)
   }
 }
index ff5de9cdeb89805f1950cc43ac98c8e8b9a4d016..1770e75ee0de0b9a57ec8e8ff36ca56f1da73a84 100644 (file)
@@ -1,5 +1,5 @@
 package ofc.generators.onetep
-import ofc.codegen.{Statement,NullStatement,Expression,FloatType}
+import ofc.codegen.{Statement,ProducerStatement,NullStatement,Expression,FloatType}
 
 object IterationSpace {
   def sort(spaces : Traversable[IterationSpace]) : Seq[IterationSpace] = {
@@ -27,13 +27,16 @@ trait IterationSpace {
     val operands = getOperands
     operands.toSet ++ operands.flatMap(_.getDependencies)
   }
-  def getReaderFragment : Statement
-  def getSuffixFragment : Statement
+
+  def getBodyFragment : ProducerStatement
+  def getPrefixFragment : ProducerStatement
+  def getSuffixFragment : ProducerStatement
 }
 
 trait DataSpace extends IterationSpace {
   def getOperands = Nil
-  def getReaderFragment = new NullStatement
+  def getPrefixFragment = new ProducerStatement
+  def getBodyFragment = new ProducerStatement
 }
 
 trait Matrix extends DataSpace
index c4a58575f712ad521f210cba0ad577eedf0b3907..94e11a15543c2fe032632dbdea5d0293ae687091 100644 (file)
@@ -6,6 +6,7 @@ class IterationSpaceIndexBinding(operand: IterationSpace) extends IterationSpace
   def getSpatialIndices = operand.getSpatialIndices
   def getDiscreteIndices = Nil
   def getDataValue = operand.getDataValue
-  def getReaderFragment = operand.getReaderFragment
+  def getBodyFragment = operand.getBodyFragment
   def getSuffixFragment = operand.getSuffixFragment
+  def getPrefixFragment = operand.getPrefixFragment
 }
index 6c4051460fe52923f5da353e80d50de9fd9cb6e5..1ecd583dd011b6c83e3caa74c99a628e95c040bc 100644 (file)
@@ -98,7 +98,7 @@ object PPDFunctionSet {
 
 class PPDFunctionSet private(discreteIndices: Seq[DiscreteIndex], 
   spatialIndices: Seq[SpatialIndex], data: Expression[FloatType], 
-  producer: Statement) extends FunctionSet {
+  producer: ProducerStatement) extends FunctionSet {
 
   def getSuffixFragment = producer
   def getDiscreteIndices = discreteIndices
index 6320d6a8b8bd1bfa58e4a45a49ec3de69d3ffd9c..c22b35324e89105d912aad57e556b74307394073 100644 (file)
@@ -1,5 +1,5 @@
 package ofc.generators.onetep
-import ofc.codegen.{NullStatement,Comment, FloatLiteral}
+import ofc.codegen.{ProducerStatement,NullStatement,Comment, FloatLiteral}
 
 class SPAM3(name : String) extends Matrix {
   override def toString = name
@@ -8,5 +8,5 @@ class SPAM3(name : String) extends Matrix {
   def getSpatialIndices = Nil
   def getDiscreteIndices = Nil
   def getDataValue = new FloatLiteral(0.0)
-  def getSuffixFragment = new Comment("Suffix of "+toString+".")
+  def getSuffixFragment = new ProducerStatement
 }
index 98a7c38491c831e4f6377552caf00c900fa03612..65e74f49537d3b1e0666fc50dffec594d3a9d1a7 100644 (file)
@@ -7,16 +7,14 @@ object SpatialRestriction {
 
   private val pubFFTBoxWidth = for (dim <- 0 to 2) yield FFTBoxInfo.public % FFTBoxInfo.totalPts(dim)
   private val ppdWidth = for(dim <- 0 to 2) yield CellInfo.public % CellInfo.ppdWidth(dim)
-}
 
-class SpatialRestriction(op: IterationSpace) extends IterationSpace {
-  def getOperands = List(op)
-  def getDiscreteIndices = Nil
-  def getSuffixFragment = new NullStatement
-  def getDataValue = op.getDataValue
-  def getReaderFragment = {
+  private class RestrictionIndex(name: String, value: Expression[IntType]) extends SpatialIndex {
+    def getName = name
+    def getValue = value
+  }
+
+  def apply(op: IterationSpace) : SpatialRestriction = {
     import OnetepTypes._
-    import SpatialRestriction._
 
     val inputIndices = for(index <- op.getSpatialIndices) yield
       index match {
@@ -31,7 +29,7 @@ class SpatialRestriction(op: IterationSpace) extends IterationSpace {
       
     val producer = new ProducerStatement
     val origin = for (dim <- 0 to 2) 
-      yield producer.addExpression("fftbox_origin_pt"+(dim+1), functionCentre(dim) - pubFFTBoxWidth(dim)/2)
+      yield producer.addExpression("fftbox_origin_pt"+(dim+1), (cellWidthPts(dim) + functionCentre(dim) - pubFFTBoxWidth(dim)/2) % cellWidthPts(dim))
 
     val offset = for (dim <- 0 to 2)
       yield producer.addExpression("fftbox_offset_pt"+(dim+1), 
@@ -40,38 +38,19 @@ class SpatialRestriction(op: IterationSpace) extends IterationSpace {
     for (dim <- 0 to 2)
       producer.addPredicate(offset(dim) |<| pubFFTBoxWidth(dim))
 
-    producer
-  }
-
-  def getSpatialIndices = {
-    //TODO: Implement me!
-    Nil
+    val indices = for(dim <- 0 to 2) yield new RestrictionIndex("restriction_pos"+(dim+1), offset(dim))
+    new SpatialRestriction(op, indices, producer)
   }
+}
 
-/*
-  class RestrictedIndex(parent: SpatialRestriction, dimension: Int)  extends SpatialIndex {
-    def getName = "restriction_index_" + dimension
-    def getDependencies = Set()
-    def getDenseWidth(names: NameManager) = "pub_fftbox%total_pt"+(dimension+1)
-
-    def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+getDenseWidth(names)
-    def generateIterationFooter(names: NameManager) = "end do"
-    def getDeclarations(names: NameManager) = Nil
-  }
-
-  val spatialIndices = for (dimension <- 0 until op.getSpatialIndices.size) yield new RestrictedIndex(this, dimension)
-
-  def getSpatialIndices = spatialIndices.toList
-  def getDiscreteIndices = op.getDiscreteIndices
-  def getExternalIndices = Set()
-
-  def getConsumerGenerator = Some(new ConsumerGenerator {
-    def generate(names: NameManager, indices: Map[Index,String], values : Map[IterationSpace, String]) : String = {
-      "!SpatialRestriction consumer."
-    }
-  })
-  def getTransformGenerator = None
-  def getProducerGenerator = None
-*/
+class SpatialRestriction private(op: IterationSpace, 
+  spatialIndices: Seq[SpatialIndex], producer: ProducerStatement) extends IterationSpace {
+  def getOperands = List(op)
+  def getDiscreteIndices = Nil
+  def getPrefixFragment = new ProducerStatement
+  def getSuffixFragment = new ProducerStatement
+  def getDataValue = op.getDataValue
+  def getBodyFragment = producer
+  def getSpatialIndices = spatialIndices
 }
 
index 777cb6d84cb89bb01cb7a97dcfde80078b122f57..00b6a36036a029c7d3d4c4d4dd315465657d8219 100644 (file)
@@ -96,8 +96,7 @@ class TreeBuilder(dictionary : Dictionary) {
 
     term match {
       case (t: IndexedIdentifier) => buildIndexedSpace(t)
-      case Operator(Identifier("fftbox"), List(op)) => 
-        new SpatialRestriction(buildExpression(op))
+      case Operator(Identifier("fftbox"), List(op)) => SpatialRestriction(buildExpression(op))
 
       /*
       case ScalarConstant(s) => new Scalar(s)