Stores the configuration of indices we will use for code generation.
*/
-object LoopTree {
- def apply(root: IterationSpace) = {
- val base = new LoopTree(None)
+object LoopNest {
+ def apply(root: IterationSpace) : LoopNest = {
+ val nest = new LoopNest
val sortedSpaces = IterationSpace.flattenPostorder(root)
val sortedIndices = Index.sort(sortedSpaces flatMap (_.getIndices))
for(space <- sortedSpaces) {
val indices = space.getIndices
val localSortedIndices = sortedIndices filter (indices.contains(_))
- base.addIterationSpace(localSortedIndices, space)
+ nest.addIterationSpace(localSortedIndices, space)
+ }
+
+ nest
+ }
+}
+
+class LoopNest {
+ val base = new LoopTree(None)
+ val nameManager = new NameManager()
+ var declarations = Set[String]()
+
+ def addIterationSpace(indices: List[Index], space: IterationSpace) = base.addIterationSpace(indices, space)
+ def getTree = base
+
+ def generateCode : String = {
+ val code = new StringBuilder()
+ declarations = base.collectDeclarations(nameManager)
+ code append declarations.mkString("\n")
+
+ val generationVisitor = new GenerationVisitor
+ base.accept(generationVisitor)
+ code append generationVisitor.getCode
+
+ code.mkString
+ }
+
+ class GenerationVisitor extends LoopTreeVisitor {
+ val code = new StringBuilder()
+
+ def enterTree(tree: LoopTree) {
+ tree.getLocalIndex match {
+ case None =>
+ case Some(index) => code append index.generateIterationHeader(nameManager)+"\n"
+ }
+ }
+
+ def exitTree(tree: LoopTree) {
+ tree.getLocalIndex match {
+ case None =>
+ case Some(index) => code append index.generateIterationFooter(nameManager)+"\n"
+ }
+ }
+
+ def visitSpace(space: IterationSpace) {
}
- println(base)
- base
+ def getCode = code.mkString
}
}
-case class LoopTree private(localIndex: Option[Index]) {
+trait LoopTreeVisitor {
+ def enterTree(tree: LoopTree)
+ def exitTree(tree: LoopTree)
+ def visitSpace(space: IterationSpace)
+}
+
+object LoopTree {
+ def collectSpaceDeclarations(term: IterationSpace, nameManager: NameManager) : Set[String] = {
+ val declarations = for(index <- term.getIndices;
+ declaration <- index.getDeclarations(nameManager)) yield declaration
+
+ var declarationsSet = declarations.toSet
+ for (op <- term.getOperands) declarationsSet ++= collectSpaceDeclarations(op, nameManager)
+ declarationsSet
+ }
+}
+
+case class LoopTree private[onetep](localIndex: Option[Index]) {
var subItems = ArrayBuffer[Either[IterationSpace, LoopTree]]()
def contains(space: IterationSpace) : Boolean = {
found
}
- private def addIterationSpace(indices: List[Index], space: IterationSpace) {
+ def accept(visitor: LoopTreeVisitor) {
+ for (item <- subItems)
+ item match {
+ case Left(space) => visitor.visitSpace(space)
+ case Right(tree) => {visitor.enterTree(tree); tree.accept(visitor); visitor.exitTree(tree)}
+ }
+ }
+
+ def collectDeclarations(nameManager: NameManager) : Set[String] = {
+ val result = collection.mutable.Set[String]()
+
+ for(item <- subItems)
+ item match {
+ case Left(space) => result ++= LoopTree.collectSpaceDeclarations(space, nameManager)
+ case Right(tree) => result ++= tree.collectDeclarations(nameManager)
+ }
+
+ result.toSet
+ }
+
+ def addIterationSpace(indices: List[Index], space: IterationSpace) {
indices match {
case Nil => subItems += Left(space)
case (head :: tail) => getEndLoop(head).addIterationSpace(tail, space)
}
}
- private def getLocalIndex = localIndex
+ def getLocalIndex = localIndex
private def getLoopIndices(space: IterationSpace) : List[Index] = space.getIndices.toList
private def toStrings : List[String] = {
val result = ArrayBuffer[String]("Index: " + localIndex)
- for (entry <- subItems) {
- val subList = (entry match {
+ for (entryID <- 0 until subItems.size) {
+ val subList = (subItems(entryID) match {
case Left(space) => List(space.toString)
case Right(tree) => tree.toStrings
})
- result ++= "|--"+subList.head :: (subList.tail.map("| "+_))
+ val prefix = if (entryID < subItems.size-1) "| " else " "
+ result ++= "|--"+subList.head :: (subList.tail.map(prefix+_))
}
result.toList
}
}
+
+
trait Index {
def getName : String
def getDependencies : Set[Index]
- def getDenseWidth : String
+ def getDenseWidth(names: NameManager) : String
def getDensePosition(names: NameManager) : String = names(this)
def generateIterationHeader(names: NameManager) : String
def generateIterationFooter(names: NameManager) : String
trait FunctionSet extends DataSpace
class Assignment(indexBindings: IndexBindings, lhs: DataSpace, rhs: IterationSpace) extends IterationSpace {
- override def toString = indexBindings.toString
def getIndexBindings = indexBindings
def getOperands = List(lhs,rhs)
def getSpatialIndices = Nil
class DenseSpatialIndex(parent: GeneralInnerProduct, original: SpatialIndex) extends SpatialIndex{
def getDependencies = Set()
def getName = "dense_spatial_index"
- def getDenseWidth = original.getDenseWidth
- def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+getDenseWidth
+ def getDenseWidth(names: NameManager) = original.getDenseWidth(names)
+ def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+getDenseWidth(names)
def generateIterationFooter(names: NameManager) = "end do"
def getDeclarations(names: NameManager) = List("integer :: "+names(this))
}
class DenseDiscreteIndex(parent: GeneralInnerProduct, original: DiscreteIndex) extends DiscreteIndex {
def getDependencies = Set()
def getName = "dense_discrete_index"
- def getDenseWidth = original.getDenseWidth
- def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+getDenseWidth
+ def getDenseWidth(names: NameManager) = original.getDenseWidth(names)
+ def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+getDenseWidth(names)
def generateIterationFooter(names: NameManager) = "end do"
def getDeclarations(names: NameManager) = List("integer :: "+names(this))
}
class BlockIndex(parent: Reciprocal, dimension: Int, original: SpatialIndex) extends SpatialIndex {
def getName = "reciprocal_index_" + dimension
def getDependencies = Set()
- def getDenseWidth = original.getDenseWidth
- def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+getDenseWidth
+ def getDenseWidth(names: NameManager) = original.getDenseWidth(names)
+ def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+getDenseWidth(names)
def generateIterationFooter(names: NameManager) = "end do"
def getDeclarations(names: NameManager) = List("integer :: "+names(this))
}
class RestrictedIndex(parent: SpatialRestriction, dimension: Int) extends SpatialIndex {
def getName = "restriction_index_" + dimension
def getDependencies = Set()
- def getDenseWidth = throw new UnimplementedException("Restriction unimplemnted")
+ def getDenseWidth(names: NameManager) = "pub_fftbox%total_pt"+(dimension+1)
- def generateIterationHeader(names: NameManager) = throw new UnimplementedException("how the hell does this work?")
- def generateIterationFooter(names: NameManager) = throw new UnimplementedException("how does this work either?")
+ def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+getDenseWidth(names)
+ def generateIterationFooter(names: NameManager) = "end do"
def getDeclarations(names: NameManager) = Nil
}
class SPAM3(name : String) extends Matrix {
override def toString = name
+ def getName = name
class RowIndex(parent: SPAM3) extends DiscreteIndex {
override def toString = parent + ".row"
def getName = "row_index"
def getDependencies = Set()
- def getDenseWidth = throw new UnimplementedException("Matrix size unimplemented")
+ def getDenseWidth(names: NameManager) = "sparse_num_rows("+parent.getName+")"
def generateIterationHeader(names: NameManager) = {
val indexName = names(this)
- "do "+indexName+"=1,"+getDenseWidth
+ "do "+indexName+"=1,"+getDenseWidth(names)
}
def generateIterationFooter(names: NameManager) = "end do"
override def toString = parent + ".col"
def getName = "row_index"
def getDependencies = Set()
- def getDenseWidth = throw new UnimplementedException("Matrix size unimplemented")
+ def getDenseWidth(names: NameManager) = "sparse_num_cols("+parent.getName+")"
def generateIterationHeader(names: NameManager) = {
val indexName = names(this)
- "do "+indexName+"=1,"+getDenseWidth
+ "do "+indexName+"=1,"+getDenseWidth(names)
}
class SphereIndex(parent: PPDFunctionSet) extends DiscreteIndex {
def getName = "sphere_index"
def getDependencies = Set()
- def getDenseWidth = throw new UnimplementedException("Sphere count unimplemented")
+ def getDenseWidth(names: NameManager) = parent.getNumSpheres(names)
def generateIterationHeader(names: NameManager) = {
val indexName = names(this)
- "do "+indexName+"=1,"+getDenseWidth
+ "do "+indexName+"=1," + parent.getSphere(names) + "%n_ppds_sphere"
}
def generateIterationFooter(names: NameManager) = "end do"
def getDensePPDIndices = denseIndexNames
//TODO: def getDenseWidth = parent.getSphereIndex.getSphere + "%n_ppds_sphere"
- def getDenseWidth = parent.basis+"%max_n_ppds_sphere"
+ def getDenseWidth(names: NameManager) = parent.basis+"%max_n_ppds_sphere"
def generateIterationHeader(names: NameManager) = {
class IntraPPDIndex(parent: PPDFunctionSet, dimension: Int) extends SpatialIndex {
def getName = "intra_ppd_index_" + dimension
def getDependencies = Set[Index](parent.getPPDIndex)
- def getDenseWidth = "pub_cell%total_pt"+(dimension+1)
+ def getDenseWidth(names: NameManager) = "pub_cell%total_pt"+(dimension+1)
def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+"pub_cell%n_pt"+(dimension+1)
def generateIterationFooter(names: NameManager) = "end do"
def getDeclarations(names: NameManager) = List("integer :: "+names(this))
def getSphereIndex = sphereIndex
def getSphere(names: NameManager) = basis + "%spheres("+names(getSphereIndex)+")"
+ def getNumSpheres(names: NameManager) = {
+ // TODO: This number is dependent on the parallel distribution
+ basis + "%node_num"
+ }
+
def getSpatialIndices = spatialIndices.toList
def getDiscreteIndices = List(getSphereIndex)
def getExternalIndices = Set(getPPDIndex)