]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Make expressions and variables carry their types.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Sun, 8 Apr 2012 01:14:57 +0000 (02:14 +0100)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Sun, 8 Apr 2012 01:14:57 +0000 (02:14 +0100)
src/ofc/codegen/ConditionalValue.scala
src/ofc/codegen/Expression.scala
src/ofc/codegen/FortranGenerator.scala
src/ofc/codegen/NumericOperator.scala
src/ofc/codegen/ProducerStatement.scala
src/ofc/codegen/Symbol.scala
src/ofc/codegen/Type.scala [new file with mode: 0644]
src/ofc/generators/onetep/OnetepTypes.scala [new file with mode: 0644]
src/ofc/generators/onetep/PPDFunctionSet.scala

index f35c3d31c9280c3715ec5f633b49065936e9b215..84f4de50a092222b013ae66f819e3b8803c9c308 100644 (file)
@@ -1,8 +1,15 @@
 package ofc.codegen
+import ofc.LogicError
 
 class ConditionalValue[T <: Type](predicate: Expression[BoolType], ifTrue: Expression[T], ifFalse: Expression[T]) extends Expression[T] {
   def foreach[U](f: Expression[_] => U) = List(predicate, ifTrue, ifFalse).foreach(f)
   def getPredicate = predicate
   def getIfTrue = ifTrue
   def getIfFalse = ifFalse
+  def getType = {
+    if (ifTrue.getType != ifFalse.getType)
+      throw new LogicError("Parameters to ternary operator have different types")
+    else
+      ifTrue.getType
+  }
 }
index a28f219391bb48a6176f6df03431bfe615d3b045..7325ce4ebf5679623ff924d30f999e287dbc32e6 100644 (file)
@@ -1,27 +1,5 @@
 package ofc.codegen
 
-trait TypeProperty
-trait Numeric extends TypeProperty
-
-class HasProperty[T <: Type, P <: TypeProperty]
-object HasProperty {
-  implicit val intNumeric = new HasProperty[IntType, Numeric]()
-  implicit val floatNumeric = new HasProperty[FloatType, Numeric]()
-}
-
-sealed abstract class Type
-sealed abstract class PrimitiveType extends Type
-class IntType extends PrimitiveType
-class FloatType extends PrimitiveType
-class BoolType extends PrimitiveType
-class ArrayType[ElementType <: Type] extends Type
-class PointerType[TargetType <: Type] extends Type
-abstract class StructType extends Type
-
-trait LeafExpression {
-  def foreach[U](f: Expression[_] => U): Unit = ()
-}
-
 object Expression {
   implicit def fromInt(i: Int) : Expression[IntType] = new IntegerLiteral(i)
 
@@ -40,6 +18,8 @@ object Expression {
 }
 
 abstract class Expression[T <: Type] extends Traversable[Expression[_]] {
+  def getType : T
+
   // Field Operations
   def +(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[T] =
     new NumericOperator[T](NumericOperations.Add, this, rhs)
@@ -85,9 +65,14 @@ abstract class Expression[T <: Type] extends Traversable[Expression[_]] {
     new PointerDereference(witness(this))
 }
 
+trait LeafExpression {
+  def foreach[U](f: Expression[_] => U): Unit = ()
+}
+
 // Variable references
 class VarRef[T <: Type](symbol: VarSymbol[T]) extends Expression[T] with LeafExpression {
   def getSymbol = symbol
+  def getType = symbol.getType
 }
 
 // Struct and array accesses
@@ -95,19 +80,23 @@ class FieldAccess[T <: Type](expression: Expression[StructType], field: FieldSym
   def foreach[U](f: Expression[_] => U) = f(expression)
   def getStructExpression = expression
   def getField = field
+  def getType = field.getType
 }
 
 class ArrayRead[E <: Type](expression: Expression[ArrayType[E]], index: Seq[Expression[IntType]]) extends Expression[E] {
   def foreach[U](f: Expression[_] => U) = (index :+ expression).foreach(f)
   def getArrayExpression = expression
   def getIndexExpressions = index
+  def getType = expression.getType.getElementType
 }
 class PointerDereference[E <: Type](expression: Expression[PointerType[E]]) extends Expression[E] {
   def foreach[U](f: Expression[_] => U) = f(expression)
   def getExpression = expression
+  def getType = expression.getType.getTargetType
 }
 
 // Literals
 class IntegerLiteral(value: Int) extends Expression[IntType] with LeafExpression {
   def getValue = value
+  def getType = new IntType
 }
index fea6c970b326d8b2635f148678028a229206589d..fe92476ec11d09b4a5ae7134027a8091d7c002cf 100644 (file)
@@ -106,8 +106,8 @@ class FortranGenerator {
     }
   }
 
-  private def buildConditionalValue(conditional: ConditionalValue[_]) : ExpHolder = {
-    var symbol = new DeclaredVarSymbol[Type]("ternary")
+  private def buildConditionalValue(conditional: ConditionalValue[_ <: Type]) : ExpHolder = {
+    var symbol = new DeclaredVarSymbol[Type]("ternary", conditional.getType)
     symbolManager.addSymbol(symbol)
     val name = symbolManager.getName(symbol)
     addLine("if (%s) then".format(buildExpression(conditional.getPredicate)))
index d119783beeb5b5e7b9241f4508f4cc22f8197a70..58213a461a87bcfd939998afea7fcf8eab21626b 100644 (file)
@@ -1,4 +1,5 @@
 package ofc.codegen
+import ofc.LogicError
 
 object NumericOperations {
   sealed abstract class FieldOp
@@ -18,17 +19,22 @@ object NumericOperations {
 }
 
 class NumericOperator[T <: Type](op: NumericOperations.FieldOp, left: Expression[T], right: Expression[T]) extends Expression[T] {
-  // TODO: Type check operators
   def foreach[U](f: Expression[_] => U) = List(left, right).foreach(f)
   def getOperation = op
   def getLeft = left
   def getRight = right
+  def getType = {
+    if (left.getType != right.getType)
+      throw new LogicError("Non-matching types for parameters to numeric comparison")
+    else
+      left.getType
+  }
 }
 
 class NumericComparison[T <: Type](op: NumericOperations.CompareOp, left: Expression[T], right: Expression[T]) extends Expression[BoolType] {
-  // TODO: Type check operators
   def foreach[U](f: Expression[_] => U) = List(left, right).foreach(f)
   def getOperation = op
   def getLeft = left
   def getRight = right
+  def getType = new BoolType
 }
index a1a301d94ea5da066d2a93ded31ddcd85e4f849f..e564ca7d16edc4e191cb8b52df417be6d13f6a77 100644 (file)
@@ -85,7 +85,7 @@ class ProducerStatement extends Statement {
   var expressions : Seq[DerivedExpression] = Nil
 
   def addExpression[T <: Type](name: String, expression: Expression[T]) : VarSymbol[T] = {
-    val symbol = new DeclaredVarSymbol[T](name)
+    val symbol = new DeclaredVarSymbol[T](name, expression.getType)
     expressions +:= new DerivedExpression(symbol, expression)
     symbol
   }
index 6d463c6eda2b2db16b44fe95837d7e1145c8b3c0..905920d855343ae532e9e978995d52681650cb0a 100644 (file)
@@ -4,17 +4,25 @@ trait Symbol {
   def getName : String
 }
 
-case class FieldSymbol[T <: Type](name: String) extends Symbol {
+class FieldSymbol[T <: Type](name: String, fieldType: T) extends Symbol {
+  def this(name: String)(implicit builder: TypeBuilder[T]) = this(name, builder())
   def getName = name
+  def getType = fieldType
 }
 
-sealed abstract class VarSymbol[T <: Type](name: String) extends Symbol {
+sealed abstract class VarSymbol[T <: Type](name: String, varType: T) extends Symbol {
   def getName = name
+  def getType : T = varType
 }
 
 object VarSymbol {
   implicit def toRef[T <: Type](symbol: VarSymbol[T]) = new VarRef[T](symbol)
 }
 
-class DeclaredVarSymbol[T <: Type](name: String) extends VarSymbol[T](name)
-class NamedUnboundVarSymbol[T <: Type](name: String) extends VarSymbol[T](name)
+class DeclaredVarSymbol[T <: Type](name: String, varType: T) extends VarSymbol[T](name, varType) {
+  def this(name: String)(implicit builder: TypeBuilder[T]) = this(name, builder())
+}
+
+class NamedUnboundVarSymbol[T <: Type](name: String, varType: T) extends VarSymbol[T](name, varType) {
+  def this(name: String)(implicit builder: TypeBuilder[T]) = this(name, builder())
+}
diff --git a/src/ofc/codegen/Type.scala b/src/ofc/codegen/Type.scala
new file mode 100644 (file)
index 0000000..76db372
--- /dev/null
@@ -0,0 +1,42 @@
+package ofc.codegen
+
+sealed abstract class Type
+sealed abstract class PrimitiveType extends Type
+
+// These are case classes solely for the comparison operators
+final case class IntType() extends PrimitiveType
+final case class FloatType() extends PrimitiveType
+final case class BoolType() extends PrimitiveType
+
+final case class ArrayType[ElementType <: Type](rank: Int, eType: ElementType) extends Type {
+  def this(rank: Int)(implicit builder: TypeBuilder[ElementType]) = this(rank, builder())
+  def getElementType = eType
+}
+
+final case class PointerType[TargetType <: Type](tType: TargetType) extends Type {
+  def this()(implicit builder: TypeBuilder[TargetType]) = this(builder())
+  def getTargetType = tType
+}
+
+abstract class StructType extends Type
+
+trait TypeBuilder[T <: Type] {
+  def apply() : T
+}
+
+object TypeBuilder {
+  implicit val intBuilder = new TypeBuilder[IntType] { def apply() = new IntType }
+  implicit val floatBuilder = new TypeBuilder[FloatType] { def apply() = new FloatType }
+  implicit val boolBuilder = new TypeBuilder[BoolType] { def apply() = new BoolType }
+}
+
+trait TypeProperty
+trait Numeric extends TypeProperty
+
+class HasProperty[T <: Type, P <: TypeProperty]
+object HasProperty {
+  implicit val intNumeric = new HasProperty[IntType, Numeric]()
+  implicit val floatNumeric = new HasProperty[FloatType, Numeric]()
+}
+
+
diff --git a/src/ofc/generators/onetep/OnetepTypes.scala b/src/ofc/generators/onetep/OnetepTypes.scala
new file mode 100644 (file)
index 0000000..8038609
--- /dev/null
@@ -0,0 +1,35 @@
+package ofc.generators.onetep
+import ofc.codegen._
+
+object OnetepTypes {
+  object FunctionBasis extends StructType {
+    val numPPDsInSphere = {
+      val fieldType = new PointerType[ArrayType[IntType]](new ArrayType[IntType](1))
+      new FieldSymbol[PointerType[ArrayType[IntType]]]("n_ppds_sphere", fieldType)
+    }
+
+    val ppdList = {
+      val fieldType = new PointerType[ArrayType[IntType]](new ArrayType[IntType](1))
+      new FieldSymbol[PointerType[ArrayType[IntType]]]("ppd_list", fieldType)
+    }
+
+    val num = new FieldSymbol[IntType]("num")
+    
+    val tightBoxes = {
+      val fieldType = new PointerType[ArrayType[StructType]](new ArrayType(1, TightBox))
+      new FieldSymbol[PointerType[ArrayType[StructType]]]("tight_boxes", fieldType)
+    }
+  }
+
+  object CellInfo extends StructType {
+    val ppdWidth = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("n_pt"+dim)}.toSeq
+    val numPPDs = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("n_ppds_a"+dim)}.toSeq
+  }
+
+  object TightBox extends StructType {
+    val startPts = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("start_pts"+dim)}.toSeq
+    val finishPts = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("finish_pts"+dim)}.toSeq
+    val startPPD = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("start_ppd"+dim)}.toSeq
+    val finishPPD = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("finish_ppd"+dim)}.toSeq
+  }
+}
index 2a6239a93199058b3fb53e7c36ddbba9607525ba..827a02e22b6b438d30ae28e6ab82a9356a6eeb1b 100644 (file)
@@ -2,20 +2,22 @@ package ofc.generators.onetep
 import ofc.codegen._
 
 class PPDFunctionSet(val basisName: String, dataName: String) extends FunctionSet {
-  val basis = new NamedUnboundVarSymbol[StructType](basisName)
-  val data =  new NamedUnboundVarSymbol[ArrayType[FloatType]](dataName)
-  val pubCell = new NamedUnboundVarSymbol[StructType]("pub_cell")
+  import OnetepTypes._
 
-  val numSpheres = basis % FieldSymbol[IntType]("num");
-  val ppdWidths = for(dim <- 1 to 3) yield pubCell % FieldSymbol[IntType]("n_pt"+dim)
-  val cellWidthInPPDs = for(dim <- 1 to 3) yield pubCell % FieldSymbol[IntType]("n_ppds_a"+dim)
+  val basis = new NamedUnboundVarSymbol[StructType](basisName, FunctionBasis)
+  val data =  new NamedUnboundVarSymbol[ArrayType[FloatType]](dataName, new ArrayType[FloatType](1))
+  val pubCell = new NamedUnboundVarSymbol[StructType]("pub_cell", OnetepTypes.CellInfo)
+
+  val numSpheres = basis % FunctionBasis.num
+  val ppdWidths = for(dim <- 0 to 2) yield pubCell % CellInfo.ppdWidth(dim)
+  val cellWidthInPPDs = for(dim <- 0 to 2) yield pubCell % CellInfo.numPPDs(dim)
 
   def getSuffixFragment = {
     val producer = new ProducerStatement
     val sphereIndex = producer.addIteration("sphere_index", 1, numSpheres)
-    val numPPDs = (~(basis % FieldSymbol[PointerType[ArrayType[IntType]]]("n_ppds_sphere"))).readAt(sphereIndex)
+    val numPPDs = (~(basis % FunctionBasis.numPPDsInSphere)).readAt(sphereIndex)
     val ppdIndex = producer.addIteration("ppd_index", 1, numPPDs)
-    val ppdGlobalCount = (~(basis % FieldSymbol[PointerType[ArrayType[IntType]]]("ppd_list"))).readAt(ppdIndex, 1) - 1
+    val ppdGlobalCount = (~(basis % FunctionBasis.ppdList)).readAt(ppdIndex, 1) - 1
 
     // The integer co-ordinates of the PPD (0-based)
     val a3pos = producer.addExpression("ppd_pos1", ppdGlobalCount / (cellWidthInPPDs(0)*cellWidthInPPDs(1)))
@@ -23,17 +25,17 @@ class PPDFunctionSet(val basisName: String, dataName: String) extends FunctionSe
     val a1pos = producer.addExpression("ppd_pos3", ppdGlobalCount % cellWidthInPPDs(0))
     val ppdPos = List(a1pos, a2pos, a3pos)
 
-    val tightbox = (~(basis % FieldSymbol[PointerType[ArrayType[StructType]]]("tight_boxes"))).readAt(sphereIndex)
+    val tightbox = (~(basis % FunctionBasis.tightBoxes)).readAt(sphereIndex)
 
     // The offsets into the PPDs for the edges of the tightbox
-    val ppdStartOffsets = for(dim <- 1 to 3) yield tightbox % FieldSymbol[IntType]("start_pts"+dim)
-    val ppdFinishOffsets = for(dim <- 1 to 3) yield tightbox % FieldSymbol[IntType]("finish_pts"+dim)
+    val ppdStartOffsets = for(dim <- 0 to 2) yield tightbox % TightBox.startPts(dim)
+    val ppdFinishOffsets = for(dim <- 0 to 2) yield tightbox % TightBox.finishPts(dim)
 
     // The first and last PPDs in PPD co-ordinates (0-based, inside simulation cell)
     val startPPDs = for(dim <- 0 to 2) yield 
-      producer.addExpression("start_ppd"+(dim+1), (tightbox % FieldSymbol[IntType]("start_ppd"+(dim+1)) + cellWidthInPPDs(dim)-1) % cellWidthInPPDs(dim))
+      producer.addExpression("start_ppd"+(dim+1), (tightbox % TightBox.startPPD(dim) + cellWidthInPPDs(dim)-1) % cellWidthInPPDs(dim))
     val finishPPDs = for(dim <- 0 to 2) yield 
-      producer.addExpression("finish_ppd"+(dim+1),(tightbox % FieldSymbol[IntType]("finish_ppd"+(dim+1)) + cellWidthInPPDs(dim)-1) % cellWidthInPPDs(dim))
+      producer.addExpression("finish_ppd"+(dim+1),(tightbox % TightBox.finishPPD(dim) + cellWidthInPPDs(dim)-1) % cellWidthInPPDs(dim))
 
     // Offsets for the current PPD being iterated over
     val loopStarts = for(dim <- 0 to 2) yield