]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Initial work on re-enabling tree builder and code generator.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Thu, 5 Apr 2012 17:13:09 +0000 (18:13 +0100)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Thu, 5 Apr 2012 17:13:09 +0000 (18:13 +0100)
14 files changed:
examples/test.ofl [new file with mode: 0644]
src/ofc/OFC.scala
src/ofc/codegen/Expression.scala
src/ofc/codegen/ProducerStatement.scala
src/ofc/codegen/ScopeStatement.scala
src/ofc/generators/Generator.scala
src/ofc/generators/Onetep.scala
src/ofc/generators/onetep/Assignment.scala
src/ofc/generators/onetep/CodeGenerator.scala
src/ofc/generators/onetep/Index.scala
src/ofc/generators/onetep/IterationSpace.scala
src/ofc/generators/onetep/LoopTree.scala [deleted file]
src/ofc/generators/onetep/SPAM3.scala
src/ofc/generators/onetep/TreeBuilder.scala

diff --git a/examples/test.ofl b/examples/test.ofl
new file mode 100644 (file)
index 0000000..d65c864
--- /dev/null
@@ -0,0 +1,12 @@
+# Parameter information
+Matrix kinet
+FunctionSet bra
+
+# Computation
+kinet = bra
+
+# Implementation specific
+target ONETEP
+kinet is SPAM3("kinet")
+bra is PPDFunctionSet("bra_basis", "bras_on_grid")
+output is FortranFunction("integrals_kinetic", ["kinet", "bras_on_grid", "bra_basis", "kets_on_grid", "ket_basis"])
index b082eef76cd0464c38f72dc32e76c1f0f2ec028c..b5b303d198923be50fcf951ee477a9fc960abdb1 100644 (file)
@@ -35,13 +35,13 @@ object OFC extends Parser {
     }
   }
 
-  def processAST(statements : List[Statement]) = {
+  def processAST(statements : Seq[Statement]) = {
     val targetStatements = 
-      statements.filter(_ match { case _ : Target => true; case _ => false }).asInstanceOf[List[Target]]
+      statements.filter(_ match { case _ : Target => true; case _ => false }).asInstanceOf[Seq[Target]]
 
     val generator : Generator = targetStatements match {
-      case List(Target(Identifier("ONETEP"))) => new generators.Onetep
-      case List(Target(Identifier(x))) => throw new InvalidInputException("Unknown target: " + x)
+      case Seq(Target(Identifier("ONETEP"))) => new generators.Onetep
+      case Seq(Target(Identifier(x))) => throw new InvalidInputException("Unknown target: " + x)
       case _ => throw new InvalidInputException("OFL file should have single target statement.")
     }
 
index fcd39016c37f3603b20583899a3100fb6eb4f3b5..71b863353ba36fb3a94555c81b87cade7c03a189 100644 (file)
@@ -73,7 +73,7 @@ class VarRef[T <: Type](symbol: VarSymbol[T]) extends Expression[T]
 
 // Struct and array accesses
 class FieldAccess[T <: Type](expression: Expression[StructType], field: FieldSymbol[T]) extends Expression[T]
-class ArrayRead[E <: Type](expression: Expression[ArrayType[E]], index: List[Expression[IntType]]) extends Expression[E]
+class ArrayRead[E <: Type](expression: Expression[ArrayType[E]], index: Seq[Expression[IntType]]) extends Expression[E]
 class PointerDereference[E <: Type](expression: Expression[PointerType[E]]) extends Expression[E]
 
 // Literals
index 99d858ee95db2bcd5db70b8aef4f2f5089791f9b..e3b3c72d136dec9be033b7d730e03a91754383a2 100644 (file)
@@ -5,8 +5,8 @@ class ProducerStatement extends Statement {
   class Predicate
 
   var statement = new NullStatement
-  var ranges : List[VariableRange] = List.empty
-  var predicates : List[Predicate] = List.empty
+  var ranges : Seq[VariableRange] = Nil
+  var predicates : Seq[Predicate] = Nil
   var expressions : Map[Symbol, Expression[_]] = Map.empty
 
   def addExpression[T <: Type](name: String, expression: Expression[T]) : VarSymbol[T] = {
index 4606eadb098929d7ded6f63dedc4f30066b53544..1614c186b4ab3a4913fade34ce24ed3f384e108c 100644 (file)
@@ -1,12 +1,13 @@
 package ofc.codegen
 import scala.collection.mutable.ArrayBuffer
 
-class ScopeStatement(initialStatements: List[Statement]) extends Statement {
-
+abstract class ScopeStatement(initialStatements: Seq[Statement] = Nil) extends Statement {
   val statements = initialStatements.toBuffer
 
-  def this() = this(List.empty)
   def +=(stat: Statement) {
     statements += stat
   }
 }
+
+class BlockStatement(initialStatements: Seq[Statement] = Nil) extends ScopeStatement(initialStatements) {
+}
index 5f61405fc8d779435402fdd9e53cd96c0d4ca41a..cf938eb979c82365ce12395dbf3fdb1b89ba438d 100644 (file)
@@ -2,5 +2,5 @@ package ofc.generators
 import ofc.parser.Statement
 
 trait Generator {
-  def acceptInput(program : List[Statement]) : Unit
+  def acceptInput(program : Seq[Statement]) : Unit
 }
index d5d125f28863a1eb4047ee71892c5fe5410f4b6e..3b9e6fde77fcf26c2cc5dc9ab12c0f7b5b6d2d21 100644 (file)
@@ -10,19 +10,19 @@ class Onetep extends Generator {
 
   var dictionary = new Dictionary
 
-  def acceptInput(program : List[parser.Statement]) = {
+  def acceptInput(program : Seq[parser.Statement]) = {
     println("Parsed input:\n"+program.mkString("\n") + "\n")
     buildDictionary(program)
     buildDefinitions(program)
   }
 
-  def filterStatements[T <: parser.Statement](statements : List[parser.Statement])(implicit m: Manifest[T]) =
+  def filterStatements[T <: parser.Statement](statements : Seq[parser.Statement])(implicit m: Manifest[T]) =
     statements.foldLeft(List[T]())((list, item) => item match {
       case s if (singleType(s) <:< m) =>  s.asInstanceOf[T] :: list
       case _ => list
     })
 
-  def getDeclarations(statements : List[parser.Statement]) : Map[parser.Identifier, parser.OFLType] = {
+  def getDeclarations(statements : Seq[parser.Statement]) : Map[parser.Identifier, parser.OFLType] = {
     def getMappings(dl : parser.DeclarationList) =
       for (name <- dl.names) yield
         (name, dl.oflType)
@@ -35,8 +35,8 @@ class Onetep extends Generator {
 
     call match {
       case Some(FunctionCall(matType, params)) => (matType, params) match {
-        //case (Identifier("SPAM3"), ParameterList(StringParameter(name))) => 
-        //  dictionary.matrices += (id -> new SPAM3(name))
+        case (Identifier("SPAM3"), ParameterList(StringParameter(name))) => 
+          dictionary.matrices += (id -> new SPAM3(name))
         case _ => throw new InvalidInputException("Unknown usage of type: "+matType.name)
       }
       case _ => throw new InvalidInputException("Undefined concrete type for matrix: "+id.name)
@@ -48,8 +48,8 @@ class Onetep extends Generator {
 
     call match {
       case Some(FunctionCall(fSetType, params)) => (fSetType, params) match {
-        //case (Identifier("PPDFunctionSet"), ParameterList(StringParameter(basis), StringParameter(data))) => 
-        //  dictionary.functionSets += id -> new PPDFunctionSet(basis, data)
+        case (Identifier("PPDFunctionSet"), ParameterList(StringParameter(basis), StringParameter(data))) => 
+          dictionary.functionSets += id -> new PPDFunctionSet(basis, data)
         case _ => throw new InvalidInputException("Unknown usage of type: "+fSetType.name)
       }
       case _ => throw new InvalidInputException("Undefined concrete type for function set: "+id.name)
@@ -63,7 +63,7 @@ class Onetep extends Generator {
     }
   }
 
-  def buildDictionary(statements : List[parser.Statement]) {
+  def buildDictionary(statements : Seq[parser.Statement]) {
     val targetDeclarations = filterStatements[parser.TargetAssignment](statements)
     val declarations = getDeclarations(statements)
 
@@ -86,11 +86,11 @@ class Onetep extends Generator {
   def buildDefinition(definition : parser.Definition) {
     val builder = new TreeBuilder(dictionary)
     val assignment = builder(definition.term, definition.expr)
-    //val codeGenerator = new CodeGenerator()
-    //codeGenerator(assignment)
+    val codeGenerator = new CodeGenerator()
+    codeGenerator(assignment)
   }
 
-  def buildDefinitions(statements : List[parser.Statement]) {
+  def buildDefinitions(statements : Seq[parser.Statement]) {
     val definitions = filterStatements[parser.Definition](statements)
     if (definitions.size != 1)
       throw new InvalidInputException("Input file should only contain a single definition.")
index 23c34a79ab0c8f82c9b67f494aa316849c818d10..584678610e303e37a14661f66093b82ff43bda53 100644 (file)
@@ -1,14 +1,10 @@
 package ofc.generators.onetep
-/*
-class Assignment(indexBindings: IndexBindings, lhs: DataSpace, rhs: IterationSpace) extends IterationSpace {
-  def getIndexBindings = indexBindings
-  def getOperands = List(rhs)
-  def getSpatialIndices = lhs.getSpatialIndices
-  def getDiscreteIndices = lhs.getDiscreteIndices
-  def getExternalIndices = Set()
+import ofc.codegen.NullStatement
 
-  def getConsumerGenerator = None
-  def getTransformGenerator = None
-  def getProducerGenerator = None
+class Assignment(indexBindings: IndexBindings, lhs: DataSpace, rhs: IterationSpace) extends IterationSpace {
+  def getOperands = List(lhs, rhs)
+  def getSpatialIndices = Nil
+  def getDiscreteIndices = Nil
+  def getReaderFragment = new NullStatement
+  def getSuffixFragment = new NullStatement
 }
-*/
index 9f7c673942626232f596333f9b0dd2fae2ed66b5..d164fd95f97e5409d553a14dcbeb4e4fed16ee5c 100644 (file)
@@ -1,5 +1,7 @@
 package ofc.generators.onetep
 import scala.collection.mutable.HashMap
+import ofc.codegen._
+
 /*
 class NameManager {
   var nextIndexID = 0
@@ -24,8 +26,17 @@ class NameManager {
     name
   }
 }
+*/
 
 class CodeGenerator {
+  def apply(assignment: Assignment) {
+    //val declarations = collectDeclarations(assignment)
+    //for(declaration <- declarations) code append declaration+"\n"
+    generateCode(assignment)
+  }
+
+
+  /*
   val code = new StringBuilder()
   val nameManager = new NameManager()
   
@@ -37,12 +48,7 @@ class CodeGenerator {
     for (op <- term.getOperands) declarationsSet ++= collectDeclarations(op)
     declarationsSet
   }
-
-  def apply(assignment: Assignment) {
-    val declarations = collectDeclarations(assignment)
-    for(declaration <- declarations) code append declaration+"\n"
-    generateCode(assignment)
-  }
+  */
 
   def generateCode(space: IterationSpace) {
     val allSpaces = IterationSpace.flattenPostorder(space)
@@ -52,59 +58,14 @@ class CodeGenerator {
     for(op <- IterationSpace.sort(allSpaces))
       println(op)
     println("\nIndices:")
-    for (i <- Index.sort(allIndices))
+    for (i <- allIndices)
       println(i)
     println("")
 
-    val loopNest = LoopNest(space)
-    println("Code:\n"+loopNest.generateCode+"\n")
-    println("Loop Nest:\n"+ loopNest.getTree + "\n")
-    // Next: we dump all these things into a prefix map
-    System.exit(0)
-
-    val operands = space.getOperands
-
-    for(operand <- operands)
-      generateCode(operand)
-
-    val lowerIndices = operands flatMap (x => x.getDiscreteIndices ++ x.getSpatialIndices) toSet
-    val upperIndices = space.getDiscreteIndices ++ space.getSpatialIndices toSet
-
-    val destroyedIndices = lowerIndices -- upperIndices
-    println("destroyed: "+destroyedIndices.mkString(","))
-
-    for (op <- operands) {
-      val opDestroyedIndices = (op.getSpatialIndices ++ op.getDiscreteIndices).toSet & destroyedIndices
-
-      if (!opDestroyedIndices.isEmpty) {
-        // We search for all indices bound to the one being destroyed
-        // We generate a composite iteration over those loops
-        // If GeneralInnerProduct rebuilds derived indices, we need to be able to construct a valid size
-        val concreteIndexList = opDestroyedIndices.toList
-        val storageName = nameManager.newIdentifier("dense")
-        code append "real(kind=DP), allocatable, dimension" + (":"*concreteIndexList.size).mkString("(",", &\n",")") + " :: " + 
-          storageName + "\n"
-        code append "allocate("+ storageName +
-          (concreteIndexList map ((x : Index) => x.getDenseWidth(nameManager))).mkString("(",",",")") + ", stat=ierr)\n"
-  
-        // We've declared temporary storage, now create the loops to populate it
-        for (index <- concreteIndexList) code append index.generateIterationHeader(nameManager) + "\n"
-        //val lhs = storageName + (concreteIndexList map ((x: Index) => x.getDensePosition(nameManager))).mkString("(",", &\n",")")
-        //val rhs = op.getAccessExpression(nameManager)
-        //code append lhs + " = &\n" + rhs + "\n"
-        for (index <- concreteIndexList) code append index.generateIterationFooter(nameManager) + "\n"
-  
-        println(code.mkString)
-        System.exit(0)
-      }
+    val statements = new BlockStatement
+    for (op <- IterationSpace.sort(allSpaces)) {
+      statements += op.getReaderFragment
+      statements += op.getSuffixFragment
     }
-
-    val createdIndices = upperIndices -- lowerIndices
-    println("created: "+createdIndices.mkString(","))
-
-  
-    // We've now moved al necessary destroyed indices into dense buffers
-    // We now generate the actual loop for space. This may involve a composite iteration construction
   }
 }
-*/
index caabe2b595eb23fbef229515e9ac5710c6fe1add..d340f256d85cacf08d1f4de03d32d8355b4d0da2 100644 (file)
@@ -1,21 +1,6 @@
 package ofc.generators.onetep
 import ofc.codegen.{Expression,IntType}
 
-/*
-object Index {
-  def sort(indices: Traversable[Index]) : List[Index] = {
-    def helper(input: Index, seen: collection.mutable.Set[Index]) : List[Index] =
-    if (seen add input)
-      input.getDependencies.toList.flatMap(helper(_, seen)) ++ List(input)
-    else
-      Nil
-
-    val seen = collection.mutable.Set[Index]()
-    indices.toList.flatMap(helper(_, seen))
-  }
-}
-*/
-
 trait Index {
   def getName : String
   def getMinimumValue : Expression[IntType]
index 70f86a7031ddf7783baa38e3e5e50b5fb50f6c01..3d7f16ed5c5a3ebe668a6f2ef154d5f3b145a19e 100644 (file)
@@ -1,12 +1,11 @@
 package ofc.generators.onetep
 import ofc.codegen.{Statement,NullStatement}
 
-/*
 object IterationSpace {
-  def sort(spaces : Traversable[IterationSpace]) : List[IterationSpace] = {
-    def helper(input: IterationSpace, seen: collection.mutable.Set[IterationSpace]) : List[IterationSpace] =
+  def sort(spaces : Traversable[IterationSpace]) : Seq[IterationSpace] = {
+    def helper(input: IterationSpace, seen: collection.mutable.Set[IterationSpace]) : Seq[IterationSpace] =
       if (seen add input)
-        input.getOperands.flatMap(helper(_, seen)) ++ List(input)
+        input.getOperands.flatMap(helper(_, seen)) :+ input
       else
         Nil
 
@@ -14,15 +13,14 @@ object IterationSpace {
     spaces.toList.flatMap(helper(_, seen))
   }
 
-  def flattenPostorder(term: IterationSpace) : Traversable[IterationSpace] = 
-    term.getOperands.toTraversable.flatMap(flattenPostorder(_)) ++ List(term)
+  def flattenPostorder(term: IterationSpace) : Seq[IterationSpace] = 
+    term.getOperands.toSeq.flatMap(flattenPostorder(_)).+:(term)
 }
-*/
 
 trait IterationSpace {
-  def getOperands : List[IterationSpace]
-  def getSpatialIndices : List[SpatialIndex]
-  def getDiscreteIndices : List[DiscreteIndex]
+  def getOperands : Seq[IterationSpace]
+  def getSpatialIndices : Seq[SpatialIndex]
+  def getDiscreteIndices : Seq[DiscreteIndex]
   def getIndices : Set[Index] = (getSpatialIndices ++ getDiscreteIndices).toSet
   def getDependencies : Set[IterationSpace] = {
     val operands = getOperands
diff --git a/src/ofc/generators/onetep/LoopTree.scala b/src/ofc/generators/onetep/LoopTree.scala
deleted file mode 100644 (file)
index a2ae25f..0000000
+++ /dev/null
@@ -1,379 +0,0 @@
-package ofc.generators.onetep
-
-import scala.collection.mutable.ArrayBuffer
-import ofc.LogicError
-
-/*
-Stores the configuration of indices we will use for code generation.
-*/
-
-/*
-object LoopNest {
-  def apply(root: IterationSpace) : LoopNest = {
-    val sortedSpaces = IterationSpace.flattenPostorder(root)
-    val sortedIndices = Index.sort(sortedSpaces flatMap (_.getIndices))
-    val nest = new LoopNest(sortedIndices.toList)
-
-    for(space <- sortedSpaces) {
-      val indices = space.getIndices
-      val localSortedIndices = sortedIndices filter (indices.contains(_))
-      nest.addIterationSpace(localSortedIndices, space)
-    }
-
-    nest
-  }
-}
-
-class LoopNest(sortedIndices: List[Index]) {
-  val base = new LoopTree(None)
-  val nameManager = new NameManager()
-  var declarations = Set[String]()
-  var spaceFragmentsInfo = Map[IterationSpace, FragmentsInfo]()
-  var consumerProducers = collection.mutable.Map[ConsumerFragment, Set[ProducerFragment]]()
-
-  class FragmentsInfo(val consumer: Option[ConsumerFragment], 
-                      val transform: Option[Fragment], 
-                      val producer: Option[ProducerFragment])
-
-  def addIterationSpace(indices: List[Index], space: IterationSpace) {
-    val operands = space.getOperands
-    val fragmentDepends = for(operand <- operands; fragment <- spaceFragmentsInfo(operand).producer) yield fragment
-
-    // Construct Fragments
-    val consumer = space.getConsumerGenerator match {
-      case Some(consumerGenerator) => Some(new ConsumerFragment(space, fragmentDepends.toSet))
-      case None => None
-    }
-
-    val transform = space.getTransformGenerator match {
-      case Some(transformGenerator) => Some(new TransformFragment(space, consumer.toSet))
-      case None => None
-    }
-
-    val producer = space.getProducerGenerator match {
-      case Some(producerGenerator) => Some(new ProducerFragment(space, transform.toSet ++ consumer.toSet))
-      case None => None
-    }
-
-    // Record consumer-producer relationship
-    consumer match {
-      case Some(c) => consumerProducers.getOrElseUpdate(c, fragmentDepends.toSet)
-      case _ =>
-    }
-
-    // Insert into tree
-    val fragmentsInfo = new FragmentsInfo(consumer, transform, producer)
-    spaceFragmentsInfo += (space -> fragmentsInfo)
-
-    val consumerIndices = getSortedIndices((for (op <- operands; index <- op.getInternalIndices) yield index).toSet)
-    consumer match {
-      case Some(fragment) => base.addFragment(consumerIndices, fragment)
-      case None =>
-    }
-
-    val transformIndices = getSortedIndices(consumerIndices.toSet & space.getIndices.toSet)
-    transform match {
-      case Some(fragment) => base.addFragment(transformIndices, fragment)
-      case None =>
-    }
-
-    val producerIndices = getSortedIndices(space.getIndices.toSet)
-    producer match {
-      case Some(fragment) => base.addFragment(producerIndices, fragment)
-      case None =>
-    }
-
-    base.fuse()
-  }
-
-  private def getSortedIndices(indices: Set[Index]) = sortedIndices filter (indices.contains(_))
-
-  def getTree = base
-
-  def generateCode : String = {
-    computeBuffers()
-
-    val code = new StringBuilder
-    declarations = base.collectDeclarations(nameManager)
-    code append declarations.mkString("\n") + "\n\n"
-
-    val generationVisitor = new GenerationVisitor
-    base.accept(generationVisitor)
-    code append generationVisitor.getCode
-
-    code.mkString
-  }
-
-  private def computeBuffers() {
-    val producerConsumers = collection.mutable.Map[ProducerFragment, collection.mutable.Set[ConsumerFragment]]()
-
-    for ((consumer, producers) <- consumerProducers; producer <- producers)
-      producerConsumers.getOrElseUpdate(producer, collection.mutable.Set[ConsumerFragment]()) += consumer
-
-    val descriptors = for ((producer, consumers) <- producerConsumers) yield 
-      new BufferDescriptor(producer, consumers.toSet)
-
-    descriptors.map(base.addBufferDescriptor(_))
-  }
-
-  class GenerationVisitor extends LoopTreeVisitor {
-    val code = new StringBuilder
-
-    def enterTree(tree: LoopTree) {
-      tree.getLocalIndex match {
-        case None =>
-        case Some(index) => code append index.generateIterationHeader(nameManager)+"\n"
-      }
-    }
-
-    def exitTree(tree: LoopTree) {
-      tree.getLocalIndex match {
-        case None =>
-        case Some(index) => code append index.generateIterationFooter(nameManager)+"\n"
-      }
-    }
-
-    def visitFragment(fragment: Fragment) {
-      fragment match {
-        case (c: ConsumerFragment) => code append "!"+c.toString+"\n"
-        case (p: ProducerFragment) => code append "!"+p.generate(nameManager)+"\n"
-        case (t: TransformFragment) => code append "!"+t.toString+"\n"
-      }
-    }
-
-    def getCode = code.mkString
-  }
-}
-
-trait Fragment {
-  def getAllFragments : Set[Fragment]
-  def getDependencies : Set[Fragment]
-  def collectDeclarations(nameManager: NameManager) : Set[String]
-  def accept(visitor: LoopTreeVisitor) : Unit
-}
-
-class ConsumerFragment(parent: IterationSpace, dependencies: Set[Fragment]) extends Fragment {
-  def getSpace = parent
-  def getAllFragments = Set(this)
-  def getDependencies = dependencies
-  def collectDeclarations(nameManager: NameManager) = parent.getIndices.flatMap(_.getDeclarations(nameManager))
-  def accept(visitor: LoopTreeVisitor) = visitor.visitFragment(this)
-  override def toString = "Consumer: " + parent.toString
-}
-
-class ProducerFragment(parent: IterationSpace, dependencies: Set[Fragment]) extends Fragment {
-  def getSpace = parent
-  def getAllFragments = Set(this)
-  def getDependencies = dependencies
-  def collectDeclarations(nameManager: NameManager) = parent.getIndices.flatMap(_.getDeclarations(nameManager))
-  def accept(visitor: LoopTreeVisitor) = visitor.visitFragment(this)
-  def generate(nameManager: NameManager) = parent.getProducerGenerator.get.generate(nameManager)
-  override def toString = "Producer: " + parent.toString
-}
-
-class TransformFragment(parent: IterationSpace, dependencies: Set[Fragment]) extends Fragment {
-  def getSpace = parent
-  def getAllFragments = Set(this)
-  def getDependencies = dependencies
-  def collectDeclarations(nameManager: NameManager) = parent.getIndices.flatMap(_.getDeclarations(nameManager))
-  def accept(visitor: LoopTreeVisitor) = visitor.visitFragment(this)
-  override def toString = "Transform: " + parent.toString
-}
-
-trait LoopTreeVisitor {
-  def enterTree(tree: LoopTree)
-  def exitTree(tree: LoopTree)
-  def visitFragment(space: Fragment)
-}
-
-class BufferDescriptor(producer: ProducerFragment, consumers: Set[ConsumerFragment]) {
-  def getProducer : ProducerFragment = producer
-  def getConsumers : Set[ConsumerFragment] = consumers
-  def getSpace : IterationSpace = producer.getSpace
-  def getIndices : Set[Index] = getSpace.getInternalIndices
-  override def toString = "Buffer: "+getSpace.toString
-}
-
-object LoopTree {
-  def collectSpaceDeclarations(term: IterationSpace, nameManager: NameManager) : Set[String] = {
-    val declarations = for(index <- term.getIndices;
-                           declaration <- index.getDeclarations(nameManager)) yield declaration 
-
-    var declarationsSet = declarations.toSet
-    for (op <- term.getOperands) declarationsSet ++= collectSpaceDeclarations(op, nameManager)
-    declarationsSet
-  }
-
-  def attemptFusion(a: LoopTree, b: LoopTree, commonScope: LoopTree) : Option[LoopTree] = {
-    if (a == b)
-      None
-    else if (a.getLocalIndex != b.getLocalIndex)
-      None
-    else
-      Some(a + b)
-  }
-}
-
-class LoopTree private[onetep](localIndex: Option[Index]) extends Fragment {
-  var subItems = ArrayBuffer[Fragment]()
-  var bufferDescriptors = ArrayBuffer[BufferDescriptor]()
-
-  def accept(visitor: LoopTreeVisitor) {
-    visitor.enterTree(this)
-
-    for (item <- subItems)
-      item.accept(visitor)
-
-    visitor.exitTree(this)
-  }
-
-  def +(b: LoopTree) : LoopTree =
-    if (getLocalIndex == b.getLocalIndex) {
-      val newTree = new LoopTree(getLocalIndex)
-      newTree.subItems = subItems ++ b.subItems
-      newTree.bufferDescriptors = bufferDescriptors ++ b.bufferDescriptors
-      newTree
-    } else {
-      throw new LogicError("Addition undefined for loops with different indices")
-    }
-
-  def addBufferDescriptor(descriptor: BufferDescriptor) {
-    val allFragments = getAllFragments
-
-    // FIXME: the toSet call should be unneccessary
-    if (!allFragments.contains(descriptor.getProducer) ||
-        !descriptor.getConsumers.toSet[Fragment].subsetOf(allFragments)) {
-      throw new LogicError("Attempt to place buffer in LoopTree missing consumer or producer.")
-    }
-
-    for(item <- subItems) item match {
-      case tree: LoopTree => {
-        val placeInside = tree.getLocalIndex match {
-          case Some(index) =>
-            descriptor.getIndices.contains(index) &&
-            (descriptor.getConsumers.toSet[Fragment] + descriptor.getProducer).subsetOf(item.getAllFragments)
-          case None => false
-        }
-
-        if (placeInside) {
-          tree.addBufferDescriptor(descriptor)
-          return
-        }
-      }
-      case _ =>
-    }
-
-    bufferDescriptors += descriptor
-  }
-
-  def getBufferDescriptors : List[BufferDescriptor] = bufferDescriptors.toList
-
-  def collectDeclarations(nameManager: NameManager) : Set[String] = {
-    val result = collection.mutable.Set[String]()
-
-    for(item <- subItems)
-      result ++= item.collectDeclarations(nameManager)
-
-    result.toSet
-  }
-
-  def getDependencies : Set[Fragment] = {
-    val dependencies = collection.mutable.Set[Fragment]()
-
-    for (item <- subItems)
-      dependencies ++= item.getDependencies
-    
-    dependencies.toSet
-  }
-
-  def getAllFragments : Set[Fragment] = {
-    val fragments = collection.mutable.Set[Fragment]()
-    fragments += this
-
-    for (item <- subItems)
-      fragments ++= item.getAllFragments
-
-    fragments.toSet
-  }
-
-  def addFragment(indices: List[Index], fragment: Fragment) {
-    indices match {
-      case Nil => subItems += fragment
-      case (head :: tail) => {
-        val tree = new LoopTree(Some(head))
-        subItems += tree
-        tree.addFragment(tail, fragment)
-      }
-    }
-  }
-
-  def fuse() {
-    val trees = collection.mutable.Set[LoopTree]()
-    for (item <- subItems)
-      item match {
-        case (tree: LoopTree) => trees += tree
-        case _ =>
-      }
-
-    subItems --= trees
-
-    def attemptFusion(loops: collection.mutable.Set[LoopTree]) : Boolean = {
-      for(a <- loops) 
-        for(b <- loops)
-          LoopTree.attemptFusion(a, b, this) match {
-            case Some(fused) => {loops -= a; loops -= b; loops += fused; return true}
-            case None =>
-          }
-      false
-    }
-
-    while(attemptFusion(trees)) {}
-    trees map (_.fuse())
-    subItems ++= trees
-
-    sort()
-  }
-
-  def sort() {
-    def compareItems(before: Fragment, after: Fragment) : Boolean =
-      (after.getDependencies & before.getAllFragments).nonEmpty
-
-    subItems = subItems.sortWith(compareItems(_, _))
-  }
-
-  def getLocalIndex = localIndex
-
-  private def getLoopIndices(space: IterationSpace) : List[Index] = space.getIndices.toList
-
-  override def toString : String = toStrings.mkString("\n")
-
-  private def toStrings : List[String] = {
-    val stringTree = ArrayBuffer[List[String]]()
-    val result = ArrayBuffer[String]()
-    result += "Index: " + (localIndex match {
-      case Some(x) => x.toString
-      case None => "None"
-    })
-
-    for(bufferDescriptor <- bufferDescriptors)
-      stringTree += List(bufferDescriptor.toString)
-    
-    for (subItem <- subItems) {
-      val subList = (subItem match {
-        case (tree: LoopTree) => tree.toStrings
-        case x => List(x.toString)
-      })
-
-      stringTree += subList
-    }
-
-    for((subTree, subTreeIndex) <- stringTree.zipWithIndex) {
-      val subTreeHeadPrefix = if (subTreeIndex < stringTree.size-1) "|--" else "`--"
-      val subTreeTailPrefix = if (subTreeIndex < stringTree.size-1) "|  " else "   "
-      result ++= (subTreeHeadPrefix+subTree.head) :: (subTree.tail.map(subTreeTailPrefix+_))
-    }
-
-    result.toList
-  }
-}
-*/
index b9134448dbae99b87fa77cb62d701c4f29d2a2c7..81988227d084b15ab55f2a5f17b27181ebdc9852 100644 (file)
@@ -1,47 +1,11 @@
 package ofc.generators.onetep
-/*
+import ofc.codegen.NullStatement
+
 class SPAM3(name : String) extends Matrix {
   override def toString = name
   def getName = name
 
-  class RowIndex(parent: SPAM3) extends DiscreteIndex {
-    override def toString = parent + ".row"
-    def getName = "row_index"
-    def getDependencies = Set()
-    def getDenseWidth(names: NameManager) = "sparse_num_rows("+parent.getName+")"
-
-    def generateIterationHeader(names: NameManager) = {
-      val indexName = names(this)
-      "do "+indexName+"=1,"+getDenseWidth(names)
-    }
-
-    def generateIterationFooter(names: NameManager) = "end do"
-    def getDeclarations(names: NameManager) = List("integer :: "+names(this))
-  }
-
-  class ColIndex(parent: SPAM3) extends DiscreteIndex {
-    override def toString = parent + ".col"
-    def getName = "row_index"
-    def getDependencies = Set()
-    def getDenseWidth(names: NameManager) = "sparse_num_cols("+parent.getName+")"
-
-    def generateIterationHeader(names: NameManager) = {
-      val indexName = names(this)
-      "do "+indexName+"=1,"+getDenseWidth(names)
-    }
-
-
-    def generateIterationFooter(names: NameManager) = "end do"
-    def getDeclarations(names: NameManager) = List("integer :: "+names(this))
-  }
-
-  val rowIndex = new RowIndex(this)
-  val colIndex = new ColIndex(this)
-
   def getSpatialIndices = Nil
-  def getDiscreteIndices = List(rowIndex, colIndex)
-  def getExternalIndices = Set()
-
-  def getProducerGenerator = None
+  def getDiscreteIndices = Nil
+  def getSuffixFragment = null
 }
-*/
index 9b878bd786c5dd48bb4752a8a1fd244796a817fe..f5d49833dd7fd8a3413e745d3bc5463507ca6831 100644 (file)
@@ -58,7 +58,7 @@ class TreeBuilder(dictionary : Dictionary) {
     val rhsTree = buildExpression(rhs)
 
     lhsTree match {
-      //case (lhsTree: DataSpace) => new Assignment(indexBindings, lhsTree, rhsTree)
+      case (lhsTree: DataSpace) => new Assignment(indexBindings, lhsTree, rhsTree)
       case _ => throw new InvalidInputException("Non-assignable expression on LHS of assignment.")
     }
   }