class FortranGenerator {
var indentLevel = 0
+ val maxPrec = 30
val symbolManager = new SymbolManager
val buffer = scala.collection.mutable.Buffer[String]()
+ object BinaryOpInfo {
+ sealed abstract class Associativity
+ object LEFT extends Associativity
+ object RIGHT extends Associativity
+ object FUNCTION extends Associativity
+ }
+
+ case class BinaryOpInfo(template: String, precedence: Int, assoc: BinaryOpInfo.Associativity)
+ case class ExpHolder(prec: Int, exp: String) {
+ override def toString = exp
+ }
+
+
def processStatement(stat: Statement) : String = {
stat match {
case (x : NullStatement) => ()
if (indentLevel < 0) throw new LogicError("Indentation level dropped below 0 in FORTRAN generator.")
}
- private def buildExpression(expression: Expression[_]) : String = {
+ private def buildExpression(expression: Expression[_]) : ExpHolder = {
expression match {
- case (i : IntegerLiteral) => i.getValue.toString
- case (a : FieldAccess[_]) => "%s%%%s".format(buildExpression(a.getStructExpression), a.getField.getName)
+ case (i : IntegerLiteral) => ExpHolder(maxPrec, i.getValue.toString)
+ case (a : FieldAccess[_]) => ExpHolder(maxPrec, "%s%%%s".format(buildExpression(a.getStructExpression), a.getField.getName))
case (r : VarRef[_]) => r.getSymbol match {
- case (s: NamedUnboundVarSymbol[_]) => s.getName
- case s => symbolManager.getName(s)
+ case (s: NamedUnboundVarSymbol[_]) => ExpHolder(maxPrec, s.getName)
+ case s => ExpHolder(maxPrec, symbolManager.getName(s))
}
case (r: ArrayRead[_]) =>
- buildExpression(r.getArrayExpression) + r.getIndexExpressions.map(buildExpression(_)).mkString("(",", ",")")
+ ExpHolder(maxPrec, buildExpression(r.getArrayExpression) + r.getIndexExpressions.map(buildExpression(_)).mkString("(",", ",")"))
case (d: PointerDereference[_]) => buildExpression(d.getExpression)
case (c: ConditionalValue[_]) => buildConditionalValue(c)
case (c: NumericComparison[_]) => buildNumericComparison(c)
}
}
- private def buildConditionalValue(conditional: ConditionalValue[_]) : String = {
+ private def buildConditionalValue(conditional: ConditionalValue[_]) : ExpHolder = {
var symbol = new DeclaredVarSymbol[Type]("ternary")
symbolManager.addSymbol(symbol)
val name = symbolManager.getName(symbol)
out
addLine("endif")
- name
+ ExpHolder(maxPrec, name)
}
-
- private def buildNumericComparison(comparison: NumericComparison[_]) : String = {
+
+ private def getBinaryOpInfo(op: NumericOperations.CompareOp) : BinaryOpInfo = {
import NumericOperations._
- val opString = comparison.getOperation match {
- case LT => ".lt."
- case LE => ".le."
- case EQ => ".eq."
- case NE => ".ne."
- case GT => ".gt."
- case GE => ".ge."
+ import BinaryOpInfo._
+ op match {
+ case LT => BinaryOpInfo("%s .lt. %s", 16, LEFT)
+ case LE => BinaryOpInfo("%s .le. %s", 16, LEFT)
+ case EQ => BinaryOpInfo("%s .eq. %s", 16, LEFT)
+ case NE => BinaryOpInfo("%s .ne. %s", 16, LEFT)
+ case GT => BinaryOpInfo("%s .gt. %s", 16, LEFT)
+ case GE => BinaryOpInfo("%s .ge. %s", 16, LEFT)
case x => throw new UnimplementedException("Unknown comparison type in FORTRAN generator: "+x.toString)
}
-
- buildExpression(comparison.getLeft) + opString + buildExpression(comparison.getRight)
}
- private def buildNumericOperator(comparison: NumericOperator[_]) : String = {
+ private def getBinaryOpInfo(op: NumericOperations.FieldOp) : BinaryOpInfo = {
import NumericOperations._
- val opString = comparison.getOperation match {
- case Add => "+"
- case Sub => "-"
- case Mul => "*"
- case Div => "/"
- case Mod => return "mod(%s, %s)".format(buildExpression(comparison.getLeft), buildExpression(comparison.getRight))
+ import BinaryOpInfo._
+ op match {
+ case Add => BinaryOpInfo("%s + %s", 22, LEFT)
+ case Sub => BinaryOpInfo("%s - %s", 22, LEFT)
+ case Mul => BinaryOpInfo("%s * %s", 26, LEFT)
+ case Div => BinaryOpInfo("%s / %s", 26, LEFT)
+ case Mod => BinaryOpInfo("mod(%s, %s)", maxPrec, FUNCTION)
case x => throw new UnimplementedException("Unknown numeric operator in FORTRAN generator: "+x.toString)
}
+ }
+
+ private def buildBinaryOperation(opInfo: BinaryOpInfo, left: ExpHolder, right: ExpHolder) : ExpHolder = {
+ import BinaryOpInfo._
+
+ def bracket(opInfo: BinaryOpInfo, exp: ExpHolder, assoc: Associativity) =
+ opInfo.assoc != FUNCTION &&
+ (opInfo.precedence > exp.prec || (opInfo.precedence == exp.prec && opInfo.assoc != assoc))
+
+ val lhs = if (bracket(opInfo, left, LEFT))
+ "(" + left.exp + ")"
+ else
+ left.exp
- buildExpression(comparison.getLeft) + opString + buildExpression(comparison.getRight)
+ val rhs = if (bracket(opInfo, right, RIGHT))
+ "(" + right.exp + ")"
+ else
+ right.exp
+
+ ExpHolder(opInfo.precedence, opInfo.template.format(lhs, rhs))
}
+ private def buildNumericComparison(c: NumericComparison[_]) : ExpHolder =
+ buildBinaryOperation(getBinaryOpInfo(c.getOperation), buildExpression(c.getLeft), buildExpression(c.getRight))
+
+ private def buildNumericOperator(o: NumericOperator[_]) : ExpHolder =
+ buildBinaryOperation(getBinaryOpInfo(o.getOperation), buildExpression(o.getLeft), buildExpression(o.getRight))
+
private def processForLoop(stat: ForLoop) {
val index = stat.getIndex
val name = symbolManager.getName(index)