]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Generate variable declarations.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Sun, 8 Apr 2012 02:14:49 +0000 (03:14 +0100)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Sun, 8 Apr 2012 02:14:49 +0000 (03:14 +0100)
src/ofc/codegen/FortranGenerator.scala
src/ofc/codegen/Type.scala
src/ofc/generators/onetep/CodeGenerator.scala
src/ofc/generators/onetep/OnetepTypes.scala

index fe92476ec11d09b4a5ae7134027a8091d7c002cf..5eca5e657866a6387bcc6c2e5bf058ec830b8b87 100644 (file)
@@ -9,8 +9,8 @@ class SymbolManager {
     def getName = name
   }
 
-  val symbols = mutable.Map[VarSymbol[_], SymbolInfo]()
-  val names = mutable.Set[String]()
+  private val symbols = mutable.Map[VarSymbol[_ <: Type], SymbolInfo]()
+  private val names = mutable.Set[String]()
 
   private def createNewName(sym: VarSymbol[_]) : String = {
     @tailrec
@@ -39,18 +39,23 @@ class SymbolManager {
     }
   }
 
-  def getName(sym: VarSymbol[_]) =
+  def getName(sym: VarSymbol[_ <: Type]) =
     symbols.get(sym) match {
       case None => throw new LogicError("Unknown symbol (missing declaration?): "+sym.toString)
       case Some(info) => info.getName
     }
+
+  def getDeclarations : Seq[String] = {
+    for ((sym, info) <- symbols) yield
+     sym.getType.getFortranAttributes.mkString(", ") + " :: " + info.getName
+  }.toSeq.sorted
 }
 
 class FortranGenerator {
-  var indentLevel = 0
-  val maxPrec = 30
-  val symbolManager = new SymbolManager
-  val buffer = scala.collection.mutable.Buffer[String]()
+  private var indentLevel = 0
+  private val maxPrec = 30
+  private val symbolManager = new SymbolManager
+  private val buffer = scala.collection.mutable.Buffer[String]()
 
   object BinaryOpInfo {
     sealed abstract class Associativity
@@ -64,8 +69,15 @@ class FortranGenerator {
     override def toString = exp
   }
 
+  def apply(stat: Statement) : String = {
+    processStatement(stat)
+
+    buffer.prepend("\n")
+    buffer.prependAll(symbolManager.getDeclarations)
+    buffer.mkString("\n")
+  }
 
-  def processStatement(stat: Statement) : String = {
+  private def processStatement(stat: Statement) {
     stat match {
       case (x : NullStatement) => ()
       case (x : Comment) => addLine("!" + x.getValue)
@@ -75,8 +87,6 @@ class FortranGenerator {
       case (a : Assignment) => processAssignment(a)
       case x => throw new UnimplementedException("Unknown statement type in FORTRAN generator: " + x.toString)
     }
-
-    buffer.mkString("\n")
   }
 
   private def in() {
index 76db3723283f35f69008f084ccc179c64b51a1d1..4d07aec5d22d5717c6dd6f5e6b02362216e77afb 100644 (file)
@@ -1,21 +1,33 @@
 package ofc.codegen
 
-sealed abstract class Type
+sealed abstract class Type {
+  def getFortranAttributes : Set[String]
+}
 sealed abstract class PrimitiveType extends Type
 
 // These are case classes solely for the comparison operators
-final case class IntType() extends PrimitiveType
-final case class FloatType() extends PrimitiveType
-final case class BoolType() extends PrimitiveType
+final case class IntType() extends PrimitiveType {
+  def getFortranAttributes = Set("integer")
+}
+
+final case class FloatType() extends PrimitiveType {
+  def getFortranAttributes = Set("real(kind=DP")
+}
+
+final case class BoolType() extends PrimitiveType {
+  def getFortranAttributes = Set("logical")
+}
 
 final case class ArrayType[ElementType <: Type](rank: Int, eType: ElementType) extends Type {
   def this(rank: Int)(implicit builder: TypeBuilder[ElementType]) = this(rank, builder())
   def getElementType = eType
+  def getFortranAttributes = eType.getFortranAttributes ++ Set("allocatable", (":"*rank).mkString("dimension(",",",")"))
 }
 
 final case class PointerType[TargetType <: Type](tType: TargetType) extends Type {
   def this()(implicit builder: TypeBuilder[TargetType]) = this(builder())
   def getTargetType = tType
+  def getFortranAttributes = tType.getFortranAttributes + "pointer"
 }
 
 abstract class StructType extends Type
index 23f24ca064d5af307cbda365d8aed69eb64a305a..262ac1750d5c7f6a30d1e40eb062e0fb0f3ee46b 100644 (file)
@@ -69,7 +69,7 @@ class CodeGenerator {
     }
 
     val fortranGenerator = new FortranGenerator
-    val code = fortranGenerator.processStatement(statements)
+    val code = fortranGenerator(statements)
     println(code)
   }
 }
index 8038609670bd66958115f647a350ecd014a0d011..9770d65c27c63a6022d57bd5b3b762058f7ddf88 100644 (file)
@@ -19,11 +19,14 @@ object OnetepTypes {
       val fieldType = new PointerType[ArrayType[StructType]](new ArrayType(1, TightBox))
       new FieldSymbol[PointerType[ArrayType[StructType]]]("tight_boxes", fieldType)
     }
+
+    def getFortranAttributes = Set("type(FUNC_BASIS)")
   }
 
   object CellInfo extends StructType {
     val ppdWidth = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("n_pt"+dim)}.toSeq
     val numPPDs = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("n_ppds_a"+dim)}.toSeq
+    def getFortranAttributes = Set("type(CELL_INFO)")
   }
 
   object TightBox extends StructType {
@@ -31,5 +34,6 @@ object OnetepTypes {
     val finishPts = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("finish_pts"+dim)}.toSeq
     val startPPD = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("start_ppd"+dim)}.toSeq
     val finishPPD = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("finish_ppd"+dim)}.toSeq
+    def getFortranAttributes = Set("type(FUNCTION_TIGHT_BOX)")
   }
 }