]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Initial work on generating loop hierarchy.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Wed, 1 Feb 2012 19:54:34 +0000 (19:54 +0000)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Wed, 1 Feb 2012 19:54:34 +0000 (19:54 +0000)
src/ofc/generators/onetep/CodeGenerator.scala
src/ofc/generators/onetep/LoopTree.scala [new file with mode: 0644]
src/ofc/generators/onetep/Tree.scala

index 737c44fc0af380d5e4b6b24084536d48280d5466..b5fe47a7bdd4c42e7893b6f765d303075bba139e 100644 (file)
@@ -28,6 +28,28 @@ class NameManager {
 object CodeGenerator {
   def getAllSpaces(term: IterationSpace) : Set[IterationSpace] =
     term.getOperands.toSet.flatMap(getAllSpaces(_: IterationSpace)) + term
+
+  def sortSpaces(spaces : Traversable[IterationSpace]) : List[IterationSpace] = {
+    val seen = collection.mutable.Set[IterationSpace]()
+    spaces.toList.flatMap(sortSpacesHelper(_, seen))
+  }
+
+  private def sortSpacesHelper(input: IterationSpace, seen: collection.mutable.Set[IterationSpace]) : List[IterationSpace] =
+    if (seen add input)
+      input.getOperands.flatMap(sortSpacesHelper(_, seen)) ++ List(input)
+    else
+      Nil
+
+  def sortIndices(indices: Traversable[Index]) : List[Index] = {
+    val seen = collection.mutable.Set[Index]()
+    indices.toList.flatMap(sortIndicesHelper(_, seen))
+  }
+
+  private def sortIndicesHelper(input: Index, seen: collection.mutable.Set[Index]) : List[Index] =
+    if (seen add input)
+      input.getDependencies.toList.flatMap(sortIndicesHelper(_, seen)) ++ List(input)
+    else
+      Nil
 }
 
 class CodeGenerator {
@@ -54,13 +76,14 @@ class CodeGenerator {
     val allIndices = allSpaces flatMap (_.getIndices)
 
     println("dumping operations")
-    for(op <- allSpaces)
+    for(op <- CodeGenerator.sortSpaces(allSpaces))
       println(op)
     println("finished dumping operations\n\ndumping indices")
-    for (i <- allIndices)
+    for (i <- CodeGenerator.sortIndices(allIndices))
       println(i)
     println("finished dumping indices")
 
+    val loopTree = LoopTree()
     // Next: we dump all these things into a prefix map
     System.exit(0)
 
diff --git a/src/ofc/generators/onetep/LoopTree.scala b/src/ofc/generators/onetep/LoopTree.scala
new file mode 100644 (file)
index 0000000..564f50b
--- /dev/null
@@ -0,0 +1,61 @@
+package ofc.generators.onetep
+
+import scala.collection.mutable.ArrayBuffer
+
+/*
+Stores the configuration of indices we will use for code generation.
+*/
+
+object LoopTree {
+  def apply() = new LoopTree(None)
+}
+
+case class LoopTree private(localIndex: Option[Index]) {
+  var subItems = ArrayBuffer[Either[IterationSpace, LoopTree]]()
+
+  def addIterationSpace(space: IterationSpace) {
+    addIterationSpace(getLoopIndices(space), space)
+  }
+
+  def contains(space: IterationSpace) : Boolean = {
+    var found = false
+    for (item <- subItems)
+      found |= (item match {
+        case (item: LoopTree) => item.contains(space)
+        case (item: IterationSpace) => item == space
+      })
+    
+    found
+  }
+
+  private def addIterationSpace(indices : List[Index], space: IterationSpace) {
+    val size = subItems.size
+    var insertPos = size
+
+    for(candidatePos <- (size-1 to 0 by -1)) {
+      val acceptable = (subItems(candidatePos) match {
+        case (item: IterationSpace) => !hasDependency(space, item)
+        case (item: LoopTree) => !hasDependency(space, item)
+      })
+
+      if (acceptable) insertPos = candidatePos
+    }
+
+    indices match {
+      case head :: tail => {
+        subItems(insertPos) match {
+          case LoopTree(Some(head)) => 
+          case _ => indices.insert(insertPos, space)
+        }
+      }
+      case Nil => indices.insert(insertPos, space)
+  }
+
+  private def getLoopIndices(space: IterationSpace) : List[Index] = space.getIndices.toList
+
+  private def hasDependency(from: IterationSpace, to: IterationSpace) : Boolean = {
+    (for(f <- from; if f!=from && f == to) yield f).nonEmpty
+  }
+
+  private def hasDependency(from: IterationSpace, to: LoopTree) = to.contains(from)
+}
index 6d16e5f4dcdd0891ce39391c84ea23c67f386078..d80a12f433f8008dfc9a2992a2b50c266fbd8492 100644 (file)
@@ -17,7 +17,7 @@ trait Index {
 trait SpatialIndex extends Index
 trait DiscreteIndex extends Index
 
-trait IterationSpace {
+trait IterationSpace extends Traversable[IterationSpace] {
   def getAccessExpression(indexNames: NameManager) : String
   def getOperands : List[IterationSpace]
   def getSpatialIndices : List[SpatialIndex]
@@ -25,6 +25,7 @@ trait IterationSpace {
   def getExternalIndices : Set[Index]
   def getInternalIndices : Set[Index] = (getSpatialIndices ++ getDiscreteIndices).toSet
   def getIndices : Set[Index] = getInternalIndices ++ getExternalIndices
+  def foreach[U](f: IterationSpace => U) : Unit = {getOperands.foreach(f); f(this)}
 }
 
 trait DataSpace extends IterationSpace {