]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Generate 'use' statements.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Thu, 7 Jun 2012 16:39:04 +0000 (17:39 +0100)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Thu, 7 Jun 2012 16:39:04 +0000 (17:39 +0100)
src/ofc/codegen/FortranGenerator.scala

index 13d39a94a4988c13f8cecd647f84ee2003ba9630..42e73796ae649c5116704401bf16a4cad85dec97 100644 (file)
@@ -9,7 +9,8 @@ class SymbolManager {
     def getName = name
   }
 
-  private val symbols = mutable.Map[VarSymbol[_ <: Type], SymbolInfo]()
+  private val symbols = mutable.Set[Symbol]()
+  private val declaredSymbols = mutable.Map[VarSymbol[_ <: Type], SymbolInfo]()
   private val names = mutable.Set[String]()
 
   private def createNewName(sym: VarSymbol[_]) : String = {
@@ -25,28 +26,40 @@ class SymbolManager {
     helper(sym, 1)
   }
 
-  def addSymbol(sym: VarSymbol[_ <: Type]) {
+  def addSymbol(sym: Symbol) {
+    symbols += sym
+  }
+
+  def addDeclaration(sym: VarSymbol[_ <: Type]) {
+    addSymbol(sym)
+
     sym match {
-      case (s: DeclaredVarSymbol[_]) => if (!symbols.contains(s)) {
+      case (s: DeclaredVarSymbol[_]) => if (!declaredSymbols.contains(s)) {
         val name = createNewName(s)
         names += name
-        symbols += s -> new SymbolInfo(name)
+        declaredSymbols += s -> new SymbolInfo(name)
+
+        //FIXME: This is a hacky way to detect structures we need to import
+        s.getType match {
+          case (structType: StructType) => addSymbol(structType)
+          case _ => ()
+        }
       } else {
         throw new LogicError("Attempted to add symbol more than once (multiple declarations?): "+sym.getName)
       }
 
-      case (_: NamedUnboundVarSymbol[_]) => throw new LogicError("Attempted to add unbound symbol to SymbolManager.")
+      case _ => throw new LogicError("Attempted to add declaration not of type DeclaredVarSymbol.")
     }
   }
 
   def getName(sym: VarSymbol[_ <: Type]) =
-    symbols.get(sym) match {
+    declaredSymbols.get(sym) match {
       case None => throw new LogicError("Unknown symbol (missing declaration?): "+sym.getName)
       case Some(info) => info.getName
     }
 
   def getDeclarations : Seq[String] = {
-    for ((sym, info) <- symbols) yield {
+    for ((sym, info) <- declaredSymbols) yield {
       var attributeStrings : Seq[String] = Nil
 
       // It seems these properties need to go after the type-related attributes
@@ -59,6 +72,24 @@ class SymbolManager {
       attributeStrings.mkString(", ") + " :: " + info.getName
     }
   }.toSeq.sorted
+
+  def getUses : Seq[String] = {
+    var uses : Map[String, mutable.Set[String]] = Map.empty
+
+    for (sym  <- symbols) {
+      for(property <- sym.getProperties) property match {
+        case FortranModule(name) => {
+          val imported = uses.getOrElse(name, mutable.Set.empty)
+          uses += name -> imported
+          imported += sym.getName
+        }
+        case _ => ()
+      }
+    }
+
+    for((moduleName, symbolNames) <- uses) yield
+      "use "+moduleName+", only: " + symbolNames.mkString(", ")
+  }.toSeq.sorted
 }
 
 object FortranGenerator {
@@ -128,9 +159,9 @@ object FortranGenerator {
 class FortranGenerator {
   import FortranGenerator.{maxPrec, BinaryOpInfo, getBinaryOpInfo}
 
-  private var indentLevel = 0
   private val symbolManager = new SymbolManager
   private val buffer = scala.collection.mutable.Buffer[String]()
+  private var indentLevel = 0
 
   case class ExpHolder(prec: Int, exp: String) {
     override def toString = exp
@@ -151,7 +182,7 @@ class FortranGenerator {
     prependLines(symbolManager.getDeclarations)
     prependLine("implicit none")
     prependLine("")
-    prependLine("!use statments will go here")
+    prependLines(symbolManager.getUses)
     out
 
     // parameters are only named *after* processing the body
@@ -203,8 +234,11 @@ class FortranGenerator {
       case (i : CharLiteral) => ExpHolder(maxPrec, "'%s'".format(i.getValue.toString))
       case (a : FieldAccess[_]) => ExpHolder(maxPrec, "%s%%%s".format(buildExpression(a.getStructExpression), a.getField.getName))
       case (r : VarRef[_]) => r.getSymbol match {
-        case (s: NamedUnboundVarSymbol[_]) => ExpHolder(maxPrec, s.getName)
-        case s => ExpHolder(maxPrec, symbolManager.getName(s))
+        case (s: DeclaredVarSymbol[_]) => ExpHolder(maxPrec, symbolManager.getName(s))
+        case (s: Symbol) => {
+          symbolManager.addSymbol(s)
+          ExpHolder(maxPrec, s.getName)
+        }
       }
       case (r: ArrayAccess[_]) => 
         ExpHolder(maxPrec, buildExpression(r.getArrayExpression) + r.getIndexExpressions.map(buildExpression(_)).mkString("(",", ",")"))
@@ -214,14 +248,17 @@ class FortranGenerator {
       case (c: NumericOperator[_]) => buildNumericOperator(c)
       case (c: Conversion[_,_]) => buildConversion(c)
       case (i: Intrinsic[_]) => buildIntrinsic(i)
-      case (f: FunctionCall[_]) => buildFunctionCall(f)
+      case (f: FunctionCall[_]) => {
+        symbolManager.addSymbol(f.getSignature)
+        buildFunctionCall(f)
+      }
       case x => throw new UnimplementedException("Unknown expression type in FORTRAN generator: " + x.toString)
     }
   }
 
   private def buildConditionalValue(conditional: ConditionalValue[_ <: Type]) : ExpHolder = {
     var symbol = new DeclaredVarSymbol[Type]("ternary", conditional.getType)
-    symbolManager.addSymbol(symbol)
+    symbolManager.addDeclaration(symbol)
     val name = symbolManager.getName(symbol)
     addLine("if (%s) then".format(buildExpression(conditional.getPredicate)))
     in
@@ -320,7 +357,7 @@ class FortranGenerator {
 
   private def processScope(scope: ScopeStatement) {
     for (sym <- scope.getDeclarations) {
-      symbolManager.addSymbol(sym)
+      symbolManager.addDeclaration(sym)
     }
     for(stat <- scope.getStatements) {
       processStatement(stat)