From: Francis Russell Date: Fri, 3 Feb 2012 16:45:25 +0000 (+0000) Subject: New index fusion implementation. X-Git-Url: https://git.unchartedbackwaters.co.uk/w/?a=commitdiff_plain;h=fe574fdcd34ec1fd082a71aaed7b98caf0637dbd;p=francis%2Fofc.git New index fusion implementation. --- diff --git a/src/ofc/OFC.scala b/src/ofc/OFC.scala index 96bc906..0bec89b 100644 --- a/src/ofc/OFC.scala +++ b/src/ofc/OFC.scala @@ -6,6 +6,7 @@ import generators.Generator class InvalidInputException(s: String) extends Exception(s) class UnimplementedException(s: String) extends Exception(s) +class SemanticError(s: String) extends Exception(s) object OFC extends Parser { diff --git a/src/ofc/generators/onetep/CodeGenerator.scala b/src/ofc/generators/onetep/CodeGenerator.scala index 807061e..5aa32bc 100644 --- a/src/ofc/generators/onetep/CodeGenerator.scala +++ b/src/ofc/generators/onetep/CodeGenerator.scala @@ -85,7 +85,7 @@ class CodeGenerator { code append "real(kind=DP), allocatable, dimension" + (":"*concreteIndexList.size).mkString("(",", &\n",")") + " :: " + storageName + "\n" code append "allocate("+ storageName + - (concreteIndexList map ((x : Index) => x.getDenseWidth)).mkString("(",",",")") + ", stat=ierr)\n" + (concreteIndexList map ((x : Index) => x.getDenseWidth(nameManager))).mkString("(",",",")") + ", stat=ierr)\n" // We've declared temporary storage, now create the loops to populate it for (index <- concreteIndexList) code append index.generateIterationHeader(nameManager) + "\n" diff --git a/src/ofc/generators/onetep/LoopTree.scala b/src/ofc/generators/onetep/LoopTree.scala index dcd4beb..8ef2094 100644 --- a/src/ofc/generators/onetep/LoopTree.scala +++ b/src/ofc/generators/onetep/LoopTree.scala @@ -1,6 +1,7 @@ package ofc.generators.onetep import scala.collection.mutable.ArrayBuffer +import ofc.SemanticError /* Stores the configuration of indices we will use for code generation. @@ -27,7 +28,11 @@ class LoopNest { val nameManager = new NameManager() var declarations = Set[String]() - def addIterationSpace(indices: List[Index], space: IterationSpace) = base.addIterationSpace(indices, space) + def addIterationSpace(indices: List[Index], space: IterationSpace) { + base.addIterationSpace(indices, space) + base.fuse() + } + def getTree = base def generateCode : String = { @@ -81,22 +86,34 @@ object LoopTree { for (op <- term.getOperands) declarationsSet ++= collectSpaceDeclarations(op, nameManager) declarationsSet } + + def attemptFusion(a: LoopTree, b: LoopTree, commonScope: LoopTree) : Option[LoopTree] = { + if (a == b) + None + else if (a.getLocalIndex != b.getLocalIndex) + None + else + Some(a + b) + } } -case class LoopTree private[onetep](localIndex: Option[Index]) { +class LoopTree private[onetep](localIndex: Option[Index]) { var subItems = ArrayBuffer[Either[IterationSpace, LoopTree]]() - def contains(space: IterationSpace) : Boolean = { + private def contains(space: IterationSpace, deep: Boolean) : Boolean = { var found = false for (item <- subItems) found |= (item match { - case (item: LoopTree) => item.contains(space) + case (item: LoopTree) => deep && item.contains(space, deep) case (item: IterationSpace) => item == space }) found } + def containsShallow(space: IterationSpace) : Boolean = contains(space, false) + def containsDeep(space: IterationSpace) : Boolean = contains(space, true) + def accept(visitor: LoopTreeVisitor) { for (item <- subItems) item match { @@ -105,6 +122,15 @@ case class LoopTree private[onetep](localIndex: Option[Index]) { } } + def +(b: LoopTree) : LoopTree = + if (getLocalIndex == b.getLocalIndex) { + val newTree = new LoopTree(getLocalIndex) + newTree.subItems = subItems ++ b.subItems + newTree + } else { + throw new SemanticError("Addition undefined for loops with different indices") + } + def collectDeclarations(nameManager: NameManager) : Set[String] = { val result = collection.mutable.Set[String]() @@ -117,38 +143,77 @@ case class LoopTree private[onetep](localIndex: Option[Index]) { result.toSet } + def getDependencies : Set[IterationSpace] = { + val dependencies = collection.mutable.Set[IterationSpace]() + + for (item <- subItems) + item match { + case Left(space) => dependencies ++= IterationSpace.flattenPostorder(space) + case Right(tree) => dependencies ++= tree.getDependencies + } + + dependencies.toSet + } + + def getSpaces : Set[IterationSpace] = { + val spaces = collection.mutable.Set[IterationSpace]() + + for (item <- subItems) + item match { + case Left(space) => spaces += space + case Right(tree) => spaces ++= tree.getSpaces + } + + spaces.toSet + } + + def addIterationSpace(indices: List[Index], space: IterationSpace) { indices match { case Nil => subItems += Left(space) - case (head :: tail) => getEndLoop(head).addIterationSpace(tail, space) + case (head :: tail) => { + val tree = new LoopTree(Some(head)) + subItems += Right(tree) + tree.addIterationSpace(tail, space) + } } } - private def getEndLoop(index: Index) : LoopTree = { - def newTree = { val tree = new LoopTree(Some(index)); subItems += Right(tree); tree} + def fuse() { + val trees = collection.mutable.Set[LoopTree]() + for (item <- subItems) + item match { + case Right(tree) => trees += tree + case _ => + } - var destination : Option[LoopTree]= None - for(item <- subItems) - item match { - case Right(tree) => if (tree.getLocalIndex == Some(index)) destination = Some(tree) - case _ => - } + subItems --= trees map (Right(_)) - destination match { - case Some(tree) => tree - case None => newTree + def attemptFusion(loops: collection.mutable.Set[LoopTree]) : Boolean = { + for(a <- loops) + for(b <- loops) + LoopTree.attemptFusion(a, b, this) match { + case Some(fused) => {loops -= a; loops -= b; loops += fused; return true} + case None => + } + false } + + while(attemptFusion(trees)) {} + trees map (_.fuse()) + subItems ++= trees map (Right(_)) + + sort() + } + + def sort() { + //TODO: Implement me! } def getLocalIndex = localIndex private def getLoopIndices(space: IterationSpace) : List[Index] = space.getIndices.toList - private def hasDependency(from: IterationSpace, to: IterationSpace) : Boolean = - (for(f <- IterationSpace.flattenPostorder(from); if f!=from && f == to) yield f).nonEmpty - - private def hasDependency(from: IterationSpace, to: LoopTree) = to.contains(from) - override def toString : String = toStrings.mkString("\n") private def toStrings : List[String] = { @@ -161,7 +226,7 @@ case class LoopTree private[onetep](localIndex: Option[Index]) { }) val prefix = if (entryID < subItems.size-1) "| " else " " - result ++= "|--"+subList.head :: (subList.tail.map(prefix+_)) + result ++= "`--"+subList.head :: (subList.tail.map(prefix+_)) } result.toList