From 0cff3254a3e3128a8beeee21155162dc6a420e15 Mon Sep 17 00:00:00 2001 From: Francis Russell Date: Wed, 2 May 2012 20:30:38 +0100 Subject: [PATCH] Implement application of laplacian operator. --- src/ofc/codegen/Assignment.scala | 7 +- src/ofc/codegen/Expression.scala | 16 ++++- src/ofc/codegen/FortranGenerator.scala | 14 +++- src/ofc/codegen/ProducerStatement.scala | 2 +- src/ofc/codegen/Type.scala | 7 ++ src/ofc/generators/onetep/FieldFragment.scala | 3 + src/ofc/generators/onetep/Laplacian.scala | 68 ++++++++++++++++++- src/ofc/generators/onetep/OnetepTypes.scala | 9 +++ .../generators/onetep/PPDFunctionSet.scala | 22 +++--- 9 files changed, 129 insertions(+), 19 deletions(-) diff --git a/src/ofc/codegen/Assignment.scala b/src/ofc/codegen/Assignment.scala index e932eba..e0b4d0c 100644 --- a/src/ofc/codegen/Assignment.scala +++ b/src/ofc/codegen/Assignment.scala @@ -1,7 +1,10 @@ package ofc.codegen +import ofc.LogicError + +class AssignStatement(lhs: Expression[_ <: Type], rhs: Expression[_ <: Type]) extends Statement { + if (lhs.getType != rhs.getType) + throw new LogicError("Assignment from incompatible type.") -class Assignment(lhs: Expression[_ <: Type], rhs: Expression[_ <: Type]) extends Statement { def getLHS : Expression[_] = lhs def getRHS : Expression[_] = rhs - // TODO: type check assignment } diff --git a/src/ofc/codegen/Expression.scala b/src/ofc/codegen/Expression.scala index 5111d7b..b1433f1 100644 --- a/src/ofc/codegen/Expression.scala +++ b/src/ofc/codegen/Expression.scala @@ -59,8 +59,8 @@ abstract class Expression[T <: Type] extends Traversable[Expression[_]] { def %[FieldType <: Type](field: FieldSymbol[FieldType])(implicit witness: <:<[this.type, Expression[StructType]]) : Expression[FieldType] = new FieldAccess[FieldType](witness(this), field) - def readAt[T <: Type](index: Expression[IntType]*)(implicit witness: <:<[this.type, Expression[ArrayType[T]]]) : Expression[T] = - new ArrayRead(witness(this), index.toList) + def at[T <: Type](index: Expression[IntType]*)(implicit witness: <:<[this.type, Expression[ArrayType[T]]]) : Expression[T] = + new ArrayAccess(witness(this), index.toList) def unary_~[T <: Type]()(implicit witness: <:<[this.type, Expression[PointerType[T]]]) : Expression[T] = new PointerDereference(witness(this)) @@ -89,7 +89,7 @@ class FieldAccess[T <: Type](expression: Expression[StructType], field: FieldSym def getType = field.getType } -class ArrayRead[E <: Type](expression: Expression[ArrayType[E]], index: Seq[Expression[IntType]]) extends Expression[E] { +class ArrayAccess[E <: Type](expression: Expression[ArrayType[E]], index: Seq[Expression[IntType]]) extends Expression[E] { if (index.size != expression.getType.getRank) throw new LogicError("Array of rank "+expression.getType.getRank+" indexed with rank "+index.size+" index.") @@ -104,6 +104,16 @@ class PointerDereference[E <: Type](expression: Expression[PointerType[E]]) exte def getType = expression.getType.getTargetType } +// Conversation +class Conversion[F <: Type, T <: Type](expression: Expression[F], toType: T)(implicit convertible: IsConvertible[F,T]) extends Expression[T] { + def this(expression: Expression[F])(implicit builder: TypeBuilder[T], convertible: IsConvertible[F,T]) + = this(expression, builder())(convertible) + + def getType = toType + def getExpression = expression + def foreach[U](f: Expression[_] => U) = expression.foreach(f) +} + // Literals class IntegerLiteral(value: Int) extends Expression[IntType] with LeafExpression { def getValue = value diff --git a/src/ofc/codegen/FortranGenerator.scala b/src/ofc/codegen/FortranGenerator.scala index 2f8f8b9..9f0ecd4 100644 --- a/src/ofc/codegen/FortranGenerator.scala +++ b/src/ofc/codegen/FortranGenerator.scala @@ -141,7 +141,7 @@ class FortranGenerator { case (x : BlockStatement) => processScope(x) case (x : ProducerStatement) => processStatement(x.toConcrete) case (x : ForLoop) => processForLoop(x) - case (a : Assignment) => processAssignment(a) + case (a : AssignStatement) => processAssignment(a) case (i : IfStatement) => processIf(i) case (a : AllocateStatement) => processAllocate(a) case (d : DeallocateStatement) => processDeallocate(d) @@ -169,12 +169,13 @@ class FortranGenerator { case (s: NamedUnboundVarSymbol[_]) => ExpHolder(maxPrec, s.getName) case s => ExpHolder(maxPrec, symbolManager.getName(s)) } - case (r: ArrayRead[_]) => + case (r: ArrayAccess[_]) => ExpHolder(maxPrec, buildExpression(r.getArrayExpression) + r.getIndexExpressions.map(buildExpression(_)).mkString("(",", ",")")) case (d: PointerDereference[_]) => buildExpression(d.getExpression) case (c: ConditionalValue[_]) => buildConditionalValue(c) case (c: NumericComparison[_]) => buildNumericComparison(c) case (c: NumericOperator[_]) => buildNumericOperator(c) + case (c: Conversion[_,_]) => buildConversion(c) case x => throw new UnimplementedException("Unknown expression type in FORTRAN generator: " + x.toString) } } @@ -216,6 +217,13 @@ class FortranGenerator { ExpHolder(opInfo.precedence, opInfo.template.format(lhs, rhs)) } + private def buildConversion(c: Conversion[_,_]) : ExpHolder = c.getType match { + case (_: FloatType) => ExpHolder(maxPrec, "real(%s)".format(buildExpression(c.getExpression))) + case (_: ComplexType) => ExpHolder(maxPrec, "cmplx(%s)".format(buildExpression(c.getExpression))) + case _ => throw new UnimplementedException("Fortran generator cannot handle conversion.") + } + + private def buildNumericComparison(c: NumericComparison[_]) : ExpHolder = buildBinaryOperation(getBinaryOpInfo(c.getOperation), buildExpression(c.getLeft), buildExpression(c.getRight)) @@ -255,7 +263,7 @@ class FortranGenerator { } } - private def processAssignment(assignment: Assignment) { + private def processAssignment(assignment: AssignStatement) { addLine("%s = %s".format(buildExpression(assignment.getLHS), buildExpression(assignment.getRHS))) } diff --git a/src/ofc/codegen/ProducerStatement.scala b/src/ofc/codegen/ProducerStatement.scala index 485781f..0c182ae 100644 --- a/src/ofc/codegen/ProducerStatement.scala +++ b/src/ofc/codegen/ProducerStatement.scala @@ -133,7 +133,7 @@ class ProducerStatement extends Statement { scope = ifStat } case DerivedExpression(sym, expression) => { - val assignment = new Assignment(sym, expression) + val assignment = new AssignStatement(sym, expression) scope.addDeclaration(sym) scope += assignment } diff --git a/src/ofc/codegen/Type.scala b/src/ofc/codegen/Type.scala index cf94a6e..1a114de 100644 --- a/src/ofc/codegen/Type.scala +++ b/src/ofc/codegen/Type.scala @@ -65,4 +65,11 @@ class HasProperty[T <: Type, P <: TypeProperty] object HasProperty { implicit val intNumeric = new HasProperty[IntType, Numeric]() implicit val floatNumeric = new HasProperty[FloatType, Numeric]() + implicit val complexNumeric = new HasProperty[ComplexType, Numeric]() +} + +class IsConvertible[From <: Type, To <: Type] +object IsConvertible { + implicit val intToFloat = new IsConvertible[IntType, FloatType]() + implicit val floatToFloat = new IsConvertible[FloatType, ComplexType]() } diff --git a/src/ofc/generators/onetep/FieldFragment.scala b/src/ofc/generators/onetep/FieldFragment.scala index a937371..cd954d6 100644 --- a/src/ofc/generators/onetep/FieldFragment.scala +++ b/src/ofc/generators/onetep/FieldFragment.scala @@ -1,4 +1,5 @@ package ofc.generators.onetep +import ofc.codegen._ trait FieldFragment extends Fragment { def toReciprocal : ReciprocalFragment @@ -8,4 +9,6 @@ trait PsincFragment extends FieldFragment trait ReciprocalFragment extends FieldFragment { def toReciprocal = this + def getSize : Seq[Expression[IntType]] + def getBuffer : Expression[ArrayType[ComplexType]] } diff --git a/src/ofc/generators/onetep/Laplacian.scala b/src/ofc/generators/onetep/Laplacian.scala index 9588fad..5f771f7 100644 --- a/src/ofc/generators/onetep/Laplacian.scala +++ b/src/ofc/generators/onetep/Laplacian.scala @@ -2,5 +2,71 @@ package ofc.generators.onetep import ofc.codegen._ class Laplacian(op: Field) extends Field { - def getFragment(indices: Map[NamedIndex, Expression[IntType]]) = op.getFragment(indices).toReciprocal + class LocalFragment(parent: Laplacian, indices: Map[NamedIndex, Expression[IntType]]) extends ReciprocalFragment { + val transformed = new DeclaredVarSymbol[ArrayType[ComplexType]]("transformed", new ArrayType[ComplexType](3)) + val opFragment = parent.getOperand.getFragment(indices).toReciprocal + + def setup(context: GenerationContext) { + context.addDeclaration(transformed) + opFragment.setup(context) + + context += new AllocateStatement(transformed, opFragment.getSize) + + val indices = for(dim <- 0 to 2) yield { + val index = new DeclaredVarSymbol[IntType]("i"+(dim+1)) + context.addDeclaration(index) + index + } + + // Construct loops + val loops = for(dim <- 0 to 2) yield new ForLoop(indices(dim), 1, getSize(dim)) + + // Nest loops and add outer to context + for(dim <- 1 to 2) loops(dim) += loops(dim-1) + context += loops(2) + + val reciprocalVector = for(dim <- 0 to 2) yield { + val component = new DeclaredVarSymbol[FloatType]("reciprocal_vector"+(dim+1)) + context.addDeclaration(component) + new VarRef[FloatType](component) + } + + for(dim <- 0 to 2) { + var component : Expression[FloatType] = new FloatLiteral(0.0) + for(vec <- 0 to 2) { + val vector = OnetepTypes.CellInfo.public % OnetepTypes.CellInfo.latticeReciprocal(vec) + component = component + vector % OnetepTypes.Point.coord(dim) * new Conversion[IntType, FloatType](indices(dim)) + } + loops(0) += new AssignStatement(reciprocalVector(dim), component) + } + + val reciprocalIndex = indices.map(new VarRef[IntType](_)) + + //TODO: Use a unary negation instead of multiplication by -1.0. + loops(0) += new AssignStatement(transformed.at(reciprocalIndex: _*), + opFragment.getBuffer.at(reciprocalIndex: _*) * + new Conversion[FloatType, ComplexType](magnitude(reciprocalVector) * new FloatLiteral(-1.0))) + + opFragment.teardown(context) + } + + private def magnitude(vector: Seq[Expression[FloatType]]) = { + var result : Expression[FloatType] = new FloatLiteral(0.0) + for(element <- vector) result += element * element + result + } + + def teardown(context: GenerationContext) { + context += new DeallocateStatement(transformed) + } + + def getSize = opFragment.getSize + + def getBuffer = transformed + } + + private def getOperand = op + + def getFragment(indices: Map[NamedIndex, Expression[IntType]]) = + new LocalFragment(this, indices) } diff --git a/src/ofc/generators/onetep/OnetepTypes.scala b/src/ofc/generators/onetep/OnetepTypes.scala index fd8eabd..f2f1d7e 100644 --- a/src/ofc/generators/onetep/OnetepTypes.scala +++ b/src/ofc/generators/onetep/OnetepTypes.scala @@ -34,11 +34,20 @@ object OnetepTypes { def getFortranAttributes = Set("type(SPHERE)") } + object Point extends StructType { + val x = new FieldSymbol[FloatType]("X") + val y = new FieldSymbol[FloatType]("Y") + val z = new FieldSymbol[FloatType]("Z") + val coord = List(x,y,z) + def getFortranAttributes = Set("type(POINT)") + } + object CellInfo extends StructType { val public = new NamedUnboundVarSymbol[StructType]("pub_cell", OnetepTypes.CellInfo) val ppdWidth = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("n_pt"+dim)}.toSeq val numPPDs = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("n_ppds_a"+dim)}.toSeq val pointsInPPD = new FieldSymbol[IntType]("n_pts") + val latticeReciprocal = for(dim <- 1 to 3) yield new FieldSymbol[StructType]("b"+dim, Point) def getFortranAttributes = Set("type(CELL_INFO)") } diff --git a/src/ofc/generators/onetep/PPDFunctionSet.scala b/src/ofc/generators/onetep/PPDFunctionSet.scala index 3bf5480..754e0e7 100644 --- a/src/ofc/generators/onetep/PPDFunctionSet.scala +++ b/src/ofc/generators/onetep/PPDFunctionSet.scala @@ -27,10 +27,10 @@ object PPDFunctionSet { val producer = new ProducerStatement val sphereIndex = producer.addIteration("sphere_index", 1, numSpheres) - val numPPDs = (~(basis % FunctionBasis.numPPDsInSphere)).readAt(sphereIndex) + val numPPDs = (~(basis % FunctionBasis.numPPDsInSphere)).at(sphereIndex) val ppdIndex = producer.addIteration("ppd_index", 1, numPPDs) - val sphere = (~(basis % FunctionBasis.spheres)).readAt(sphereIndex) - val ppdGlobalCount = (~(sphere % Sphere.ppdList)).readAt(ppdIndex, 1) - 1 + val sphere = (~(basis % FunctionBasis.spheres)).at(sphereIndex) + val ppdGlobalCount = (~(sphere % Sphere.ppdList)).at(ppdIndex, 1) - 1 // The integer co-ordinates of the PPD (0-based) val a3pos = producer.addExpression("ppd_pos1", ppdGlobalCount / (cellWidthPPDs(0)*cellWidthPPDs(1))) @@ -38,7 +38,7 @@ object PPDFunctionSet { val a1pos = producer.addExpression("ppd_pos3", ppdGlobalCount % cellWidthPPDs(0)) val ppdPos = List(a1pos, a2pos, a3pos) - val tightbox = (~(basis % FunctionBasis.tightBoxes)).readAt(sphereIndex) + val tightbox = (~(basis % FunctionBasis.tightBoxes)).at(sphereIndex) // The offsets into the PPDs for the edges of the tightbox val ppdStartOffsets = for(dim <- 0 to 2) yield tightbox % TightBox.startPts(dim) - 1 @@ -85,7 +85,7 @@ object PPDFunctionSet { + ppdIndices(1) * (CellInfo.public % CellInfo.ppdWidth(0)) + ppdIndices(0)) - val dataValue = producer.addExpression("data", data.readAt(ppdDataIndex)) + val dataValue = producer.addExpression("data", data.at(ppdDataIndex)) val discreteIndices = List[DiscreteIndex](new SphereIndex("sphere", sphereIndex)) val spatialIndices = { @@ -117,8 +117,8 @@ class PPDFunctionSet(basisName: String, dataName: String, indices: Seq[NamedInde } val fftbox = new DeclaredVarSymbol[ArrayType[FloatType]]("fftbox", new ArrayType[FloatType](3)) - val tightbox = (~(basis % FunctionBasis.tightBoxes)).readAt(sphereIndex) - val sphere = (~(basis % FunctionBasis.spheres)).readAt(sphereIndex) + val tightbox = (~(basis % FunctionBasis.tightBoxes)).at(sphereIndex) + val sphere = (~(basis % FunctionBasis.spheres)).at(sphereIndex) val fftboxOffset = for(dim <- 0 to 2) yield new DeclaredVarSymbol[IntType]("fftbox_offset"+(dim+1)) val reciprocalBox = new DeclaredVarSymbol[ArrayType[ComplexType]]("reciprocal_box", new ArrayType[ComplexType](3)) @@ -129,7 +129,7 @@ class PPDFunctionSet(basisName: String, dataName: String, indices: Seq[NamedInde context.addDeclaration(reciprocalBox) fftboxOffset.map(context.addDeclaration(_)) - val fftboxSize : Seq[Expression[IntType]] = for (dim <- 0 to 2) yield FFTBoxInfo.public % FFTBoxInfo.totalPts(dim) + val fftboxSize : Seq[Expression[IntType]] = getSize context += new AllocateStatement(fftbox, fftboxSize) context += new AllocateStatement(reciprocalBox, fftboxSize) context += new FunctionCallStatement(new FunctionCall(OnetepFunctions.basis_ket_start_wrt_fftbox, @@ -153,9 +153,13 @@ class PPDFunctionSet(basisName: String, dataName: String, indices: Seq[NamedInde def teardown(context: GenerationContext) { context += new DeallocateStatement(reciprocalBox) } + + def getSize = for (dim <- 0 to 2) yield OnetepTypes.FFTBoxInfo.public % OnetepTypes.FFTBoxInfo.totalPts(dim) + + def getBuffer = reciprocalBox } - def getSphereIndex = indices.head + private def getSphereIndex = indices.head def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : FieldFragment = new LocalFragment(this, indices) -- 2.47.3