package ofc.codegen
+import ofc.LogicError
class ConditionalValue[T <: Type](predicate: Expression[BoolType], ifTrue: Expression[T], ifFalse: Expression[T]) extends Expression[T] {
def foreach[U](f: Expression[_] => U) = List(predicate, ifTrue, ifFalse).foreach(f)
def getPredicate = predicate
def getIfTrue = ifTrue
def getIfFalse = ifFalse
+ def getType = {
+ if (ifTrue.getType != ifFalse.getType)
+ throw new LogicError("Parameters to ternary operator have different types")
+ else
+ ifTrue.getType
+ }
}
package ofc.codegen
-trait TypeProperty
-trait Numeric extends TypeProperty
-
-class HasProperty[T <: Type, P <: TypeProperty]
-object HasProperty {
- implicit val intNumeric = new HasProperty[IntType, Numeric]()
- implicit val floatNumeric = new HasProperty[FloatType, Numeric]()
-}
-
-sealed abstract class Type
-sealed abstract class PrimitiveType extends Type
-class IntType extends PrimitiveType
-class FloatType extends PrimitiveType
-class BoolType extends PrimitiveType
-class ArrayType[ElementType <: Type] extends Type
-class PointerType[TargetType <: Type] extends Type
-abstract class StructType extends Type
-
-trait LeafExpression {
- def foreach[U](f: Expression[_] => U): Unit = ()
-}
-
object Expression {
implicit def fromInt(i: Int) : Expression[IntType] = new IntegerLiteral(i)
}
abstract class Expression[T <: Type] extends Traversable[Expression[_]] {
+ def getType : T
+
// Field Operations
def +(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[T] =
new NumericOperator[T](NumericOperations.Add, this, rhs)
new PointerDereference(witness(this))
}
+trait LeafExpression {
+ def foreach[U](f: Expression[_] => U): Unit = ()
+}
+
// Variable references
class VarRef[T <: Type](symbol: VarSymbol[T]) extends Expression[T] with LeafExpression {
def getSymbol = symbol
+ def getType = symbol.getType
}
// Struct and array accesses
def foreach[U](f: Expression[_] => U) = f(expression)
def getStructExpression = expression
def getField = field
+ def getType = field.getType
}
class ArrayRead[E <: Type](expression: Expression[ArrayType[E]], index: Seq[Expression[IntType]]) extends Expression[E] {
def foreach[U](f: Expression[_] => U) = (index :+ expression).foreach(f)
def getArrayExpression = expression
def getIndexExpressions = index
+ def getType = expression.getType.getElementType
}
class PointerDereference[E <: Type](expression: Expression[PointerType[E]]) extends Expression[E] {
def foreach[U](f: Expression[_] => U) = f(expression)
def getExpression = expression
+ def getType = expression.getType.getTargetType
}
// Literals
class IntegerLiteral(value: Int) extends Expression[IntType] with LeafExpression {
def getValue = value
+ def getType = new IntType
}
}
}
- private def buildConditionalValue(conditional: ConditionalValue[_]) : ExpHolder = {
- var symbol = new DeclaredVarSymbol[Type]("ternary")
+ private def buildConditionalValue(conditional: ConditionalValue[_ <: Type]) : ExpHolder = {
+ var symbol = new DeclaredVarSymbol[Type]("ternary", conditional.getType)
symbolManager.addSymbol(symbol)
val name = symbolManager.getName(symbol)
addLine("if (%s) then".format(buildExpression(conditional.getPredicate)))
package ofc.codegen
+import ofc.LogicError
object NumericOperations {
sealed abstract class FieldOp
}
class NumericOperator[T <: Type](op: NumericOperations.FieldOp, left: Expression[T], right: Expression[T]) extends Expression[T] {
- // TODO: Type check operators
def foreach[U](f: Expression[_] => U) = List(left, right).foreach(f)
def getOperation = op
def getLeft = left
def getRight = right
+ def getType = {
+ if (left.getType != right.getType)
+ throw new LogicError("Non-matching types for parameters to numeric comparison")
+ else
+ left.getType
+ }
}
class NumericComparison[T <: Type](op: NumericOperations.CompareOp, left: Expression[T], right: Expression[T]) extends Expression[BoolType] {
- // TODO: Type check operators
def foreach[U](f: Expression[_] => U) = List(left, right).foreach(f)
def getOperation = op
def getLeft = left
def getRight = right
+ def getType = new BoolType
}
var expressions : Seq[DerivedExpression] = Nil
def addExpression[T <: Type](name: String, expression: Expression[T]) : VarSymbol[T] = {
- val symbol = new DeclaredVarSymbol[T](name)
+ val symbol = new DeclaredVarSymbol[T](name, expression.getType)
expressions +:= new DerivedExpression(symbol, expression)
symbol
}
def getName : String
}
-case class FieldSymbol[T <: Type](name: String) extends Symbol {
+class FieldSymbol[T <: Type](name: String, fieldType: T) extends Symbol {
+ def this(name: String)(implicit builder: TypeBuilder[T]) = this(name, builder())
def getName = name
+ def getType = fieldType
}
-sealed abstract class VarSymbol[T <: Type](name: String) extends Symbol {
+sealed abstract class VarSymbol[T <: Type](name: String, varType: T) extends Symbol {
def getName = name
+ def getType : T = varType
}
object VarSymbol {
implicit def toRef[T <: Type](symbol: VarSymbol[T]) = new VarRef[T](symbol)
}
-class DeclaredVarSymbol[T <: Type](name: String) extends VarSymbol[T](name)
-class NamedUnboundVarSymbol[T <: Type](name: String) extends VarSymbol[T](name)
+class DeclaredVarSymbol[T <: Type](name: String, varType: T) extends VarSymbol[T](name, varType) {
+ def this(name: String)(implicit builder: TypeBuilder[T]) = this(name, builder())
+}
+
+class NamedUnboundVarSymbol[T <: Type](name: String, varType: T) extends VarSymbol[T](name, varType) {
+ def this(name: String)(implicit builder: TypeBuilder[T]) = this(name, builder())
+}
--- /dev/null
+package ofc.codegen
+
+sealed abstract class Type
+sealed abstract class PrimitiveType extends Type
+
+// These are case classes solely for the comparison operators
+final case class IntType() extends PrimitiveType
+final case class FloatType() extends PrimitiveType
+final case class BoolType() extends PrimitiveType
+
+final case class ArrayType[ElementType <: Type](rank: Int, eType: ElementType) extends Type {
+ def this(rank: Int)(implicit builder: TypeBuilder[ElementType]) = this(rank, builder())
+ def getElementType = eType
+}
+
+final case class PointerType[TargetType <: Type](tType: TargetType) extends Type {
+ def this()(implicit builder: TypeBuilder[TargetType]) = this(builder())
+ def getTargetType = tType
+}
+
+abstract class StructType extends Type
+
+trait TypeBuilder[T <: Type] {
+ def apply() : T
+}
+
+object TypeBuilder {
+ implicit val intBuilder = new TypeBuilder[IntType] { def apply() = new IntType }
+ implicit val floatBuilder = new TypeBuilder[FloatType] { def apply() = new FloatType }
+ implicit val boolBuilder = new TypeBuilder[BoolType] { def apply() = new BoolType }
+}
+
+trait TypeProperty
+trait Numeric extends TypeProperty
+
+class HasProperty[T <: Type, P <: TypeProperty]
+object HasProperty {
+ implicit val intNumeric = new HasProperty[IntType, Numeric]()
+ implicit val floatNumeric = new HasProperty[FloatType, Numeric]()
+}
+
+
--- /dev/null
+package ofc.generators.onetep
+import ofc.codegen._
+
+object OnetepTypes {
+ object FunctionBasis extends StructType {
+ val numPPDsInSphere = {
+ val fieldType = new PointerType[ArrayType[IntType]](new ArrayType[IntType](1))
+ new FieldSymbol[PointerType[ArrayType[IntType]]]("n_ppds_sphere", fieldType)
+ }
+
+ val ppdList = {
+ val fieldType = new PointerType[ArrayType[IntType]](new ArrayType[IntType](1))
+ new FieldSymbol[PointerType[ArrayType[IntType]]]("ppd_list", fieldType)
+ }
+
+ val num = new FieldSymbol[IntType]("num")
+
+ val tightBoxes = {
+ val fieldType = new PointerType[ArrayType[StructType]](new ArrayType(1, TightBox))
+ new FieldSymbol[PointerType[ArrayType[StructType]]]("tight_boxes", fieldType)
+ }
+ }
+
+ object CellInfo extends StructType {
+ val ppdWidth = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("n_pt"+dim)}.toSeq
+ val numPPDs = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("n_ppds_a"+dim)}.toSeq
+ }
+
+ object TightBox extends StructType {
+ val startPts = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("start_pts"+dim)}.toSeq
+ val finishPts = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("finish_pts"+dim)}.toSeq
+ val startPPD = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("start_ppd"+dim)}.toSeq
+ val finishPPD = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("finish_ppd"+dim)}.toSeq
+ }
+}
import ofc.codegen._
class PPDFunctionSet(val basisName: String, dataName: String) extends FunctionSet {
- val basis = new NamedUnboundVarSymbol[StructType](basisName)
- val data = new NamedUnboundVarSymbol[ArrayType[FloatType]](dataName)
- val pubCell = new NamedUnboundVarSymbol[StructType]("pub_cell")
+ import OnetepTypes._
- val numSpheres = basis % FieldSymbol[IntType]("num");
- val ppdWidths = for(dim <- 1 to 3) yield pubCell % FieldSymbol[IntType]("n_pt"+dim)
- val cellWidthInPPDs = for(dim <- 1 to 3) yield pubCell % FieldSymbol[IntType]("n_ppds_a"+dim)
+ val basis = new NamedUnboundVarSymbol[StructType](basisName, FunctionBasis)
+ val data = new NamedUnboundVarSymbol[ArrayType[FloatType]](dataName, new ArrayType[FloatType](1))
+ val pubCell = new NamedUnboundVarSymbol[StructType]("pub_cell", OnetepTypes.CellInfo)
+
+ val numSpheres = basis % FunctionBasis.num
+ val ppdWidths = for(dim <- 0 to 2) yield pubCell % CellInfo.ppdWidth(dim)
+ val cellWidthInPPDs = for(dim <- 0 to 2) yield pubCell % CellInfo.numPPDs(dim)
def getSuffixFragment = {
val producer = new ProducerStatement
val sphereIndex = producer.addIteration("sphere_index", 1, numSpheres)
- val numPPDs = (~(basis % FieldSymbol[PointerType[ArrayType[IntType]]]("n_ppds_sphere"))).readAt(sphereIndex)
+ val numPPDs = (~(basis % FunctionBasis.numPPDsInSphere)).readAt(sphereIndex)
val ppdIndex = producer.addIteration("ppd_index", 1, numPPDs)
- val ppdGlobalCount = (~(basis % FieldSymbol[PointerType[ArrayType[IntType]]]("ppd_list"))).readAt(ppdIndex, 1) - 1
+ val ppdGlobalCount = (~(basis % FunctionBasis.ppdList)).readAt(ppdIndex, 1) - 1
// The integer co-ordinates of the PPD (0-based)
val a3pos = producer.addExpression("ppd_pos1", ppdGlobalCount / (cellWidthInPPDs(0)*cellWidthInPPDs(1)))
val a1pos = producer.addExpression("ppd_pos3", ppdGlobalCount % cellWidthInPPDs(0))
val ppdPos = List(a1pos, a2pos, a3pos)
- val tightbox = (~(basis % FieldSymbol[PointerType[ArrayType[StructType]]]("tight_boxes"))).readAt(sphereIndex)
+ val tightbox = (~(basis % FunctionBasis.tightBoxes)).readAt(sphereIndex)
// The offsets into the PPDs for the edges of the tightbox
- val ppdStartOffsets = for(dim <- 1 to 3) yield tightbox % FieldSymbol[IntType]("start_pts"+dim)
- val ppdFinishOffsets = for(dim <- 1 to 3) yield tightbox % FieldSymbol[IntType]("finish_pts"+dim)
+ val ppdStartOffsets = for(dim <- 0 to 2) yield tightbox % TightBox.startPts(dim)
+ val ppdFinishOffsets = for(dim <- 0 to 2) yield tightbox % TightBox.finishPts(dim)
// The first and last PPDs in PPD co-ordinates (0-based, inside simulation cell)
val startPPDs = for(dim <- 0 to 2) yield
- producer.addExpression("start_ppd"+(dim+1), (tightbox % FieldSymbol[IntType]("start_ppd"+(dim+1)) + cellWidthInPPDs(dim)-1) % cellWidthInPPDs(dim))
+ producer.addExpression("start_ppd"+(dim+1), (tightbox % TightBox.startPPD(dim) + cellWidthInPPDs(dim)-1) % cellWidthInPPDs(dim))
val finishPPDs = for(dim <- 0 to 2) yield
- producer.addExpression("finish_ppd"+(dim+1),(tightbox % FieldSymbol[IntType]("finish_ppd"+(dim+1)) + cellWidthInPPDs(dim)-1) % cellWidthInPPDs(dim))
+ producer.addExpression("finish_ppd"+(dim+1),(tightbox % TightBox.finishPPD(dim) + cellWidthInPPDs(dim)-1) % cellWidthInPPDs(dim))
// Offsets for the current PPD being iterated over
val loopStarts = for(dim <- 0 to 2) yield