]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Generate code to transform into reciprocal space.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Wed, 2 May 2012 15:13:45 +0000 (16:13 +0100)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Wed, 2 May 2012 15:13:45 +0000 (16:13 +0100)
src/ofc/codegen/Expression.scala
src/ofc/codegen/FortranGenerator.scala
src/ofc/codegen/Type.scala
src/ofc/generators/onetep/OnetepFunctions.scala
src/ofc/generators/onetep/PPDFunctionSet.scala

index 1e3e9ff39645c3dad2661d789adc41c24244b6a6..5111d7b738ae00d195d83a28e2e0ee183a143250 100644 (file)
@@ -114,3 +114,8 @@ class FloatLiteral(value: Double) extends Expression[FloatType] with LeafExpress
   def getValue = value
   def getType = new FloatType
 }
+
+class CharLiteral(value: Char) extends Expression[CharType] with LeafExpression {
+  def getValue = value
+  def getType = new CharType
+}
index 69b2568d507e71fb628b76d0965e957f352f4861..2f8f8b9baac3b6091ecfcff9631c74386e5d25a1 100644 (file)
@@ -163,6 +163,7 @@ class FortranGenerator {
     expression match {
       case (i : IntegerLiteral) => ExpHolder(maxPrec, i.getValue.toString)
       case (i : FloatLiteral) => ExpHolder(maxPrec, i.getValue.toString)
+      case (i : CharLiteral) => ExpHolder(maxPrec, "'%s'".format(i.getValue.toString))
       case (a : FieldAccess[_]) => ExpHolder(maxPrec, "%s%%%s".format(buildExpression(a.getStructExpression), a.getField.getName))
       case (r : VarRef[_]) => r.getSymbol match {
         case (s: NamedUnboundVarSymbol[_]) => ExpHolder(maxPrec, s.getName)
index a9fe75f12f61ade10850f590f6c2604ccdd4e00a..cf94a6ecda5c9394afbfa69fdd21cf12f59d9508 100644 (file)
@@ -23,6 +23,14 @@ final case class VoidType() extends PrimitiveType {
   def getFortranAttributes = throw new LogicError("void type does not exist in Fortran.")
 }
 
+final case class CharType() extends PrimitiveType {
+  def getFortranAttributes = Set("character")
+}
+
+final case class ComplexType() extends PrimitiveType {
+  def getFortranAttributes = Set("complex(kind=DP)")
+}
+
 final case class ArrayType[ElementType <: Type](rank: Int, eType: ElementType) extends Type {
   def this(rank: Int)(implicit builder: TypeBuilder[ElementType]) = this(rank, builder())
   def getElementType = eType
@@ -43,8 +51,10 @@ trait TypeBuilder[T <: Type] {
 }
 
 object TypeBuilder {
+  implicit val charBuilder = new TypeBuilder[CharType] { def apply() = new CharType }
   implicit val intBuilder = new TypeBuilder[IntType] { def apply() = new IntType }
   implicit val floatBuilder = new TypeBuilder[FloatType] { def apply() = new FloatType }
+  implicit val complexBuilder = new TypeBuilder[ComplexType] { def apply() = new ComplexType }
   implicit val boolBuilder = new TypeBuilder[BoolType] { def apply() = new BoolType }
 }
 
@@ -56,5 +66,3 @@ object HasProperty {
   implicit val intNumeric = new HasProperty[IntType, Numeric]()
   implicit val floatNumeric = new HasProperty[FloatType, Numeric]()
 }
-
-
index 0e83e04b436c0d4fe647104b9afc8eada5bc4a27..c056ebb74d00a514bada090deb4d4b5a15c81369 100644 (file)
@@ -18,4 +18,11 @@ object OnetepFunctions {
         ("n1", new IntType),
         ("n2", new IntType),
         ("n3", new IntType)))
+
+  val fourier_apply_box_pair = new FortranSubroutineSignature("fourier_apply_box_pair",
+    Seq(("grid", new CharType),
+        ("dir", new CharType),
+        ("rspc1", new ArrayType[FloatType](3)),
+        ("rspc2", new ArrayType[FloatType](3)),
+        ("gspc", new ArrayType[ComplexType](3))))
 }
index cf27b2132c0cb2813337c4adb8a853e17656be22..3bf5480d03cb2efddbc7d14df9be1bac8fc642d7 100644 (file)
@@ -120,15 +120,18 @@ class PPDFunctionSet(basisName: String, dataName: String, indices: Seq[NamedInde
     val tightbox = (~(basis % FunctionBasis.tightBoxes)).readAt(sphereIndex)
     val sphere = (~(basis % FunctionBasis.spheres)).readAt(sphereIndex) 
     val fftboxOffset = for(dim <- 0 to 2) yield new DeclaredVarSymbol[IntType]("fftbox_offset"+(dim+1)) 
+    val reciprocalBox = new DeclaredVarSymbol[ArrayType[ComplexType]]("reciprocal_box", new ArrayType[ComplexType](3))
 
     def setup(context: GenerationContext) {
       import OnetepTypes.FFTBoxInfo
 
       context.addDeclaration(fftbox)
+      context.addDeclaration(reciprocalBox)
       fftboxOffset.map(context.addDeclaration(_))
 
       val fftboxSize : Seq[Expression[IntType]] = for (dim <- 0 to 2) yield FFTBoxInfo.public % FFTBoxInfo.totalPts(dim)
       context += new AllocateStatement(fftbox, fftboxSize)
+      context += new AllocateStatement(reciprocalBox, fftboxSize)
       context += new FunctionCallStatement(new FunctionCall(OnetepFunctions.basis_ket_start_wrt_fftbox,
         fftboxOffset.map(new VarRef[IntType](_)) ++ fftboxSize))
 
@@ -140,10 +143,15 @@ class PPDFunctionSet(basisName: String, dataName: String, indices: Seq[NamedInde
       basisCopyParams :+= sphere
 
       context += new FunctionCallStatement(new FunctionCall(OnetepFunctions.basis_copy_function_to_fftbox, basisCopyParams))
+
+      val fourierParams : Seq[Expression[_]] = Seq(new CharLiteral('C'), new CharLiteral('F'), fftbox, fftbox, reciprocalBox)
+      context += new FunctionCallStatement(new FunctionCall(OnetepFunctions.fourier_apply_box_pair, fourierParams))
+
+      context += new DeallocateStatement(fftbox)
     }
 
     def teardown(context: GenerationContext) {
-      context += new DeallocateStatement(fftbox)
+      context += new DeallocateStatement(reciprocalBox)
     }
   }