names += name
symbols += s -> new SymbolInfo(name)
} else {
- throw new LogicError("Attempted to add symbol more than once (multiple declarations?): "+sym.toString)
+ throw new LogicError("Attempted to add symbol more than once (multiple declarations?): "+sym.getName)
}
case (_: NamedUnboundVarSymbol[_]) => throw new LogicError("Attempted to add unbound symbol to SymbolManager.")
def getName(sym: VarSymbol[_ <: Type]) =
symbols.get(sym) match {
- case None => throw new LogicError("Unknown symbol (missing declaration?): "+sym.toString)
+ case None => throw new LogicError("Unknown symbol (missing declaration?): "+sym.getName)
case Some(info) => info.getName
}
case (i : IfStatement) => processIf(i)
case (a : AllocateStatement) => processAllocate(a)
case (d : DeallocateStatement) => processDeallocate(d)
+ case (f : FunctionCallStatement) => processFunctionCallStatement(f)
case x => throw new UnimplementedException("Unknown statement type in FORTRAN generator: " + x.toString)
}
}
addLine(footer)
}
+ private def processFunctionCallStatement(stat: FunctionCallStatement) {
+ val call = stat.getCall
+ call.getSignature match {
+ case (fortSub: FortranSubroutineSignature) =>
+ addLine("call %s(%s)".format(fortSub.getName, call.getParams.map(buildExpression(_)).mkString(", ")))
+ case _ => throw new LogicError("Fortran generator only knows how to call Fortran functions.")
+ }
+ }
+
private def processScope(scope: ScopeStatement) {
for (sym <- scope.getDeclarations) {
symbolManager.addSymbol(sym)
--- /dev/null
+package ofc.codegen
+import ofc.LogicError
+
+class FunctionCall[R <: Type](signature: FunctionSignature[R], params: Seq[Expression[_]]) extends Expression[R] {
+ if (params.size != signature.getParams.size)
+ throw new LogicError("Function "+signature.getName+" called with incorrect number of parameters.")
+
+ for((param, (name, paramType)) <- params.zip(signature.getParams))
+ if (param.getType != paramType)
+ throw new LogicError("Type mismatch on parameter "+name+" when calling function "+signature.getName)
+
+ def getType = signature.getReturnType
+ def foreach[U](f: Expression[_] => U) = params.foreach(f)
+ def getSignature = signature
+ def getParams = params
+}
--- /dev/null
+package ofc.codegen
+
+class FunctionCallStatement(call: FunctionCall[VoidType]) extends Statement {
+ def getCall : FunctionCall[VoidType] = call
+}
trait FunctionSignature[R <: Type] {
def getName: String
+ def getReturnType: R
+ def getParams: Seq[(String, Type)]
}
-class FortranFunctionSignature[R <: Type](name: String,
- returnType: R, params: Seq[(String, Type)]) extends FunctionSignature[R] {
+class FortranSubroutineSignature(name: String,
+ params: Seq[(String, Type)]) extends FunctionSignature[VoidType] {
def getName = name
+ def getReturnType = new VoidType
+ def getParams = params
}
package ofc.codegen
+import ofc.LogicError
sealed abstract class Type {
def getFortranAttributes : Set[String]
def getFortranAttributes = Set("logical")
}
+final case class VoidType() extends PrimitiveType {
+ def getFortranAttributes = throw new LogicError("void type does not exist in Fortran.")
+}
+
final case class ArrayType[ElementType <: Type](rank: Int, eType: ElementType) extends Type {
def this(rank: Int)(implicit builder: TypeBuilder[ElementType]) = this(rank, builder())
def getElementType = eType
import ofc.codegen._
class CodeGenerator(dictionary: Dictionary) {
- val indexMap : Map[NamedIndex, Expression[IntType]] = {
+ val indexSyms : Map[NamedIndex, DeclaredVarSymbol[IntType]] = {
for(index <- dictionary.getIndices) yield
- (index, new VarRef[IntType](new DeclaredVarSymbol[IntType](index.getName)))
+ (index, new DeclaredVarSymbol[IntType](index.getName))
+ }.toMap
+
+ val indexMap : Map[NamedIndex, Expression[IntType]] = {
+ for((index, sym) <- indexSyms) yield
+ (index, new VarRef[IntType](sym))
}.toMap
class Context extends GenerationContext {
val rhs = assignment.rhs
val context = new Context
+
+ //TODO: Remove me when symbols are created properly
+ for((index, sym) <- indexSyms)
+ context.addDeclaration(sym)
+
val rhsFragment = rhs.getFragment(indexMap)
rhsFragment.setup(context)
--- /dev/null
+package ofc.generators.onetep
+import ofc.codegen._
+
+object OnetepFunctions {
+ val basis_copy_function_to_fftbox = new FortranSubroutineSignature("basis_copy_function_to_fftbox",
+ Seq(("fa_fftbox", new ArrayType[FloatType](3)),
+ ("fa_start1", new IntType),
+ ("fa_start2", new IntType),
+ ("fa_start3", new IntType),
+ ("fa_tightbox", OnetepTypes.TightBox),
+ ("fa_on_grid", new ArrayType[FloatType](1)),
+ ("fa_sphere", OnetepTypes.Sphere)))
+
+ val basis_ket_start_wrt_fftbox = new FortranSubroutineSignature("basis_ket_start_wrt_fftbox",
+ Seq(("row_start1", new IntType),
+ ("row_start1", new IntType),
+ ("row_start3", new IntType),
+ ("n1", new IntType),
+ ("n2", new IntType),
+ ("n3", new IntType)))
+}
package ofc.generators.onetep
import ofc.codegen._
+import ofc.LogicError
/*
object PPDFunctionSet {
private class SphereIndex(name: String, value: Expression[IntType]) extends DiscreteIndex {
*/
class PPDFunctionSet(basisName: String, dataName: String, indices: Seq[NamedIndex]) extends Field {
- class LocalFragment(parent: PPDFunctionSet) extends PsincFragment {
+ val basis = new NamedUnboundVarSymbol[StructType](basisName, OnetepTypes.FunctionBasis)
+ val data = new NamedUnboundVarSymbol[ArrayType[FloatType]](dataName, new ArrayType[FloatType](1))
+
+ class LocalFragment(parent: PPDFunctionSet, indices: Map[NamedIndex, Expression[IntType]]) extends PsincFragment {
def setup(context: GenerationContext) {}
def teardown(context: GenerationContext) {}
- def toReciprocal : ReciprocalFragment = new LocalReciprocal(parent)
+ def toReciprocal : ReciprocalFragment = new LocalReciprocal(parent, indices)
}
- class LocalReciprocal(parent: PPDFunctionSet) extends ReciprocalFragment {
+ class LocalReciprocal(parent: PPDFunctionSet, indices: Map[NamedIndex, Expression[IntType]]) extends ReciprocalFragment {
+ import OnetepTypes.FunctionBasis
+
+ val sphereIndex = indices.get(parent.getSphereIndex) match {
+ case Some(expression) => expression
+ case None => throw new LogicError("Cannot find expression for index "+parent.getSphereIndex)
+ }
+
val fftbox = new DeclaredVarSymbol[ArrayType[FloatType]]("fftbox", new ArrayType[FloatType](3))
+ val tightbox = (~(basis % FunctionBasis.tightBoxes)).readAt(sphereIndex)
+ val sphere = (~(basis % FunctionBasis.spheres)).readAt(sphereIndex)
+ val fftboxOffset = for(dim <- 0 to 2) yield new DeclaredVarSymbol[IntType]("fftbox_offset"+(dim+1))
def setup(context: GenerationContext) {
import OnetepTypes.FFTBoxInfo
context.addDeclaration(fftbox)
- context += new AllocateStatement(fftbox, for (dim <- 0 to 2) yield FFTBoxInfo.public % FFTBoxInfo.totalPts(dim))
+ fftboxOffset.map(context.addDeclaration(_))
+
+ val fftboxSize : Seq[Expression[IntType]] = for (dim <- 0 to 2) yield FFTBoxInfo.public % FFTBoxInfo.totalPts(dim)
+ context += new AllocateStatement(fftbox, fftboxSize)
+ context += new FunctionCallStatement(new FunctionCall(OnetepFunctions.basis_ket_start_wrt_fftbox,
+ fftboxOffset.map(new VarRef[IntType](_)) ++ fftboxSize))
+
+ var basisCopyParams : Seq[Expression[_]] = Nil
+ basisCopyParams :+= new VarRef[ArrayType[FloatType]](fftbox)
+ basisCopyParams ++= fftboxOffset.map(new VarRef[IntType](_))
+ basisCopyParams :+= tightbox
+ basisCopyParams :+= new VarRef[ArrayType[FloatType]](parent.data)
+ basisCopyParams :+= sphere
+
+ context += new FunctionCallStatement(new FunctionCall(OnetepFunctions.basis_copy_function_to_fftbox, basisCopyParams))
}
def teardown(context: GenerationContext) {
}
}
+ def getSphereIndex = indices.head
+
def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : FieldFragment =
- new LocalFragment(this)
+ new LocalFragment(this, indices)
}