package ofc.codegen
-class ConditionalValue[T <: Type](predicate: Expression[BoolType], ifTrue: Expression[T], ifFalse: Expression[T]) extends Expression[T]
+class ConditionalValue[T <: Type](predicate: Expression[BoolType], ifTrue: Expression[T], ifFalse: Expression[T]) extends Expression[T] {
+ def foreach[U](f: Expression[_] => U) = List(predicate, ifTrue, ifFalse).foreach(f)
+}
class PointerType[TargetType <: Type] extends Type
abstract class StructType extends Type
+trait LeafExpression {
+ def foreach[U](f: Expression[_] => U): Unit = ()
+}
+
object Expression {
implicit def fromInt(i: Int) : Expression[IntType] = new IntegerLiteral(i)
+
+ def findReferencedVariables(expression: Expression[_]) : Set[VarSymbol[_]] =
+ findReferencedVariables(Some[Expression[_]](expression))
+
+ def findReferencedVariables(expressions: Traversable[Expression[_]]) : Set[VarSymbol[_]] = {
+ val vars = scala.collection.mutable.Set[VarSymbol[_]]()
+ expressions.foreach(_ match {
+ case (ref: VarRef[_]) => vars += ref.getSymbol
+ case x => vars ++= findReferencedVariables(x.toTraversable)
+ })
+
+ vars.toSet
+ }
}
-class Expression[T <: Type] {
+abstract class Expression[T <: Type] extends Traversable[Expression[_]] {
// Field Operations
def +(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[T] =
new NumericOperator[T](NumericOperations.Add, this, rhs)
}
// Variable references
-class VarRef[T <: Type](symbol: VarSymbol[T]) extends Expression[T]
+class VarRef[T <: Type](symbol: VarSymbol[T]) extends Expression[T] with LeafExpression {
+ def getSymbol = symbol
+}
// Struct and array accesses
-class FieldAccess[T <: Type](expression: Expression[StructType], field: FieldSymbol[T]) extends Expression[T]
-class ArrayRead[E <: Type](expression: Expression[ArrayType[E]], index: Seq[Expression[IntType]]) extends Expression[E]
-class PointerDereference[E <: Type](expression: Expression[PointerType[E]]) extends Expression[E]
+class FieldAccess[T <: Type](expression: Expression[StructType], field: FieldSymbol[T]) extends Expression[T] {
+ def foreach[U](f: Expression[_] => U) = f(expression)
+}
+
+class ArrayRead[E <: Type](expression: Expression[ArrayType[E]], index: Seq[Expression[IntType]]) extends Expression[E] {
+ def foreach[U](f: Expression[_] => U) = (index :+ expression).foreach(f)
+}
+class PointerDereference[E <: Type](expression: Expression[PointerType[E]]) extends Expression[E] {
+ def foreach[U](f: Expression[_] => U) = f(expression)
+}
// Literals
-class IntegerLiteral(value: Int) extends Expression[IntType]
+class IntegerLiteral(value: Int) extends Expression[IntType] with LeafExpression
object GTE extends CompareOp
}
-class NumericOperator[T <: Type](op: NumericOperations.FieldOp, left: Expression[T], right: Expression[T]) extends Expression[T]
-class NumericComparison[T <: Type](op: NumericOperations.CompareOp, left: Expression[T], right: Expression[T]) extends Expression[BoolType]
+class NumericOperator[T <: Type](op: NumericOperations.FieldOp, left: Expression[T], right: Expression[T]) extends Expression[T] {
+ def foreach[U](f: Expression[_] => U) = List(left, right).foreach(f)
+}
+
+class NumericComparison[T <: Type](op: NumericOperations.CompareOp, left: Expression[T], right: Expression[T]) extends Expression[BoolType] {
+ def foreach[U](f: Expression[_] => U) = List(left, right).foreach(f)
+}
package ofc.codegen
+import ofc.util.Ordering
class ProducerStatement extends Statement {
- case class VariableRange(symbol: Symbol, first: Expression[IntType], last: Expression[IntType])
- case class Predicate(expression: Expression[BoolType])
+ object Context {
+ def sort(contexts: Seq[Context]) : Seq[Context] = {
+ val ordering = scala.collection.mutable.Set[(Context, Context)]()
+ for(c1 <- contexts; c2 <- contexts)
+ (c1.tryCompare(c2)) match {
+ case Some(x) => {
+ if (x<0)
+ ordering += (c1 -> c2)
+ else if (x>0)
+ ordering += (c2 -> c1)
+ }
+ case None => ()
+ }
+
+ val totalOrdering = Ordering.transitiveClosure(ordering.toSet)
+ contexts.sortWith((a,b) => totalOrdering.contains(a,b))
+ }
+ }
+
+ trait Context {
+ def defines : Set[VarSymbol[_]]
+ def depends : Set[VarSymbol[_]]
+ def tryCompare(other: Context) : Option[Int] = {
+ val isChild = other.defines & depends
+ val isParent = defines & other.depends
+ assert(isChild.isEmpty || isParent.isEmpty)
+
+ if (this == other)
+ Some(0)
+ else if (isParent.nonEmpty)
+ Some(-1)
+ else if (isChild.nonEmpty)
+ Some(1)
+ else
+ None
+ }
+ }
+
+ class VariableRange(symbol: VarSymbol[_], first: Expression[IntType], last: Expression[IntType]) extends Context {
+ override def toString = "VariableRange("+symbol.toString+", "+first.toString+", "+last.toString+")"
+ def defines = Set(symbol)
+ def depends = Expression.findReferencedVariables(List(first, last))
+ }
+
+ class Predicate(expression: Expression[BoolType]) extends Context {
+ override def toString = "Predicate("+expression.toString+")"
+ def defines = Set.empty
+ def depends = Expression.findReferencedVariables(expression)
+ }
+
+ class DerivedExpression(symbol: VarSymbol[_], expression: Expression[_]) extends Context {
+ override def toString = "DerivedExpression("+symbol.toString + " <- " + expression.toString+")"
+ def defines = Set(symbol)
+ def depends = Expression.findReferencedVariables(expression)
+ }
var statement = new Comment("Placeholder statement for consumer.")
var ranges : Seq[VariableRange] = Nil
var predicates : Seq[Predicate] = Nil
- var expressions : Map[VarSymbol[_], Expression[_]] = Map.empty
+ var expressions : Seq[DerivedExpression] = Nil
def addExpression[T <: Type](name: String, expression: Expression[T]) : VarSymbol[T] = {
val symbol = new DeclaredVarSymbol[T](name)
- expressions += symbol -> expression
+ expressions +:= new DerivedExpression(symbol, expression)
symbol
}
symbol
}
- def toConcrete = new Comment("Producer flattening unimplemented.")
+ def toConcrete : Statement = {
+ val contexts : Seq[Context] = ranges ++ predicates ++ expressions
+ val sortedContexts = Context.sort(contexts)
+
+ for(c <- sortedContexts)
+ println(c.toString)
+
+ new Comment("Producer flattening unimplemented.")
+ }
}
--- /dev/null
+package ofc.util
+import scala.annotation.tailrec
+
+object Ordering {
+
+ @tailrec
+ def transitiveClosure[T](ordering: Set[(T, T)]) : Set[(T, T)] = {
+ def step[T](ordering: Set[(T, T)]) = {
+ val newOrdering = scala.collection.mutable.Set[(T, T)]()
+ newOrdering ++= ordering
+
+ for ((a1,a2) <- ordering; (b1, b2) <- ordering; if a2 == b1)
+ newOrdering += (a1 -> b2)
+
+ assert(newOrdering.size >= ordering.size)
+ newOrdering.toSet
+ }
+
+ val stepped = step(ordering)
+
+ if (stepped.size == ordering.size)
+ ordering
+ else
+ transitiveClosure(stepped)
+ }
+}