package ofc.codegen
+import ofc.LogicError
+
+class AssignStatement(lhs: Expression[_ <: Type], rhs: Expression[_ <: Type]) extends Statement {
+ if (lhs.getType != rhs.getType)
+ throw new LogicError("Assignment from incompatible type.")
-class Assignment(lhs: Expression[_ <: Type], rhs: Expression[_ <: Type]) extends Statement {
def getLHS : Expression[_] = lhs
def getRHS : Expression[_] = rhs
- // TODO: type check assignment
}
def %[FieldType <: Type](field: FieldSymbol[FieldType])(implicit witness: <:<[this.type, Expression[StructType]]) : Expression[FieldType] =
new FieldAccess[FieldType](witness(this), field)
- def readAt[T <: Type](index: Expression[IntType]*)(implicit witness: <:<[this.type, Expression[ArrayType[T]]]) : Expression[T] =
- new ArrayRead(witness(this), index.toList)
+ def at[T <: Type](index: Expression[IntType]*)(implicit witness: <:<[this.type, Expression[ArrayType[T]]]) : Expression[T] =
+ new ArrayAccess(witness(this), index.toList)
def unary_~[T <: Type]()(implicit witness: <:<[this.type, Expression[PointerType[T]]]) : Expression[T] =
new PointerDereference(witness(this))
def getType = field.getType
}
-class ArrayRead[E <: Type](expression: Expression[ArrayType[E]], index: Seq[Expression[IntType]]) extends Expression[E] {
+class ArrayAccess[E <: Type](expression: Expression[ArrayType[E]], index: Seq[Expression[IntType]]) extends Expression[E] {
if (index.size != expression.getType.getRank)
throw new LogicError("Array of rank "+expression.getType.getRank+" indexed with rank "+index.size+" index.")
def getType = expression.getType.getTargetType
}
+// Conversation
+class Conversion[F <: Type, T <: Type](expression: Expression[F], toType: T)(implicit convertible: IsConvertible[F,T]) extends Expression[T] {
+ def this(expression: Expression[F])(implicit builder: TypeBuilder[T], convertible: IsConvertible[F,T])
+ = this(expression, builder())(convertible)
+
+ def getType = toType
+ def getExpression = expression
+ def foreach[U](f: Expression[_] => U) = expression.foreach(f)
+}
+
// Literals
class IntegerLiteral(value: Int) extends Expression[IntType] with LeafExpression {
def getValue = value
case (x : BlockStatement) => processScope(x)
case (x : ProducerStatement) => processStatement(x.toConcrete)
case (x : ForLoop) => processForLoop(x)
- case (a : Assignment) => processAssignment(a)
+ case (a : AssignStatement) => processAssignment(a)
case (i : IfStatement) => processIf(i)
case (a : AllocateStatement) => processAllocate(a)
case (d : DeallocateStatement) => processDeallocate(d)
case (s: NamedUnboundVarSymbol[_]) => ExpHolder(maxPrec, s.getName)
case s => ExpHolder(maxPrec, symbolManager.getName(s))
}
- case (r: ArrayRead[_]) =>
+ case (r: ArrayAccess[_]) =>
ExpHolder(maxPrec, buildExpression(r.getArrayExpression) + r.getIndexExpressions.map(buildExpression(_)).mkString("(",", ",")"))
case (d: PointerDereference[_]) => buildExpression(d.getExpression)
case (c: ConditionalValue[_]) => buildConditionalValue(c)
case (c: NumericComparison[_]) => buildNumericComparison(c)
case (c: NumericOperator[_]) => buildNumericOperator(c)
+ case (c: Conversion[_,_]) => buildConversion(c)
case x => throw new UnimplementedException("Unknown expression type in FORTRAN generator: " + x.toString)
}
}
ExpHolder(opInfo.precedence, opInfo.template.format(lhs, rhs))
}
+ private def buildConversion(c: Conversion[_,_]) : ExpHolder = c.getType match {
+ case (_: FloatType) => ExpHolder(maxPrec, "real(%s)".format(buildExpression(c.getExpression)))
+ case (_: ComplexType) => ExpHolder(maxPrec, "cmplx(%s)".format(buildExpression(c.getExpression)))
+ case _ => throw new UnimplementedException("Fortran generator cannot handle conversion.")
+ }
+
+
private def buildNumericComparison(c: NumericComparison[_]) : ExpHolder =
buildBinaryOperation(getBinaryOpInfo(c.getOperation), buildExpression(c.getLeft), buildExpression(c.getRight))
}
}
- private def processAssignment(assignment: Assignment) {
+ private def processAssignment(assignment: AssignStatement) {
addLine("%s = %s".format(buildExpression(assignment.getLHS), buildExpression(assignment.getRHS)))
}
scope = ifStat
}
case DerivedExpression(sym, expression) => {
- val assignment = new Assignment(sym, expression)
+ val assignment = new AssignStatement(sym, expression)
scope.addDeclaration(sym)
scope += assignment
}
object HasProperty {
implicit val intNumeric = new HasProperty[IntType, Numeric]()
implicit val floatNumeric = new HasProperty[FloatType, Numeric]()
+ implicit val complexNumeric = new HasProperty[ComplexType, Numeric]()
+}
+
+class IsConvertible[From <: Type, To <: Type]
+object IsConvertible {
+ implicit val intToFloat = new IsConvertible[IntType, FloatType]()
+ implicit val floatToFloat = new IsConvertible[FloatType, ComplexType]()
}
package ofc.generators.onetep
+import ofc.codegen._
trait FieldFragment extends Fragment {
def toReciprocal : ReciprocalFragment
trait ReciprocalFragment extends FieldFragment {
def toReciprocal = this
+ def getSize : Seq[Expression[IntType]]
+ def getBuffer : Expression[ArrayType[ComplexType]]
}
import ofc.codegen._
class Laplacian(op: Field) extends Field {
- def getFragment(indices: Map[NamedIndex, Expression[IntType]]) = op.getFragment(indices).toReciprocal
+ class LocalFragment(parent: Laplacian, indices: Map[NamedIndex, Expression[IntType]]) extends ReciprocalFragment {
+ val transformed = new DeclaredVarSymbol[ArrayType[ComplexType]]("transformed", new ArrayType[ComplexType](3))
+ val opFragment = parent.getOperand.getFragment(indices).toReciprocal
+
+ def setup(context: GenerationContext) {
+ context.addDeclaration(transformed)
+ opFragment.setup(context)
+
+ context += new AllocateStatement(transformed, opFragment.getSize)
+
+ val indices = for(dim <- 0 to 2) yield {
+ val index = new DeclaredVarSymbol[IntType]("i"+(dim+1))
+ context.addDeclaration(index)
+ index
+ }
+
+ // Construct loops
+ val loops = for(dim <- 0 to 2) yield new ForLoop(indices(dim), 1, getSize(dim))
+
+ // Nest loops and add outer to context
+ for(dim <- 1 to 2) loops(dim) += loops(dim-1)
+ context += loops(2)
+
+ val reciprocalVector = for(dim <- 0 to 2) yield {
+ val component = new DeclaredVarSymbol[FloatType]("reciprocal_vector"+(dim+1))
+ context.addDeclaration(component)
+ new VarRef[FloatType](component)
+ }
+
+ for(dim <- 0 to 2) {
+ var component : Expression[FloatType] = new FloatLiteral(0.0)
+ for(vec <- 0 to 2) {
+ val vector = OnetepTypes.CellInfo.public % OnetepTypes.CellInfo.latticeReciprocal(vec)
+ component = component + vector % OnetepTypes.Point.coord(dim) * new Conversion[IntType, FloatType](indices(dim))
+ }
+ loops(0) += new AssignStatement(reciprocalVector(dim), component)
+ }
+
+ val reciprocalIndex = indices.map(new VarRef[IntType](_))
+
+ //TODO: Use a unary negation instead of multiplication by -1.0.
+ loops(0) += new AssignStatement(transformed.at(reciprocalIndex: _*),
+ opFragment.getBuffer.at(reciprocalIndex: _*) *
+ new Conversion[FloatType, ComplexType](magnitude(reciprocalVector) * new FloatLiteral(-1.0)))
+
+ opFragment.teardown(context)
+ }
+
+ private def magnitude(vector: Seq[Expression[FloatType]]) = {
+ var result : Expression[FloatType] = new FloatLiteral(0.0)
+ for(element <- vector) result += element * element
+ result
+ }
+
+ def teardown(context: GenerationContext) {
+ context += new DeallocateStatement(transformed)
+ }
+
+ def getSize = opFragment.getSize
+
+ def getBuffer = transformed
+ }
+
+ private def getOperand = op
+
+ def getFragment(indices: Map[NamedIndex, Expression[IntType]]) =
+ new LocalFragment(this, indices)
}
def getFortranAttributes = Set("type(SPHERE)")
}
+ object Point extends StructType {
+ val x = new FieldSymbol[FloatType]("X")
+ val y = new FieldSymbol[FloatType]("Y")
+ val z = new FieldSymbol[FloatType]("Z")
+ val coord = List(x,y,z)
+ def getFortranAttributes = Set("type(POINT)")
+ }
+
object CellInfo extends StructType {
val public = new NamedUnboundVarSymbol[StructType]("pub_cell", OnetepTypes.CellInfo)
val ppdWidth = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("n_pt"+dim)}.toSeq
val numPPDs = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("n_ppds_a"+dim)}.toSeq
val pointsInPPD = new FieldSymbol[IntType]("n_pts")
+ val latticeReciprocal = for(dim <- 1 to 3) yield new FieldSymbol[StructType]("b"+dim, Point)
def getFortranAttributes = Set("type(CELL_INFO)")
}
val producer = new ProducerStatement
val sphereIndex = producer.addIteration("sphere_index", 1, numSpheres)
- val numPPDs = (~(basis % FunctionBasis.numPPDsInSphere)).readAt(sphereIndex)
+ val numPPDs = (~(basis % FunctionBasis.numPPDsInSphere)).at(sphereIndex)
val ppdIndex = producer.addIteration("ppd_index", 1, numPPDs)
- val sphere = (~(basis % FunctionBasis.spheres)).readAt(sphereIndex)
- val ppdGlobalCount = (~(sphere % Sphere.ppdList)).readAt(ppdIndex, 1) - 1
+ val sphere = (~(basis % FunctionBasis.spheres)).at(sphereIndex)
+ val ppdGlobalCount = (~(sphere % Sphere.ppdList)).at(ppdIndex, 1) - 1
// The integer co-ordinates of the PPD (0-based)
val a3pos = producer.addExpression("ppd_pos1", ppdGlobalCount / (cellWidthPPDs(0)*cellWidthPPDs(1)))
val a1pos = producer.addExpression("ppd_pos3", ppdGlobalCount % cellWidthPPDs(0))
val ppdPos = List(a1pos, a2pos, a3pos)
- val tightbox = (~(basis % FunctionBasis.tightBoxes)).readAt(sphereIndex)
+ val tightbox = (~(basis % FunctionBasis.tightBoxes)).at(sphereIndex)
// The offsets into the PPDs for the edges of the tightbox
val ppdStartOffsets = for(dim <- 0 to 2) yield tightbox % TightBox.startPts(dim) - 1
+ ppdIndices(1) * (CellInfo.public % CellInfo.ppdWidth(0))
+ ppdIndices(0))
- val dataValue = producer.addExpression("data", data.readAt(ppdDataIndex))
+ val dataValue = producer.addExpression("data", data.at(ppdDataIndex))
val discreteIndices = List[DiscreteIndex](new SphereIndex("sphere", sphereIndex))
val spatialIndices = {
}
val fftbox = new DeclaredVarSymbol[ArrayType[FloatType]]("fftbox", new ArrayType[FloatType](3))
- val tightbox = (~(basis % FunctionBasis.tightBoxes)).readAt(sphereIndex)
- val sphere = (~(basis % FunctionBasis.spheres)).readAt(sphereIndex)
+ val tightbox = (~(basis % FunctionBasis.tightBoxes)).at(sphereIndex)
+ val sphere = (~(basis % FunctionBasis.spheres)).at(sphereIndex)
val fftboxOffset = for(dim <- 0 to 2) yield new DeclaredVarSymbol[IntType]("fftbox_offset"+(dim+1))
val reciprocalBox = new DeclaredVarSymbol[ArrayType[ComplexType]]("reciprocal_box", new ArrayType[ComplexType](3))
context.addDeclaration(reciprocalBox)
fftboxOffset.map(context.addDeclaration(_))
- val fftboxSize : Seq[Expression[IntType]] = for (dim <- 0 to 2) yield FFTBoxInfo.public % FFTBoxInfo.totalPts(dim)
+ val fftboxSize : Seq[Expression[IntType]] = getSize
context += new AllocateStatement(fftbox, fftboxSize)
context += new AllocateStatement(reciprocalBox, fftboxSize)
context += new FunctionCallStatement(new FunctionCall(OnetepFunctions.basis_ket_start_wrt_fftbox,
def teardown(context: GenerationContext) {
context += new DeallocateStatement(reciprocalBox)
}
+
+ def getSize = for (dim <- 0 to 2) yield OnetepTypes.FFTBoxInfo.public % OnetepTypes.FFTBoxInfo.totalPts(dim)
+
+ def getBuffer = reciprocalBox
}
- def getSphereIndex = indices.head
+ private def getSphereIndex = indices.head
def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : FieldFragment =
new LocalFragment(this, indices)