]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
More work on iteration generation.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Wed, 25 Jan 2012 19:19:44 +0000 (19:19 +0000)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Wed, 25 Jan 2012 19:19:44 +0000 (19:19 +0000)
src/ofc/generators/onetep/CodeGenerator.scala
src/ofc/generators/onetep/Tree.scala

index da3dae7e4d8e3eca95776e942b693f7adc21ff0d..a62876162d4b2614b8f761d8b3ce779fefdfe26a 100644 (file)
@@ -1,7 +1,7 @@
 package ofc.generators.onetep
 import scala.collection.mutable.HashMap
 
-class IndexNames {
+class NameManager {
   var nextIndexID = 0
   val names = new HashMap[Index, String]()
 
@@ -17,22 +17,21 @@ class IndexNames {
      names(index)
    else
      addIndex(index)
+
+  def newIdentifier(prefix: String) = {
+    val name = prefix + "_" + nextIndexID
+    nextIndexID +=1 
+    name
+  }
 }
 
 class CodeGenerator {
   val code = new StringBuilder()
-  val indexNames = new IndexNames()
-  var nextVarIndex = 0
+  val nameManager = new NameManager()
   
-  def newLocalVar : String = {
-    val name = "var_" + nextVarIndex
-    nextVarIndex += 1
-    name
-  }
-
   def collectDeclarations(term: IterationSpace) : Set[String] = {
     val declarations = for(index <- term.getSpatialIndices ++ term.getDiscreteIndices; 
-                           declaration <- index.getDeclarations(indexNames)) yield declaration 
+                           declaration <- index.getDeclarations(nameManager)) yield declaration 
 
     var declarationsSet = declarations.toSet
     for (op <- term.getOperands) declarationsSet ++= collectDeclarations(op)
@@ -40,7 +39,8 @@ class CodeGenerator {
   }
 
   def apply(assignment: Assignment) {
-    collectDeclarations(assignment)
+    val declarations = collectDeclarations(assignment)
+    for(declaration <- declarations) code append declaration+"\n"
     generateCode(assignment)
   }
 
@@ -53,30 +53,40 @@ class CodeGenerator {
     val lowerIndices = operands flatMap (x => x.getDiscreteIndices ++ x.getSpatialIndices) toSet
     val upperIndices = space.getDiscreteIndices ++ space.getSpatialIndices toSet
 
-    val createdIndices = upperIndices -- lowerIndices
     val destroyedIndices = lowerIndices -- upperIndices
-
-    println("created: "+createdIndices.mkString(","))
     println("destroyed: "+destroyedIndices.mkString(","))
 
-    if (!destroyedIndices.isEmpty) {
-      // We search for all indices bound to the one being destroyed
-      // We generate a composite iteration over those loops
-      // If GeneralInnerProduct rebuilds derived indices, we need to be able to construct a valid size
-      val concreteIndexList = destroyedIndices.toList
-      val storageName = newLocalVar
-      code append "real(kind=DP), allocatable, dimension" + (":"*concreteIndexList.size).mkString("(",",",")") + " :: " + 
-        storageName + "\n"
-      code append "allocate("+ storageName +
-        (concreteIndexList map ((x : Index) => x.getDenseWidth)).mkString("(",",",")") + ", stat=ierr)\n"
-
-      // We've declared temporary storage, now create the loops to populate it
-      for (index <- concreteIndexList) code append index.generateIterationHeader(indexNames) + "\n"
-
-      println(code.mkString)
-      System.exit(0)
+    for (op <- operands) {
+      val opDestroyedIndices = (op.getSpatialIndices ++ op.getDiscreteIndices).toSet & destroyedIndices
+
+      if (!opDestroyedIndices.isEmpty) {
+        // We search for all indices bound to the one being destroyed
+        // We generate a composite iteration over those loops
+        // If GeneralInnerProduct rebuilds derived indices, we need to be able to construct a valid size
+        val concreteIndexList = opDestroyedIndices.toList
+        val storageName = nameManager.newIdentifier("dense")
+        code append "real(kind=DP), allocatable, dimension" + (":"*concreteIndexList.size).mkString("(",", &\n",")") + " :: " + 
+          storageName + "\n"
+        code append "allocate("+ storageName +
+          (concreteIndexList map ((x : Index) => x.getDenseWidth)).mkString("(",",",")") + ", stat=ierr)\n"
+  
+        // We've declared temporary storage, now create the loops to populate it
+        for (index <- concreteIndexList) code append index.generateIterationHeader(nameManager) + "\n"
+        val lhs = storageName + (concreteIndexList map ((x: Index) => x.getDensePosition(nameManager))).mkString("(",", &\n",")")
+        val rhs = op.getAccessExpression(nameManager)
+        code append lhs + " = &\n" + rhs + "\n"
+        for (index <- concreteIndexList) code append index.generateIterationFooter(nameManager) + "\n"
+  
+        println(code.mkString)
+        System.exit(0)
+      }
     }
 
-    // When an index is destroyed -> generate a possibly composite loop over the index
+    val createdIndices = upperIndices -- lowerIndices
+    println("created: "+createdIndices.mkString(","))
+
+  
+    // We've now moved al necessary destroyed indices into dense buffers
+    // We now generate the actual loop for space. This may involve a composite iteration construction
   }
 }
index 15792fc91b9aab3bb5ec26ff9d20ca6971b3cb5d..1d5b22e1bff07cde722649e457cebded7683c86f 100644 (file)
@@ -8,15 +8,17 @@ trait Index {
   def getName : String
   def getDependencies : Set[Index]
   def getDenseWidth : String
-  def generateIterationHeader(names: IndexNames) : String
-  def getDeclarations(names: IndexNames) : List[String]
+  def getDensePosition(names: NameManager) : String = names(this)
+  def generateIterationHeader(names: NameManager) : String
+  def generateIterationFooter(names: NameManager) : String
+  def getDeclarations(names: NameManager) : List[String]
 }
 
 trait SpatialIndex extends Index
 trait DiscreteIndex extends Index
 
 trait IterationSpace {
-  def getAccessExpression(indexNames: IndexNames) : String
+  def getAccessExpression(indexNames: NameManager) : String
   def getOperands() : List[IterationSpace]
   def getSpatialIndices() : List[SpatialIndex]
   def getDiscreteIndices() : List[DiscreteIndex]
@@ -35,14 +37,14 @@ class Assignment(indexBindings: IndexBindings, lhs: DataSpace, rhs: IterationSpa
   def getOperands = List(lhs,rhs)
   def getSpatialIndices = Nil
   def getDiscreteIndices = Nil
-  def getAccessExpression(indexNames: IndexNames) = throw new UnimplementedException("Access failed")
+  def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed")
 }
 
 class Scalar(value: Double) extends IterationSpace {
   def getOperands() = Nil
   def getSpatialIndices() = Nil
   def getDiscreteIndices() = Nil
-  def getAccessExpression(indexNames: IndexNames) = throw new UnimplementedException("Access failed")
+  def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed")
 }
 
 class GeneralInnerProduct(operands: List[IterationSpace], removedIndices: Set[Index]) extends IterationSpace {
@@ -51,16 +53,18 @@ class GeneralInnerProduct(operands: List[IterationSpace], removedIndices: Set[In
     def getDependencies = Set()
     def getName = "dense_spatial_index"
     def getDenseWidth = original.getDenseWidth
-    def generateIterationHeader(names: IndexNames) = "do "+names(this)+"=1,"+getDenseWidth
-    def getDeclarations(names: IndexNames) = List("integer :: "+names(this))
+    def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+getDenseWidth
+    def generateIterationFooter(names: NameManager) = "end do"
+    def getDeclarations(names: NameManager) = List("integer :: "+names(this))
   }
 
   class DenseDiscreteIndex(parent: GeneralInnerProduct, original: DiscreteIndex) extends DiscreteIndex {
     def getDependencies = Set()
     def getName = "dense_discrete_index"
     def getDenseWidth = original.getDenseWidth
-    def generateIterationHeader(names: IndexNames) = "do "+names(this)+"=1,"+getDenseWidth
-    def getDeclarations(names: IndexNames) = List("integer :: "+names(this))
+    def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+getDenseWidth
+    def generateIterationFooter(names: NameManager) = "end do"
+    def getDeclarations(names: NameManager) = List("integer :: "+names(this))
   }
 
   val spatialIndices =
@@ -80,7 +84,7 @@ class GeneralInnerProduct(operands: List[IterationSpace], removedIndices: Set[In
   def getOperands = operands
   def getSpatialIndices() = spatialIndices
   def getDiscreteIndices() = discreteIndices
-  def getAccessExpression(indexNames: IndexNames) = throw new UnimplementedException("Access failed")
+  def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed")
 }
 
 class Reciprocal(op: IterationSpace) extends IterationSpace {
@@ -88,8 +92,9 @@ class Reciprocal(op: IterationSpace) extends IterationSpace {
     def getName = "reciprocal_index_" + dimension
     def getDependencies = Set()
     def getDenseWidth = original.getDenseWidth
-    def generateIterationHeader(names: IndexNames) = "do "+names(this)+"=1,"+getDenseWidth
-    def getDeclarations(names: IndexNames) = List("integer :: "+names(this))
+    def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+getDenseWidth
+    def generateIterationFooter(names: NameManager) = "end do"
+    def getDeclarations(names: NameManager) = List("integer :: "+names(this))
   }
 
   val spatialIndices = for (dimension <- 0 until op.getSpatialIndices.size) yield 
@@ -98,14 +103,14 @@ class Reciprocal(op: IterationSpace) extends IterationSpace {
   def getOperands = List(op)
   def getSpatialIndices() = spatialIndices.toList
   def getDiscreteIndices() = op.getDiscreteIndices
-  def getAccessExpression(indexNames: IndexNames) = throw new UnimplementedException("Access failed")
+  def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed")
 }
 
 class Laplacian(op: IterationSpace) extends IterationSpace {
   def getOperands() = List(op)
   def getSpatialIndices() = op.getSpatialIndices
   def getDiscreteIndices() = op.getDiscreteIndices
-  def getAccessExpression(indexNames: IndexNames) = throw new UnimplementedException("Access failed")
+  def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed")
 }
 
 class SpatialRestriction(op: IterationSpace) extends IterationSpace {
@@ -114,8 +119,9 @@ class SpatialRestriction(op: IterationSpace) extends IterationSpace {
     def getDependencies = Set()
     def getDenseWidth = throw new UnimplementedException("Restriction unimplemnted")
 
-    def generateIterationHeader(names: IndexNames) = throw new UnimplementedException("how the hell does this work?")
-    def getDeclarations(names: IndexNames) = Nil
+    def generateIterationHeader(names: NameManager) = throw new UnimplementedException("how the hell does this work?")
+    def generateIterationFooter(names: NameManager) = throw new UnimplementedException("how does this work either?")
+    def getDeclarations(names: NameManager) = Nil
   }
 
   val spatialIndices = for (dimension <- 0 until op.getSpatialIndices.size) yield new RestrictedIndex(this, dimension)
@@ -123,7 +129,7 @@ class SpatialRestriction(op: IterationSpace) extends IterationSpace {
   def getOperands() = List(op)
   def getSpatialIndices() = spatialIndices.toList
   def getDiscreteIndices() = op.getDiscreteIndices
-  def getAccessExpression(indexNames: IndexNames) = throw new UnimplementedException("Access failed")
+  def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed")
 }
 
 class SPAM3(name : String) extends Matrix {
@@ -135,12 +141,13 @@ class SPAM3(name : String) extends Matrix {
     def getDependencies = Set()
     def getDenseWidth = throw new UnimplementedException("Matrix size unimplemented")
 
-    def generateIterationHeader(names: IndexNames) = {
+    def generateIterationHeader(names: NameManager) = {
       val indexName = names(this)
       "do "+indexName+"=1,"+getDenseWidth
     }
 
-    def getDeclarations(names: IndexNames) = List("integer :: "+names(this))
+    def generateIterationFooter(names: NameManager) = "end do"
+    def getDeclarations(names: NameManager) = List("integer :: "+names(this))
   }
 
   class ColIndex(parent: SPAM3) extends DiscreteIndex {
@@ -149,12 +156,14 @@ class SPAM3(name : String) extends Matrix {
     def getDependencies = Set()
     def getDenseWidth = throw new UnimplementedException("Matrix size unimplemented")
 
-    def generateIterationHeader(names: IndexNames) = {
+    def generateIterationHeader(names: NameManager) = {
       val indexName = names(this)
       "do "+indexName+"=1,"+getDenseWidth
     }
 
-    def getDeclarations(names: IndexNames) = List("integer :: "+names(this))
+
+    def generateIterationFooter(names: NameManager) = "end do"
+    def getDeclarations(names: NameManager) = List("integer :: "+names(this))
   }
 
   val rowIndex = new RowIndex(this)
@@ -162,7 +171,7 @@ class SPAM3(name : String) extends Matrix {
 
   def getSpatialIndices() = Nil
   def getDiscreteIndices() = List(rowIndex, colIndex)
-  def getAccessExpression(indexNames: IndexNames) = throw new UnimplementedException("Access failed")
+  def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed")
 }
 
 class PPDFunctionSet(val basis : String, data : String) extends FunctionSet {
@@ -170,34 +179,56 @@ class PPDFunctionSet(val basis : String, data : String) extends FunctionSet {
     def getName = "sphere_index"
     def getDependencies = Set()
     def getDenseWidth = throw new UnimplementedException("Sphere count unimplemented")
-    def getSphere(names: IndexNames) = parent.basis + "%spheres("+names(this)+")"
 
-    def generateIterationHeader(names: IndexNames) = {
+    def generateIterationHeader(names: NameManager) = {
       val indexName = names(this)
       "do "+indexName+"=1,"+getDenseWidth
     }
 
-    def getDeclarations(names: IndexNames) = List("integer :: "+names(this))
+    def generateIterationFooter(names: NameManager) = "end do"
+    def getDeclarations(names: NameManager) = List("integer :: "+names(this))
   }
 
   class PPDIndex(parent: PPDFunctionSet) extends DiscreteIndex {
+    var denseIndexNames : List[String] = Nil
+
     def getName = "ppd_index"
     def getDependencies = Set[Index](parent.getSphereIndex())
+    def getDensePPDIndices = denseIndexNames
+
     //TODO: def getDenseWidth = parent.getSphereIndex.getSphere + "%n_ppds_sphere"
     def getDenseWidth = parent.basis+"%max_n_ppds_sphere"
 
-    def generateIterationHeader(names: IndexNames) = 
-      "do "+names(this)+"=1,"+parent.getSphereIndex.getSphere(names)+"%n_ppds_sphere"
+    def generateIterationHeader(names: NameManager) = {
+      
+      val initDense = "call basis_find_ppd_in_neighbour(" + denseIndexNames.mkString(",") + ", &\n" +
+        parent.getSphere(names) + "%ppd_list(1," + names(this) + "), &\n" +
+        parent.getSphere(names) + "%ppd_list(2," + names(this) + "), &\n" +
+        "pub_cell%n_ppds_a1, pub_cell%n_ppds_a2, pub_cell%n_ppds_a3)"
+
+      val loopDeclaration = "do "+names(this)+"=1,"+parent.getSphere(names)+"%n_ppds_sphere"
+
+      initDense + "\n" + loopDeclaration
+    }
+
+    def generateIterationFooter(names: NameManager) = "end do"
 
-    def getDeclarations(names: IndexNames) = List("integer :: "+names(this))
+    def getDeclarations(names: NameManager) = {
+      denseIndexNames = (for (dim <- 0 to 2) yield names.newIdentifier("derived_ppd_position_"+dim)).toList
+      denseIndexNames.map(x => "integer :: " + x) ++ List("integer :: "+names(this))
+    }
   }
 
   class IntraPPDIndex(parent: PPDFunctionSet, dimension: Int) extends SpatialIndex {
     def getName = "intra_ppd_index_" + dimension
     def getDependencies = Set[Index](parent.getPPDIndex)
     def getDenseWidth = "pub_cell%total_pt"+(dimension+1)
-    def generateIterationHeader(names: IndexNames) = "do "+names(this)+"=1,"+"pub_cell%n_pt"+(dimension+1)
-    def getDeclarations(names: IndexNames) = List("integer :: "+names(this))
+    def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+"pub_cell%n_pt"+(dimension+1)
+    def generateIterationFooter(names: NameManager) = "end do"
+    def getDeclarations(names: NameManager) = List("integer :: "+names(this))
+
+    override def getDensePosition(names: NameManager) = 
+      parent.getPPDIndex.getDensePPDIndices(dimension) + "*pub_cell%n_pt"+(dimension+1) + " + " + names(this)
   }
 
   val ppdIndex = new PPDIndex(this)
@@ -206,10 +237,20 @@ class PPDFunctionSet(val basis : String, data : String) extends FunctionSet {
 
   def getPPDIndex() = ppdIndex
   def getSphereIndex() = sphereIndex
+  def getSphere(names: NameManager) = basis + "%spheres("+names(getSphereIndex)+")"
 
   def getSpatialIndices() = spatialIndices.toList
   def getDiscreteIndices() = List(getSphereIndex(), getPPDIndex())
-  def getAccessExpression(indexNames: IndexNames) = throw new UnimplementedException("Access failed")
+
+  def getAccessExpression(indexNames: NameManager) =  {
+    val index = getSphere(indexNames)+"%offset + &\n" + 
+      "("+indexNames(getPPDIndex)+"-1)*pub_cell%n_pts - 1 + &\n" +
+      "(" + indexNames(spatialIndices(2)) + "-1)*pub_cell%n_pt2*pub_cell%n_pt1 + &\n" +
+      "(" + indexNames(spatialIndices(1)) + "-1)*pub_cell%n_pt1 + &\n" +
+      indexNames(spatialIndices(0))
+
+    data+"("+index+")"
+  }
 }
 
 class BindingIndex(name : String) {