From: Francis Russell Date: Sun, 8 Apr 2012 01:14:57 +0000 (+0100) Subject: Make expressions and variables carry their types. X-Git-Url: https://git.unchartedbackwaters.co.uk/w/?a=commitdiff_plain;h=e3484e5513c699f7d6f80f5eea2ff3a20cb3fc6a;p=francis%2Fofc.git Make expressions and variables carry their types. --- diff --git a/src/ofc/codegen/ConditionalValue.scala b/src/ofc/codegen/ConditionalValue.scala index f35c3d3..84f4de5 100644 --- a/src/ofc/codegen/ConditionalValue.scala +++ b/src/ofc/codegen/ConditionalValue.scala @@ -1,8 +1,15 @@ package ofc.codegen +import ofc.LogicError class ConditionalValue[T <: Type](predicate: Expression[BoolType], ifTrue: Expression[T], ifFalse: Expression[T]) extends Expression[T] { def foreach[U](f: Expression[_] => U) = List(predicate, ifTrue, ifFalse).foreach(f) def getPredicate = predicate def getIfTrue = ifTrue def getIfFalse = ifFalse + def getType = { + if (ifTrue.getType != ifFalse.getType) + throw new LogicError("Parameters to ternary operator have different types") + else + ifTrue.getType + } } diff --git a/src/ofc/codegen/Expression.scala b/src/ofc/codegen/Expression.scala index a28f219..7325ce4 100644 --- a/src/ofc/codegen/Expression.scala +++ b/src/ofc/codegen/Expression.scala @@ -1,27 +1,5 @@ package ofc.codegen -trait TypeProperty -trait Numeric extends TypeProperty - -class HasProperty[T <: Type, P <: TypeProperty] -object HasProperty { - implicit val intNumeric = new HasProperty[IntType, Numeric]() - implicit val floatNumeric = new HasProperty[FloatType, Numeric]() -} - -sealed abstract class Type -sealed abstract class PrimitiveType extends Type -class IntType extends PrimitiveType -class FloatType extends PrimitiveType -class BoolType extends PrimitiveType -class ArrayType[ElementType <: Type] extends Type -class PointerType[TargetType <: Type] extends Type -abstract class StructType extends Type - -trait LeafExpression { - def foreach[U](f: Expression[_] => U): Unit = () -} - object Expression { implicit def fromInt(i: Int) : Expression[IntType] = new IntegerLiteral(i) @@ -40,6 +18,8 @@ object Expression { } abstract class Expression[T <: Type] extends Traversable[Expression[_]] { + def getType : T + // Field Operations def +(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[T] = new NumericOperator[T](NumericOperations.Add, this, rhs) @@ -85,9 +65,14 @@ abstract class Expression[T <: Type] extends Traversable[Expression[_]] { new PointerDereference(witness(this)) } +trait LeafExpression { + def foreach[U](f: Expression[_] => U): Unit = () +} + // Variable references class VarRef[T <: Type](symbol: VarSymbol[T]) extends Expression[T] with LeafExpression { def getSymbol = symbol + def getType = symbol.getType } // Struct and array accesses @@ -95,19 +80,23 @@ class FieldAccess[T <: Type](expression: Expression[StructType], field: FieldSym def foreach[U](f: Expression[_] => U) = f(expression) def getStructExpression = expression def getField = field + def getType = field.getType } class ArrayRead[E <: Type](expression: Expression[ArrayType[E]], index: Seq[Expression[IntType]]) extends Expression[E] { def foreach[U](f: Expression[_] => U) = (index :+ expression).foreach(f) def getArrayExpression = expression def getIndexExpressions = index + def getType = expression.getType.getElementType } class PointerDereference[E <: Type](expression: Expression[PointerType[E]]) extends Expression[E] { def foreach[U](f: Expression[_] => U) = f(expression) def getExpression = expression + def getType = expression.getType.getTargetType } // Literals class IntegerLiteral(value: Int) extends Expression[IntType] with LeafExpression { def getValue = value + def getType = new IntType } diff --git a/src/ofc/codegen/FortranGenerator.scala b/src/ofc/codegen/FortranGenerator.scala index fea6c97..fe92476 100644 --- a/src/ofc/codegen/FortranGenerator.scala +++ b/src/ofc/codegen/FortranGenerator.scala @@ -106,8 +106,8 @@ class FortranGenerator { } } - private def buildConditionalValue(conditional: ConditionalValue[_]) : ExpHolder = { - var symbol = new DeclaredVarSymbol[Type]("ternary") + private def buildConditionalValue(conditional: ConditionalValue[_ <: Type]) : ExpHolder = { + var symbol = new DeclaredVarSymbol[Type]("ternary", conditional.getType) symbolManager.addSymbol(symbol) val name = symbolManager.getName(symbol) addLine("if (%s) then".format(buildExpression(conditional.getPredicate))) diff --git a/src/ofc/codegen/NumericOperator.scala b/src/ofc/codegen/NumericOperator.scala index d119783..58213a4 100644 --- a/src/ofc/codegen/NumericOperator.scala +++ b/src/ofc/codegen/NumericOperator.scala @@ -1,4 +1,5 @@ package ofc.codegen +import ofc.LogicError object NumericOperations { sealed abstract class FieldOp @@ -18,17 +19,22 @@ object NumericOperations { } class NumericOperator[T <: Type](op: NumericOperations.FieldOp, left: Expression[T], right: Expression[T]) extends Expression[T] { - // TODO: Type check operators def foreach[U](f: Expression[_] => U) = List(left, right).foreach(f) def getOperation = op def getLeft = left def getRight = right + def getType = { + if (left.getType != right.getType) + throw new LogicError("Non-matching types for parameters to numeric comparison") + else + left.getType + } } class NumericComparison[T <: Type](op: NumericOperations.CompareOp, left: Expression[T], right: Expression[T]) extends Expression[BoolType] { - // TODO: Type check operators def foreach[U](f: Expression[_] => U) = List(left, right).foreach(f) def getOperation = op def getLeft = left def getRight = right + def getType = new BoolType } diff --git a/src/ofc/codegen/ProducerStatement.scala b/src/ofc/codegen/ProducerStatement.scala index a1a301d..e564ca7 100644 --- a/src/ofc/codegen/ProducerStatement.scala +++ b/src/ofc/codegen/ProducerStatement.scala @@ -85,7 +85,7 @@ class ProducerStatement extends Statement { var expressions : Seq[DerivedExpression] = Nil def addExpression[T <: Type](name: String, expression: Expression[T]) : VarSymbol[T] = { - val symbol = new DeclaredVarSymbol[T](name) + val symbol = new DeclaredVarSymbol[T](name, expression.getType) expressions +:= new DerivedExpression(symbol, expression) symbol } diff --git a/src/ofc/codegen/Symbol.scala b/src/ofc/codegen/Symbol.scala index 6d463c6..905920d 100644 --- a/src/ofc/codegen/Symbol.scala +++ b/src/ofc/codegen/Symbol.scala @@ -4,17 +4,25 @@ trait Symbol { def getName : String } -case class FieldSymbol[T <: Type](name: String) extends Symbol { +class FieldSymbol[T <: Type](name: String, fieldType: T) extends Symbol { + def this(name: String)(implicit builder: TypeBuilder[T]) = this(name, builder()) def getName = name + def getType = fieldType } -sealed abstract class VarSymbol[T <: Type](name: String) extends Symbol { +sealed abstract class VarSymbol[T <: Type](name: String, varType: T) extends Symbol { def getName = name + def getType : T = varType } object VarSymbol { implicit def toRef[T <: Type](symbol: VarSymbol[T]) = new VarRef[T](symbol) } -class DeclaredVarSymbol[T <: Type](name: String) extends VarSymbol[T](name) -class NamedUnboundVarSymbol[T <: Type](name: String) extends VarSymbol[T](name) +class DeclaredVarSymbol[T <: Type](name: String, varType: T) extends VarSymbol[T](name, varType) { + def this(name: String)(implicit builder: TypeBuilder[T]) = this(name, builder()) +} + +class NamedUnboundVarSymbol[T <: Type](name: String, varType: T) extends VarSymbol[T](name, varType) { + def this(name: String)(implicit builder: TypeBuilder[T]) = this(name, builder()) +} diff --git a/src/ofc/codegen/Type.scala b/src/ofc/codegen/Type.scala new file mode 100644 index 0000000..76db372 --- /dev/null +++ b/src/ofc/codegen/Type.scala @@ -0,0 +1,42 @@ +package ofc.codegen + +sealed abstract class Type +sealed abstract class PrimitiveType extends Type + +// These are case classes solely for the comparison operators +final case class IntType() extends PrimitiveType +final case class FloatType() extends PrimitiveType +final case class BoolType() extends PrimitiveType + +final case class ArrayType[ElementType <: Type](rank: Int, eType: ElementType) extends Type { + def this(rank: Int)(implicit builder: TypeBuilder[ElementType]) = this(rank, builder()) + def getElementType = eType +} + +final case class PointerType[TargetType <: Type](tType: TargetType) extends Type { + def this()(implicit builder: TypeBuilder[TargetType]) = this(builder()) + def getTargetType = tType +} + +abstract class StructType extends Type + +trait TypeBuilder[T <: Type] { + def apply() : T +} + +object TypeBuilder { + implicit val intBuilder = new TypeBuilder[IntType] { def apply() = new IntType } + implicit val floatBuilder = new TypeBuilder[FloatType] { def apply() = new FloatType } + implicit val boolBuilder = new TypeBuilder[BoolType] { def apply() = new BoolType } +} + +trait TypeProperty +trait Numeric extends TypeProperty + +class HasProperty[T <: Type, P <: TypeProperty] +object HasProperty { + implicit val intNumeric = new HasProperty[IntType, Numeric]() + implicit val floatNumeric = new HasProperty[FloatType, Numeric]() +} + + diff --git a/src/ofc/generators/onetep/OnetepTypes.scala b/src/ofc/generators/onetep/OnetepTypes.scala new file mode 100644 index 0000000..8038609 --- /dev/null +++ b/src/ofc/generators/onetep/OnetepTypes.scala @@ -0,0 +1,35 @@ +package ofc.generators.onetep +import ofc.codegen._ + +object OnetepTypes { + object FunctionBasis extends StructType { + val numPPDsInSphere = { + val fieldType = new PointerType[ArrayType[IntType]](new ArrayType[IntType](1)) + new FieldSymbol[PointerType[ArrayType[IntType]]]("n_ppds_sphere", fieldType) + } + + val ppdList = { + val fieldType = new PointerType[ArrayType[IntType]](new ArrayType[IntType](1)) + new FieldSymbol[PointerType[ArrayType[IntType]]]("ppd_list", fieldType) + } + + val num = new FieldSymbol[IntType]("num") + + val tightBoxes = { + val fieldType = new PointerType[ArrayType[StructType]](new ArrayType(1, TightBox)) + new FieldSymbol[PointerType[ArrayType[StructType]]]("tight_boxes", fieldType) + } + } + + object CellInfo extends StructType { + val ppdWidth = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("n_pt"+dim)}.toSeq + val numPPDs = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("n_ppds_a"+dim)}.toSeq + } + + object TightBox extends StructType { + val startPts = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("start_pts"+dim)}.toSeq + val finishPts = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("finish_pts"+dim)}.toSeq + val startPPD = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("start_ppd"+dim)}.toSeq + val finishPPD = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("finish_ppd"+dim)}.toSeq + } +} diff --git a/src/ofc/generators/onetep/PPDFunctionSet.scala b/src/ofc/generators/onetep/PPDFunctionSet.scala index 2a6239a..827a02e 100644 --- a/src/ofc/generators/onetep/PPDFunctionSet.scala +++ b/src/ofc/generators/onetep/PPDFunctionSet.scala @@ -2,20 +2,22 @@ package ofc.generators.onetep import ofc.codegen._ class PPDFunctionSet(val basisName: String, dataName: String) extends FunctionSet { - val basis = new NamedUnboundVarSymbol[StructType](basisName) - val data = new NamedUnboundVarSymbol[ArrayType[FloatType]](dataName) - val pubCell = new NamedUnboundVarSymbol[StructType]("pub_cell") + import OnetepTypes._ - val numSpheres = basis % FieldSymbol[IntType]("num"); - val ppdWidths = for(dim <- 1 to 3) yield pubCell % FieldSymbol[IntType]("n_pt"+dim) - val cellWidthInPPDs = for(dim <- 1 to 3) yield pubCell % FieldSymbol[IntType]("n_ppds_a"+dim) + val basis = new NamedUnboundVarSymbol[StructType](basisName, FunctionBasis) + val data = new NamedUnboundVarSymbol[ArrayType[FloatType]](dataName, new ArrayType[FloatType](1)) + val pubCell = new NamedUnboundVarSymbol[StructType]("pub_cell", OnetepTypes.CellInfo) + + val numSpheres = basis % FunctionBasis.num + val ppdWidths = for(dim <- 0 to 2) yield pubCell % CellInfo.ppdWidth(dim) + val cellWidthInPPDs = for(dim <- 0 to 2) yield pubCell % CellInfo.numPPDs(dim) def getSuffixFragment = { val producer = new ProducerStatement val sphereIndex = producer.addIteration("sphere_index", 1, numSpheres) - val numPPDs = (~(basis % FieldSymbol[PointerType[ArrayType[IntType]]]("n_ppds_sphere"))).readAt(sphereIndex) + val numPPDs = (~(basis % FunctionBasis.numPPDsInSphere)).readAt(sphereIndex) val ppdIndex = producer.addIteration("ppd_index", 1, numPPDs) - val ppdGlobalCount = (~(basis % FieldSymbol[PointerType[ArrayType[IntType]]]("ppd_list"))).readAt(ppdIndex, 1) - 1 + val ppdGlobalCount = (~(basis % FunctionBasis.ppdList)).readAt(ppdIndex, 1) - 1 // The integer co-ordinates of the PPD (0-based) val a3pos = producer.addExpression("ppd_pos1", ppdGlobalCount / (cellWidthInPPDs(0)*cellWidthInPPDs(1))) @@ -23,17 +25,17 @@ class PPDFunctionSet(val basisName: String, dataName: String) extends FunctionSe val a1pos = producer.addExpression("ppd_pos3", ppdGlobalCount % cellWidthInPPDs(0)) val ppdPos = List(a1pos, a2pos, a3pos) - val tightbox = (~(basis % FieldSymbol[PointerType[ArrayType[StructType]]]("tight_boxes"))).readAt(sphereIndex) + val tightbox = (~(basis % FunctionBasis.tightBoxes)).readAt(sphereIndex) // The offsets into the PPDs for the edges of the tightbox - val ppdStartOffsets = for(dim <- 1 to 3) yield tightbox % FieldSymbol[IntType]("start_pts"+dim) - val ppdFinishOffsets = for(dim <- 1 to 3) yield tightbox % FieldSymbol[IntType]("finish_pts"+dim) + val ppdStartOffsets = for(dim <- 0 to 2) yield tightbox % TightBox.startPts(dim) + val ppdFinishOffsets = for(dim <- 0 to 2) yield tightbox % TightBox.finishPts(dim) // The first and last PPDs in PPD co-ordinates (0-based, inside simulation cell) val startPPDs = for(dim <- 0 to 2) yield - producer.addExpression("start_ppd"+(dim+1), (tightbox % FieldSymbol[IntType]("start_ppd"+(dim+1)) + cellWidthInPPDs(dim)-1) % cellWidthInPPDs(dim)) + producer.addExpression("start_ppd"+(dim+1), (tightbox % TightBox.startPPD(dim) + cellWidthInPPDs(dim)-1) % cellWidthInPPDs(dim)) val finishPPDs = for(dim <- 0 to 2) yield - producer.addExpression("finish_ppd"+(dim+1),(tightbox % FieldSymbol[IntType]("finish_ppd"+(dim+1)) + cellWidthInPPDs(dim)-1) % cellWidthInPPDs(dim)) + producer.addExpression("finish_ppd"+(dim+1),(tightbox % TightBox.finishPPD(dim) + cellWidthInPPDs(dim)-1) % cellWidthInPPDs(dim)) // Offsets for the current PPD being iterated over val loopStarts = for(dim <- 0 to 2) yield