]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Implement application of laplacian operator.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Wed, 2 May 2012 19:30:38 +0000 (20:30 +0100)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Wed, 2 May 2012 19:30:38 +0000 (20:30 +0100)
src/ofc/codegen/Assignment.scala
src/ofc/codegen/Expression.scala
src/ofc/codegen/FortranGenerator.scala
src/ofc/codegen/ProducerStatement.scala
src/ofc/codegen/Type.scala
src/ofc/generators/onetep/FieldFragment.scala
src/ofc/generators/onetep/Laplacian.scala
src/ofc/generators/onetep/OnetepTypes.scala
src/ofc/generators/onetep/PPDFunctionSet.scala

index e932eba90f349831c6b533a315fdd1fdb70fc592..e0b4d0ce913fc06033c8d9f3a45e0c3b96c08970 100644 (file)
@@ -1,7 +1,10 @@
 package ofc.codegen
+import ofc.LogicError
+
+class AssignStatement(lhs: Expression[_ <: Type], rhs: Expression[_ <: Type]) extends Statement {
+  if (lhs.getType != rhs.getType)
+    throw new LogicError("Assignment from incompatible type.")
 
-class Assignment(lhs: Expression[_ <: Type], rhs: Expression[_ <: Type]) extends Statement {
   def getLHS : Expression[_] = lhs
   def getRHS : Expression[_] = rhs
-  // TODO: type check assignment
 }
index 5111d7b738ae00d195d83a28e2e0ee183a143250..b1433f1cf6775b3a2d3ba33248ee5454fddb226a 100644 (file)
@@ -59,8 +59,8 @@ abstract class Expression[T <: Type] extends Traversable[Expression[_]] {
   def %[FieldType <: Type](field: FieldSymbol[FieldType])(implicit witness: <:<[this.type, Expression[StructType]]) : Expression[FieldType] =
     new FieldAccess[FieldType](witness(this), field)
 
-  def readAt[T <: Type](index: Expression[IntType]*)(implicit witness: <:<[this.type, Expression[ArrayType[T]]]) : Expression[T] = 
-    new ArrayRead(witness(this), index.toList)
+  def at[T <: Type](index: Expression[IntType]*)(implicit witness: <:<[this.type, Expression[ArrayType[T]]]) : Expression[T] = 
+    new ArrayAccess(witness(this), index.toList)
 
   def unary_~[T <: Type]()(implicit witness: <:<[this.type, Expression[PointerType[T]]]) : Expression[T] =
     new PointerDereference(witness(this))
@@ -89,7 +89,7 @@ class FieldAccess[T <: Type](expression: Expression[StructType], field: FieldSym
   def getType = field.getType
 }
 
-class ArrayRead[E <: Type](expression: Expression[ArrayType[E]], index: Seq[Expression[IntType]]) extends Expression[E] {
+class ArrayAccess[E <: Type](expression: Expression[ArrayType[E]], index: Seq[Expression[IntType]]) extends Expression[E] {
   if (index.size != expression.getType.getRank)
     throw new LogicError("Array of rank "+expression.getType.getRank+" indexed with rank "+index.size+" index.")
 
@@ -104,6 +104,16 @@ class PointerDereference[E <: Type](expression: Expression[PointerType[E]]) exte
   def getType = expression.getType.getTargetType
 }
 
+// Conversation
+class Conversion[F <: Type, T <: Type](expression: Expression[F], toType: T)(implicit convertible: IsConvertible[F,T]) extends Expression[T] {
+  def this(expression: Expression[F])(implicit builder: TypeBuilder[T], convertible: IsConvertible[F,T]) 
+    = this(expression, builder())(convertible)
+
+  def getType = toType
+  def getExpression = expression
+  def foreach[U](f: Expression[_] => U) = expression.foreach(f)
+}
+
 // Literals
 class IntegerLiteral(value: Int) extends Expression[IntType] with LeafExpression {
   def getValue = value
index 2f8f8b9baac3b6091ecfcff9631c74386e5d25a1..9f0ecd4f091a0cefe23e2effe2bdb51580b04a99 100644 (file)
@@ -141,7 +141,7 @@ class FortranGenerator {
       case (x : BlockStatement) => processScope(x)
       case (x : ProducerStatement) => processStatement(x.toConcrete)
       case (x : ForLoop) => processForLoop(x)
-      case (a : Assignment) => processAssignment(a)
+      case (a : AssignStatement) => processAssignment(a)
       case (i : IfStatement) => processIf(i)
       case (a : AllocateStatement) => processAllocate(a)
       case (d : DeallocateStatement) => processDeallocate(d)
@@ -169,12 +169,13 @@ class FortranGenerator {
         case (s: NamedUnboundVarSymbol[_]) => ExpHolder(maxPrec, s.getName)
         case s => ExpHolder(maxPrec, symbolManager.getName(s))
       }
-      case (r: ArrayRead[_]) => 
+      case (r: ArrayAccess[_]) => 
         ExpHolder(maxPrec, buildExpression(r.getArrayExpression) + r.getIndexExpressions.map(buildExpression(_)).mkString("(",", ",")"))
       case (d: PointerDereference[_]) => buildExpression(d.getExpression)
       case (c: ConditionalValue[_]) => buildConditionalValue(c)
       case (c: NumericComparison[_]) => buildNumericComparison(c)
       case (c: NumericOperator[_]) => buildNumericOperator(c)
+      case (c: Conversion[_,_]) => buildConversion(c)
       case x => throw new UnimplementedException("Unknown expression type in FORTRAN generator: " + x.toString)
     }
   }
@@ -216,6 +217,13 @@ class FortranGenerator {
     ExpHolder(opInfo.precedence, opInfo.template.format(lhs, rhs))
   }
 
+  private def buildConversion(c: Conversion[_,_]) : ExpHolder = c.getType match {
+    case (_: FloatType) => ExpHolder(maxPrec, "real(%s)".format(buildExpression(c.getExpression)))
+    case (_: ComplexType) => ExpHolder(maxPrec, "cmplx(%s)".format(buildExpression(c.getExpression)))
+    case _ => throw new UnimplementedException("Fortran generator cannot handle conversion.")
+  }
+    
+
   private def buildNumericComparison(c: NumericComparison[_]) : ExpHolder =
     buildBinaryOperation(getBinaryOpInfo(c.getOperation), buildExpression(c.getLeft), buildExpression(c.getRight))
 
@@ -255,7 +263,7 @@ class FortranGenerator {
     }
   }
 
-  private def processAssignment(assignment: Assignment) {
+  private def processAssignment(assignment: AssignStatement) {
     addLine("%s = %s".format(buildExpression(assignment.getLHS), buildExpression(assignment.getRHS)))
   }
 
index 485781f4e89102a685d763d71a61eff38bdc486b..0c182ae0fafdbf1e1c8c9031ea0c9164bf3d8e62 100644 (file)
@@ -133,7 +133,7 @@ class ProducerStatement extends Statement {
           scope = ifStat
         }
         case DerivedExpression(sym, expression) => {
-          val assignment = new Assignment(sym, expression)
+          val assignment = new AssignStatement(sym, expression)
           scope.addDeclaration(sym)
           scope += assignment
         }
index cf94a6ecda5c9394afbfa69fdd21cf12f59d9508..1a114de1cebcd7668138f01577f986cb6807e698 100644 (file)
@@ -65,4 +65,11 @@ class HasProperty[T <: Type, P <: TypeProperty]
 object HasProperty {
   implicit val intNumeric = new HasProperty[IntType, Numeric]()
   implicit val floatNumeric = new HasProperty[FloatType, Numeric]()
+  implicit val complexNumeric = new HasProperty[ComplexType, Numeric]()
+}
+
+class IsConvertible[From <: Type, To <: Type]
+object IsConvertible {
+  implicit val intToFloat = new IsConvertible[IntType, FloatType]()
+  implicit val floatToFloat = new IsConvertible[FloatType, ComplexType]()
 }
index a937371990ab5771cad1918cd452d559c00cba6d..cd954d647c534607af7664ec2056db28c13db690 100644 (file)
@@ -1,4 +1,5 @@
 package ofc.generators.onetep
+import ofc.codegen._
 
 trait FieldFragment extends Fragment {
   def toReciprocal : ReciprocalFragment
@@ -8,4 +9,6 @@ trait PsincFragment extends FieldFragment
 
 trait ReciprocalFragment extends FieldFragment {
   def toReciprocal = this
+  def getSize : Seq[Expression[IntType]] 
+  def getBuffer : Expression[ArrayType[ComplexType]]
 }
index 9588fad3410901e853c15cf64a3576bc91414240..5f771f725ee388559f5ac0af5d8daa8b37cce03f 100644 (file)
@@ -2,5 +2,71 @@ package ofc.generators.onetep
 import ofc.codegen._
 
 class Laplacian(op: Field)  extends Field {
-  def getFragment(indices: Map[NamedIndex, Expression[IntType]]) = op.getFragment(indices).toReciprocal
+  class LocalFragment(parent: Laplacian, indices: Map[NamedIndex, Expression[IntType]]) extends ReciprocalFragment {
+    val transformed = new DeclaredVarSymbol[ArrayType[ComplexType]]("transformed", new ArrayType[ComplexType](3))
+    val opFragment = parent.getOperand.getFragment(indices).toReciprocal
+
+    def setup(context: GenerationContext) {
+      context.addDeclaration(transformed)
+      opFragment.setup(context)
+
+      context += new AllocateStatement(transformed, opFragment.getSize)
+
+      val indices = for(dim <- 0 to 2) yield {
+        val index = new DeclaredVarSymbol[IntType]("i"+(dim+1))
+        context.addDeclaration(index)
+        index
+      }
+
+      // Construct loops
+      val loops = for(dim <- 0 to 2) yield new ForLoop(indices(dim), 1, getSize(dim))
+
+      // Nest loops and add outer to context
+      for(dim <- 1 to 2) loops(dim) += loops(dim-1)
+      context += loops(2)
+      
+      val reciprocalVector = for(dim <- 0 to 2) yield {
+        val component = new DeclaredVarSymbol[FloatType]("reciprocal_vector"+(dim+1))
+        context.addDeclaration(component)
+        new VarRef[FloatType](component)
+      }
+
+      for(dim <- 0 to 2) {
+        var component : Expression[FloatType] = new FloatLiteral(0.0)
+        for(vec <- 0 to 2) {
+          val vector = OnetepTypes.CellInfo.public % OnetepTypes.CellInfo.latticeReciprocal(vec)
+          component = component + vector % OnetepTypes.Point.coord(dim) * new Conversion[IntType, FloatType](indices(dim))
+        }
+        loops(0) += new AssignStatement(reciprocalVector(dim), component)
+      }
+
+      val reciprocalIndex = indices.map(new VarRef[IntType](_))
+      
+      //TODO: Use a unary negation instead of multiplication by -1.0.
+      loops(0) += new AssignStatement(transformed.at(reciprocalIndex: _*), 
+        opFragment.getBuffer.at(reciprocalIndex: _*) * 
+        new Conversion[FloatType, ComplexType](magnitude(reciprocalVector) * new FloatLiteral(-1.0)))
+
+      opFragment.teardown(context)
+    }
+
+    private def magnitude(vector: Seq[Expression[FloatType]]) = {
+      var result : Expression[FloatType] = new FloatLiteral(0.0)
+      for(element <- vector) result += element * element
+      result
+    }
+
+    def teardown(context: GenerationContext) {
+      context += new DeallocateStatement(transformed)
+    }
+
+    def getSize = opFragment.getSize
+
+    def getBuffer = transformed
+  }
+  
+  private def getOperand = op
+
+  def getFragment(indices: Map[NamedIndex, Expression[IntType]]) = 
+    new LocalFragment(this, indices)
 }
index fd8eabdc117b0b77ef0b4cead65fb0eac3cdbd57..f2f1d7e2e27e17a4f05020a0a8e42d820ceff334 100644 (file)
@@ -34,11 +34,20 @@ object OnetepTypes {
     def getFortranAttributes = Set("type(SPHERE)")
   }
 
+  object Point extends StructType {
+    val x = new FieldSymbol[FloatType]("X")
+    val y = new FieldSymbol[FloatType]("Y")
+    val z = new FieldSymbol[FloatType]("Z")
+    val coord = List(x,y,z)
+    def getFortranAttributes = Set("type(POINT)")
+  }
+
   object CellInfo extends StructType {
     val public = new NamedUnboundVarSymbol[StructType]("pub_cell", OnetepTypes.CellInfo)
     val ppdWidth = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("n_pt"+dim)}.toSeq
     val numPPDs = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("n_ppds_a"+dim)}.toSeq
     val pointsInPPD = new FieldSymbol[IntType]("n_pts")
+    val latticeReciprocal = for(dim <- 1 to 3) yield new FieldSymbol[StructType]("b"+dim, Point)
     def getFortranAttributes = Set("type(CELL_INFO)")
   }
 
index 3bf5480d03cb2efddbc7d14df9be1bac8fc642d7..754e0e7bacd110bb849bafbd283d2e66ee030631 100644 (file)
@@ -27,10 +27,10 @@ object PPDFunctionSet {
 
     val producer = new ProducerStatement
     val sphereIndex = producer.addIteration("sphere_index", 1, numSpheres)
-    val numPPDs = (~(basis % FunctionBasis.numPPDsInSphere)).readAt(sphereIndex)
+    val numPPDs = (~(basis % FunctionBasis.numPPDsInSphere)).at(sphereIndex)
     val ppdIndex = producer.addIteration("ppd_index", 1, numPPDs)
-    val sphere = (~(basis % FunctionBasis.spheres)).readAt(sphereIndex) 
-    val ppdGlobalCount = (~(sphere % Sphere.ppdList)).readAt(ppdIndex, 1) - 1
+    val sphere = (~(basis % FunctionBasis.spheres)).at(sphereIndex) 
+    val ppdGlobalCount = (~(sphere % Sphere.ppdList)).at(ppdIndex, 1) - 1
 
     // The integer co-ordinates of the PPD (0-based)
     val a3pos = producer.addExpression("ppd_pos1", ppdGlobalCount / (cellWidthPPDs(0)*cellWidthPPDs(1)))
@@ -38,7 +38,7 @@ object PPDFunctionSet {
     val a1pos = producer.addExpression("ppd_pos3", ppdGlobalCount % cellWidthPPDs(0))
     val ppdPos = List(a1pos, a2pos, a3pos)
 
-    val tightbox = (~(basis % FunctionBasis.tightBoxes)).readAt(sphereIndex)
+    val tightbox = (~(basis % FunctionBasis.tightBoxes)).at(sphereIndex)
 
     // The offsets into the PPDs for the edges of the tightbox
     val ppdStartOffsets = for(dim <- 0 to 2) yield tightbox % TightBox.startPts(dim) - 1
@@ -85,7 +85,7 @@ object PPDFunctionSet {
       + ppdIndices(1) * (CellInfo.public % CellInfo.ppdWidth(0))
       + ppdIndices(0))
 
-    val dataValue = producer.addExpression("data", data.readAt(ppdDataIndex))
+    val dataValue = producer.addExpression("data", data.at(ppdDataIndex))
 
     val discreteIndices = List[DiscreteIndex](new SphereIndex("sphere", sphereIndex))
     val spatialIndices = {
@@ -117,8 +117,8 @@ class PPDFunctionSet(basisName: String, dataName: String, indices: Seq[NamedInde
     }
 
     val fftbox = new DeclaredVarSymbol[ArrayType[FloatType]]("fftbox", new ArrayType[FloatType](3))
-    val tightbox = (~(basis % FunctionBasis.tightBoxes)).readAt(sphereIndex)
-    val sphere = (~(basis % FunctionBasis.spheres)).readAt(sphereIndex) 
+    val tightbox = (~(basis % FunctionBasis.tightBoxes)).at(sphereIndex)
+    val sphere = (~(basis % FunctionBasis.spheres)).at(sphereIndex) 
     val fftboxOffset = for(dim <- 0 to 2) yield new DeclaredVarSymbol[IntType]("fftbox_offset"+(dim+1)) 
     val reciprocalBox = new DeclaredVarSymbol[ArrayType[ComplexType]]("reciprocal_box", new ArrayType[ComplexType](3))
 
@@ -129,7 +129,7 @@ class PPDFunctionSet(basisName: String, dataName: String, indices: Seq[NamedInde
       context.addDeclaration(reciprocalBox)
       fftboxOffset.map(context.addDeclaration(_))
 
-      val fftboxSize : Seq[Expression[IntType]] = for (dim <- 0 to 2) yield FFTBoxInfo.public % FFTBoxInfo.totalPts(dim)
+      val fftboxSize : Seq[Expression[IntType]] = getSize
       context += new AllocateStatement(fftbox, fftboxSize)
       context += new AllocateStatement(reciprocalBox, fftboxSize)
       context += new FunctionCallStatement(new FunctionCall(OnetepFunctions.basis_ket_start_wrt_fftbox,
@@ -153,9 +153,13 @@ class PPDFunctionSet(basisName: String, dataName: String, indices: Seq[NamedInde
     def teardown(context: GenerationContext) {
       context += new DeallocateStatement(reciprocalBox)
     }
+
+    def getSize = for (dim <- 0 to 2) yield OnetepTypes.FFTBoxInfo.public % OnetepTypes.FFTBoxInfo.totalPts(dim)
+
+    def getBuffer = reciprocalBox
   }
 
-  def getSphereIndex = indices.head
+  private def getSphereIndex = indices.head
 
   def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : FieldFragment =
     new LocalFragment(this, indices)