From: Francis Russell Date: Thu, 2 Feb 2012 19:26:03 +0000 (+0000) Subject: Generate (rather flawed loop hierarchy). X-Git-Url: https://git.unchartedbackwaters.co.uk/w/?a=commitdiff_plain;h=3df1c44c704d2b1082ddfcc760bb97984ff34c7f;p=francis%2Fofc.git Generate (rather flawed loop hierarchy). --- diff --git a/src/ofc/generators/onetep/CodeGenerator.scala b/src/ofc/generators/onetep/CodeGenerator.scala index b5fe47a..26bb459 100644 --- a/src/ofc/generators/onetep/CodeGenerator.scala +++ b/src/ofc/generators/onetep/CodeGenerator.scala @@ -25,33 +25,6 @@ class NameManager { } } -object CodeGenerator { - def getAllSpaces(term: IterationSpace) : Set[IterationSpace] = - term.getOperands.toSet.flatMap(getAllSpaces(_: IterationSpace)) + term - - def sortSpaces(spaces : Traversable[IterationSpace]) : List[IterationSpace] = { - val seen = collection.mutable.Set[IterationSpace]() - spaces.toList.flatMap(sortSpacesHelper(_, seen)) - } - - private def sortSpacesHelper(input: IterationSpace, seen: collection.mutable.Set[IterationSpace]) : List[IterationSpace] = - if (seen add input) - input.getOperands.flatMap(sortSpacesHelper(_, seen)) ++ List(input) - else - Nil - - def sortIndices(indices: Traversable[Index]) : List[Index] = { - val seen = collection.mutable.Set[Index]() - indices.toList.flatMap(sortIndicesHelper(_, seen)) - } - - private def sortIndicesHelper(input: Index, seen: collection.mutable.Set[Index]) : List[Index] = - if (seen add input) - input.getDependencies.toList.flatMap(sortIndicesHelper(_, seen)) ++ List(input) - else - Nil -} - class CodeGenerator { val code = new StringBuilder() val nameManager = new NameManager() @@ -72,18 +45,18 @@ class CodeGenerator { } def generateCode(space: IterationSpace) { - val allSpaces = CodeGenerator.getAllSpaces(space) + val allSpaces = IterationSpace.flattenPostorder(space) val allIndices = allSpaces flatMap (_.getIndices) println("dumping operations") - for(op <- CodeGenerator.sortSpaces(allSpaces)) + for(op <- IterationSpace.sort(allSpaces)) println(op) println("finished dumping operations\n\ndumping indices") - for (i <- CodeGenerator.sortIndices(allIndices)) + for (i <- Index.sort(allIndices)) println(i) println("finished dumping indices") - val loopTree = LoopTree() + val loopTree = LoopTree(space) // Next: we dump all these things into a prefix map System.exit(0) diff --git a/src/ofc/generators/onetep/LoopTree.scala b/src/ofc/generators/onetep/LoopTree.scala index 564f50b..3d34043 100644 --- a/src/ofc/generators/onetep/LoopTree.scala +++ b/src/ofc/generators/onetep/LoopTree.scala @@ -7,16 +7,26 @@ Stores the configuration of indices we will use for code generation. */ object LoopTree { - def apply() = new LoopTree(None) + def apply(root: IterationSpace) = { + val base = new LoopTree(None) + val sortedSpaces = IterationSpace.flattenPostorder(root) + val sortedIndices = Index.sort(sortedSpaces flatMap (_.getIndices)) + + for(space <- sortedSpaces) { + val indices = space.getIndices + val localSortedIndices = sortedIndices filter (indices.contains(_)) + base.addIterationSpace(localSortedIndices, space) + println(localSortedIndices.toString + " -> "+space) + } + + println(base) + base + } } case class LoopTree private(localIndex: Option[Index]) { var subItems = ArrayBuffer[Either[IterationSpace, LoopTree]]() - def addIterationSpace(space: IterationSpace) { - addIterationSpace(getLoopIndices(space), space) - } - def contains(space: IterationSpace) : Boolean = { var found = false for (item <- subItems) @@ -28,34 +38,48 @@ case class LoopTree private(localIndex: Option[Index]) { found } - private def addIterationSpace(indices : List[Index], space: IterationSpace) { - val size = subItems.size - var insertPos = size - - for(candidatePos <- (size-1 to 0 by -1)) { - val acceptable = (subItems(candidatePos) match { - case (item: IterationSpace) => !hasDependency(space, item) - case (item: LoopTree) => !hasDependency(space, item) - }) - - if (acceptable) insertPos = candidatePos + private def addIterationSpace(indices: List[Index], space: IterationSpace) { + indices match { + case Nil => subItems += Left(space) + case (head :: tail) => getEndLoop(head).addIterationSpace(tail, space) } + } - indices match { - case head :: tail => { - subItems(insertPos) match { - case LoopTree(Some(head)) => - case _ => indices.insert(insertPos, space) - } + private def getEndLoop(index: Index) : LoopTree = { + def newTree = { val tree = new LoopTree(Some(index)); subItems += Right(tree); tree} + + if (subItems.isEmpty) + newTree + else + subItems.last match { + case Right(tree) => if (tree.getLocalIndex == Some(index)) tree else newTree + case _ => newTree } - case Nil => indices.insert(insertPos, space) } + private def getLocalIndex = localIndex + private def getLoopIndices(space: IterationSpace) : List[Index] = space.getIndices.toList - private def hasDependency(from: IterationSpace, to: IterationSpace) : Boolean = { - (for(f <- from; if f!=from && f == to) yield f).nonEmpty - } + private def hasDependency(from: IterationSpace, to: IterationSpace) : Boolean = + (for(f <- IterationSpace.flattenPostorder(from); if f!=from && f == to) yield f).nonEmpty private def hasDependency(from: IterationSpace, to: LoopTree) = to.contains(from) + + override def toString : String = toStrings.mkString("\n") + + private def toStrings : List[String] = { + val result = ArrayBuffer[String]("Index: " + localIndex) + + for (entry <- subItems) { + val subList = (entry match { + case Left(space) => List(space.toString) + case Right(tree) => tree.toStrings + }) + + result ++= "|--"+subList.head :: (subList.tail.map("| "+_)) + } + + result.toList + } } diff --git a/src/ofc/generators/onetep/Tree.scala b/src/ofc/generators/onetep/Tree.scala index d80a12f..1a624c1 100644 --- a/src/ofc/generators/onetep/Tree.scala +++ b/src/ofc/generators/onetep/Tree.scala @@ -4,6 +4,19 @@ import ofc.parser import ofc.parser.Identifier import ofc.{InvalidInputException,UnimplementedException} +object Index { + def sort(indices: Traversable[Index]) : List[Index] = { + def helper(input: Index, seen: collection.mutable.Set[Index]) : List[Index] = + if (seen add input) + input.getDependencies.toList.flatMap(helper(_, seen)) ++ List(input) + else + Nil + + val seen = collection.mutable.Set[Index]() + indices.toList.flatMap(helper(_, seen)) + } +} + trait Index { def getName : String def getDependencies : Set[Index] @@ -17,7 +30,23 @@ trait Index { trait SpatialIndex extends Index trait DiscreteIndex extends Index -trait IterationSpace extends Traversable[IterationSpace] { +object IterationSpace { + def sort(spaces : Traversable[IterationSpace]) : List[IterationSpace] = { + def helper(input: IterationSpace, seen: collection.mutable.Set[IterationSpace]) : List[IterationSpace] = + if (seen add input) + input.getOperands.flatMap(helper(_, seen)) ++ List(input) + else + Nil + + val seen = collection.mutable.Set[IterationSpace]() + spaces.toList.flatMap(helper(_, seen)) + } + + def flattenPostorder(term: IterationSpace) : Traversable[IterationSpace] = + term.getOperands.toTraversable.flatMap(flattenPostorder(_)) ++ List(term) +} + +trait IterationSpace { def getAccessExpression(indexNames: NameManager) : String def getOperands : List[IterationSpace] def getSpatialIndices : List[SpatialIndex] @@ -25,7 +54,6 @@ trait IterationSpace extends Traversable[IterationSpace] { def getExternalIndices : Set[Index] def getInternalIndices : Set[Index] = (getSpatialIndices ++ getDiscreteIndices).toSet def getIndices : Set[Index] = getInternalIndices ++ getExternalIndices - def foreach[U](f: IterationSpace => U) : Unit = {getOperands.foreach(f); f(this)} } trait DataSpace extends IterationSpace { @@ -325,11 +353,7 @@ class TreeBuilder(dictionary : Dictionary) { } def buildIndexedTerm(term: parser.IndexedTerm) : IterationSpace = { - val dataSpace = dictionary.getData(term.id) match { - case (functionSet : PPDFunctionSet) => new GeneralInnerProduct(List(functionSet), Set(functionSet.getPPDIndex)) - case v => v - } - + val dataSpace = dictionary.getData(term.id) val indices = for(bindingID <- term.indices) yield dictionary.getIndex(bindingID) if (indices.size != dataSpace.getDiscreteIndices.size)