object LoopNest {
def apply(root: IterationSpace) : LoopNest = {
- val nest = new LoopNest
val sortedSpaces = IterationSpace.flattenPostorder(root)
val sortedIndices = Index.sort(sortedSpaces flatMap (_.getIndices))
+ val nest = new LoopNest(sortedIndices.toList)
for(space <- sortedSpaces) {
val indices = space.getIndices
}
}
-class LoopNest {
+class LoopNest(sortedIndices: List[Index]) {
val base = new LoopTree(None)
val nameManager = new NameManager()
var declarations = Set[String]()
+ var spaceFragmentsInfo = Map[IterationSpace, FragmentsInfo]()
+
+ class FragmentsInfo(val consumer: Option[Fragment], val transform: Option[Fragment], val producer: Option[Fragment])
def addIterationSpace(indices: List[Index], space: IterationSpace) {
- base.addIterationSpace(indices, space)
+ val operands = space.getOperands
+ val fragmentDepends = for(operand <- operands; fragment <- spaceFragmentsInfo(operand).producer) yield fragment
+
+ val consumer = space.getConsumerGenerator match {
+ case Some(consumerGenerator) => Some(new ConsumerFragment(space, fragmentDepends.toSet))
+ case None => None
+ }
+
+ val transform = space.getTransformGenerator match {
+ case Some(transformGenerator) => Some(new TransformFragment(space, consumer.toSet))
+ case None => None
+ }
+
+ val producer = space.getProducerGenerator match {
+ case Some(producerGenerator) => Some(new ProducerFragment(space, transform.toSet ++ consumer.toSet))
+ case None => None
+ }
+
+ val fragmentsInfo = new FragmentsInfo(consumer, transform, producer)
+ spaceFragmentsInfo += (space -> fragmentsInfo)
+
+ val consumerIndices = getSortedIndices((for (op <- operands; index <- op.getIndices) yield index).toSet)
+ consumer match {
+ case Some(fragment) => base.addFragment(consumerIndices, fragment)
+ case None =>
+ }
+
+ val transformIndices = getSortedIndices(consumerIndices.toSet & space.getIndices.toSet)
+ transform match {
+ case Some(fragment) => base.addFragment(transformIndices, fragment)
+ case None =>
+ }
+
+ val producerIndices = getSortedIndices(space.getIndices.toSet)
+ producer match {
+ case Some(fragment) => base.addFragment(producerIndices, fragment)
+ case None =>
+ }
+
base.fuse()
}
+ private def getSortedIndices(indices: Set[Index]) = sortedIndices filter (indices.contains(_))
+
def getTree = base
def generateCode : String = {
}
}
- def visitSpace(space: IterationSpace) {
+ def visitFragment(fragment: Fragment) {
}
def getCode = code.mkString
}
}
+trait Fragment {
+ def getAllFragments : Set[Fragment]
+ def getDependencies : Set[Fragment]
+ def collectDeclarations(nameManager: NameManager) : Set[String]
+ def accept(visitor: LoopTreeVisitor) : Unit
+}
+
+class ConsumerFragment(parent: IterationSpace, dependencies: Set[Fragment]) extends Fragment {
+ def getAllFragments = Set(this)
+ def getDependencies = dependencies
+ def collectDeclarations(nameManager: NameManager) = Set[String]()
+ def accept(visitor: LoopTreeVisitor) = visitor.visitFragment(this)
+ override def toString = "Consumer: " + parent.toString
+}
+
+class ProducerFragment(parent: IterationSpace, dependencies: Set[Fragment]) extends Fragment {
+ def getAllFragments = Set(this)
+ def getDependencies = dependencies
+ def collectDeclarations(nameManager: NameManager) = Set[String]()
+ def accept(visitor: LoopTreeVisitor) = visitor.visitFragment(this)
+ override def toString = "Producer: " + parent.toString
+}
+
+class TransformFragment(parent: IterationSpace, dependencies: Set[Fragment]) extends Fragment {
+ def getAllFragments = Set(this)
+ def getDependencies = dependencies
+ def collectDeclarations(nameManager: NameManager) = Set[String]()
+ def accept(visitor: LoopTreeVisitor) = visitor.visitFragment(this)
+ override def toString = "Transform: " + parent.toString
+}
+
trait LoopTreeVisitor {
def enterTree(tree: LoopTree)
def exitTree(tree: LoopTree)
- def visitSpace(space: IterationSpace)
+ def visitFragment(space: Fragment)
}
object LoopTree {
}
}
-class LoopTree private[onetep](localIndex: Option[Index]) {
- var subItems = ArrayBuffer[Either[IterationSpace, LoopTree]]()
-
- private def contains(space: IterationSpace, deep: Boolean) : Boolean = {
- var found = false
- for (item <- subItems)
- found |= (item match {
- case (item: LoopTree) => deep && item.contains(space, deep)
- case (item: IterationSpace) => item == space
- })
-
- found
- }
-
- def containsShallow(space: IterationSpace) : Boolean = contains(space, false)
- def containsDeep(space: IterationSpace) : Boolean = contains(space, true)
+class LoopTree private[onetep](localIndex: Option[Index]) extends Fragment {
+ var subItems = ArrayBuffer[Fragment]()
def accept(visitor: LoopTreeVisitor) {
+ visitor.enterTree(this)
+
for (item <- subItems)
- item match {
- case Left(space) => visitor.visitSpace(space)
- case Right(tree) => {visitor.enterTree(tree); tree.accept(visitor); visitor.exitTree(tree)}
- }
+ item.accept(visitor)
+
+ visitor.exitTree(this)
}
def +(b: LoopTree) : LoopTree =
val result = collection.mutable.Set[String]()
for(item <- subItems)
- item match {
- case Left(space) => result ++= LoopTree.collectSpaceDeclarations(space, nameManager)
- case Right(tree) => result ++= tree.collectDeclarations(nameManager)
- }
+ result ++= item.collectDeclarations(nameManager)
result.toSet
}
- def getDependencies : Set[IterationSpace] = {
- val dependencies = collection.mutable.Set[IterationSpace]()
+ def getDependencies : Set[Fragment] = {
+ val dependencies = collection.mutable.Set[Fragment]()
for (item <- subItems)
- item match {
- case Left(space) => dependencies ++= space.getDependencies
- case Right(tree) => dependencies ++= tree.getDependencies
- }
-
+ dependencies ++= item.getDependencies
+
dependencies.toSet
}
- def getSpaces : Set[IterationSpace] = {
- val spaces = collection.mutable.Set[IterationSpace]()
+ def getAllFragments : Set[Fragment] = {
+ val fragments = collection.mutable.Set[Fragment]()
+ fragments += this
for (item <- subItems)
- item match {
- case Left(space) => spaces += space
- case Right(tree) => spaces ++= tree.getSpaces
- }
+ fragments ++= item.getAllFragments
- spaces.toSet
+ fragments.toSet
}
- def addIterationSpace(indices: List[Index], space: IterationSpace) {
+ def addFragment(indices: List[Index], fragment: Fragment) {
indices match {
- case Nil => subItems += Left(space)
+ case Nil => subItems += fragment
case (head :: tail) => {
val tree = new LoopTree(Some(head))
- subItems += Right(tree)
- tree.addIterationSpace(tail, space)
+ subItems += tree
+ tree.addFragment(tail, fragment)
}
}
}
val trees = collection.mutable.Set[LoopTree]()
for (item <- subItems)
item match {
- case Right(tree) => trees += tree
+ case (tree: LoopTree) => trees += tree
case _ =>
}
- subItems --= trees map (Right(_))
+ subItems --= trees
def attemptFusion(loops: collection.mutable.Set[LoopTree]) : Boolean = {
for(a <- loops)
while(attemptFusion(trees)) {}
trees map (_.fuse())
- subItems ++= trees map (Right(_))
+ subItems ++= trees
sort()
}
def sort() {
- def compareItems(before: Either[IterationSpace, LoopTree], after: Either[IterationSpace, LoopTree]) : Boolean =
- (before, after) match {
- case (Left(space1), Left(space2)) => space2.getDependencies.contains(space1)
- case (Right(tree1), Right(tree2)) => (tree2.getDependencies & tree1.getSpaces).nonEmpty
- case (Left(space), Right(tree)) => tree.getDependencies.contains(space)
- case (Right(tree), Left(space)) => (space.getDependencies & tree.getSpaces).nonEmpty
- }
+ def compareItems(before: Fragment, after: Fragment) : Boolean =
+ (after.getDependencies & before.getAllFragments).nonEmpty
subItems = subItems.sortWith(compareItems(_, _))
}
for (entryID <- 0 until subItems.size) {
val subList = (subItems(entryID) match {
- case Left(space) => List(space.toString)
- case Right(tree) => tree.toStrings
+ case (tree: LoopTree) => tree.toStrings
+ case x => List(x.toString)
})
val subListHeadPrefix = if (entryID < subItems.size-1) "|--" else "`--"
term.getOperands.toTraversable.flatMap(flattenPostorder(_)) ++ List(term)
}
+trait ConsumerGenerator {
+ def generate(names: NameManager, indices: Map[Index,String], values : Map[IterationSpace, String]) : String
+}
+
+trait ProducerGenerator {
+ def generate(names: NameManager) : String
+}
+
+trait TransformGenerator {
+ def generate(names: NameManager) : String
+}
+
trait IterationSpace {
- def getAccessExpression(indexNames: NameManager) : String
def getOperands : List[IterationSpace]
def getSpatialIndices : List[SpatialIndex]
def getDiscreteIndices : List[DiscreteIndex]
val operands = getOperands
operands.toSet ++ operands.flatMap(_.getDependencies)
}
+
+ // Code generation
+ def getConsumerGenerator : Option[ConsumerGenerator]
+ def getTransformGenerator : Option[TransformGenerator]
+ def getProducerGenerator : Option[ProducerGenerator]
}
trait DataSpace extends IterationSpace {
def getOperands = Nil
+ def getConsumerGenerator = None
+ def getTransformGenerator = None
}
trait Matrix extends DataSpace
class Assignment(indexBindings: IndexBindings, lhs: DataSpace, rhs: IterationSpace) extends IterationSpace {
def getIndexBindings = indexBindings
- def getOperands = List(lhs,rhs)
- def getSpatialIndices = Nil
- def getDiscreteIndices = Nil
+ def getOperands = List(rhs)
+ def getSpatialIndices = lhs.getSpatialIndices
+ def getDiscreteIndices = lhs.getDiscreteIndices
def getExternalIndices = Set()
- def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed")
+
+ def getConsumerGenerator = None
+ def getTransformGenerator = None
+ def getProducerGenerator = None
}
-class Scalar(value: Double) extends IterationSpace {
- def getOperands = Nil
+class Scalar(value: Double) extends DataSpace {
def getSpatialIndices = Nil
def getDiscreteIndices = Nil
def getExternalIndices = Set()
- def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed")
+ def getProducerGenerator = Some(new ProducerGenerator {
+ def generate(names: NameManager) = value.toString
+ })
}
class GeneralInnerProduct(operands: List[IterationSpace], removedIndices: Set[Index]) extends IterationSpace {
def getSpatialIndices = spatialIndices
def getDiscreteIndices = discreteIndices
def getExternalIndices = Set()
- def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed")
+
+ def getConsumerGenerator = None
+ def getTransformGenerator = None
+ def getProducerGenerator = None
}
class Reciprocal(op: IterationSpace) extends IterationSpace {
def getSpatialIndices = spatialIndices.toList
def getDiscreteIndices = op.getDiscreteIndices
def getExternalIndices = Set()
- def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed")
+
+ def getConsumerGenerator = None
+ def getTransformGenerator = None
+ def getProducerGenerator = None
}
class Laplacian(op: IterationSpace) extends IterationSpace {
def getSpatialIndices = op.getSpatialIndices
def getDiscreteIndices = op.getDiscreteIndices
def getExternalIndices = Set()
- def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed")
+
+ def getConsumerGenerator = None
+ def getTransformGenerator = None
+ def getProducerGenerator = None
}
class SpatialRestriction(op: IterationSpace) extends IterationSpace {
def getSpatialIndices = spatialIndices.toList
def getDiscreteIndices = op.getDiscreteIndices
def getExternalIndices = Set()
- def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed")
+
+ def getConsumerGenerator = None
+ def getTransformGenerator = None
+ def getProducerGenerator = None
}
class SPAM3(name : String) extends Matrix {
def getSpatialIndices = Nil
def getDiscreteIndices = List(rowIndex, colIndex)
def getExternalIndices = Set()
- def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed")
+
+ def getProducerGenerator = None
}
class PPDFunctionSet(val basis : String, data : String) extends FunctionSet {
def getDenseWidth(names: NameManager) = parent.basis+"%max_n_ppds_sphere"
def generateIterationHeader(names: NameManager) = {
-
val initDense = "call basis_find_ppd_in_neighbour(" + denseIndexNames.mkString(",") + ", &\n" +
parent.getSphere(names) + "%ppd_list(1," + names(this) + "), &\n" +
parent.getSphere(names) + "%ppd_list(2," + names(this) + "), &\n" +
def getDiscreteIndices = List(getSphereIndex)
def getExternalIndices = Set(getPPDIndex)
- def getAccessExpression(indexNames: NameManager) = {
- val index = getSphere(indexNames)+"%offset + &\n" +
- "("+indexNames(getPPDIndex)+"-1)*pub_cell%n_pts - 1 + &\n" +
- "(" + indexNames(spatialIndices(2)) + "-1)*pub_cell%n_pt2*pub_cell%n_pt1 + &\n" +
- "(" + indexNames(spatialIndices(1)) + "-1)*pub_cell%n_pt1 + &\n" +
- indexNames(spatialIndices(0))
+ def getProducerGenerator = Some(new ProducerGenerator {
+ def generate(names: NameManager) = {
+ val offset = getSphere(names)+"%offset + &\n" +
+ "("+names(getPPDIndex)+"-1)*pub_cell%n_pts - 1 + &\n" +
+ "(" + names(spatialIndices(2)) + "-1)*pub_cell%n_pt2*pub_cell%n_pt1 + &\n" +
+ "(" + names(spatialIndices(1)) + "-1)*pub_cell%n_pt1 + &\n" +
+ names(spatialIndices(0))
- data+"("+index+")"
- }
+ data+"("+offset+")"
+ }
+ })
}
class BindingIndex(name : String) {