]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Implement precedence-based expression bracketing.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Sat, 7 Apr 2012 23:15:10 +0000 (00:15 +0100)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Sat, 7 Apr 2012 23:15:10 +0000 (00:15 +0100)
src/ofc/codegen/FortranGenerator.scala

index e852c2198c9d9a24c7fdd3841c8d1e55f816efd5..fea6c970b326d8b2635f148678028a229206589d 100644 (file)
@@ -48,9 +48,23 @@ class SymbolManager {
 
 class FortranGenerator {
   var indentLevel = 0
+  val maxPrec = 30
   val symbolManager = new SymbolManager
   val buffer = scala.collection.mutable.Buffer[String]()
 
+  object BinaryOpInfo {
+    sealed abstract class Associativity
+    object LEFT extends Associativity
+    object RIGHT extends Associativity
+    object FUNCTION extends Associativity
+  }
+
+  case class BinaryOpInfo(template: String, precedence: Int, assoc: BinaryOpInfo.Associativity)
+  case class ExpHolder(prec: Int, exp: String) {
+    override def toString = exp
+  }
+
+
   def processStatement(stat: Statement) : String = {
     stat match {
       case (x : NullStatement) => ()
@@ -74,16 +88,16 @@ class FortranGenerator {
     if (indentLevel < 0) throw new LogicError("Indentation level dropped below 0 in FORTRAN generator.")
   }
 
-  private def buildExpression(expression: Expression[_]) : String = {
+  private def buildExpression(expression: Expression[_]) : ExpHolder = {
     expression match {
-      case (i : IntegerLiteral) => i.getValue.toString
-      case (a : FieldAccess[_]) => "%s%%%s".format(buildExpression(a.getStructExpression), a.getField.getName)
+      case (i : IntegerLiteral) => ExpHolder(maxPrec, i.getValue.toString)
+      case (a : FieldAccess[_]) => ExpHolder(maxPrec, "%s%%%s".format(buildExpression(a.getStructExpression), a.getField.getName))
       case (r : VarRef[_]) => r.getSymbol match {
-        case (s: NamedUnboundVarSymbol[_]) => s.getName
-        case s => symbolManager.getName(s)
+        case (s: NamedUnboundVarSymbol[_]) => ExpHolder(maxPrec, s.getName)
+        case s => ExpHolder(maxPrec, symbolManager.getName(s))
       }
       case (r: ArrayRead[_]) => 
-        buildExpression(r.getArrayExpression) + r.getIndexExpressions.map(buildExpression(_)).mkString("(",", ",")")
+        ExpHolder(maxPrec, buildExpression(r.getArrayExpression) + r.getIndexExpressions.map(buildExpression(_)).mkString("(",", ",")"))
       case (d: PointerDereference[_]) => buildExpression(d.getExpression)
       case (c: ConditionalValue[_]) => buildConditionalValue(c)
       case (c: NumericComparison[_]) => buildNumericComparison(c)
@@ -92,7 +106,7 @@ class FortranGenerator {
     }
   }
 
-  private def buildConditionalValue(conditional: ConditionalValue[_]) : String = {
+  private def buildConditionalValue(conditional: ConditionalValue[_]) : ExpHolder = {
     var symbol = new DeclaredVarSymbol[Type]("ternary")
     symbolManager.addSymbol(symbol)
     val name = symbolManager.getName(symbol)
@@ -106,38 +120,62 @@ class FortranGenerator {
     out
     addLine("endif")
 
-    name
+    ExpHolder(maxPrec, name)
   }
-
-  private def buildNumericComparison(comparison: NumericComparison[_]) : String = {
+  
+  private def getBinaryOpInfo(op: NumericOperations.CompareOp) : BinaryOpInfo = {
     import NumericOperations._
-    val opString = comparison.getOperation match {
-      case LT => ".lt."
-      case LE => ".le."
-      case EQ => ".eq."
-      case NE => ".ne."
-      case GT => ".gt."
-      case GE => ".ge."
+    import BinaryOpInfo._
+    op match {
+      case LT => BinaryOpInfo("%s .lt. %s", 16, LEFT)
+      case LE => BinaryOpInfo("%s .le. %s", 16, LEFT)
+      case EQ => BinaryOpInfo("%s .eq. %s", 16, LEFT)
+      case NE => BinaryOpInfo("%s .ne. %s", 16, LEFT)
+      case GT => BinaryOpInfo("%s .gt. %s", 16, LEFT)
+      case GE => BinaryOpInfo("%s .ge. %s", 16, LEFT)
       case x => throw new UnimplementedException("Unknown comparison type in FORTRAN generator: "+x.toString)
     }
-
-    buildExpression(comparison.getLeft) + opString + buildExpression(comparison.getRight)
   }
 
-  private def buildNumericOperator(comparison: NumericOperator[_]) : String = {
+  private def getBinaryOpInfo(op: NumericOperations.FieldOp) : BinaryOpInfo = {
     import NumericOperations._
-    val opString = comparison.getOperation match {
-      case Add => "+"
-      case Sub => "-"
-      case Mul => "*"
-      case Div => "/"
-      case Mod => return "mod(%s, %s)".format(buildExpression(comparison.getLeft), buildExpression(comparison.getRight))
+    import BinaryOpInfo._
+    op match {
+      case Add => BinaryOpInfo("%s + %s", 22, LEFT)
+      case Sub => BinaryOpInfo("%s - %s", 22, LEFT)
+      case Mul => BinaryOpInfo("%s * %s", 26, LEFT)
+      case Div => BinaryOpInfo("%s / %s", 26, LEFT)
+      case Mod => BinaryOpInfo("mod(%s, %s)", maxPrec, FUNCTION)
       case x => throw new UnimplementedException("Unknown numeric operator in FORTRAN generator: "+x.toString)
     }
+  }
+
+  private def buildBinaryOperation(opInfo: BinaryOpInfo, left: ExpHolder, right: ExpHolder) : ExpHolder = {
+    import BinaryOpInfo._
+
+    def bracket(opInfo: BinaryOpInfo, exp: ExpHolder, assoc: Associativity) =
+      opInfo.assoc != FUNCTION && 
+      (opInfo.precedence > exp.prec || (opInfo.precedence == exp.prec && opInfo.assoc != assoc))
+
+    val lhs = if (bracket(opInfo, left, LEFT))
+     "(" + left.exp + ")"
+     else
+       left.exp
 
-    buildExpression(comparison.getLeft) + opString + buildExpression(comparison.getRight)
+    val rhs = if (bracket(opInfo, right, RIGHT))
+     "(" + right.exp + ")"
+     else
+       right.exp
+
+    ExpHolder(opInfo.precedence, opInfo.template.format(lhs, rhs))
   }
 
+  private def buildNumericComparison(c: NumericComparison[_]) : ExpHolder =
+    buildBinaryOperation(getBinaryOpInfo(c.getOperation), buildExpression(c.getLeft), buildExpression(c.getRight))
+
+  private def buildNumericOperator(o: NumericOperator[_]) : ExpHolder =
+    buildBinaryOperation(getBinaryOpInfo(o.getOperation), buildExpression(o.getLeft), buildExpression(o.getRight))
+
   private def processForLoop(stat: ForLoop) {
     val index = stat.getIndex
     val name = symbolManager.getName(index)