]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Handle external indices differently.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Tue, 31 Jan 2012 20:52:56 +0000 (20:52 +0000)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Tue, 31 Jan 2012 20:52:56 +0000 (20:52 +0000)
src/ofc/generators/onetep/CodeGenerator.scala
src/ofc/generators/onetep/Tree.scala

index a62876162d4b2614b8f761d8b3ce779fefdfe26a..737c44fc0af380d5e4b6b24084536d48280d5466 100644 (file)
@@ -25,12 +25,17 @@ class NameManager {
   }
 }
 
+object CodeGenerator {
+  def getAllSpaces(term: IterationSpace) : Set[IterationSpace] =
+    term.getOperands.toSet.flatMap(getAllSpaces(_: IterationSpace)) + term
+}
+
 class CodeGenerator {
   val code = new StringBuilder()
   val nameManager = new NameManager()
   
   def collectDeclarations(term: IterationSpace) : Set[String] = {
-    val declarations = for(index <- term.getSpatialIndices ++ term.getDiscreteIndices; 
+    val declarations = for(index <- term.getIndices;
                            declaration <- index.getDeclarations(nameManager)) yield declaration 
 
     var declarationsSet = declarations.toSet
@@ -45,6 +50,20 @@ class CodeGenerator {
   }
 
   def generateCode(space: IterationSpace) {
+    val allSpaces = CodeGenerator.getAllSpaces(space)
+    val allIndices = allSpaces flatMap (_.getIndices)
+
+    println("dumping operations")
+    for(op <- allSpaces)
+      println(op)
+    println("finished dumping operations\n\ndumping indices")
+    for (i <- allIndices)
+      println(i)
+    println("finished dumping indices")
+
+    // Next: we dump all these things into a prefix map
+    System.exit(0)
+
     val operands = space.getOperands
 
     for(operand <- operands)
index 93ba521170eea3802efea087fa7f0a7f86cf7a0b..6d16e5f4dcdd0891ce39391c84ea23c67f386078 100644 (file)
@@ -22,6 +22,9 @@ trait IterationSpace {
   def getOperands : List[IterationSpace]
   def getSpatialIndices : List[SpatialIndex]
   def getDiscreteIndices : List[DiscreteIndex]
+  def getExternalIndices : Set[Index]
+  def getInternalIndices : Set[Index] = (getSpatialIndices ++ getDiscreteIndices).toSet
+  def getIndices : Set[Index] = getInternalIndices ++ getExternalIndices
 }
 
 trait DataSpace extends IterationSpace {
@@ -37,6 +40,7 @@ class Assignment(indexBindings: IndexBindings, lhs: DataSpace, rhs: IterationSpa
   def getOperands = List(lhs,rhs)
   def getSpatialIndices = Nil
   def getDiscreteIndices = Nil
+  def getExternalIndices = Set()
   def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed")
 }
 
@@ -44,6 +48,7 @@ class Scalar(value: Double) extends IterationSpace {
   def getOperands = Nil
   def getSpatialIndices = Nil
   def getDiscreteIndices = Nil
+  def getExternalIndices = Set()
   def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed")
 }
 
@@ -84,6 +89,7 @@ class GeneralInnerProduct(operands: List[IterationSpace], removedIndices: Set[In
   def getOperands = operands
   def getSpatialIndices = spatialIndices
   def getDiscreteIndices = discreteIndices
+  def getExternalIndices = Set()
   def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed")
 }
 
@@ -103,6 +109,7 @@ class Reciprocal(op: IterationSpace) extends IterationSpace {
   def getOperands = List(op)
   def getSpatialIndices = spatialIndices.toList
   def getDiscreteIndices = op.getDiscreteIndices
+  def getExternalIndices = Set()
   def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed")
 }
 
@@ -110,6 +117,7 @@ class Laplacian(op: IterationSpace) extends IterationSpace {
   def getOperands = List(op)
   def getSpatialIndices = op.getSpatialIndices
   def getDiscreteIndices = op.getDiscreteIndices
+  def getExternalIndices = Set()
   def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed")
 }
 
@@ -129,6 +137,7 @@ class SpatialRestriction(op: IterationSpace) extends IterationSpace {
   def getOperands = List(op)
   def getSpatialIndices = spatialIndices.toList
   def getDiscreteIndices = op.getDiscreteIndices
+  def getExternalIndices = Set()
   def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed")
 }
 
@@ -171,6 +180,7 @@ class SPAM3(name : String) extends Matrix {
 
   def getSpatialIndices = Nil
   def getDiscreteIndices = List(rowIndex, colIndex)
+  def getExternalIndices = Set()
   def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed")
 }
 
@@ -241,6 +251,7 @@ class PPDFunctionSet(val basis : String, data : String) extends FunctionSet {
 
   def getSpatialIndices = spatialIndices.toList
   def getDiscreteIndices = List(getSphereIndex)
+  def getExternalIndices = Set(getPPDIndex)
 
   def getAccessExpression(indexNames: NameManager) =  {
     val index = getSphere(indexNames)+"%offset + &\n" +