From ef0482433b99247754844b2a474eb3dfd81f9e5d Mon Sep 17 00:00:00 2001 From: Francis Russell Date: Fri, 6 Apr 2012 23:57:21 +0100 Subject: [PATCH] Generate basic loop structures. --- src/ofc/codegen/Assignment.scala | 5 + src/ofc/codegen/ConditionalValue.scala | 3 + src/ofc/codegen/Expression.scala | 13 +- src/ofc/codegen/ForLoop.scala | 7 + src/ofc/codegen/FortranGenerator.scala | 141 +++++++++++++++++- src/ofc/codegen/IfStatement.scala | 3 + src/ofc/codegen/NumericOperator.scala | 12 +- src/ofc/codegen/ProducerStatement.scala | 35 ++++- src/ofc/codegen/Symbol.scala | 7 +- .../generators/onetep/PPDFunctionSet.scala | 6 +- 10 files changed, 208 insertions(+), 24 deletions(-) create mode 100644 src/ofc/codegen/Assignment.scala create mode 100644 src/ofc/codegen/ForLoop.scala create mode 100644 src/ofc/codegen/IfStatement.scala diff --git a/src/ofc/codegen/Assignment.scala b/src/ofc/codegen/Assignment.scala new file mode 100644 index 0000000..f729988 --- /dev/null +++ b/src/ofc/codegen/Assignment.scala @@ -0,0 +1,5 @@ +package ofc.codegen + +class Assignment(symbol: VarSymbol[_ <: Type], expression: Expression[_ <: Type]) extends Statement { + // TODO: type check assignment +} diff --git a/src/ofc/codegen/ConditionalValue.scala b/src/ofc/codegen/ConditionalValue.scala index 4e39ada..f35c3d3 100644 --- a/src/ofc/codegen/ConditionalValue.scala +++ b/src/ofc/codegen/ConditionalValue.scala @@ -2,4 +2,7 @@ package ofc.codegen class ConditionalValue[T <: Type](predicate: Expression[BoolType], ifTrue: Expression[T], ifFalse: Expression[T]) extends Expression[T] { def foreach[U](f: Expression[_] => U) = List(predicate, ifTrue, ifFalse).foreach(f) + def getPredicate = predicate + def getIfTrue = ifTrue + def getIfFalse = ifFalse } diff --git a/src/ofc/codegen/Expression.scala b/src/ofc/codegen/Expression.scala index 0279a96..a28f219 100644 --- a/src/ofc/codegen/Expression.scala +++ b/src/ofc/codegen/Expression.scala @@ -61,7 +61,7 @@ abstract class Expression[T <: Type] extends Traversable[Expression[_]] { new NumericComparison[T](NumericOperations.LT, this, rhs) def |<=|(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[BoolType] = - new NumericComparison[T](NumericOperations.LTE, this, rhs) + new NumericComparison[T](NumericOperations.LE, this, rhs) def |==|(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[BoolType] = new NumericComparison[T](NumericOperations.EQ, this, rhs) @@ -73,7 +73,7 @@ abstract class Expression[T <: Type] extends Traversable[Expression[_]] { new NumericComparison[T](NumericOperations.GT, this, rhs) def |>=|(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[BoolType] = - new NumericComparison[T](NumericOperations.GTE, this, rhs) + new NumericComparison[T](NumericOperations.GE, this, rhs) def %[FieldType <: Type](field: FieldSymbol[FieldType])(implicit witness: <:<[this.type, Expression[StructType]]) : Expression[FieldType] = new FieldAccess[FieldType](witness(this), field) @@ -93,14 +93,21 @@ class VarRef[T <: Type](symbol: VarSymbol[T]) extends Expression[T] with LeafExp // Struct and array accesses class FieldAccess[T <: Type](expression: Expression[StructType], field: FieldSymbol[T]) extends Expression[T] { def foreach[U](f: Expression[_] => U) = f(expression) + def getStructExpression = expression + def getField = field } class ArrayRead[E <: Type](expression: Expression[ArrayType[E]], index: Seq[Expression[IntType]]) extends Expression[E] { def foreach[U](f: Expression[_] => U) = (index :+ expression).foreach(f) + def getArrayExpression = expression + def getIndexExpressions = index } class PointerDereference[E <: Type](expression: Expression[PointerType[E]]) extends Expression[E] { def foreach[U](f: Expression[_] => U) = f(expression) + def getExpression = expression } // Literals -class IntegerLiteral(value: Int) extends Expression[IntType] with LeafExpression +class IntegerLiteral(value: Int) extends Expression[IntType] with LeafExpression { + def getValue = value +} diff --git a/src/ofc/codegen/ForLoop.scala b/src/ofc/codegen/ForLoop.scala new file mode 100644 index 0000000..1468bed --- /dev/null +++ b/src/ofc/codegen/ForLoop.scala @@ -0,0 +1,7 @@ +package ofc.codegen + +class ForLoop(index: VarSymbol[IntType], begin: Expression[IntType], end: Expression[IntType]) extends ScopeStatement { + def getIndex = index + def getBegin = begin + def getEnd = end +} diff --git a/src/ofc/codegen/FortranGenerator.scala b/src/ofc/codegen/FortranGenerator.scala index 69a9aec..fdc6e89 100644 --- a/src/ofc/codegen/FortranGenerator.scala +++ b/src/ofc/codegen/FortranGenerator.scala @@ -1,9 +1,52 @@ package ofc.codegen -import ofc.UnimplementedException -import scala.collection.mutable.ArrayBuffer +import scala.annotation.tailrec +import ofc.{UnimplementedException,LogicError} + +class SymbolManager { + import scala.collection.mutable + + class SymbolInfo(name: String) { + def getName = name + } + + val symbols = mutable.Map[VarSymbol[_], SymbolInfo]() + val names = mutable.Set[String]() + + private def createNewName(sym: VarSymbol[_]) : String = { + @tailrec + def helper(sym: VarSymbol[_], suffix: Int) : String = { + val candidate = sym.getName + "_" + suffix + if (names.contains(candidate)) + helper(sym, suffix + 1) + else + candidate + } + + helper(sym, 1) + } + + def addSymbol(sym: VarSymbol[_ <: Type]) { + sym match { + case (s: DeclaredVarSymbol[_]) => if (!symbols.contains(s)) { + val name = createNewName(s) + names += name + symbols += s -> new SymbolInfo(name) + } + case (_: NamedUnboundVarSymbol[_]) => throw new LogicError("Attempted to add unbound symbol to SymbolManager.") + } + } + + def getName(sym: VarSymbol[_]) = + symbols.get(sym) match { + case None => throw new LogicError("Unknown symbol "+sym.toString) + case Some(info) => info.getName + } +} class FortranGenerator { - val buffer = new ArrayBuffer[String] + var indentLevel = 0 + val symbolManager = new SymbolManager + val buffer = scala.collection.mutable.Buffer[String]() def processStatement(stat: Statement) : String = { stat match { @@ -11,12 +54,102 @@ class FortranGenerator { case (x : Comment) => addLine("!" + x.getValue) case (x : BlockStatement) => processScope(x) case (x : ProducerStatement) => processStatement(x.toConcrete) + case (x : ForLoop) => processForLoop(x) case x => throw new UnimplementedException("Unknown statement type in FORTRAN generator: " + x.toString) } buffer.mkString("\n") } + private def in() { + indentLevel += 1 + } + + private def out() { + indentLevel -=1 + if (indentLevel < 0) throw new LogicError("Indentation level dropped below 0 in FORTRAN generator.") + } + + private def buildExpression(expression: Expression[_]) : String = { + expression match { + case (i : IntegerLiteral) => i.getValue.toString + case (a : FieldAccess[_]) => "%s%%%s".format(buildExpression(a.getStructExpression), a.getField.getName) + case (r : VarRef[_]) => r.getSymbol match { + case (s: NamedUnboundVarSymbol[_]) => s.getName + case s => symbolManager.getName(s) + } + case (r: ArrayRead[_]) => + buildExpression(r.getArrayExpression) + r.getIndexExpressions.map(buildExpression(_)).mkString("(",", ",")") + case (d: PointerDereference[_]) => buildExpression(d.getExpression) + case (c: ConditionalValue[_]) => buildConditionalValue(c) + case (c: NumericComparison[_]) => buildNumericComparison(c) + case (c: NumericOperator[_]) => buildNumericOperator(c) + case x => throw new UnimplementedException("Unknown expression type in FORTRAN generator: " + x.toString) + } + } + + private def buildConditionalValue(conditional: ConditionalValue[_]) : String = { + var symbol = new DeclaredVarSymbol[Type]("ternary") + symbolManager.addSymbol(symbol) + val name = symbolManager.getName(symbol) + addLine("if (%s) then".format(buildExpression(conditional.getPredicate))) + in + addLine("%s = %s".format(name, buildExpression(conditional.getIfTrue))) + out + addLine("else") + in + addLine("%s = %s".format(name, buildExpression(conditional.getIfFalse))) + out + addLine("endif") + + name + } + + private def buildNumericComparison(comparison: NumericComparison[_]) : String = { + import NumericOperations._ + val opString = comparison.getOperation match { + case LT => ".lt." + case LE => ".le." + case EQ => ".eq." + case NE => ".ne." + case GT => ".gt." + case GE => ".ge." + case x => throw new UnimplementedException("Unknown comparison type in FORTRAN generator: "+x.toString) + } + + buildExpression(comparison.getLeft) + opString + buildExpression(comparison.getRight) + } + + private def buildNumericOperator(comparison: NumericOperator[_]) : String = { + import NumericOperations._ + val opString = comparison.getOperation match { + case Add => "+" + case Sub => "-" + case Mul => "*" + case Div => "/" + case Mod => return "mod(%s, %s)".format(buildExpression(comparison.getLeft), buildExpression(comparison.getRight)) + case x => throw new UnimplementedException("Unknown numeric operator in FORTRAN generator: "+x.toString) + } + + buildExpression(comparison.getLeft) + opString + buildExpression(comparison.getRight) + } + + private def processForLoop(stat: ForLoop) { + val index = stat.getIndex + symbolManager.addSymbol(index) + val name = symbolManager.getName(index) + val begin = buildExpression(stat.getBegin) + val end = buildExpression(stat.getEnd) + + val header = "do %s = %s, %s".format(name, begin, end) + val footer = "end do" + addLine(header) + in + processScope(stat) + out + addLine(footer) + } + private def processScope(scope: ScopeStatement) { for(stat <- scope.getStatements) { processStatement(stat) @@ -24,6 +157,6 @@ class FortranGenerator { } private def addLine(line: String) { - buffer += line + buffer += " "*indentLevel + line } } diff --git a/src/ofc/codegen/IfStatement.scala b/src/ofc/codegen/IfStatement.scala new file mode 100644 index 0000000..84ae5b0 --- /dev/null +++ b/src/ofc/codegen/IfStatement.scala @@ -0,0 +1,3 @@ +package ofc.codegen + +class IfStatement(predicate: Expression[BoolType]) extends ScopeStatement diff --git a/src/ofc/codegen/NumericOperator.scala b/src/ofc/codegen/NumericOperator.scala index ef23db1..d119783 100644 --- a/src/ofc/codegen/NumericOperator.scala +++ b/src/ofc/codegen/NumericOperator.scala @@ -10,17 +10,25 @@ object NumericOperations { sealed abstract class CompareOp object LT extends CompareOp - object LTE extends CompareOp + object LE extends CompareOp object EQ extends CompareOp object NE extends CompareOp object GT extends CompareOp - object GTE extends CompareOp + object GE extends CompareOp } class NumericOperator[T <: Type](op: NumericOperations.FieldOp, left: Expression[T], right: Expression[T]) extends Expression[T] { + // TODO: Type check operators def foreach[U](f: Expression[_] => U) = List(left, right).foreach(f) + def getOperation = op + def getLeft = left + def getRight = right } class NumericComparison[T <: Type](op: NumericOperations.CompareOp, left: Expression[T], right: Expression[T]) extends Expression[BoolType] { + // TODO: Type check operators def foreach[U](f: Expression[_] => U) = List(left, right).foreach(f) + def getOperation = op + def getLeft = left + def getRight = right } diff --git a/src/ofc/codegen/ProducerStatement.scala b/src/ofc/codegen/ProducerStatement.scala index c79e169..c9d3acf 100644 --- a/src/ofc/codegen/ProducerStatement.scala +++ b/src/ofc/codegen/ProducerStatement.scala @@ -14,7 +14,7 @@ class ProducerStatement extends Statement { } } - trait Context { + sealed trait Context { def defines : Set[VarSymbol[_]] def depends : Set[VarSymbol[_]] def tryCompare(other: Context) : Option[Int] = { @@ -33,19 +33,19 @@ class ProducerStatement extends Statement { } } - class VariableRange(symbol: VarSymbol[_], first: Expression[IntType], last: Expression[IntType]) extends Context { + case class VariableRange(symbol: VarSymbol[IntType], first: Expression[IntType], last: Expression[IntType]) extends Context { override def toString = "VariableRange("+symbol.toString+", "+first.toString+", "+last.toString+")" def defines = Set(symbol) def depends = Expression.findReferencedVariables(List(first, last)) } - class Predicate(expression: Expression[BoolType]) extends Context { + case class Predicate(expression: Expression[BoolType]) extends Context { override def toString = "Predicate("+expression.toString+")" def defines = Set.empty def depends = Expression.findReferencedVariables(expression) } - class DerivedExpression(symbol: VarSymbol[_], expression: Expression[_]) extends Context { + case class DerivedExpression(symbol: VarSymbol[_ <: Type], expression: Expression[_ <: Type]) extends Context { override def toString = "DerivedExpression("+symbol.toString + " <- " + expression.toString+")" def defines = Set(symbol) def depends = Expression.findReferencedVariables(expression) @@ -69,12 +69,31 @@ class ProducerStatement extends Statement { } def toConcrete : Statement = { - val contexts : Seq[Context] = ranges ++ predicates ++ expressions + val contexts = ranges ++ predicates ++ expressions val sortedContexts = Context.sort(contexts) - for(c <- sortedContexts) - println(c.toString) + val block = new BlockStatement + var scope : ScopeStatement = block - new Comment("Producer flattening unimplemented.") + for (context <- sortedContexts) { + context match { + case VariableRange(sym, first, last) => { + val loop = new ForLoop(sym, first, last) + scope += loop + scope = loop + } + case Predicate(expression) => { + val ifStat = new IfStatement(expression) + scope += ifStat + scope = ifStat + } + case DerivedExpression(sym, expression) => { + val assignment = new Assignment(sym, expression) + scope += assignment + } + } + } + + block } } diff --git a/src/ofc/codegen/Symbol.scala b/src/ofc/codegen/Symbol.scala index a7c770c..6d463c6 100644 --- a/src/ofc/codegen/Symbol.scala +++ b/src/ofc/codegen/Symbol.scala @@ -8,7 +8,7 @@ case class FieldSymbol[T <: Type](name: String) extends Symbol { def getName = name } -abstract class VarSymbol[T <: Type](name: String) extends Symbol { +sealed abstract class VarSymbol[T <: Type](name: String) extends Symbol { def getName = name } @@ -16,6 +16,5 @@ object VarSymbol { implicit def toRef[T <: Type](symbol: VarSymbol[T]) = new VarRef[T](symbol) } -case class DeclaredVarSymbol[T <: Type](name: String) extends VarSymbol[T](name) -abstract class UnboundVarSymbol[T <: Type](name: String) extends VarSymbol[T](name) -case class NamedUnboundVarSymbol[T <: Type](name: String) extends UnboundVarSymbol[T](name) +class DeclaredVarSymbol[T <: Type](name: String) extends VarSymbol[T](name) +class NamedUnboundVarSymbol[T <: Type](name: String) extends VarSymbol[T](name) diff --git a/src/ofc/generators/onetep/PPDFunctionSet.scala b/src/ofc/generators/onetep/PPDFunctionSet.scala index 622308a..e4a8f50 100644 --- a/src/ofc/generators/onetep/PPDFunctionSet.scala +++ b/src/ofc/generators/onetep/PPDFunctionSet.scala @@ -2,9 +2,9 @@ package ofc.generators.onetep import ofc.codegen._ class PPDFunctionSet(val basisName: String, dataName: String) extends FunctionSet { - val basis = NamedUnboundVarSymbol[StructType](basisName) - val data = NamedUnboundVarSymbol[ArrayType[FloatType]](dataName) - val pubCell = NamedUnboundVarSymbol[StructType]("pub_cell") + val basis = new NamedUnboundVarSymbol[StructType](basisName) + val data = new NamedUnboundVarSymbol[ArrayType[FloatType]](dataName) + val pubCell = new NamedUnboundVarSymbol[StructType]("pub_cell") val numSpheres = basis % FieldSymbol[IntType]("num"); val ppdWidths = for(dim <- 1 to 3) yield pubCell % FieldSymbol[IntType]("n_pt"+dim) -- 2.47.3