From e68fe69ebb8f995421bb7048ce2e1170b3fe2490 Mon Sep 17 00:00:00 2001 From: Francis Russell Date: Thu, 3 May 2012 17:32:26 +0100 Subject: [PATCH] Add code generation for inner product. --- src/ofc/generators/onetep/CodeGenerator.scala | 2 +- src/ofc/generators/onetep/InnerProduct.scala | 35 ++++++++++++++++++- src/ofc/generators/onetep/Laplacian.scala | 2 +- .../generators/onetep/PPDFunctionSet.scala | 27 +++++++------- 4 files changed, 50 insertions(+), 16 deletions(-) diff --git a/src/ofc/generators/onetep/CodeGenerator.scala b/src/ofc/generators/onetep/CodeGenerator.scala index a61f87f..3b6a684 100644 --- a/src/ofc/generators/onetep/CodeGenerator.scala +++ b/src/ofc/generators/onetep/CodeGenerator.scala @@ -9,7 +9,7 @@ class CodeGenerator(dictionary: Dictionary) { val indexMap : Map[NamedIndex, Expression[IntType]] = { for((index, sym) <- indexSyms) yield - (index, new VarRef[IntType](sym)) + (index, sym: Expression[IntType]) }.toMap class Context extends GenerationContext { diff --git a/src/ofc/generators/onetep/InnerProduct.scala b/src/ofc/generators/onetep/InnerProduct.scala index c619b73..67e4134 100644 --- a/src/ofc/generators/onetep/InnerProduct.scala +++ b/src/ofc/generators/onetep/InnerProduct.scala @@ -4,18 +4,51 @@ import ofc.codegen._ class InnerProduct(left: Field, right: Field) extends Scalar { class LocalFragment(left: FieldFragment, right: FieldFragment) extends ScalarFragment { + val result = new DeclaredVarSymbol[FloatType]("inner_product_result") val leftDense = left.toDensePsinc val rightDense = right.toDensePsinc def setup(context: GenerationContext) { + context.addDeclaration(result) leftDense.setup(context) rightDense.setup(context) + val leftOrigin = leftDense.getOrigin + val leftSize = leftDense.getSize + + val rightOrigin = rightDense.getOrigin + val rightSize = rightDense.getSize + + val topLeft : Seq[Expression[IntType]] = + for (dim <- 0 to 2) yield new Max[IntType](leftOrigin(dim), rightOrigin(dim)) + + val bottomRight : Seq[Expression[IntType]] = + for (dim <- 0 to 2) yield new Min[IntType](leftOrigin(dim) + leftSize(dim), rightOrigin(dim) + rightSize(dim)) + + val indices = for(dim <- 0 to 2) yield { + val index = new DeclaredVarSymbol[IntType]("i"+(dim+1)) + context.addDeclaration(index) + index + } + + val loops = for(dim <- 0 to 2) yield new ForLoop(indices(dim), topLeft(dim), bottomRight(dim)) + for(dim <- 1 to 2) loops(dim) += loops(dim-1) + + context += new AssignStatement(result, new FloatLiteral(0.0)) + context += loops(2) + + val leftIndex = for (dim <- 0 to 2) yield indices(dim) - leftOrigin(dim) + val rightIndex = for (dim <- 0 to 2) yield indices(dim) - rightOrigin(dim) + + loops(0) += new AssignStatement(result, (result : Expression[FloatType]) + + leftDense.getBuffer.at(leftIndex: _*) * + rightDense.getBuffer.at(rightIndex: _*)) + leftDense.teardown(context) rightDense.teardown(context) } - def getValue = throw new ofc.UnimplementedException("rargh!") + def getValue = result def teardown(context: GenerationContext) { } diff --git a/src/ofc/generators/onetep/Laplacian.scala b/src/ofc/generators/onetep/Laplacian.scala index ae7ab15..46a295e 100644 --- a/src/ofc/generators/onetep/Laplacian.scala +++ b/src/ofc/generators/onetep/Laplacian.scala @@ -28,7 +28,7 @@ class Laplacian(op: Field) extends Field { val reciprocalVector = for(dim <- 0 to 2) yield { val component = new DeclaredVarSymbol[FloatType]("reciprocal_vector"+(dim+1)) context.addDeclaration(component) - new VarRef[FloatType](component) + (component : Expression[FloatType]) } for(dim <- 0 to 2) { diff --git a/src/ofc/generators/onetep/PPDFunctionSet.scala b/src/ofc/generators/onetep/PPDFunctionSet.scala index cb0c111..d624003 100644 --- a/src/ofc/generators/onetep/PPDFunctionSet.scala +++ b/src/ofc/generators/onetep/PPDFunctionSet.scala @@ -121,6 +121,7 @@ class PPDFunctionSet(basisName: String, dataName: String, indices: Seq[NamedInde val tightbox = (~(basis % FunctionBasis.tightBoxes)).at(sphereIndex) val sphere = (~(basis % FunctionBasis.spheres)).at(sphereIndex) val fftboxOffset = for(dim <- 0 to 2) yield new DeclaredVarSymbol[IntType]("fftbox_offset"+(dim+1)) + val tightboxOrigin = for(dim <- 0 to 2) yield new DeclaredVarSymbol[IntType]("tightbox_origin"+(dim+1)) def setup(context: GenerationContext) { import OnetepTypes.FFTBoxInfo @@ -134,13 +135,23 @@ class PPDFunctionSet(basisName: String, dataName: String, indices: Seq[NamedInde fftboxOffset.map(new VarRef[IntType](_)) ++ fftboxSize)) var basisCopyParams : Seq[Expression[_]] = Nil - basisCopyParams :+= new VarRef[ArrayType[FloatType]](fftbox) + basisCopyParams :+= (fftbox: Expression[ArrayType[FloatType]]) basisCopyParams ++= fftboxOffset.map(new VarRef[IntType](_)) basisCopyParams :+= tightbox - basisCopyParams :+= new VarRef[ArrayType[FloatType]](parent.data) + basisCopyParams :+= (parent.data: Expression[ArrayType[FloatType]]) basisCopyParams :+= sphere context += new FunctionCallStatement(new FunctionCall(OnetepFunctions.basis_copy_function_to_fftbox, basisCopyParams)) + + for (dim <- 0 to 2) yield { + import OnetepTypes._ + val startPPD = tightbox % TightBox.startPPD(dim) - 1 + val startPPDPoint = startPPD * (CellInfo.public % CellInfo.ppdWidth(dim)) + val startPoint = startPPDPoint + tightbox % TightBox.startPts(dim) + + context.addDeclaration(tightboxOrigin(dim)) + context += new AssignStatement(tightboxOrigin(dim), startPoint) + } } def teardown(context: GenerationContext) { @@ -149,19 +160,9 @@ class PPDFunctionSet(basisName: String, dataName: String, indices: Seq[NamedInde def getSize = for (dim <- 0 to 2) yield OnetepTypes.FFTBoxInfo.public % OnetepTypes.FFTBoxInfo.totalPts(dim) - private def getTightBoxOrigin = for (dim <- 0 to 2) yield { - import OnetepTypes._ - val startPPD = tightbox % TightBox.startPPD(dim) - 1 - val startPPDPoint = startPPD * (CellInfo.public % CellInfo.ppdWidth(dim)) - val startPoint = startPPDPoint + tightbox % TightBox.startPts(dim) - startPoint - } - def getOrigin = { - val tightBoxOrigin = getTightBoxOrigin - for (dim <- 0 to 2) yield - tightBoxOrigin(dim) - fftboxOffset(dim) + tightboxOrigin(dim) - fftboxOffset(dim) } def getBuffer = fftbox -- 2.47.3