]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Generate calls for copying PPD data to FFT-box.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Tue, 1 May 2012 17:36:26 +0000 (18:36 +0100)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Tue, 1 May 2012 17:36:26 +0000 (18:36 +0100)
src/ofc/codegen/FortranGenerator.scala
src/ofc/codegen/FunctionCall.scala [new file with mode: 0644]
src/ofc/codegen/FunctionCallStatement.scala [new file with mode: 0644]
src/ofc/codegen/FunctionSignature.scala
src/ofc/codegen/Type.scala
src/ofc/generators/onetep/CodeGenerator.scala
src/ofc/generators/onetep/OnetepFunctions.scala [new file with mode: 0644]
src/ofc/generators/onetep/PPDFunctionSet.scala

index 080171d36d5476b72bd5b5aa35c963ed1f6d354e..69b2568d507e71fb628b76d0965e957f352f4861 100644 (file)
@@ -32,7 +32,7 @@ class SymbolManager {
         names += name
         symbols += s -> new SymbolInfo(name)
       } else {
-        throw new LogicError("Attempted to add symbol more than once (multiple declarations?): "+sym.toString)
+        throw new LogicError("Attempted to add symbol more than once (multiple declarations?): "+sym.getName)
       }
 
       case (_: NamedUnboundVarSymbol[_]) => throw new LogicError("Attempted to add unbound symbol to SymbolManager.")
@@ -41,7 +41,7 @@ class SymbolManager {
 
   def getName(sym: VarSymbol[_ <: Type]) =
     symbols.get(sym) match {
-      case None => throw new LogicError("Unknown symbol (missing declaration?): "+sym.toString)
+      case None => throw new LogicError("Unknown symbol (missing declaration?): "+sym.getName)
       case Some(info) => info.getName
     }
 
@@ -145,6 +145,7 @@ class FortranGenerator {
       case (i : IfStatement) => processIf(i)
       case (a : AllocateStatement) => processAllocate(a)
       case (d : DeallocateStatement) => processDeallocate(d)
+      case (f : FunctionCallStatement) => processFunctionCallStatement(f)
       case x => throw new UnimplementedException("Unknown statement type in FORTRAN generator: " + x.toString)
     }
   }
@@ -235,6 +236,15 @@ class FortranGenerator {
     addLine(footer)
   }
 
+  private def processFunctionCallStatement(stat: FunctionCallStatement) {
+    val call = stat.getCall
+    call.getSignature match {
+      case (fortSub: FortranSubroutineSignature) => 
+        addLine("call %s(%s)".format(fortSub.getName, call.getParams.map(buildExpression(_)).mkString(", ")))
+      case _ => throw new LogicError("Fortran generator only knows how to call Fortran functions.")
+    }
+  }
+
   private def processScope(scope: ScopeStatement) {
     for (sym <- scope.getDeclarations) {
       symbolManager.addSymbol(sym)
diff --git a/src/ofc/codegen/FunctionCall.scala b/src/ofc/codegen/FunctionCall.scala
new file mode 100644 (file)
index 0000000..f5925f9
--- /dev/null
@@ -0,0 +1,16 @@
+package ofc.codegen
+import ofc.LogicError
+
+class FunctionCall[R <: Type](signature: FunctionSignature[R], params: Seq[Expression[_]]) extends Expression[R] {
+  if (params.size != signature.getParams.size)
+    throw new LogicError("Function "+signature.getName+" called with incorrect number of parameters.")
+
+  for((param, (name, paramType)) <- params.zip(signature.getParams))
+    if (param.getType != paramType)
+      throw new LogicError("Type mismatch on parameter "+name+" when calling function "+signature.getName)
+
+  def getType = signature.getReturnType
+  def foreach[U](f: Expression[_] => U) = params.foreach(f)
+  def getSignature = signature
+  def getParams = params
+}
diff --git a/src/ofc/codegen/FunctionCallStatement.scala b/src/ofc/codegen/FunctionCallStatement.scala
new file mode 100644 (file)
index 0000000..f4c7fd6
--- /dev/null
@@ -0,0 +1,5 @@
+package ofc.codegen
+
+class FunctionCallStatement(call: FunctionCall[VoidType]) extends Statement {
+  def getCall : FunctionCall[VoidType] = call
+}
index 427c7b55d95a1e43ac579e8fdd60a2f7d40857e0..b27d8a8741da472e2e8b593beedf54f4251c2c5a 100644 (file)
@@ -2,10 +2,14 @@ package ofc.codegen
 
 trait FunctionSignature[R <: Type] {
   def getName: String
+  def getReturnType: R
+  def getParams: Seq[(String, Type)]
 }
 
-class FortranFunctionSignature[R <: Type](name: String, 
-  returnType: R, params: Seq[(String, Type)]) extends FunctionSignature[R] {
+class FortranSubroutineSignature(name: String, 
+  params: Seq[(String, Type)]) extends FunctionSignature[VoidType] {
 
   def getName = name
+  def getReturnType = new VoidType
+  def getParams = params
 }
index fbcfb900e103da684dde36b33fd3d1e7075751a7..a9fe75f12f61ade10850f590f6c2604ccdd4e00a 100644 (file)
@@ -1,4 +1,5 @@
 package ofc.codegen
+import ofc.LogicError
 
 sealed abstract class Type {
   def getFortranAttributes : Set[String]
@@ -18,6 +19,10 @@ final case class BoolType() extends PrimitiveType {
   def getFortranAttributes = Set("logical")
 }
 
+final case class VoidType() extends PrimitiveType {
+  def getFortranAttributes = throw new LogicError("void type does not exist in Fortran.")
+}
+
 final case class ArrayType[ElementType <: Type](rank: Int, eType: ElementType) extends Type {
   def this(rank: Int)(implicit builder: TypeBuilder[ElementType]) = this(rank, builder())
   def getElementType = eType
index 6b429ed296164640e23b97418876144fbad097f8..a61f87f700eeb7b421b5888fe0d11a652125c61e 100644 (file)
@@ -2,9 +2,14 @@ package ofc.generators.onetep
 import ofc.codegen._
 
 class CodeGenerator(dictionary: Dictionary) {
-  val indexMap : Map[NamedIndex, Expression[IntType]] = {
+  val indexSyms : Map[NamedIndex, DeclaredVarSymbol[IntType]] = {
     for(index <- dictionary.getIndices) yield
-      (index, new VarRef[IntType](new DeclaredVarSymbol[IntType](index.getName)))
+      (index, new DeclaredVarSymbol[IntType](index.getName))
+  }.toMap
+
+  val indexMap : Map[NamedIndex, Expression[IntType]] = {
+    for((index, sym) <- indexSyms) yield
+      (index, new VarRef[IntType](sym))
   }.toMap
 
   class Context extends GenerationContext {
@@ -26,6 +31,11 @@ class CodeGenerator(dictionary: Dictionary) {
     val rhs = assignment.rhs
 
     val context = new Context
+
+    //TODO: Remove me when symbols are created properly
+    for((index, sym) <- indexSyms)
+      context.addDeclaration(sym)
+
     val rhsFragment = rhs.getFragment(indexMap)
 
     rhsFragment.setup(context)
diff --git a/src/ofc/generators/onetep/OnetepFunctions.scala b/src/ofc/generators/onetep/OnetepFunctions.scala
new file mode 100644 (file)
index 0000000..0e83e04
--- /dev/null
@@ -0,0 +1,21 @@
+package ofc.generators.onetep
+import ofc.codegen._
+
+object OnetepFunctions {
+  val basis_copy_function_to_fftbox = new FortranSubroutineSignature("basis_copy_function_to_fftbox", 
+    Seq(("fa_fftbox", new ArrayType[FloatType](3)),
+        ("fa_start1", new IntType),
+        ("fa_start2", new IntType),
+        ("fa_start3", new IntType),
+        ("fa_tightbox", OnetepTypes.TightBox),
+        ("fa_on_grid", new ArrayType[FloatType](1)),
+        ("fa_sphere", OnetepTypes.Sphere)))
+
+  val basis_ket_start_wrt_fftbox = new FortranSubroutineSignature("basis_ket_start_wrt_fftbox",
+    Seq(("row_start1", new IntType),
+        ("row_start1", new IntType),
+        ("row_start3", new IntType),
+        ("n1", new IntType),
+        ("n2", new IntType),
+        ("n3", new IntType)))
+}
index 176e5bd6230056478d95dee79c5decb4fc7099c8..cf27b2132c0cb2813337c4adb8a853e17656be22 100644 (file)
@@ -1,5 +1,6 @@
 package ofc.generators.onetep
 import ofc.codegen._
+import ofc.LogicError
 /*
 object PPDFunctionSet {
   private class SphereIndex(name: String, value: Expression[IntType]) extends DiscreteIndex {
@@ -98,20 +99,47 @@ object PPDFunctionSet {
 */
 
 class PPDFunctionSet(basisName: String, dataName: String, indices: Seq[NamedIndex]) extends Field {
-  class LocalFragment(parent: PPDFunctionSet) extends PsincFragment {
+  val basis = new NamedUnboundVarSymbol[StructType](basisName, OnetepTypes.FunctionBasis)
+  val data =  new NamedUnboundVarSymbol[ArrayType[FloatType]](dataName, new ArrayType[FloatType](1))
+
+  class LocalFragment(parent: PPDFunctionSet, indices: Map[NamedIndex, Expression[IntType]]) extends PsincFragment {
     def setup(context: GenerationContext) {}
     def teardown(context: GenerationContext) {}
-    def toReciprocal : ReciprocalFragment = new LocalReciprocal(parent)
+    def toReciprocal : ReciprocalFragment = new LocalReciprocal(parent, indices)
   }
 
-  class LocalReciprocal(parent: PPDFunctionSet) extends ReciprocalFragment {
+  class LocalReciprocal(parent: PPDFunctionSet, indices: Map[NamedIndex, Expression[IntType]]) extends ReciprocalFragment {
+    import OnetepTypes.FunctionBasis
+
+    val sphereIndex = indices.get(parent.getSphereIndex) match {
+      case Some(expression) => expression
+      case None => throw new LogicError("Cannot find expression for index "+parent.getSphereIndex)
+    }
+
     val fftbox = new DeclaredVarSymbol[ArrayType[FloatType]]("fftbox", new ArrayType[FloatType](3))
+    val tightbox = (~(basis % FunctionBasis.tightBoxes)).readAt(sphereIndex)
+    val sphere = (~(basis % FunctionBasis.spheres)).readAt(sphereIndex) 
+    val fftboxOffset = for(dim <- 0 to 2) yield new DeclaredVarSymbol[IntType]("fftbox_offset"+(dim+1)) 
 
     def setup(context: GenerationContext) {
       import OnetepTypes.FFTBoxInfo
 
       context.addDeclaration(fftbox)
-      context += new AllocateStatement(fftbox, for (dim <- 0 to 2) yield FFTBoxInfo.public % FFTBoxInfo.totalPts(dim))
+      fftboxOffset.map(context.addDeclaration(_))
+
+      val fftboxSize : Seq[Expression[IntType]] = for (dim <- 0 to 2) yield FFTBoxInfo.public % FFTBoxInfo.totalPts(dim)
+      context += new AllocateStatement(fftbox, fftboxSize)
+      context += new FunctionCallStatement(new FunctionCall(OnetepFunctions.basis_ket_start_wrt_fftbox,
+        fftboxOffset.map(new VarRef[IntType](_)) ++ fftboxSize))
+
+      var basisCopyParams : Seq[Expression[_]] = Nil
+      basisCopyParams :+= new VarRef[ArrayType[FloatType]](fftbox)
+      basisCopyParams ++= fftboxOffset.map(new VarRef[IntType](_))
+      basisCopyParams :+= tightbox
+      basisCopyParams :+= new VarRef[ArrayType[FloatType]](parent.data)
+      basisCopyParams :+= sphere
+
+      context += new FunctionCallStatement(new FunctionCall(OnetepFunctions.basis_copy_function_to_fftbox, basisCopyParams))
     }
 
     def teardown(context: GenerationContext) {
@@ -119,6 +147,8 @@ class PPDFunctionSet(basisName: String, dataName: String, indices: Seq[NamedInde
     }
   }
 
+  def getSphereIndex = indices.head
+
   def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : FieldFragment =
-    new LocalFragment(this)
+    new LocalFragment(this, indices)
 }