]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Generate basic loop structures.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Fri, 6 Apr 2012 22:57:21 +0000 (23:57 +0100)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Fri, 6 Apr 2012 23:32:19 +0000 (00:32 +0100)
src/ofc/codegen/Assignment.scala [new file with mode: 0644]
src/ofc/codegen/ConditionalValue.scala
src/ofc/codegen/Expression.scala
src/ofc/codegen/ForLoop.scala [new file with mode: 0644]
src/ofc/codegen/FortranGenerator.scala
src/ofc/codegen/IfStatement.scala [new file with mode: 0644]
src/ofc/codegen/NumericOperator.scala
src/ofc/codegen/ProducerStatement.scala
src/ofc/codegen/Symbol.scala
src/ofc/generators/onetep/PPDFunctionSet.scala

diff --git a/src/ofc/codegen/Assignment.scala b/src/ofc/codegen/Assignment.scala
new file mode 100644 (file)
index 0000000..f729988
--- /dev/null
@@ -0,0 +1,5 @@
+package ofc.codegen
+
+class Assignment(symbol: VarSymbol[_ <: Type], expression: Expression[_ <: Type]) extends Statement {
+  // TODO: type check assignment
+}
index 4e39ada6d2fa658e1bc8b32f554e03ebe889a3ef..f35c3d31c9280c3715ec5f633b49065936e9b215 100644 (file)
@@ -2,4 +2,7 @@ package ofc.codegen
 
 class ConditionalValue[T <: Type](predicate: Expression[BoolType], ifTrue: Expression[T], ifFalse: Expression[T]) extends Expression[T] {
   def foreach[U](f: Expression[_] => U) = List(predicate, ifTrue, ifFalse).foreach(f)
+  def getPredicate = predicate
+  def getIfTrue = ifTrue
+  def getIfFalse = ifFalse
 }
index 0279a9692d76be4054f26be5727fa1757617ea43..a28f219391bb48a6176f6df03431bfe615d3b045 100644 (file)
@@ -61,7 +61,7 @@ abstract class Expression[T <: Type] extends Traversable[Expression[_]] {
     new NumericComparison[T](NumericOperations.LT, this, rhs)
 
   def |<=|(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[BoolType] =
-    new NumericComparison[T](NumericOperations.LTE, this, rhs)
+    new NumericComparison[T](NumericOperations.LE, this, rhs)
 
   def |==|(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[BoolType] =
     new NumericComparison[T](NumericOperations.EQ, this, rhs)
@@ -73,7 +73,7 @@ abstract class Expression[T <: Type] extends Traversable[Expression[_]] {
     new NumericComparison[T](NumericOperations.GT, this, rhs)
 
   def |>=|(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[BoolType] =
-    new NumericComparison[T](NumericOperations.GTE, this, rhs)
+    new NumericComparison[T](NumericOperations.GE, this, rhs)
 
   def %[FieldType <: Type](field: FieldSymbol[FieldType])(implicit witness: <:<[this.type, Expression[StructType]]) : Expression[FieldType] =
     new FieldAccess[FieldType](witness(this), field)
@@ -93,14 +93,21 @@ class VarRef[T <: Type](symbol: VarSymbol[T]) extends Expression[T] with LeafExp
 // Struct and array accesses
 class FieldAccess[T <: Type](expression: Expression[StructType], field: FieldSymbol[T]) extends Expression[T] {
   def foreach[U](f: Expression[_] => U) = f(expression)
+  def getStructExpression = expression
+  def getField = field
 }
 
 class ArrayRead[E <: Type](expression: Expression[ArrayType[E]], index: Seq[Expression[IntType]]) extends Expression[E] {
   def foreach[U](f: Expression[_] => U) = (index :+ expression).foreach(f)
+  def getArrayExpression = expression
+  def getIndexExpressions = index
 }
 class PointerDereference[E <: Type](expression: Expression[PointerType[E]]) extends Expression[E] {
   def foreach[U](f: Expression[_] => U) = f(expression)
+  def getExpression = expression
 }
 
 // Literals
-class IntegerLiteral(value: Int) extends Expression[IntType] with LeafExpression
+class IntegerLiteral(value: Int) extends Expression[IntType] with LeafExpression {
+  def getValue = value
+}
diff --git a/src/ofc/codegen/ForLoop.scala b/src/ofc/codegen/ForLoop.scala
new file mode 100644 (file)
index 0000000..1468bed
--- /dev/null
@@ -0,0 +1,7 @@
+package ofc.codegen
+
+class ForLoop(index: VarSymbol[IntType], begin: Expression[IntType], end: Expression[IntType]) extends ScopeStatement {
+  def getIndex = index
+  def getBegin = begin
+  def getEnd = end
+}
index 69a9aec50a092f8cfeb844ffe5d0d121fd72a076..fdc6e89bf065e44736f72aa0f92746f039b02f3a 100644 (file)
@@ -1,9 +1,52 @@
 package ofc.codegen
-import ofc.UnimplementedException
-import scala.collection.mutable.ArrayBuffer
+import scala.annotation.tailrec
+import ofc.{UnimplementedException,LogicError}
+
+class SymbolManager {
+  import scala.collection.mutable
+
+  class SymbolInfo(name: String) {
+    def getName = name
+  }
+
+  val symbols = mutable.Map[VarSymbol[_], SymbolInfo]()
+  val names = mutable.Set[String]()
+
+  private def createNewName(sym: VarSymbol[_]) : String = {
+    @tailrec
+    def helper(sym: VarSymbol[_], suffix: Int) : String = {
+      val candidate = sym.getName + "_" + suffix
+      if (names.contains(candidate))
+        helper(sym, suffix + 1)
+      else
+        candidate
+    }
+    
+    helper(sym, 1)
+  }
+
+  def addSymbol(sym: VarSymbol[_ <: Type]) {
+    sym match {
+      case (s: DeclaredVarSymbol[_]) => if (!symbols.contains(s)) {
+        val name = createNewName(s)
+        names += name
+        symbols += s -> new SymbolInfo(name)
+      }
+      case (_: NamedUnboundVarSymbol[_]) => throw new LogicError("Attempted to add unbound symbol to SymbolManager.")
+    }
+  }
+
+  def getName(sym: VarSymbol[_]) =
+    symbols.get(sym) match {
+      case None => throw new LogicError("Unknown symbol "+sym.toString)
+      case Some(info) => info.getName
+    }
+}
 
 class FortranGenerator {
-  val buffer = new ArrayBuffer[String]
+  var indentLevel = 0
+  val symbolManager = new SymbolManager
+  val buffer = scala.collection.mutable.Buffer[String]()
 
   def processStatement(stat: Statement) : String = {
     stat match {
@@ -11,12 +54,102 @@ class FortranGenerator {
       case (x : Comment) => addLine("!" + x.getValue)
       case (x : BlockStatement) => processScope(x)
       case (x : ProducerStatement) => processStatement(x.toConcrete)
+      case (x : ForLoop) => processForLoop(x)
       case x => throw new UnimplementedException("Unknown statement type in FORTRAN generator: " + x.toString)
     }
 
     buffer.mkString("\n")
   }
 
+  private def in() {
+    indentLevel += 1
+  }
+
+  private def out() {
+    indentLevel -=1
+    if (indentLevel < 0) throw new LogicError("Indentation level dropped below 0 in FORTRAN generator.")
+  }
+
+  private def buildExpression(expression: Expression[_]) : String = {
+    expression match {
+      case (i : IntegerLiteral) => i.getValue.toString
+      case (a : FieldAccess[_]) => "%s%%%s".format(buildExpression(a.getStructExpression), a.getField.getName)
+      case (r : VarRef[_]) => r.getSymbol match {
+        case (s: NamedUnboundVarSymbol[_]) => s.getName
+        case s => symbolManager.getName(s)
+      }
+      case (r: ArrayRead[_]) => 
+        buildExpression(r.getArrayExpression) + r.getIndexExpressions.map(buildExpression(_)).mkString("(",", ",")")
+      case (d: PointerDereference[_]) => buildExpression(d.getExpression)
+      case (c: ConditionalValue[_]) => buildConditionalValue(c)
+      case (c: NumericComparison[_]) => buildNumericComparison(c)
+      case (c: NumericOperator[_]) => buildNumericOperator(c)
+      case x => throw new UnimplementedException("Unknown expression type in FORTRAN generator: " + x.toString)
+    }
+  }
+
+  private def buildConditionalValue(conditional: ConditionalValue[_]) : String = {
+    var symbol = new DeclaredVarSymbol[Type]("ternary")
+    symbolManager.addSymbol(symbol)
+    val name = symbolManager.getName(symbol)
+    addLine("if (%s) then".format(buildExpression(conditional.getPredicate)))
+    in
+    addLine("%s = %s".format(name, buildExpression(conditional.getIfTrue)))
+    out
+    addLine("else")
+    in
+    addLine("%s = %s".format(name, buildExpression(conditional.getIfFalse)))
+    out
+    addLine("endif")
+
+    name
+  }
+
+  private def buildNumericComparison(comparison: NumericComparison[_]) : String = {
+    import NumericOperations._
+    val opString = comparison.getOperation match {
+      case LT => ".lt."
+      case LE => ".le."
+      case EQ => ".eq."
+      case NE => ".ne."
+      case GT => ".gt."
+      case GE => ".ge."
+      case x => throw new UnimplementedException("Unknown comparison type in FORTRAN generator: "+x.toString)
+    }
+
+    buildExpression(comparison.getLeft) + opString + buildExpression(comparison.getRight)
+  }
+
+  private def buildNumericOperator(comparison: NumericOperator[_]) : String = {
+    import NumericOperations._
+    val opString = comparison.getOperation match {
+      case Add => "+"
+      case Sub => "-"
+      case Mul => "*"
+      case Div => "/"
+      case Mod => return "mod(%s, %s)".format(buildExpression(comparison.getLeft), buildExpression(comparison.getRight))
+      case x => throw new UnimplementedException("Unknown numeric operator in FORTRAN generator: "+x.toString)
+    }
+
+    buildExpression(comparison.getLeft) + opString + buildExpression(comparison.getRight)
+  }
+
+  private def processForLoop(stat: ForLoop) {
+    val index = stat.getIndex
+    symbolManager.addSymbol(index)
+    val name = symbolManager.getName(index)
+    val begin = buildExpression(stat.getBegin)
+    val end = buildExpression(stat.getEnd)
+
+    val header = "do %s = %s, %s".format(name, begin, end)
+    val footer = "end do"
+    addLine(header)
+    in
+    processScope(stat)
+    out
+    addLine(footer)
+  }
+
   private def processScope(scope: ScopeStatement) {
     for(stat <- scope.getStatements) {
       processStatement(stat)
@@ -24,6 +157,6 @@ class FortranGenerator {
   }
 
   private def addLine(line: String) {
-    buffer += line
+    buffer += "  "*indentLevel + line
   }
 }
diff --git a/src/ofc/codegen/IfStatement.scala b/src/ofc/codegen/IfStatement.scala
new file mode 100644 (file)
index 0000000..84ae5b0
--- /dev/null
@@ -0,0 +1,3 @@
+package ofc.codegen
+
+class IfStatement(predicate: Expression[BoolType]) extends ScopeStatement
index ef23db1675a056af0289f11f3ddb041d09792d18..d119783beeb5b5e7b9241f4508f4cc22f8197a70 100644 (file)
@@ -10,17 +10,25 @@ object NumericOperations {
 
   sealed abstract class CompareOp
   object LT  extends CompareOp
-  object LTE extends CompareOp
+  object LE extends CompareOp
   object EQ  extends CompareOp
   object NE  extends CompareOp
   object GT  extends CompareOp
-  object GTE extends CompareOp
+  object GE extends CompareOp
 }
 
 class NumericOperator[T <: Type](op: NumericOperations.FieldOp, left: Expression[T], right: Expression[T]) extends Expression[T] {
+  // TODO: Type check operators
   def foreach[U](f: Expression[_] => U) = List(left, right).foreach(f)
+  def getOperation = op
+  def getLeft = left
+  def getRight = right
 }
 
 class NumericComparison[T <: Type](op: NumericOperations.CompareOp, left: Expression[T], right: Expression[T]) extends Expression[BoolType] {
+  // TODO: Type check operators
   def foreach[U](f: Expression[_] => U) = List(left, right).foreach(f)
+  def getOperation = op
+  def getLeft = left
+  def getRight = right
 }
index c79e169ad06f1147885fe980d196efa0c3355af0..c9d3acf0eb06390c6854678531fc63d11ef805ac 100644 (file)
@@ -14,7 +14,7 @@ class ProducerStatement extends Statement {
     }
   }
 
-  trait Context {
+  sealed trait Context {
     def defines : Set[VarSymbol[_]]
     def depends : Set[VarSymbol[_]]
     def tryCompare(other: Context) : Option[Int] = {
@@ -33,19 +33,19 @@ class ProducerStatement extends Statement {
     }
   }
 
-  class VariableRange(symbol: VarSymbol[_], first: Expression[IntType], last: Expression[IntType]) extends Context {
+  case class VariableRange(symbol: VarSymbol[IntType], first: Expression[IntType], last: Expression[IntType]) extends Context {
     override def toString = "VariableRange("+symbol.toString+", "+first.toString+", "+last.toString+")"
     def defines = Set(symbol)
     def depends = Expression.findReferencedVariables(List(first, last))
   }
 
-  class Predicate(expression: Expression[BoolType]) extends Context {
+  case class Predicate(expression: Expression[BoolType]) extends Context {
     override def toString = "Predicate("+expression.toString+")"
     def defines = Set.empty
     def depends = Expression.findReferencedVariables(expression)
   }
 
-  class DerivedExpression(symbol: VarSymbol[_], expression: Expression[_]) extends Context {
+  case class DerivedExpression(symbol: VarSymbol[_ <: Type], expression: Expression[_ <: Type]) extends Context {
     override def toString = "DerivedExpression("+symbol.toString + " <- " + expression.toString+")"
     def defines = Set(symbol)
     def depends = Expression.findReferencedVariables(expression)
@@ -69,12 +69,31 @@ class ProducerStatement extends Statement {
   }
 
   def toConcrete : Statement = {
-    val contexts : Seq[Context] = ranges ++ predicates ++ expressions
+    val contexts = ranges ++ predicates ++ expressions
     val sortedContexts = Context.sort(contexts)
 
-    for(c <- sortedContexts)
-     println(c.toString)
+    val block = new BlockStatement
+    var scope : ScopeStatement = block
 
-    new Comment("Producer flattening unimplemented.")
+    for (context <- sortedContexts) {
+      context match {
+        case VariableRange(sym, first, last) => {
+          val loop = new ForLoop(sym, first, last)
+          scope += loop
+          scope = loop
+        }
+        case Predicate(expression) => {
+          val ifStat = new IfStatement(expression)
+          scope += ifStat
+          scope = ifStat
+        }
+        case DerivedExpression(sym, expression) => {
+          val assignment = new Assignment(sym, expression)
+          scope += assignment
+        }
+      }
+    }
+
+    block
   }
 }
index a7c770cfccd665239395bc8ebfee5f4e2bf0a78e..6d463c6eda2b2db16b44fe95837d7e1145c8b3c0 100644 (file)
@@ -8,7 +8,7 @@ case class FieldSymbol[T <: Type](name: String) extends Symbol {
   def getName = name
 }
 
-abstract class VarSymbol[T <: Type](name: String) extends Symbol {
+sealed abstract class VarSymbol[T <: Type](name: String) extends Symbol {
   def getName = name
 }
 
@@ -16,6 +16,5 @@ object VarSymbol {
   implicit def toRef[T <: Type](symbol: VarSymbol[T]) = new VarRef[T](symbol)
 }
 
-case class DeclaredVarSymbol[T <: Type](name: String) extends VarSymbol[T](name)
-abstract class UnboundVarSymbol[T <: Type](name: String) extends VarSymbol[T](name)
-case class NamedUnboundVarSymbol[T <: Type](name: String) extends UnboundVarSymbol[T](name)
+class DeclaredVarSymbol[T <: Type](name: String) extends VarSymbol[T](name)
+class NamedUnboundVarSymbol[T <: Type](name: String) extends VarSymbol[T](name)
index 622308a1205c77c031f7f1cccfadeedbdd619155..e4a8f50e01d1a1fa1502ad68b9f86e350e0d15e6 100644 (file)
@@ -2,9 +2,9 @@ package ofc.generators.onetep
 import ofc.codegen._
 
 class PPDFunctionSet(val basisName: String, dataName: String) extends FunctionSet {
-  val basis = NamedUnboundVarSymbol[StructType](basisName)
-  val data =  NamedUnboundVarSymbol[ArrayType[FloatType]](dataName)
-  val pubCell = NamedUnboundVarSymbol[StructType]("pub_cell")
+  val basis = new NamedUnboundVarSymbol[StructType](basisName)
+  val data =  new NamedUnboundVarSymbol[ArrayType[FloatType]](dataName)
+  val pubCell = new NamedUnboundVarSymbol[StructType]("pub_cell")
 
   val numSpheres = basis % FieldSymbol[IntType]("num");
   val ppdWidths = for(dim <- 1 to 3) yield pubCell % FieldSymbol[IntType]("n_pt"+dim)