From 204c40760c467af873cf87c0f1f4e6436fa688a2 Mon Sep 17 00:00:00 2001 From: Francis Russell Date: Tue, 24 Jan 2012 19:29:50 +0000 Subject: [PATCH] Work on indexing. --- src/ofc/generators/Onetep.scala | 3 +- src/ofc/generators/onetep/CodeGenerator.scala | 48 ++++++++ src/ofc/generators/onetep/Tree.scala | 111 +++++++++++++++--- 3 files changed, 142 insertions(+), 20 deletions(-) create mode 100644 src/ofc/generators/onetep/CodeGenerator.scala diff --git a/src/ofc/generators/Onetep.scala b/src/ofc/generators/Onetep.scala index 6f5e19b..a127e7d 100644 --- a/src/ofc/generators/Onetep.scala +++ b/src/ofc/generators/Onetep.scala @@ -87,7 +87,8 @@ class Onetep extends Generator { println(definition) val builder = new TreeBuilder(dictionary) val assignment = builder(definition.term, definition.expr) - println(assignment) + val codeGenerator = new CodeGenerator() + codeGenerator(assignment) } def buildDefinitions(statements : List[parser.Statement]) { diff --git a/src/ofc/generators/onetep/CodeGenerator.scala b/src/ofc/generators/onetep/CodeGenerator.scala new file mode 100644 index 0000000..934090d --- /dev/null +++ b/src/ofc/generators/onetep/CodeGenerator.scala @@ -0,0 +1,48 @@ +package ofc.generators.onetep +import scala.collection.mutable.HashMap + +class IndexNames { + var nextIndexID = 0 + val names = new HashMap[Index, String]() + + def addIndex(index: Index) { + val name = index.getName + "_" + nextIndexID + nextIndexID += 1 + names += (index -> name) + } + + def apply(index: Index) = names(index) +} + +class CodeGenerator { + val indexNames = new IndexNames() + + def apply(assignment: Assignment) { + generateCode(assignment) + } + + def generateCode(space: IterationSpace) { + val operands = space.getOperands + + for(operand <- operands) + generateCode(operand) + + val lowerIndices = operands flatMap (x => x.getDiscreteIndices ++ x.getSpatialIndices) toSet + val upperIndices = space.getDiscreteIndices ++ space.getSpatialIndices toSet + + val createdIndices = upperIndices -- lowerIndices + val destroyedIndices = lowerIndices -- upperIndices + + println("created: "+createdIndices.mkString(",")) + println("destroyed: "+destroyedIndices.mkString(",")) + + if (!destroyedIndices.isEmpty) { + // We search for all indices bound to the one being destroyed + // We generate a composite iteration over those loops + // If GeneralInnerProduct rebuilds derived indices, we need to be able to construct a valid size + System.exit(0) + } + + // When an index is destroyed -> generate a possibly composite loop over the index + } +} diff --git a/src/ofc/generators/onetep/Tree.scala b/src/ofc/generators/onetep/Tree.scala index c07e001..02c5145 100644 --- a/src/ofc/generators/onetep/Tree.scala +++ b/src/ofc/generators/onetep/Tree.scala @@ -1,54 +1,110 @@ package ofc.generators.onetep -import scala.collection.mutable.{HashMap,HashSet,Set} import ofc.parser import ofc.parser.Identifier import ofc.{InvalidInputException,UnimplementedException} -trait Index +trait Index { + def getName : String + def getDependencies : Set[Index] +} trait SpatialIndex extends Index trait DiscreteIndex extends Index trait IterationSpace { + def getAccessExpression(indexNames: IndexNames) : String + def getOperands() : List[IterationSpace] def getSpatialIndices() : List[SpatialIndex] def getDiscreteIndices() : List[DiscreteIndex] } trait DataSpace extends IterationSpace { + def getOperands() = Nil } trait Matrix extends DataSpace trait FunctionSet extends DataSpace +class Assignment(indexBindings: IndexBindings, lhs: DataSpace, rhs: IterationSpace) extends IterationSpace { + override def toString = indexBindings.toString + def getIndexBindings = indexBindings + def getOperands = List(lhs,rhs) + def getSpatialIndices = Nil + def getDiscreteIndices = Nil + def getAccessExpression(indexNames: IndexNames) = throw new UnimplementedException("Access failed") +} + class Scalar(value: Double) extends IterationSpace { + def getOperands() = Nil def getSpatialIndices() = Nil def getDiscreteIndices() = Nil + def getAccessExpression(indexNames: IndexNames) = throw new UnimplementedException("Access failed") } -class GeneralInnerProduct(operands: List[IterationSpace], removedIndices: List[Index]) extends IterationSpace { - def getSpatialIndices() = operands flatMap (op => op.getSpatialIndices filterNot (index => removedIndices.contains(index))) - def getDiscreteIndices() = operands flatMap (op => op.getDiscreteIndices filterNot (index => removedIndices.contains(index))) +class GeneralInnerProduct(operands: List[IterationSpace], removedIndices: Set[Index]) extends IterationSpace { + + class DenseSpatialIndex(parent: GeneralInnerProduct) extends SpatialIndex{ + def getDependencies = Set() + def getName = "dense_spatial_index" + } + + class DenseDiscreteIndex(parent: GeneralInnerProduct) extends DiscreteIndex { + def getDependencies = Set() + def getName = "dense_discrete_index" + } + + val spatialIndices = + for(op <- operands; index <- op.getSpatialIndices; if (!removedIndices.contains(index))) yield + if (index.getDependencies.intersect(removedIndices).isEmpty) + index + else + new DenseSpatialIndex(this) + + val discreteIndices = + for(op <- operands; index <- op.getDiscreteIndices; if (!removedIndices.contains(index))) yield + if (index.getDependencies.intersect(removedIndices).isEmpty) + index + else + new DenseDiscreteIndex(this) + + def getOperands = operands + def getSpatialIndices() = spatialIndices + def getDiscreteIndices() = discreteIndices + def getAccessExpression(indexNames: IndexNames) = throw new UnimplementedException("Access failed") } class Reciprocal(op: IterationSpace) extends IterationSpace { - class BlockIndex(parent: Reciprocal, dimension: Int) extends SpatialIndex + class BlockIndex(parent: Reciprocal, dimension: Int) extends SpatialIndex { + def getName = "reciprocal_index_" + dimension + def getDependencies = Set() + } val spatialIndices = for (dimension <- 0 to op.getSpatialIndices.size) yield new BlockIndex(this, dimension) + def getOperands = List(op) def getSpatialIndices() = spatialIndices.toList def getDiscreteIndices() = op.getDiscreteIndices + def getAccessExpression(indexNames: IndexNames) = throw new UnimplementedException("Access failed") } class Laplacian(op: IterationSpace) extends IterationSpace { + def getOperands() = List(op) def getSpatialIndices() = op.getSpatialIndices def getDiscreteIndices() = op.getDiscreteIndices + def getAccessExpression(indexNames: IndexNames) = throw new UnimplementedException("Access failed") } class SpatialRestriction(op: IterationSpace) extends IterationSpace { - class RestrictedIndex(parent: SpatialRestriction, dimension: Int) extends SpatialIndex + class RestrictedIndex(parent: SpatialRestriction, dimension: Int) extends SpatialIndex { + def getName = "restriction_index_" + dimension + def getDependencies = Set() + } + val spatialIndices = for (dimension <- 0 to op.getSpatialIndices.size) yield new RestrictedIndex(this, dimension) + def getOperands() = List(op) def getSpatialIndices() = spatialIndices.toList def getDiscreteIndices() = op.getDiscreteIndices + def getAccessExpression(indexNames: IndexNames) = throw new UnimplementedException("Access failed") } class SPAM3(name : String) extends Matrix { @@ -56,9 +112,13 @@ class SPAM3(name : String) extends Matrix { class RowIndex(parent: SPAM3) extends DiscreteIndex { override def toString = parent + ".row" + def getName = "row_index" + def getDependencies = Set() } class ColIndex(parent: SPAM3) extends DiscreteIndex { override def toString = parent + ".col" + def getName = "row_index" + def getDependencies = Set() } val rowIndex = new RowIndex(this) @@ -66,12 +126,24 @@ class SPAM3(name : String) extends Matrix { def getSpatialIndices() = Nil def getDiscreteIndices() = List(rowIndex, colIndex) + def getAccessExpression(indexNames: IndexNames) = throw new UnimplementedException("Access failed") } class PPDFunctionSet(basis : String, data : String) extends FunctionSet { - class SphereIndex(parent: PPDFunctionSet) extends DiscreteIndex - class PPDIndex(parent: PPDFunctionSet) extends DiscreteIndex - class IntraPPDIndex(parent: PPDFunctionSet, dimension: Int) extends SpatialIndex + class SphereIndex(parent: PPDFunctionSet) extends DiscreteIndex { + def getName = "sphere_index" + def getDependencies = Set() + } + + class PPDIndex(parent: PPDFunctionSet) extends DiscreteIndex { + def getName = "ppd_index" + def getDependencies = Set[Index](parent.getSphereIndex()) + } + + class IntraPPDIndex(parent: PPDFunctionSet, dimension: Int) extends SpatialIndex { + def getName = "intra_ppd_index_" + dimension + def getDependencies = Set[Index](parent.getPPDIndex) + } val ppdIndex = new PPDIndex(this) val sphereIndex = new SphereIndex(this) @@ -82,6 +154,7 @@ class PPDFunctionSet(basis : String, data : String) extends FunctionSet { def getSpatialIndices() = spatialIndices.toList def getDiscreteIndices() = List(getSphereIndex(), getPPDIndex()) + def getAccessExpression(indexNames: IndexNames) = throw new UnimplementedException("Access failed") } class BindingIndex(name : String) { @@ -89,6 +162,8 @@ class BindingIndex(name : String) { } class Dictionary { + import scala.collection.mutable.HashMap + var matrices = new HashMap[Identifier, Matrix] var functionSets = new HashMap[Identifier, FunctionSet] var indices = new HashMap[Identifier, BindingIndex] @@ -109,13 +184,11 @@ class Dictionary { } } -class Assignment(indexBindings: IndexBindings, lhs: DataSpace, rhs: IterationSpace) { - override def toString = indexBindings.toString -} - class IndexBindings { + import scala.collection.mutable.{Set,HashSet, HashMap} + val spatial = new HashMap[BindingIndex, Set[SpatialIndex]] - val discrete = new HashMap[BindingIndex, Set[DiscreteIndex]] + val discrete = new HashMap[BindingIndex,Set[DiscreteIndex]] def add(binding: BindingIndex, index: SpatialIndex) = spatial.getOrElseUpdate(binding, new HashSet()) += index def add(binding: BindingIndex, index: DiscreteIndex) = discrete.getOrElseUpdate(binding, new HashSet()) += index @@ -139,13 +212,13 @@ class TreeBuilder(dictionary : Dictionary) { lhsTree match { case (lhsTree: DataSpace) => new Assignment(indexBindings, lhsTree, rhsTree) - case _ => new InvalidInputException("Non-assignable expression on LHS of assignment.") + case _ => throw new InvalidInputException("Non-assignable expression on LHS of assignment.") } } def buildIndexedTerm(term: parser.IndexedTerm) : IterationSpace = { val dataSpace = dictionary.getData(term.id) match { - case (functionSet : PPDFunctionSet) => new GeneralInnerProduct(List(functionSet), List(functionSet.getPPDIndex)) + case (functionSet : PPDFunctionSet) => new GeneralInnerProduct(List(functionSet), Set(functionSet.getPPDIndex)) case v => v } @@ -167,7 +240,7 @@ class TreeBuilder(dictionary : Dictionary) { case (t: IndexedTerm) => buildIndexedTerm(t) case ScalarConstant(s) => new Scalar(s) case Multiplication(a, b) => - new GeneralInnerProduct(List(buildExpression(a), buildExpression(b)), Nil) + new GeneralInnerProduct(List(buildExpression(a), buildExpression(b)), Set()) case Division(a, b) => throw new UnimplementedException("Semantics of division not yet defined, or implemented.") case Operator(Identifier("inner"), List(a,b)) => { @@ -180,7 +253,7 @@ class TreeBuilder(dictionary : Dictionary) { indexBindings.add(bindingIndex, right) } - new GeneralInnerProduct(List(aExpression, bExpression), aExpression.getSpatialIndices ++ bExpression.getSpatialIndices) + new GeneralInnerProduct(List(aExpression, bExpression), (aExpression.getSpatialIndices ++ bExpression.getSpatialIndices).toSet) } case Operator(Identifier("reciprocal"), List(op)) => new Reciprocal(buildExpression(op)) case Operator(Identifier("laplacian"), List(op)) => new Laplacian(buildExpression(op)) -- 2.47.3