From 40d20252066795be20af6831c37deeb4d07309ba Mon Sep 17 00:00:00 2001 From: Francis Russell Date: Mon, 30 Apr 2012 19:15:48 +0100 Subject: [PATCH] More work on tree generation. --- src/ofc/expression/Assignment.scala | 2 +- src/ofc/expression/Dictionary.scala | 4 + src/ofc/expression/Expression.scala | 50 ++++---- src/ofc/expression/TreeBuilder.scala | 10 +- src/ofc/generators/Onetep.scala | 102 ++++++++++----- src/ofc/generators/onetep/Assignment.scala | 10 +- src/ofc/generators/onetep/CodeGenerator.scala | 105 ---------------- src/ofc/generators/onetep/Dictionary.scala | 38 ++++++ src/ofc/generators/onetep/Field.scala | 4 + src/ofc/generators/onetep/FieldAccess.scala | 4 + .../onetep/GeneralInnerProduct.scala | 46 ------- src/ofc/generators/onetep/Index.scala | 15 --- src/ofc/generators/onetep/IndexBindings.scala | 40 ------ src/ofc/generators/onetep/InnerProduct.scala | 3 + .../generators/onetep/IterationSpace.scala | 40 ------ src/ofc/generators/onetep/Laplacian.scala | 12 +- src/ofc/generators/onetep/Matrix.scala | 4 + src/ofc/generators/onetep/NamedIndex.scala | 3 + .../generators/onetep/PPDFunctionSet.scala | 7 +- src/ofc/generators/onetep/Reciprocal.scala | 25 ---- src/ofc/generators/onetep/SPAM3.scala | 7 +- src/ofc/generators/onetep/Scalar.scala | 11 +- src/ofc/generators/onetep/ScalarAccess.scala | 4 + src/ofc/generators/onetep/ScalarLiteral.scala | 3 + src/ofc/generators/onetep/ScaledField.scala | 3 + .../onetep/SpatialRestriction.scala | 54 -------- src/ofc/generators/onetep/TreeBuilder.scala | 119 ------------------ 27 files changed, 181 insertions(+), 544 deletions(-) delete mode 100644 src/ofc/generators/onetep/CodeGenerator.scala create mode 100644 src/ofc/generators/onetep/Dictionary.scala create mode 100644 src/ofc/generators/onetep/Field.scala create mode 100644 src/ofc/generators/onetep/FieldAccess.scala delete mode 100644 src/ofc/generators/onetep/GeneralInnerProduct.scala delete mode 100644 src/ofc/generators/onetep/Index.scala delete mode 100644 src/ofc/generators/onetep/IndexBindings.scala create mode 100644 src/ofc/generators/onetep/InnerProduct.scala delete mode 100644 src/ofc/generators/onetep/IterationSpace.scala create mode 100644 src/ofc/generators/onetep/Matrix.scala create mode 100644 src/ofc/generators/onetep/NamedIndex.scala delete mode 100644 src/ofc/generators/onetep/Reciprocal.scala create mode 100644 src/ofc/generators/onetep/ScalarAccess.scala create mode 100644 src/ofc/generators/onetep/ScalarLiteral.scala create mode 100644 src/ofc/generators/onetep/ScaledField.scala delete mode 100644 src/ofc/generators/onetep/SpatialRestriction.scala delete mode 100644 src/ofc/generators/onetep/TreeBuilder.scala diff --git a/src/ofc/expression/Assignment.scala b/src/ofc/expression/Assignment.scala index fe44991..0a2ea2c 100644 --- a/src/ofc/expression/Assignment.scala +++ b/src/ofc/expression/Assignment.scala @@ -1,5 +1,5 @@ package ofc.expression -case class Assignment(lhs: ScalarExpression, rhs: ScalarExpression) { +case class Assignment(lhs: Scalar, rhs: Scalar) { override def toString = lhs.toString + " = " + rhs.toString } diff --git a/src/ofc/expression/Dictionary.scala b/src/ofc/expression/Dictionary.scala index 74d60e6..0b95012 100644 --- a/src/ofc/expression/Dictionary.scala +++ b/src/ofc/expression/Dictionary.scala @@ -25,4 +25,8 @@ class Dictionary { def getFunctionSet(id: Identifier) : Option[FunctionSet] = functionSets.get(id) def getIndex(id: Identifier) : Option[Index] = indices.get(id) + + def getOperands = matrices.values ++ functionSets.values + + def getIndices = indices.values } diff --git a/src/ofc/expression/Expression.scala b/src/ofc/expression/Expression.scala index 4b5d260..53dc599 100644 --- a/src/ofc/expression/Expression.scala +++ b/src/ofc/expression/Expression.scala @@ -7,14 +7,22 @@ case class Index(id: Identifier) { override def toString() = id.getName } -trait Expression { +sealed trait Expression { def isAssignable : Boolean def numIndices : Int def getDependentIndices : Set[Index] } -trait ScalarExpression extends Expression -trait FieldExpression extends Expression +sealed trait NamedOperand { + val id: Identifier + def getIdentifier = id + def isAssignable = true + def getDependentIndices : Set[Index] = Set.empty + override def toString = id.getName +} + +sealed trait Scalar extends Expression +sealed trait Field extends Expression trait IndexingOperation { val op: Expression @@ -25,54 +33,48 @@ trait IndexingOperation { override def toString = op.toString + getIndices.map(_.getName).mkString("[",",","]") } -trait NamedOperand { - val id: Identifier - def getIdentifier = id - def isAssignable = true - def getDependentIndices : Set[Index] = Set.empty - override def toString = id.getName -} - -case class ScalarIndexingOperation(val op: ScalarExpression, indices: List[Index]) extends IndexingOperation with ScalarExpression { +case class ScalarIndexingOperation(val op: Scalar, indices: List[Index]) extends IndexingOperation with Scalar { def getIndices = indices } -case class FieldIndexingOperation(val op: FieldExpression, indices: List[Index]) extends IndexingOperation with FieldExpression { +case class FieldIndexingOperation(val op: Field, indices: List[Index]) extends IndexingOperation with Field { def getIndices = indices } -case class InnerProduct(left: FieldExpression, right: FieldExpression) extends ScalarExpression { +case class InnerProduct(left: Field, right: Field) extends Scalar { override def toString = "inner(" + left.toString + ", " + right.toString+")" def isAssignable = false def numIndices = 0 def getDependentIndices = left.getDependentIndices ++ right.getDependentIndices } -case class Laplacian(op: FieldExpression) extends FieldExpression { +case class Laplacian(op: Field) extends Field { override def toString = "laplacian("+op.toString+")" def isAssignable = false def numIndices = 0 def getDependentIndices = op.getDependentIndices } -case class FieldScaling(op: FieldExpression, scale: ScalarExpression) extends FieldExpression { +case class FieldScaling(op: Field, scale: Scalar) extends Field { override def toString = "(" + op.toString + "*" + scale.toString + ")" def isAssignable = false def numIndices = 0 def getDependentIndices = op.getDependentIndices ++ scale.getDependentIndices } -class FunctionSet(val id: Identifier) extends FieldExpression with NamedOperand { +case class ScalarLiteral(literal: Double) extends Scalar { + override def toString = literal.toString + def isAssignable = false + def numIndices = 0 + def getDependentIndices = Set.empty +} + +class FunctionSet(val id: Identifier) extends Field with NamedOperand { def numIndices = 1 } -class Matrix(val id: Identifier) extends ScalarExpression with NamedOperand { +class Matrix(val id: Identifier) extends Scalar with NamedOperand { def numIndices = 2 } -class ScalarLiteral(literal: Double) extends ScalarExpression { - override def toString = literal.toString - def isAssignable = false - def numIndices = 0 - def getDependentIndices = Set.empty -} + diff --git a/src/ofc/expression/TreeBuilder.scala b/src/ofc/expression/TreeBuilder.scala index 357c075..3d656d2 100644 --- a/src/ofc/expression/TreeBuilder.scala +++ b/src/ofc/expression/TreeBuilder.scala @@ -12,7 +12,7 @@ class TreeBuilder(dictionary : Dictionary) { if (!lhsTree.isAssignable) throw new InvalidInputException("Non-assignable expression on LHS of assignment.") else (lhsTree, rhsTree) match { - case (lhs: ScalarExpression, rhs: ScalarExpression) => new Assignment(lhs, rhs) + case (lhs: Scalar, rhs: Scalar) => new Assignment(lhs, rhs) case _ => throw new InvalidInputException("Assignment must be of scalar type.") } } @@ -53,16 +53,16 @@ class TreeBuilder(dictionary : Dictionary) { case Division(a, b) => throw new UnimplementedException("Semantics of division not yet defined, or implemented.") case Multiplication(left, right) => (buildExpression(left), buildExpression(right)) match { - case (field: FieldExpression, factor: ScalarExpression) => new FieldScaling(field, factor) - case (factor: ScalarExpression, field: FieldExpression) => new FieldScaling(field, factor) + case (field: Field, factor: Scalar) => new FieldScaling(field, factor) + case (factor: Scalar, field: Field) => new FieldScaling(field, factor) case _ => throw new InvalidInputException("Cannot multiply "+left+" and "+right+".") } case Operator(Identifier("inner"), List(a,b)) => (buildExpression(a), buildExpression(b)) match { - case (left: FieldExpression, right: FieldExpression) => new InnerProduct(left, right) + case (left: Field, right: Field) => new InnerProduct(left, right) case _ => throw new InvalidInputException("inner requires both operands to be fields.") } case Operator(Identifier("laplacian"), List(op)) => buildExpression(op) match { - case (field: FieldExpression) => new Laplacian(field) + case (field: Field) => new Laplacian(field) case _ => throw new InvalidInputException("laplacian can only be applied to a field.") } case Operator(Identifier(name), _) => throw new UnimplementedException("Unknown or incorrectly called operator: "+name) diff --git a/src/ofc/generators/Onetep.scala b/src/ofc/generators/Onetep.scala index c2502b6..5acbad3 100644 --- a/src/ofc/generators/Onetep.scala +++ b/src/ofc/generators/Onetep.scala @@ -1,34 +1,87 @@ package ofc.generators -import ofc.generators.onetep._ +import ofc.InvalidInputException import ofc.parser -import ofc.expression.{Assignment,Expression,Dictionary} +import ofc.expression +import ofc.generators.onetep._ class Onetep extends Generator { - def acceptInput(dictionary: Dictionary, assignment: Assignment, targetSpecific : Seq[parser.TargetAssignment]) { - println(assignment) - if (matchLHS(assignment.lhs)) - println("ha!") + val dictionary = new Dictionary + + def acceptInput(exprDictionary: expression.Dictionary, exprAssignment: + expression.Assignment, targetSpecific : Seq[parser.TargetAssignment]) { + + buildDictionary(exprDictionary, targetSpecific) + + val assignment = new Assignment(buildScalarExpression(exprAssignment.lhs), buildScalarExpression(exprAssignment.rhs)) + } + + private def buildDictionary(exprDictionary: expression.Dictionary, targetSpecific : Seq[parser.TargetAssignment]) { + for(operand <- exprDictionary.getOperands) { + // Find corresponding target-specific declaration if it exists. + val targetDeclarationCall = targetSpecific.filter(_.id == operand.getIdentifier) match { + case Seq(x) => Some(x.value) + case Seq(_,_,_*) => throw new InvalidInputException("Invalid multiple target declarations for symbol " + operand.getIdentifier + ".") + case Nil => None + } + + operand match { + case (m: expression.Matrix) => buildMatrix(operand.getIdentifier, targetDeclarationCall) + case (f: expression.FunctionSet) => buildFunctionSet(operand.getIdentifier, targetDeclarationCall) + } + } + + for(index <- exprDictionary.getIndices) { + dictionary.add(index.getIdentifier, new NamedIndex(index.getName)) + } } - private def matchLHS(expression: Expression) : Boolean = { - import ofc.expression._ + private def getIndex(exprIndex: Seq[expression.Index]) : Seq[NamedIndex] = { + for(index <- exprIndex) yield + dictionary.getIndex(index.getIdentifier) + } - expression match { - case ScalarIndexingOperation(_: Matrix, List(bra, ket)) => true + private def matchLHS(lhs: expression.Scalar) : Boolean = { + lhs match { + case expression.ScalarIndexingOperation(_: expression.Matrix, List(bra, ket)) => true case _ => false } } - /* + private def buildScalarExpression(scalar: expression.Scalar) : Scalar = { + scalar match { + case expression.ScalarLiteral(s) => new ScalarLiteral(s) + case expression.InnerProduct(l, r) => new InnerProduct(buildFieldExpression(l), buildFieldExpression(r)) + case expression.ScalarIndexingOperation(op, indices) => buildScalarAccess(op, indices) + case (m: expression.Matrix) => dictionary.getScalar(m.getIdentifier) + } + } + + private def buildScalarAccess(op: expression.Scalar, indices: Seq[expression.Index]) : Scalar = { + val base = buildScalarExpression(op) + new ScalarAccess(base, getIndex(indices)) + } + + private def buildFieldAccess(op: expression.Field, indices: Seq[expression.Index]) : Field = { + val base = buildFieldExpression(op) + new FieldAccess(base, getIndex(indices)) + } + + private def buildFieldExpression(field: expression.Field) : Field = { + field match { + case expression.Laplacian(op) => new Laplacian(buildFieldExpression(op)) + case expression.FieldScaling(op, scale) => new ScaledField(buildFieldExpression(op), buildScalarExpression(scale)) + case expression.FieldIndexingOperation(op, indices) => buildFieldAccess(op, indices) + case (f: expression.FunctionSet) => dictionary.getField(f.getIdentifier) + } + } def buildMatrix(id: parser.Identifier, call : Option[parser.FunctionCall]) { import parser._ call match { case Some(FunctionCall(matType, params)) => (matType, params) match { - case (Identifier("SPAM3"), ParameterList(StringParameter(name))) => - dictionary.matrices += (id -> new SPAM3(name)) + case (Identifier("SPAM3"), ParameterList(StringParameter(name))) => dictionary.add(id, new SPAM3(name)) case _ => throw new InvalidInputException("Unknown usage of type: "+matType.name) } case _ => throw new InvalidInputException("Undefined concrete type for matrix: "+id.name) @@ -41,13 +94,14 @@ class Onetep extends Generator { call match { case Some(FunctionCall(fSetType, params)) => (fSetType, params) match { case (Identifier("PPDFunctionSet"), ParameterList(StringParameter(basis), StringParameter(data))) => - dictionary.functionSets += id -> PPDFunctionSet(basis, data) + dictionary.add(id, new PPDFunctionSet(basis, data)) case _ => throw new InvalidInputException("Unknown usage of type: "+fSetType.name) } case _ => throw new InvalidInputException("Undefined concrete type for function set: "+id.name) } } + /* def buildBindingIndex(id: parser.Identifier, call : Option[parser.FunctionCall]) { call match { case Some(_) => throw new InvalidInputException("Index "+id.name+" cannot have concrete type.") @@ -55,26 +109,6 @@ class Onetep extends Generator { } } - def buildDictionary(statements : Seq[parser.Statement]) { - val targetDeclarations = filterStatements[parser.TargetAssignment](statements) - val declarations = getDeclarations(statements) - - for(d <- declarations) { - // Find corresponding target-specific declaration if it exists. - val targetDeclarationCall = targetDeclarations.filter(_.id == d._1) match { - case List(x) => Some(x.value) - case List(_,_,_*) => throw new InvalidInputException("Invalid multiple target declarations for symbol " + d._1 + ".") - case Nil => None - } - - d match { - case (id, parser.Matrix()) => buildMatrix(id, targetDeclarationCall) - case (id, parser.FunctionSet()) => buildFunctionSet(id, targetDeclarationCall) - case (id, parser.Index()) => buildBindingIndex(id, targetDeclarationCall) - } - } - } - def buildDefinition(definition : parser.Definition) { val builder = new TreeBuilder(dictionary) val assignment = builder(definition.term, definition.expr) diff --git a/src/ofc/generators/onetep/Assignment.scala b/src/ofc/generators/onetep/Assignment.scala index 36414b8..04e17ee 100644 --- a/src/ofc/generators/onetep/Assignment.scala +++ b/src/ofc/generators/onetep/Assignment.scala @@ -1,11 +1,3 @@ package ofc.generators.onetep -import ofc.codegen.{ProducerStatement,NullStatement,FloatLiteral} -class Assignment(indexBindings: IndexBindings, lhs: DataSpace, rhs: IterationSpace) extends IterationSpace { - // TODO: Implement assignment - def getOperands = List(rhs) - def getSpatialIndices = Nil - def getDiscreteIndices = Nil - def getProducer(ancestors: Map[IterationSpace, ProducerStatement]) = ancestors.get(rhs).get - def getDataValue = new FloatLiteral(0.0) -} +class Assignment(lhs: Scalar, rhs: Scalar) diff --git a/src/ofc/generators/onetep/CodeGenerator.scala b/src/ofc/generators/onetep/CodeGenerator.scala deleted file mode 100644 index 4d6737e..0000000 --- a/src/ofc/generators/onetep/CodeGenerator.scala +++ /dev/null @@ -1,105 +0,0 @@ -package ofc.generators.onetep -import scala.collection.mutable.HashMap -import ofc.codegen._ - -/* -class NameManager { - var nextIndexID = 0 - val names = new HashMap[Index, String]() - - def addIndex(index: Index) = { - val name = index.getName + "_" + nextIndexID - nextIndexID += 1 - names += (index -> name) - name - } - - def apply(index: Index) = - if (names.contains(index)) - names(index) - else - addIndex(index) - - def newIdentifier(prefix: String) = { - val name = prefix + "_" + nextIndexID - nextIndexID +=1 - name - } -} -*/ - -class CodeGenerator(indexBindings: IndexBindings) { - private val indexSymbols = { - def createMapping(index: BindingIndex) = (index, new NamedUnboundVarSymbol[IntType](index.getName)) - indexBindings.getBindingIndices.map(createMapping(_)).toMap - } - - def apply(assignment: Assignment) { - //val declarations = collectDeclarations(assignment) - //for(declaration <- declarations) code append declaration+"\n" - generateCode(assignment) - } - - - /* - val code = new StringBuilder() - val nameManager = new NameManager() - - def collectDeclarations(term: IterationSpace) : Set[String] = { - val declarations = for(index <- term.getIndices; - declaration <- index.getDeclarations(nameManager)) yield declaration - - var declarationsSet = declarations.toSet - for (op <- term.getOperands) declarationsSet ++= collectDeclarations(op) - declarationsSet - } - */ - - private def buildStatement(space: IterationSpace) : ProducerStatement = { - // TODO: Until we can handle multi-operand nodes - assert(space.getOperands.size < 2) - - var operandProducers : Map[IterationSpace, ProducerStatement] = Map.empty - for(operand <- space.getOperands) { - val opProducer = buildStatement(operand) - - for(discreteIndex <- operand.getDiscreteIndices) indexBindings.getBindingIndex(discreteIndex) match { - case Some(bindingIndex) => { - val symbol = indexSymbols.get(bindingIndex).get - val newData = opProducer.addPredicate(symbol |==| discreteIndex.getValue) - } - case _ => () - } - - for(spatialIndex <- operand.getSpatialIndices) indexBindings.getBindingIndex(spatialIndex) match { - case Some(bindingIndex) => { - val symbol = indexSymbols.get(bindingIndex).get - val newData = opProducer.addPredicate(symbol |==| spatialIndex.getValue) - } - case _ => () - } - - operandProducers += operand -> opProducer - } - - space.getProducer(operandProducers) - } - - def generateCode(space: IterationSpace) { - val allSpaces = IterationSpace.flattenPostorder(space) - val allIndices = allSpaces flatMap (_.getIndices) - - println("Operations:") - for(op <- IterationSpace.sort(allSpaces)) - println(op) - println("\nIndices:") - for (i <- allIndices) - println(i) - println("") - - val statement = buildStatement(space) - val fortranGenerator = new FortranGenerator - val code = fortranGenerator(statement) - println(code) - } -} diff --git a/src/ofc/generators/onetep/Dictionary.scala b/src/ofc/generators/onetep/Dictionary.scala new file mode 100644 index 0000000..851a255 --- /dev/null +++ b/src/ofc/generators/onetep/Dictionary.scala @@ -0,0 +1,38 @@ +package ofc.generators.onetep +import ofc.parser.Identifier +import ofc.InvalidInputException + +class Dictionary { + import scala.collection.mutable.HashMap + + var scalars = new HashMap[Identifier, Scalar] + var fields = new HashMap[Identifier, Field] + var indices = new HashMap[Identifier, NamedIndex] + + def add(id: Identifier, scalar: Scalar) { + scalars += id -> scalar + } + + def add(id: Identifier, field: Field) { + fields += id -> field + } + + def add(id: Identifier, index: NamedIndex) { + indices += id -> index + } + + def getScalar(id: Identifier) : Scalar = scalars.get(id) match { + case Some(s) => s + case None => throw new InvalidInputException("Unknown scalar operand "+id.getName) + } + + def getField(id: Identifier) : Field = fields.get(id) match { + case Some(f) => f + case None => throw new InvalidInputException("Unknown field operand "+id.getName) + } + + def getIndex(id: Identifier) : NamedIndex = indices.get(id) match { + case Some(i) => i + case None => throw new InvalidInputException("Unknown index operand "+id.getName) + } +} diff --git a/src/ofc/generators/onetep/Field.scala b/src/ofc/generators/onetep/Field.scala new file mode 100644 index 0000000..62e5805 --- /dev/null +++ b/src/ofc/generators/onetep/Field.scala @@ -0,0 +1,4 @@ +package ofc.generators.onetep + +trait Field { +} diff --git a/src/ofc/generators/onetep/FieldAccess.scala b/src/ofc/generators/onetep/FieldAccess.scala new file mode 100644 index 0000000..120f156 --- /dev/null +++ b/src/ofc/generators/onetep/FieldAccess.scala @@ -0,0 +1,4 @@ +package ofc.generators.onetep + +class FieldAccess(op: Field, indices: Seq[NamedIndex]) extends Field + diff --git a/src/ofc/generators/onetep/GeneralInnerProduct.scala b/src/ofc/generators/onetep/GeneralInnerProduct.scala deleted file mode 100644 index 267b7e5..0000000 --- a/src/ofc/generators/onetep/GeneralInnerProduct.scala +++ /dev/null @@ -1,46 +0,0 @@ -package ofc.generators.onetep -/* -class GeneralInnerProduct(operands: List[IterationSpace], removedIndices: Set[Index]) extends IterationSpace { - - class DenseSpatialIndex(parent: GeneralInnerProduct, original: SpatialIndex) extends SpatialIndex{ - def getDependencies = Set() - def getName = "dense_spatial_index" - def getDenseWidth(names: NameManager) = original.getDenseWidth(names) - def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+getDenseWidth(names) - def generateIterationFooter(names: NameManager) = "end do" - def getDeclarations(names: NameManager) = List("integer :: "+names(this)) - } - - class DenseDiscreteIndex(parent: GeneralInnerProduct, original: DiscreteIndex) extends DiscreteIndex { - def getDependencies = Set() - def getName = "dense_discrete_index" - def getDenseWidth(names: NameManager) = original.getDenseWidth(names) - def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+getDenseWidth(names) - def generateIterationFooter(names: NameManager) = "end do" - def getDeclarations(names: NameManager) = List("integer :: "+names(this)) - } - - val spatialIndices = - for(op <- operands; index <- op.getSpatialIndices; if (!removedIndices.contains(index))) yield - if (index.getDependencies.intersect(removedIndices).isEmpty) - index - else - new DenseSpatialIndex(this, index) - - val discreteIndices = - for(op <- operands; index <- op.getDiscreteIndices; if (!removedIndices.contains(index))) yield - if (index.getDependencies.intersect(removedIndices).isEmpty) - index - else - new DenseDiscreteIndex(this, index) - - def getOperands = operands - def getSpatialIndices = spatialIndices - def getDiscreteIndices = discreteIndices - def getExternalIndices = Set() - - def getConsumerGenerator = None - def getTransformGenerator = None - def getProducerGenerator = None -} -*/ diff --git a/src/ofc/generators/onetep/Index.scala b/src/ofc/generators/onetep/Index.scala deleted file mode 100644 index 7eee2d7..0000000 --- a/src/ofc/generators/onetep/Index.scala +++ /dev/null @@ -1,15 +0,0 @@ -package ofc.generators.onetep -import ofc.codegen.{Expression,IntType} - -trait Index { - def getName: String - def getValue: Expression[IntType] - //def getMinimumValue : Expression[IntType] - //def getLength : Expression[IntType] -} - -trait DiscreteIndex extends Index -trait SpatialIndex extends Index -trait FunctionSpatialIndex extends SpatialIndex { - def getFunctionCentre: Expression[IntType] -} diff --git a/src/ofc/generators/onetep/IndexBindings.scala b/src/ofc/generators/onetep/IndexBindings.scala deleted file mode 100644 index 7cab597..0000000 --- a/src/ofc/generators/onetep/IndexBindings.scala +++ /dev/null @@ -1,40 +0,0 @@ -package ofc.generators.onetep -import ofc.LogicError - -class IndexBindings { - import scala.collection.mutable.{Set,HashSet, HashMap} - - val spatial = new HashMap[BindingIndex, Set[SpatialIndex]] - val discrete = new HashMap[BindingIndex,Set[DiscreteIndex]] - - override def toString = spatial.mkString("\n") + "\n" + discrete.mkString("\n") - - def add(binding: BindingIndex, index: SpatialIndex) = spatial.getOrElseUpdate(binding, new HashSet()) += index - def add(binding: BindingIndex, index: DiscreteIndex) = discrete.getOrElseUpdate(binding, new HashSet()) += index - - def contains(index: SpatialIndex) : Boolean = getBindingIndex(index) match { - case Some(_) => true - case _ => false - } - - def contains(index: DiscreteIndex) : Boolean = getBindingIndex(index) match { - case Some(_) => true - case _ => false - } - - def getBindingIndex(index: SpatialIndex) : Option[BindingIndex] = { - for((bindingIndex, spatialIndices) <- spatial; if spatialIndices.contains(index)) - return Some(bindingIndex) - - None - } - - def getBindingIndex(index: DiscreteIndex) : Option[BindingIndex] = { - for((bindingIndex, discreteIndices) <- discrete; if discreteIndices.contains(index)) - return Some(bindingIndex) - - None - } - - def getBindingIndices = spatial.keys ++ discrete.keys -} diff --git a/src/ofc/generators/onetep/InnerProduct.scala b/src/ofc/generators/onetep/InnerProduct.scala new file mode 100644 index 0000000..516e299 --- /dev/null +++ b/src/ofc/generators/onetep/InnerProduct.scala @@ -0,0 +1,3 @@ +package ofc.generators.onetep + +class InnerProduct(left: Field, right: Field) extends Scalar diff --git a/src/ofc/generators/onetep/IterationSpace.scala b/src/ofc/generators/onetep/IterationSpace.scala deleted file mode 100644 index d159245..0000000 --- a/src/ofc/generators/onetep/IterationSpace.scala +++ /dev/null @@ -1,40 +0,0 @@ -package ofc.generators.onetep -import ofc.codegen.{Statement,ProducerStatement,NullStatement,Expression,FloatType} - -object IterationSpace { - def sort(spaces : Traversable[IterationSpace]) : Seq[IterationSpace] = { - def helper(input: IterationSpace, seen: collection.mutable.Set[IterationSpace]) : Seq[IterationSpace] = - if (seen add input) - input.getOperands.flatMap(helper(_, seen)) :+ input - else - Nil - - val seen = collection.mutable.Set[IterationSpace]() - spaces.toList.flatMap(helper(_, seen)) - } - - def flattenPostorder(term: IterationSpace) : Seq[IterationSpace] = - term.getOperands.toSeq.flatMap(flattenPostorder(_)).+:(term) -} - -trait IterationSpace { - def getOperands : Seq[IterationSpace] - def getSpatialIndices : Seq[SpatialIndex] - def getDiscreteIndices : Seq[DiscreteIndex] - def getDataValue : Expression[FloatType] - def getIndices : Set[Index] = (getSpatialIndices ++ getDiscreteIndices).toSet - def getDependencies : Set[IterationSpace] = { - val operands = getOperands - operands.toSet ++ operands.flatMap(_.getDependencies) - } - def getProducer(ancestors: Map[IterationSpace, ProducerStatement]) : ProducerStatement -} - -trait DataSpace extends IterationSpace { - def getOperands = Nil - def getProducer(ancestors: Map[IterationSpace, ProducerStatement]) = this.getProducer - def getProducer : ProducerStatement -} - -trait Matrix extends DataSpace -trait FunctionSet extends DataSpace diff --git a/src/ofc/generators/onetep/Laplacian.scala b/src/ofc/generators/onetep/Laplacian.scala index abab4c4..dc31ba1 100644 --- a/src/ofc/generators/onetep/Laplacian.scala +++ b/src/ofc/generators/onetep/Laplacian.scala @@ -1,13 +1,3 @@ package ofc.generators.onetep -/* -class Laplacian(op: IterationSpace) extends IterationSpace { - def getOperands = List(op) - def getSpatialIndices = op.getSpatialIndices - def getDiscreteIndices = op.getDiscreteIndices - def getExternalIndices = Set() - def getConsumerGenerator = None - def getTransformGenerator = None - def getProducerGenerator = None -} -*/ +class Laplacian(op: Field) extends Field diff --git a/src/ofc/generators/onetep/Matrix.scala b/src/ofc/generators/onetep/Matrix.scala new file mode 100644 index 0000000..686d759 --- /dev/null +++ b/src/ofc/generators/onetep/Matrix.scala @@ -0,0 +1,4 @@ +package ofc.generators.onetep + +trait Matrix { +} diff --git a/src/ofc/generators/onetep/NamedIndex.scala b/src/ofc/generators/onetep/NamedIndex.scala new file mode 100644 index 0000000..271f5e6 --- /dev/null +++ b/src/ofc/generators/onetep/NamedIndex.scala @@ -0,0 +1,3 @@ +package ofc.generators.onetep + +class NamedIndex(name: String) diff --git a/src/ofc/generators/onetep/PPDFunctionSet.scala b/src/ofc/generators/onetep/PPDFunctionSet.scala index 60d3918..45081cf 100644 --- a/src/ofc/generators/onetep/PPDFunctionSet.scala +++ b/src/ofc/generators/onetep/PPDFunctionSet.scala @@ -1,6 +1,6 @@ package ofc.generators.onetep import ofc.codegen._ - +/* object PPDFunctionSet { private class SphereIndex(name: String, value: Expression[IntType]) extends DiscreteIndex { def getName = name @@ -95,7 +95,11 @@ object PPDFunctionSet { new PPDFunctionSet(discreteIndices, spatialIndices, dataValue, producer) } } +*/ + +class PPDFunctionSet(basisName: String, dataName: String) extends Field + /* class PPDFunctionSet private(discreteIndices: Seq[DiscreteIndex], spatialIndices: Seq[SpatialIndex], data: Expression[FloatType], producer: ProducerStatement) extends FunctionSet { @@ -105,3 +109,4 @@ class PPDFunctionSet private(discreteIndices: Seq[DiscreteIndex], def getSpatialIndices = spatialIndices def getDataValue = data } +*/ diff --git a/src/ofc/generators/onetep/Reciprocal.scala b/src/ofc/generators/onetep/Reciprocal.scala deleted file mode 100644 index 1391b79..0000000 --- a/src/ofc/generators/onetep/Reciprocal.scala +++ /dev/null @@ -1,25 +0,0 @@ -package ofc.generators.onetep -/* -class Reciprocal(op: IterationSpace) extends IterationSpace { - class BlockIndex(parent: Reciprocal, dimension: Int, original: SpatialIndex) extends SpatialIndex { - def getName = "reciprocal_index_" + dimension - def getDependencies = Set() - def getDenseWidth(names: NameManager) = original.getDenseWidth(names) - def generateIterationHeader(names: NameManager) = "do "+names(this)+"=1,"+getDenseWidth(names) - def generateIterationFooter(names: NameManager) = "end do" - def getDeclarations(names: NameManager) = List("integer :: "+names(this)) - } - - val spatialIndices = for (dimension <- 0 until op.getSpatialIndices.size) yield - new BlockIndex(this, dimension, op.getSpatialIndices(dimension)) - - def getOperands = List(op) - def getSpatialIndices = spatialIndices.toList - def getDiscreteIndices = op.getDiscreteIndices - def getExternalIndices = Set() - - def getConsumerGenerator = None - def getTransformGenerator = None - def getProducerGenerator = None -} -*/ diff --git a/src/ofc/generators/onetep/SPAM3.scala b/src/ofc/generators/onetep/SPAM3.scala index 6c50428..ff2f167 100644 --- a/src/ofc/generators/onetep/SPAM3.scala +++ b/src/ofc/generators/onetep/SPAM3.scala @@ -1,12 +1,7 @@ package ofc.generators.onetep import ofc.codegen.{ProducerStatement,NullStatement,Comment, FloatLiteral} -class SPAM3(name : String) extends Matrix { +class SPAM3(name : String) extends Scalar { override def toString = name def getName = name - - def getSpatialIndices = Nil - def getDiscreteIndices = Nil - def getDataValue = new FloatLiteral(0.0) - def getProducer = new ProducerStatement } diff --git a/src/ofc/generators/onetep/Scalar.scala b/src/ofc/generators/onetep/Scalar.scala index 97fe64e..5e673a6 100644 --- a/src/ofc/generators/onetep/Scalar.scala +++ b/src/ofc/generators/onetep/Scalar.scala @@ -1,11 +1,4 @@ package ofc.generators.onetep -/* -class Scalar(value: Double) extends DataSpace { - def getSpatialIndices = Nil - def getDiscreteIndices = Nil - def getExternalIndices = Set() - def getProducerGenerator = Some(new ProducerGenerator { - def generate(names: NameManager) = value.toString - }) + +trait Scalar { } -*/ diff --git a/src/ofc/generators/onetep/ScalarAccess.scala b/src/ofc/generators/onetep/ScalarAccess.scala new file mode 100644 index 0000000..0b7e5c0 --- /dev/null +++ b/src/ofc/generators/onetep/ScalarAccess.scala @@ -0,0 +1,4 @@ +package ofc.generators.onetep + +class ScalarAccess(op: Scalar, indices: Seq[NamedIndex]) extends Scalar + diff --git a/src/ofc/generators/onetep/ScalarLiteral.scala b/src/ofc/generators/onetep/ScalarLiteral.scala new file mode 100644 index 0000000..fc38f9a --- /dev/null +++ b/src/ofc/generators/onetep/ScalarLiteral.scala @@ -0,0 +1,3 @@ +package ofc.generators.onetep + +class ScalarLiteral(s: Double) extends Scalar diff --git a/src/ofc/generators/onetep/ScaledField.scala b/src/ofc/generators/onetep/ScaledField.scala new file mode 100644 index 0000000..070b0fd --- /dev/null +++ b/src/ofc/generators/onetep/ScaledField.scala @@ -0,0 +1,3 @@ +package ofc.generators.onetep + +class ScaledField(op: Field, factor: Scalar) extends Field diff --git a/src/ofc/generators/onetep/SpatialRestriction.scala b/src/ofc/generators/onetep/SpatialRestriction.scala deleted file mode 100644 index 4809345..0000000 --- a/src/ofc/generators/onetep/SpatialRestriction.scala +++ /dev/null @@ -1,54 +0,0 @@ -package ofc.generators.onetep -import ofc.InvalidInputException -import ofc.codegen._ - -object SpatialRestriction { - import OnetepTypes._ - - private val pubFFTBoxWidth = for (dim <- 0 to 2) yield FFTBoxInfo.public % FFTBoxInfo.totalPts(dim) - private val ppdWidth = for(dim <- 0 to 2) yield CellInfo.public % CellInfo.ppdWidth(dim) - - private class RestrictionIndex(name: String, value: Expression[IntType]) extends SpatialIndex { - def getName = name - def getValue = value - } - - def apply(op: IterationSpace) : SpatialRestriction = { - import OnetepTypes._ - - val inputIndices = for(index <- op.getSpatialIndices) yield - index match { - case (f: FunctionSpatialIndex) => f - case _ => throw new InvalidInputException("Input to SpatialRestriction must be a function") - } - - val ppdWidths = for(dim <- 0 to 2) yield CellInfo.public % CellInfo.ppdWidth(dim) - val cellWidthPPDs = for(dim <- 0 to 2) yield CellInfo.public % CellInfo.numPPDs(dim) - val cellWidthPts = for(dim <- 0 to 2) yield cellWidthPPDs(dim) * ppdWidths(dim) - val functionCentre = inputIndices.map(_.getFunctionCentre) - - val producer = new ProducerStatement - val origin = for (dim <- 0 to 2) - yield producer.addExpression("fftbox_origin_pt"+(dim+1), (cellWidthPts(dim) + functionCentre(dim) - pubFFTBoxWidth(dim)/2) % cellWidthPts(dim)) - - val offset = for (dim <- 0 to 2) - yield producer.addExpression("fftbox_offset_pt"+(dim+1), - (inputIndices(dim).getValue - origin(dim) + cellWidthPts(dim)) % cellWidthPts(dim)) - - for (dim <- 0 to 2) - producer.addPredicate(offset(dim) |<| pubFFTBoxWidth(dim)) - - val indices = for(dim <- 0 to 2) yield new RestrictionIndex("restriction_pos"+(dim+1), offset(dim)) - new SpatialRestriction(op, indices, producer) - } -} - -class SpatialRestriction private(op: IterationSpace, - spatialIndices: Seq[SpatialIndex], producer: ProducerStatement) extends IterationSpace { - def getOperands = List(op) - def getDiscreteIndices = Nil - def getDataValue = op.getDataValue - def getProducer(ancestors: Map[IterationSpace, ProducerStatement]) = producer merge ancestors.get(op).get - def getSpatialIndices = spatialIndices -} - diff --git a/src/ofc/generators/onetep/TreeBuilder.scala b/src/ofc/generators/onetep/TreeBuilder.scala deleted file mode 100644 index 4011db1..0000000 --- a/src/ofc/generators/onetep/TreeBuilder.scala +++ /dev/null @@ -1,119 +0,0 @@ -package ofc.generators.onetep - -import ofc.parser -import ofc.parser.Identifier -import ofc.{InvalidInputException,UnimplementedException} - -case class BindingIndex(name : String) { - override def toString() = name - def getName = name -} - -class Dictionary { - import scala.collection.mutable.HashMap - - var matrices = new HashMap[Identifier, Matrix] - var functionSets = new HashMap[Identifier, FunctionSet] - var indices = new HashMap[Identifier, BindingIndex] - - def getData(id: Identifier) = - matrices.get(id) match { - case Some(mat) => mat - case None => functionSets.get(id) match { - case Some(functionSet) => functionSet - case None => throw new InvalidInputException("Unknown identifier "+id.name) - } - } - - def getIndex(id: Identifier) = - indices.get(id) match { - case Some(index) => index - case None => throw new InvalidInputException("Unknown index "+id.name) - } -} - -class TreeBuilder(dictionary : Dictionary) { - val indexBindings = new IndexBindings - var nextBindingIndexID = 0 - - private def newBindingIndex() = { - val index = new BindingIndex("synthetic_"+nextBindingIndexID) - nextBindingIndexID += 1 - index - } - - def apply(lhs: parser.IndexedIdentifier, rhs: parser.Expression) = { - val lhsTree = buildIndexedSpace(lhs) - val rhsTree = buildExpression(rhs) - - lhsTree match { - case (lhsTree: DataSpace) => new Assignment(indexBindings, lhsTree, rhsTree) - case _ => throw new InvalidInputException("Non-assignable expression on LHS of assignment.") - } - } - - def getIndexBindings = indexBindings - - private def buildIndexedSpace(term: parser.IndexedIdentifier) : IterationSpace = { - val dataSpace = dictionary.getData(term.id) - val indices = for(bindingID <- term.indices) yield dictionary.getIndex(bindingID) - - if (indices.size != dataSpace.getDiscreteIndices.size) - throw new InvalidInputException("Incorrect number of indices for object "+term.id.name) - - for(i <- indices zip dataSpace.getDiscreteIndices) - indexBindings.add(i._1, i._2) - - /* - dataSpace match { - case (dataSpace: DataSpace) => new DataSpaceIndexBinding(dataSpace) - case iterationSpace => new IterationSpaceIndexBinding(iterationSpace) - } - */ - dataSpace - } - - private def buildIndex(term: parser.Expression) : BindingIndex = { - term match { - case (indexedID: parser.IndexedIdentifier) => { - if (indexedID.indices.nonEmpty) - throw new InvalidInputException("Tried to parse expression "+term+" as index but it is indexed.") - else - dictionary.getIndex(indexedID.id) - } - case other => throw new InvalidInputException("Cannot parse expression "+other+" as index.") - } - } - - private def buildExpression(term: parser.Expression) : IterationSpace = { - import parser._ - - term match { - case (t: IndexedIdentifier) => buildIndexedSpace(t) - case Operator(Identifier("fftbox"), List(op)) => SpatialRestriction(buildExpression(op)) - - /* - case ScalarConstant(s) => new Scalar(s) - case Multiplication(a, b) => - new GeneralInnerProduct(List(buildExpression(a), buildExpression(b)), Set()) - case Division(a, b) => - throw new UnimplementedException("Semantics of division not yet defined, or implemented.") - case Operator(Identifier("inner"), List(a,b)) => { - val aExpression = buildExpression(a) - val bExpression = buildExpression(b) - - for ((left,right) <- aExpression.getSpatialIndices zip bExpression.getSpatialIndices) { - val bindingIndex = newBindingIndex() - indexBindings.add(bindingIndex, left) - indexBindings.add(bindingIndex, right) - } - - new GeneralInnerProduct(List(aExpression, bExpression), (aExpression.getSpatialIndices ++ bExpression.getSpatialIndices).toSet) - } - case Operator(Identifier("reciprocal"), List(op)) => new Reciprocal(buildExpression(op)) - case Operator(Identifier("laplacian"), List(op)) => new Laplacian(buildExpression(op)) - case Operator(Identifier(name), _) => throw new UnimplementedException("Unknown or incorrectly called operator: "+name) - */ - } - } -} -- 2.47.3