From a5b0dac435054e6376146b13251a4d1485eb5c48 Mon Sep 17 00:00:00 2001 From: Francis Russell Date: Wed, 1 Feb 2012 19:54:34 +0000 Subject: [PATCH] Initial work on generating loop hierarchy. --- src/ofc/generators/onetep/CodeGenerator.scala | 27 +++++++- src/ofc/generators/onetep/LoopTree.scala | 61 +++++++++++++++++++ src/ofc/generators/onetep/Tree.scala | 3 +- 3 files changed, 88 insertions(+), 3 deletions(-) create mode 100644 src/ofc/generators/onetep/LoopTree.scala diff --git a/src/ofc/generators/onetep/CodeGenerator.scala b/src/ofc/generators/onetep/CodeGenerator.scala index 737c44f..b5fe47a 100644 --- a/src/ofc/generators/onetep/CodeGenerator.scala +++ b/src/ofc/generators/onetep/CodeGenerator.scala @@ -28,6 +28,28 @@ class NameManager { object CodeGenerator { def getAllSpaces(term: IterationSpace) : Set[IterationSpace] = term.getOperands.toSet.flatMap(getAllSpaces(_: IterationSpace)) + term + + def sortSpaces(spaces : Traversable[IterationSpace]) : List[IterationSpace] = { + val seen = collection.mutable.Set[IterationSpace]() + spaces.toList.flatMap(sortSpacesHelper(_, seen)) + } + + private def sortSpacesHelper(input: IterationSpace, seen: collection.mutable.Set[IterationSpace]) : List[IterationSpace] = + if (seen add input) + input.getOperands.flatMap(sortSpacesHelper(_, seen)) ++ List(input) + else + Nil + + def sortIndices(indices: Traversable[Index]) : List[Index] = { + val seen = collection.mutable.Set[Index]() + indices.toList.flatMap(sortIndicesHelper(_, seen)) + } + + private def sortIndicesHelper(input: Index, seen: collection.mutable.Set[Index]) : List[Index] = + if (seen add input) + input.getDependencies.toList.flatMap(sortIndicesHelper(_, seen)) ++ List(input) + else + Nil } class CodeGenerator { @@ -54,13 +76,14 @@ class CodeGenerator { val allIndices = allSpaces flatMap (_.getIndices) println("dumping operations") - for(op <- allSpaces) + for(op <- CodeGenerator.sortSpaces(allSpaces)) println(op) println("finished dumping operations\n\ndumping indices") - for (i <- allIndices) + for (i <- CodeGenerator.sortIndices(allIndices)) println(i) println("finished dumping indices") + val loopTree = LoopTree() // Next: we dump all these things into a prefix map System.exit(0) diff --git a/src/ofc/generators/onetep/LoopTree.scala b/src/ofc/generators/onetep/LoopTree.scala new file mode 100644 index 0000000..564f50b --- /dev/null +++ b/src/ofc/generators/onetep/LoopTree.scala @@ -0,0 +1,61 @@ +package ofc.generators.onetep + +import scala.collection.mutable.ArrayBuffer + +/* +Stores the configuration of indices we will use for code generation. +*/ + +object LoopTree { + def apply() = new LoopTree(None) +} + +case class LoopTree private(localIndex: Option[Index]) { + var subItems = ArrayBuffer[Either[IterationSpace, LoopTree]]() + + def addIterationSpace(space: IterationSpace) { + addIterationSpace(getLoopIndices(space), space) + } + + def contains(space: IterationSpace) : Boolean = { + var found = false + for (item <- subItems) + found |= (item match { + case (item: LoopTree) => item.contains(space) + case (item: IterationSpace) => item == space + }) + + found + } + + private def addIterationSpace(indices : List[Index], space: IterationSpace) { + val size = subItems.size + var insertPos = size + + for(candidatePos <- (size-1 to 0 by -1)) { + val acceptable = (subItems(candidatePos) match { + case (item: IterationSpace) => !hasDependency(space, item) + case (item: LoopTree) => !hasDependency(space, item) + }) + + if (acceptable) insertPos = candidatePos + } + + indices match { + case head :: tail => { + subItems(insertPos) match { + case LoopTree(Some(head)) => + case _ => indices.insert(insertPos, space) + } + } + case Nil => indices.insert(insertPos, space) + } + + private def getLoopIndices(space: IterationSpace) : List[Index] = space.getIndices.toList + + private def hasDependency(from: IterationSpace, to: IterationSpace) : Boolean = { + (for(f <- from; if f!=from && f == to) yield f).nonEmpty + } + + private def hasDependency(from: IterationSpace, to: LoopTree) = to.contains(from) +} diff --git a/src/ofc/generators/onetep/Tree.scala b/src/ofc/generators/onetep/Tree.scala index 6d16e5f..d80a12f 100644 --- a/src/ofc/generators/onetep/Tree.scala +++ b/src/ofc/generators/onetep/Tree.scala @@ -17,7 +17,7 @@ trait Index { trait SpatialIndex extends Index trait DiscreteIndex extends Index -trait IterationSpace { +trait IterationSpace extends Traversable[IterationSpace] { def getAccessExpression(indexNames: NameManager) : String def getOperands : List[IterationSpace] def getSpatialIndices : List[SpatialIndex] @@ -25,6 +25,7 @@ trait IterationSpace { def getExternalIndices : Set[Index] def getInternalIndices : Set[Index] = (getSpatialIndices ++ getDiscreteIndices).toSet def getIndices : Set[Index] = getInternalIndices ++ getExternalIndices + def foreach[U](f: IterationSpace => U) : Unit = {getOperands.foreach(f); f(this)} } trait DataSpace extends IterationSpace { -- 2.47.3