From 4bb31141fed11102492fcfe2437a6b35794e6859 Mon Sep 17 00:00:00 2001 From: Francis Russell Date: Fri, 6 Apr 2012 18:24:59 +0100 Subject: [PATCH] Compute dependencies between constructs in ProducerStatement. --- src/ofc/codegen/ConditionalValue.scala | 4 +- src/ofc/codegen/Expression.scala | 38 ++++++++++--- src/ofc/codegen/NumericOperator.scala | 9 +++- src/ofc/codegen/ProducerStatement.scala | 72 +++++++++++++++++++++++-- src/ofc/util/Ordering.scala | 26 +++++++++ 5 files changed, 135 insertions(+), 14 deletions(-) create mode 100644 src/ofc/util/Ordering.scala diff --git a/src/ofc/codegen/ConditionalValue.scala b/src/ofc/codegen/ConditionalValue.scala index fc07f8e..4e39ada 100644 --- a/src/ofc/codegen/ConditionalValue.scala +++ b/src/ofc/codegen/ConditionalValue.scala @@ -1,3 +1,5 @@ package ofc.codegen -class ConditionalValue[T <: Type](predicate: Expression[BoolType], ifTrue: Expression[T], ifFalse: Expression[T]) extends Expression[T] +class ConditionalValue[T <: Type](predicate: Expression[BoolType], ifTrue: Expression[T], ifFalse: Expression[T]) extends Expression[T] { + def foreach[U](f: Expression[_] => U) = List(predicate, ifTrue, ifFalse).foreach(f) +} diff --git a/src/ofc/codegen/Expression.scala b/src/ofc/codegen/Expression.scala index 71b8633..0279a96 100644 --- a/src/ofc/codegen/Expression.scala +++ b/src/ofc/codegen/Expression.scala @@ -18,11 +18,28 @@ class ArrayType[ElementType <: Type] extends Type class PointerType[TargetType <: Type] extends Type abstract class StructType extends Type +trait LeafExpression { + def foreach[U](f: Expression[_] => U): Unit = () +} + object Expression { implicit def fromInt(i: Int) : Expression[IntType] = new IntegerLiteral(i) + + def findReferencedVariables(expression: Expression[_]) : Set[VarSymbol[_]] = + findReferencedVariables(Some[Expression[_]](expression)) + + def findReferencedVariables(expressions: Traversable[Expression[_]]) : Set[VarSymbol[_]] = { + val vars = scala.collection.mutable.Set[VarSymbol[_]]() + expressions.foreach(_ match { + case (ref: VarRef[_]) => vars += ref.getSymbol + case x => vars ++= findReferencedVariables(x.toTraversable) + }) + + vars.toSet + } } -class Expression[T <: Type] { +abstract class Expression[T <: Type] extends Traversable[Expression[_]] { // Field Operations def +(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[T] = new NumericOperator[T](NumericOperations.Add, this, rhs) @@ -69,12 +86,21 @@ class Expression[T <: Type] { } // Variable references -class VarRef[T <: Type](symbol: VarSymbol[T]) extends Expression[T] +class VarRef[T <: Type](symbol: VarSymbol[T]) extends Expression[T] with LeafExpression { + def getSymbol = symbol +} // Struct and array accesses -class FieldAccess[T <: Type](expression: Expression[StructType], field: FieldSymbol[T]) extends Expression[T] -class ArrayRead[E <: Type](expression: Expression[ArrayType[E]], index: Seq[Expression[IntType]]) extends Expression[E] -class PointerDereference[E <: Type](expression: Expression[PointerType[E]]) extends Expression[E] +class FieldAccess[T <: Type](expression: Expression[StructType], field: FieldSymbol[T]) extends Expression[T] { + def foreach[U](f: Expression[_] => U) = f(expression) +} + +class ArrayRead[E <: Type](expression: Expression[ArrayType[E]], index: Seq[Expression[IntType]]) extends Expression[E] { + def foreach[U](f: Expression[_] => U) = (index :+ expression).foreach(f) +} +class PointerDereference[E <: Type](expression: Expression[PointerType[E]]) extends Expression[E] { + def foreach[U](f: Expression[_] => U) = f(expression) +} // Literals -class IntegerLiteral(value: Int) extends Expression[IntType] +class IntegerLiteral(value: Int) extends Expression[IntType] with LeafExpression diff --git a/src/ofc/codegen/NumericOperator.scala b/src/ofc/codegen/NumericOperator.scala index 8225cf7..ef23db1 100644 --- a/src/ofc/codegen/NumericOperator.scala +++ b/src/ofc/codegen/NumericOperator.scala @@ -17,5 +17,10 @@ object NumericOperations { object GTE extends CompareOp } -class NumericOperator[T <: Type](op: NumericOperations.FieldOp, left: Expression[T], right: Expression[T]) extends Expression[T] -class NumericComparison[T <: Type](op: NumericOperations.CompareOp, left: Expression[T], right: Expression[T]) extends Expression[BoolType] +class NumericOperator[T <: Type](op: NumericOperations.FieldOp, left: Expression[T], right: Expression[T]) extends Expression[T] { + def foreach[U](f: Expression[_] => U) = List(left, right).foreach(f) +} + +class NumericComparison[T <: Type](op: NumericOperations.CompareOp, left: Expression[T], right: Expression[T]) extends Expression[BoolType] { + def foreach[U](f: Expression[_] => U) = List(left, right).foreach(f) +} diff --git a/src/ofc/codegen/ProducerStatement.scala b/src/ofc/codegen/ProducerStatement.scala index 35687bb..52cbba3 100644 --- a/src/ofc/codegen/ProducerStatement.scala +++ b/src/ofc/codegen/ProducerStatement.scala @@ -1,17 +1,71 @@ package ofc.codegen +import ofc.util.Ordering class ProducerStatement extends Statement { - case class VariableRange(symbol: Symbol, first: Expression[IntType], last: Expression[IntType]) - case class Predicate(expression: Expression[BoolType]) + object Context { + def sort(contexts: Seq[Context]) : Seq[Context] = { + val ordering = scala.collection.mutable.Set[(Context, Context)]() + for(c1 <- contexts; c2 <- contexts) + (c1.tryCompare(c2)) match { + case Some(x) => { + if (x<0) + ordering += (c1 -> c2) + else if (x>0) + ordering += (c2 -> c1) + } + case None => () + } + + val totalOrdering = Ordering.transitiveClosure(ordering.toSet) + contexts.sortWith((a,b) => totalOrdering.contains(a,b)) + } + } + + trait Context { + def defines : Set[VarSymbol[_]] + def depends : Set[VarSymbol[_]] + def tryCompare(other: Context) : Option[Int] = { + val isChild = other.defines & depends + val isParent = defines & other.depends + assert(isChild.isEmpty || isParent.isEmpty) + + if (this == other) + Some(0) + else if (isParent.nonEmpty) + Some(-1) + else if (isChild.nonEmpty) + Some(1) + else + None + } + } + + class VariableRange(symbol: VarSymbol[_], first: Expression[IntType], last: Expression[IntType]) extends Context { + override def toString = "VariableRange("+symbol.toString+", "+first.toString+", "+last.toString+")" + def defines = Set(symbol) + def depends = Expression.findReferencedVariables(List(first, last)) + } + + class Predicate(expression: Expression[BoolType]) extends Context { + override def toString = "Predicate("+expression.toString+")" + def defines = Set.empty + def depends = Expression.findReferencedVariables(expression) + } + + class DerivedExpression(symbol: VarSymbol[_], expression: Expression[_]) extends Context { + override def toString = "DerivedExpression("+symbol.toString + " <- " + expression.toString+")" + def defines = Set(symbol) + def depends = Expression.findReferencedVariables(expression) + } var statement = new Comment("Placeholder statement for consumer.") var ranges : Seq[VariableRange] = Nil var predicates : Seq[Predicate] = Nil - var expressions : Map[VarSymbol[_], Expression[_]] = Map.empty + var expressions : Seq[DerivedExpression] = Nil def addExpression[T <: Type](name: String, expression: Expression[T]) : VarSymbol[T] = { val symbol = new DeclaredVarSymbol[T](name) - expressions += symbol -> expression + expressions +:= new DerivedExpression(symbol, expression) symbol } @@ -21,5 +75,13 @@ class ProducerStatement extends Statement { symbol } - def toConcrete = new Comment("Producer flattening unimplemented.") + def toConcrete : Statement = { + val contexts : Seq[Context] = ranges ++ predicates ++ expressions + val sortedContexts = Context.sort(contexts) + + for(c <- sortedContexts) + println(c.toString) + + new Comment("Producer flattening unimplemented.") + } } diff --git a/src/ofc/util/Ordering.scala b/src/ofc/util/Ordering.scala new file mode 100644 index 0000000..bd6fd46 --- /dev/null +++ b/src/ofc/util/Ordering.scala @@ -0,0 +1,26 @@ +package ofc.util +import scala.annotation.tailrec + +object Ordering { + + @tailrec + def transitiveClosure[T](ordering: Set[(T, T)]) : Set[(T, T)] = { + def step[T](ordering: Set[(T, T)]) = { + val newOrdering = scala.collection.mutable.Set[(T, T)]() + newOrdering ++= ordering + + for ((a1,a2) <- ordering; (b1, b2) <- ordering; if a2 == b1) + newOrdering += (a1 -> b2) + + assert(newOrdering.size >= ordering.size) + newOrdering.toSet + } + + val stepped = step(ordering) + + if (stepped.size == ordering.size) + ordering + else + transitiveClosure(stepped) + } +} -- 2.47.3