From: Francis Russell Date: Sat, 7 Apr 2012 09:05:09 +0000 (+0100) Subject: Use better control-flow nesting in ProduderStatement. X-Git-Url: https://git.unchartedbackwaters.co.uk/w/?a=commitdiff_plain;h=534bb8e4d288471bb1178da86adfa55a8bf0bacf;p=francis%2Fofc.git Use better control-flow nesting in ProduderStatement. When we do not have any dependencies between a loops, a predicates and expressions, place expressions outside loops, but both loops and expressions inside predicates. --- diff --git a/src/ofc/codegen/Assignment.scala b/src/ofc/codegen/Assignment.scala index f729988..e932eba 100644 --- a/src/ofc/codegen/Assignment.scala +++ b/src/ofc/codegen/Assignment.scala @@ -1,5 +1,7 @@ package ofc.codegen -class Assignment(symbol: VarSymbol[_ <: Type], expression: Expression[_ <: Type]) extends Statement { +class Assignment(lhs: Expression[_ <: Type], rhs: Expression[_ <: Type]) extends Statement { + def getLHS : Expression[_] = lhs + def getRHS : Expression[_] = rhs // TODO: type check assignment } diff --git a/src/ofc/codegen/FortranGenerator.scala b/src/ofc/codegen/FortranGenerator.scala index fdc6e89..e852c21 100644 --- a/src/ofc/codegen/FortranGenerator.scala +++ b/src/ofc/codegen/FortranGenerator.scala @@ -31,14 +31,17 @@ class SymbolManager { val name = createNewName(s) names += name symbols += s -> new SymbolInfo(name) + } else { + throw new LogicError("Attempted to add symbol more than once (multiple declarations?): "+sym.toString) } + case (_: NamedUnboundVarSymbol[_]) => throw new LogicError("Attempted to add unbound symbol to SymbolManager.") } } def getName(sym: VarSymbol[_]) = symbols.get(sym) match { - case None => throw new LogicError("Unknown symbol "+sym.toString) + case None => throw new LogicError("Unknown symbol (missing declaration?): "+sym.toString) case Some(info) => info.getName } } @@ -55,6 +58,7 @@ class FortranGenerator { case (x : BlockStatement) => processScope(x) case (x : ProducerStatement) => processStatement(x.toConcrete) case (x : ForLoop) => processForLoop(x) + case (a : Assignment) => processAssignment(a) case x => throw new UnimplementedException("Unknown statement type in FORTRAN generator: " + x.toString) } @@ -136,7 +140,6 @@ class FortranGenerator { private def processForLoop(stat: ForLoop) { val index = stat.getIndex - symbolManager.addSymbol(index) val name = symbolManager.getName(index) val begin = buildExpression(stat.getBegin) val end = buildExpression(stat.getEnd) @@ -151,11 +154,18 @@ class FortranGenerator { } private def processScope(scope: ScopeStatement) { + for (sym <- scope.getDeclarations) { + symbolManager.addSymbol(sym) + } for(stat <- scope.getStatements) { processStatement(stat) } } + private def processAssignment(assignment: Assignment) { + addLine("%s = %s".format(buildExpression(assignment.getLHS), buildExpression(assignment.getRHS))) + } + private def addLine(line: String) { buffer += " "*indentLevel + line } diff --git a/src/ofc/codegen/ProducerStatement.scala b/src/ofc/codegen/ProducerStatement.scala index c9d3acf..c734264 100644 --- a/src/ofc/codegen/ProducerStatement.scala +++ b/src/ofc/codegen/ProducerStatement.scala @@ -3,14 +3,30 @@ import ofc.util.Ordering class ProducerStatement extends Statement { object Context { + def preferenceOrdering(ordering: (Context, Context) => Boolean) : (Context, Context) => Boolean = { + // This ensures that the nesting ordering is Predicate, DerivedExpression, VariableRange + // when no other dependencies exist. + (left, right) => if (ordering(left, right)) + true + else if (ordering(right, left)) + false + else (left, right) match { + case (_: Predicate, _: DerivedExpression) => true + case (_: Predicate, _: VariableRange) => true + case (_: DerivedExpression, _: VariableRange) => true + case _ => false + } + } + def sort(contexts: Seq[Context]) : Seq[Context] = { def pathFunction(c1: Context, c2: Context) = c1.tryCompare(c2) match { case Some(x) if x<0 => true case _ => false } - val totalOrdering = Ordering.transitiveClosure(contexts, pathFunction(_: Context, _: Context)) - contexts.sortWith((a,b) => totalOrdering.contains(a,b)) + val partialOrdering = Ordering.transitiveClosure(contexts, pathFunction(_: Context, _: Context)) + val augmentedOrdering = preferenceOrdering((a,b) => partialOrdering.contains(a,b)) + contexts.sortWith(augmentedOrdering) } } @@ -79,6 +95,7 @@ class ProducerStatement extends Statement { context match { case VariableRange(sym, first, last) => { val loop = new ForLoop(sym, first, last) + scope.addDeclaration(sym) scope += loop scope = loop } @@ -89,6 +106,7 @@ class ProducerStatement extends Statement { } case DerivedExpression(sym, expression) => { val assignment = new Assignment(sym, expression) + scope.addDeclaration(sym) scope += assignment } } diff --git a/src/ofc/codegen/ScopeStatement.scala b/src/ofc/codegen/ScopeStatement.scala index b44b8a4..3c1162c 100644 --- a/src/ofc/codegen/ScopeStatement.scala +++ b/src/ofc/codegen/ScopeStatement.scala @@ -1,7 +1,7 @@ package ofc.codegen -import scala.collection.mutable.ArrayBuffer abstract class ScopeStatement(initialStatements: Seq[Statement] = Nil) extends Statement { + val declarations = scala.collection.mutable.Set[VarSymbol[_ <: Type]]() val statements = initialStatements.toBuffer def +=(stat: Statement) { @@ -9,6 +9,12 @@ abstract class ScopeStatement(initialStatements: Seq[Statement] = Nil) extends S } def getStatements : Seq[Statement] = statements.toSeq + + def addDeclaration(sym: VarSymbol[_ <: Type]) { + declarations += sym + } + + def getDeclarations : Seq[VarSymbol[_ <: Type]] = declarations.toSeq } class BlockStatement(initialStatements: Seq[Statement] = Nil) extends ScopeStatement(initialStatements) { diff --git a/src/ofc/generators/onetep/PPDFunctionSet.scala b/src/ofc/generators/onetep/PPDFunctionSet.scala index e4a8f50..9d21ce7 100644 --- a/src/ofc/generators/onetep/PPDFunctionSet.scala +++ b/src/ofc/generators/onetep/PPDFunctionSet.scala @@ -21,19 +21,28 @@ class PPDFunctionSet(val basisName: String, dataName: String) extends FunctionSe val a3pos = ppdGlobalCount / (cellWidthInPPDs(0)*cellWidthInPPDs(1)) val a2pos = (ppdGlobalCount % (cellWidthInPPDs(0)*cellWidthInPPDs(1)))/cellWidthInPPDs(0) val a1pos = ppdGlobalCount % cellWidthInPPDs(0) - val ppdPos = List(a1pos, a2pos, a3pos) val tightbox = (~(basis % FieldSymbol[PointerType[ArrayType[StructType]]]("tight_boxes"))).readAt(sphereIndex) + + // The offsets into the PPDs for the edges of the tightbox val ppdStartOffsets = for(dim <- 1 to 3) yield tightbox % FieldSymbol[IntType]("start_pts"+dim) val ppdFinishOffsets = for(dim <- 1 to 3) yield tightbox % FieldSymbol[IntType]("finish_pts"+dim) - val startPPDs = for(dim <- 0 to 2) yield (tightbox % FieldSymbol[IntType]("start_ppd"+(dim+1)) + cellWidthInPPDs(dim)) % cellWidthInPPDs(dim) - val finishPPDs = for(dim <- 0 to 2) yield (tightbox % FieldSymbol[IntType]("finish_ppd"+(dim+1)) + cellWidthInPPDs(dim)) % cellWidthInPPDs(dim) + // The first and last PPDs in PPD co-ordinates (inside simulation cell) + val startPPDs = for(dim <- 0 to 2) yield + producer.addExpression("start_ppd"+(dim+1), (tightbox % FieldSymbol[IntType]("start_ppd"+(dim+1)) + cellWidthInPPDs(dim)) % cellWidthInPPDs(dim)) + val finishPPDs = for(dim <- 0 to 2) yield + producer.addExpression("finish_ppd"+(dim+1),(tightbox % FieldSymbol[IntType]("finish_ppd"+(dim+1)) + cellWidthInPPDs(dim)) % cellWidthInPPDs(dim)) + + // Offsets for the current PPD being iterated over + val loopStarts = for(dim <- 0 to 2) yield + producer.addExpression("start_pt"+(dim+1), new ConditionalValue[IntType](startPPDs(dim) |==| ppdPos(dim), ppdStartOffsets(dim), 1)) - val loopStarts = for(dim <- 0 to 2) yield new ConditionalValue[IntType](startPPDs(dim) |==| ppdPos(dim), ppdStartOffsets(dim), 1) - val loopEnds = for(dim <- 0 to 2) yield new ConditionalValue[IntType](finishPPDs(dim) |==| ppdPos(dim), ppdFinishOffsets(dim), ppdWidths(dim)) + val loopEnds = for(dim <- 0 to 2) yield + producer.addExpression("end_pt"+(dim+1), new ConditionalValue[IntType](finishPPDs(dim) |==| ppdPos(dim), ppdFinishOffsets(dim), ppdWidths(dim))) + // Loops for iterating over the PPD itself val ppdIndices = for(dim <- 0 to 2) yield producer.addIteration("point"+(dim+1), loopStarts(dim), loopEnds(dim)) producer