From 19c412704022f08f94b6335a9ff5fcf69b29f77e Mon Sep 17 00:00:00 2001 From: Francis Russell Date: Wed, 4 Apr 2012 23:58:31 +0100 Subject: [PATCH] Work on typing expressions. --- src/ofc/codegen/ConditionalValue.scala | 2 +- src/ofc/codegen/Expression.scala | 70 ++++++++++++++----- ...ryOperator.scala => NumericOperator.scala} | 4 +- src/ofc/codegen/ProducerStatement.scala | 12 ++-- src/ofc/codegen/Symbol.scala | 13 ++-- src/ofc/generators/onetep/Index.scala | 7 +- .../generators/onetep/PPDFunctionSet.scala | 26 +++---- 7 files changed, 85 insertions(+), 49 deletions(-) rename src/ofc/codegen/{BinaryOperator.scala => NumericOperator.scala} (60%) diff --git a/src/ofc/codegen/ConditionalValue.scala b/src/ofc/codegen/ConditionalValue.scala index 9b93e82..fc07f8e 100644 --- a/src/ofc/codegen/ConditionalValue.scala +++ b/src/ofc/codegen/ConditionalValue.scala @@ -1,3 +1,3 @@ package ofc.codegen -class ConditionalValue(predicate: Expression, ifTrue: Expression, ifFalse: Expression) extends Expression +class ConditionalValue[T <: Type](predicate: Expression[BoolType], ifTrue: Expression[T], ifFalse: Expression[T]) extends Expression[T] diff --git a/src/ofc/codegen/Expression.scala b/src/ofc/codegen/Expression.scala index 0e93d92..05ce29c 100644 --- a/src/ofc/codegen/Expression.scala +++ b/src/ofc/codegen/Expression.scala @@ -1,25 +1,63 @@ package ofc.codegen -class Expression { - def ~>(field: FieldSymbol) : Expression = new FieldAccess(this, field) - def readAt(index: List[Expression]) : Expression = new ArrayRead(this, index) - def apply(index: Expression*) : Expression = this.readAt(index.toList) - def +(rhs: Expression) : Expression = new BinaryOperator(BinaryOperator.Add, this, rhs) - def -(rhs: Expression) : Expression = new BinaryOperator(BinaryOperator.Sub, this, rhs) - def *(rhs: Expression) : Expression = new BinaryOperator(BinaryOperator.Mul, this, rhs) - def /(rhs: Expression) : Expression = new BinaryOperator(BinaryOperator.Div, this, rhs) - def %(rhs: Expression) : Expression = new BinaryOperator(BinaryOperator.Mod, this, rhs) +trait TypeProperty +trait Numeric extends TypeProperty + +class HasProperty[T <: Type, P <: TypeProperty] +object HasProperty { + implicit val intNumeric = new HasProperty[IntType, Numeric]() + implicit val floatNumeric = new HasProperty[FloatType, Numeric]() +} + +sealed abstract class Type +sealed abstract class PrimitiveType extends Type +class IntType extends PrimitiveType +class FloatType extends PrimitiveType +class BoolType extends PrimitiveType +class ArrayType[ElementType <: Type] extends Type +class PointerType[TargetType <: Type] extends Type +abstract class StructType extends Type + +object Expression { + implicit def fromInt(i: Int) : Expression[IntType] = new IntegerLiteral(i) +} + +class Expression[T <: Type] { + def +(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[T] = + new NumericOperator[T](NumericOperator.Add, this, rhs) + + def -(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[T] = + new NumericOperator[T](NumericOperator.Sub, this, rhs) + + def *(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[T] = + new NumericOperator[T](NumericOperator.Mul, this, rhs) + + def /(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[T] = + new NumericOperator[T](NumericOperator.Div, this, rhs) + + def %(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[T] = + new NumericOperator[T](NumericOperator.Mod, this, rhs) + + def ~>[FieldType <: Type](field: FieldSymbol[FieldType])(implicit witness: <:<[this.type, Expression[StructType]]) : Expression[FieldType] = + new FieldAccess[FieldType](witness(this), field) + + def readAt[T <: Type](index: Expression[IntType]*)(implicit witness: <:<[this.type, Expression[ArrayType[T]]]) : Expression[T] = + new ArrayRead(witness(this), index.toList) + + //def apply[T <: Type](index: Expression[IntType]*)(implicit witness: <:<[this.type, Expression[ArrayType[T]]]) : Expression[T] = + // this.readAt(index.toList) + + def unary_~[T <: Type]()(implicit witness: <:<[this.type, Expression[PointerType[T]]]) : Expression[T] = + new PointerDereference(witness(this)) } // Variable references -class VarRef(symbol: VarSymbol) extends Expression +class VarRef[T <: Type](symbol: VarSymbol[T]) extends Expression[T] // Struct and array accesses -class FieldAccess(expression: Expression, field: FieldSymbol) extends Expression -class ArrayRead(expression: Expression, index: List[Expression]) extends Expression +class FieldAccess[T <: Type](expression: Expression[StructType], field: FieldSymbol[T]) extends Expression[T] +class ArrayRead[E <: Type](expression: Expression[ArrayType[E]], index: List[Expression[IntType]]) extends Expression[E] +class PointerDereference[E <: Type](expression: Expression[PointerType[E]]) extends Expression[E] // Literals -class IntegerLiteral(value: Int) extends Expression -object IntegerLiteral { - implicit def fromInt(value: Int) : Expression = new IntegerLiteral(value) -} +class IntegerLiteral(value: Int) extends Expression[IntType] diff --git a/src/ofc/codegen/BinaryOperator.scala b/src/ofc/codegen/NumericOperator.scala similarity index 60% rename from src/ofc/codegen/BinaryOperator.scala rename to src/ofc/codegen/NumericOperator.scala index 406f507..267ba10 100644 --- a/src/ofc/codegen/BinaryOperator.scala +++ b/src/ofc/codegen/NumericOperator.scala @@ -1,6 +1,6 @@ package ofc.codegen -object BinaryOperator { +object NumericOperator { sealed abstract class Operator case object Add extends Operator case object Sub extends Operator @@ -9,4 +9,4 @@ object BinaryOperator { case object Mod extends Operator } -class BinaryOperator(op: BinaryOperator.Operator, left: Expression, right: Expression) extends Expression +class NumericOperator[T <: Type](op: NumericOperator.Operator, left: Expression[T], right: Expression[T]) extends Expression[T] diff --git a/src/ofc/codegen/ProducerStatement.scala b/src/ofc/codegen/ProducerStatement.scala index 96c119a..8c1af5e 100644 --- a/src/ofc/codegen/ProducerStatement.scala +++ b/src/ofc/codegen/ProducerStatement.scala @@ -1,22 +1,22 @@ package ofc.codegen class ProducerStatement extends Statement { - class VariableRange(symbol: Symbol, count: Expression) + class VariableRange(symbol: Symbol, count: Expression[IntType]) class Predicate var statement = new NullStatement var ranges : List[VariableRange] = List.empty var predicates : List[Predicate] = List.empty - var expressions : Map[Symbol, Expression] = Map.empty + var expressions : Map[Symbol, Expression[_]] = Map.empty - def addExpression(name: String, expression: Expression) : VarSymbol = { - val symbol = new DeclaredVarSymbol(name) + def addExpression[T <: Type](name: String, expression: Expression[T]) : VarSymbol[T] = { + val symbol = new DeclaredVarSymbol[T](name) expressions += symbol -> expression symbol } - def addIteration(name: String, count: Expression) : VarSymbol = { - val symbol = new DeclaredVarSymbol(name) + def addIteration(name: String, count: Expression[IntType]) : VarSymbol[IntType] = { + val symbol = new DeclaredVarSymbol[IntType](name) ranges +:= new VariableRange(symbol, count) symbol } diff --git a/src/ofc/codegen/Symbol.scala b/src/ofc/codegen/Symbol.scala index b55caac..a7c770c 100644 --- a/src/ofc/codegen/Symbol.scala +++ b/src/ofc/codegen/Symbol.scala @@ -4,19 +4,18 @@ trait Symbol { def getName : String } -case class FieldSymbol(name: String) extends Symbol { +case class FieldSymbol[T <: Type](name: String) extends Symbol { def getName = name } -abstract class VarSymbol(name: String) extends Symbol { +abstract class VarSymbol[T <: Type](name: String) extends Symbol { def getName = name } object VarSymbol { - implicit def toRef(symbol: VarSymbol) = new VarRef(symbol) + implicit def toRef[T <: Type](symbol: VarSymbol[T]) = new VarRef[T](symbol) } -case class DeclaredVarSymbol(name: String) extends VarSymbol(name) -abstract class UnboundVarSymbol(name: String) extends VarSymbol(name) -case class NamedUnboundVarSymbol(name: String) extends UnboundVarSymbol(name) -class IterationSymbol(name: String, count: Expression) extends UnboundVarSymbol(name) +case class DeclaredVarSymbol[T <: Type](name: String) extends VarSymbol[T](name) +abstract class UnboundVarSymbol[T <: Type](name: String) extends VarSymbol[T](name) +case class NamedUnboundVarSymbol[T <: Type](name: String) extends UnboundVarSymbol[T](name) diff --git a/src/ofc/generators/onetep/Index.scala b/src/ofc/generators/onetep/Index.scala index 435d62e..caabe2b 100644 --- a/src/ofc/generators/onetep/Index.scala +++ b/src/ofc/generators/onetep/Index.scala @@ -1,5 +1,5 @@ package ofc.generators.onetep -import ofc.codegen.Expression +import ofc.codegen.{Expression,IntType} /* object Index { @@ -18,9 +18,8 @@ object Index { trait Index { def getName : String - def getMappingFunction : Expression - def getMinimumValue : Expression - def getLength : Expression + def getMinimumValue : Expression[IntType] + def getLength : Expression[IntType] def isRandomAccess : Boolean } diff --git a/src/ofc/generators/onetep/PPDFunctionSet.scala b/src/ofc/generators/onetep/PPDFunctionSet.scala index f35aee0..24c298a 100644 --- a/src/ofc/generators/onetep/PPDFunctionSet.scala +++ b/src/ofc/generators/onetep/PPDFunctionSet.scala @@ -2,20 +2,20 @@ package ofc.generators.onetep import ofc.codegen._ class PPDFunctionSet(val basisName: String, dataName: String) extends FunctionSet { - val basis = NamedUnboundVarSymbol(basisName) - val data = NamedUnboundVarSymbol(dataName) - val pubCell = NamedUnboundVarSymbol("pub_cell") + val basis = NamedUnboundVarSymbol[StructType](basisName) + val data = NamedUnboundVarSymbol[ArrayType[FloatType]](dataName) + val pubCell = NamedUnboundVarSymbol[StructType]("pub_cell") - val numSpheres = basis~>FieldSymbol("num"); - val ppdWidths = for(dim <- 1 to 3) yield pubCell~>FieldSymbol("n_pt"+dim) - val cellWidthInPPDs = for(dim <- 1 to 3) yield pubCell~>FieldSymbol("n_ppds_a"+dim) + val numSpheres = basis~>FieldSymbol[IntType]("num"); + val ppdWidths = for(dim <- 1 to 3) yield pubCell~>FieldSymbol[IntType]("n_pt"+dim) + val cellWidthInPPDs = for(dim <- 1 to 3) yield pubCell~>FieldSymbol[IntType]("n_ppds_a"+dim) def getSuffixFragment = { val producer = new ProducerStatement val sphereIndex = producer.addIteration("sphere_index", numSpheres) - val numPPDs = (basis~>FieldSymbol("n_ppds_sphere"))(sphereIndex) + val numPPDs = (~(basis~>FieldSymbol[PointerType[ArrayType[IntType]]]("n_ppds_sphere"))).readAt(sphereIndex) val ppdIndex = producer.addIteration("ppd_index", numPPDs) - val ppdGlobalCount = (basis~>FieldSymbol("ppd_list"))(ppdIndex, new IntegerLiteral(1)) - new IntegerLiteral(1) + val ppdGlobalCount = (~(basis~>FieldSymbol[PointerType[ArrayType[IntType]]]("ppd_list"))).readAt(ppdIndex, 1) - 1 // We need to calculate the integer co-ordinates of the PPD (0-based) val a3pos = ppdGlobalCount / (cellWidthInPPDs(0)*cellWidthInPPDs(1)) @@ -24,12 +24,12 @@ class PPDFunctionSet(val basisName: String, dataName: String) extends FunctionSe val ppdPos = List(a1pos, a2pos, a3pos) - val tightbox = basis~>FieldSymbol("tight_boxes")(sphereIndex) - val ppdStartOffsets = for(dim <- 1 to 3) yield tightbox % FieldSymbol("start_pts"+dim) - val ppdFinishOffsets = for(dim <- 1 to 3) yield tightbox % FieldSymbol("finish_pts"+dim) + val tightbox = (~(basis~>FieldSymbol[PointerType[ArrayType[StructType]]]("tight_boxes"))).readAt(sphereIndex) + val ppdStartOffsets = for(dim <- 1 to 3) yield tightbox ~> FieldSymbol[IntType]("start_pts"+dim) + val ppdFinishOffsets = for(dim <- 1 to 3) yield tightbox ~> FieldSymbol[IntType]("finish_pts"+dim) - val startPPDs = for(dim <- 0 to 2) yield (tightbox % FieldSymbol("start_ppd"+(dim+1)) + cellWidthInPPDs(dim)) % cellWidthInPPDs(dim) - val finishPPDs = for(dim <- 0 to 2) yield (tightbox % FieldSymbol("finish_ppd"+(dim+1)) + cellWidthInPPDs(dim)) % cellWidthInPPDs(dim) + val startPPDs = for(dim <- 0 to 2) yield (tightbox ~> FieldSymbol("start_ppd"+(dim+1)) + cellWidthInPPDs(dim)) % cellWidthInPPDs(dim) + val finishPPDs = for(dim <- 0 to 2) yield (tightbox ~> FieldSymbol("finish_ppd"+(dim+1)) + cellWidthInPPDs(dim)) % cellWidthInPPDs(dim) //val ppdRanges = for(dim <- 0 to 2) yield producer.addIteration("point"+(dim+1), ppdWidths(dim)) -- 2.47.3