From ea03036ee45828f15f116ba024bafb18f13353a0 Mon Sep 17 00:00:00 2001 From: Francis Russell Date: Fri, 3 Feb 2012 08:41:45 +0000 Subject: [PATCH] Generate loop headers and footers. --- src/ofc/generators/onetep/CodeGenerator.scala | 4 +- src/ofc/generators/onetep/LoopTree.scala | 107 ++++++++++++++++-- src/ofc/generators/onetep/Tree.scala | 43 +++---- 3 files changed, 122 insertions(+), 32 deletions(-) diff --git a/src/ofc/generators/onetep/CodeGenerator.scala b/src/ofc/generators/onetep/CodeGenerator.scala index 26bb459..807061e 100644 --- a/src/ofc/generators/onetep/CodeGenerator.scala +++ b/src/ofc/generators/onetep/CodeGenerator.scala @@ -56,7 +56,9 @@ class CodeGenerator { println(i) println("finished dumping indices") - val loopTree = LoopTree(space) + val loopNest = LoopNest(space) + println(loopNest.getTree) + println(loopNest.generateCode) // Next: we dump all these things into a prefix map System.exit(0) diff --git a/src/ofc/generators/onetep/LoopTree.scala b/src/ofc/generators/onetep/LoopTree.scala index c711a5a..dcd4beb 100644 --- a/src/ofc/generators/onetep/LoopTree.scala +++ b/src/ofc/generators/onetep/LoopTree.scala @@ -6,24 +6,84 @@ import scala.collection.mutable.ArrayBuffer Stores the configuration of indices we will use for code generation. */ -object LoopTree { - def apply(root: IterationSpace) = { - val base = new LoopTree(None) +object LoopNest { + def apply(root: IterationSpace) : LoopNest = { + val nest = new LoopNest val sortedSpaces = IterationSpace.flattenPostorder(root) val sortedIndices = Index.sort(sortedSpaces flatMap (_.getIndices)) for(space <- sortedSpaces) { val indices = space.getIndices val localSortedIndices = sortedIndices filter (indices.contains(_)) - base.addIterationSpace(localSortedIndices, space) + nest.addIterationSpace(localSortedIndices, space) + } + + nest + } +} + +class LoopNest { + val base = new LoopTree(None) + val nameManager = new NameManager() + var declarations = Set[String]() + + def addIterationSpace(indices: List[Index], space: IterationSpace) = base.addIterationSpace(indices, space) + def getTree = base + + def generateCode : String = { + val code = new StringBuilder() + declarations = base.collectDeclarations(nameManager) + code append declarations.mkString("\n") + + val generationVisitor = new GenerationVisitor + base.accept(generationVisitor) + code append generationVisitor.getCode + + code.mkString + } + + class GenerationVisitor extends LoopTreeVisitor { + val code = new StringBuilder() + + def enterTree(tree: LoopTree) { + tree.getLocalIndex match { + case None => + case Some(index) => code append index.generateIterationHeader(nameManager)+"\n" + } + } + + def exitTree(tree: LoopTree) { + tree.getLocalIndex match { + case None => + case Some(index) => code append index.generateIterationFooter(nameManager)+"\n" + } + } + + def visitSpace(space: IterationSpace) { } - println(base) - base + def getCode = code.mkString } } -case class LoopTree private(localIndex: Option[Index]) { +trait LoopTreeVisitor { + def enterTree(tree: LoopTree) + def exitTree(tree: LoopTree) + def visitSpace(space: IterationSpace) +} + +object LoopTree { + def collectSpaceDeclarations(term: IterationSpace, nameManager: NameManager) : Set[String] = { + val declarations = for(index <- term.getIndices; + declaration <- index.getDeclarations(nameManager)) yield declaration + + var declarationsSet = declarations.toSet + for (op <- term.getOperands) declarationsSet ++= collectSpaceDeclarations(op, nameManager) + declarationsSet + } +} + +case class LoopTree private[onetep](localIndex: Option[Index]) { var subItems = ArrayBuffer[Either[IterationSpace, LoopTree]]() def contains(space: IterationSpace) : Boolean = { @@ -37,7 +97,27 @@ case class LoopTree private(localIndex: Option[Index]) { found } - private def addIterationSpace(indices: List[Index], space: IterationSpace) { + def accept(visitor: LoopTreeVisitor) { + for (item <- subItems) + item match { + case Left(space) => visitor.visitSpace(space) + case Right(tree) => {visitor.enterTree(tree); tree.accept(visitor); visitor.exitTree(tree)} + } + } + + def collectDeclarations(nameManager: NameManager) : Set[String] = { + val result = collection.mutable.Set[String]() + + for(item <- subItems) + item match { + case Left(space) => result ++= LoopTree.collectSpaceDeclarations(space, nameManager) + case Right(tree) => result ++= tree.collectDeclarations(nameManager) + } + + result.toSet + } + + def addIterationSpace(indices: List[Index], space: IterationSpace) { indices match { case Nil => subItems += Left(space) case (head :: tail) => getEndLoop(head).addIterationSpace(tail, space) @@ -60,7 +140,7 @@ case class LoopTree private(localIndex: Option[Index]) { } } - private def getLocalIndex = localIndex + def getLocalIndex = localIndex private def getLoopIndices(space: IterationSpace) : List[Index] = space.getIndices.toList @@ -74,15 +154,18 @@ case class LoopTree private(localIndex: Option[Index]) { private def toStrings : List[String] = { val result = ArrayBuffer[String]("Index: " + localIndex) - for (entry <- subItems) { - val subList = (entry match { + for (entryID <- 0 until subItems.size) { + val subList = (subItems(entryID) match { case Left(space) => List(space.toString) case Right(tree) => tree.toStrings }) - result ++= "|--"+subList.head :: (subList.tail.map("| "+_)) + val prefix = if (entryID < subItems.size-1) "| " else " " + result ++= "|--"+subList.head :: (subList.tail.map(prefix+_)) } result.toList } } + + diff --git a/src/ofc/generators/onetep/Tree.scala b/src/ofc/generators/onetep/Tree.scala index 1a624c1..0b7d3ab 100644 --- a/src/ofc/generators/onetep/Tree.scala +++ b/src/ofc/generators/onetep/Tree.scala @@ -20,7 +20,7 @@ object Index { trait Index { def getName : String def getDependencies : Set[Index] - def getDenseWidth : String + def getDenseWidth(names: NameManager) : String def getDensePosition(names: NameManager) : String = names(this) def generateIterationHeader(names: NameManager) : String def generateIterationFooter(names: NameManager) : String @@ -64,7 +64,6 @@ trait Matrix extends DataSpace trait FunctionSet extends DataSpace class Assignment(indexBindings: IndexBindings, lhs: DataSpace, rhs: IterationSpace) extends IterationSpace { - override def toString = indexBindings.toString def getIndexBindings = indexBindings def getOperands = List(lhs,rhs) def getSpatialIndices = Nil @@ -86,8 +85,8 @@ class GeneralInnerProduct(operands: List[IterationSpace], removedIndices: Set[In class DenseSpatialIndex(parent: GeneralInnerProduct, original: SpatialIndex) extends SpatialIndex{ def getDependencies = Set() def getName = "dense_spatial_index" - def getDenseWidth = original.getDenseWidth - def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+getDenseWidth + def getDenseWidth(names: NameManager) = original.getDenseWidth(names) + def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+getDenseWidth(names) def generateIterationFooter(names: NameManager) = "end do" def getDeclarations(names: NameManager) = List("integer :: "+names(this)) } @@ -95,8 +94,8 @@ class GeneralInnerProduct(operands: List[IterationSpace], removedIndices: Set[In class DenseDiscreteIndex(parent: GeneralInnerProduct, original: DiscreteIndex) extends DiscreteIndex { def getDependencies = Set() def getName = "dense_discrete_index" - def getDenseWidth = original.getDenseWidth - def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+getDenseWidth + def getDenseWidth(names: NameManager) = original.getDenseWidth(names) + def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+getDenseWidth(names) def generateIterationFooter(names: NameManager) = "end do" def getDeclarations(names: NameManager) = List("integer :: "+names(this)) } @@ -126,8 +125,8 @@ class Reciprocal(op: IterationSpace) extends IterationSpace { class BlockIndex(parent: Reciprocal, dimension: Int, original: SpatialIndex) extends SpatialIndex { def getName = "reciprocal_index_" + dimension def getDependencies = Set() - def getDenseWidth = original.getDenseWidth - def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+getDenseWidth + def getDenseWidth(names: NameManager) = original.getDenseWidth(names) + def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+getDenseWidth(names) def generateIterationFooter(names: NameManager) = "end do" def getDeclarations(names: NameManager) = List("integer :: "+names(this)) } @@ -154,10 +153,10 @@ class SpatialRestriction(op: IterationSpace) extends IterationSpace { class RestrictedIndex(parent: SpatialRestriction, dimension: Int) extends SpatialIndex { def getName = "restriction_index_" + dimension def getDependencies = Set() - def getDenseWidth = throw new UnimplementedException("Restriction unimplemnted") + def getDenseWidth(names: NameManager) = "pub_fftbox%total_pt"+(dimension+1) - def generateIterationHeader(names: NameManager) = throw new UnimplementedException("how the hell does this work?") - def generateIterationFooter(names: NameManager) = throw new UnimplementedException("how does this work either?") + def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+getDenseWidth(names) + def generateIterationFooter(names: NameManager) = "end do" def getDeclarations(names: NameManager) = Nil } @@ -172,16 +171,17 @@ class SpatialRestriction(op: IterationSpace) extends IterationSpace { class SPAM3(name : String) extends Matrix { override def toString = name + def getName = name class RowIndex(parent: SPAM3) extends DiscreteIndex { override def toString = parent + ".row" def getName = "row_index" def getDependencies = Set() - def getDenseWidth = throw new UnimplementedException("Matrix size unimplemented") + def getDenseWidth(names: NameManager) = "sparse_num_rows("+parent.getName+")" def generateIterationHeader(names: NameManager) = { val indexName = names(this) - "do "+indexName+"=1,"+getDenseWidth + "do "+indexName+"=1,"+getDenseWidth(names) } def generateIterationFooter(names: NameManager) = "end do" @@ -192,11 +192,11 @@ class SPAM3(name : String) extends Matrix { override def toString = parent + ".col" def getName = "row_index" def getDependencies = Set() - def getDenseWidth = throw new UnimplementedException("Matrix size unimplemented") + def getDenseWidth(names: NameManager) = "sparse_num_cols("+parent.getName+")" def generateIterationHeader(names: NameManager) = { val indexName = names(this) - "do "+indexName+"=1,"+getDenseWidth + "do "+indexName+"=1,"+getDenseWidth(names) } @@ -217,11 +217,11 @@ class PPDFunctionSet(val basis : String, data : String) extends FunctionSet { class SphereIndex(parent: PPDFunctionSet) extends DiscreteIndex { def getName = "sphere_index" def getDependencies = Set() - def getDenseWidth = throw new UnimplementedException("Sphere count unimplemented") + def getDenseWidth(names: NameManager) = parent.getNumSpheres(names) def generateIterationHeader(names: NameManager) = { val indexName = names(this) - "do "+indexName+"=1,"+getDenseWidth + "do "+indexName+"=1," + parent.getSphere(names) + "%n_ppds_sphere" } def generateIterationFooter(names: NameManager) = "end do" @@ -236,7 +236,7 @@ class PPDFunctionSet(val basis : String, data : String) extends FunctionSet { def getDensePPDIndices = denseIndexNames //TODO: def getDenseWidth = parent.getSphereIndex.getSphere + "%n_ppds_sphere" - def getDenseWidth = parent.basis+"%max_n_ppds_sphere" + def getDenseWidth(names: NameManager) = parent.basis+"%max_n_ppds_sphere" def generateIterationHeader(names: NameManager) = { @@ -261,7 +261,7 @@ class PPDFunctionSet(val basis : String, data : String) extends FunctionSet { class IntraPPDIndex(parent: PPDFunctionSet, dimension: Int) extends SpatialIndex { def getName = "intra_ppd_index_" + dimension def getDependencies = Set[Index](parent.getPPDIndex) - def getDenseWidth = "pub_cell%total_pt"+(dimension+1) + def getDenseWidth(names: NameManager) = "pub_cell%total_pt"+(dimension+1) def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+"pub_cell%n_pt"+(dimension+1) def generateIterationFooter(names: NameManager) = "end do" def getDeclarations(names: NameManager) = List("integer :: "+names(this)) @@ -278,6 +278,11 @@ class PPDFunctionSet(val basis : String, data : String) extends FunctionSet { def getSphereIndex = sphereIndex def getSphere(names: NameManager) = basis + "%spheres("+names(getSphereIndex)+")" + def getNumSpheres(names: NameManager) = { + // TODO: This number is dependent on the parallel distribution + basis + "%node_num" + } + def getSpatialIndices = spatialIndices.toList def getDiscreteIndices = List(getSphereIndex) def getExternalIndices = Set(getPPDIndex) -- 2.47.3