package ofc.generators.onetep
import scala.collection.mutable.ArrayBuffer
+import ofc.SemanticError
/*
Stores the configuration of indices we will use for code generation.
val nameManager = new NameManager()
var declarations = Set[String]()
- def addIterationSpace(indices: List[Index], space: IterationSpace) = base.addIterationSpace(indices, space)
+ def addIterationSpace(indices: List[Index], space: IterationSpace) {
+ base.addIterationSpace(indices, space)
+ base.fuse()
+ }
+
def getTree = base
def generateCode : String = {
for (op <- term.getOperands) declarationsSet ++= collectSpaceDeclarations(op, nameManager)
declarationsSet
}
+
+ def attemptFusion(a: LoopTree, b: LoopTree, commonScope: LoopTree) : Option[LoopTree] = {
+ if (a == b)
+ None
+ else if (a.getLocalIndex != b.getLocalIndex)
+ None
+ else
+ Some(a + b)
+ }
}
-case class LoopTree private[onetep](localIndex: Option[Index]) {
+class LoopTree private[onetep](localIndex: Option[Index]) {
var subItems = ArrayBuffer[Either[IterationSpace, LoopTree]]()
- def contains(space: IterationSpace) : Boolean = {
+ private def contains(space: IterationSpace, deep: Boolean) : Boolean = {
var found = false
for (item <- subItems)
found |= (item match {
- case (item: LoopTree) => item.contains(space)
+ case (item: LoopTree) => deep && item.contains(space, deep)
case (item: IterationSpace) => item == space
})
found
}
+ def containsShallow(space: IterationSpace) : Boolean = contains(space, false)
+ def containsDeep(space: IterationSpace) : Boolean = contains(space, true)
+
def accept(visitor: LoopTreeVisitor) {
for (item <- subItems)
item match {
}
}
+ def +(b: LoopTree) : LoopTree =
+ if (getLocalIndex == b.getLocalIndex) {
+ val newTree = new LoopTree(getLocalIndex)
+ newTree.subItems = subItems ++ b.subItems
+ newTree
+ } else {
+ throw new SemanticError("Addition undefined for loops with different indices")
+ }
+
def collectDeclarations(nameManager: NameManager) : Set[String] = {
val result = collection.mutable.Set[String]()
result.toSet
}
+ def getDependencies : Set[IterationSpace] = {
+ val dependencies = collection.mutable.Set[IterationSpace]()
+
+ for (item <- subItems)
+ item match {
+ case Left(space) => dependencies ++= IterationSpace.flattenPostorder(space)
+ case Right(tree) => dependencies ++= tree.getDependencies
+ }
+
+ dependencies.toSet
+ }
+
+ def getSpaces : Set[IterationSpace] = {
+ val spaces = collection.mutable.Set[IterationSpace]()
+
+ for (item <- subItems)
+ item match {
+ case Left(space) => spaces += space
+ case Right(tree) => spaces ++= tree.getSpaces
+ }
+
+ spaces.toSet
+ }
+
+
def addIterationSpace(indices: List[Index], space: IterationSpace) {
indices match {
case Nil => subItems += Left(space)
- case (head :: tail) => getEndLoop(head).addIterationSpace(tail, space)
+ case (head :: tail) => {
+ val tree = new LoopTree(Some(head))
+ subItems += Right(tree)
+ tree.addIterationSpace(tail, space)
+ }
}
}
- private def getEndLoop(index: Index) : LoopTree = {
- def newTree = { val tree = new LoopTree(Some(index)); subItems += Right(tree); tree}
+ def fuse() {
+ val trees = collection.mutable.Set[LoopTree]()
+ for (item <- subItems)
+ item match {
+ case Right(tree) => trees += tree
+ case _ =>
+ }
- var destination : Option[LoopTree]= None
- for(item <- subItems)
- item match {
- case Right(tree) => if (tree.getLocalIndex == Some(index)) destination = Some(tree)
- case _ =>
- }
+ subItems --= trees map (Right(_))
- destination match {
- case Some(tree) => tree
- case None => newTree
+ def attemptFusion(loops: collection.mutable.Set[LoopTree]) : Boolean = {
+ for(a <- loops)
+ for(b <- loops)
+ LoopTree.attemptFusion(a, b, this) match {
+ case Some(fused) => {loops -= a; loops -= b; loops += fused; return true}
+ case None =>
+ }
+ false
}
+
+ while(attemptFusion(trees)) {}
+ trees map (_.fuse())
+ subItems ++= trees map (Right(_))
+
+ sort()
+ }
+
+ def sort() {
+ //TODO: Implement me!
}
def getLocalIndex = localIndex
private def getLoopIndices(space: IterationSpace) : List[Index] = space.getIndices.toList
- private def hasDependency(from: IterationSpace, to: IterationSpace) : Boolean =
- (for(f <- IterationSpace.flattenPostorder(from); if f!=from && f == to) yield f).nonEmpty
-
- private def hasDependency(from: IterationSpace, to: LoopTree) = to.contains(from)
-
override def toString : String = toStrings.mkString("\n")
private def toStrings : List[String] = {
})
val prefix = if (entryID < subItems.size-1) "| " else " "
- result ++= "|--"+subList.head :: (subList.tail.map(prefix+_))
+ result ++= "`--"+subList.head :: (subList.tail.map(prefix+_))
}
result.toList