From 3f12368a850ca298818371c56824dd94bb606649 Mon Sep 17 00:00:00 2001 From: Francis Russell Date: Tue, 8 May 2012 17:13:01 +0100 Subject: [PATCH] Add matrix element assignment. --- src/ofc/generators/onetep/CodeGenerator.scala | 2 ++ src/ofc/generators/onetep/InnerProduct.scala | 3 +-- .../generators/onetep/OnetepFunctions.scala | 7 +++++++ src/ofc/generators/onetep/SPAM3.scala | 18 ++++++++++++------ src/ofc/generators/onetep/ScalarFragment.scala | 7 +++++++ src/ofc/generators/onetep/ScalarLiteral.scala | 2 +- 6 files changed, 30 insertions(+), 9 deletions(-) diff --git a/src/ofc/generators/onetep/CodeGenerator.scala b/src/ofc/generators/onetep/CodeGenerator.scala index e470d2a..d93d0c0 100644 --- a/src/ofc/generators/onetep/CodeGenerator.scala +++ b/src/ofc/generators/onetep/CodeGenerator.scala @@ -36,9 +36,11 @@ class CodeGenerator(dictionary: Dictionary) { val context = new Context val indexMap = iterationInfo.getIndexMappings + val lhsFragment = lhs.getFragment(indexMap) val rhsFragment = rhs.getFragment(indexMap) rhsFragment.setup(context) + lhsFragment.setValue(context, rhsFragment.getValue) rhsFragment.teardown(context) val generator = new FortranGenerator diff --git a/src/ofc/generators/onetep/InnerProduct.scala b/src/ofc/generators/onetep/InnerProduct.scala index b13220b..a07e812 100644 --- a/src/ofc/generators/onetep/InnerProduct.scala +++ b/src/ofc/generators/onetep/InnerProduct.scala @@ -2,8 +2,7 @@ package ofc.generators.onetep import ofc.codegen._ class InnerProduct(left: Field, right: Field) extends Scalar { - - class LocalFragment(left: FieldFragment, right: FieldFragment) extends ScalarFragment { + class LocalFragment(left: FieldFragment, right: FieldFragment) extends ScalarFragment with NonAssignableScalarFragment { val result = new DeclaredVarSymbol[FloatType]("inner_product_result") val leftDense = left.toDensePsinc val rightDense = right.toDensePsinc diff --git a/src/ofc/generators/onetep/OnetepFunctions.scala b/src/ofc/generators/onetep/OnetepFunctions.scala index f862e3c..005bac4 100644 --- a/src/ofc/generators/onetep/OnetepFunctions.scala +++ b/src/ofc/generators/onetep/OnetepFunctions.scala @@ -42,4 +42,11 @@ object OnetepFunctions { Seq(("elem", new IntType), ("mat", OnetepTypes.SPAM3), ("rowcol", new CharType))) + + val sparse_put_element_real = new FortranSubroutineSignature("sparse_put_element_real", + Seq(("el", new FloatType), + ("mat", OnetepTypes.SPAM3), + ("jrow", new IntType), + ("jcol", new IntType))) + } diff --git a/src/ofc/generators/onetep/SPAM3.scala b/src/ofc/generators/onetep/SPAM3.scala index 33235d4..064a0fc 100644 --- a/src/ofc/generators/onetep/SPAM3.scala +++ b/src/ofc/generators/onetep/SPAM3.scala @@ -1,21 +1,27 @@ package ofc.generators.onetep import ofc.codegen._ -class SPAM3(name : String, indices: Seq[NamedIndex]) extends Scalar { +class SPAM3(name : String, position: Seq[NamedIndex]) extends Scalar { val mat = new NamedUnboundVarSymbol[StructType](name, OnetepTypes.SPAM3) - class LocalFragment extends ScalarFragment { + class LocalFragment(row: Expression[IntType], col: Expression[IntType]) extends ScalarFragment { def setup(context: GenerationContext) { } - def getValue = throw new ofc.UnimplementedException("rargh!") + def getValue = throw new ofc.UnimplementedException("get unimplemented for SPAM3") + + def setValue(context: GenerationContext, value: Expression[FloatType]) { + val functionCall = new FunctionCall(OnetepFunctions.sparse_put_element_real, + Seq(value, mat, row, col)) + context += new FunctionCallStatement(functionCall) + } def teardown(context: GenerationContext) { } } def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : ScalarFragment = - new LocalFragment + new LocalFragment(indices.get(position(0)).get, indices.get(position(1)).get) def getIterationInfo : IterationInfo = { val context = new IterationContext @@ -67,8 +73,8 @@ class SPAM3(name : String, indices: Seq[NamedIndex]) extends Scalar { context.addPredicate(index.at(rowIdx) |==| rowAtom) var indexMappings : Map[NamedIndex, Expression[IntType]] = Map.empty - indexMappings += indices(0) -> row - indexMappings += indices(1) -> col + indexMappings += position(0) -> row + indexMappings += position(1) -> col new IterationInfo(context, indexMappings) } diff --git a/src/ofc/generators/onetep/ScalarFragment.scala b/src/ofc/generators/onetep/ScalarFragment.scala index 850c213..936bbc9 100644 --- a/src/ofc/generators/onetep/ScalarFragment.scala +++ b/src/ofc/generators/onetep/ScalarFragment.scala @@ -3,4 +3,11 @@ import ofc.codegen._ trait ScalarFragment extends Fragment { def getValue : Expression[FloatType] + def setValue(context: GenerationContext, value: Expression[FloatType]) +} + +trait NonAssignableScalarFragment { + def setValue(context: GenerationContext, value: Expression[FloatType]) { + throw new ofc.LogicError("Expression: "+this+" is not assignable.") + } } diff --git a/src/ofc/generators/onetep/ScalarLiteral.scala b/src/ofc/generators/onetep/ScalarLiteral.scala index 8c0704c..277613e 100644 --- a/src/ofc/generators/onetep/ScalarLiteral.scala +++ b/src/ofc/generators/onetep/ScalarLiteral.scala @@ -2,7 +2,7 @@ package ofc.generators.onetep import ofc.codegen._ class ScalarLiteral(s: Double) extends Scalar { - class LocalFragment(s: Double) extends ScalarFragment { + class LocalFragment(s: Double) extends ScalarFragment with NonAssignableScalarFragment { def setup(context: GenerationContext) { } -- 2.47.3