From: Francis Russell Date: Tue, 1 May 2012 04:33:32 +0000 (+0100) Subject: Skeleton for fragment-based code generation. X-Git-Url: https://git.unchartedbackwaters.co.uk/w/?a=commitdiff_plain;h=7e33a04bfca12d1f3b78bd65a5b0133a447f30e4;p=francis%2Fofc.git Skeleton for fragment-based code generation. --- diff --git a/src/ofc/expression/Expression.scala b/src/ofc/expression/Expression.scala index 53dc599..9fa2d48 100644 --- a/src/ofc/expression/Expression.scala +++ b/src/ofc/expression/Expression.scala @@ -76,5 +76,3 @@ class FunctionSet(val id: Identifier) extends Field with NamedOperand { class Matrix(val id: Identifier) extends Scalar with NamedOperand { def numIndices = 2 } - - diff --git a/src/ofc/generators/Onetep.scala b/src/ofc/generators/Onetep.scala index 5acbad3..22b2440 100644 --- a/src/ofc/generators/Onetep.scala +++ b/src/ofc/generators/Onetep.scala @@ -12,8 +12,9 @@ class Onetep extends Generator { expression.Assignment, targetSpecific : Seq[parser.TargetAssignment]) { buildDictionary(exprDictionary, targetSpecific) - val assignment = new Assignment(buildScalarExpression(exprAssignment.lhs), buildScalarExpression(exprAssignment.rhs)) + val codeGenerator = new CodeGenerator(dictionary) + codeGenerator(assignment) } private def buildDictionary(exprDictionary: expression.Dictionary, targetSpecific : Seq[parser.TargetAssignment]) { @@ -32,7 +33,7 @@ class Onetep extends Generator { } for(index <- exprDictionary.getIndices) { - dictionary.add(index.getIdentifier, new NamedIndex(index.getName)) + dictionary.addIndex(index.getIdentifier, new NamedIndex(index.getName)) } } @@ -53,48 +54,51 @@ class Onetep extends Generator { case expression.ScalarLiteral(s) => new ScalarLiteral(s) case expression.InnerProduct(l, r) => new InnerProduct(buildFieldExpression(l), buildFieldExpression(r)) case expression.ScalarIndexingOperation(op, indices) => buildScalarAccess(op, indices) - case (m: expression.Matrix) => dictionary.getScalar(m.getIdentifier) + case (_: expression.Matrix) => throw new InvalidInputException("Cannot handle un-indexed matrices.") } } - private def buildScalarAccess(op: expression.Scalar, indices: Seq[expression.Index]) : Scalar = { - val base = buildScalarExpression(op) - new ScalarAccess(base, getIndex(indices)) - } + private def buildScalarAccess(op: expression.Scalar, indices: Seq[expression.Index]) : Scalar = + op match { + case (matrix: expression.Matrix) => dictionary.getScalar(matrix.getIdentifier)(getIndex(indices)) + case _ => throw new InvalidInputException("Can only index leaf-matrices.") + } - private def buildFieldAccess(op: expression.Field, indices: Seq[expression.Index]) : Field = { - val base = buildFieldExpression(op) - new FieldAccess(base, getIndex(indices)) - } + private def buildFieldAccess(op: expression.Field, indices: Seq[expression.Index]) : Field = + op match { + case (functionSet: expression.FunctionSet) => dictionary.getField(functionSet.getIdentifier)(getIndex(indices)) + case _ => throw new InvalidInputException("Can only index function-sets.") + } private def buildFieldExpression(field: expression.Field) : Field = { field match { case expression.Laplacian(op) => new Laplacian(buildFieldExpression(op)) case expression.FieldScaling(op, scale) => new ScaledField(buildFieldExpression(op), buildScalarExpression(scale)) case expression.FieldIndexingOperation(op, indices) => buildFieldAccess(op, indices) - case (f: expression.FunctionSet) => dictionary.getField(f.getIdentifier) + case (_: expression.FunctionSet) => throw new InvalidInputException("Cannot handle un-indexed function sets.") } } - def buildMatrix(id: parser.Identifier, call : Option[parser.FunctionCall]) { + private def buildMatrix(id: parser.Identifier, call : Option[parser.FunctionCall]) { import parser._ call match { case Some(FunctionCall(matType, params)) => (matType, params) match { - case (Identifier("SPAM3"), ParameterList(StringParameter(name))) => dictionary.add(id, new SPAM3(name)) + case (Identifier("SPAM3"), ParameterList(StringParameter(name))) => + dictionary.addScalar(id, new SPAM3(name, _: Seq[NamedIndex])) case _ => throw new InvalidInputException("Unknown usage of type: "+matType.name) } case _ => throw new InvalidInputException("Undefined concrete type for matrix: "+id.name) } } - def buildFunctionSet(id: parser.Identifier, call : Option[parser.FunctionCall]) { + private def buildFunctionSet(id: parser.Identifier, call : Option[parser.FunctionCall]) { import parser._ call match { case Some(FunctionCall(fSetType, params)) => (fSetType, params) match { case (Identifier("PPDFunctionSet"), ParameterList(StringParameter(basis), StringParameter(data))) => - dictionary.add(id, new PPDFunctionSet(basis, data)) + dictionary.addField(id, new PPDFunctionSet(basis, data, _: Seq[NamedIndex])) case _ => throw new InvalidInputException("Unknown usage of type: "+fSetType.name) } case _ => throw new InvalidInputException("Undefined concrete type for function set: "+id.name) diff --git a/src/ofc/generators/onetep/Assignment.scala b/src/ofc/generators/onetep/Assignment.scala index 04e17ee..f5972c3 100644 --- a/src/ofc/generators/onetep/Assignment.scala +++ b/src/ofc/generators/onetep/Assignment.scala @@ -1,3 +1,3 @@ package ofc.generators.onetep -class Assignment(lhs: Scalar, rhs: Scalar) +class Assignment(val lhs: Scalar, val rhs: Scalar) diff --git a/src/ofc/generators/onetep/CodeGenerator.scala b/src/ofc/generators/onetep/CodeGenerator.scala new file mode 100644 index 0000000..6b429ed --- /dev/null +++ b/src/ofc/generators/onetep/CodeGenerator.scala @@ -0,0 +1,38 @@ +package ofc.generators.onetep +import ofc.codegen._ + +class CodeGenerator(dictionary: Dictionary) { + val indexMap : Map[NamedIndex, Expression[IntType]] = { + for(index <- dictionary.getIndices) yield + (index, new VarRef[IntType](new DeclaredVarSymbol[IntType](index.getName))) + }.toMap + + class Context extends GenerationContext { + val block = new BlockStatement + + def addDeclaration(sym: VarSymbol[_ <: Type]) { + block.addDeclaration(sym) + } + + def +=(stat: Statement) { + block += stat + } + + def getStatement: Statement = block + } + + def apply(assignment: Assignment) { + val lhs = assignment.lhs + val rhs = assignment.rhs + + val context = new Context + val rhsFragment = rhs.getFragment(indexMap) + + rhsFragment.setup(context) + rhsFragment.teardown(context) + + val generator = new FortranGenerator + val code = generator(context.getStatement) + println(code) + } +} diff --git a/src/ofc/generators/onetep/Dictionary.scala b/src/ofc/generators/onetep/Dictionary.scala index 851a255..4466e58 100644 --- a/src/ofc/generators/onetep/Dictionary.scala +++ b/src/ofc/generators/onetep/Dictionary.scala @@ -5,28 +5,28 @@ import ofc.InvalidInputException class Dictionary { import scala.collection.mutable.HashMap - var scalars = new HashMap[Identifier, Scalar] - var fields = new HashMap[Identifier, Field] + var scalars = new HashMap[Identifier, Seq[NamedIndex] => Scalar] + var fields = new HashMap[Identifier, Seq[NamedIndex] => Field] var indices = new HashMap[Identifier, NamedIndex] - def add(id: Identifier, scalar: Scalar) { - scalars += id -> scalar + def addScalar(id: Identifier, scalarGenerator: Seq[NamedIndex] => Scalar) { + scalars += id -> scalarGenerator } - def add(id: Identifier, field: Field) { - fields += id -> field + def addField(id: Identifier, fieldGenerator: Seq[NamedIndex] => Field) { + fields += id -> fieldGenerator } - def add(id: Identifier, index: NamedIndex) { + def addIndex(id: Identifier, index: NamedIndex) { indices += id -> index } - def getScalar(id: Identifier) : Scalar = scalars.get(id) match { + def getScalar(id: Identifier) = scalars.get(id) match { case Some(s) => s case None => throw new InvalidInputException("Unknown scalar operand "+id.getName) } - def getField(id: Identifier) : Field = fields.get(id) match { + def getField(id: Identifier) = fields.get(id) match { case Some(f) => f case None => throw new InvalidInputException("Unknown field operand "+id.getName) } @@ -35,4 +35,6 @@ class Dictionary { case Some(i) => i case None => throw new InvalidInputException("Unknown index operand "+id.getName) } + + def getIndices = indices.values } diff --git a/src/ofc/generators/onetep/Field.scala b/src/ofc/generators/onetep/Field.scala index 62e5805..1e8514c 100644 --- a/src/ofc/generators/onetep/Field.scala +++ b/src/ofc/generators/onetep/Field.scala @@ -1,4 +1,6 @@ package ofc.generators.onetep +import ofc.codegen._ trait Field { + def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : FieldFragment } diff --git a/src/ofc/generators/onetep/FieldAccess.scala b/src/ofc/generators/onetep/FieldAccess.scala deleted file mode 100644 index 120f156..0000000 --- a/src/ofc/generators/onetep/FieldAccess.scala +++ /dev/null @@ -1,4 +0,0 @@ -package ofc.generators.onetep - -class FieldAccess(op: Field, indices: Seq[NamedIndex]) extends Field - diff --git a/src/ofc/generators/onetep/FieldFragment.scala b/src/ofc/generators/onetep/FieldFragment.scala new file mode 100644 index 0000000..a937371 --- /dev/null +++ b/src/ofc/generators/onetep/FieldFragment.scala @@ -0,0 +1,11 @@ +package ofc.generators.onetep + +trait FieldFragment extends Fragment { + def toReciprocal : ReciprocalFragment +} + +trait PsincFragment extends FieldFragment + +trait ReciprocalFragment extends FieldFragment { + def toReciprocal = this +} diff --git a/src/ofc/generators/onetep/Fragment.scala b/src/ofc/generators/onetep/Fragment.scala new file mode 100644 index 0000000..8ed8c8a --- /dev/null +++ b/src/ofc/generators/onetep/Fragment.scala @@ -0,0 +1,6 @@ +package ofc.generators.onetep + +trait Fragment { + def setup(context: GenerationContext) + def teardown(context: GenerationContext) +} diff --git a/src/ofc/generators/onetep/GenerationContext.scala b/src/ofc/generators/onetep/GenerationContext.scala new file mode 100644 index 0000000..f12700e --- /dev/null +++ b/src/ofc/generators/onetep/GenerationContext.scala @@ -0,0 +1,7 @@ +package ofc.generators.onetep +import ofc.codegen._ + +trait GenerationContext { + def addDeclaration(sym: VarSymbol[_ <: Type]) + def +=(stat: Statement) +} diff --git a/src/ofc/generators/onetep/InnerProduct.scala b/src/ofc/generators/onetep/InnerProduct.scala index 516e299..7f77246 100644 --- a/src/ofc/generators/onetep/InnerProduct.scala +++ b/src/ofc/generators/onetep/InnerProduct.scala @@ -1,3 +1,20 @@ package ofc.generators.onetep +import ofc.codegen._ -class InnerProduct(left: Field, right: Field) extends Scalar +class InnerProduct(left: Field, right: Field) extends Scalar { + + class LocalFragment(left: FieldFragment, right: FieldFragment) extends ScalarFragment { + def setup(context: GenerationContext) { + left.setup(context) + right.setup(context) + } + + def teardown(context: GenerationContext) { + left.teardown(context) + right.teardown(context) + } + } + + def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : ScalarFragment = + new LocalFragment(left.getFragment(indices), right.getFragment(indices)) +} diff --git a/src/ofc/generators/onetep/Laplacian.scala b/src/ofc/generators/onetep/Laplacian.scala index dc31ba1..72a31b5 100644 --- a/src/ofc/generators/onetep/Laplacian.scala +++ b/src/ofc/generators/onetep/Laplacian.scala @@ -1,3 +1,6 @@ package ofc.generators.onetep +import ofc.codegen._ -class Laplacian(op: Field) extends Field +class Laplacian(op: Field) extends Field { + def getFragment(indices: Map[NamedIndex, Expression[IntType]]) = op.getFragment(indices) +} diff --git a/src/ofc/generators/onetep/Matrix.scala b/src/ofc/generators/onetep/Matrix.scala deleted file mode 100644 index 686d759..0000000 --- a/src/ofc/generators/onetep/Matrix.scala +++ /dev/null @@ -1,4 +0,0 @@ -package ofc.generators.onetep - -trait Matrix { -} diff --git a/src/ofc/generators/onetep/NamedIndex.scala b/src/ofc/generators/onetep/NamedIndex.scala index 271f5e6..bc50ca1 100644 --- a/src/ofc/generators/onetep/NamedIndex.scala +++ b/src/ofc/generators/onetep/NamedIndex.scala @@ -1,3 +1,5 @@ package ofc.generators.onetep -class NamedIndex(name: String) +class NamedIndex(name: String) { + def getName = name +} diff --git a/src/ofc/generators/onetep/PPDFunctionSet.scala b/src/ofc/generators/onetep/PPDFunctionSet.scala index 45081cf..23b87ae 100644 --- a/src/ofc/generators/onetep/PPDFunctionSet.scala +++ b/src/ofc/generators/onetep/PPDFunctionSet.scala @@ -97,16 +97,25 @@ object PPDFunctionSet { } */ -class PPDFunctionSet(basisName: String, dataName: String) extends Field +class PPDFunctionSet(basisName: String, dataName: String, indices: Seq[NamedIndex]) extends Field { + class LocalFragment(parent: PPDFunctionSet) extends PsincFragment { + def setup(context: GenerationContext) {} + def teardown(context: GenerationContext) {} + def toReciprocal : ReciprocalFragment = new LocalReciprocal(parent) + } + + class LocalReciprocal(parent: PPDFunctionSet) extends ReciprocalFragment { + val fftbox = new DeclaredVarSymbol[ArrayType[FloatType]]("fftbox", new ArrayType[FloatType](3)) + + def setup(context: GenerationContext) { + context.addDeclaration(fftbox) + } - /* -class PPDFunctionSet private(discreteIndices: Seq[DiscreteIndex], - spatialIndices: Seq[SpatialIndex], data: Expression[FloatType], - producer: ProducerStatement) extends FunctionSet { + def teardown(context: GenerationContext) { + } + } + + def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : FieldFragment = + new LocalFragment(this) - def getProducer = producer - def getDiscreteIndices = discreteIndices - def getSpatialIndices = spatialIndices - def getDataValue = data } -*/ diff --git a/src/ofc/generators/onetep/SPAM3.scala b/src/ofc/generators/onetep/SPAM3.scala index ff2f167..bcffbce 100644 --- a/src/ofc/generators/onetep/SPAM3.scala +++ b/src/ofc/generators/onetep/SPAM3.scala @@ -1,7 +1,15 @@ package ofc.generators.onetep -import ofc.codegen.{ProducerStatement,NullStatement,Comment, FloatLiteral} +import ofc.codegen._ -class SPAM3(name : String) extends Scalar { - override def toString = name - def getName = name +class SPAM3(name : String, indices: Seq[NamedIndex]) extends Scalar { + class LocalFragment extends ScalarFragment { + def setup(context: GenerationContext) { + } + + def teardown(context: GenerationContext) { + } + } + + def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : ScalarFragment = + new LocalFragment } diff --git a/src/ofc/generators/onetep/Scalar.scala b/src/ofc/generators/onetep/Scalar.scala index 5e673a6..89a0f51 100644 --- a/src/ofc/generators/onetep/Scalar.scala +++ b/src/ofc/generators/onetep/Scalar.scala @@ -1,4 +1,6 @@ package ofc.generators.onetep +import ofc.codegen._ trait Scalar { + def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : ScalarFragment } diff --git a/src/ofc/generators/onetep/ScalarAccess.scala b/src/ofc/generators/onetep/ScalarAccess.scala deleted file mode 100644 index 0b7e5c0..0000000 --- a/src/ofc/generators/onetep/ScalarAccess.scala +++ /dev/null @@ -1,4 +0,0 @@ -package ofc.generators.onetep - -class ScalarAccess(op: Scalar, indices: Seq[NamedIndex]) extends Scalar - diff --git a/src/ofc/generators/onetep/ScalarFragment.scala b/src/ofc/generators/onetep/ScalarFragment.scala new file mode 100644 index 0000000..20ff425 --- /dev/null +++ b/src/ofc/generators/onetep/ScalarFragment.scala @@ -0,0 +1,3 @@ +package ofc.generators.onetep + +trait ScalarFragment extends Fragment diff --git a/src/ofc/generators/onetep/ScalarLiteral.scala b/src/ofc/generators/onetep/ScalarLiteral.scala index fc38f9a..12fb2dc 100644 --- a/src/ofc/generators/onetep/ScalarLiteral.scala +++ b/src/ofc/generators/onetep/ScalarLiteral.scala @@ -1,3 +1,16 @@ package ofc.generators.onetep +import ofc.codegen._ -class ScalarLiteral(s: Double) extends Scalar +class ScalarLiteral(s: Double) extends Scalar { + class LocalFragment(s: Double) extends ScalarFragment { + def setup(context: GenerationContext) { + } + + def teardown(context: GenerationContext) { + } + } + + def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : ScalarFragment = + new LocalFragment(s) + +} diff --git a/src/ofc/generators/onetep/ScaledField.scala b/src/ofc/generators/onetep/ScaledField.scala index 070b0fd..d7aa6db 100644 --- a/src/ofc/generators/onetep/ScaledField.scala +++ b/src/ofc/generators/onetep/ScaledField.scala @@ -1,3 +1,7 @@ package ofc.generators.onetep +import ofc.codegen._ -class ScaledField(op: Field, factor: Scalar) extends Field +class ScaledField(op: Field, factor: Scalar) extends Field { + def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : FieldFragment = + op.getFragment(indices) +}