val indexMap : Map[NamedIndex, Expression[IntType]] = {
for((index, sym) <- indexSyms) yield
- (index, new VarRef[IntType](sym))
+ (index, sym: Expression[IntType])
}.toMap
class Context extends GenerationContext {
class InnerProduct(left: Field, right: Field) extends Scalar {
class LocalFragment(left: FieldFragment, right: FieldFragment) extends ScalarFragment {
+ val result = new DeclaredVarSymbol[FloatType]("inner_product_result")
val leftDense = left.toDensePsinc
val rightDense = right.toDensePsinc
def setup(context: GenerationContext) {
+ context.addDeclaration(result)
leftDense.setup(context)
rightDense.setup(context)
+ val leftOrigin = leftDense.getOrigin
+ val leftSize = leftDense.getSize
+
+ val rightOrigin = rightDense.getOrigin
+ val rightSize = rightDense.getSize
+
+ val topLeft : Seq[Expression[IntType]] =
+ for (dim <- 0 to 2) yield new Max[IntType](leftOrigin(dim), rightOrigin(dim))
+
+ val bottomRight : Seq[Expression[IntType]] =
+ for (dim <- 0 to 2) yield new Min[IntType](leftOrigin(dim) + leftSize(dim), rightOrigin(dim) + rightSize(dim))
+
+ val indices = for(dim <- 0 to 2) yield {
+ val index = new DeclaredVarSymbol[IntType]("i"+(dim+1))
+ context.addDeclaration(index)
+ index
+ }
+
+ val loops = for(dim <- 0 to 2) yield new ForLoop(indices(dim), topLeft(dim), bottomRight(dim))
+ for(dim <- 1 to 2) loops(dim) += loops(dim-1)
+
+ context += new AssignStatement(result, new FloatLiteral(0.0))
+ context += loops(2)
+
+ val leftIndex = for (dim <- 0 to 2) yield indices(dim) - leftOrigin(dim)
+ val rightIndex = for (dim <- 0 to 2) yield indices(dim) - rightOrigin(dim)
+
+ loops(0) += new AssignStatement(result, (result : Expression[FloatType]) +
+ leftDense.getBuffer.at(leftIndex: _*) *
+ rightDense.getBuffer.at(rightIndex: _*))
+
leftDense.teardown(context)
rightDense.teardown(context)
}
- def getValue = throw new ofc.UnimplementedException("rargh!")
+ def getValue = result
def teardown(context: GenerationContext) {
}
val reciprocalVector = for(dim <- 0 to 2) yield {
val component = new DeclaredVarSymbol[FloatType]("reciprocal_vector"+(dim+1))
context.addDeclaration(component)
- new VarRef[FloatType](component)
+ (component : Expression[FloatType])
}
for(dim <- 0 to 2) {
val tightbox = (~(basis % FunctionBasis.tightBoxes)).at(sphereIndex)
val sphere = (~(basis % FunctionBasis.spheres)).at(sphereIndex)
val fftboxOffset = for(dim <- 0 to 2) yield new DeclaredVarSymbol[IntType]("fftbox_offset"+(dim+1))
+ val tightboxOrigin = for(dim <- 0 to 2) yield new DeclaredVarSymbol[IntType]("tightbox_origin"+(dim+1))
def setup(context: GenerationContext) {
import OnetepTypes.FFTBoxInfo
fftboxOffset.map(new VarRef[IntType](_)) ++ fftboxSize))
var basisCopyParams : Seq[Expression[_]] = Nil
- basisCopyParams :+= new VarRef[ArrayType[FloatType]](fftbox)
+ basisCopyParams :+= (fftbox: Expression[ArrayType[FloatType]])
basisCopyParams ++= fftboxOffset.map(new VarRef[IntType](_))
basisCopyParams :+= tightbox
- basisCopyParams :+= new VarRef[ArrayType[FloatType]](parent.data)
+ basisCopyParams :+= (parent.data: Expression[ArrayType[FloatType]])
basisCopyParams :+= sphere
context += new FunctionCallStatement(new FunctionCall(OnetepFunctions.basis_copy_function_to_fftbox, basisCopyParams))
+
+ for (dim <- 0 to 2) yield {
+ import OnetepTypes._
+ val startPPD = tightbox % TightBox.startPPD(dim) - 1
+ val startPPDPoint = startPPD * (CellInfo.public % CellInfo.ppdWidth(dim))
+ val startPoint = startPPDPoint + tightbox % TightBox.startPts(dim)
+
+ context.addDeclaration(tightboxOrigin(dim))
+ context += new AssignStatement(tightboxOrigin(dim), startPoint)
+ }
}
def teardown(context: GenerationContext) {
def getSize = for (dim <- 0 to 2) yield OnetepTypes.FFTBoxInfo.public % OnetepTypes.FFTBoxInfo.totalPts(dim)
- private def getTightBoxOrigin = for (dim <- 0 to 2) yield {
- import OnetepTypes._
- val startPPD = tightbox % TightBox.startPPD(dim) - 1
- val startPPDPoint = startPPD * (CellInfo.public % CellInfo.ppdWidth(dim))
- val startPoint = startPPDPoint + tightbox % TightBox.startPts(dim)
- startPoint
- }
-
def getOrigin = {
- val tightBoxOrigin = getTightBoxOrigin
-
for (dim <- 0 to 2) yield
- tightBoxOrigin(dim) - fftboxOffset(dim)
+ tightboxOrigin(dim) - fftboxOffset(dim)
}
def getBuffer = fftbox