From 5007ed0ac8e67aabab08bf4e233cbff04bd436b4 Mon Sep 17 00:00:00 2001 From: Francis Russell Date: Thu, 5 Apr 2012 18:13:09 +0100 Subject: [PATCH] Initial work on re-enabling tree builder and code generator. --- examples/test.ofl | 12 + src/ofc/OFC.scala | 8 +- src/ofc/codegen/Expression.scala | 2 +- src/ofc/codegen/ProducerStatement.scala | 4 +- src/ofc/codegen/ScopeStatement.scala | 7 +- src/ofc/generators/Generator.scala | 2 +- src/ofc/generators/Onetep.scala | 22 +- src/ofc/generators/onetep/Assignment.scala | 18 +- src/ofc/generators/onetep/CodeGenerator.scala | 73 +--- src/ofc/generators/onetep/Index.scala | 15 - .../generators/onetep/IterationSpace.scala | 18 +- src/ofc/generators/onetep/LoopTree.scala | 379 ------------------ src/ofc/generators/onetep/SPAM3.scala | 44 +- src/ofc/generators/onetep/TreeBuilder.scala | 2 +- 14 files changed, 72 insertions(+), 534 deletions(-) create mode 100644 examples/test.ofl delete mode 100644 src/ofc/generators/onetep/LoopTree.scala diff --git a/examples/test.ofl b/examples/test.ofl new file mode 100644 index 0000000..d65c864 --- /dev/null +++ b/examples/test.ofl @@ -0,0 +1,12 @@ +# Parameter information +Matrix kinet +FunctionSet bra + +# Computation +kinet = bra + +# Implementation specific +target ONETEP +kinet is SPAM3("kinet") +bra is PPDFunctionSet("bra_basis", "bras_on_grid") +output is FortranFunction("integrals_kinetic", ["kinet", "bras_on_grid", "bra_basis", "kets_on_grid", "ket_basis"]) diff --git a/src/ofc/OFC.scala b/src/ofc/OFC.scala index b082eef..b5b303d 100644 --- a/src/ofc/OFC.scala +++ b/src/ofc/OFC.scala @@ -35,13 +35,13 @@ object OFC extends Parser { } } - def processAST(statements : List[Statement]) = { + def processAST(statements : Seq[Statement]) = { val targetStatements = - statements.filter(_ match { case _ : Target => true; case _ => false }).asInstanceOf[List[Target]] + statements.filter(_ match { case _ : Target => true; case _ => false }).asInstanceOf[Seq[Target]] val generator : Generator = targetStatements match { - case List(Target(Identifier("ONETEP"))) => new generators.Onetep - case List(Target(Identifier(x))) => throw new InvalidInputException("Unknown target: " + x) + case Seq(Target(Identifier("ONETEP"))) => new generators.Onetep + case Seq(Target(Identifier(x))) => throw new InvalidInputException("Unknown target: " + x) case _ => throw new InvalidInputException("OFL file should have single target statement.") } diff --git a/src/ofc/codegen/Expression.scala b/src/ofc/codegen/Expression.scala index fcd3901..71b8633 100644 --- a/src/ofc/codegen/Expression.scala +++ b/src/ofc/codegen/Expression.scala @@ -73,7 +73,7 @@ class VarRef[T <: Type](symbol: VarSymbol[T]) extends Expression[T] // Struct and array accesses class FieldAccess[T <: Type](expression: Expression[StructType], field: FieldSymbol[T]) extends Expression[T] -class ArrayRead[E <: Type](expression: Expression[ArrayType[E]], index: List[Expression[IntType]]) extends Expression[E] +class ArrayRead[E <: Type](expression: Expression[ArrayType[E]], index: Seq[Expression[IntType]]) extends Expression[E] class PointerDereference[E <: Type](expression: Expression[PointerType[E]]) extends Expression[E] // Literals diff --git a/src/ofc/codegen/ProducerStatement.scala b/src/ofc/codegen/ProducerStatement.scala index 99d858e..e3b3c72 100644 --- a/src/ofc/codegen/ProducerStatement.scala +++ b/src/ofc/codegen/ProducerStatement.scala @@ -5,8 +5,8 @@ class ProducerStatement extends Statement { class Predicate var statement = new NullStatement - var ranges : List[VariableRange] = List.empty - var predicates : List[Predicate] = List.empty + var ranges : Seq[VariableRange] = Nil + var predicates : Seq[Predicate] = Nil var expressions : Map[Symbol, Expression[_]] = Map.empty def addExpression[T <: Type](name: String, expression: Expression[T]) : VarSymbol[T] = { diff --git a/src/ofc/codegen/ScopeStatement.scala b/src/ofc/codegen/ScopeStatement.scala index 4606ead..1614c18 100644 --- a/src/ofc/codegen/ScopeStatement.scala +++ b/src/ofc/codegen/ScopeStatement.scala @@ -1,12 +1,13 @@ package ofc.codegen import scala.collection.mutable.ArrayBuffer -class ScopeStatement(initialStatements: List[Statement]) extends Statement { - +abstract class ScopeStatement(initialStatements: Seq[Statement] = Nil) extends Statement { val statements = initialStatements.toBuffer - def this() = this(List.empty) def +=(stat: Statement) { statements += stat } } + +class BlockStatement(initialStatements: Seq[Statement] = Nil) extends ScopeStatement(initialStatements) { +} diff --git a/src/ofc/generators/Generator.scala b/src/ofc/generators/Generator.scala index 5f61405..cf938eb 100644 --- a/src/ofc/generators/Generator.scala +++ b/src/ofc/generators/Generator.scala @@ -2,5 +2,5 @@ package ofc.generators import ofc.parser.Statement trait Generator { - def acceptInput(program : List[Statement]) : Unit + def acceptInput(program : Seq[Statement]) : Unit } diff --git a/src/ofc/generators/Onetep.scala b/src/ofc/generators/Onetep.scala index d5d125f..3b9e6fd 100644 --- a/src/ofc/generators/Onetep.scala +++ b/src/ofc/generators/Onetep.scala @@ -10,19 +10,19 @@ class Onetep extends Generator { var dictionary = new Dictionary - def acceptInput(program : List[parser.Statement]) = { + def acceptInput(program : Seq[parser.Statement]) = { println("Parsed input:\n"+program.mkString("\n") + "\n") buildDictionary(program) buildDefinitions(program) } - def filterStatements[T <: parser.Statement](statements : List[parser.Statement])(implicit m: Manifest[T]) = + def filterStatements[T <: parser.Statement](statements : Seq[parser.Statement])(implicit m: Manifest[T]) = statements.foldLeft(List[T]())((list, item) => item match { case s if (singleType(s) <:< m) => s.asInstanceOf[T] :: list case _ => list }) - def getDeclarations(statements : List[parser.Statement]) : Map[parser.Identifier, parser.OFLType] = { + def getDeclarations(statements : Seq[parser.Statement]) : Map[parser.Identifier, parser.OFLType] = { def getMappings(dl : parser.DeclarationList) = for (name <- dl.names) yield (name, dl.oflType) @@ -35,8 +35,8 @@ class Onetep extends Generator { call match { case Some(FunctionCall(matType, params)) => (matType, params) match { - //case (Identifier("SPAM3"), ParameterList(StringParameter(name))) => - // dictionary.matrices += (id -> new SPAM3(name)) + case (Identifier("SPAM3"), ParameterList(StringParameter(name))) => + dictionary.matrices += (id -> new SPAM3(name)) case _ => throw new InvalidInputException("Unknown usage of type: "+matType.name) } case _ => throw new InvalidInputException("Undefined concrete type for matrix: "+id.name) @@ -48,8 +48,8 @@ class Onetep extends Generator { call match { case Some(FunctionCall(fSetType, params)) => (fSetType, params) match { - //case (Identifier("PPDFunctionSet"), ParameterList(StringParameter(basis), StringParameter(data))) => - // dictionary.functionSets += id -> new PPDFunctionSet(basis, data) + case (Identifier("PPDFunctionSet"), ParameterList(StringParameter(basis), StringParameter(data))) => + dictionary.functionSets += id -> new PPDFunctionSet(basis, data) case _ => throw new InvalidInputException("Unknown usage of type: "+fSetType.name) } case _ => throw new InvalidInputException("Undefined concrete type for function set: "+id.name) @@ -63,7 +63,7 @@ class Onetep extends Generator { } } - def buildDictionary(statements : List[parser.Statement]) { + def buildDictionary(statements : Seq[parser.Statement]) { val targetDeclarations = filterStatements[parser.TargetAssignment](statements) val declarations = getDeclarations(statements) @@ -86,11 +86,11 @@ class Onetep extends Generator { def buildDefinition(definition : parser.Definition) { val builder = new TreeBuilder(dictionary) val assignment = builder(definition.term, definition.expr) - //val codeGenerator = new CodeGenerator() - //codeGenerator(assignment) + val codeGenerator = new CodeGenerator() + codeGenerator(assignment) } - def buildDefinitions(statements : List[parser.Statement]) { + def buildDefinitions(statements : Seq[parser.Statement]) { val definitions = filterStatements[parser.Definition](statements) if (definitions.size != 1) throw new InvalidInputException("Input file should only contain a single definition.") diff --git a/src/ofc/generators/onetep/Assignment.scala b/src/ofc/generators/onetep/Assignment.scala index 23c34a7..5846786 100644 --- a/src/ofc/generators/onetep/Assignment.scala +++ b/src/ofc/generators/onetep/Assignment.scala @@ -1,14 +1,10 @@ package ofc.generators.onetep -/* -class Assignment(indexBindings: IndexBindings, lhs: DataSpace, rhs: IterationSpace) extends IterationSpace { - def getIndexBindings = indexBindings - def getOperands = List(rhs) - def getSpatialIndices = lhs.getSpatialIndices - def getDiscreteIndices = lhs.getDiscreteIndices - def getExternalIndices = Set() +import ofc.codegen.NullStatement - def getConsumerGenerator = None - def getTransformGenerator = None - def getProducerGenerator = None +class Assignment(indexBindings: IndexBindings, lhs: DataSpace, rhs: IterationSpace) extends IterationSpace { + def getOperands = List(lhs, rhs) + def getSpatialIndices = Nil + def getDiscreteIndices = Nil + def getReaderFragment = new NullStatement + def getSuffixFragment = new NullStatement } -*/ diff --git a/src/ofc/generators/onetep/CodeGenerator.scala b/src/ofc/generators/onetep/CodeGenerator.scala index 9f7c673..d164fd9 100644 --- a/src/ofc/generators/onetep/CodeGenerator.scala +++ b/src/ofc/generators/onetep/CodeGenerator.scala @@ -1,5 +1,7 @@ package ofc.generators.onetep import scala.collection.mutable.HashMap +import ofc.codegen._ + /* class NameManager { var nextIndexID = 0 @@ -24,8 +26,17 @@ class NameManager { name } } +*/ class CodeGenerator { + def apply(assignment: Assignment) { + //val declarations = collectDeclarations(assignment) + //for(declaration <- declarations) code append declaration+"\n" + generateCode(assignment) + } + + + /* val code = new StringBuilder() val nameManager = new NameManager() @@ -37,12 +48,7 @@ class CodeGenerator { for (op <- term.getOperands) declarationsSet ++= collectDeclarations(op) declarationsSet } - - def apply(assignment: Assignment) { - val declarations = collectDeclarations(assignment) - for(declaration <- declarations) code append declaration+"\n" - generateCode(assignment) - } + */ def generateCode(space: IterationSpace) { val allSpaces = IterationSpace.flattenPostorder(space) @@ -52,59 +58,14 @@ class CodeGenerator { for(op <- IterationSpace.sort(allSpaces)) println(op) println("\nIndices:") - for (i <- Index.sort(allIndices)) + for (i <- allIndices) println(i) println("") - val loopNest = LoopNest(space) - println("Code:\n"+loopNest.generateCode+"\n") - println("Loop Nest:\n"+ loopNest.getTree + "\n") - // Next: we dump all these things into a prefix map - System.exit(0) - - val operands = space.getOperands - - for(operand <- operands) - generateCode(operand) - - val lowerIndices = operands flatMap (x => x.getDiscreteIndices ++ x.getSpatialIndices) toSet - val upperIndices = space.getDiscreteIndices ++ space.getSpatialIndices toSet - - val destroyedIndices = lowerIndices -- upperIndices - println("destroyed: "+destroyedIndices.mkString(",")) - - for (op <- operands) { - val opDestroyedIndices = (op.getSpatialIndices ++ op.getDiscreteIndices).toSet & destroyedIndices - - if (!opDestroyedIndices.isEmpty) { - // We search for all indices bound to the one being destroyed - // We generate a composite iteration over those loops - // If GeneralInnerProduct rebuilds derived indices, we need to be able to construct a valid size - val concreteIndexList = opDestroyedIndices.toList - val storageName = nameManager.newIdentifier("dense") - code append "real(kind=DP), allocatable, dimension" + (":"*concreteIndexList.size).mkString("(",", &\n",")") + " :: " + - storageName + "\n" - code append "allocate("+ storageName + - (concreteIndexList map ((x : Index) => x.getDenseWidth(nameManager))).mkString("(",",",")") + ", stat=ierr)\n" - - // We've declared temporary storage, now create the loops to populate it - for (index <- concreteIndexList) code append index.generateIterationHeader(nameManager) + "\n" - //val lhs = storageName + (concreteIndexList map ((x: Index) => x.getDensePosition(nameManager))).mkString("(",", &\n",")") - //val rhs = op.getAccessExpression(nameManager) - //code append lhs + " = &\n" + rhs + "\n" - for (index <- concreteIndexList) code append index.generateIterationFooter(nameManager) + "\n" - - println(code.mkString) - System.exit(0) - } + val statements = new BlockStatement + for (op <- IterationSpace.sort(allSpaces)) { + statements += op.getReaderFragment + statements += op.getSuffixFragment } - - val createdIndices = upperIndices -- lowerIndices - println("created: "+createdIndices.mkString(",")) - - - // We've now moved al necessary destroyed indices into dense buffers - // We now generate the actual loop for space. This may involve a composite iteration construction } } -*/ diff --git a/src/ofc/generators/onetep/Index.scala b/src/ofc/generators/onetep/Index.scala index caabe2b..d340f25 100644 --- a/src/ofc/generators/onetep/Index.scala +++ b/src/ofc/generators/onetep/Index.scala @@ -1,21 +1,6 @@ package ofc.generators.onetep import ofc.codegen.{Expression,IntType} -/* -object Index { - def sort(indices: Traversable[Index]) : List[Index] = { - def helper(input: Index, seen: collection.mutable.Set[Index]) : List[Index] = - if (seen add input) - input.getDependencies.toList.flatMap(helper(_, seen)) ++ List(input) - else - Nil - - val seen = collection.mutable.Set[Index]() - indices.toList.flatMap(helper(_, seen)) - } -} -*/ - trait Index { def getName : String def getMinimumValue : Expression[IntType] diff --git a/src/ofc/generators/onetep/IterationSpace.scala b/src/ofc/generators/onetep/IterationSpace.scala index 70f86a7..3d7f16e 100644 --- a/src/ofc/generators/onetep/IterationSpace.scala +++ b/src/ofc/generators/onetep/IterationSpace.scala @@ -1,12 +1,11 @@ package ofc.generators.onetep import ofc.codegen.{Statement,NullStatement} -/* object IterationSpace { - def sort(spaces : Traversable[IterationSpace]) : List[IterationSpace] = { - def helper(input: IterationSpace, seen: collection.mutable.Set[IterationSpace]) : List[IterationSpace] = + def sort(spaces : Traversable[IterationSpace]) : Seq[IterationSpace] = { + def helper(input: IterationSpace, seen: collection.mutable.Set[IterationSpace]) : Seq[IterationSpace] = if (seen add input) - input.getOperands.flatMap(helper(_, seen)) ++ List(input) + input.getOperands.flatMap(helper(_, seen)) :+ input else Nil @@ -14,15 +13,14 @@ object IterationSpace { spaces.toList.flatMap(helper(_, seen)) } - def flattenPostorder(term: IterationSpace) : Traversable[IterationSpace] = - term.getOperands.toTraversable.flatMap(flattenPostorder(_)) ++ List(term) + def flattenPostorder(term: IterationSpace) : Seq[IterationSpace] = + term.getOperands.toSeq.flatMap(flattenPostorder(_)).+:(term) } -*/ trait IterationSpace { - def getOperands : List[IterationSpace] - def getSpatialIndices : List[SpatialIndex] - def getDiscreteIndices : List[DiscreteIndex] + def getOperands : Seq[IterationSpace] + def getSpatialIndices : Seq[SpatialIndex] + def getDiscreteIndices : Seq[DiscreteIndex] def getIndices : Set[Index] = (getSpatialIndices ++ getDiscreteIndices).toSet def getDependencies : Set[IterationSpace] = { val operands = getOperands diff --git a/src/ofc/generators/onetep/LoopTree.scala b/src/ofc/generators/onetep/LoopTree.scala deleted file mode 100644 index a2ae25f..0000000 --- a/src/ofc/generators/onetep/LoopTree.scala +++ /dev/null @@ -1,379 +0,0 @@ -package ofc.generators.onetep - -import scala.collection.mutable.ArrayBuffer -import ofc.LogicError - -/* -Stores the configuration of indices we will use for code generation. -*/ - -/* -object LoopNest { - def apply(root: IterationSpace) : LoopNest = { - val sortedSpaces = IterationSpace.flattenPostorder(root) - val sortedIndices = Index.sort(sortedSpaces flatMap (_.getIndices)) - val nest = new LoopNest(sortedIndices.toList) - - for(space <- sortedSpaces) { - val indices = space.getIndices - val localSortedIndices = sortedIndices filter (indices.contains(_)) - nest.addIterationSpace(localSortedIndices, space) - } - - nest - } -} - -class LoopNest(sortedIndices: List[Index]) { - val base = new LoopTree(None) - val nameManager = new NameManager() - var declarations = Set[String]() - var spaceFragmentsInfo = Map[IterationSpace, FragmentsInfo]() - var consumerProducers = collection.mutable.Map[ConsumerFragment, Set[ProducerFragment]]() - - class FragmentsInfo(val consumer: Option[ConsumerFragment], - val transform: Option[Fragment], - val producer: Option[ProducerFragment]) - - def addIterationSpace(indices: List[Index], space: IterationSpace) { - val operands = space.getOperands - val fragmentDepends = for(operand <- operands; fragment <- spaceFragmentsInfo(operand).producer) yield fragment - - // Construct Fragments - val consumer = space.getConsumerGenerator match { - case Some(consumerGenerator) => Some(new ConsumerFragment(space, fragmentDepends.toSet)) - case None => None - } - - val transform = space.getTransformGenerator match { - case Some(transformGenerator) => Some(new TransformFragment(space, consumer.toSet)) - case None => None - } - - val producer = space.getProducerGenerator match { - case Some(producerGenerator) => Some(new ProducerFragment(space, transform.toSet ++ consumer.toSet)) - case None => None - } - - // Record consumer-producer relationship - consumer match { - case Some(c) => consumerProducers.getOrElseUpdate(c, fragmentDepends.toSet) - case _ => - } - - // Insert into tree - val fragmentsInfo = new FragmentsInfo(consumer, transform, producer) - spaceFragmentsInfo += (space -> fragmentsInfo) - - val consumerIndices = getSortedIndices((for (op <- operands; index <- op.getInternalIndices) yield index).toSet) - consumer match { - case Some(fragment) => base.addFragment(consumerIndices, fragment) - case None => - } - - val transformIndices = getSortedIndices(consumerIndices.toSet & space.getIndices.toSet) - transform match { - case Some(fragment) => base.addFragment(transformIndices, fragment) - case None => - } - - val producerIndices = getSortedIndices(space.getIndices.toSet) - producer match { - case Some(fragment) => base.addFragment(producerIndices, fragment) - case None => - } - - base.fuse() - } - - private def getSortedIndices(indices: Set[Index]) = sortedIndices filter (indices.contains(_)) - - def getTree = base - - def generateCode : String = { - computeBuffers() - - val code = new StringBuilder - declarations = base.collectDeclarations(nameManager) - code append declarations.mkString("\n") + "\n\n" - - val generationVisitor = new GenerationVisitor - base.accept(generationVisitor) - code append generationVisitor.getCode - - code.mkString - } - - private def computeBuffers() { - val producerConsumers = collection.mutable.Map[ProducerFragment, collection.mutable.Set[ConsumerFragment]]() - - for ((consumer, producers) <- consumerProducers; producer <- producers) - producerConsumers.getOrElseUpdate(producer, collection.mutable.Set[ConsumerFragment]()) += consumer - - val descriptors = for ((producer, consumers) <- producerConsumers) yield - new BufferDescriptor(producer, consumers.toSet) - - descriptors.map(base.addBufferDescriptor(_)) - } - - class GenerationVisitor extends LoopTreeVisitor { - val code = new StringBuilder - - def enterTree(tree: LoopTree) { - tree.getLocalIndex match { - case None => - case Some(index) => code append index.generateIterationHeader(nameManager)+"\n" - } - } - - def exitTree(tree: LoopTree) { - tree.getLocalIndex match { - case None => - case Some(index) => code append index.generateIterationFooter(nameManager)+"\n" - } - } - - def visitFragment(fragment: Fragment) { - fragment match { - case (c: ConsumerFragment) => code append "!"+c.toString+"\n" - case (p: ProducerFragment) => code append "!"+p.generate(nameManager)+"\n" - case (t: TransformFragment) => code append "!"+t.toString+"\n" - } - } - - def getCode = code.mkString - } -} - -trait Fragment { - def getAllFragments : Set[Fragment] - def getDependencies : Set[Fragment] - def collectDeclarations(nameManager: NameManager) : Set[String] - def accept(visitor: LoopTreeVisitor) : Unit -} - -class ConsumerFragment(parent: IterationSpace, dependencies: Set[Fragment]) extends Fragment { - def getSpace = parent - def getAllFragments = Set(this) - def getDependencies = dependencies - def collectDeclarations(nameManager: NameManager) = parent.getIndices.flatMap(_.getDeclarations(nameManager)) - def accept(visitor: LoopTreeVisitor) = visitor.visitFragment(this) - override def toString = "Consumer: " + parent.toString -} - -class ProducerFragment(parent: IterationSpace, dependencies: Set[Fragment]) extends Fragment { - def getSpace = parent - def getAllFragments = Set(this) - def getDependencies = dependencies - def collectDeclarations(nameManager: NameManager) = parent.getIndices.flatMap(_.getDeclarations(nameManager)) - def accept(visitor: LoopTreeVisitor) = visitor.visitFragment(this) - def generate(nameManager: NameManager) = parent.getProducerGenerator.get.generate(nameManager) - override def toString = "Producer: " + parent.toString -} - -class TransformFragment(parent: IterationSpace, dependencies: Set[Fragment]) extends Fragment { - def getSpace = parent - def getAllFragments = Set(this) - def getDependencies = dependencies - def collectDeclarations(nameManager: NameManager) = parent.getIndices.flatMap(_.getDeclarations(nameManager)) - def accept(visitor: LoopTreeVisitor) = visitor.visitFragment(this) - override def toString = "Transform: " + parent.toString -} - -trait LoopTreeVisitor { - def enterTree(tree: LoopTree) - def exitTree(tree: LoopTree) - def visitFragment(space: Fragment) -} - -class BufferDescriptor(producer: ProducerFragment, consumers: Set[ConsumerFragment]) { - def getProducer : ProducerFragment = producer - def getConsumers : Set[ConsumerFragment] = consumers - def getSpace : IterationSpace = producer.getSpace - def getIndices : Set[Index] = getSpace.getInternalIndices - override def toString = "Buffer: "+getSpace.toString -} - -object LoopTree { - def collectSpaceDeclarations(term: IterationSpace, nameManager: NameManager) : Set[String] = { - val declarations = for(index <- term.getIndices; - declaration <- index.getDeclarations(nameManager)) yield declaration - - var declarationsSet = declarations.toSet - for (op <- term.getOperands) declarationsSet ++= collectSpaceDeclarations(op, nameManager) - declarationsSet - } - - def attemptFusion(a: LoopTree, b: LoopTree, commonScope: LoopTree) : Option[LoopTree] = { - if (a == b) - None - else if (a.getLocalIndex != b.getLocalIndex) - None - else - Some(a + b) - } -} - -class LoopTree private[onetep](localIndex: Option[Index]) extends Fragment { - var subItems = ArrayBuffer[Fragment]() - var bufferDescriptors = ArrayBuffer[BufferDescriptor]() - - def accept(visitor: LoopTreeVisitor) { - visitor.enterTree(this) - - for (item <- subItems) - item.accept(visitor) - - visitor.exitTree(this) - } - - def +(b: LoopTree) : LoopTree = - if (getLocalIndex == b.getLocalIndex) { - val newTree = new LoopTree(getLocalIndex) - newTree.subItems = subItems ++ b.subItems - newTree.bufferDescriptors = bufferDescriptors ++ b.bufferDescriptors - newTree - } else { - throw new LogicError("Addition undefined for loops with different indices") - } - - def addBufferDescriptor(descriptor: BufferDescriptor) { - val allFragments = getAllFragments - - // FIXME: the toSet call should be unneccessary - if (!allFragments.contains(descriptor.getProducer) || - !descriptor.getConsumers.toSet[Fragment].subsetOf(allFragments)) { - throw new LogicError("Attempt to place buffer in LoopTree missing consumer or producer.") - } - - for(item <- subItems) item match { - case tree: LoopTree => { - val placeInside = tree.getLocalIndex match { - case Some(index) => - descriptor.getIndices.contains(index) && - (descriptor.getConsumers.toSet[Fragment] + descriptor.getProducer).subsetOf(item.getAllFragments) - case None => false - } - - if (placeInside) { - tree.addBufferDescriptor(descriptor) - return - } - } - case _ => - } - - bufferDescriptors += descriptor - } - - def getBufferDescriptors : List[BufferDescriptor] = bufferDescriptors.toList - - def collectDeclarations(nameManager: NameManager) : Set[String] = { - val result = collection.mutable.Set[String]() - - for(item <- subItems) - result ++= item.collectDeclarations(nameManager) - - result.toSet - } - - def getDependencies : Set[Fragment] = { - val dependencies = collection.mutable.Set[Fragment]() - - for (item <- subItems) - dependencies ++= item.getDependencies - - dependencies.toSet - } - - def getAllFragments : Set[Fragment] = { - val fragments = collection.mutable.Set[Fragment]() - fragments += this - - for (item <- subItems) - fragments ++= item.getAllFragments - - fragments.toSet - } - - def addFragment(indices: List[Index], fragment: Fragment) { - indices match { - case Nil => subItems += fragment - case (head :: tail) => { - val tree = new LoopTree(Some(head)) - subItems += tree - tree.addFragment(tail, fragment) - } - } - } - - def fuse() { - val trees = collection.mutable.Set[LoopTree]() - for (item <- subItems) - item match { - case (tree: LoopTree) => trees += tree - case _ => - } - - subItems --= trees - - def attemptFusion(loops: collection.mutable.Set[LoopTree]) : Boolean = { - for(a <- loops) - for(b <- loops) - LoopTree.attemptFusion(a, b, this) match { - case Some(fused) => {loops -= a; loops -= b; loops += fused; return true} - case None => - } - false - } - - while(attemptFusion(trees)) {} - trees map (_.fuse()) - subItems ++= trees - - sort() - } - - def sort() { - def compareItems(before: Fragment, after: Fragment) : Boolean = - (after.getDependencies & before.getAllFragments).nonEmpty - - subItems = subItems.sortWith(compareItems(_, _)) - } - - def getLocalIndex = localIndex - - private def getLoopIndices(space: IterationSpace) : List[Index] = space.getIndices.toList - - override def toString : String = toStrings.mkString("\n") - - private def toStrings : List[String] = { - val stringTree = ArrayBuffer[List[String]]() - val result = ArrayBuffer[String]() - result += "Index: " + (localIndex match { - case Some(x) => x.toString - case None => "None" - }) - - for(bufferDescriptor <- bufferDescriptors) - stringTree += List(bufferDescriptor.toString) - - for (subItem <- subItems) { - val subList = (subItem match { - case (tree: LoopTree) => tree.toStrings - case x => List(x.toString) - }) - - stringTree += subList - } - - for((subTree, subTreeIndex) <- stringTree.zipWithIndex) { - val subTreeHeadPrefix = if (subTreeIndex < stringTree.size-1) "|--" else "`--" - val subTreeTailPrefix = if (subTreeIndex < stringTree.size-1) "| " else " " - result ++= (subTreeHeadPrefix+subTree.head) :: (subTree.tail.map(subTreeTailPrefix+_)) - } - - result.toList - } -} -*/ diff --git a/src/ofc/generators/onetep/SPAM3.scala b/src/ofc/generators/onetep/SPAM3.scala index b913444..8198822 100644 --- a/src/ofc/generators/onetep/SPAM3.scala +++ b/src/ofc/generators/onetep/SPAM3.scala @@ -1,47 +1,11 @@ package ofc.generators.onetep -/* +import ofc.codegen.NullStatement + class SPAM3(name : String) extends Matrix { override def toString = name def getName = name - class RowIndex(parent: SPAM3) extends DiscreteIndex { - override def toString = parent + ".row" - def getName = "row_index" - def getDependencies = Set() - def getDenseWidth(names: NameManager) = "sparse_num_rows("+parent.getName+")" - - def generateIterationHeader(names: NameManager) = { - val indexName = names(this) - "do "+indexName+"=1,"+getDenseWidth(names) - } - - def generateIterationFooter(names: NameManager) = "end do" - def getDeclarations(names: NameManager) = List("integer :: "+names(this)) - } - - class ColIndex(parent: SPAM3) extends DiscreteIndex { - override def toString = parent + ".col" - def getName = "row_index" - def getDependencies = Set() - def getDenseWidth(names: NameManager) = "sparse_num_cols("+parent.getName+")" - - def generateIterationHeader(names: NameManager) = { - val indexName = names(this) - "do "+indexName+"=1,"+getDenseWidth(names) - } - - - def generateIterationFooter(names: NameManager) = "end do" - def getDeclarations(names: NameManager) = List("integer :: "+names(this)) - } - - val rowIndex = new RowIndex(this) - val colIndex = new ColIndex(this) - def getSpatialIndices = Nil - def getDiscreteIndices = List(rowIndex, colIndex) - def getExternalIndices = Set() - - def getProducerGenerator = None + def getDiscreteIndices = Nil + def getSuffixFragment = null } -*/ diff --git a/src/ofc/generators/onetep/TreeBuilder.scala b/src/ofc/generators/onetep/TreeBuilder.scala index 9b878bd..f5d4983 100644 --- a/src/ofc/generators/onetep/TreeBuilder.scala +++ b/src/ofc/generators/onetep/TreeBuilder.scala @@ -58,7 +58,7 @@ class TreeBuilder(dictionary : Dictionary) { val rhsTree = buildExpression(rhs) lhsTree match { - //case (lhsTree: DataSpace) => new Assignment(indexBindings, lhsTree, rhsTree) + case (lhsTree: DataSpace) => new Assignment(indexBindings, lhsTree, rhsTree) case _ => throw new InvalidInputException("Non-assignable expression on LHS of assignment.") } } -- 2.47.3