]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Skeleton for fragment-based code generation.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Tue, 1 May 2012 04:33:32 +0000 (05:33 +0100)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Tue, 1 May 2012 04:33:32 +0000 (05:33 +0100)
21 files changed:
src/ofc/expression/Expression.scala
src/ofc/generators/Onetep.scala
src/ofc/generators/onetep/Assignment.scala
src/ofc/generators/onetep/CodeGenerator.scala [new file with mode: 0644]
src/ofc/generators/onetep/Dictionary.scala
src/ofc/generators/onetep/Field.scala
src/ofc/generators/onetep/FieldAccess.scala [deleted file]
src/ofc/generators/onetep/FieldFragment.scala [new file with mode: 0644]
src/ofc/generators/onetep/Fragment.scala [new file with mode: 0644]
src/ofc/generators/onetep/GenerationContext.scala [new file with mode: 0644]
src/ofc/generators/onetep/InnerProduct.scala
src/ofc/generators/onetep/Laplacian.scala
src/ofc/generators/onetep/Matrix.scala [deleted file]
src/ofc/generators/onetep/NamedIndex.scala
src/ofc/generators/onetep/PPDFunctionSet.scala
src/ofc/generators/onetep/SPAM3.scala
src/ofc/generators/onetep/Scalar.scala
src/ofc/generators/onetep/ScalarAccess.scala [deleted file]
src/ofc/generators/onetep/ScalarFragment.scala [new file with mode: 0644]
src/ofc/generators/onetep/ScalarLiteral.scala
src/ofc/generators/onetep/ScaledField.scala

index 53dc5990b992d948b2b942e33e250efdd4abc3fa..9fa2d48bd28751ee4c7fe2fbe40a77c698c8bd19 100644 (file)
@@ -76,5 +76,3 @@ class FunctionSet(val id: Identifier) extends Field with NamedOperand {
 class Matrix(val id: Identifier) extends Scalar with NamedOperand {
   def numIndices = 2
 }
-
-
index 5acbad34e5210d72b724292e6d7fbfbf289cec1a..22b24401f52a4d4ddfeb2952dba842d00fa9712d 100644 (file)
@@ -12,8 +12,9 @@ class Onetep extends Generator {
     expression.Assignment, targetSpecific : Seq[parser.TargetAssignment]) {
 
     buildDictionary(exprDictionary, targetSpecific)
-
     val assignment = new Assignment(buildScalarExpression(exprAssignment.lhs), buildScalarExpression(exprAssignment.rhs))
+    val codeGenerator = new CodeGenerator(dictionary)
+    codeGenerator(assignment)
   }
 
   private def buildDictionary(exprDictionary: expression.Dictionary, targetSpecific : Seq[parser.TargetAssignment]) {
@@ -32,7 +33,7 @@ class Onetep extends Generator {
     }
 
     for(index <- exprDictionary.getIndices) {
-      dictionary.add(index.getIdentifier, new NamedIndex(index.getName))
+      dictionary.addIndex(index.getIdentifier, new NamedIndex(index.getName))
     }
   }
 
@@ -53,48 +54,51 @@ class Onetep extends Generator {
       case expression.ScalarLiteral(s) => new ScalarLiteral(s)
       case expression.InnerProduct(l, r) => new InnerProduct(buildFieldExpression(l), buildFieldExpression(r))
       case expression.ScalarIndexingOperation(op, indices) => buildScalarAccess(op, indices)
-      case (m: expression.Matrix) => dictionary.getScalar(m.getIdentifier)
+      case (_: expression.Matrix) => throw new InvalidInputException("Cannot handle un-indexed matrices.")
     }
   }
 
-  private def buildScalarAccess(op: expression.Scalar, indices: Seq[expression.Index]) : Scalar = {
-    val base = buildScalarExpression(op)
-    new ScalarAccess(base, getIndex(indices))
-  }
+  private def buildScalarAccess(op: expression.Scalar, indices: Seq[expression.Index]) : Scalar =
+    op match {
+      case (matrix: expression.Matrix) => dictionary.getScalar(matrix.getIdentifier)(getIndex(indices))
+      case _ => throw new InvalidInputException("Can only index leaf-matrices.")
+    }
 
-  private def buildFieldAccess(op: expression.Field, indices: Seq[expression.Index]) : Field = {
-    val base = buildFieldExpression(op)
-    new FieldAccess(base, getIndex(indices))
-  }
+  private def buildFieldAccess(op: expression.Field, indices: Seq[expression.Index]) : Field =
+    op match {
+      case (functionSet: expression.FunctionSet) => dictionary.getField(functionSet.getIdentifier)(getIndex(indices))
+      case _ => throw new InvalidInputException("Can only index function-sets.")
+    }
 
   private def buildFieldExpression(field: expression.Field) : Field = {
     field match {
       case expression.Laplacian(op) => new Laplacian(buildFieldExpression(op))
       case expression.FieldScaling(op, scale) => new ScaledField(buildFieldExpression(op), buildScalarExpression(scale))
       case expression.FieldIndexingOperation(op, indices) => buildFieldAccess(op, indices)
-      case (f: expression.FunctionSet) => dictionary.getField(f.getIdentifier)
+      case (_: expression.FunctionSet) => throw new InvalidInputException("Cannot handle un-indexed function sets.")
     }
   }
 
-  def buildMatrix(id: parser.Identifier, call : Option[parser.FunctionCall]) {
+  private def buildMatrix(id: parser.Identifier, call : Option[parser.FunctionCall]) {
     import parser._
 
     call match {
       case Some(FunctionCall(matType, params)) => (matType, params) match {
-        case (Identifier("SPAM3"), ParameterList(StringParameter(name))) => dictionary.add(id, new SPAM3(name))
+        case (Identifier("SPAM3"), ParameterList(StringParameter(name))) => 
+          dictionary.addScalar(id, new SPAM3(name, _: Seq[NamedIndex]))
         case _ => throw new InvalidInputException("Unknown usage of type: "+matType.name)
       }
       case _ => throw new InvalidInputException("Undefined concrete type for matrix: "+id.name)
     }
   }
 
-  def buildFunctionSet(id: parser.Identifier, call : Option[parser.FunctionCall]) {
+  private def buildFunctionSet(id: parser.Identifier, call : Option[parser.FunctionCall]) {
     import parser._
 
     call match {
       case Some(FunctionCall(fSetType, params)) => (fSetType, params) match {
         case (Identifier("PPDFunctionSet"), ParameterList(StringParameter(basis), StringParameter(data))) => 
-          dictionary.add(id, new PPDFunctionSet(basis, data))
+          dictionary.addField(id, new PPDFunctionSet(basis, data, _: Seq[NamedIndex]))
         case _ => throw new InvalidInputException("Unknown usage of type: "+fSetType.name)
       }
       case _ => throw new InvalidInputException("Undefined concrete type for function set: "+id.name)
index 04e17ee9f81eec0c99ca5554f00e8af4d84fda28..f5972c3774f94e09f914cba553115d90d2bffa40 100644 (file)
@@ -1,3 +1,3 @@
 package ofc.generators.onetep
 
-class Assignment(lhs: Scalar, rhs: Scalar)
+class Assignment(val lhs: Scalar, val rhs: Scalar)
diff --git a/src/ofc/generators/onetep/CodeGenerator.scala b/src/ofc/generators/onetep/CodeGenerator.scala
new file mode 100644 (file)
index 0000000..6b429ed
--- /dev/null
@@ -0,0 +1,38 @@
+package ofc.generators.onetep
+import ofc.codegen._
+
+class CodeGenerator(dictionary: Dictionary) {
+  val indexMap : Map[NamedIndex, Expression[IntType]] = {
+    for(index <- dictionary.getIndices) yield
+      (index, new VarRef[IntType](new DeclaredVarSymbol[IntType](index.getName)))
+  }.toMap
+
+  class Context extends GenerationContext {
+    val block = new BlockStatement
+
+    def addDeclaration(sym: VarSymbol[_ <: Type]) {
+      block.addDeclaration(sym)
+    }
+
+    def +=(stat: Statement) {
+      block += stat
+    }
+
+    def getStatement: Statement = block
+  }
+
+  def apply(assignment: Assignment) {
+    val lhs = assignment.lhs
+    val rhs = assignment.rhs
+
+    val context = new Context
+    val rhsFragment = rhs.getFragment(indexMap)
+
+    rhsFragment.setup(context)
+    rhsFragment.teardown(context)
+
+    val generator = new FortranGenerator
+    val code = generator(context.getStatement)
+    println(code)
+  }
+}
index 851a255a7b42c94a311e1e2a073bae7f5686fa83..4466e58de885e21d2d4a6b7edf38782807ec83be 100644 (file)
@@ -5,28 +5,28 @@ import ofc.InvalidInputException
 class Dictionary {
   import scala.collection.mutable.HashMap
 
-  var scalars = new HashMap[Identifier, Scalar]
-  var fields = new HashMap[Identifier, Field]
+  var scalars = new HashMap[Identifier, Seq[NamedIndex] => Scalar]
+  var fields = new HashMap[Identifier, Seq[NamedIndex] => Field]
   var indices = new HashMap[Identifier, NamedIndex]
 
-  def add(id: Identifier, scalar: Scalar) {
-    scalars += id -> scalar
+  def addScalar(id: Identifier, scalarGenerator: Seq[NamedIndex] => Scalar) {
+    scalars += id -> scalarGenerator
   }
 
-  def add(id: Identifier, field: Field) {
-    fields += id -> field
+  def addField(id: Identifier, fieldGenerator: Seq[NamedIndex] => Field) {
+    fields += id -> fieldGenerator
   }
 
-  def add(id: Identifier, index: NamedIndex) {
+  def addIndex(id: Identifier, index: NamedIndex) {
     indices += id -> index
   }
 
-  def getScalar(id: Identifier) : Scalar = scalars.get(id) match {
+  def getScalar(id: Identifier) = scalars.get(id) match {
     case Some(s) => s
     case None => throw new InvalidInputException("Unknown scalar operand "+id.getName)
   }
 
-  def getField(id: Identifier) : Field = fields.get(id) match {
+  def getField(id: Identifier) = fields.get(id) match {
     case Some(f) => f
     case None => throw new InvalidInputException("Unknown field operand "+id.getName)
   }
@@ -35,4 +35,6 @@ class Dictionary {
     case Some(i) => i
     case None => throw new InvalidInputException("Unknown index operand "+id.getName)
   }
+
+  def getIndices = indices.values
 }
index 62e5805a7827c8c9865963da42e45fee5ffb2142..1e8514cbb74d8773a1e5ee93825ede6e139f5fd7 100644 (file)
@@ -1,4 +1,6 @@
 package ofc.generators.onetep
+import ofc.codegen._
 
 trait Field {
+  def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : FieldFragment
 }
diff --git a/src/ofc/generators/onetep/FieldAccess.scala b/src/ofc/generators/onetep/FieldAccess.scala
deleted file mode 100644 (file)
index 120f156..0000000
+++ /dev/null
@@ -1,4 +0,0 @@
-package ofc.generators.onetep
-
-class FieldAccess(op: Field, indices: Seq[NamedIndex]) extends Field
-
diff --git a/src/ofc/generators/onetep/FieldFragment.scala b/src/ofc/generators/onetep/FieldFragment.scala
new file mode 100644 (file)
index 0000000..a937371
--- /dev/null
@@ -0,0 +1,11 @@
+package ofc.generators.onetep
+
+trait FieldFragment extends Fragment {
+  def toReciprocal : ReciprocalFragment
+}
+
+trait PsincFragment extends FieldFragment
+
+trait ReciprocalFragment extends FieldFragment {
+  def toReciprocal = this
+}
diff --git a/src/ofc/generators/onetep/Fragment.scala b/src/ofc/generators/onetep/Fragment.scala
new file mode 100644 (file)
index 0000000..8ed8c8a
--- /dev/null
@@ -0,0 +1,6 @@
+package ofc.generators.onetep
+
+trait Fragment {
+  def setup(context: GenerationContext)
+  def teardown(context: GenerationContext)
+}
diff --git a/src/ofc/generators/onetep/GenerationContext.scala b/src/ofc/generators/onetep/GenerationContext.scala
new file mode 100644 (file)
index 0000000..f12700e
--- /dev/null
@@ -0,0 +1,7 @@
+package ofc.generators.onetep
+import ofc.codegen._
+
+trait GenerationContext {
+  def addDeclaration(sym: VarSymbol[_ <: Type])
+  def +=(stat: Statement)
+}
index 516e29940887131df1b2f6f60da0679abad98583..7f77246ae5583c7945a59b0eb4e0ca7215f9274d 100644 (file)
@@ -1,3 +1,20 @@
 package ofc.generators.onetep
+import ofc.codegen._
 
-class InnerProduct(left: Field, right: Field) extends Scalar
+class InnerProduct(left: Field, right: Field) extends Scalar {
+
+  class LocalFragment(left: FieldFragment, right: FieldFragment) extends ScalarFragment {
+    def setup(context: GenerationContext) {
+      left.setup(context)
+      right.setup(context)
+    }
+
+    def teardown(context: GenerationContext) {
+      left.teardown(context)
+      right.teardown(context)
+    }
+  }
+
+  def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : ScalarFragment = 
+    new LocalFragment(left.getFragment(indices), right.getFragment(indices))
+}
index dc31ba11497644b3342b341261bc396427876d1d..72a31b53d2b0d48a5321a70c50c329897552f884 100644 (file)
@@ -1,3 +1,6 @@
 package ofc.generators.onetep
+import ofc.codegen._
 
-class Laplacian(op: Field)  extends Field
+class Laplacian(op: Field)  extends Field {
+  def getFragment(indices: Map[NamedIndex, Expression[IntType]]) = op.getFragment(indices)
+}
diff --git a/src/ofc/generators/onetep/Matrix.scala b/src/ofc/generators/onetep/Matrix.scala
deleted file mode 100644 (file)
index 686d759..0000000
+++ /dev/null
@@ -1,4 +0,0 @@
-package ofc.generators.onetep
-
-trait Matrix {
-}
index 271f5e6ccb685cd5cf17347c00846eed177092f1..bc50ca1e50c74f4add9b70035b84ae8b594b225b 100644 (file)
@@ -1,3 +1,5 @@
 package ofc.generators.onetep
 
-class NamedIndex(name: String)
+class NamedIndex(name: String) {
+  def getName = name
+}
index 45081cf1f8d8e51912d8aece192c5f2915d12211..23b87ae26b9235ba1e2ba2576d2edcac47351d45 100644 (file)
@@ -97,16 +97,25 @@ object PPDFunctionSet {
 }
 */
 
-class PPDFunctionSet(basisName: String, dataName: String) extends Field
+class PPDFunctionSet(basisName: String, dataName: String, indices: Seq[NamedIndex]) extends Field {
+  class LocalFragment(parent: PPDFunctionSet) extends PsincFragment {
+    def setup(context: GenerationContext) {}
+    def teardown(context: GenerationContext) {}
+    def toReciprocal : ReciprocalFragment = new LocalReciprocal(parent)
+  }
+
+  class LocalReciprocal(parent: PPDFunctionSet) extends ReciprocalFragment {
+    val fftbox = new DeclaredVarSymbol[ArrayType[FloatType]]("fftbox", new ArrayType[FloatType](3))
+
+    def setup(context: GenerationContext) {
+      context.addDeclaration(fftbox)
+    }
 
-  /*
-class PPDFunctionSet private(discreteIndices: Seq[DiscreteIndex], 
-  spatialIndices: Seq[SpatialIndex], data: Expression[FloatType], 
-  producer: ProducerStatement) extends FunctionSet {
+    def teardown(context: GenerationContext) {
+    }
+  }
+
+  def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : FieldFragment =
+    new LocalFragment(this)
 
-  def getProducer = producer
-  def getDiscreteIndices = discreteIndices
-  def getSpatialIndices = spatialIndices
-  def getDataValue = data
 }
-*/
index ff2f1676eeffb6f64797f764262198730a83d93e..bcffbce1a26983d9ecaa2e04cb158f02fc2ff5b6 100644 (file)
@@ -1,7 +1,15 @@
 package ofc.generators.onetep
-import ofc.codegen.{ProducerStatement,NullStatement,Comment, FloatLiteral}
+import ofc.codegen._
 
-class SPAM3(name : String) extends Scalar {
-  override def toString = name
-  def getName = name
+class SPAM3(name : String, indices: Seq[NamedIndex]) extends Scalar {
+  class LocalFragment extends ScalarFragment {
+    def setup(context: GenerationContext) {
+    }
+
+    def teardown(context: GenerationContext) {
+    }
+  }
+
+  def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : ScalarFragment =
+    new LocalFragment
 }
index 5e673a6d82b41f85ad54f520b3148edec38e242e..89a0f5136932bace9ae258fe2c21619587bd684d 100644 (file)
@@ -1,4 +1,6 @@
 package ofc.generators.onetep
+import ofc.codegen._
 
 trait Scalar {
+  def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : ScalarFragment
 }
diff --git a/src/ofc/generators/onetep/ScalarAccess.scala b/src/ofc/generators/onetep/ScalarAccess.scala
deleted file mode 100644 (file)
index 0b7e5c0..0000000
+++ /dev/null
@@ -1,4 +0,0 @@
-package ofc.generators.onetep
-
-class ScalarAccess(op: Scalar, indices: Seq[NamedIndex]) extends Scalar
-
diff --git a/src/ofc/generators/onetep/ScalarFragment.scala b/src/ofc/generators/onetep/ScalarFragment.scala
new file mode 100644 (file)
index 0000000..20ff425
--- /dev/null
@@ -0,0 +1,3 @@
+package ofc.generators.onetep
+
+trait ScalarFragment extends Fragment
index fc38f9a6cc7dad5e9c0c0789b65bf11a6402f650..12fb2dcd5657339f8b49a7f008dca80a2d7e7837 100644 (file)
@@ -1,3 +1,16 @@
 package ofc.generators.onetep
+import ofc.codegen._
 
-class ScalarLiteral(s: Double) extends Scalar
+class ScalarLiteral(s: Double) extends Scalar {
+  class LocalFragment(s: Double) extends ScalarFragment {
+    def setup(context: GenerationContext) {
+    }
+
+    def teardown(context: GenerationContext) {
+    }
+  }
+
+  def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : ScalarFragment =
+    new LocalFragment(s)
+
+}
index 070b0fd3f9a02fc9985cada25a511d0be8084724..d7aa6dbb58dcc83c90096c064cf1e48a9628dd46 100644 (file)
@@ -1,3 +1,7 @@
 package ofc.generators.onetep
+import ofc.codegen._
 
-class ScaledField(op: Field, factor: Scalar) extends Field
+class ScaledField(op: Field, factor: Scalar) extends Field {
+  def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : FieldFragment =
+    op.getFragment(indices)
+}