]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
New index fusion implementation.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Fri, 3 Feb 2012 16:45:25 +0000 (16:45 +0000)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Fri, 3 Feb 2012 16:45:25 +0000 (16:45 +0000)
src/ofc/OFC.scala
src/ofc/generators/onetep/CodeGenerator.scala
src/ofc/generators/onetep/LoopTree.scala

index 96bc906af9f80e23c5404d2118d7b9e83dc5111b..0bec89b9661821cf533c396689b54b4377a0e321 100644 (file)
@@ -6,6 +6,7 @@ import generators.Generator
 
 class InvalidInputException(s: String) extends Exception(s)
 class UnimplementedException(s: String) extends Exception(s)
+class SemanticError(s: String) extends Exception(s)
 
 object OFC extends Parser {
 
index 807061edffb9c95fe4c76eaacbb51ca28c78b177..5aa32bc383c61aca4f16af7e04eb325727d0340d 100644 (file)
@@ -85,7 +85,7 @@ class CodeGenerator {
         code append "real(kind=DP), allocatable, dimension" + (":"*concreteIndexList.size).mkString("(",", &\n",")") + " :: " + 
           storageName + "\n"
         code append "allocate("+ storageName +
-          (concreteIndexList map ((x : Index) => x.getDenseWidth)).mkString("(",",",")") + ", stat=ierr)\n"
+          (concreteIndexList map ((x : Index) => x.getDenseWidth(nameManager))).mkString("(",",",")") + ", stat=ierr)\n"
   
         // We've declared temporary storage, now create the loops to populate it
         for (index <- concreteIndexList) code append index.generateIterationHeader(nameManager) + "\n"
index dcd4bebdf2c2ce811a6a4a74e1cb264816972727..8ef209462b91e91f56e497dd7f9f8b3883389930 100644 (file)
@@ -1,6 +1,7 @@
 package ofc.generators.onetep
 
 import scala.collection.mutable.ArrayBuffer
+import ofc.SemanticError
 
 /*
 Stores the configuration of indices we will use for code generation.
@@ -27,7 +28,11 @@ class LoopNest {
   val nameManager = new NameManager()
   var declarations = Set[String]()
 
-  def addIterationSpace(indices: List[Index], space: IterationSpace) = base.addIterationSpace(indices, space)
+  def addIterationSpace(indices: List[Index], space: IterationSpace) {
+    base.addIterationSpace(indices, space)
+    base.fuse()
+  }
+
   def getTree = base
 
   def generateCode : String = {
@@ -81,22 +86,34 @@ object LoopTree {
     for (op <- term.getOperands) declarationsSet ++= collectSpaceDeclarations(op, nameManager)
     declarationsSet
   }
+
+  def attemptFusion(a: LoopTree, b: LoopTree, commonScope: LoopTree) : Option[LoopTree] = {
+    if (a == b)
+      None
+    else if (a.getLocalIndex != b.getLocalIndex)
+      None
+    else
+      Some(a + b)
+  }
 }
 
-case class LoopTree private[onetep](localIndex: Option[Index]) {
+class LoopTree private[onetep](localIndex: Option[Index]) {
   var subItems = ArrayBuffer[Either[IterationSpace, LoopTree]]()
 
-  def contains(space: IterationSpace) : Boolean = {
+  private def contains(space: IterationSpace, deep: Boolean) : Boolean = {
     var found = false
     for (item <- subItems)
       found |= (item match {
-        case (item: LoopTree) => item.contains(space)
+        case (item: LoopTree) => deep && item.contains(space, deep)
         case (item: IterationSpace) => item == space
       })
     
     found
   }
 
+  def containsShallow(space: IterationSpace) : Boolean = contains(space, false)
+  def containsDeep(space: IterationSpace) : Boolean = contains(space, true)
+
   def accept(visitor: LoopTreeVisitor) {
     for (item <- subItems)
       item match {
@@ -105,6 +122,15 @@ case class LoopTree private[onetep](localIndex: Option[Index]) {
       }
   }
 
+  def +(b: LoopTree) : LoopTree =
+    if (getLocalIndex == b.getLocalIndex) {
+      val newTree = new LoopTree(getLocalIndex)
+      newTree.subItems = subItems ++ b.subItems
+      newTree
+    } else {
+      throw new SemanticError("Addition undefined for loops with different indices")
+    }
+
   def collectDeclarations(nameManager: NameManager) : Set[String] = {
     val result = collection.mutable.Set[String]()
 
@@ -117,38 +143,77 @@ case class LoopTree private[onetep](localIndex: Option[Index]) {
     result.toSet
   }
 
+  def getDependencies : Set[IterationSpace] = {
+    val dependencies = collection.mutable.Set[IterationSpace]()
+
+    for (item <- subItems)
+      item match {
+        case Left(space) => dependencies ++= IterationSpace.flattenPostorder(space)
+        case Right(tree) => dependencies ++= tree.getDependencies
+      }
+
+    dependencies.toSet
+  }
+
+  def getSpaces : Set[IterationSpace] = {
+    val spaces = collection.mutable.Set[IterationSpace]()
+
+    for (item <- subItems)
+      item match {
+        case Left(space) => spaces += space
+        case Right(tree) => spaces ++= tree.getSpaces
+      }
+
+    spaces.toSet
+  }
+
+
   def addIterationSpace(indices: List[Index], space: IterationSpace) {
     indices match {
       case Nil => subItems += Left(space)
-      case (head :: tail) => getEndLoop(head).addIterationSpace(tail, space)
+      case (head :: tail) => {
+        val tree = new LoopTree(Some(head))
+        subItems += Right(tree)
+        tree.addIterationSpace(tail, space)
+      }
     }
   }
 
-  private def getEndLoop(index: Index) : LoopTree = {
-    def newTree = { val tree = new LoopTree(Some(index)); subItems += Right(tree); tree} 
+  def fuse() {
+    val trees = collection.mutable.Set[LoopTree]()
+    for (item <- subItems)
+      item match {
+        case Right(tree) => trees += tree
+        case _ =>
+      }
 
-    var destination : Option[LoopTree]= None
-    for(item <- subItems)
-      item  match {
-      case Right(tree) => if (tree.getLocalIndex == Some(index)) destination = Some(tree)
-      case _ =>
-    }
+    subItems --= trees map (Right(_))
 
-    destination match {
-      case Some(tree) => tree
-      case None => newTree
+    def attemptFusion(loops: collection.mutable.Set[LoopTree]) : Boolean = {
+      for(a <- loops) 
+        for(b <- loops)
+          LoopTree.attemptFusion(a, b, this) match {
+            case Some(fused) => {loops -= a; loops -= b; loops += fused; return true}
+            case None =>
+          }
+      false
     }
+
+    while(attemptFusion(trees)) {}
+    trees map (_.fuse())
+    subItems ++= trees map (Right(_))
+
+    sort()
+  }
+
+  def sort() {
+    //TODO: Implement me!
   }
 
   def getLocalIndex = localIndex
 
   private def getLoopIndices(space: IterationSpace) : List[Index] = space.getIndices.toList
 
-  private def hasDependency(from: IterationSpace, to: IterationSpace) : Boolean =
-    (for(f <- IterationSpace.flattenPostorder(from); if f!=from && f == to) yield f).nonEmpty
-
-  private def hasDependency(from: IterationSpace, to: LoopTree) = to.contains(from)
-
   override def toString : String = toStrings.mkString("\n")
 
   private def toStrings : List[String] = {
@@ -161,7 +226,7 @@ case class LoopTree private[onetep](localIndex: Option[Index]) {
       })
 
       val prefix = if (entryID < subItems.size-1) "|  " else "   "
-      result ++= "|--"+subList.head :: (subList.tail.map(prefix+_))
+      result ++= "`--"+subList.head :: (subList.tail.map(prefix+_))
     }
 
     result.toList