]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Work on indexing.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Tue, 24 Jan 2012 19:29:50 +0000 (19:29 +0000)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Tue, 24 Jan 2012 19:29:50 +0000 (19:29 +0000)
src/ofc/generators/Onetep.scala
src/ofc/generators/onetep/CodeGenerator.scala [new file with mode: 0644]
src/ofc/generators/onetep/Tree.scala

index 6f5e19b38365f16d6312e19b44c7517b3d5adeb5..a127e7df97320b317992fd938859ff9a555f325b 100644 (file)
@@ -87,7 +87,8 @@ class Onetep extends Generator {
     println(definition)
     val builder = new TreeBuilder(dictionary)
     val assignment = builder(definition.term, definition.expr)
-    println(assignment)
+    val codeGenerator = new CodeGenerator()
+    codeGenerator(assignment)
   }
 
   def buildDefinitions(statements : List[parser.Statement]) {
diff --git a/src/ofc/generators/onetep/CodeGenerator.scala b/src/ofc/generators/onetep/CodeGenerator.scala
new file mode 100644 (file)
index 0000000..934090d
--- /dev/null
@@ -0,0 +1,48 @@
+package ofc.generators.onetep
+import scala.collection.mutable.HashMap
+
+class IndexNames {
+  var nextIndexID = 0
+  val names = new HashMap[Index, String]()
+
+  def addIndex(index: Index) {
+    val name = index.getName + "_" + nextIndexID
+    nextIndexID += 1
+    names += (index -> name)
+  }
+
+  def apply(index: Index) = names(index)
+}
+
+class CodeGenerator {
+  val indexNames = new IndexNames()
+
+  def apply(assignment: Assignment) {
+    generateCode(assignment)
+  }
+
+  def generateCode(space: IterationSpace) {
+    val operands = space.getOperands
+
+    for(operand <- operands)
+      generateCode(operand)
+
+    val lowerIndices = operands flatMap (x => x.getDiscreteIndices ++ x.getSpatialIndices) toSet
+    val upperIndices = space.getDiscreteIndices ++ space.getSpatialIndices toSet
+
+    val createdIndices = upperIndices -- lowerIndices
+    val destroyedIndices = lowerIndices -- upperIndices
+
+    println("created: "+createdIndices.mkString(","))
+    println("destroyed: "+destroyedIndices.mkString(","))
+
+    if (!destroyedIndices.isEmpty) {
+      // We search for all indices bound to the one being destroyed
+      // We generate a composite iteration over those loops
+      // If GeneralInnerProduct rebuilds derived indices, we need to be able to construct a valid size
+      System.exit(0)
+    }
+
+    // When an index is destroyed -> generate a possibly composite loop over the index
+  }
+}
index c07e001c92c9d927d35e57db35a8e5cff915d441..02c51457ca061797c477643972228df9eab944a5 100644 (file)
 package ofc.generators.onetep
 
-import scala.collection.mutable.{HashMap,HashSet,Set}
 import ofc.parser
 import ofc.parser.Identifier
 import ofc.{InvalidInputException,UnimplementedException}
 
-trait Index
+trait Index {
+  def getName : String
+  def getDependencies : Set[Index]
+}
 trait SpatialIndex extends Index
 trait DiscreteIndex extends Index
 
 trait IterationSpace {
+  def getAccessExpression(indexNames: IndexNames) : String
+  def getOperands() : List[IterationSpace]
   def getSpatialIndices() : List[SpatialIndex]
   def getDiscreteIndices() : List[DiscreteIndex]
 }
 
 trait DataSpace extends IterationSpace {
+  def getOperands() = Nil
 }
 
 trait Matrix extends DataSpace
 trait FunctionSet extends DataSpace
 
+class Assignment(indexBindings: IndexBindings, lhs: DataSpace, rhs: IterationSpace) extends IterationSpace {
+  override def toString = indexBindings.toString
+  def getIndexBindings = indexBindings
+  def getOperands = List(lhs,rhs)
+  def getSpatialIndices = Nil
+  def getDiscreteIndices = Nil
+  def getAccessExpression(indexNames: IndexNames) = throw new UnimplementedException("Access failed")
+}
+
 class Scalar(value: Double) extends IterationSpace {
+  def getOperands() = Nil
   def getSpatialIndices() = Nil
   def getDiscreteIndices() = Nil
+  def getAccessExpression(indexNames: IndexNames) = throw new UnimplementedException("Access failed")
 }
 
-class GeneralInnerProduct(operands: List[IterationSpace], removedIndices: List[Index]) extends IterationSpace {
-  def getSpatialIndices() = operands flatMap (op => op.getSpatialIndices filterNot (index => removedIndices.contains(index)))
-  def getDiscreteIndices() = operands flatMap (op => op.getDiscreteIndices filterNot (index => removedIndices.contains(index)))
+class GeneralInnerProduct(operands: List[IterationSpace], removedIndices: Set[Index]) extends IterationSpace {
+
+  class DenseSpatialIndex(parent: GeneralInnerProduct) extends SpatialIndex{
+    def getDependencies = Set()
+    def getName = "dense_spatial_index"
+  }
+
+  class DenseDiscreteIndex(parent: GeneralInnerProduct) extends DiscreteIndex {
+    def getDependencies = Set()
+    def getName = "dense_discrete_index"
+  }
+
+  val spatialIndices =
+    for(op <- operands; index <- op.getSpatialIndices; if (!removedIndices.contains(index))) yield
+      if (index.getDependencies.intersect(removedIndices).isEmpty)
+        index
+      else
+        new DenseSpatialIndex(this)
+
+  val discreteIndices =
+    for(op <- operands; index <- op.getDiscreteIndices; if (!removedIndices.contains(index))) yield
+      if (index.getDependencies.intersect(removedIndices).isEmpty)
+        index
+      else
+        new DenseDiscreteIndex(this)
+
+  def getOperands = operands
+  def getSpatialIndices() = spatialIndices
+  def getDiscreteIndices() = discreteIndices
+  def getAccessExpression(indexNames: IndexNames) = throw new UnimplementedException("Access failed")
 }
 
 class Reciprocal(op: IterationSpace) extends IterationSpace {
-  class BlockIndex(parent: Reciprocal, dimension: Int)  extends SpatialIndex
+  class BlockIndex(parent: Reciprocal, dimension: Int)  extends SpatialIndex {
+    def getName = "reciprocal_index_" + dimension
+    def getDependencies = Set()
+  }
   val spatialIndices = for (dimension <- 0 to op.getSpatialIndices.size) yield new BlockIndex(this, dimension)
 
+  def getOperands = List(op)
   def getSpatialIndices() = spatialIndices.toList
   def getDiscreteIndices() = op.getDiscreteIndices
+  def getAccessExpression(indexNames: IndexNames) = throw new UnimplementedException("Access failed")
 }
 
 class Laplacian(op: IterationSpace) extends IterationSpace {
+  def getOperands() = List(op)
   def getSpatialIndices() = op.getSpatialIndices
   def getDiscreteIndices() = op.getDiscreteIndices
+  def getAccessExpression(indexNames: IndexNames) = throw new UnimplementedException("Access failed")
 }
 
 class SpatialRestriction(op: IterationSpace) extends IterationSpace {
-  class RestrictedIndex(parent: SpatialRestriction, dimension: Int)  extends SpatialIndex
+  class RestrictedIndex(parent: SpatialRestriction, dimension: Int)  extends SpatialIndex {
+    def getName = "restriction_index_" + dimension
+    def getDependencies = Set()
+  }
+
   val spatialIndices = for (dimension <- 0 to op.getSpatialIndices.size) yield new RestrictedIndex(this, dimension)
 
+  def getOperands() = List(op)
   def getSpatialIndices() = spatialIndices.toList
   def getDiscreteIndices() = op.getDiscreteIndices
+  def getAccessExpression(indexNames: IndexNames) = throw new UnimplementedException("Access failed")
 }
 
 class SPAM3(name : String) extends Matrix {
@@ -56,9 +112,13 @@ class SPAM3(name : String) extends Matrix {
 
   class RowIndex(parent: SPAM3) extends DiscreteIndex {
     override def toString = parent + ".row"
+    def getName = "row_index"
+    def getDependencies = Set()
   }
   class ColIndex(parent: SPAM3) extends DiscreteIndex {
     override def toString = parent + ".col"
+    def getName = "row_index"
+    def getDependencies = Set()
   }
 
   val rowIndex = new RowIndex(this)
@@ -66,12 +126,24 @@ class SPAM3(name : String) extends Matrix {
 
   def getSpatialIndices() = Nil
   def getDiscreteIndices() = List(rowIndex, colIndex)
+  def getAccessExpression(indexNames: IndexNames) = throw new UnimplementedException("Access failed")
 }
 
 class PPDFunctionSet(basis : String, data : String) extends FunctionSet {
-  class SphereIndex(parent: PPDFunctionSet) extends DiscreteIndex
-  class PPDIndex(parent: PPDFunctionSet) extends DiscreteIndex
-  class IntraPPDIndex(parent: PPDFunctionSet, dimension: Int) extends SpatialIndex
+  class SphereIndex(parent: PPDFunctionSet) extends DiscreteIndex {
+    def getName = "sphere_index"
+    def getDependencies = Set()
+  }
+
+  class PPDIndex(parent: PPDFunctionSet) extends DiscreteIndex {
+    def getName = "ppd_index"
+    def getDependencies = Set[Index](parent.getSphereIndex())
+  }
+
+  class IntraPPDIndex(parent: PPDFunctionSet, dimension: Int) extends SpatialIndex {
+    def getName = "intra_ppd_index_" + dimension
+    def getDependencies = Set[Index](parent.getPPDIndex)
+  }
 
   val ppdIndex = new PPDIndex(this)
   val sphereIndex = new SphereIndex(this)
@@ -82,6 +154,7 @@ class PPDFunctionSet(basis : String, data : String) extends FunctionSet {
 
   def getSpatialIndices() = spatialIndices.toList
   def getDiscreteIndices() = List(getSphereIndex(), getPPDIndex())
+  def getAccessExpression(indexNames: IndexNames) = throw new UnimplementedException("Access failed")
 }
 
 class BindingIndex(name : String) {
@@ -89,6 +162,8 @@ class BindingIndex(name : String) {
 }
 
 class Dictionary {
+  import scala.collection.mutable.HashMap
+
   var matrices = new HashMap[Identifier, Matrix]
   var functionSets = new HashMap[Identifier, FunctionSet]
   var indices = new HashMap[Identifier, BindingIndex]
@@ -109,13 +184,11 @@ class Dictionary {
     }
 }
 
-class Assignment(indexBindings: IndexBindings, lhs: DataSpace, rhs: IterationSpace) {
-  override def toString = indexBindings.toString
-}
-
 class IndexBindings {
+  import scala.collection.mutable.{Set,HashSet, HashMap}
+
   val spatial = new HashMap[BindingIndex, Set[SpatialIndex]]
-  val discrete = new HashMap[BindingIndex, Set[DiscreteIndex]]
+  val discrete = new HashMap[BindingIndex,Set[DiscreteIndex]]
 
   def add(binding: BindingIndex, index: SpatialIndex) = spatial.getOrElseUpdate(binding, new HashSet()) += index
   def add(binding: BindingIndex, index: DiscreteIndex) = discrete.getOrElseUpdate(binding, new HashSet()) += index
@@ -139,13 +212,13 @@ class TreeBuilder(dictionary : Dictionary) {
 
     lhsTree match {
       case (lhsTree: DataSpace) => new Assignment(indexBindings, lhsTree, rhsTree)
-      case _ => new InvalidInputException("Non-assignable expression on LHS of assignment.")
+      case _ => throw new InvalidInputException("Non-assignable expression on LHS of assignment.")
     }
   }
 
   def buildIndexedTerm(term: parser.IndexedTerm) : IterationSpace = {
     val dataSpace = dictionary.getData(term.id) match {
-      case (functionSet : PPDFunctionSet) => new GeneralInnerProduct(List(functionSet), List(functionSet.getPPDIndex))
+      case (functionSet : PPDFunctionSet) => new GeneralInnerProduct(List(functionSet), Set(functionSet.getPPDIndex))
       case v => v
     }
 
@@ -167,7 +240,7 @@ class TreeBuilder(dictionary : Dictionary) {
       case (t: IndexedTerm) => buildIndexedTerm(t)
       case ScalarConstant(s) => new Scalar(s)
       case Multiplication(a, b) => 
-        new GeneralInnerProduct(List(buildExpression(a), buildExpression(b)), Nil)
+        new GeneralInnerProduct(List(buildExpression(a), buildExpression(b)), Set())
       case Division(a, b) => 
         throw new UnimplementedException("Semantics of division not yet defined, or implemented.")
       case Operator(Identifier("inner"), List(a,b)) => {
@@ -180,7 +253,7 @@ class TreeBuilder(dictionary : Dictionary) {
           indexBindings.add(bindingIndex, right)
         }
 
-        new GeneralInnerProduct(List(aExpression, bExpression), aExpression.getSpatialIndices ++ bExpression.getSpatialIndices)
+        new GeneralInnerProduct(List(aExpression, bExpression), (aExpression.getSpatialIndices ++ bExpression.getSpatialIndices).toSet)
       }
       case Operator(Identifier("reciprocal"), List(op)) => new Reciprocal(buildExpression(op))
       case Operator(Identifier("laplacian"), List(op)) => new Laplacian(buildExpression(op))