From afc356bc0765b5208219997086d14727177f7315 Mon Sep 17 00:00:00 2001 From: Francis Russell Date: Wed, 2 May 2012 16:13:45 +0100 Subject: [PATCH] Generate code to transform into reciprocal space. --- src/ofc/codegen/Expression.scala | 5 +++++ src/ofc/codegen/FortranGenerator.scala | 1 + src/ofc/codegen/Type.scala | 12 ++++++++++-- src/ofc/generators/onetep/OnetepFunctions.scala | 7 +++++++ src/ofc/generators/onetep/PPDFunctionSet.scala | 10 +++++++++- 5 files changed, 32 insertions(+), 3 deletions(-) diff --git a/src/ofc/codegen/Expression.scala b/src/ofc/codegen/Expression.scala index 1e3e9ff..5111d7b 100644 --- a/src/ofc/codegen/Expression.scala +++ b/src/ofc/codegen/Expression.scala @@ -114,3 +114,8 @@ class FloatLiteral(value: Double) extends Expression[FloatType] with LeafExpress def getValue = value def getType = new FloatType } + +class CharLiteral(value: Char) extends Expression[CharType] with LeafExpression { + def getValue = value + def getType = new CharType +} diff --git a/src/ofc/codegen/FortranGenerator.scala b/src/ofc/codegen/FortranGenerator.scala index 69b2568..2f8f8b9 100644 --- a/src/ofc/codegen/FortranGenerator.scala +++ b/src/ofc/codegen/FortranGenerator.scala @@ -163,6 +163,7 @@ class FortranGenerator { expression match { case (i : IntegerLiteral) => ExpHolder(maxPrec, i.getValue.toString) case (i : FloatLiteral) => ExpHolder(maxPrec, i.getValue.toString) + case (i : CharLiteral) => ExpHolder(maxPrec, "'%s'".format(i.getValue.toString)) case (a : FieldAccess[_]) => ExpHolder(maxPrec, "%s%%%s".format(buildExpression(a.getStructExpression), a.getField.getName)) case (r : VarRef[_]) => r.getSymbol match { case (s: NamedUnboundVarSymbol[_]) => ExpHolder(maxPrec, s.getName) diff --git a/src/ofc/codegen/Type.scala b/src/ofc/codegen/Type.scala index a9fe75f..cf94a6e 100644 --- a/src/ofc/codegen/Type.scala +++ b/src/ofc/codegen/Type.scala @@ -23,6 +23,14 @@ final case class VoidType() extends PrimitiveType { def getFortranAttributes = throw new LogicError("void type does not exist in Fortran.") } +final case class CharType() extends PrimitiveType { + def getFortranAttributes = Set("character") +} + +final case class ComplexType() extends PrimitiveType { + def getFortranAttributes = Set("complex(kind=DP)") +} + final case class ArrayType[ElementType <: Type](rank: Int, eType: ElementType) extends Type { def this(rank: Int)(implicit builder: TypeBuilder[ElementType]) = this(rank, builder()) def getElementType = eType @@ -43,8 +51,10 @@ trait TypeBuilder[T <: Type] { } object TypeBuilder { + implicit val charBuilder = new TypeBuilder[CharType] { def apply() = new CharType } implicit val intBuilder = new TypeBuilder[IntType] { def apply() = new IntType } implicit val floatBuilder = new TypeBuilder[FloatType] { def apply() = new FloatType } + implicit val complexBuilder = new TypeBuilder[ComplexType] { def apply() = new ComplexType } implicit val boolBuilder = new TypeBuilder[BoolType] { def apply() = new BoolType } } @@ -56,5 +66,3 @@ object HasProperty { implicit val intNumeric = new HasProperty[IntType, Numeric]() implicit val floatNumeric = new HasProperty[FloatType, Numeric]() } - - diff --git a/src/ofc/generators/onetep/OnetepFunctions.scala b/src/ofc/generators/onetep/OnetepFunctions.scala index 0e83e04..c056ebb 100644 --- a/src/ofc/generators/onetep/OnetepFunctions.scala +++ b/src/ofc/generators/onetep/OnetepFunctions.scala @@ -18,4 +18,11 @@ object OnetepFunctions { ("n1", new IntType), ("n2", new IntType), ("n3", new IntType))) + + val fourier_apply_box_pair = new FortranSubroutineSignature("fourier_apply_box_pair", + Seq(("grid", new CharType), + ("dir", new CharType), + ("rspc1", new ArrayType[FloatType](3)), + ("rspc2", new ArrayType[FloatType](3)), + ("gspc", new ArrayType[ComplexType](3)))) } diff --git a/src/ofc/generators/onetep/PPDFunctionSet.scala b/src/ofc/generators/onetep/PPDFunctionSet.scala index cf27b21..3bf5480 100644 --- a/src/ofc/generators/onetep/PPDFunctionSet.scala +++ b/src/ofc/generators/onetep/PPDFunctionSet.scala @@ -120,15 +120,18 @@ class PPDFunctionSet(basisName: String, dataName: String, indices: Seq[NamedInde val tightbox = (~(basis % FunctionBasis.tightBoxes)).readAt(sphereIndex) val sphere = (~(basis % FunctionBasis.spheres)).readAt(sphereIndex) val fftboxOffset = for(dim <- 0 to 2) yield new DeclaredVarSymbol[IntType]("fftbox_offset"+(dim+1)) + val reciprocalBox = new DeclaredVarSymbol[ArrayType[ComplexType]]("reciprocal_box", new ArrayType[ComplexType](3)) def setup(context: GenerationContext) { import OnetepTypes.FFTBoxInfo context.addDeclaration(fftbox) + context.addDeclaration(reciprocalBox) fftboxOffset.map(context.addDeclaration(_)) val fftboxSize : Seq[Expression[IntType]] = for (dim <- 0 to 2) yield FFTBoxInfo.public % FFTBoxInfo.totalPts(dim) context += new AllocateStatement(fftbox, fftboxSize) + context += new AllocateStatement(reciprocalBox, fftboxSize) context += new FunctionCallStatement(new FunctionCall(OnetepFunctions.basis_ket_start_wrt_fftbox, fftboxOffset.map(new VarRef[IntType](_)) ++ fftboxSize)) @@ -140,10 +143,15 @@ class PPDFunctionSet(basisName: String, dataName: String, indices: Seq[NamedInde basisCopyParams :+= sphere context += new FunctionCallStatement(new FunctionCall(OnetepFunctions.basis_copy_function_to_fftbox, basisCopyParams)) + + val fourierParams : Seq[Expression[_]] = Seq(new CharLiteral('C'), new CharLiteral('F'), fftbox, fftbox, reciprocalBox) + context += new FunctionCallStatement(new FunctionCall(OnetepFunctions.fourier_apply_box_pair, fourierParams)) + + context += new DeallocateStatement(fftbox) } def teardown(context: GenerationContext) { - context += new DeallocateStatement(fftbox) + context += new DeallocateStatement(reciprocalBox) } } -- 2.47.3