From: Francis Russell Date: Mon, 6 Feb 2012 20:59:11 +0000 (+0000) Subject: Work on producer/consumer model of transforms. X-Git-Url: https://git.unchartedbackwaters.co.uk/w/?a=commitdiff_plain;h=7fc96f0abc6dc2516fd9b2dd62a162411191e87d;p=francis%2Fofc.git Work on producer/consumer model of transforms. --- diff --git a/src/ofc/generators/onetep/CodeGenerator.scala b/src/ofc/generators/onetep/CodeGenerator.scala index f83b6f3..fd6fe9f 100644 --- a/src/ofc/generators/onetep/CodeGenerator.scala +++ b/src/ofc/generators/onetep/CodeGenerator.scala @@ -89,9 +89,9 @@ class CodeGenerator { // We've declared temporary storage, now create the loops to populate it for (index <- concreteIndexList) code append index.generateIterationHeader(nameManager) + "\n" - val lhs = storageName + (concreteIndexList map ((x: Index) => x.getDensePosition(nameManager))).mkString("(",", &\n",")") - val rhs = op.getAccessExpression(nameManager) - code append lhs + " = &\n" + rhs + "\n" + //val lhs = storageName + (concreteIndexList map ((x: Index) => x.getDensePosition(nameManager))).mkString("(",", &\n",")") + //val rhs = op.getAccessExpression(nameManager) + //code append lhs + " = &\n" + rhs + "\n" for (index <- concreteIndexList) code append index.generateIterationFooter(nameManager) + "\n" println(code.mkString) diff --git a/src/ofc/generators/onetep/LoopTree.scala b/src/ofc/generators/onetep/LoopTree.scala index e4edd88..6ce34cf 100644 --- a/src/ofc/generators/onetep/LoopTree.scala +++ b/src/ofc/generators/onetep/LoopTree.scala @@ -9,9 +9,9 @@ Stores the configuration of indices we will use for code generation. object LoopNest { def apply(root: IterationSpace) : LoopNest = { - val nest = new LoopNest val sortedSpaces = IterationSpace.flattenPostorder(root) val sortedIndices = Index.sort(sortedSpaces flatMap (_.getIndices)) + val nest = new LoopNest(sortedIndices.toList) for(space <- sortedSpaces) { val indices = space.getIndices @@ -23,16 +23,59 @@ object LoopNest { } } -class LoopNest { +class LoopNest(sortedIndices: List[Index]) { val base = new LoopTree(None) val nameManager = new NameManager() var declarations = Set[String]() + var spaceFragmentsInfo = Map[IterationSpace, FragmentsInfo]() + + class FragmentsInfo(val consumer: Option[Fragment], val transform: Option[Fragment], val producer: Option[Fragment]) def addIterationSpace(indices: List[Index], space: IterationSpace) { - base.addIterationSpace(indices, space) + val operands = space.getOperands + val fragmentDepends = for(operand <- operands; fragment <- spaceFragmentsInfo(operand).producer) yield fragment + + val consumer = space.getConsumerGenerator match { + case Some(consumerGenerator) => Some(new ConsumerFragment(space, fragmentDepends.toSet)) + case None => None + } + + val transform = space.getTransformGenerator match { + case Some(transformGenerator) => Some(new TransformFragment(space, consumer.toSet)) + case None => None + } + + val producer = space.getProducerGenerator match { + case Some(producerGenerator) => Some(new ProducerFragment(space, transform.toSet ++ consumer.toSet)) + case None => None + } + + val fragmentsInfo = new FragmentsInfo(consumer, transform, producer) + spaceFragmentsInfo += (space -> fragmentsInfo) + + val consumerIndices = getSortedIndices((for (op <- operands; index <- op.getIndices) yield index).toSet) + consumer match { + case Some(fragment) => base.addFragment(consumerIndices, fragment) + case None => + } + + val transformIndices = getSortedIndices(consumerIndices.toSet & space.getIndices.toSet) + transform match { + case Some(fragment) => base.addFragment(transformIndices, fragment) + case None => + } + + val producerIndices = getSortedIndices(space.getIndices.toSet) + producer match { + case Some(fragment) => base.addFragment(producerIndices, fragment) + case None => + } + base.fuse() } + private def getSortedIndices(indices: Set[Index]) = sortedIndices filter (indices.contains(_)) + def getTree = base def generateCode : String = { @@ -64,17 +107,48 @@ class LoopNest { } } - def visitSpace(space: IterationSpace) { + def visitFragment(fragment: Fragment) { } def getCode = code.mkString } } +trait Fragment { + def getAllFragments : Set[Fragment] + def getDependencies : Set[Fragment] + def collectDeclarations(nameManager: NameManager) : Set[String] + def accept(visitor: LoopTreeVisitor) : Unit +} + +class ConsumerFragment(parent: IterationSpace, dependencies: Set[Fragment]) extends Fragment { + def getAllFragments = Set(this) + def getDependencies = dependencies + def collectDeclarations(nameManager: NameManager) = Set[String]() + def accept(visitor: LoopTreeVisitor) = visitor.visitFragment(this) + override def toString = "Consumer: " + parent.toString +} + +class ProducerFragment(parent: IterationSpace, dependencies: Set[Fragment]) extends Fragment { + def getAllFragments = Set(this) + def getDependencies = dependencies + def collectDeclarations(nameManager: NameManager) = Set[String]() + def accept(visitor: LoopTreeVisitor) = visitor.visitFragment(this) + override def toString = "Producer: " + parent.toString +} + +class TransformFragment(parent: IterationSpace, dependencies: Set[Fragment]) extends Fragment { + def getAllFragments = Set(this) + def getDependencies = dependencies + def collectDeclarations(nameManager: NameManager) = Set[String]() + def accept(visitor: LoopTreeVisitor) = visitor.visitFragment(this) + override def toString = "Transform: " + parent.toString +} + trait LoopTreeVisitor { def enterTree(tree: LoopTree) def exitTree(tree: LoopTree) - def visitSpace(space: IterationSpace) + def visitFragment(space: Fragment) } object LoopTree { @@ -97,29 +171,16 @@ object LoopTree { } } -class LoopTree private[onetep](localIndex: Option[Index]) { - var subItems = ArrayBuffer[Either[IterationSpace, LoopTree]]() - - private def contains(space: IterationSpace, deep: Boolean) : Boolean = { - var found = false - for (item <- subItems) - found |= (item match { - case (item: LoopTree) => deep && item.contains(space, deep) - case (item: IterationSpace) => item == space - }) - - found - } - - def containsShallow(space: IterationSpace) : Boolean = contains(space, false) - def containsDeep(space: IterationSpace) : Boolean = contains(space, true) +class LoopTree private[onetep](localIndex: Option[Index]) extends Fragment { + var subItems = ArrayBuffer[Fragment]() def accept(visitor: LoopTreeVisitor) { + visitor.enterTree(this) + for (item <- subItems) - item match { - case Left(space) => visitor.visitSpace(space) - case Right(tree) => {visitor.enterTree(tree); tree.accept(visitor); visitor.exitTree(tree)} - } + item.accept(visitor) + + visitor.exitTree(this) } def +(b: LoopTree) : LoopTree = @@ -135,45 +196,37 @@ class LoopTree private[onetep](localIndex: Option[Index]) { val result = collection.mutable.Set[String]() for(item <- subItems) - item match { - case Left(space) => result ++= LoopTree.collectSpaceDeclarations(space, nameManager) - case Right(tree) => result ++= tree.collectDeclarations(nameManager) - } + result ++= item.collectDeclarations(nameManager) result.toSet } - def getDependencies : Set[IterationSpace] = { - val dependencies = collection.mutable.Set[IterationSpace]() + def getDependencies : Set[Fragment] = { + val dependencies = collection.mutable.Set[Fragment]() for (item <- subItems) - item match { - case Left(space) => dependencies ++= space.getDependencies - case Right(tree) => dependencies ++= tree.getDependencies - } - + dependencies ++= item.getDependencies + dependencies.toSet } - def getSpaces : Set[IterationSpace] = { - val spaces = collection.mutable.Set[IterationSpace]() + def getAllFragments : Set[Fragment] = { + val fragments = collection.mutable.Set[Fragment]() + fragments += this for (item <- subItems) - item match { - case Left(space) => spaces += space - case Right(tree) => spaces ++= tree.getSpaces - } + fragments ++= item.getAllFragments - spaces.toSet + fragments.toSet } - def addIterationSpace(indices: List[Index], space: IterationSpace) { + def addFragment(indices: List[Index], fragment: Fragment) { indices match { - case Nil => subItems += Left(space) + case Nil => subItems += fragment case (head :: tail) => { val tree = new LoopTree(Some(head)) - subItems += Right(tree) - tree.addIterationSpace(tail, space) + subItems += tree + tree.addFragment(tail, fragment) } } } @@ -182,11 +235,11 @@ class LoopTree private[onetep](localIndex: Option[Index]) { val trees = collection.mutable.Set[LoopTree]() for (item <- subItems) item match { - case Right(tree) => trees += tree + case (tree: LoopTree) => trees += tree case _ => } - subItems --= trees map (Right(_)) + subItems --= trees def attemptFusion(loops: collection.mutable.Set[LoopTree]) : Boolean = { for(a <- loops) @@ -200,19 +253,14 @@ class LoopTree private[onetep](localIndex: Option[Index]) { while(attemptFusion(trees)) {} trees map (_.fuse()) - subItems ++= trees map (Right(_)) + subItems ++= trees sort() } def sort() { - def compareItems(before: Either[IterationSpace, LoopTree], after: Either[IterationSpace, LoopTree]) : Boolean = - (before, after) match { - case (Left(space1), Left(space2)) => space2.getDependencies.contains(space1) - case (Right(tree1), Right(tree2)) => (tree2.getDependencies & tree1.getSpaces).nonEmpty - case (Left(space), Right(tree)) => tree.getDependencies.contains(space) - case (Right(tree), Left(space)) => (space.getDependencies & tree.getSpaces).nonEmpty - } + def compareItems(before: Fragment, after: Fragment) : Boolean = + (after.getDependencies & before.getAllFragments).nonEmpty subItems = subItems.sortWith(compareItems(_, _)) } @@ -232,8 +280,8 @@ class LoopTree private[onetep](localIndex: Option[Index]) { for (entryID <- 0 until subItems.size) { val subList = (subItems(entryID) match { - case Left(space) => List(space.toString) - case Right(tree) => tree.toStrings + case (tree: LoopTree) => tree.toStrings + case x => List(x.toString) }) val subListHeadPrefix = if (entryID < subItems.size-1) "|--" else "`--" diff --git a/src/ofc/generators/onetep/Tree.scala b/src/ofc/generators/onetep/Tree.scala index 218fc66..5dff2b0 100644 --- a/src/ofc/generators/onetep/Tree.scala +++ b/src/ofc/generators/onetep/Tree.scala @@ -46,8 +46,19 @@ object IterationSpace { term.getOperands.toTraversable.flatMap(flattenPostorder(_)) ++ List(term) } +trait ConsumerGenerator { + def generate(names: NameManager, indices: Map[Index,String], values : Map[IterationSpace, String]) : String +} + +trait ProducerGenerator { + def generate(names: NameManager) : String +} + +trait TransformGenerator { + def generate(names: NameManager) : String +} + trait IterationSpace { - def getAccessExpression(indexNames: NameManager) : String def getOperands : List[IterationSpace] def getSpatialIndices : List[SpatialIndex] def getDiscreteIndices : List[DiscreteIndex] @@ -58,10 +69,17 @@ trait IterationSpace { val operands = getOperands operands.toSet ++ operands.flatMap(_.getDependencies) } + + // Code generation + def getConsumerGenerator : Option[ConsumerGenerator] + def getTransformGenerator : Option[TransformGenerator] + def getProducerGenerator : Option[ProducerGenerator] } trait DataSpace extends IterationSpace { def getOperands = Nil + def getConsumerGenerator = None + def getTransformGenerator = None } trait Matrix extends DataSpace @@ -69,19 +87,23 @@ trait FunctionSet extends DataSpace class Assignment(indexBindings: IndexBindings, lhs: DataSpace, rhs: IterationSpace) extends IterationSpace { def getIndexBindings = indexBindings - def getOperands = List(lhs,rhs) - def getSpatialIndices = Nil - def getDiscreteIndices = Nil + def getOperands = List(rhs) + def getSpatialIndices = lhs.getSpatialIndices + def getDiscreteIndices = lhs.getDiscreteIndices def getExternalIndices = Set() - def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed") + + def getConsumerGenerator = None + def getTransformGenerator = None + def getProducerGenerator = None } -class Scalar(value: Double) extends IterationSpace { - def getOperands = Nil +class Scalar(value: Double) extends DataSpace { def getSpatialIndices = Nil def getDiscreteIndices = Nil def getExternalIndices = Set() - def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed") + def getProducerGenerator = Some(new ProducerGenerator { + def generate(names: NameManager) = value.toString + }) } class GeneralInnerProduct(operands: List[IterationSpace], removedIndices: Set[Index]) extends IterationSpace { @@ -122,7 +144,10 @@ class GeneralInnerProduct(operands: List[IterationSpace], removedIndices: Set[In def getSpatialIndices = spatialIndices def getDiscreteIndices = discreteIndices def getExternalIndices = Set() - def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed") + + def getConsumerGenerator = None + def getTransformGenerator = None + def getProducerGenerator = None } class Reciprocal(op: IterationSpace) extends IterationSpace { @@ -142,7 +167,10 @@ class Reciprocal(op: IterationSpace) extends IterationSpace { def getSpatialIndices = spatialIndices.toList def getDiscreteIndices = op.getDiscreteIndices def getExternalIndices = Set() - def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed") + + def getConsumerGenerator = None + def getTransformGenerator = None + def getProducerGenerator = None } class Laplacian(op: IterationSpace) extends IterationSpace { @@ -150,7 +178,10 @@ class Laplacian(op: IterationSpace) extends IterationSpace { def getSpatialIndices = op.getSpatialIndices def getDiscreteIndices = op.getDiscreteIndices def getExternalIndices = Set() - def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed") + + def getConsumerGenerator = None + def getTransformGenerator = None + def getProducerGenerator = None } class SpatialRestriction(op: IterationSpace) extends IterationSpace { @@ -170,7 +201,10 @@ class SpatialRestriction(op: IterationSpace) extends IterationSpace { def getSpatialIndices = spatialIndices.toList def getDiscreteIndices = op.getDiscreteIndices def getExternalIndices = Set() - def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed") + + def getConsumerGenerator = None + def getTransformGenerator = None + def getProducerGenerator = None } class SPAM3(name : String) extends Matrix { @@ -214,7 +248,8 @@ class SPAM3(name : String) extends Matrix { def getSpatialIndices = Nil def getDiscreteIndices = List(rowIndex, colIndex) def getExternalIndices = Set() - def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed") + + def getProducerGenerator = None } class PPDFunctionSet(val basis : String, data : String) extends FunctionSet { @@ -243,7 +278,6 @@ class PPDFunctionSet(val basis : String, data : String) extends FunctionSet { def getDenseWidth(names: NameManager) = parent.basis+"%max_n_ppds_sphere" def generateIterationHeader(names: NameManager) = { - val initDense = "call basis_find_ppd_in_neighbour(" + denseIndexNames.mkString(",") + ", &\n" + parent.getSphere(names) + "%ppd_list(1," + names(this) + "), &\n" + parent.getSphere(names) + "%ppd_list(2," + names(this) + "), &\n" + @@ -291,15 +325,17 @@ class PPDFunctionSet(val basis : String, data : String) extends FunctionSet { def getDiscreteIndices = List(getSphereIndex) def getExternalIndices = Set(getPPDIndex) - def getAccessExpression(indexNames: NameManager) = { - val index = getSphere(indexNames)+"%offset + &\n" + - "("+indexNames(getPPDIndex)+"-1)*pub_cell%n_pts - 1 + &\n" + - "(" + indexNames(spatialIndices(2)) + "-1)*pub_cell%n_pt2*pub_cell%n_pt1 + &\n" + - "(" + indexNames(spatialIndices(1)) + "-1)*pub_cell%n_pt1 + &\n" + - indexNames(spatialIndices(0)) + def getProducerGenerator = Some(new ProducerGenerator { + def generate(names: NameManager) = { + val offset = getSphere(names)+"%offset + &\n" + + "("+names(getPPDIndex)+"-1)*pub_cell%n_pts - 1 + &\n" + + "(" + names(spatialIndices(2)) + "-1)*pub_cell%n_pt2*pub_cell%n_pt1 + &\n" + + "(" + names(spatialIndices(1)) + "-1)*pub_cell%n_pt1 + &\n" + + names(spatialIndices(0)) - data+"("+index+")" - } + data+"("+offset+")" + } + }) } class BindingIndex(name : String) {