]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Work on producer/consumer model of transforms.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Mon, 6 Feb 2012 20:59:11 +0000 (20:59 +0000)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Mon, 6 Feb 2012 20:59:11 +0000 (20:59 +0000)
src/ofc/generators/onetep/CodeGenerator.scala
src/ofc/generators/onetep/LoopTree.scala
src/ofc/generators/onetep/Tree.scala

index f83b6f3de8f6e46274331cdebd2ef75f4a810242..fd6fe9fe34822f036948e97cc784ae451c5f685e 100644 (file)
@@ -89,9 +89,9 @@ class CodeGenerator {
   
         // We've declared temporary storage, now create the loops to populate it
         for (index <- concreteIndexList) code append index.generateIterationHeader(nameManager) + "\n"
-        val lhs = storageName + (concreteIndexList map ((x: Index) => x.getDensePosition(nameManager))).mkString("(",", &\n",")")
-        val rhs = op.getAccessExpression(nameManager)
-        code append lhs + " = &\n" + rhs + "\n"
+        //val lhs = storageName + (concreteIndexList map ((x: Index) => x.getDensePosition(nameManager))).mkString("(",", &\n",")")
+        //val rhs = op.getAccessExpression(nameManager)
+        //code append lhs + " = &\n" + rhs + "\n"
         for (index <- concreteIndexList) code append index.generateIterationFooter(nameManager) + "\n"
   
         println(code.mkString)
index e4edd88e7b446989c740ddd7524baa04b9fa3f9c..6ce34cf88b5c2204faa849fa78963a6978984833 100644 (file)
@@ -9,9 +9,9 @@ Stores the configuration of indices we will use for code generation.
 
 object LoopNest {
   def apply(root: IterationSpace) : LoopNest = {
-    val nest = new LoopNest
     val sortedSpaces = IterationSpace.flattenPostorder(root)
     val sortedIndices = Index.sort(sortedSpaces flatMap (_.getIndices))
+    val nest = new LoopNest(sortedIndices.toList)
 
     for(space <- sortedSpaces) {
       val indices = space.getIndices
@@ -23,16 +23,59 @@ object LoopNest {
   }
 }
 
-class LoopNest {
+class LoopNest(sortedIndices: List[Index]) {
   val base = new LoopTree(None)
   val nameManager = new NameManager()
   var declarations = Set[String]()
+  var spaceFragmentsInfo = Map[IterationSpace, FragmentsInfo]()
+
+  class FragmentsInfo(val consumer: Option[Fragment], val transform: Option[Fragment], val producer: Option[Fragment])
 
   def addIterationSpace(indices: List[Index], space: IterationSpace) {
-    base.addIterationSpace(indices, space)
+    val operands = space.getOperands
+    val fragmentDepends = for(operand <- operands; fragment <- spaceFragmentsInfo(operand).producer) yield fragment
+
+    val consumer = space.getConsumerGenerator match {
+      case Some(consumerGenerator) => Some(new ConsumerFragment(space, fragmentDepends.toSet))
+      case None => None
+    }
+
+    val transform = space.getTransformGenerator match {
+      case Some(transformGenerator) => Some(new TransformFragment(space, consumer.toSet))
+      case None => None
+    }
+
+    val producer = space.getProducerGenerator match {
+      case Some(producerGenerator) => Some(new ProducerFragment(space, transform.toSet ++ consumer.toSet))
+      case None => None
+    }
+
+    val fragmentsInfo = new FragmentsInfo(consumer, transform, producer)
+    spaceFragmentsInfo += (space -> fragmentsInfo)
+
+    val consumerIndices = getSortedIndices((for (op <- operands; index <- op.getIndices) yield index).toSet)
+    consumer match {
+      case Some(fragment) => base.addFragment(consumerIndices, fragment)
+      case None =>
+    }
+
+    val transformIndices = getSortedIndices(consumerIndices.toSet & space.getIndices.toSet)
+    transform match {
+      case Some(fragment) => base.addFragment(transformIndices, fragment)
+      case None =>
+    }
+
+    val producerIndices = getSortedIndices(space.getIndices.toSet)
+    producer match {
+      case Some(fragment) => base.addFragment(producerIndices, fragment)
+      case None =>
+    }
+
     base.fuse()
   }
 
+  private def getSortedIndices(indices: Set[Index]) = sortedIndices filter (indices.contains(_))
+
   def getTree = base
 
   def generateCode : String = {
@@ -64,17 +107,48 @@ class LoopNest {
       }
     }
 
-    def visitSpace(space: IterationSpace) {
+    def visitFragment(fragment: Fragment) {
     }
 
     def getCode = code.mkString
   }
 }
 
+trait Fragment {
+  def getAllFragments : Set[Fragment]
+  def getDependencies : Set[Fragment]
+  def collectDeclarations(nameManager: NameManager) : Set[String]
+  def accept(visitor: LoopTreeVisitor) : Unit
+}
+
+class ConsumerFragment(parent: IterationSpace, dependencies: Set[Fragment])  extends Fragment {
+  def getAllFragments = Set(this)
+  def getDependencies = dependencies
+  def collectDeclarations(nameManager: NameManager) = Set[String]()
+  def accept(visitor: LoopTreeVisitor) = visitor.visitFragment(this)
+  override def toString = "Consumer: " + parent.toString
+}
+
+class ProducerFragment(parent: IterationSpace, dependencies: Set[Fragment])  extends Fragment {
+  def getAllFragments = Set(this)
+  def getDependencies = dependencies
+  def collectDeclarations(nameManager: NameManager) = Set[String]()
+  def accept(visitor: LoopTreeVisitor) = visitor.visitFragment(this)
+  override def toString = "Producer: " + parent.toString
+}
+
+class TransformFragment(parent: IterationSpace, dependencies: Set[Fragment])  extends Fragment {
+  def getAllFragments = Set(this)
+  def getDependencies = dependencies
+  def collectDeclarations(nameManager: NameManager) = Set[String]()
+  def accept(visitor: LoopTreeVisitor) = visitor.visitFragment(this)
+  override def toString = "Transform: " + parent.toString
+}
+
 trait LoopTreeVisitor {
   def enterTree(tree: LoopTree)
   def exitTree(tree: LoopTree)
-  def visitSpace(space: IterationSpace)
+  def visitFragment(space: Fragment)
 }
 
 object LoopTree {
@@ -97,29 +171,16 @@ object LoopTree {
   }
 }
 
-class LoopTree private[onetep](localIndex: Option[Index]) {
-  var subItems = ArrayBuffer[Either[IterationSpace, LoopTree]]()
-
-  private def contains(space: IterationSpace, deep: Boolean) : Boolean = {
-    var found = false
-    for (item <- subItems)
-      found |= (item match {
-        case (item: LoopTree) => deep && item.contains(space, deep)
-        case (item: IterationSpace) => item == space
-      })
-    
-    found
-  }
-
-  def containsShallow(space: IterationSpace) : Boolean = contains(space, false)
-  def containsDeep(space: IterationSpace) : Boolean = contains(space, true)
+class LoopTree private[onetep](localIndex: Option[Index]) extends Fragment {
+  var subItems = ArrayBuffer[Fragment]()
 
   def accept(visitor: LoopTreeVisitor) {
+    visitor.enterTree(this)
+
     for (item <- subItems)
-      item match {
-        case Left(space) => visitor.visitSpace(space)
-        case Right(tree) => {visitor.enterTree(tree); tree.accept(visitor); visitor.exitTree(tree)}
-      }
+      item.accept(visitor)
+
+    visitor.exitTree(this)
   }
 
   def +(b: LoopTree) : LoopTree =
@@ -135,45 +196,37 @@ class LoopTree private[onetep](localIndex: Option[Index]) {
     val result = collection.mutable.Set[String]()
 
     for(item <- subItems)
-      item match {
-        case Left(space) => result ++= LoopTree.collectSpaceDeclarations(space, nameManager)
-        case Right(tree) => result ++= tree.collectDeclarations(nameManager)
-      }
+      result ++= item.collectDeclarations(nameManager)
 
     result.toSet
   }
 
-  def getDependencies : Set[IterationSpace] = {
-    val dependencies = collection.mutable.Set[IterationSpace]()
+  def getDependencies : Set[Fragment] = {
+    val dependencies = collection.mutable.Set[Fragment]()
 
     for (item <- subItems)
-      item match {
-        case Left(space) => dependencies ++= space.getDependencies
-        case Right(tree) => dependencies ++= tree.getDependencies
-      }
-
+      dependencies ++= item.getDependencies
+    
     dependencies.toSet
   }
 
-  def getSpaces : Set[IterationSpace] = {
-    val spaces = collection.mutable.Set[IterationSpace]()
+  def getAllFragments : Set[Fragment] = {
+    val fragments = collection.mutable.Set[Fragment]()
+    fragments += this
 
     for (item <- subItems)
-      item match {
-        case Left(space) => spaces += space
-        case Right(tree) => spaces ++= tree.getSpaces
-      }
+      fragments ++= item.getAllFragments
 
-    spaces.toSet
+    fragments.toSet
   }
 
-  def addIterationSpace(indices: List[Index], space: IterationSpace) {
+  def addFragment(indices: List[Index], fragment: Fragment) {
     indices match {
-      case Nil => subItems += Left(space)
+      case Nil => subItems += fragment
       case (head :: tail) => {
         val tree = new LoopTree(Some(head))
-        subItems += Right(tree)
-        tree.addIterationSpace(tail, space)
+        subItems += tree
+        tree.addFragment(tail, fragment)
       }
     }
   }
@@ -182,11 +235,11 @@ class LoopTree private[onetep](localIndex: Option[Index]) {
     val trees = collection.mutable.Set[LoopTree]()
     for (item <- subItems)
       item match {
-        case Right(tree) => trees += tree
+        case (tree: LoopTree) => trees += tree
         case _ =>
       }
 
-    subItems --= trees map (Right(_))
+    subItems --= trees
 
     def attemptFusion(loops: collection.mutable.Set[LoopTree]) : Boolean = {
       for(a <- loops) 
@@ -200,19 +253,14 @@ class LoopTree private[onetep](localIndex: Option[Index]) {
 
     while(attemptFusion(trees)) {}
     trees map (_.fuse())
-    subItems ++= trees map (Right(_))
+    subItems ++= trees
 
     sort()
   }
 
   def sort() {
-    def compareItems(before: Either[IterationSpace, LoopTree], after: Either[IterationSpace, LoopTree]) : Boolean =
-      (before, after) match {
-        case (Left(space1), Left(space2)) => space2.getDependencies.contains(space1)
-        case (Right(tree1), Right(tree2)) => (tree2.getDependencies & tree1.getSpaces).nonEmpty
-        case (Left(space),  Right(tree))  => tree.getDependencies.contains(space)
-        case (Right(tree),  Left(space))  => (space.getDependencies & tree.getSpaces).nonEmpty 
-      }
+    def compareItems(before: Fragment, after: Fragment) : Boolean =
+      (after.getDependencies & before.getAllFragments).nonEmpty
 
     subItems = subItems.sortWith(compareItems(_, _))
   }
@@ -232,8 +280,8 @@ class LoopTree private[onetep](localIndex: Option[Index]) {
     
     for (entryID <- 0 until subItems.size) {
       val subList = (subItems(entryID) match {
-        case Left(space) => List(space.toString)
-        case Right(tree) => tree.toStrings
+        case (tree: LoopTree) => tree.toStrings
+        case x => List(x.toString)
       })
 
       val subListHeadPrefix = if (entryID < subItems.size-1) "|--" else "`--"
index 218fc6688f3faf0c2159f9cff8ce66e5c97a8de7..5dff2b05fe89aa636f06a84c628d207e722a2435 100644 (file)
@@ -46,8 +46,19 @@ object IterationSpace {
     term.getOperands.toTraversable.flatMap(flattenPostorder(_)) ++ List(term)
 }
 
+trait ConsumerGenerator {
+  def generate(names: NameManager, indices: Map[Index,String], values : Map[IterationSpace, String]) : String
+}
+
+trait ProducerGenerator {
+  def generate(names: NameManager) : String
+}
+
+trait TransformGenerator {
+  def generate(names: NameManager) : String
+}
+
 trait IterationSpace {
-  def getAccessExpression(indexNames: NameManager) : String
   def getOperands : List[IterationSpace]
   def getSpatialIndices : List[SpatialIndex]
   def getDiscreteIndices : List[DiscreteIndex]
@@ -58,10 +69,17 @@ trait IterationSpace {
     val operands = getOperands
     operands.toSet ++ operands.flatMap(_.getDependencies)
   }
+
+  // Code generation
+  def getConsumerGenerator : Option[ConsumerGenerator]
+  def getTransformGenerator : Option[TransformGenerator]
+  def getProducerGenerator : Option[ProducerGenerator]
 }
 
 trait DataSpace extends IterationSpace {
   def getOperands = Nil
+  def getConsumerGenerator = None
+  def getTransformGenerator = None
 }
 
 trait Matrix extends DataSpace
@@ -69,19 +87,23 @@ trait FunctionSet extends DataSpace
 
 class Assignment(indexBindings: IndexBindings, lhs: DataSpace, rhs: IterationSpace) extends IterationSpace {
   def getIndexBindings = indexBindings
-  def getOperands = List(lhs,rhs)
-  def getSpatialIndices = Nil
-  def getDiscreteIndices = Nil
+  def getOperands = List(rhs)
+  def getSpatialIndices = lhs.getSpatialIndices
+  def getDiscreteIndices = lhs.getDiscreteIndices
   def getExternalIndices = Set()
-  def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed")
+
+  def getConsumerGenerator = None
+  def getTransformGenerator = None
+  def getProducerGenerator = None
 }
 
-class Scalar(value: Double) extends IterationSpace {
-  def getOperands = Nil
+class Scalar(value: Double) extends DataSpace {
   def getSpatialIndices = Nil
   def getDiscreteIndices = Nil
   def getExternalIndices = Set()
-  def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed")
+  def getProducerGenerator = Some(new ProducerGenerator {
+    def generate(names: NameManager) = value.toString
+  })
 }
 
 class GeneralInnerProduct(operands: List[IterationSpace], removedIndices: Set[Index]) extends IterationSpace {
@@ -122,7 +144,10 @@ class GeneralInnerProduct(operands: List[IterationSpace], removedIndices: Set[In
   def getSpatialIndices = spatialIndices
   def getDiscreteIndices = discreteIndices
   def getExternalIndices = Set()
-  def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed")
+
+  def getConsumerGenerator = None
+  def getTransformGenerator = None
+  def getProducerGenerator = None
 }
 
 class Reciprocal(op: IterationSpace) extends IterationSpace {
@@ -142,7 +167,10 @@ class Reciprocal(op: IterationSpace) extends IterationSpace {
   def getSpatialIndices = spatialIndices.toList
   def getDiscreteIndices = op.getDiscreteIndices
   def getExternalIndices = Set()
-  def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed")
+
+  def getConsumerGenerator = None
+  def getTransformGenerator = None
+  def getProducerGenerator = None
 }
 
 class Laplacian(op: IterationSpace) extends IterationSpace {
@@ -150,7 +178,10 @@ class Laplacian(op: IterationSpace) extends IterationSpace {
   def getSpatialIndices = op.getSpatialIndices
   def getDiscreteIndices = op.getDiscreteIndices
   def getExternalIndices = Set()
-  def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed")
+
+  def getConsumerGenerator = None
+  def getTransformGenerator = None
+  def getProducerGenerator = None
 }
 
 class SpatialRestriction(op: IterationSpace) extends IterationSpace {
@@ -170,7 +201,10 @@ class SpatialRestriction(op: IterationSpace) extends IterationSpace {
   def getSpatialIndices = spatialIndices.toList
   def getDiscreteIndices = op.getDiscreteIndices
   def getExternalIndices = Set()
-  def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed")
+
+  def getConsumerGenerator = None
+  def getTransformGenerator = None
+  def getProducerGenerator = None
 }
 
 class SPAM3(name : String) extends Matrix {
@@ -214,7 +248,8 @@ class SPAM3(name : String) extends Matrix {
   def getSpatialIndices = Nil
   def getDiscreteIndices = List(rowIndex, colIndex)
   def getExternalIndices = Set()
-  def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed")
+
+  def getProducerGenerator = None
 }
 
 class PPDFunctionSet(val basis : String, data : String) extends FunctionSet {
@@ -243,7 +278,6 @@ class PPDFunctionSet(val basis : String, data : String) extends FunctionSet {
     def getDenseWidth(names: NameManager) = parent.basis+"%max_n_ppds_sphere"
 
     def generateIterationHeader(names: NameManager) = {
-      
       val initDense = "call basis_find_ppd_in_neighbour(" + denseIndexNames.mkString(",") + ", &\n" +
         parent.getSphere(names) + "%ppd_list(1," + names(this) + "), &\n" +
         parent.getSphere(names) + "%ppd_list(2," + names(this) + "), &\n" +
@@ -291,15 +325,17 @@ class PPDFunctionSet(val basis : String, data : String) extends FunctionSet {
   def getDiscreteIndices = List(getSphereIndex)
   def getExternalIndices = Set(getPPDIndex)
 
-  def getAccessExpression(indexNames: NameManager) =  {
-    val index = getSphere(indexNames)+"%offset + &\n" + 
-      "("+indexNames(getPPDIndex)+"-1)*pub_cell%n_pts - 1 + &\n" +
-      "(" + indexNames(spatialIndices(2)) + "-1)*pub_cell%n_pt2*pub_cell%n_pt1 + &\n" +
-      "(" + indexNames(spatialIndices(1)) + "-1)*pub_cell%n_pt1 + &\n" +
-      indexNames(spatialIndices(0))
+  def getProducerGenerator = Some(new ProducerGenerator {
+    def generate(names: NameManager) = {
+      val offset = getSphere(names)+"%offset + &\n" + 
+        "("+names(getPPDIndex)+"-1)*pub_cell%n_pts - 1 + &\n" +
+        "(" + names(spatialIndices(2)) + "-1)*pub_cell%n_pt2*pub_cell%n_pt1 + &\n" +
+        "(" + names(spatialIndices(1)) + "-1)*pub_cell%n_pt1 + &\n" +
+        names(spatialIndices(0))
 
-    data+"("+index+")"
-  }
+      data+"("+offset+")"
+    }
+  })
 }
 
 class BindingIndex(name : String) {