package ofc.generators.onetep
import scala.collection.mutable.ArrayBuffer
-import ofc.SemanticError
+import ofc.LogicError
/*
Stores the configuration of indices we will use for code generation.
val nameManager = new NameManager()
var declarations = Set[String]()
var spaceFragmentsInfo = Map[IterationSpace, FragmentsInfo]()
+ var consumerProducers = collection.mutable.Map[ConsumerFragment, Set[ProducerFragment]]()
- class FragmentsInfo(val consumer: Option[Fragment], val transform: Option[Fragment], val producer: Option[Fragment])
+ class FragmentsInfo(val consumer: Option[ConsumerFragment],
+ val transform: Option[Fragment],
+ val producer: Option[ProducerFragment])
def addIterationSpace(indices: List[Index], space: IterationSpace) {
val operands = space.getOperands
val fragmentDepends = for(operand <- operands; fragment <- spaceFragmentsInfo(operand).producer) yield fragment
+ // Construct Fragments
val consumer = space.getConsumerGenerator match {
case Some(consumerGenerator) => Some(new ConsumerFragment(space, fragmentDepends.toSet))
case None => None
case None => None
}
+ // Record consumer-producer relationship
+ consumer match {
+ case Some(c) => consumerProducers.getOrElseUpdate(c, fragmentDepends.toSet)
+ case _ =>
+ }
+
+ // Insert into tree
val fragmentsInfo = new FragmentsInfo(consumer, transform, producer)
spaceFragmentsInfo += (space -> fragmentsInfo)
def getTree = base
def generateCode : String = {
- val code = new StringBuilder()
+ computeBuffers()
+
+ val code = new StringBuilder
declarations = base.collectDeclarations(nameManager)
code append declarations.mkString("\n") + "\n\n"
code.mkString
}
+ private def computeBuffers() {
+ val producerConsumers = collection.mutable.Map[ProducerFragment, collection.mutable.Set[ConsumerFragment]]()
+
+ for ((consumer, producers) <- consumerProducers; producer <- producers)
+ producerConsumers.getOrElseUpdate(producer, collection.mutable.Set[ConsumerFragment]()) += consumer
+
+ val descriptors = for ((producer, consumers) <- producerConsumers) yield
+ new BufferDescriptor(producer, consumers.toSet)
+
+ descriptors.map(base.addBufferDescriptor(_))
+ }
+
class GenerationVisitor extends LoopTreeVisitor {
- val code = new StringBuilder()
+ val code = new StringBuilder
def enterTree(tree: LoopTree) {
tree.getLocalIndex match {
def visitFragment(fragment: Fragment) {
fragment match {
- case (c: ConsumerFragment) =>
+ case (c: ConsumerFragment) => code append "!"+c.toString+"\n"
case (p: ProducerFragment) => code append "!"+p.generate(nameManager)+"\n"
- case (t: TransformFragment) =>
+ case (t: TransformFragment) => code append "!"+t.toString+"\n"
}
}
def accept(visitor: LoopTreeVisitor) : Unit
}
-class ConsumerFragment(parent: IterationSpace, dependencies: Set[Fragment]) extends Fragment {
+class ConsumerFragment(parent: IterationSpace, dependencies: Set[Fragment]) extends Fragment {
+ def getSpace = parent
def getAllFragments = Set(this)
def getDependencies = dependencies
def collectDeclarations(nameManager: NameManager) = parent.getIndices.flatMap(_.getDeclarations(nameManager))
override def toString = "Consumer: " + parent.toString
}
-class ProducerFragment(parent: IterationSpace, dependencies: Set[Fragment]) extends Fragment {
+class ProducerFragment(parent: IterationSpace, dependencies: Set[Fragment]) extends Fragment {
+ def getSpace = parent
def getAllFragments = Set(this)
def getDependencies = dependencies
def collectDeclarations(nameManager: NameManager) = parent.getIndices.flatMap(_.getDeclarations(nameManager))
override def toString = "Producer: " + parent.toString
}
-class TransformFragment(parent: IterationSpace, dependencies: Set[Fragment]) extends Fragment {
+class TransformFragment(parent: IterationSpace, dependencies: Set[Fragment]) extends Fragment {
+ def getSpace = parent
def getAllFragments = Set(this)
def getDependencies = dependencies
def collectDeclarations(nameManager: NameManager) = parent.getIndices.flatMap(_.getDeclarations(nameManager))
def visitFragment(space: Fragment)
}
+class BufferDescriptor(producer: ProducerFragment, consumers: Set[ConsumerFragment]) {
+ def getProducer : ProducerFragment = producer
+ def getConsumers : Set[ConsumerFragment] = consumers
+ def getSpace : IterationSpace = producer.getSpace
+ def getIndices : Set[Index] = getSpace.getInternalIndices
+ override def toString = "Buffer: "+getSpace.toString
+}
+
object LoopTree {
def collectSpaceDeclarations(term: IterationSpace, nameManager: NameManager) : Set[String] = {
val declarations = for(index <- term.getIndices;
class LoopTree private[onetep](localIndex: Option[Index]) extends Fragment {
var subItems = ArrayBuffer[Fragment]()
+ var bufferDescriptors = ArrayBuffer[BufferDescriptor]()
def accept(visitor: LoopTreeVisitor) {
visitor.enterTree(this)
if (getLocalIndex == b.getLocalIndex) {
val newTree = new LoopTree(getLocalIndex)
newTree.subItems = subItems ++ b.subItems
+ newTree.bufferDescriptors = bufferDescriptors ++ b.bufferDescriptors
newTree
} else {
- throw new SemanticError("Addition undefined for loops with different indices")
+ throw new LogicError("Addition undefined for loops with different indices")
}
+ def addBufferDescriptor(descriptor: BufferDescriptor) {
+ val allFragments = getAllFragments
+
+ // FIXME: the toSet call should be unneccessary
+ if (!allFragments.contains(descriptor.getProducer) ||
+ !descriptor.getConsumers.toSet[Fragment].subsetOf(allFragments)) {
+ throw new LogicError("Attempt to place buffer in LoopTree missing consumer or producer.")
+ }
+
+ for(item <- subItems) item match {
+ case tree: LoopTree => {
+ val placeInside = tree.getLocalIndex match {
+ case Some(index) =>
+ descriptor.getIndices.contains(index) &&
+ (descriptor.getConsumers.toSet[Fragment] + descriptor.getProducer).subsetOf(item.getAllFragments)
+ case None => false
+ }
+
+ if (placeInside) {
+ tree.addBufferDescriptor(descriptor)
+ return
+ }
+ }
+ case _ =>
+ }
+
+ bufferDescriptors += descriptor
+ }
+
+ def getBufferDescriptors : List[BufferDescriptor] = bufferDescriptors.toList
+
def collectDeclarations(nameManager: NameManager) : Set[String] = {
val result = collection.mutable.Set[String]()
case Some(x) => x.toString
case None => "None"
})
+
+ for(bufferDescriptor <- bufferDescriptors)
+ result += "|--"+bufferDescriptor.toString
for (entryID <- 0 until subItems.size) {
val subList = (subItems(entryID) match {