From: Francis Russell Date: Sat, 7 Apr 2012 23:15:10 +0000 (+0100) Subject: Implement precedence-based expression bracketing. X-Git-Url: https://git.unchartedbackwaters.co.uk/w/?a=commitdiff_plain;h=0b7cf1a2b71485659d477d7a2e078bcb8f9f40aa;p=francis%2Fofc.git Implement precedence-based expression bracketing. --- diff --git a/src/ofc/codegen/FortranGenerator.scala b/src/ofc/codegen/FortranGenerator.scala index e852c21..fea6c97 100644 --- a/src/ofc/codegen/FortranGenerator.scala +++ b/src/ofc/codegen/FortranGenerator.scala @@ -48,9 +48,23 @@ class SymbolManager { class FortranGenerator { var indentLevel = 0 + val maxPrec = 30 val symbolManager = new SymbolManager val buffer = scala.collection.mutable.Buffer[String]() + object BinaryOpInfo { + sealed abstract class Associativity + object LEFT extends Associativity + object RIGHT extends Associativity + object FUNCTION extends Associativity + } + + case class BinaryOpInfo(template: String, precedence: Int, assoc: BinaryOpInfo.Associativity) + case class ExpHolder(prec: Int, exp: String) { + override def toString = exp + } + + def processStatement(stat: Statement) : String = { stat match { case (x : NullStatement) => () @@ -74,16 +88,16 @@ class FortranGenerator { if (indentLevel < 0) throw new LogicError("Indentation level dropped below 0 in FORTRAN generator.") } - private def buildExpression(expression: Expression[_]) : String = { + private def buildExpression(expression: Expression[_]) : ExpHolder = { expression match { - case (i : IntegerLiteral) => i.getValue.toString - case (a : FieldAccess[_]) => "%s%%%s".format(buildExpression(a.getStructExpression), a.getField.getName) + case (i : IntegerLiteral) => ExpHolder(maxPrec, i.getValue.toString) + case (a : FieldAccess[_]) => ExpHolder(maxPrec, "%s%%%s".format(buildExpression(a.getStructExpression), a.getField.getName)) case (r : VarRef[_]) => r.getSymbol match { - case (s: NamedUnboundVarSymbol[_]) => s.getName - case s => symbolManager.getName(s) + case (s: NamedUnboundVarSymbol[_]) => ExpHolder(maxPrec, s.getName) + case s => ExpHolder(maxPrec, symbolManager.getName(s)) } case (r: ArrayRead[_]) => - buildExpression(r.getArrayExpression) + r.getIndexExpressions.map(buildExpression(_)).mkString("(",", ",")") + ExpHolder(maxPrec, buildExpression(r.getArrayExpression) + r.getIndexExpressions.map(buildExpression(_)).mkString("(",", ",")")) case (d: PointerDereference[_]) => buildExpression(d.getExpression) case (c: ConditionalValue[_]) => buildConditionalValue(c) case (c: NumericComparison[_]) => buildNumericComparison(c) @@ -92,7 +106,7 @@ class FortranGenerator { } } - private def buildConditionalValue(conditional: ConditionalValue[_]) : String = { + private def buildConditionalValue(conditional: ConditionalValue[_]) : ExpHolder = { var symbol = new DeclaredVarSymbol[Type]("ternary") symbolManager.addSymbol(symbol) val name = symbolManager.getName(symbol) @@ -106,38 +120,62 @@ class FortranGenerator { out addLine("endif") - name + ExpHolder(maxPrec, name) } - - private def buildNumericComparison(comparison: NumericComparison[_]) : String = { + + private def getBinaryOpInfo(op: NumericOperations.CompareOp) : BinaryOpInfo = { import NumericOperations._ - val opString = comparison.getOperation match { - case LT => ".lt." - case LE => ".le." - case EQ => ".eq." - case NE => ".ne." - case GT => ".gt." - case GE => ".ge." + import BinaryOpInfo._ + op match { + case LT => BinaryOpInfo("%s .lt. %s", 16, LEFT) + case LE => BinaryOpInfo("%s .le. %s", 16, LEFT) + case EQ => BinaryOpInfo("%s .eq. %s", 16, LEFT) + case NE => BinaryOpInfo("%s .ne. %s", 16, LEFT) + case GT => BinaryOpInfo("%s .gt. %s", 16, LEFT) + case GE => BinaryOpInfo("%s .ge. %s", 16, LEFT) case x => throw new UnimplementedException("Unknown comparison type in FORTRAN generator: "+x.toString) } - - buildExpression(comparison.getLeft) + opString + buildExpression(comparison.getRight) } - private def buildNumericOperator(comparison: NumericOperator[_]) : String = { + private def getBinaryOpInfo(op: NumericOperations.FieldOp) : BinaryOpInfo = { import NumericOperations._ - val opString = comparison.getOperation match { - case Add => "+" - case Sub => "-" - case Mul => "*" - case Div => "/" - case Mod => return "mod(%s, %s)".format(buildExpression(comparison.getLeft), buildExpression(comparison.getRight)) + import BinaryOpInfo._ + op match { + case Add => BinaryOpInfo("%s + %s", 22, LEFT) + case Sub => BinaryOpInfo("%s - %s", 22, LEFT) + case Mul => BinaryOpInfo("%s * %s", 26, LEFT) + case Div => BinaryOpInfo("%s / %s", 26, LEFT) + case Mod => BinaryOpInfo("mod(%s, %s)", maxPrec, FUNCTION) case x => throw new UnimplementedException("Unknown numeric operator in FORTRAN generator: "+x.toString) } + } + + private def buildBinaryOperation(opInfo: BinaryOpInfo, left: ExpHolder, right: ExpHolder) : ExpHolder = { + import BinaryOpInfo._ + + def bracket(opInfo: BinaryOpInfo, exp: ExpHolder, assoc: Associativity) = + opInfo.assoc != FUNCTION && + (opInfo.precedence > exp.prec || (opInfo.precedence == exp.prec && opInfo.assoc != assoc)) + + val lhs = if (bracket(opInfo, left, LEFT)) + "(" + left.exp + ")" + else + left.exp - buildExpression(comparison.getLeft) + opString + buildExpression(comparison.getRight) + val rhs = if (bracket(opInfo, right, RIGHT)) + "(" + right.exp + ")" + else + right.exp + + ExpHolder(opInfo.precedence, opInfo.template.format(lhs, rhs)) } + private def buildNumericComparison(c: NumericComparison[_]) : ExpHolder = + buildBinaryOperation(getBinaryOpInfo(c.getOperation), buildExpression(c.getLeft), buildExpression(c.getRight)) + + private def buildNumericOperator(o: NumericOperator[_]) : ExpHolder = + buildBinaryOperation(getBinaryOpInfo(o.getOperation), buildExpression(o.getLeft), buildExpression(o.getRight)) + private def processForLoop(stat: ForLoop) { val index = stat.getIndex val name = symbolManager.getName(index)