}
}
-object CodeGenerator {
- def getAllSpaces(term: IterationSpace) : Set[IterationSpace] =
- term.getOperands.toSet.flatMap(getAllSpaces(_: IterationSpace)) + term
-
- def sortSpaces(spaces : Traversable[IterationSpace]) : List[IterationSpace] = {
- val seen = collection.mutable.Set[IterationSpace]()
- spaces.toList.flatMap(sortSpacesHelper(_, seen))
- }
-
- private def sortSpacesHelper(input: IterationSpace, seen: collection.mutable.Set[IterationSpace]) : List[IterationSpace] =
- if (seen add input)
- input.getOperands.flatMap(sortSpacesHelper(_, seen)) ++ List(input)
- else
- Nil
-
- def sortIndices(indices: Traversable[Index]) : List[Index] = {
- val seen = collection.mutable.Set[Index]()
- indices.toList.flatMap(sortIndicesHelper(_, seen))
- }
-
- private def sortIndicesHelper(input: Index, seen: collection.mutable.Set[Index]) : List[Index] =
- if (seen add input)
- input.getDependencies.toList.flatMap(sortIndicesHelper(_, seen)) ++ List(input)
- else
- Nil
-}
-
class CodeGenerator {
val code = new StringBuilder()
val nameManager = new NameManager()
}
def generateCode(space: IterationSpace) {
- val allSpaces = CodeGenerator.getAllSpaces(space)
+ val allSpaces = IterationSpace.flattenPostorder(space)
val allIndices = allSpaces flatMap (_.getIndices)
println("dumping operations")
- for(op <- CodeGenerator.sortSpaces(allSpaces))
+ for(op <- IterationSpace.sort(allSpaces))
println(op)
println("finished dumping operations\n\ndumping indices")
- for (i <- CodeGenerator.sortIndices(allIndices))
+ for (i <- Index.sort(allIndices))
println(i)
println("finished dumping indices")
- val loopTree = LoopTree()
+ val loopTree = LoopTree(space)
// Next: we dump all these things into a prefix map
System.exit(0)
*/
object LoopTree {
- def apply() = new LoopTree(None)
+ def apply(root: IterationSpace) = {
+ val base = new LoopTree(None)
+ val sortedSpaces = IterationSpace.flattenPostorder(root)
+ val sortedIndices = Index.sort(sortedSpaces flatMap (_.getIndices))
+
+ for(space <- sortedSpaces) {
+ val indices = space.getIndices
+ val localSortedIndices = sortedIndices filter (indices.contains(_))
+ base.addIterationSpace(localSortedIndices, space)
+ println(localSortedIndices.toString + " -> "+space)
+ }
+
+ println(base)
+ base
+ }
}
case class LoopTree private(localIndex: Option[Index]) {
var subItems = ArrayBuffer[Either[IterationSpace, LoopTree]]()
- def addIterationSpace(space: IterationSpace) {
- addIterationSpace(getLoopIndices(space), space)
- }
-
def contains(space: IterationSpace) : Boolean = {
var found = false
for (item <- subItems)
found
}
- private def addIterationSpace(indices : List[Index], space: IterationSpace) {
- val size = subItems.size
- var insertPos = size
-
- for(candidatePos <- (size-1 to 0 by -1)) {
- val acceptable = (subItems(candidatePos) match {
- case (item: IterationSpace) => !hasDependency(space, item)
- case (item: LoopTree) => !hasDependency(space, item)
- })
-
- if (acceptable) insertPos = candidatePos
+ private def addIterationSpace(indices: List[Index], space: IterationSpace) {
+ indices match {
+ case Nil => subItems += Left(space)
+ case (head :: tail) => getEndLoop(head).addIterationSpace(tail, space)
}
+ }
- indices match {
- case head :: tail => {
- subItems(insertPos) match {
- case LoopTree(Some(head)) =>
- case _ => indices.insert(insertPos, space)
- }
+ private def getEndLoop(index: Index) : LoopTree = {
+ def newTree = { val tree = new LoopTree(Some(index)); subItems += Right(tree); tree}
+
+ if (subItems.isEmpty)
+ newTree
+ else
+ subItems.last match {
+ case Right(tree) => if (tree.getLocalIndex == Some(index)) tree else newTree
+ case _ => newTree
}
- case Nil => indices.insert(insertPos, space)
}
+ private def getLocalIndex = localIndex
+
private def getLoopIndices(space: IterationSpace) : List[Index] = space.getIndices.toList
- private def hasDependency(from: IterationSpace, to: IterationSpace) : Boolean = {
- (for(f <- from; if f!=from && f == to) yield f).nonEmpty
- }
+ private def hasDependency(from: IterationSpace, to: IterationSpace) : Boolean =
+ (for(f <- IterationSpace.flattenPostorder(from); if f!=from && f == to) yield f).nonEmpty
private def hasDependency(from: IterationSpace, to: LoopTree) = to.contains(from)
+
+ override def toString : String = toStrings.mkString("\n")
+
+ private def toStrings : List[String] = {
+ val result = ArrayBuffer[String]("Index: " + localIndex)
+
+ for (entry <- subItems) {
+ val subList = (entry match {
+ case Left(space) => List(space.toString)
+ case Right(tree) => tree.toStrings
+ })
+
+ result ++= "|--"+subList.head :: (subList.tail.map("| "+_))
+ }
+
+ result.toList
+ }
}
import ofc.parser.Identifier
import ofc.{InvalidInputException,UnimplementedException}
+object Index {
+ def sort(indices: Traversable[Index]) : List[Index] = {
+ def helper(input: Index, seen: collection.mutable.Set[Index]) : List[Index] =
+ if (seen add input)
+ input.getDependencies.toList.flatMap(helper(_, seen)) ++ List(input)
+ else
+ Nil
+
+ val seen = collection.mutable.Set[Index]()
+ indices.toList.flatMap(helper(_, seen))
+ }
+}
+
trait Index {
def getName : String
def getDependencies : Set[Index]
trait SpatialIndex extends Index
trait DiscreteIndex extends Index
-trait IterationSpace extends Traversable[IterationSpace] {
+object IterationSpace {
+ def sort(spaces : Traversable[IterationSpace]) : List[IterationSpace] = {
+ def helper(input: IterationSpace, seen: collection.mutable.Set[IterationSpace]) : List[IterationSpace] =
+ if (seen add input)
+ input.getOperands.flatMap(helper(_, seen)) ++ List(input)
+ else
+ Nil
+
+ val seen = collection.mutable.Set[IterationSpace]()
+ spaces.toList.flatMap(helper(_, seen))
+ }
+
+ def flattenPostorder(term: IterationSpace) : Traversable[IterationSpace] =
+ term.getOperands.toTraversable.flatMap(flattenPostorder(_)) ++ List(term)
+}
+
+trait IterationSpace {
def getAccessExpression(indexNames: NameManager) : String
def getOperands : List[IterationSpace]
def getSpatialIndices : List[SpatialIndex]
def getExternalIndices : Set[Index]
def getInternalIndices : Set[Index] = (getSpatialIndices ++ getDiscreteIndices).toSet
def getIndices : Set[Index] = getInternalIndices ++ getExternalIndices
- def foreach[U](f: IterationSpace => U) : Unit = {getOperands.foreach(f); f(this)}
}
trait DataSpace extends IterationSpace {
}
def buildIndexedTerm(term: parser.IndexedTerm) : IterationSpace = {
- val dataSpace = dictionary.getData(term.id) match {
- case (functionSet : PPDFunctionSet) => new GeneralInnerProduct(List(functionSet), Set(functionSet.getPPDIndex))
- case v => v
- }
-
+ val dataSpace = dictionary.getData(term.id)
val indices = for(bindingID <- term.indices) yield dictionary.getIndex(bindingID)
if (indices.size != dataSpace.getDiscreteIndices.size)