]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Generate loop headers and footers.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Fri, 3 Feb 2012 08:41:45 +0000 (08:41 +0000)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Fri, 3 Feb 2012 08:41:45 +0000 (08:41 +0000)
src/ofc/generators/onetep/CodeGenerator.scala
src/ofc/generators/onetep/LoopTree.scala
src/ofc/generators/onetep/Tree.scala

index 26bb45903f0535138ce48906986eb1ff365af2b2..807061edffb9c95fe4c76eaacbb51ca28c78b177 100644 (file)
@@ -56,7 +56,9 @@ class CodeGenerator {
       println(i)
     println("finished dumping indices")
 
-    val loopTree = LoopTree(space)
+    val loopNest = LoopNest(space)
+    println(loopNest.getTree)
+    println(loopNest.generateCode)
     // Next: we dump all these things into a prefix map
     System.exit(0)
 
index c711a5ab675101fd2fea786b75d5430758d493b8..dcd4bebdf2c2ce811a6a4a74e1cb264816972727 100644 (file)
@@ -6,24 +6,84 @@ import scala.collection.mutable.ArrayBuffer
 Stores the configuration of indices we will use for code generation.
 */
 
-object LoopTree {
-  def apply(root: IterationSpace) = {
-    val base = new LoopTree(None)
+object LoopNest {
+  def apply(root: IterationSpace) : LoopNest = {
+    val nest = new LoopNest
     val sortedSpaces = IterationSpace.flattenPostorder(root)
     val sortedIndices = Index.sort(sortedSpaces flatMap (_.getIndices))
 
     for(space <- sortedSpaces) {
       val indices = space.getIndices
       val localSortedIndices = sortedIndices filter (indices.contains(_))
-      base.addIterationSpace(localSortedIndices, space)
+      nest.addIterationSpace(localSortedIndices, space)
+    }
+
+    nest
+  }
+}
+
+class LoopNest {
+  val base = new LoopTree(None)
+  val nameManager = new NameManager()
+  var declarations = Set[String]()
+
+  def addIterationSpace(indices: List[Index], space: IterationSpace) = base.addIterationSpace(indices, space)
+  def getTree = base
+
+  def generateCode : String = {
+    val code = new StringBuilder()
+    declarations = base.collectDeclarations(nameManager)
+    code append declarations.mkString("\n")
+
+    val generationVisitor = new GenerationVisitor
+    base.accept(generationVisitor)
+    code append generationVisitor.getCode
+
+    code.mkString
+  }
+
+  class GenerationVisitor extends LoopTreeVisitor {
+    val code = new StringBuilder()
+
+    def enterTree(tree: LoopTree) {
+      tree.getLocalIndex match {
+        case None =>
+        case Some(index) => code append index.generateIterationHeader(nameManager)+"\n"
+      }
+    }
+
+    def exitTree(tree: LoopTree) {
+      tree.getLocalIndex match {
+        case None =>
+        case Some(index) => code append index.generateIterationFooter(nameManager)+"\n"
+      }
+    }
+
+    def visitSpace(space: IterationSpace) {
     }
 
-    println(base)
-    base
+    def getCode = code.mkString
   }
 }
 
-case class LoopTree private(localIndex: Option[Index]) {
+trait LoopTreeVisitor {
+  def enterTree(tree: LoopTree)
+  def exitTree(tree: LoopTree)
+  def visitSpace(space: IterationSpace)
+}
+
+object LoopTree {
+  def collectSpaceDeclarations(term: IterationSpace, nameManager: NameManager) : Set[String] = {
+    val declarations = for(index <- term.getIndices;
+                           declaration <- index.getDeclarations(nameManager)) yield declaration 
+
+    var declarationsSet = declarations.toSet
+    for (op <- term.getOperands) declarationsSet ++= collectSpaceDeclarations(op, nameManager)
+    declarationsSet
+  }
+}
+
+case class LoopTree private[onetep](localIndex: Option[Index]) {
   var subItems = ArrayBuffer[Either[IterationSpace, LoopTree]]()
 
   def contains(space: IterationSpace) : Boolean = {
@@ -37,7 +97,27 @@ case class LoopTree private(localIndex: Option[Index]) {
     found
   }
 
-  private def addIterationSpace(indices: List[Index], space: IterationSpace) {
+  def accept(visitor: LoopTreeVisitor) {
+    for (item <- subItems)
+      item match {
+        case Left(space) => visitor.visitSpace(space)
+        case Right(tree) => {visitor.enterTree(tree); tree.accept(visitor); visitor.exitTree(tree)}
+      }
+  }
+
+  def collectDeclarations(nameManager: NameManager) : Set[String] = {
+    val result = collection.mutable.Set[String]()
+
+    for(item <- subItems)
+      item match {
+        case Left(space) => result ++= LoopTree.collectSpaceDeclarations(space, nameManager)
+        case Right(tree) => result ++= tree.collectDeclarations(nameManager)
+      }
+
+    result.toSet
+  }
+
+  def addIterationSpace(indices: List[Index], space: IterationSpace) {
     indices match {
       case Nil => subItems += Left(space)
       case (head :: tail) => getEndLoop(head).addIterationSpace(tail, space)
@@ -60,7 +140,7 @@ case class LoopTree private(localIndex: Option[Index]) {
     }
   }
 
-  private def getLocalIndex = localIndex
+  def getLocalIndex = localIndex
 
   private def getLoopIndices(space: IterationSpace) : List[Index] = space.getIndices.toList
 
@@ -74,15 +154,18 @@ case class LoopTree private(localIndex: Option[Index]) {
   private def toStrings : List[String] = {
     val result = ArrayBuffer[String]("Index: " + localIndex)
     
-    for (entry <- subItems) {
-      val subList = (entry match {
+    for (entryID <- 0 until subItems.size) {
+      val subList = (subItems(entryID) match {
         case Left(space) => List(space.toString)
         case Right(tree) => tree.toStrings
       })
 
-      result ++= "|--"+subList.head :: (subList.tail.map("|  "+_))
+      val prefix = if (entryID < subItems.size-1) "|  " else "   "
+      result ++= "|--"+subList.head :: (subList.tail.map(prefix+_))
     }
 
     result.toList
   }
 }
+
+
index 1a624c1cb2478b61ee13cf582bb51f05de6f670e..0b7d3abf13fdf0515d5d0a940d6dc449d4663a10 100644 (file)
@@ -20,7 +20,7 @@ object Index {
 trait Index {
   def getName : String
   def getDependencies : Set[Index]
-  def getDenseWidth : String
+  def getDenseWidth(names: NameManager) : String
   def getDensePosition(names: NameManager) : String = names(this)
   def generateIterationHeader(names: NameManager) : String
   def generateIterationFooter(names: NameManager) : String
@@ -64,7 +64,6 @@ trait Matrix extends DataSpace
 trait FunctionSet extends DataSpace
 
 class Assignment(indexBindings: IndexBindings, lhs: DataSpace, rhs: IterationSpace) extends IterationSpace {
-  override def toString = indexBindings.toString
   def getIndexBindings = indexBindings
   def getOperands = List(lhs,rhs)
   def getSpatialIndices = Nil
@@ -86,8 +85,8 @@ class GeneralInnerProduct(operands: List[IterationSpace], removedIndices: Set[In
   class DenseSpatialIndex(parent: GeneralInnerProduct, original: SpatialIndex) extends SpatialIndex{
     def getDependencies = Set()
     def getName = "dense_spatial_index"
-    def getDenseWidth = original.getDenseWidth
-    def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+getDenseWidth
+    def getDenseWidth(names: NameManager) = original.getDenseWidth(names)
+    def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+getDenseWidth(names)
     def generateIterationFooter(names: NameManager) = "end do"
     def getDeclarations(names: NameManager) = List("integer :: "+names(this))
   }
@@ -95,8 +94,8 @@ class GeneralInnerProduct(operands: List[IterationSpace], removedIndices: Set[In
   class DenseDiscreteIndex(parent: GeneralInnerProduct, original: DiscreteIndex) extends DiscreteIndex {
     def getDependencies = Set()
     def getName = "dense_discrete_index"
-    def getDenseWidth = original.getDenseWidth
-    def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+getDenseWidth
+    def getDenseWidth(names: NameManager) = original.getDenseWidth(names)
+    def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+getDenseWidth(names)
     def generateIterationFooter(names: NameManager) = "end do"
     def getDeclarations(names: NameManager) = List("integer :: "+names(this))
   }
@@ -126,8 +125,8 @@ class Reciprocal(op: IterationSpace) extends IterationSpace {
   class BlockIndex(parent: Reciprocal, dimension: Int, original: SpatialIndex)  extends SpatialIndex {
     def getName = "reciprocal_index_" + dimension
     def getDependencies = Set()
-    def getDenseWidth = original.getDenseWidth
-    def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+getDenseWidth
+    def getDenseWidth(names: NameManager) = original.getDenseWidth(names)
+    def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+getDenseWidth(names)
     def generateIterationFooter(names: NameManager) = "end do"
     def getDeclarations(names: NameManager) = List("integer :: "+names(this))
   }
@@ -154,10 +153,10 @@ class SpatialRestriction(op: IterationSpace) extends IterationSpace {
   class RestrictedIndex(parent: SpatialRestriction, dimension: Int)  extends SpatialIndex {
     def getName = "restriction_index_" + dimension
     def getDependencies = Set()
-    def getDenseWidth = throw new UnimplementedException("Restriction unimplemnted")
+    def getDenseWidth(names: NameManager) = "pub_fftbox%total_pt"+(dimension+1)
 
-    def generateIterationHeader(names: NameManager) = throw new UnimplementedException("how the hell does this work?")
-    def generateIterationFooter(names: NameManager) = throw new UnimplementedException("how does this work either?")
+    def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+getDenseWidth(names)
+    def generateIterationFooter(names: NameManager) = "end do"
     def getDeclarations(names: NameManager) = Nil
   }
 
@@ -172,16 +171,17 @@ class SpatialRestriction(op: IterationSpace) extends IterationSpace {
 
 class SPAM3(name : String) extends Matrix {
   override def toString = name
+  def getName = name
 
   class RowIndex(parent: SPAM3) extends DiscreteIndex {
     override def toString = parent + ".row"
     def getName = "row_index"
     def getDependencies = Set()
-    def getDenseWidth = throw new UnimplementedException("Matrix size unimplemented")
+    def getDenseWidth(names: NameManager) = "sparse_num_rows("+parent.getName+")"
 
     def generateIterationHeader(names: NameManager) = {
       val indexName = names(this)
-      "do "+indexName+"=1,"+getDenseWidth
+      "do "+indexName+"=1,"+getDenseWidth(names)
     }
 
     def generateIterationFooter(names: NameManager) = "end do"
@@ -192,11 +192,11 @@ class SPAM3(name : String) extends Matrix {
     override def toString = parent + ".col"
     def getName = "row_index"
     def getDependencies = Set()
-    def getDenseWidth = throw new UnimplementedException("Matrix size unimplemented")
+    def getDenseWidth(names: NameManager) = "sparse_num_cols("+parent.getName+")"
 
     def generateIterationHeader(names: NameManager) = {
       val indexName = names(this)
-      "do "+indexName+"=1,"+getDenseWidth
+      "do "+indexName+"=1,"+getDenseWidth(names)
     }
 
 
@@ -217,11 +217,11 @@ class PPDFunctionSet(val basis : String, data : String) extends FunctionSet {
   class SphereIndex(parent: PPDFunctionSet) extends DiscreteIndex {
     def getName = "sphere_index"
     def getDependencies = Set()
-    def getDenseWidth = throw new UnimplementedException("Sphere count unimplemented")
+    def getDenseWidth(names: NameManager) = parent.getNumSpheres(names)
 
     def generateIterationHeader(names: NameManager) = {
       val indexName = names(this)
-      "do "+indexName+"=1,"+getDenseWidth
+      "do "+indexName+"=1," + parent.getSphere(names) + "%n_ppds_sphere"
     }
 
     def generateIterationFooter(names: NameManager) = "end do"
@@ -236,7 +236,7 @@ class PPDFunctionSet(val basis : String, data : String) extends FunctionSet {
     def getDensePPDIndices = denseIndexNames
 
     //TODO: def getDenseWidth = parent.getSphereIndex.getSphere + "%n_ppds_sphere"
-    def getDenseWidth = parent.basis+"%max_n_ppds_sphere"
+    def getDenseWidth(names: NameManager) = parent.basis+"%max_n_ppds_sphere"
 
     def generateIterationHeader(names: NameManager) = {
       
@@ -261,7 +261,7 @@ class PPDFunctionSet(val basis : String, data : String) extends FunctionSet {
   class IntraPPDIndex(parent: PPDFunctionSet, dimension: Int) extends SpatialIndex {
     def getName = "intra_ppd_index_" + dimension
     def getDependencies = Set[Index](parent.getPPDIndex)
-    def getDenseWidth = "pub_cell%total_pt"+(dimension+1)
+    def getDenseWidth(names: NameManager) = "pub_cell%total_pt"+(dimension+1)
     def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+"pub_cell%n_pt"+(dimension+1)
     def generateIterationFooter(names: NameManager) = "end do"
     def getDeclarations(names: NameManager) = List("integer :: "+names(this))
@@ -278,6 +278,11 @@ class PPDFunctionSet(val basis : String, data : String) extends FunctionSet {
   def getSphereIndex = sphereIndex
   def getSphere(names: NameManager) = basis + "%spheres("+names(getSphereIndex)+")"
 
+  def getNumSpheres(names: NameManager) = {
+    // TODO: This number is dependent on the parallel distribution
+    basis + "%node_num"
+  }
+
   def getSpatialIndices = spatialIndices.toList
   def getDiscreteIndices = List(getSphereIndex)
   def getExternalIndices = Set(getPPDIndex)