]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Add matrix element assignment.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Tue, 8 May 2012 16:13:01 +0000 (17:13 +0100)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Tue, 8 May 2012 16:13:01 +0000 (17:13 +0100)
src/ofc/generators/onetep/CodeGenerator.scala
src/ofc/generators/onetep/InnerProduct.scala
src/ofc/generators/onetep/OnetepFunctions.scala
src/ofc/generators/onetep/SPAM3.scala
src/ofc/generators/onetep/ScalarFragment.scala
src/ofc/generators/onetep/ScalarLiteral.scala

index e470d2ad073a8407c0ce3a47a9858945565703dc..d93d0c0482e57399ffa0d9a408b2cbf10f96a75f 100644 (file)
@@ -36,9 +36,11 @@ class CodeGenerator(dictionary: Dictionary) {
     val context = new Context
 
     val indexMap = iterationInfo.getIndexMappings
+    val lhsFragment = lhs.getFragment(indexMap)
     val rhsFragment = rhs.getFragment(indexMap)
 
     rhsFragment.setup(context)
+    lhsFragment.setValue(context, rhsFragment.getValue)
     rhsFragment.teardown(context)
 
     val generator = new FortranGenerator
index b13220be13d9e3bc5eca1b20fe3224c38a02aea8..a07e8121d5cb857bcc9c7c708e3d6d291389a01a 100644 (file)
@@ -2,8 +2,7 @@ package ofc.generators.onetep
 import ofc.codegen._
 
 class InnerProduct(left: Field, right: Field) extends Scalar {
-
-  class LocalFragment(left: FieldFragment, right: FieldFragment) extends ScalarFragment {
+  class LocalFragment(left: FieldFragment, right: FieldFragment) extends ScalarFragment with NonAssignableScalarFragment {
     val result = new DeclaredVarSymbol[FloatType]("inner_product_result")
     val leftDense = left.toDensePsinc
     val rightDense = right.toDensePsinc
index f862e3ccbdffa510e52f189750409867e11693e8..005bac4fc65e3f0ab7f739f42219b230a6b0b7be 100644 (file)
@@ -42,4 +42,11 @@ object OnetepFunctions {
     Seq(("elem", new IntType),
         ("mat", OnetepTypes.SPAM3),
         ("rowcol", new CharType)))
+
+  val sparse_put_element_real = new FortranSubroutineSignature("sparse_put_element_real",
+    Seq(("el", new FloatType),
+        ("mat", OnetepTypes.SPAM3),
+        ("jrow", new IntType),
+        ("jcol", new IntType)))
+
 }
index 33235d4c1cd200139d73a8371f3dbd441583b885..064a0fcea966d6e3277514a2da9f11330a4b327b 100644 (file)
@@ -1,21 +1,27 @@
 package ofc.generators.onetep
 import ofc.codegen._
 
-class SPAM3(name : String, indices: Seq[NamedIndex]) extends Scalar {
+class SPAM3(name : String, position: Seq[NamedIndex]) extends Scalar {
   val mat = new NamedUnboundVarSymbol[StructType](name, OnetepTypes.SPAM3)
 
-  class LocalFragment extends ScalarFragment {
+  class LocalFragment(row: Expression[IntType], col: Expression[IntType]) extends ScalarFragment {
     def setup(context: GenerationContext) {
     }
 
-    def getValue = throw new ofc.UnimplementedException("rargh!")
+    def getValue = throw new ofc.UnimplementedException("get unimplemented for SPAM3")
+    
+    def setValue(context: GenerationContext, value: Expression[FloatType]) {
+      val functionCall = new FunctionCall(OnetepFunctions.sparse_put_element_real,
+        Seq(value, mat, row, col))
+      context += new FunctionCallStatement(functionCall)
+    }
 
     def teardown(context: GenerationContext) {
     }
   }
 
   def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : ScalarFragment =
-    new LocalFragment
+    new LocalFragment(indices.get(position(0)).get, indices.get(position(1)).get)
 
   def getIterationInfo : IterationInfo = {
     val context = new IterationContext
@@ -67,8 +73,8 @@ class SPAM3(name : String, indices: Seq[NamedIndex]) extends Scalar {
     context.addPredicate(index.at(rowIdx) |==| rowAtom)
 
     var indexMappings : Map[NamedIndex, Expression[IntType]] = Map.empty
-    indexMappings += indices(0) -> row
-    indexMappings += indices(1) -> col
+    indexMappings += position(0) -> row
+    indexMappings += position(1) -> col
 
     new IterationInfo(context, indexMappings)
   }
index 850c21314ea5ec3897a9dab9fc8e713a35935349..936bbc95df7989bb83d67328a64a99b506f62c1f 100644 (file)
@@ -3,4 +3,11 @@ import ofc.codegen._
 
 trait ScalarFragment extends Fragment {
   def getValue : Expression[FloatType]
+  def setValue(context: GenerationContext, value: Expression[FloatType])
+}
+
+trait NonAssignableScalarFragment {
+  def setValue(context: GenerationContext, value: Expression[FloatType]) {
+    throw new ofc.LogicError("Expression: "+this+" is not assignable.")
+  }
 }
index 8c0704c371e129042e06a0751dfa860a99ae5ccb..277613e4da42aad9aa003cb5d7772e6d1c2d5378 100644 (file)
@@ -2,7 +2,7 @@ package ofc.generators.onetep
 import ofc.codegen._
 
 class ScalarLiteral(s: Double) extends Scalar {
-  class LocalFragment(s: Double) extends ScalarFragment {
+  class LocalFragment(s: Double) extends ScalarFragment with NonAssignableScalarFragment {
     def setup(context: GenerationContext) {
     }