From 520fdb0c8a13c0b069983f69977da13e49e5bbbf Mon Sep 17 00:00:00 2001 From: Francis Russell Date: Sun, 8 Apr 2012 03:14:49 +0100 Subject: [PATCH] Generate variable declarations. --- src/ofc/codegen/FortranGenerator.scala | 30 ++++++++++++------- src/ofc/codegen/Type.scala | 20 ++++++++++--- src/ofc/generators/onetep/CodeGenerator.scala | 2 +- src/ofc/generators/onetep/OnetepTypes.scala | 4 +++ 4 files changed, 41 insertions(+), 15 deletions(-) diff --git a/src/ofc/codegen/FortranGenerator.scala b/src/ofc/codegen/FortranGenerator.scala index fe92476..5eca5e6 100644 --- a/src/ofc/codegen/FortranGenerator.scala +++ b/src/ofc/codegen/FortranGenerator.scala @@ -9,8 +9,8 @@ class SymbolManager { def getName = name } - val symbols = mutable.Map[VarSymbol[_], SymbolInfo]() - val names = mutable.Set[String]() + private val symbols = mutable.Map[VarSymbol[_ <: Type], SymbolInfo]() + private val names = mutable.Set[String]() private def createNewName(sym: VarSymbol[_]) : String = { @tailrec @@ -39,18 +39,23 @@ class SymbolManager { } } - def getName(sym: VarSymbol[_]) = + def getName(sym: VarSymbol[_ <: Type]) = symbols.get(sym) match { case None => throw new LogicError("Unknown symbol (missing declaration?): "+sym.toString) case Some(info) => info.getName } + + def getDeclarations : Seq[String] = { + for ((sym, info) <- symbols) yield + sym.getType.getFortranAttributes.mkString(", ") + " :: " + info.getName + }.toSeq.sorted } class FortranGenerator { - var indentLevel = 0 - val maxPrec = 30 - val symbolManager = new SymbolManager - val buffer = scala.collection.mutable.Buffer[String]() + private var indentLevel = 0 + private val maxPrec = 30 + private val symbolManager = new SymbolManager + private val buffer = scala.collection.mutable.Buffer[String]() object BinaryOpInfo { sealed abstract class Associativity @@ -64,8 +69,15 @@ class FortranGenerator { override def toString = exp } + def apply(stat: Statement) : String = { + processStatement(stat) + + buffer.prepend("\n") + buffer.prependAll(symbolManager.getDeclarations) + buffer.mkString("\n") + } - def processStatement(stat: Statement) : String = { + private def processStatement(stat: Statement) { stat match { case (x : NullStatement) => () case (x : Comment) => addLine("!" + x.getValue) @@ -75,8 +87,6 @@ class FortranGenerator { case (a : Assignment) => processAssignment(a) case x => throw new UnimplementedException("Unknown statement type in FORTRAN generator: " + x.toString) } - - buffer.mkString("\n") } private def in() { diff --git a/src/ofc/codegen/Type.scala b/src/ofc/codegen/Type.scala index 76db372..4d07aec 100644 --- a/src/ofc/codegen/Type.scala +++ b/src/ofc/codegen/Type.scala @@ -1,21 +1,33 @@ package ofc.codegen -sealed abstract class Type +sealed abstract class Type { + def getFortranAttributes : Set[String] +} sealed abstract class PrimitiveType extends Type // These are case classes solely for the comparison operators -final case class IntType() extends PrimitiveType -final case class FloatType() extends PrimitiveType -final case class BoolType() extends PrimitiveType +final case class IntType() extends PrimitiveType { + def getFortranAttributes = Set("integer") +} + +final case class FloatType() extends PrimitiveType { + def getFortranAttributes = Set("real(kind=DP") +} + +final case class BoolType() extends PrimitiveType { + def getFortranAttributes = Set("logical") +} final case class ArrayType[ElementType <: Type](rank: Int, eType: ElementType) extends Type { def this(rank: Int)(implicit builder: TypeBuilder[ElementType]) = this(rank, builder()) def getElementType = eType + def getFortranAttributes = eType.getFortranAttributes ++ Set("allocatable", (":"*rank).mkString("dimension(",",",")")) } final case class PointerType[TargetType <: Type](tType: TargetType) extends Type { def this()(implicit builder: TypeBuilder[TargetType]) = this(builder()) def getTargetType = tType + def getFortranAttributes = tType.getFortranAttributes + "pointer" } abstract class StructType extends Type diff --git a/src/ofc/generators/onetep/CodeGenerator.scala b/src/ofc/generators/onetep/CodeGenerator.scala index 23f24ca..262ac17 100644 --- a/src/ofc/generators/onetep/CodeGenerator.scala +++ b/src/ofc/generators/onetep/CodeGenerator.scala @@ -69,7 +69,7 @@ class CodeGenerator { } val fortranGenerator = new FortranGenerator - val code = fortranGenerator.processStatement(statements) + val code = fortranGenerator(statements) println(code) } } diff --git a/src/ofc/generators/onetep/OnetepTypes.scala b/src/ofc/generators/onetep/OnetepTypes.scala index 8038609..9770d65 100644 --- a/src/ofc/generators/onetep/OnetepTypes.scala +++ b/src/ofc/generators/onetep/OnetepTypes.scala @@ -19,11 +19,14 @@ object OnetepTypes { val fieldType = new PointerType[ArrayType[StructType]](new ArrayType(1, TightBox)) new FieldSymbol[PointerType[ArrayType[StructType]]]("tight_boxes", fieldType) } + + def getFortranAttributes = Set("type(FUNC_BASIS)") } object CellInfo extends StructType { val ppdWidth = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("n_pt"+dim)}.toSeq val numPPDs = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("n_ppds_a"+dim)}.toSeq + def getFortranAttributes = Set("type(CELL_INFO)") } object TightBox extends StructType { @@ -31,5 +34,6 @@ object OnetepTypes { val finishPts = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("finish_pts"+dim)}.toSeq val startPPD = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("start_ppd"+dim)}.toSeq val finishPPD = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("finish_ppd"+dim)}.toSeq + def getFortranAttributes = Set("type(FUNCTION_TIGHT_BOX)") } } -- 2.47.3