]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Allocate storage for operations that remove indices.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Wed, 25 Jan 2012 05:24:29 +0000 (05:24 +0000)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Wed, 25 Jan 2012 05:24:29 +0000 (05:24 +0000)
src/ofc/generators/onetep/CodeGenerator.scala
src/ofc/generators/onetep/Tree.scala

index 934090dc578f29f73ef5f7d0e1f0e6f0fa8cda08..3d2f9dbfae2eaf29bc88cd9f1073cf8dae20c5cd 100644 (file)
@@ -15,7 +15,15 @@ class IndexNames {
 }
 
 class CodeGenerator {
+  val code = new StringBuilder()
   val indexNames = new IndexNames()
+  var nextVarIndex = 0
+  
+  def newLocalVar : String = {
+    val name = "var_" + nextVarIndex
+    nextVarIndex += 1
+    name
+  }
 
   def apply(assignment: Assignment) {
     generateCode(assignment)
@@ -40,6 +48,17 @@ class CodeGenerator {
       // We search for all indices bound to the one being destroyed
       // We generate a composite iteration over those loops
       // If GeneralInnerProduct rebuilds derived indices, we need to be able to construct a valid size
+      val concreteIndexList = destroyedIndices.toList
+      val storageName = newLocalVar
+      code append "real(kind=DP), allocatable, dimension" + (":"*concreteIndexList.size).mkString("(",",",")") + " :: " + 
+        storageName + "\n"
+      code append "allocate("+ storageName +
+        (concreteIndexList map ((x : Index) => x.getDenseWidth)).mkString("(",",",")") + ", stat=ierr)"
+
+      // We've declared temporary storage, now create the loops to populate it
+      //for (index <- concreteIndexList) code append index.generateIterationHeader(indexNames)
+
+      println(code.mkString)
       System.exit(0)
     }
 
index 02c51457ca061797c477643972228df9eab944a5..cb758a298ef826afb0c8b70c306c7725c6684d7a 100644 (file)
@@ -7,7 +7,10 @@ import ofc.{InvalidInputException,UnimplementedException}
 trait Index {
   def getName : String
   def getDependencies : Set[Index]
+  def getDenseWidth : String
+  //def generateIterationHeader(names: IndexNames) : String
 }
+
 trait SpatialIndex extends Index
 trait DiscreteIndex extends Index
 
@@ -43,14 +46,16 @@ class Scalar(value: Double) extends IterationSpace {
 
 class GeneralInnerProduct(operands: List[IterationSpace], removedIndices: Set[Index]) extends IterationSpace {
 
-  class DenseSpatialIndex(parent: GeneralInnerProduct) extends SpatialIndex{
+  class DenseSpatialIndex(parent: GeneralInnerProduct, original: SpatialIndex) extends SpatialIndex{
     def getDependencies = Set()
     def getName = "dense_spatial_index"
+    def getDenseWidth = original.getDenseWidth
   }
 
-  class DenseDiscreteIndex(parent: GeneralInnerProduct) extends DiscreteIndex {
+  class DenseDiscreteIndex(parent: GeneralInnerProduct, original: DiscreteIndex) extends DiscreteIndex {
     def getDependencies = Set()
     def getName = "dense_discrete_index"
+    def getDenseWidth = original.getDenseWidth
   }
 
   val spatialIndices =
@@ -58,14 +63,14 @@ class GeneralInnerProduct(operands: List[IterationSpace], removedIndices: Set[In
       if (index.getDependencies.intersect(removedIndices).isEmpty)
         index
       else
-        new DenseSpatialIndex(this)
+        new DenseSpatialIndex(this, index)
 
   val discreteIndices =
     for(op <- operands; index <- op.getDiscreteIndices; if (!removedIndices.contains(index))) yield
       if (index.getDependencies.intersect(removedIndices).isEmpty)
         index
       else
-        new DenseDiscreteIndex(this)
+        new DenseDiscreteIndex(this, index)
 
   def getOperands = operands
   def getSpatialIndices() = spatialIndices
@@ -74,11 +79,14 @@ class GeneralInnerProduct(operands: List[IterationSpace], removedIndices: Set[In
 }
 
 class Reciprocal(op: IterationSpace) extends IterationSpace {
-  class BlockIndex(parent: Reciprocal, dimension: Int)  extends SpatialIndex {
+  class BlockIndex(parent: Reciprocal, dimension: Int, original: SpatialIndex)  extends SpatialIndex {
     def getName = "reciprocal_index_" + dimension
     def getDependencies = Set()
+    def getDenseWidth = original.getDenseWidth
   }
-  val spatialIndices = for (dimension <- 0 to op.getSpatialIndices.size) yield new BlockIndex(this, dimension)
+
+  val spatialIndices = for (dimension <- 0 until op.getSpatialIndices.size) yield 
+    new BlockIndex(this, dimension, op.getSpatialIndices()(dimension))
 
   def getOperands = List(op)
   def getSpatialIndices() = spatialIndices.toList
@@ -97,9 +105,10 @@ class SpatialRestriction(op: IterationSpace) extends IterationSpace {
   class RestrictedIndex(parent: SpatialRestriction, dimension: Int)  extends SpatialIndex {
     def getName = "restriction_index_" + dimension
     def getDependencies = Set()
+    def getDenseWidth = throw new UnimplementedException("Restriction unimplemnted")
   }
 
-  val spatialIndices = for (dimension <- 0 to op.getSpatialIndices.size) yield new RestrictedIndex(this, dimension)
+  val spatialIndices = for (dimension <- 0 until op.getSpatialIndices.size) yield new RestrictedIndex(this, dimension)
 
   def getOperands() = List(op)
   def getSpatialIndices() = spatialIndices.toList
@@ -114,11 +123,13 @@ class SPAM3(name : String) extends Matrix {
     override def toString = parent + ".row"
     def getName = "row_index"
     def getDependencies = Set()
+    def getDenseWidth = throw new UnimplementedException("Matrix size unimplemented")
   }
   class ColIndex(parent: SPAM3) extends DiscreteIndex {
     override def toString = parent + ".col"
     def getName = "row_index"
     def getDependencies = Set()
+    def getDenseWidth = throw new UnimplementedException("Matrix size unimplemented")
   }
 
   val rowIndex = new RowIndex(this)
@@ -129,20 +140,25 @@ class SPAM3(name : String) extends Matrix {
   def getAccessExpression(indexNames: IndexNames) = throw new UnimplementedException("Access failed")
 }
 
-class PPDFunctionSet(basis : String, data : String) extends FunctionSet {
+class PPDFunctionSet(val basis : String, data : String) extends FunctionSet {
   class SphereIndex(parent: PPDFunctionSet) extends DiscreteIndex {
     def getName = "sphere_index"
     def getDependencies = Set()
+    def getDenseWidth = throw new UnimplementedException("Sphere count unimplemented")
+    //TODO: def getSphere = parent.basis + "%spheres(????)"
   }
 
   class PPDIndex(parent: PPDFunctionSet) extends DiscreteIndex {
     def getName = "ppd_index"
     def getDependencies = Set[Index](parent.getSphereIndex())
+    //TODO: def getDenseWidth = parent.getSphereIndex.getSphere + "%n_ppds_sphere"
+    def getDenseWidth = parent.basis+"%max_n_ppds_sphere"
   }
 
   class IntraPPDIndex(parent: PPDFunctionSet, dimension: Int) extends SpatialIndex {
     def getName = "intra_ppd_index_" + dimension
     def getDependencies = Set[Index](parent.getPPDIndex)
+    def getDenseWidth = "pub_cell%total_pt"+(dimension+1)
   }
 
   val ppdIndex = new PPDIndex(this)