]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Work on typing expressions.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Wed, 4 Apr 2012 22:58:31 +0000 (23:58 +0100)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Wed, 4 Apr 2012 22:58:31 +0000 (23:58 +0100)
src/ofc/codegen/ConditionalValue.scala
src/ofc/codegen/Expression.scala
src/ofc/codegen/NumericOperator.scala [moved from src/ofc/codegen/BinaryOperator.scala with 60% similarity]
src/ofc/codegen/ProducerStatement.scala
src/ofc/codegen/Symbol.scala
src/ofc/generators/onetep/Index.scala
src/ofc/generators/onetep/PPDFunctionSet.scala

index 9b93e8200bd9d302299ae9cdddbf0ef0a2a21041..fc07f8e0ecd039fb5394a899627776ebcd04e0d3 100644 (file)
@@ -1,3 +1,3 @@
 package ofc.codegen
 
-class ConditionalValue(predicate: Expression, ifTrue: Expression, ifFalse: Expression) extends Expression
+class ConditionalValue[T <: Type](predicate: Expression[BoolType], ifTrue: Expression[T], ifFalse: Expression[T]) extends Expression[T]
index 0e93d92b6e60a4a1bab66c0f574e4980c2950c09..05ce29c9533560aeaa3f2c3cec6f3884b26d9a64 100644 (file)
@@ -1,25 +1,63 @@
 package ofc.codegen
 
-class Expression {
-  def ~>(field: FieldSymbol) : Expression = new FieldAccess(this, field)
-  def readAt(index: List[Expression]) : Expression = new ArrayRead(this, index)
-  def apply(index: Expression*) : Expression = this.readAt(index.toList)
-  def +(rhs: Expression) : Expression = new BinaryOperator(BinaryOperator.Add, this, rhs)
-  def -(rhs: Expression) : Expression = new BinaryOperator(BinaryOperator.Sub, this, rhs)
-  def *(rhs: Expression) : Expression = new BinaryOperator(BinaryOperator.Mul, this, rhs)
-  def /(rhs: Expression) : Expression = new BinaryOperator(BinaryOperator.Div, this, rhs)
-  def %(rhs: Expression) : Expression = new BinaryOperator(BinaryOperator.Mod, this, rhs)
+trait TypeProperty
+trait Numeric extends TypeProperty
+
+class HasProperty[T <: Type, P <: TypeProperty]
+object HasProperty {
+  implicit val intNumeric = new HasProperty[IntType, Numeric]()
+  implicit val floatNumeric = new HasProperty[FloatType, Numeric]()
+}
+
+sealed abstract class Type
+sealed abstract class PrimitiveType extends Type
+class IntType extends PrimitiveType
+class FloatType extends PrimitiveType
+class BoolType extends PrimitiveType
+class ArrayType[ElementType <: Type] extends Type
+class PointerType[TargetType <: Type] extends Type
+abstract class StructType extends Type
+
+object Expression {
+  implicit def fromInt(i: Int) : Expression[IntType] = new IntegerLiteral(i)
+}
+
+class Expression[T <: Type] {
+  def +(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[T] =
+    new NumericOperator[T](NumericOperator.Add, this, rhs)
+
+  def -(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[T] =
+    new NumericOperator[T](NumericOperator.Sub, this, rhs)
+
+  def *(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[T] =
+    new NumericOperator[T](NumericOperator.Mul, this, rhs)
+
+  def /(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[T] =
+    new NumericOperator[T](NumericOperator.Div, this, rhs)
+
+  def %(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[T] =
+    new NumericOperator[T](NumericOperator.Mod, this, rhs)
+
+  def ~>[FieldType <: Type](field: FieldSymbol[FieldType])(implicit witness: <:<[this.type, Expression[StructType]]) : Expression[FieldType] =
+    new FieldAccess[FieldType](witness(this), field)
+
+  def readAt[T <: Type](index: Expression[IntType]*)(implicit witness: <:<[this.type, Expression[ArrayType[T]]]) : Expression[T] = 
+    new ArrayRead(witness(this), index.toList)
+
+  //def apply[T <: Type](index: Expression[IntType]*)(implicit witness: <:<[this.type, Expression[ArrayType[T]]]) : Expression[T] = 
+  //  this.readAt(index.toList)
+
+  def unary_~[T <: Type]()(implicit witness: <:<[this.type, Expression[PointerType[T]]]) : Expression[T] =
+    new PointerDereference(witness(this))
 }
 
 // Variable references
-class VarRef(symbol: VarSymbol) extends Expression
+class VarRef[T <: Type](symbol: VarSymbol[T]) extends Expression[T]
 
 // Struct and array accesses
-class FieldAccess(expression: Expression, field: FieldSymbol) extends Expression
-class ArrayRead(expression: Expression, index: List[Expression]) extends Expression
+class FieldAccess[T <: Type](expression: Expression[StructType], field: FieldSymbol[T]) extends Expression[T]
+class ArrayRead[E <: Type](expression: Expression[ArrayType[E]], index: List[Expression[IntType]]) extends Expression[E]
+class PointerDereference[E <: Type](expression: Expression[PointerType[E]]) extends Expression[E]
 
 // Literals
-class IntegerLiteral(value: Int) extends Expression
-object IntegerLiteral {
-  implicit def fromInt(value: Int) : Expression = new IntegerLiteral(value)
-}
+class IntegerLiteral(value: Int) extends Expression[IntType]
similarity index 60%
rename from src/ofc/codegen/BinaryOperator.scala
rename to src/ofc/codegen/NumericOperator.scala
index 406f5073318ead011c2ce89c94e3258ed9571a1e..267ba1071166e19ca6d67588d446804c8f5c8956 100644 (file)
@@ -1,6 +1,6 @@
 package ofc.codegen
 
-object BinaryOperator {
+object NumericOperator {
   sealed abstract class Operator
   case object Add extends Operator
   case object Sub extends Operator
@@ -9,4 +9,4 @@ object BinaryOperator {
   case object Mod extends Operator
 }
 
-class BinaryOperator(op: BinaryOperator.Operator, left: Expression, right: Expression) extends Expression
+class NumericOperator[T <: Type](op: NumericOperator.Operator, left: Expression[T], right: Expression[T]) extends Expression[T]
index 96c119a52e778c7a26499ed13a2f0cad95600684..8c1af5edf587d5684923318d4d31fffd34123573 100644 (file)
@@ -1,22 +1,22 @@
 package ofc.codegen
 
 class ProducerStatement extends Statement {
-  class VariableRange(symbol: Symbol, count: Expression)
+  class VariableRange(symbol: Symbol, count: Expression[IntType])
   class Predicate
 
   var statement = new NullStatement
   var ranges : List[VariableRange] = List.empty
   var predicates : List[Predicate] = List.empty
-  var expressions : Map[Symbol, Expression] = Map.empty
+  var expressions : Map[Symbol, Expression[_]] = Map.empty
 
-  def addExpression(name: String, expression: Expression) : VarSymbol = {
-    val symbol = new DeclaredVarSymbol(name)
+  def addExpression[T <: Type](name: String, expression: Expression[T]) : VarSymbol[T] = {
+    val symbol = new DeclaredVarSymbol[T](name)
     expressions += symbol -> expression
     symbol
   }
 
-  def addIteration(name: String, count: Expression) : VarSymbol = {
-    val symbol = new DeclaredVarSymbol(name)
+  def addIteration(name: String, count: Expression[IntType]) : VarSymbol[IntType] = {
+    val symbol = new DeclaredVarSymbol[IntType](name)
     ranges +:= new VariableRange(symbol, count)
     symbol
   }
index b55caace0c7eb471b47f120586a61bb6c2c2cf15..a7c770cfccd665239395bc8ebfee5f4e2bf0a78e 100644 (file)
@@ -4,19 +4,18 @@ trait Symbol {
   def getName : String
 }
 
-case class FieldSymbol(name: String) extends Symbol {
+case class FieldSymbol[T <: Type](name: String) extends Symbol {
   def getName = name
 }
 
-abstract class VarSymbol(name: String) extends Symbol {
+abstract class VarSymbol[T <: Type](name: String) extends Symbol {
   def getName = name
 }
 
 object VarSymbol {
-  implicit def toRef(symbol: VarSymbol) = new VarRef(symbol)
+  implicit def toRef[T <: Type](symbol: VarSymbol[T]) = new VarRef[T](symbol)
 }
 
-case class DeclaredVarSymbol(name: String) extends VarSymbol(name)
-abstract class UnboundVarSymbol(name: String) extends VarSymbol(name)
-case class NamedUnboundVarSymbol(name: String) extends UnboundVarSymbol(name)
-class IterationSymbol(name: String, count: Expression) extends UnboundVarSymbol(name)
+case class DeclaredVarSymbol[T <: Type](name: String) extends VarSymbol[T](name)
+abstract class UnboundVarSymbol[T <: Type](name: String) extends VarSymbol[T](name)
+case class NamedUnboundVarSymbol[T <: Type](name: String) extends UnboundVarSymbol[T](name)
index 435d62e25af06095b0d69f5f8326a19c028c5a7c..caabe2b595eb23fbef229515e9ac5710c6fe1add 100644 (file)
@@ -1,5 +1,5 @@
 package ofc.generators.onetep
-import ofc.codegen.Expression
+import ofc.codegen.{Expression,IntType}
 
 /*
 object Index {
@@ -18,9 +18,8 @@ object Index {
 
 trait Index {
   def getName : String
-  def getMappingFunction : Expression
-  def getMinimumValue : Expression
-  def getLength : Expression
+  def getMinimumValue : Expression[IntType]
+  def getLength : Expression[IntType]
   def isRandomAccess : Boolean
 }
 
index f35aee03da4203caa1f904522cda1a68f8c54ee0..24c298aed72e7e6e667e61323814cb2079931d5e 100644 (file)
@@ -2,20 +2,20 @@ package ofc.generators.onetep
 import ofc.codegen._
 
 class PPDFunctionSet(val basisName: String, dataName: String) extends FunctionSet {
-  val basis = NamedUnboundVarSymbol(basisName)
-  val data =  NamedUnboundVarSymbol(dataName)
-  val pubCell = NamedUnboundVarSymbol("pub_cell")
+  val basis = NamedUnboundVarSymbol[StructType](basisName)
+  val data =  NamedUnboundVarSymbol[ArrayType[FloatType]](dataName)
+  val pubCell = NamedUnboundVarSymbol[StructType]("pub_cell")
 
-  val numSpheres = basis~>FieldSymbol("num");
-  val ppdWidths = for(dim <- 1 to 3) yield pubCell~>FieldSymbol("n_pt"+dim)
-  val cellWidthInPPDs = for(dim <- 1 to 3) yield pubCell~>FieldSymbol("n_ppds_a"+dim)
+  val numSpheres = basis~>FieldSymbol[IntType]("num");
+  val ppdWidths = for(dim <- 1 to 3) yield pubCell~>FieldSymbol[IntType]("n_pt"+dim)
+  val cellWidthInPPDs = for(dim <- 1 to 3) yield pubCell~>FieldSymbol[IntType]("n_ppds_a"+dim)
 
   def getSuffixFragment = {
     val producer = new ProducerStatement
     val sphereIndex = producer.addIteration("sphere_index", numSpheres)
-    val numPPDs = (basis~>FieldSymbol("n_ppds_sphere"))(sphereIndex)
+    val numPPDs = (~(basis~>FieldSymbol[PointerType[ArrayType[IntType]]]("n_ppds_sphere"))).readAt(sphereIndex)
     val ppdIndex = producer.addIteration("ppd_index", numPPDs)
-    val ppdGlobalCount = (basis~>FieldSymbol("ppd_list"))(ppdIndex, new IntegerLiteral(1)) - new IntegerLiteral(1)
+    val ppdGlobalCount = (~(basis~>FieldSymbol[PointerType[ArrayType[IntType]]]("ppd_list"))).readAt(ppdIndex, 1) - 1
 
     // We need to calculate the integer co-ordinates of the PPD (0-based)
     val a3pos = ppdGlobalCount / (cellWidthInPPDs(0)*cellWidthInPPDs(1))
@@ -24,12 +24,12 @@ class PPDFunctionSet(val basisName: String, dataName: String) extends FunctionSe
 
     val ppdPos = List(a1pos, a2pos, a3pos)
 
-    val tightbox = basis~>FieldSymbol("tight_boxes")(sphereIndex)
-    val ppdStartOffsets = for(dim <- 1 to 3) yield tightbox % FieldSymbol("start_pts"+dim)
-    val ppdFinishOffsets = for(dim <- 1 to 3) yield tightbox % FieldSymbol("finish_pts"+dim)
+    val tightbox = (~(basis~>FieldSymbol[PointerType[ArrayType[StructType]]]("tight_boxes"))).readAt(sphereIndex)
+    val ppdStartOffsets = for(dim <- 1 to 3) yield tightbox ~> FieldSymbol[IntType]("start_pts"+dim)
+    val ppdFinishOffsets = for(dim <- 1 to 3) yield tightbox ~> FieldSymbol[IntType]("finish_pts"+dim)
 
-    val startPPDs = for(dim <- 0 to 2) yield (tightbox % FieldSymbol("start_ppd"+(dim+1)) + cellWidthInPPDs(dim)) % cellWidthInPPDs(dim)
-    val finishPPDs = for(dim <- 0 to 2) yield (tightbox % FieldSymbol("finish_ppd"+(dim+1)) + cellWidthInPPDs(dim)) % cellWidthInPPDs(dim)
+    val startPPDs = for(dim <- 0 to 2) yield (tightbox ~> FieldSymbol("start_ppd"+(dim+1)) + cellWidthInPPDs(dim)) % cellWidthInPPDs(dim)
+    val finishPPDs = for(dim <- 0 to 2) yield (tightbox ~> FieldSymbol("finish_ppd"+(dim+1)) + cellWidthInPPDs(dim)) % cellWidthInPPDs(dim)
 
     //val ppdRanges = for(dim <- 0 to 2) yield producer.addIteration("point"+(dim+1), ppdWidths(dim))