]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Compute dependencies between constructs in ProducerStatement.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Fri, 6 Apr 2012 17:24:59 +0000 (18:24 +0100)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Fri, 6 Apr 2012 17:43:16 +0000 (18:43 +0100)
src/ofc/codegen/ConditionalValue.scala
src/ofc/codegen/Expression.scala
src/ofc/codegen/NumericOperator.scala
src/ofc/codegen/ProducerStatement.scala
src/ofc/util/Ordering.scala [new file with mode: 0644]

index fc07f8e0ecd039fb5394a899627776ebcd04e0d3..4e39ada6d2fa658e1bc8b32f554e03ebe889a3ef 100644 (file)
@@ -1,3 +1,5 @@
 package ofc.codegen
 
-class ConditionalValue[T <: Type](predicate: Expression[BoolType], ifTrue: Expression[T], ifFalse: Expression[T]) extends Expression[T]
+class ConditionalValue[T <: Type](predicate: Expression[BoolType], ifTrue: Expression[T], ifFalse: Expression[T]) extends Expression[T] {
+  def foreach[U](f: Expression[_] => U) = List(predicate, ifTrue, ifFalse).foreach(f)
+}
index 71b863353ba36fb3a94555c81b87cade7c03a189..0279a9692d76be4054f26be5727fa1757617ea43 100644 (file)
@@ -18,11 +18,28 @@ class ArrayType[ElementType <: Type] extends Type
 class PointerType[TargetType <: Type] extends Type
 abstract class StructType extends Type
 
+trait LeafExpression {
+  def foreach[U](f: Expression[_] => U): Unit = ()
+}
+
 object Expression {
   implicit def fromInt(i: Int) : Expression[IntType] = new IntegerLiteral(i)
+
+  def findReferencedVariables(expression: Expression[_]) : Set[VarSymbol[_]] =
+    findReferencedVariables(Some[Expression[_]](expression))
+
+  def findReferencedVariables(expressions: Traversable[Expression[_]]) : Set[VarSymbol[_]] = {
+    val vars = scala.collection.mutable.Set[VarSymbol[_]]()
+    expressions.foreach(_ match {
+      case (ref: VarRef[_]) => vars += ref.getSymbol
+      case x => vars ++= findReferencedVariables(x.toTraversable)
+    })
+
+    vars.toSet
+  }
 }
 
-class Expression[T <: Type] {
+abstract class Expression[T <: Type] extends Traversable[Expression[_]] {
   // Field Operations
   def +(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[T] =
     new NumericOperator[T](NumericOperations.Add, this, rhs)
@@ -69,12 +86,21 @@ class Expression[T <: Type] {
 }
 
 // Variable references
-class VarRef[T <: Type](symbol: VarSymbol[T]) extends Expression[T]
+class VarRef[T <: Type](symbol: VarSymbol[T]) extends Expression[T] with LeafExpression {
+  def getSymbol = symbol
+}
 
 // Struct and array accesses
-class FieldAccess[T <: Type](expression: Expression[StructType], field: FieldSymbol[T]) extends Expression[T]
-class ArrayRead[E <: Type](expression: Expression[ArrayType[E]], index: Seq[Expression[IntType]]) extends Expression[E]
-class PointerDereference[E <: Type](expression: Expression[PointerType[E]]) extends Expression[E]
+class FieldAccess[T <: Type](expression: Expression[StructType], field: FieldSymbol[T]) extends Expression[T] {
+  def foreach[U](f: Expression[_] => U) = f(expression)
+}
+
+class ArrayRead[E <: Type](expression: Expression[ArrayType[E]], index: Seq[Expression[IntType]]) extends Expression[E] {
+  def foreach[U](f: Expression[_] => U) = (index :+ expression).foreach(f)
+}
+class PointerDereference[E <: Type](expression: Expression[PointerType[E]]) extends Expression[E] {
+  def foreach[U](f: Expression[_] => U) = f(expression)
+}
 
 // Literals
-class IntegerLiteral(value: Int) extends Expression[IntType]
+class IntegerLiteral(value: Int) extends Expression[IntType] with LeafExpression
index 8225cf705dbf3f93dda2443a8ceb75a46971d5dc..ef23db1675a056af0289f11f3ddb041d09792d18 100644 (file)
@@ -17,5 +17,10 @@ object NumericOperations {
   object GTE extends CompareOp
 }
 
-class NumericOperator[T <: Type](op: NumericOperations.FieldOp, left: Expression[T], right: Expression[T]) extends Expression[T]
-class NumericComparison[T <: Type](op: NumericOperations.CompareOp, left: Expression[T], right: Expression[T]) extends Expression[BoolType]
+class NumericOperator[T <: Type](op: NumericOperations.FieldOp, left: Expression[T], right: Expression[T]) extends Expression[T] {
+  def foreach[U](f: Expression[_] => U) = List(left, right).foreach(f)
+}
+
+class NumericComparison[T <: Type](op: NumericOperations.CompareOp, left: Expression[T], right: Expression[T]) extends Expression[BoolType] {
+  def foreach[U](f: Expression[_] => U) = List(left, right).foreach(f)
+}
index 35687bb8901a7f8887c748f2e0f684be6670d9d3..52cbba321a9eaaab1045a7ee875b8c092d2a45fe 100644 (file)
@@ -1,17 +1,71 @@
 package ofc.codegen
+import ofc.util.Ordering
 
 class ProducerStatement extends Statement {
-  case class VariableRange(symbol: Symbol, first: Expression[IntType], last: Expression[IntType])
-  case class Predicate(expression: Expression[BoolType])
+  object Context {
+    def sort(contexts: Seq[Context]) : Seq[Context] = {
+      val ordering = scala.collection.mutable.Set[(Context, Context)]()
+      for(c1 <- contexts; c2 <- contexts)
+        (c1.tryCompare(c2)) match {
+          case Some(x) => {
+            if (x<0) 
+              ordering += (c1 -> c2)
+            else if (x>0)
+              ordering += (c2 -> c1)
+          }
+          case None => ()
+        }
+
+      val totalOrdering = Ordering.transitiveClosure(ordering.toSet)
+      contexts.sortWith((a,b) => totalOrdering.contains(a,b))
+    }
+  }
+
+  trait Context {
+    def defines : Set[VarSymbol[_]]
+    def depends : Set[VarSymbol[_]]
+    def tryCompare(other: Context) : Option[Int] = {
+      val isChild = other.defines & depends
+      val isParent = defines & other.depends
+      assert(isChild.isEmpty || isParent.isEmpty)
+
+      if (this == other)
+        Some(0)
+      else if (isParent.nonEmpty)
+        Some(-1)
+      else if (isChild.nonEmpty)
+        Some(1)
+      else
+        None
+    }
+  }
+
+  class VariableRange(symbol: VarSymbol[_], first: Expression[IntType], last: Expression[IntType]) extends Context {
+    override def toString = "VariableRange("+symbol.toString+", "+first.toString+", "+last.toString+")"
+    def defines = Set(symbol)
+    def depends = Expression.findReferencedVariables(List(first, last))
+  }
+
+  class Predicate(expression: Expression[BoolType]) extends Context {
+    override def toString = "Predicate("+expression.toString+")"
+    def defines = Set.empty
+    def depends = Expression.findReferencedVariables(expression)
+  }
+
+  class DerivedExpression(symbol: VarSymbol[_], expression: Expression[_]) extends Context {
+    override def toString = "DerivedExpression("+symbol.toString + " <- " + expression.toString+")"
+    def defines = Set(symbol)
+    def depends = Expression.findReferencedVariables(expression)
+  }
 
   var statement = new Comment("Placeholder statement for consumer.")
   var ranges : Seq[VariableRange] = Nil
   var predicates : Seq[Predicate] = Nil
-  var expressions : Map[VarSymbol[_], Expression[_]] = Map.empty
+  var expressions : Seq[DerivedExpression] = Nil
 
   def addExpression[T <: Type](name: String, expression: Expression[T]) : VarSymbol[T] = {
     val symbol = new DeclaredVarSymbol[T](name)
-    expressions += symbol -> expression
+    expressions +:= new DerivedExpression(symbol, expression)
     symbol
   }
 
@@ -21,5 +75,13 @@ class ProducerStatement extends Statement {
     symbol
   }
 
-  def toConcrete = new Comment("Producer flattening unimplemented.")
+  def toConcrete : Statement = {
+    val contexts : Seq[Context] = ranges ++ predicates ++ expressions
+    val sortedContexts = Context.sort(contexts)
+
+    for(c <- sortedContexts)
+     println(c.toString)
+
+    new Comment("Producer flattening unimplemented.")
+  }
 }
diff --git a/src/ofc/util/Ordering.scala b/src/ofc/util/Ordering.scala
new file mode 100644 (file)
index 0000000..bd6fd46
--- /dev/null
@@ -0,0 +1,26 @@
+package ofc.util
+import scala.annotation.tailrec
+
+object Ordering {
+  
+  @tailrec
+  def transitiveClosure[T](ordering: Set[(T, T)]) : Set[(T, T)] = {
+    def step[T](ordering: Set[(T, T)]) = {
+      val newOrdering = scala.collection.mutable.Set[(T, T)]()
+      newOrdering ++= ordering
+
+      for ((a1,a2) <- ordering; (b1, b2) <- ordering; if a2 == b1)
+        newOrdering += (a1 -> b2)
+
+      assert(newOrdering.size >= ordering.size)
+      newOrdering.toSet
+    }
+
+    val stepped = step(ordering)
+
+    if (stepped.size == ordering.size)
+      ordering
+    else
+      transitiveClosure(stepped)
+  }
+}