From 5d66acdee77e4877f017503aea90617421209e6a Mon Sep 17 00:00:00 2001 From: Francis Russell Date: Wed, 8 Feb 2012 17:45:17 +0000 Subject: [PATCH] Work on buffer placement. --- src/ofc/OFC.scala | 2 +- src/ofc/generators/onetep/CodeGenerator.scala | 2 +- src/ofc/generators/onetep/LoopTree.scala | 92 +++++++++++++++++-- .../generators/onetep/PPDFunctionSet.scala | 2 +- 4 files changed, 85 insertions(+), 13 deletions(-) diff --git a/src/ofc/OFC.scala b/src/ofc/OFC.scala index 0bec89b..b082eef 100644 --- a/src/ofc/OFC.scala +++ b/src/ofc/OFC.scala @@ -6,7 +6,7 @@ import generators.Generator class InvalidInputException(s: String) extends Exception(s) class UnimplementedException(s: String) extends Exception(s) -class SemanticError(s: String) extends Exception(s) +class LogicError(s: String) extends Exception(s) object OFC extends Parser { diff --git a/src/ofc/generators/onetep/CodeGenerator.scala b/src/ofc/generators/onetep/CodeGenerator.scala index fd6fe9f..7735c9e 100644 --- a/src/ofc/generators/onetep/CodeGenerator.scala +++ b/src/ofc/generators/onetep/CodeGenerator.scala @@ -57,8 +57,8 @@ class CodeGenerator { println("") val loopNest = LoopNest(space) + println("Code:\n"+loopNest.generateCode+"\n") println("Loop Nest:\n"+ loopNest.getTree + "\n") - println("Code:\n"+loopNest.generateCode) // Next: we dump all these things into a prefix map System.exit(0) diff --git a/src/ofc/generators/onetep/LoopTree.scala b/src/ofc/generators/onetep/LoopTree.scala index 7c218d2..e34458b 100644 --- a/src/ofc/generators/onetep/LoopTree.scala +++ b/src/ofc/generators/onetep/LoopTree.scala @@ -1,7 +1,7 @@ package ofc.generators.onetep import scala.collection.mutable.ArrayBuffer -import ofc.SemanticError +import ofc.LogicError /* Stores the configuration of indices we will use for code generation. @@ -28,13 +28,17 @@ class LoopNest(sortedIndices: List[Index]) { val nameManager = new NameManager() var declarations = Set[String]() var spaceFragmentsInfo = Map[IterationSpace, FragmentsInfo]() + var consumerProducers = collection.mutable.Map[ConsumerFragment, Set[ProducerFragment]]() - class FragmentsInfo(val consumer: Option[Fragment], val transform: Option[Fragment], val producer: Option[Fragment]) + class FragmentsInfo(val consumer: Option[ConsumerFragment], + val transform: Option[Fragment], + val producer: Option[ProducerFragment]) def addIterationSpace(indices: List[Index], space: IterationSpace) { val operands = space.getOperands val fragmentDepends = for(operand <- operands; fragment <- spaceFragmentsInfo(operand).producer) yield fragment + // Construct Fragments val consumer = space.getConsumerGenerator match { case Some(consumerGenerator) => Some(new ConsumerFragment(space, fragmentDepends.toSet)) case None => None @@ -50,6 +54,13 @@ class LoopNest(sortedIndices: List[Index]) { case None => None } + // Record consumer-producer relationship + consumer match { + case Some(c) => consumerProducers.getOrElseUpdate(c, fragmentDepends.toSet) + case _ => + } + + // Insert into tree val fragmentsInfo = new FragmentsInfo(consumer, transform, producer) spaceFragmentsInfo += (space -> fragmentsInfo) @@ -79,7 +90,9 @@ class LoopNest(sortedIndices: List[Index]) { def getTree = base def generateCode : String = { - val code = new StringBuilder() + computeBuffers() + + val code = new StringBuilder declarations = base.collectDeclarations(nameManager) code append declarations.mkString("\n") + "\n\n" @@ -90,8 +103,20 @@ class LoopNest(sortedIndices: List[Index]) { code.mkString } + private def computeBuffers() { + val producerConsumers = collection.mutable.Map[ProducerFragment, collection.mutable.Set[ConsumerFragment]]() + + for ((consumer, producers) <- consumerProducers; producer <- producers) + producerConsumers.getOrElseUpdate(producer, collection.mutable.Set[ConsumerFragment]()) += consumer + + val descriptors = for ((producer, consumers) <- producerConsumers) yield + new BufferDescriptor(producer, consumers.toSet) + + descriptors.map(base.addBufferDescriptor(_)) + } + class GenerationVisitor extends LoopTreeVisitor { - val code = new StringBuilder() + val code = new StringBuilder def enterTree(tree: LoopTree) { tree.getLocalIndex match { @@ -109,9 +134,9 @@ class LoopNest(sortedIndices: List[Index]) { def visitFragment(fragment: Fragment) { fragment match { - case (c: ConsumerFragment) => + case (c: ConsumerFragment) => code append "!"+c.toString+"\n" case (p: ProducerFragment) => code append "!"+p.generate(nameManager)+"\n" - case (t: TransformFragment) => + case (t: TransformFragment) => code append "!"+t.toString+"\n" } } @@ -126,7 +151,8 @@ trait Fragment { def accept(visitor: LoopTreeVisitor) : Unit } -class ConsumerFragment(parent: IterationSpace, dependencies: Set[Fragment]) extends Fragment { +class ConsumerFragment(parent: IterationSpace, dependencies: Set[Fragment]) extends Fragment { + def getSpace = parent def getAllFragments = Set(this) def getDependencies = dependencies def collectDeclarations(nameManager: NameManager) = parent.getIndices.flatMap(_.getDeclarations(nameManager)) @@ -134,7 +160,8 @@ class ConsumerFragment(parent: IterationSpace, dependencies: Set[Fragment]) ext override def toString = "Consumer: " + parent.toString } -class ProducerFragment(parent: IterationSpace, dependencies: Set[Fragment]) extends Fragment { +class ProducerFragment(parent: IterationSpace, dependencies: Set[Fragment]) extends Fragment { + def getSpace = parent def getAllFragments = Set(this) def getDependencies = dependencies def collectDeclarations(nameManager: NameManager) = parent.getIndices.flatMap(_.getDeclarations(nameManager)) @@ -143,7 +170,8 @@ class ProducerFragment(parent: IterationSpace, dependencies: Set[Fragment]) ext override def toString = "Producer: " + parent.toString } -class TransformFragment(parent: IterationSpace, dependencies: Set[Fragment]) extends Fragment { +class TransformFragment(parent: IterationSpace, dependencies: Set[Fragment]) extends Fragment { + def getSpace = parent def getAllFragments = Set(this) def getDependencies = dependencies def collectDeclarations(nameManager: NameManager) = parent.getIndices.flatMap(_.getDeclarations(nameManager)) @@ -157,6 +185,14 @@ trait LoopTreeVisitor { def visitFragment(space: Fragment) } +class BufferDescriptor(producer: ProducerFragment, consumers: Set[ConsumerFragment]) { + def getProducer : ProducerFragment = producer + def getConsumers : Set[ConsumerFragment] = consumers + def getSpace : IterationSpace = producer.getSpace + def getIndices : Set[Index] = getSpace.getInternalIndices + override def toString = "Buffer: "+getSpace.toString +} + object LoopTree { def collectSpaceDeclarations(term: IterationSpace, nameManager: NameManager) : Set[String] = { val declarations = for(index <- term.getIndices; @@ -179,6 +215,7 @@ object LoopTree { class LoopTree private[onetep](localIndex: Option[Index]) extends Fragment { var subItems = ArrayBuffer[Fragment]() + var bufferDescriptors = ArrayBuffer[BufferDescriptor]() def accept(visitor: LoopTreeVisitor) { visitor.enterTree(this) @@ -193,11 +230,43 @@ class LoopTree private[onetep](localIndex: Option[Index]) extends Fragment { if (getLocalIndex == b.getLocalIndex) { val newTree = new LoopTree(getLocalIndex) newTree.subItems = subItems ++ b.subItems + newTree.bufferDescriptors = bufferDescriptors ++ b.bufferDescriptors newTree } else { - throw new SemanticError("Addition undefined for loops with different indices") + throw new LogicError("Addition undefined for loops with different indices") } + def addBufferDescriptor(descriptor: BufferDescriptor) { + val allFragments = getAllFragments + + // FIXME: the toSet call should be unneccessary + if (!allFragments.contains(descriptor.getProducer) || + !descriptor.getConsumers.toSet[Fragment].subsetOf(allFragments)) { + throw new LogicError("Attempt to place buffer in LoopTree missing consumer or producer.") + } + + for(item <- subItems) item match { + case tree: LoopTree => { + val placeInside = tree.getLocalIndex match { + case Some(index) => + descriptor.getIndices.contains(index) && + (descriptor.getConsumers.toSet[Fragment] + descriptor.getProducer).subsetOf(item.getAllFragments) + case None => false + } + + if (placeInside) { + tree.addBufferDescriptor(descriptor) + return + } + } + case _ => + } + + bufferDescriptors += descriptor + } + + def getBufferDescriptors : List[BufferDescriptor] = bufferDescriptors.toList + def collectDeclarations(nameManager: NameManager) : Set[String] = { val result = collection.mutable.Set[String]() @@ -283,6 +352,9 @@ class LoopTree private[onetep](localIndex: Option[Index]) extends Fragment { case Some(x) => x.toString case None => "None" }) + + for(bufferDescriptor <- bufferDescriptors) + result += "|--"+bufferDescriptor.toString for (entryID <- 0 until subItems.size) { val subList = (subItems(entryID) match { diff --git a/src/ofc/generators/onetep/PPDFunctionSet.scala b/src/ofc/generators/onetep/PPDFunctionSet.scala index df0f900..f7adac4 100644 --- a/src/ofc/generators/onetep/PPDFunctionSet.scala +++ b/src/ofc/generators/onetep/PPDFunctionSet.scala @@ -52,7 +52,7 @@ class PPDFunctionSet(val basis : String, data : String) extends FunctionSet { "+ ("+startNames(2)+"-1)*pub_cell%n_pt2*pub_cell%n_pt1 &\n"+ "+ ("+startNames(1)+"-1)*pub_cell%n_pt1 + ("+startNames(0)+"-1)" - (List(findPPD) ++ computeRanges ++ List(loopDeclaration) ++ List(ppdOffsetCalc)).mkString("\n") + (List(findPPD) ++ computeRanges ++ List(loopDeclaration, ppdOffsetCalc)).mkString("\n") } def generateIterationFooter(names: NameManager) = "end do" -- 2.47.3