]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Add code generation for inner product.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Thu, 3 May 2012 16:32:26 +0000 (17:32 +0100)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Thu, 3 May 2012 16:32:26 +0000 (17:32 +0100)
src/ofc/generators/onetep/CodeGenerator.scala
src/ofc/generators/onetep/InnerProduct.scala
src/ofc/generators/onetep/Laplacian.scala
src/ofc/generators/onetep/PPDFunctionSet.scala

index a61f87f700eeb7b421b5888fe0d11a652125c61e..3b6a6843bfe9a72a5cdebafd245d70feca736588 100644 (file)
@@ -9,7 +9,7 @@ class CodeGenerator(dictionary: Dictionary) {
 
   val indexMap : Map[NamedIndex, Expression[IntType]] = {
     for((index, sym) <- indexSyms) yield
-      (index, new VarRef[IntType](sym))
+      (index, sym: Expression[IntType])
   }.toMap
 
   class Context extends GenerationContext {
index c619b73fb20459b2cefaed62103f1cf6f2ac2c8c..67e41343f6f093a98e6541d22566830c19f196f2 100644 (file)
@@ -4,18 +4,51 @@ import ofc.codegen._
 class InnerProduct(left: Field, right: Field) extends Scalar {
 
   class LocalFragment(left: FieldFragment, right: FieldFragment) extends ScalarFragment {
+    val result = new DeclaredVarSymbol[FloatType]("inner_product_result")
     val leftDense = left.toDensePsinc
     val rightDense = right.toDensePsinc
 
     def setup(context: GenerationContext) {
+      context.addDeclaration(result)
       leftDense.setup(context)
       rightDense.setup(context)
 
+      val leftOrigin = leftDense.getOrigin
+      val leftSize = leftDense.getSize
+
+      val rightOrigin = rightDense.getOrigin
+      val rightSize = rightDense.getSize
+
+      val topLeft : Seq[Expression[IntType]] = 
+        for (dim <- 0 to 2) yield new Max[IntType](leftOrigin(dim), rightOrigin(dim))
+
+      val bottomRight : Seq[Expression[IntType]] = 
+        for (dim <- 0 to 2) yield new Min[IntType](leftOrigin(dim) + leftSize(dim), rightOrigin(dim) + rightSize(dim))
+
+      val indices = for(dim <- 0 to 2) yield {
+        val index = new DeclaredVarSymbol[IntType]("i"+(dim+1))
+        context.addDeclaration(index)
+        index
+      }
+
+      val loops = for(dim <- 0 to 2) yield new ForLoop(indices(dim), topLeft(dim), bottomRight(dim))
+      for(dim <- 1 to 2) loops(dim) += loops(dim-1)
+
+      context += new AssignStatement(result, new FloatLiteral(0.0))
+      context += loops(2)
+
+      val leftIndex = for (dim <- 0 to 2) yield indices(dim) - leftOrigin(dim)
+      val rightIndex = for (dim <- 0 to 2) yield indices(dim) - rightOrigin(dim)
+
+      loops(0) += new AssignStatement(result, (result : Expression[FloatType]) + 
+        leftDense.getBuffer.at(leftIndex: _*) * 
+        rightDense.getBuffer.at(rightIndex: _*))
+
       leftDense.teardown(context)
       rightDense.teardown(context)
     }
 
-    def getValue = throw new ofc.UnimplementedException("rargh!")
+    def getValue = result
 
     def teardown(context: GenerationContext) {
     }
index ae7ab154b52e0c0e810971aaa2b22731bd19242e..46a295e8a668d551ae39a24cddd3f51df3983c21 100644 (file)
@@ -28,7 +28,7 @@ class Laplacian(op: Field)  extends Field {
       val reciprocalVector = for(dim <- 0 to 2) yield {
         val component = new DeclaredVarSymbol[FloatType]("reciprocal_vector"+(dim+1))
         context.addDeclaration(component)
-        new VarRef[FloatType](component)
+        (component : Expression[FloatType])
       }
 
       for(dim <- 0 to 2) {
index cb0c111825da32a60d55f68c635051e5478bdb47..d62400372277dbfa2d06efad7381031a04c0238e 100644 (file)
@@ -121,6 +121,7 @@ class PPDFunctionSet(basisName: String, dataName: String, indices: Seq[NamedInde
     val tightbox = (~(basis % FunctionBasis.tightBoxes)).at(sphereIndex)
     val sphere = (~(basis % FunctionBasis.spheres)).at(sphereIndex) 
     val fftboxOffset = for(dim <- 0 to 2) yield new DeclaredVarSymbol[IntType]("fftbox_offset"+(dim+1)) 
+    val tightboxOrigin = for(dim <- 0 to 2) yield new DeclaredVarSymbol[IntType]("tightbox_origin"+(dim+1)) 
 
     def setup(context: GenerationContext) {
       import OnetepTypes.FFTBoxInfo
@@ -134,13 +135,23 @@ class PPDFunctionSet(basisName: String, dataName: String, indices: Seq[NamedInde
         fftboxOffset.map(new VarRef[IntType](_)) ++ fftboxSize))
 
       var basisCopyParams : Seq[Expression[_]] = Nil
-      basisCopyParams :+= new VarRef[ArrayType[FloatType]](fftbox)
+      basisCopyParams :+= (fftbox: Expression[ArrayType[FloatType]])
       basisCopyParams ++= fftboxOffset.map(new VarRef[IntType](_))
       basisCopyParams :+= tightbox
-      basisCopyParams :+= new VarRef[ArrayType[FloatType]](parent.data)
+      basisCopyParams :+= (parent.data: Expression[ArrayType[FloatType]])
       basisCopyParams :+= sphere
 
       context += new FunctionCallStatement(new FunctionCall(OnetepFunctions.basis_copy_function_to_fftbox, basisCopyParams))
+
+      for (dim <- 0 to 2) yield {
+        import OnetepTypes._
+        val startPPD = tightbox % TightBox.startPPD(dim) - 1
+        val startPPDPoint = startPPD * (CellInfo.public % CellInfo.ppdWidth(dim))
+        val startPoint = startPPDPoint + tightbox % TightBox.startPts(dim)
+
+        context.addDeclaration(tightboxOrigin(dim))
+        context += new AssignStatement(tightboxOrigin(dim), startPoint)
+      }
     }
 
     def teardown(context: GenerationContext) {
@@ -149,19 +160,9 @@ class PPDFunctionSet(basisName: String, dataName: String, indices: Seq[NamedInde
 
     def getSize = for (dim <- 0 to 2) yield OnetepTypes.FFTBoxInfo.public % OnetepTypes.FFTBoxInfo.totalPts(dim)
 
-    private def getTightBoxOrigin = for (dim <- 0 to 2) yield {
-      import OnetepTypes._
-      val startPPD = tightbox % TightBox.startPPD(dim) - 1
-      val startPPDPoint = startPPD * (CellInfo.public % CellInfo.ppdWidth(dim))
-      val startPoint = startPPDPoint + tightbox % TightBox.startPts(dim)
-      startPoint
-    }
-
     def getOrigin = {
-      val tightBoxOrigin = getTightBoxOrigin
-
       for (dim <- 0 to 2) yield
-        tightBoxOrigin(dim) - fftboxOffset(dim)
+        tightboxOrigin(dim) - fftboxOffset(dim)
     }
 
     def getBuffer = fftbox