]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Work on buffer placement.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Wed, 8 Feb 2012 17:45:17 +0000 (17:45 +0000)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Wed, 8 Feb 2012 17:45:17 +0000 (17:45 +0000)
src/ofc/OFC.scala
src/ofc/generators/onetep/CodeGenerator.scala
src/ofc/generators/onetep/LoopTree.scala
src/ofc/generators/onetep/PPDFunctionSet.scala

index 0bec89b9661821cf533c396689b54b4377a0e321..b082eef76cd0464c38f72dc32e76c1f0f2ec028c 100644 (file)
@@ -6,7 +6,7 @@ import generators.Generator
 
 class InvalidInputException(s: String) extends Exception(s)
 class UnimplementedException(s: String) extends Exception(s)
-class SemanticError(s: String) extends Exception(s)
+class LogicError(s: String) extends Exception(s)
 
 object OFC extends Parser {
 
index fd6fe9fe34822f036948e97cc784ae451c5f685e..7735c9eb3b74eae01020c891165f0b6a36c7263f 100644 (file)
@@ -57,8 +57,8 @@ class CodeGenerator {
     println("")
 
     val loopNest = LoopNest(space)
+    println("Code:\n"+loopNest.generateCode+"\n")
     println("Loop Nest:\n"+ loopNest.getTree + "\n")
-    println("Code:\n"+loopNest.generateCode)
     // Next: we dump all these things into a prefix map
     System.exit(0)
 
index 7c218d29c3951e6c65da814e3f8a5768a0fdaf87..e34458b9119fc5b9dc2b39f481c62a9a770cf876 100644 (file)
@@ -1,7 +1,7 @@
 package ofc.generators.onetep
 
 import scala.collection.mutable.ArrayBuffer
-import ofc.SemanticError
+import ofc.LogicError
 
 /*
 Stores the configuration of indices we will use for code generation.
@@ -28,13 +28,17 @@ class LoopNest(sortedIndices: List[Index]) {
   val nameManager = new NameManager()
   var declarations = Set[String]()
   var spaceFragmentsInfo = Map[IterationSpace, FragmentsInfo]()
+  var consumerProducers = collection.mutable.Map[ConsumerFragment, Set[ProducerFragment]]()
 
-  class FragmentsInfo(val consumer: Option[Fragment], val transform: Option[Fragment], val producer: Option[Fragment])
+  class FragmentsInfo(val consumer: Option[ConsumerFragment], 
+                      val transform: Option[Fragment], 
+                      val producer: Option[ProducerFragment])
 
   def addIterationSpace(indices: List[Index], space: IterationSpace) {
     val operands = space.getOperands
     val fragmentDepends = for(operand <- operands; fragment <- spaceFragmentsInfo(operand).producer) yield fragment
 
+    // Construct Fragments
     val consumer = space.getConsumerGenerator match {
       case Some(consumerGenerator) => Some(new ConsumerFragment(space, fragmentDepends.toSet))
       case None => None
@@ -50,6 +54,13 @@ class LoopNest(sortedIndices: List[Index]) {
       case None => None
     }
 
+    // Record consumer-producer relationship
+    consumer match {
+      case Some(c) => consumerProducers.getOrElseUpdate(c, fragmentDepends.toSet)
+      case _ =>
+    }
+
+    // Insert into tree
     val fragmentsInfo = new FragmentsInfo(consumer, transform, producer)
     spaceFragmentsInfo += (space -> fragmentsInfo)
 
@@ -79,7 +90,9 @@ class LoopNest(sortedIndices: List[Index]) {
   def getTree = base
 
   def generateCode : String = {
-    val code = new StringBuilder()
+    computeBuffers()
+
+    val code = new StringBuilder
     declarations = base.collectDeclarations(nameManager)
     code append declarations.mkString("\n") + "\n\n"
 
@@ -90,8 +103,20 @@ class LoopNest(sortedIndices: List[Index]) {
     code.mkString
   }
 
+  private def computeBuffers() {
+    val producerConsumers = collection.mutable.Map[ProducerFragment, collection.mutable.Set[ConsumerFragment]]()
+
+    for ((consumer, producers) <- consumerProducers; producer <- producers)
+      producerConsumers.getOrElseUpdate(producer, collection.mutable.Set[ConsumerFragment]()) += consumer
+
+    val descriptors = for ((producer, consumers) <- producerConsumers) yield 
+      new BufferDescriptor(producer, consumers.toSet)
+
+    descriptors.map(base.addBufferDescriptor(_))
+  }
+
   class GenerationVisitor extends LoopTreeVisitor {
-    val code = new StringBuilder()
+    val code = new StringBuilder
 
     def enterTree(tree: LoopTree) {
       tree.getLocalIndex match {
@@ -109,9 +134,9 @@ class LoopNest(sortedIndices: List[Index]) {
 
     def visitFragment(fragment: Fragment) {
       fragment match {
-        case (c: ConsumerFragment) => 
+        case (c: ConsumerFragment) => code append "!"+c.toString+"\n"
         case (p: ProducerFragment) => code append "!"+p.generate(nameManager)+"\n"
-        case (t: TransformFragment) =>
+        case (t: TransformFragment) => code append "!"+t.toString+"\n"
       }
     }
 
@@ -126,7 +151,8 @@ trait Fragment {
   def accept(visitor: LoopTreeVisitor) : Unit
 }
 
-class ConsumerFragment(parent: IterationSpace, dependencies: Set[Fragment])  extends Fragment {
+class ConsumerFragment(parent: IterationSpace, dependencies: Set[Fragment]) extends Fragment {
+  def getSpace = parent
   def getAllFragments = Set(this)
   def getDependencies = dependencies
   def collectDeclarations(nameManager: NameManager) = parent.getIndices.flatMap(_.getDeclarations(nameManager))
@@ -134,7 +160,8 @@ class ConsumerFragment(parent: IterationSpace, dependencies: Set[Fragment])  ext
   override def toString = "Consumer: " + parent.toString
 }
 
-class ProducerFragment(parent: IterationSpace, dependencies: Set[Fragment])  extends Fragment {
+class ProducerFragment(parent: IterationSpace, dependencies: Set[Fragment]) extends Fragment {
+  def getSpace = parent
   def getAllFragments = Set(this)
   def getDependencies = dependencies
   def collectDeclarations(nameManager: NameManager) = parent.getIndices.flatMap(_.getDeclarations(nameManager))
@@ -143,7 +170,8 @@ class ProducerFragment(parent: IterationSpace, dependencies: Set[Fragment])  ext
   override def toString = "Producer: " + parent.toString
 }
 
-class TransformFragment(parent: IterationSpace, dependencies: Set[Fragment])  extends Fragment {
+class TransformFragment(parent: IterationSpace, dependencies: Set[Fragment]) extends Fragment {
+  def getSpace = parent
   def getAllFragments = Set(this)
   def getDependencies = dependencies
   def collectDeclarations(nameManager: NameManager) = parent.getIndices.flatMap(_.getDeclarations(nameManager))
@@ -157,6 +185,14 @@ trait LoopTreeVisitor {
   def visitFragment(space: Fragment)
 }
 
+class BufferDescriptor(producer: ProducerFragment, consumers: Set[ConsumerFragment]) {
+  def getProducer : ProducerFragment = producer
+  def getConsumers : Set[ConsumerFragment] = consumers
+  def getSpace : IterationSpace = producer.getSpace
+  def getIndices : Set[Index] = getSpace.getInternalIndices
+  override def toString = "Buffer: "+getSpace.toString
+}
+
 object LoopTree {
   def collectSpaceDeclarations(term: IterationSpace, nameManager: NameManager) : Set[String] = {
     val declarations = for(index <- term.getIndices;
@@ -179,6 +215,7 @@ object LoopTree {
 
 class LoopTree private[onetep](localIndex: Option[Index]) extends Fragment {
   var subItems = ArrayBuffer[Fragment]()
+  var bufferDescriptors = ArrayBuffer[BufferDescriptor]()
 
   def accept(visitor: LoopTreeVisitor) {
     visitor.enterTree(this)
@@ -193,11 +230,43 @@ class LoopTree private[onetep](localIndex: Option[Index]) extends Fragment {
     if (getLocalIndex == b.getLocalIndex) {
       val newTree = new LoopTree(getLocalIndex)
       newTree.subItems = subItems ++ b.subItems
+      newTree.bufferDescriptors = bufferDescriptors ++ b.bufferDescriptors
       newTree
     } else {
-      throw new SemanticError("Addition undefined for loops with different indices")
+      throw new LogicError("Addition undefined for loops with different indices")
     }
 
+  def addBufferDescriptor(descriptor: BufferDescriptor) {
+    val allFragments = getAllFragments
+
+    // FIXME: the toSet call should be unneccessary
+    if (!allFragments.contains(descriptor.getProducer) ||
+        !descriptor.getConsumers.toSet[Fragment].subsetOf(allFragments)) {
+      throw new LogicError("Attempt to place buffer in LoopTree missing consumer or producer.")
+    }
+
+    for(item <- subItems) item match {
+      case tree: LoopTree => {
+        val placeInside = tree.getLocalIndex match {
+          case Some(index) =>
+            descriptor.getIndices.contains(index) &&
+            (descriptor.getConsumers.toSet[Fragment] + descriptor.getProducer).subsetOf(item.getAllFragments)
+          case None => false
+        }
+
+        if (placeInside) {
+          tree.addBufferDescriptor(descriptor)
+          return
+        }
+      }
+      case _ =>
+    }
+
+    bufferDescriptors += descriptor
+  }
+
+  def getBufferDescriptors : List[BufferDescriptor] = bufferDescriptors.toList
+
   def collectDeclarations(nameManager: NameManager) : Set[String] = {
     val result = collection.mutable.Set[String]()
 
@@ -283,6 +352,9 @@ class LoopTree private[onetep](localIndex: Option[Index]) extends Fragment {
       case Some(x) => x.toString
       case None => "None"
     })
+
+    for(bufferDescriptor <- bufferDescriptors)
+      result += "|--"+bufferDescriptor.toString
     
     for (entryID <- 0 until subItems.size) {
       val subList = (subItems(entryID) match {
index df0f900ceaa1050f309d9778b2732e5752faf6a3..f7adac480ce60d2949c3d730a4fa81708757e26b 100644 (file)
@@ -52,7 +52,7 @@ class PPDFunctionSet(val basis : String, data : String) extends FunctionSet {
       "+ ("+startNames(2)+"-1)*pub_cell%n_pt2*pub_cell%n_pt1 &\n"+
       "+ ("+startNames(1)+"-1)*pub_cell%n_pt1 + ("+startNames(0)+"-1)"
 
-      (List(findPPD) ++ computeRanges ++ List(loopDeclaration) ++ List(ppdOffsetCalc)).mkString("\n")
+      (List(findPPD) ++ computeRanges ++ List(loopDeclarationppdOffsetCalc)).mkString("\n")
     }
 
     def generateIterationFooter(names: NameManager) = "end do"