--- /dev/null
+package ofc.codegen
+
+class Assignment(symbol: VarSymbol[_ <: Type], expression: Expression[_ <: Type]) extends Statement {
+ // TODO: type check assignment
+}
class ConditionalValue[T <: Type](predicate: Expression[BoolType], ifTrue: Expression[T], ifFalse: Expression[T]) extends Expression[T] {
def foreach[U](f: Expression[_] => U) = List(predicate, ifTrue, ifFalse).foreach(f)
+ def getPredicate = predicate
+ def getIfTrue = ifTrue
+ def getIfFalse = ifFalse
}
new NumericComparison[T](NumericOperations.LT, this, rhs)
def |<=|(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[BoolType] =
- new NumericComparison[T](NumericOperations.LTE, this, rhs)
+ new NumericComparison[T](NumericOperations.LE, this, rhs)
def |==|(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[BoolType] =
new NumericComparison[T](NumericOperations.EQ, this, rhs)
new NumericComparison[T](NumericOperations.GT, this, rhs)
def |>=|(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[BoolType] =
- new NumericComparison[T](NumericOperations.GTE, this, rhs)
+ new NumericComparison[T](NumericOperations.GE, this, rhs)
def %[FieldType <: Type](field: FieldSymbol[FieldType])(implicit witness: <:<[this.type, Expression[StructType]]) : Expression[FieldType] =
new FieldAccess[FieldType](witness(this), field)
// Struct and array accesses
class FieldAccess[T <: Type](expression: Expression[StructType], field: FieldSymbol[T]) extends Expression[T] {
def foreach[U](f: Expression[_] => U) = f(expression)
+ def getStructExpression = expression
+ def getField = field
}
class ArrayRead[E <: Type](expression: Expression[ArrayType[E]], index: Seq[Expression[IntType]]) extends Expression[E] {
def foreach[U](f: Expression[_] => U) = (index :+ expression).foreach(f)
+ def getArrayExpression = expression
+ def getIndexExpressions = index
}
class PointerDereference[E <: Type](expression: Expression[PointerType[E]]) extends Expression[E] {
def foreach[U](f: Expression[_] => U) = f(expression)
+ def getExpression = expression
}
// Literals
-class IntegerLiteral(value: Int) extends Expression[IntType] with LeafExpression
+class IntegerLiteral(value: Int) extends Expression[IntType] with LeafExpression {
+ def getValue = value
+}
--- /dev/null
+package ofc.codegen
+
+class ForLoop(index: VarSymbol[IntType], begin: Expression[IntType], end: Expression[IntType]) extends ScopeStatement {
+ def getIndex = index
+ def getBegin = begin
+ def getEnd = end
+}
package ofc.codegen
-import ofc.UnimplementedException
-import scala.collection.mutable.ArrayBuffer
+import scala.annotation.tailrec
+import ofc.{UnimplementedException,LogicError}
+
+class SymbolManager {
+ import scala.collection.mutable
+
+ class SymbolInfo(name: String) {
+ def getName = name
+ }
+
+ val symbols = mutable.Map[VarSymbol[_], SymbolInfo]()
+ val names = mutable.Set[String]()
+
+ private def createNewName(sym: VarSymbol[_]) : String = {
+ @tailrec
+ def helper(sym: VarSymbol[_], suffix: Int) : String = {
+ val candidate = sym.getName + "_" + suffix
+ if (names.contains(candidate))
+ helper(sym, suffix + 1)
+ else
+ candidate
+ }
+
+ helper(sym, 1)
+ }
+
+ def addSymbol(sym: VarSymbol[_ <: Type]) {
+ sym match {
+ case (s: DeclaredVarSymbol[_]) => if (!symbols.contains(s)) {
+ val name = createNewName(s)
+ names += name
+ symbols += s -> new SymbolInfo(name)
+ }
+ case (_: NamedUnboundVarSymbol[_]) => throw new LogicError("Attempted to add unbound symbol to SymbolManager.")
+ }
+ }
+
+ def getName(sym: VarSymbol[_]) =
+ symbols.get(sym) match {
+ case None => throw new LogicError("Unknown symbol "+sym.toString)
+ case Some(info) => info.getName
+ }
+}
class FortranGenerator {
- val buffer = new ArrayBuffer[String]
+ var indentLevel = 0
+ val symbolManager = new SymbolManager
+ val buffer = scala.collection.mutable.Buffer[String]()
def processStatement(stat: Statement) : String = {
stat match {
case (x : Comment) => addLine("!" + x.getValue)
case (x : BlockStatement) => processScope(x)
case (x : ProducerStatement) => processStatement(x.toConcrete)
+ case (x : ForLoop) => processForLoop(x)
case x => throw new UnimplementedException("Unknown statement type in FORTRAN generator: " + x.toString)
}
buffer.mkString("\n")
}
+ private def in() {
+ indentLevel += 1
+ }
+
+ private def out() {
+ indentLevel -=1
+ if (indentLevel < 0) throw new LogicError("Indentation level dropped below 0 in FORTRAN generator.")
+ }
+
+ private def buildExpression(expression: Expression[_]) : String = {
+ expression match {
+ case (i : IntegerLiteral) => i.getValue.toString
+ case (a : FieldAccess[_]) => "%s%%%s".format(buildExpression(a.getStructExpression), a.getField.getName)
+ case (r : VarRef[_]) => r.getSymbol match {
+ case (s: NamedUnboundVarSymbol[_]) => s.getName
+ case s => symbolManager.getName(s)
+ }
+ case (r: ArrayRead[_]) =>
+ buildExpression(r.getArrayExpression) + r.getIndexExpressions.map(buildExpression(_)).mkString("(",", ",")")
+ case (d: PointerDereference[_]) => buildExpression(d.getExpression)
+ case (c: ConditionalValue[_]) => buildConditionalValue(c)
+ case (c: NumericComparison[_]) => buildNumericComparison(c)
+ case (c: NumericOperator[_]) => buildNumericOperator(c)
+ case x => throw new UnimplementedException("Unknown expression type in FORTRAN generator: " + x.toString)
+ }
+ }
+
+ private def buildConditionalValue(conditional: ConditionalValue[_]) : String = {
+ var symbol = new DeclaredVarSymbol[Type]("ternary")
+ symbolManager.addSymbol(symbol)
+ val name = symbolManager.getName(symbol)
+ addLine("if (%s) then".format(buildExpression(conditional.getPredicate)))
+ in
+ addLine("%s = %s".format(name, buildExpression(conditional.getIfTrue)))
+ out
+ addLine("else")
+ in
+ addLine("%s = %s".format(name, buildExpression(conditional.getIfFalse)))
+ out
+ addLine("endif")
+
+ name
+ }
+
+ private def buildNumericComparison(comparison: NumericComparison[_]) : String = {
+ import NumericOperations._
+ val opString = comparison.getOperation match {
+ case LT => ".lt."
+ case LE => ".le."
+ case EQ => ".eq."
+ case NE => ".ne."
+ case GT => ".gt."
+ case GE => ".ge."
+ case x => throw new UnimplementedException("Unknown comparison type in FORTRAN generator: "+x.toString)
+ }
+
+ buildExpression(comparison.getLeft) + opString + buildExpression(comparison.getRight)
+ }
+
+ private def buildNumericOperator(comparison: NumericOperator[_]) : String = {
+ import NumericOperations._
+ val opString = comparison.getOperation match {
+ case Add => "+"
+ case Sub => "-"
+ case Mul => "*"
+ case Div => "/"
+ case Mod => return "mod(%s, %s)".format(buildExpression(comparison.getLeft), buildExpression(comparison.getRight))
+ case x => throw new UnimplementedException("Unknown numeric operator in FORTRAN generator: "+x.toString)
+ }
+
+ buildExpression(comparison.getLeft) + opString + buildExpression(comparison.getRight)
+ }
+
+ private def processForLoop(stat: ForLoop) {
+ val index = stat.getIndex
+ symbolManager.addSymbol(index)
+ val name = symbolManager.getName(index)
+ val begin = buildExpression(stat.getBegin)
+ val end = buildExpression(stat.getEnd)
+
+ val header = "do %s = %s, %s".format(name, begin, end)
+ val footer = "end do"
+ addLine(header)
+ in
+ processScope(stat)
+ out
+ addLine(footer)
+ }
+
private def processScope(scope: ScopeStatement) {
for(stat <- scope.getStatements) {
processStatement(stat)
}
private def addLine(line: String) {
- buffer += line
+ buffer += " "*indentLevel + line
}
}
--- /dev/null
+package ofc.codegen
+
+class IfStatement(predicate: Expression[BoolType]) extends ScopeStatement
sealed abstract class CompareOp
object LT extends CompareOp
- object LTE extends CompareOp
+ object LE extends CompareOp
object EQ extends CompareOp
object NE extends CompareOp
object GT extends CompareOp
- object GTE extends CompareOp
+ object GE extends CompareOp
}
class NumericOperator[T <: Type](op: NumericOperations.FieldOp, left: Expression[T], right: Expression[T]) extends Expression[T] {
+ // TODO: Type check operators
def foreach[U](f: Expression[_] => U) = List(left, right).foreach(f)
+ def getOperation = op
+ def getLeft = left
+ def getRight = right
}
class NumericComparison[T <: Type](op: NumericOperations.CompareOp, left: Expression[T], right: Expression[T]) extends Expression[BoolType] {
+ // TODO: Type check operators
def foreach[U](f: Expression[_] => U) = List(left, right).foreach(f)
+ def getOperation = op
+ def getLeft = left
+ def getRight = right
}
}
}
- trait Context {
+ sealed trait Context {
def defines : Set[VarSymbol[_]]
def depends : Set[VarSymbol[_]]
def tryCompare(other: Context) : Option[Int] = {
}
}
- class VariableRange(symbol: VarSymbol[_], first: Expression[IntType], last: Expression[IntType]) extends Context {
+ case class VariableRange(symbol: VarSymbol[IntType], first: Expression[IntType], last: Expression[IntType]) extends Context {
override def toString = "VariableRange("+symbol.toString+", "+first.toString+", "+last.toString+")"
def defines = Set(symbol)
def depends = Expression.findReferencedVariables(List(first, last))
}
- class Predicate(expression: Expression[BoolType]) extends Context {
+ case class Predicate(expression: Expression[BoolType]) extends Context {
override def toString = "Predicate("+expression.toString+")"
def defines = Set.empty
def depends = Expression.findReferencedVariables(expression)
}
- class DerivedExpression(symbol: VarSymbol[_], expression: Expression[_]) extends Context {
+ case class DerivedExpression(symbol: VarSymbol[_ <: Type], expression: Expression[_ <: Type]) extends Context {
override def toString = "DerivedExpression("+symbol.toString + " <- " + expression.toString+")"
def defines = Set(symbol)
def depends = Expression.findReferencedVariables(expression)
}
def toConcrete : Statement = {
- val contexts : Seq[Context] = ranges ++ predicates ++ expressions
+ val contexts = ranges ++ predicates ++ expressions
val sortedContexts = Context.sort(contexts)
- for(c <- sortedContexts)
- println(c.toString)
+ val block = new BlockStatement
+ var scope : ScopeStatement = block
- new Comment("Producer flattening unimplemented.")
+ for (context <- sortedContexts) {
+ context match {
+ case VariableRange(sym, first, last) => {
+ val loop = new ForLoop(sym, first, last)
+ scope += loop
+ scope = loop
+ }
+ case Predicate(expression) => {
+ val ifStat = new IfStatement(expression)
+ scope += ifStat
+ scope = ifStat
+ }
+ case DerivedExpression(sym, expression) => {
+ val assignment = new Assignment(sym, expression)
+ scope += assignment
+ }
+ }
+ }
+
+ block
}
}
def getName = name
}
-abstract class VarSymbol[T <: Type](name: String) extends Symbol {
+sealed abstract class VarSymbol[T <: Type](name: String) extends Symbol {
def getName = name
}
implicit def toRef[T <: Type](symbol: VarSymbol[T]) = new VarRef[T](symbol)
}
-case class DeclaredVarSymbol[T <: Type](name: String) extends VarSymbol[T](name)
-abstract class UnboundVarSymbol[T <: Type](name: String) extends VarSymbol[T](name)
-case class NamedUnboundVarSymbol[T <: Type](name: String) extends UnboundVarSymbol[T](name)
+class DeclaredVarSymbol[T <: Type](name: String) extends VarSymbol[T](name)
+class NamedUnboundVarSymbol[T <: Type](name: String) extends VarSymbol[T](name)
import ofc.codegen._
class PPDFunctionSet(val basisName: String, dataName: String) extends FunctionSet {
- val basis = NamedUnboundVarSymbol[StructType](basisName)
- val data = NamedUnboundVarSymbol[ArrayType[FloatType]](dataName)
- val pubCell = NamedUnboundVarSymbol[StructType]("pub_cell")
+ val basis = new NamedUnboundVarSymbol[StructType](basisName)
+ val data = new NamedUnboundVarSymbol[ArrayType[FloatType]](dataName)
+ val pubCell = new NamedUnboundVarSymbol[StructType]("pub_cell")
val numSpheres = basis % FieldSymbol[IntType]("num");
val ppdWidths = for(dim <- 1 to 3) yield pubCell % FieldSymbol[IntType]("n_pt"+dim)