]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Generate (rather flawed loop hierarchy).
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Thu, 2 Feb 2012 19:26:03 +0000 (19:26 +0000)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Thu, 2 Feb 2012 19:26:03 +0000 (19:26 +0000)
src/ofc/generators/onetep/CodeGenerator.scala
src/ofc/generators/onetep/LoopTree.scala
src/ofc/generators/onetep/Tree.scala

index b5fe47a7bdd4c42e7893b6f765d303075bba139e..26bb45903f0535138ce48906986eb1ff365af2b2 100644 (file)
@@ -25,33 +25,6 @@ class NameManager {
   }
 }
 
-object CodeGenerator {
-  def getAllSpaces(term: IterationSpace) : Set[IterationSpace] =
-    term.getOperands.toSet.flatMap(getAllSpaces(_: IterationSpace)) + term
-
-  def sortSpaces(spaces : Traversable[IterationSpace]) : List[IterationSpace] = {
-    val seen = collection.mutable.Set[IterationSpace]()
-    spaces.toList.flatMap(sortSpacesHelper(_, seen))
-  }
-
-  private def sortSpacesHelper(input: IterationSpace, seen: collection.mutable.Set[IterationSpace]) : List[IterationSpace] =
-    if (seen add input)
-      input.getOperands.flatMap(sortSpacesHelper(_, seen)) ++ List(input)
-    else
-      Nil
-
-  def sortIndices(indices: Traversable[Index]) : List[Index] = {
-    val seen = collection.mutable.Set[Index]()
-    indices.toList.flatMap(sortIndicesHelper(_, seen))
-  }
-
-  private def sortIndicesHelper(input: Index, seen: collection.mutable.Set[Index]) : List[Index] =
-    if (seen add input)
-      input.getDependencies.toList.flatMap(sortIndicesHelper(_, seen)) ++ List(input)
-    else
-      Nil
-}
-
 class CodeGenerator {
   val code = new StringBuilder()
   val nameManager = new NameManager()
@@ -72,18 +45,18 @@ class CodeGenerator {
   }
 
   def generateCode(space: IterationSpace) {
-    val allSpaces = CodeGenerator.getAllSpaces(space)
+    val allSpaces = IterationSpace.flattenPostorder(space)
     val allIndices = allSpaces flatMap (_.getIndices)
 
     println("dumping operations")
-    for(op <- CodeGenerator.sortSpaces(allSpaces))
+    for(op <- IterationSpace.sort(allSpaces))
       println(op)
     println("finished dumping operations\n\ndumping indices")
-    for (i <- CodeGenerator.sortIndices(allIndices))
+    for (i <- Index.sort(allIndices))
       println(i)
     println("finished dumping indices")
 
-    val loopTree = LoopTree()
+    val loopTree = LoopTree(space)
     // Next: we dump all these things into a prefix map
     System.exit(0)
 
index 564f50b7f38a030e586d7d8637ec0d950e0790c9..3d340435ec6325f8099826ec6bd4d5a7e15cdcb6 100644 (file)
@@ -7,16 +7,26 @@ Stores the configuration of indices we will use for code generation.
 */
 
 object LoopTree {
-  def apply() = new LoopTree(None)
+  def apply(root: IterationSpace) = {
+    val base = new LoopTree(None)
+    val sortedSpaces = IterationSpace.flattenPostorder(root)
+    val sortedIndices = Index.sort(sortedSpaces flatMap (_.getIndices))
+
+    for(space <- sortedSpaces) {
+      val indices = space.getIndices
+      val localSortedIndices = sortedIndices filter (indices.contains(_))
+      base.addIterationSpace(localSortedIndices, space)
+      println(localSortedIndices.toString + " -> "+space)
+    }
+
+    println(base)
+    base
+  }
 }
 
 case class LoopTree private(localIndex: Option[Index]) {
   var subItems = ArrayBuffer[Either[IterationSpace, LoopTree]]()
 
-  def addIterationSpace(space: IterationSpace) {
-    addIterationSpace(getLoopIndices(space), space)
-  }
-
   def contains(space: IterationSpace) : Boolean = {
     var found = false
     for (item <- subItems)
@@ -28,34 +38,48 @@ case class LoopTree private(localIndex: Option[Index]) {
     found
   }
 
-  private def addIterationSpace(indices : List[Index], space: IterationSpace) {
-    val size = subItems.size
-    var insertPos = size
-
-    for(candidatePos <- (size-1 to 0 by -1)) {
-      val acceptable = (subItems(candidatePos) match {
-        case (item: IterationSpace) => !hasDependency(space, item)
-        case (item: LoopTree) => !hasDependency(space, item)
-      })
-
-      if (acceptable) insertPos = candidatePos
+  private def addIterationSpace(indices: List[Index], space: IterationSpace) {
+    indices match {
+      case Nil => subItems += Left(space)
+      case (head :: tail) => getEndLoop(head).addIterationSpace(tail, space)
     }
+  }
 
-    indices match {
-      case head :: tail => {
-        subItems(insertPos) match {
-          case LoopTree(Some(head)) => 
-          case _ => indices.insert(insertPos, space)
-        }
+  private def getEndLoop(index: Index) : LoopTree = {
+    def newTree = { val tree = new LoopTree(Some(index)); subItems += Right(tree); tree} 
+
+    if (subItems.isEmpty) 
+      newTree
+    else
+      subItems.last match {
+        case Right(tree) => if (tree.getLocalIndex == Some(index)) tree else newTree
+        case _ => newTree
       }
-      case Nil => indices.insert(insertPos, space)
   }
 
+  private def getLocalIndex = localIndex
+
   private def getLoopIndices(space: IterationSpace) : List[Index] = space.getIndices.toList
 
-  private def hasDependency(from: IterationSpace, to: IterationSpace) : Boolean = {
-    (for(f <- from; if f!=from && f == to) yield f).nonEmpty
-  }
+  private def hasDependency(from: IterationSpace, to: IterationSpace) : Boolean =
+    (for(f <- IterationSpace.flattenPostorder(from); if f!=from && f == to) yield f).nonEmpty
 
   private def hasDependency(from: IterationSpace, to: LoopTree) = to.contains(from)
+
+  override def toString : String = toStrings.mkString("\n")
+
+  private def toStrings : List[String] = {
+    val result = ArrayBuffer[String]("Index: " + localIndex)
+    
+    for (entry <- subItems) {
+      val subList = (entry match {
+        case Left(space) => List(space.toString)
+        case Right(tree) => tree.toStrings
+      })
+
+      result ++= "|--"+subList.head :: (subList.tail.map("|  "+_))
+    }
+
+    result.toList
+  }
 }
index d80a12f433f8008dfc9a2992a2b50c266fbd8492..1a624c1cb2478b61ee13cf582bb51f05de6f670e 100644 (file)
@@ -4,6 +4,19 @@ import ofc.parser
 import ofc.parser.Identifier
 import ofc.{InvalidInputException,UnimplementedException}
 
+object Index {
+  def sort(indices: Traversable[Index]) : List[Index] = {
+    def helper(input: Index, seen: collection.mutable.Set[Index]) : List[Index] =
+    if (seen add input)
+      input.getDependencies.toList.flatMap(helper(_, seen)) ++ List(input)
+    else
+      Nil
+
+    val seen = collection.mutable.Set[Index]()
+    indices.toList.flatMap(helper(_, seen))
+  }
+}
+
 trait Index {
   def getName : String
   def getDependencies : Set[Index]
@@ -17,7 +30,23 @@ trait Index {
 trait SpatialIndex extends Index
 trait DiscreteIndex extends Index
 
-trait IterationSpace extends Traversable[IterationSpace] {
+object IterationSpace {
+  def sort(spaces : Traversable[IterationSpace]) : List[IterationSpace] = {
+    def helper(input: IterationSpace, seen: collection.mutable.Set[IterationSpace]) : List[IterationSpace] =
+      if (seen add input)
+        input.getOperands.flatMap(helper(_, seen)) ++ List(input)
+      else
+        Nil
+
+    val seen = collection.mutable.Set[IterationSpace]()
+    spaces.toList.flatMap(helper(_, seen))
+  }
+
+  def flattenPostorder(term: IterationSpace) : Traversable[IterationSpace] = 
+    term.getOperands.toTraversable.flatMap(flattenPostorder(_)) ++ List(term)
+}
+
+trait IterationSpace {
   def getAccessExpression(indexNames: NameManager) : String
   def getOperands : List[IterationSpace]
   def getSpatialIndices : List[SpatialIndex]
@@ -25,7 +54,6 @@ trait IterationSpace extends Traversable[IterationSpace] {
   def getExternalIndices : Set[Index]
   def getInternalIndices : Set[Index] = (getSpatialIndices ++ getDiscreteIndices).toSet
   def getIndices : Set[Index] = getInternalIndices ++ getExternalIndices
-  def foreach[U](f: IterationSpace => U) : Unit = {getOperands.foreach(f); f(this)}
 }
 
 trait DataSpace extends IterationSpace {
@@ -325,11 +353,7 @@ class TreeBuilder(dictionary : Dictionary) {
   }
 
   def buildIndexedTerm(term: parser.IndexedTerm) : IterationSpace = {
-    val dataSpace = dictionary.getData(term.id) match {
-      case (functionSet : PPDFunctionSet) => new GeneralInnerProduct(List(functionSet), Set(functionSet.getPPDIndex))
-      case v => v
-    }
-
+    val dataSpace = dictionary.getData(term.id)
     val indices = for(bindingID <- term.indices) yield dictionary.getIndex(bindingID)
 
     if (indices.size != dataSpace.getDiscreteIndices.size)