From: Francis Russell Date: Tue, 1 May 2012 17:36:26 +0000 (+0100) Subject: Generate calls for copying PPD data to FFT-box. X-Git-Url: https://git.unchartedbackwaters.co.uk/w/?a=commitdiff_plain;h=ec1fba551c3a13746b71becb8cb0673adf97e2ed;p=francis%2Fofc.git Generate calls for copying PPD data to FFT-box. --- diff --git a/src/ofc/codegen/FortranGenerator.scala b/src/ofc/codegen/FortranGenerator.scala index 080171d..69b2568 100644 --- a/src/ofc/codegen/FortranGenerator.scala +++ b/src/ofc/codegen/FortranGenerator.scala @@ -32,7 +32,7 @@ class SymbolManager { names += name symbols += s -> new SymbolInfo(name) } else { - throw new LogicError("Attempted to add symbol more than once (multiple declarations?): "+sym.toString) + throw new LogicError("Attempted to add symbol more than once (multiple declarations?): "+sym.getName) } case (_: NamedUnboundVarSymbol[_]) => throw new LogicError("Attempted to add unbound symbol to SymbolManager.") @@ -41,7 +41,7 @@ class SymbolManager { def getName(sym: VarSymbol[_ <: Type]) = symbols.get(sym) match { - case None => throw new LogicError("Unknown symbol (missing declaration?): "+sym.toString) + case None => throw new LogicError("Unknown symbol (missing declaration?): "+sym.getName) case Some(info) => info.getName } @@ -145,6 +145,7 @@ class FortranGenerator { case (i : IfStatement) => processIf(i) case (a : AllocateStatement) => processAllocate(a) case (d : DeallocateStatement) => processDeallocate(d) + case (f : FunctionCallStatement) => processFunctionCallStatement(f) case x => throw new UnimplementedException("Unknown statement type in FORTRAN generator: " + x.toString) } } @@ -235,6 +236,15 @@ class FortranGenerator { addLine(footer) } + private def processFunctionCallStatement(stat: FunctionCallStatement) { + val call = stat.getCall + call.getSignature match { + case (fortSub: FortranSubroutineSignature) => + addLine("call %s(%s)".format(fortSub.getName, call.getParams.map(buildExpression(_)).mkString(", "))) + case _ => throw new LogicError("Fortran generator only knows how to call Fortran functions.") + } + } + private def processScope(scope: ScopeStatement) { for (sym <- scope.getDeclarations) { symbolManager.addSymbol(sym) diff --git a/src/ofc/codegen/FunctionCall.scala b/src/ofc/codegen/FunctionCall.scala new file mode 100644 index 0000000..f5925f9 --- /dev/null +++ b/src/ofc/codegen/FunctionCall.scala @@ -0,0 +1,16 @@ +package ofc.codegen +import ofc.LogicError + +class FunctionCall[R <: Type](signature: FunctionSignature[R], params: Seq[Expression[_]]) extends Expression[R] { + if (params.size != signature.getParams.size) + throw new LogicError("Function "+signature.getName+" called with incorrect number of parameters.") + + for((param, (name, paramType)) <- params.zip(signature.getParams)) + if (param.getType != paramType) + throw new LogicError("Type mismatch on parameter "+name+" when calling function "+signature.getName) + + def getType = signature.getReturnType + def foreach[U](f: Expression[_] => U) = params.foreach(f) + def getSignature = signature + def getParams = params +} diff --git a/src/ofc/codegen/FunctionCallStatement.scala b/src/ofc/codegen/FunctionCallStatement.scala new file mode 100644 index 0000000..f4c7fd6 --- /dev/null +++ b/src/ofc/codegen/FunctionCallStatement.scala @@ -0,0 +1,5 @@ +package ofc.codegen + +class FunctionCallStatement(call: FunctionCall[VoidType]) extends Statement { + def getCall : FunctionCall[VoidType] = call +} diff --git a/src/ofc/codegen/FunctionSignature.scala b/src/ofc/codegen/FunctionSignature.scala index 427c7b5..b27d8a8 100644 --- a/src/ofc/codegen/FunctionSignature.scala +++ b/src/ofc/codegen/FunctionSignature.scala @@ -2,10 +2,14 @@ package ofc.codegen trait FunctionSignature[R <: Type] { def getName: String + def getReturnType: R + def getParams: Seq[(String, Type)] } -class FortranFunctionSignature[R <: Type](name: String, - returnType: R, params: Seq[(String, Type)]) extends FunctionSignature[R] { +class FortranSubroutineSignature(name: String, + params: Seq[(String, Type)]) extends FunctionSignature[VoidType] { def getName = name + def getReturnType = new VoidType + def getParams = params } diff --git a/src/ofc/codegen/Type.scala b/src/ofc/codegen/Type.scala index fbcfb90..a9fe75f 100644 --- a/src/ofc/codegen/Type.scala +++ b/src/ofc/codegen/Type.scala @@ -1,4 +1,5 @@ package ofc.codegen +import ofc.LogicError sealed abstract class Type { def getFortranAttributes : Set[String] @@ -18,6 +19,10 @@ final case class BoolType() extends PrimitiveType { def getFortranAttributes = Set("logical") } +final case class VoidType() extends PrimitiveType { + def getFortranAttributes = throw new LogicError("void type does not exist in Fortran.") +} + final case class ArrayType[ElementType <: Type](rank: Int, eType: ElementType) extends Type { def this(rank: Int)(implicit builder: TypeBuilder[ElementType]) = this(rank, builder()) def getElementType = eType diff --git a/src/ofc/generators/onetep/CodeGenerator.scala b/src/ofc/generators/onetep/CodeGenerator.scala index 6b429ed..a61f87f 100644 --- a/src/ofc/generators/onetep/CodeGenerator.scala +++ b/src/ofc/generators/onetep/CodeGenerator.scala @@ -2,9 +2,14 @@ package ofc.generators.onetep import ofc.codegen._ class CodeGenerator(dictionary: Dictionary) { - val indexMap : Map[NamedIndex, Expression[IntType]] = { + val indexSyms : Map[NamedIndex, DeclaredVarSymbol[IntType]] = { for(index <- dictionary.getIndices) yield - (index, new VarRef[IntType](new DeclaredVarSymbol[IntType](index.getName))) + (index, new DeclaredVarSymbol[IntType](index.getName)) + }.toMap + + val indexMap : Map[NamedIndex, Expression[IntType]] = { + for((index, sym) <- indexSyms) yield + (index, new VarRef[IntType](sym)) }.toMap class Context extends GenerationContext { @@ -26,6 +31,11 @@ class CodeGenerator(dictionary: Dictionary) { val rhs = assignment.rhs val context = new Context + + //TODO: Remove me when symbols are created properly + for((index, sym) <- indexSyms) + context.addDeclaration(sym) + val rhsFragment = rhs.getFragment(indexMap) rhsFragment.setup(context) diff --git a/src/ofc/generators/onetep/OnetepFunctions.scala b/src/ofc/generators/onetep/OnetepFunctions.scala new file mode 100644 index 0000000..0e83e04 --- /dev/null +++ b/src/ofc/generators/onetep/OnetepFunctions.scala @@ -0,0 +1,21 @@ +package ofc.generators.onetep +import ofc.codegen._ + +object OnetepFunctions { + val basis_copy_function_to_fftbox = new FortranSubroutineSignature("basis_copy_function_to_fftbox", + Seq(("fa_fftbox", new ArrayType[FloatType](3)), + ("fa_start1", new IntType), + ("fa_start2", new IntType), + ("fa_start3", new IntType), + ("fa_tightbox", OnetepTypes.TightBox), + ("fa_on_grid", new ArrayType[FloatType](1)), + ("fa_sphere", OnetepTypes.Sphere))) + + val basis_ket_start_wrt_fftbox = new FortranSubroutineSignature("basis_ket_start_wrt_fftbox", + Seq(("row_start1", new IntType), + ("row_start1", new IntType), + ("row_start3", new IntType), + ("n1", new IntType), + ("n2", new IntType), + ("n3", new IntType))) +} diff --git a/src/ofc/generators/onetep/PPDFunctionSet.scala b/src/ofc/generators/onetep/PPDFunctionSet.scala index 176e5bd..cf27b21 100644 --- a/src/ofc/generators/onetep/PPDFunctionSet.scala +++ b/src/ofc/generators/onetep/PPDFunctionSet.scala @@ -1,5 +1,6 @@ package ofc.generators.onetep import ofc.codegen._ +import ofc.LogicError /* object PPDFunctionSet { private class SphereIndex(name: String, value: Expression[IntType]) extends DiscreteIndex { @@ -98,20 +99,47 @@ object PPDFunctionSet { */ class PPDFunctionSet(basisName: String, dataName: String, indices: Seq[NamedIndex]) extends Field { - class LocalFragment(parent: PPDFunctionSet) extends PsincFragment { + val basis = new NamedUnboundVarSymbol[StructType](basisName, OnetepTypes.FunctionBasis) + val data = new NamedUnboundVarSymbol[ArrayType[FloatType]](dataName, new ArrayType[FloatType](1)) + + class LocalFragment(parent: PPDFunctionSet, indices: Map[NamedIndex, Expression[IntType]]) extends PsincFragment { def setup(context: GenerationContext) {} def teardown(context: GenerationContext) {} - def toReciprocal : ReciprocalFragment = new LocalReciprocal(parent) + def toReciprocal : ReciprocalFragment = new LocalReciprocal(parent, indices) } - class LocalReciprocal(parent: PPDFunctionSet) extends ReciprocalFragment { + class LocalReciprocal(parent: PPDFunctionSet, indices: Map[NamedIndex, Expression[IntType]]) extends ReciprocalFragment { + import OnetepTypes.FunctionBasis + + val sphereIndex = indices.get(parent.getSphereIndex) match { + case Some(expression) => expression + case None => throw new LogicError("Cannot find expression for index "+parent.getSphereIndex) + } + val fftbox = new DeclaredVarSymbol[ArrayType[FloatType]]("fftbox", new ArrayType[FloatType](3)) + val tightbox = (~(basis % FunctionBasis.tightBoxes)).readAt(sphereIndex) + val sphere = (~(basis % FunctionBasis.spheres)).readAt(sphereIndex) + val fftboxOffset = for(dim <- 0 to 2) yield new DeclaredVarSymbol[IntType]("fftbox_offset"+(dim+1)) def setup(context: GenerationContext) { import OnetepTypes.FFTBoxInfo context.addDeclaration(fftbox) - context += new AllocateStatement(fftbox, for (dim <- 0 to 2) yield FFTBoxInfo.public % FFTBoxInfo.totalPts(dim)) + fftboxOffset.map(context.addDeclaration(_)) + + val fftboxSize : Seq[Expression[IntType]] = for (dim <- 0 to 2) yield FFTBoxInfo.public % FFTBoxInfo.totalPts(dim) + context += new AllocateStatement(fftbox, fftboxSize) + context += new FunctionCallStatement(new FunctionCall(OnetepFunctions.basis_ket_start_wrt_fftbox, + fftboxOffset.map(new VarRef[IntType](_)) ++ fftboxSize)) + + var basisCopyParams : Seq[Expression[_]] = Nil + basisCopyParams :+= new VarRef[ArrayType[FloatType]](fftbox) + basisCopyParams ++= fftboxOffset.map(new VarRef[IntType](_)) + basisCopyParams :+= tightbox + basisCopyParams :+= new VarRef[ArrayType[FloatType]](parent.data) + basisCopyParams :+= sphere + + context += new FunctionCallStatement(new FunctionCall(OnetepFunctions.basis_copy_function_to_fftbox, basisCopyParams)) } def teardown(context: GenerationContext) { @@ -119,6 +147,8 @@ class PPDFunctionSet(basisName: String, dataName: String, indices: Seq[NamedInde } } + def getSphereIndex = indices.head + def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : FieldFragment = - new LocalFragment(this) + new LocalFragment(this, indices) }