]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Work on generating loop declarations.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Wed, 25 Jan 2012 10:41:14 +0000 (10:41 +0000)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Wed, 25 Jan 2012 10:41:14 +0000 (10:41 +0000)
src/ofc/generators/onetep/CodeGenerator.scala
src/ofc/generators/onetep/Tree.scala

index 3d2f9dbfae2eaf29bc88cd9f1073cf8dae20c5cd..da3dae7e4d8e3eca95776e942b693f7adc21ff0d 100644 (file)
@@ -5,13 +5,18 @@ class IndexNames {
   var nextIndexID = 0
   val names = new HashMap[Index, String]()
 
-  def addIndex(index: Index) {
+  def addIndex(index: Index) {
     val name = index.getName + "_" + nextIndexID
     nextIndexID += 1
     names += (index -> name)
+    name
   }
 
-  def apply(index: Index) = names(index)
+  def apply(index: Index) = 
+   if (names.contains(index))
+     names(index)
+   else
+     addIndex(index)
 }
 
 class CodeGenerator {
@@ -25,7 +30,17 @@ class CodeGenerator {
     name
   }
 
+  def collectDeclarations(term: IterationSpace) : Set[String] = {
+    val declarations = for(index <- term.getSpatialIndices ++ term.getDiscreteIndices; 
+                           declaration <- index.getDeclarations(indexNames)) yield declaration 
+
+    var declarationsSet = declarations.toSet
+    for (op <- term.getOperands) declarationsSet ++= collectDeclarations(op)
+    declarationsSet
+  }
+
   def apply(assignment: Assignment) {
+    collectDeclarations(assignment)
     generateCode(assignment)
   }
 
@@ -53,10 +68,10 @@ class CodeGenerator {
       code append "real(kind=DP), allocatable, dimension" + (":"*concreteIndexList.size).mkString("(",",",")") + " :: " + 
         storageName + "\n"
       code append "allocate("+ storageName +
-        (concreteIndexList map ((x : Index) => x.getDenseWidth)).mkString("(",",",")") + ", stat=ierr)"
+        (concreteIndexList map ((x : Index) => x.getDenseWidth)).mkString("(",",",")") + ", stat=ierr)\n"
 
       // We've declared temporary storage, now create the loops to populate it
-      //for (index <- concreteIndexList) code append index.generateIterationHeader(indexNames)
+      for (index <- concreteIndexList) code append index.generateIterationHeader(indexNames) + "\n"
 
       println(code.mkString)
       System.exit(0)
index cb758a298ef826afb0c8b70c306c7725c6684d7a..15792fc91b9aab3bb5ec26ff9d20ca6971b3cb5d 100644 (file)
@@ -8,7 +8,8 @@ trait Index {
   def getName : String
   def getDependencies : Set[Index]
   def getDenseWidth : String
-  //def generateIterationHeader(names: IndexNames) : String
+  def generateIterationHeader(names: IndexNames) : String
+  def getDeclarations(names: IndexNames) : List[String]
 }
 
 trait SpatialIndex extends Index
@@ -50,12 +51,16 @@ class GeneralInnerProduct(operands: List[IterationSpace], removedIndices: Set[In
     def getDependencies = Set()
     def getName = "dense_spatial_index"
     def getDenseWidth = original.getDenseWidth
+    def generateIterationHeader(names: IndexNames) = "do "+names(this)+"=1,"+getDenseWidth
+    def getDeclarations(names: IndexNames) = List("integer :: "+names(this))
   }
 
   class DenseDiscreteIndex(parent: GeneralInnerProduct, original: DiscreteIndex) extends DiscreteIndex {
     def getDependencies = Set()
     def getName = "dense_discrete_index"
     def getDenseWidth = original.getDenseWidth
+    def generateIterationHeader(names: IndexNames) = "do "+names(this)+"=1,"+getDenseWidth
+    def getDeclarations(names: IndexNames) = List("integer :: "+names(this))
   }
 
   val spatialIndices =
@@ -83,6 +88,8 @@ class Reciprocal(op: IterationSpace) extends IterationSpace {
     def getName = "reciprocal_index_" + dimension
     def getDependencies = Set()
     def getDenseWidth = original.getDenseWidth
+    def generateIterationHeader(names: IndexNames) = "do "+names(this)+"=1,"+getDenseWidth
+    def getDeclarations(names: IndexNames) = List("integer :: "+names(this))
   }
 
   val spatialIndices = for (dimension <- 0 until op.getSpatialIndices.size) yield 
@@ -106,6 +113,9 @@ class SpatialRestriction(op: IterationSpace) extends IterationSpace {
     def getName = "restriction_index_" + dimension
     def getDependencies = Set()
     def getDenseWidth = throw new UnimplementedException("Restriction unimplemnted")
+
+    def generateIterationHeader(names: IndexNames) = throw new UnimplementedException("how the hell does this work?")
+    def getDeclarations(names: IndexNames) = Nil
   }
 
   val spatialIndices = for (dimension <- 0 until op.getSpatialIndices.size) yield new RestrictedIndex(this, dimension)
@@ -124,12 +134,27 @@ class SPAM3(name : String) extends Matrix {
     def getName = "row_index"
     def getDependencies = Set()
     def getDenseWidth = throw new UnimplementedException("Matrix size unimplemented")
+
+    def generateIterationHeader(names: IndexNames) = {
+      val indexName = names(this)
+      "do "+indexName+"=1,"+getDenseWidth
+    }
+
+    def getDeclarations(names: IndexNames) = List("integer :: "+names(this))
   }
+
   class ColIndex(parent: SPAM3) extends DiscreteIndex {
     override def toString = parent + ".col"
     def getName = "row_index"
     def getDependencies = Set()
     def getDenseWidth = throw new UnimplementedException("Matrix size unimplemented")
+
+    def generateIterationHeader(names: IndexNames) = {
+      val indexName = names(this)
+      "do "+indexName+"=1,"+getDenseWidth
+    }
+
+    def getDeclarations(names: IndexNames) = List("integer :: "+names(this))
   }
 
   val rowIndex = new RowIndex(this)
@@ -145,7 +170,14 @@ class PPDFunctionSet(val basis : String, data : String) extends FunctionSet {
     def getName = "sphere_index"
     def getDependencies = Set()
     def getDenseWidth = throw new UnimplementedException("Sphere count unimplemented")
-    //TODO: def getSphere = parent.basis + "%spheres(????)"
+    def getSphere(names: IndexNames) = parent.basis + "%spheres("+names(this)+")"
+
+    def generateIterationHeader(names: IndexNames) = {
+      val indexName = names(this)
+      "do "+indexName+"=1,"+getDenseWidth
+    }
+
+    def getDeclarations(names: IndexNames) = List("integer :: "+names(this))
   }
 
   class PPDIndex(parent: PPDFunctionSet) extends DiscreteIndex {
@@ -153,12 +185,19 @@ class PPDFunctionSet(val basis : String, data : String) extends FunctionSet {
     def getDependencies = Set[Index](parent.getSphereIndex())
     //TODO: def getDenseWidth = parent.getSphereIndex.getSphere + "%n_ppds_sphere"
     def getDenseWidth = parent.basis+"%max_n_ppds_sphere"
+
+    def generateIterationHeader(names: IndexNames) = 
+      "do "+names(this)+"=1,"+parent.getSphereIndex.getSphere(names)+"%n_ppds_sphere"
+
+    def getDeclarations(names: IndexNames) = List("integer :: "+names(this))
   }
 
   class IntraPPDIndex(parent: PPDFunctionSet, dimension: Int) extends SpatialIndex {
     def getName = "intra_ppd_index_" + dimension
     def getDependencies = Set[Index](parent.getPPDIndex)
     def getDenseWidth = "pub_cell%total_pt"+(dimension+1)
+    def generateIterationHeader(names: IndexNames) = "do "+names(this)+"=1,"+"pub_cell%n_pt"+(dimension+1)
+    def getDeclarations(names: IndexNames) = List("integer :: "+names(this))
   }
 
   val ppdIndex = new PPDIndex(this)