]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Use better control-flow nesting in ProduderStatement.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Sat, 7 Apr 2012 09:05:09 +0000 (10:05 +0100)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Sat, 7 Apr 2012 09:19:48 +0000 (10:19 +0100)
When we do not have any dependencies between a loops, a predicates and
expressions, place expressions outside loops, but both loops and
expressions inside predicates.

src/ofc/codegen/Assignment.scala
src/ofc/codegen/FortranGenerator.scala
src/ofc/codegen/ProducerStatement.scala
src/ofc/codegen/ScopeStatement.scala
src/ofc/generators/onetep/PPDFunctionSet.scala

index f729988c58e626dd4b78da79995fdd25cbeab03c..e932eba90f349831c6b533a315fdd1fdb70fc592 100644 (file)
@@ -1,5 +1,7 @@
 package ofc.codegen
 
-class Assignment(symbol: VarSymbol[_ <: Type], expression: Expression[_ <: Type]) extends Statement {
+class Assignment(lhs: Expression[_ <: Type], rhs: Expression[_ <: Type]) extends Statement {
+  def getLHS : Expression[_] = lhs
+  def getRHS : Expression[_] = rhs
   // TODO: type check assignment
 }
index fdc6e89bf065e44736f72aa0f92746f039b02f3a..e852c2198c9d9a24c7fdd3841c8d1e55f816efd5 100644 (file)
@@ -31,14 +31,17 @@ class SymbolManager {
         val name = createNewName(s)
         names += name
         symbols += s -> new SymbolInfo(name)
+      } else {
+        throw new LogicError("Attempted to add symbol more than once (multiple declarations?): "+sym.toString)
       }
+
       case (_: NamedUnboundVarSymbol[_]) => throw new LogicError("Attempted to add unbound symbol to SymbolManager.")
     }
   }
 
   def getName(sym: VarSymbol[_]) =
     symbols.get(sym) match {
-      case None => throw new LogicError("Unknown symbol "+sym.toString)
+      case None => throw new LogicError("Unknown symbol (missing declaration?): "+sym.toString)
       case Some(info) => info.getName
     }
 }
@@ -55,6 +58,7 @@ class FortranGenerator {
       case (x : BlockStatement) => processScope(x)
       case (x : ProducerStatement) => processStatement(x.toConcrete)
       case (x : ForLoop) => processForLoop(x)
+      case (a : Assignment) => processAssignment(a)
       case x => throw new UnimplementedException("Unknown statement type in FORTRAN generator: " + x.toString)
     }
 
@@ -136,7 +140,6 @@ class FortranGenerator {
 
   private def processForLoop(stat: ForLoop) {
     val index = stat.getIndex
-    symbolManager.addSymbol(index)
     val name = symbolManager.getName(index)
     val begin = buildExpression(stat.getBegin)
     val end = buildExpression(stat.getEnd)
@@ -151,11 +154,18 @@ class FortranGenerator {
   }
 
   private def processScope(scope: ScopeStatement) {
+    for (sym <- scope.getDeclarations) {
+      symbolManager.addSymbol(sym)
+    }
     for(stat <- scope.getStatements) {
       processStatement(stat)
     }
   }
 
+  private def processAssignment(assignment: Assignment) {
+    addLine("%s = %s".format(buildExpression(assignment.getLHS), buildExpression(assignment.getRHS)))
+  }
+
   private def addLine(line: String) {
     buffer += "  "*indentLevel + line
   }
index c9d3acf0eb06390c6854678531fc63d11ef805ac..c734264b2ae5c90554ff3189c76c44021c15bbfb 100644 (file)
@@ -3,14 +3,30 @@ import ofc.util.Ordering
 
 class ProducerStatement extends Statement {
   object Context {
+    def preferenceOrdering(ordering: (Context, Context) => Boolean) : (Context, Context) => Boolean = {
+      // This ensures that the nesting ordering is Predicate, DerivedExpression, VariableRange
+      // when no other dependencies exist.
+      (left, right) => if (ordering(left, right)) 
+        true 
+      else if (ordering(right, left)) 
+        false
+      else (left, right) match {
+        case (_: Predicate, _: DerivedExpression) => true
+        case (_: Predicate, _: VariableRange) => true
+        case (_: DerivedExpression, _: VariableRange) => true
+        case _ => false
+      }
+    }
+
     def sort(contexts: Seq[Context]) : Seq[Context] = {
       def pathFunction(c1: Context, c2: Context) = c1.tryCompare(c2) match {
         case Some(x) if x<0 => true
         case _ => false
       }
 
-      val totalOrdering = Ordering.transitiveClosure(contexts, pathFunction(_: Context, _: Context))
-      contexts.sortWith((a,b) => totalOrdering.contains(a,b))
+      val partialOrdering = Ordering.transitiveClosure(contexts, pathFunction(_: Context, _: Context))
+      val augmentedOrdering = preferenceOrdering((a,b) => partialOrdering.contains(a,b))
+      contexts.sortWith(augmentedOrdering)
     }
   }
 
@@ -79,6 +95,7 @@ class ProducerStatement extends Statement {
       context match {
         case VariableRange(sym, first, last) => {
           val loop = new ForLoop(sym, first, last)
+          scope.addDeclaration(sym)
           scope += loop
           scope = loop
         }
@@ -89,6 +106,7 @@ class ProducerStatement extends Statement {
         }
         case DerivedExpression(sym, expression) => {
           val assignment = new Assignment(sym, expression)
+          scope.addDeclaration(sym)
           scope += assignment
         }
       }
index b44b8a4f735a08bef1f6124e9f5debb62d763e3b..3c1162cf2c7937483dc8da92240651029c1de7b1 100644 (file)
@@ -1,7 +1,7 @@
 package ofc.codegen
-import scala.collection.mutable.ArrayBuffer
 
 abstract class ScopeStatement(initialStatements: Seq[Statement] = Nil) extends Statement {
+  val declarations = scala.collection.mutable.Set[VarSymbol[_ <: Type]]()
   val statements = initialStatements.toBuffer
 
   def +=(stat: Statement) {
@@ -9,6 +9,12 @@ abstract class ScopeStatement(initialStatements: Seq[Statement] = Nil) extends S
   }
 
   def getStatements : Seq[Statement] = statements.toSeq
+
+  def addDeclaration(sym: VarSymbol[_ <: Type]) {
+    declarations += sym
+  }
+
+  def getDeclarations : Seq[VarSymbol[_ <: Type]] = declarations.toSeq
 }
 
 class BlockStatement(initialStatements: Seq[Statement] = Nil) extends ScopeStatement(initialStatements) {
index e4a8f50e01d1a1fa1502ad68b9f86e350e0d15e6..9d21ce70efde8e1ca858dca7f3bd7d3e8b6c3126 100644 (file)
@@ -21,19 +21,28 @@ class PPDFunctionSet(val basisName: String, dataName: String) extends FunctionSe
     val a3pos = ppdGlobalCount / (cellWidthInPPDs(0)*cellWidthInPPDs(1))
     val a2pos = (ppdGlobalCount % (cellWidthInPPDs(0)*cellWidthInPPDs(1)))/cellWidthInPPDs(0)
     val a1pos = ppdGlobalCount % cellWidthInPPDs(0)
-
     val ppdPos = List(a1pos, a2pos, a3pos)
 
     val tightbox = (~(basis % FieldSymbol[PointerType[ArrayType[StructType]]]("tight_boxes"))).readAt(sphereIndex)
+
+    // The offsets into the PPDs for the edges of the tightbox
     val ppdStartOffsets = for(dim <- 1 to 3) yield tightbox % FieldSymbol[IntType]("start_pts"+dim)
     val ppdFinishOffsets = for(dim <- 1 to 3) yield tightbox % FieldSymbol[IntType]("finish_pts"+dim)
 
-    val startPPDs = for(dim <- 0 to 2) yield (tightbox % FieldSymbol[IntType]("start_ppd"+(dim+1)) + cellWidthInPPDs(dim)) % cellWidthInPPDs(dim)
-    val finishPPDs = for(dim <- 0 to 2) yield (tightbox % FieldSymbol[IntType]("finish_ppd"+(dim+1)) + cellWidthInPPDs(dim)) % cellWidthInPPDs(dim)
+    // The first and last PPDs in PPD co-ordinates (inside simulation cell)
+    val startPPDs = for(dim <- 0 to 2) yield 
+      producer.addExpression("start_ppd"+(dim+1), (tightbox % FieldSymbol[IntType]("start_ppd"+(dim+1)) + cellWidthInPPDs(dim)) % cellWidthInPPDs(dim))
+    val finishPPDs = for(dim <- 0 to 2) yield 
+      producer.addExpression("finish_ppd"+(dim+1),(tightbox % FieldSymbol[IntType]("finish_ppd"+(dim+1)) + cellWidthInPPDs(dim)) % cellWidthInPPDs(dim))
+
+    // Offsets for the current PPD being iterated over
+    val loopStarts = for(dim <- 0 to 2) yield 
+      producer.addExpression("start_pt"+(dim+1), new ConditionalValue[IntType](startPPDs(dim) |==| ppdPos(dim), ppdStartOffsets(dim), 1))
 
-    val loopStarts = for(dim <- 0 to 2) yield new ConditionalValue[IntType](startPPDs(dim) |==| ppdPos(dim), ppdStartOffsets(dim), 1)
-    val loopEnds = for(dim <- 0 to 2) yield new ConditionalValue[IntType](finishPPDs(dim) |==| ppdPos(dim), ppdFinishOffsets(dim), ppdWidths(dim))
+    val loopEnds = for(dim <- 0 to 2) yield 
+      producer.addExpression("end_pt"+(dim+1), new ConditionalValue[IntType](finishPPDs(dim) |==| ppdPos(dim), ppdFinishOffsets(dim), ppdWidths(dim)))
 
+    // Loops for iterating over the PPD itself
     val ppdIndices = for(dim <- 0 to 2) yield producer.addIteration("point"+(dim+1), loopStarts(dim), loopEnds(dim)) 
     
     producer