}
class Expression[T <: Type] {
+ // Field Operations
def +(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[T] =
- new NumericOperator[T](NumericOperator.Add, this, rhs)
+ new NumericOperator[T](NumericOperations.Add, this, rhs)
def -(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[T] =
- new NumericOperator[T](NumericOperator.Sub, this, rhs)
+ new NumericOperator[T](NumericOperations.Sub, this, rhs)
def *(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[T] =
- new NumericOperator[T](NumericOperator.Mul, this, rhs)
+ new NumericOperator[T](NumericOperations.Mul, this, rhs)
def /(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[T] =
- new NumericOperator[T](NumericOperator.Div, this, rhs)
+ new NumericOperator[T](NumericOperations.Div, this, rhs)
def %(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[T] =
- new NumericOperator[T](NumericOperator.Mod, this, rhs)
+ new NumericOperator[T](NumericOperations.Mod, this, rhs)
- def ~>[FieldType <: Type](field: FieldSymbol[FieldType])(implicit witness: <:<[this.type, Expression[StructType]]) : Expression[FieldType] =
+ // Comparison Operations
+ def |<|(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[BoolType] =
+ new NumericComparison[T](NumericOperations.LT, this, rhs)
+
+ def |<=|(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[BoolType] =
+ new NumericComparison[T](NumericOperations.LTE, this, rhs)
+
+ def |==|(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[BoolType] =
+ new NumericComparison[T](NumericOperations.EQ, this, rhs)
+
+ def |!=|(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[BoolType] =
+ new NumericComparison[T](NumericOperations.NE, this, rhs)
+
+ def |>|(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[BoolType] =
+ new NumericComparison[T](NumericOperations.GT, this, rhs)
+
+ def |>=|(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[BoolType] =
+ new NumericComparison[T](NumericOperations.GTE, this, rhs)
+
+ def %[FieldType <: Type](field: FieldSymbol[FieldType])(implicit witness: <:<[this.type, Expression[StructType]]) : Expression[FieldType] =
new FieldAccess[FieldType](witness(this), field)
def readAt[T <: Type](index: Expression[IntType]*)(implicit witness: <:<[this.type, Expression[ArrayType[T]]]) : Expression[T] =
new ArrayRead(witness(this), index.toList)
- //def apply[T <: Type](index: Expression[IntType]*)(implicit witness: <:<[this.type, Expression[ArrayType[T]]]) : Expression[T] =
- // this.readAt(index.toList)
-
def unary_~[T <: Type]()(implicit witness: <:<[this.type, Expression[PointerType[T]]]) : Expression[T] =
new PointerDereference(witness(this))
}
package ofc.codegen
-object NumericOperator {
- sealed abstract class Operator
- case object Add extends Operator
- case object Sub extends Operator
- case object Mul extends Operator
- case object Div extends Operator
- case object Mod extends Operator
+object NumericOperations {
+ sealed abstract class FieldOp
+ case object Add extends FieldOp
+ case object Sub extends FieldOp
+ case object Mul extends FieldOp
+ case object Div extends FieldOp
+ case object Mod extends FieldOp
+
+ sealed abstract class CompareOp
+ object LT extends CompareOp
+ object LTE extends CompareOp
+ object EQ extends CompareOp
+ object NE extends CompareOp
+ object GT extends CompareOp
+ object GTE extends CompareOp
}
-class NumericOperator[T <: Type](op: NumericOperator.Operator, left: Expression[T], right: Expression[T]) extends Expression[T]
+class NumericOperator[T <: Type](op: NumericOperations.FieldOp, left: Expression[T], right: Expression[T]) extends Expression[T]
+class NumericComparison[T <: Type](op: NumericOperations.CompareOp, left: Expression[T], right: Expression[T]) extends Expression[BoolType]
package ofc.codegen
class ProducerStatement extends Statement {
- class VariableRange(symbol: Symbol, count: Expression[IntType])
+ class VariableRange(symbol: Symbol, first: Expression[IntType], last: Expression[IntType])
class Predicate
var statement = new NullStatement
symbol
}
- def addIteration(name: String, count: Expression[IntType]) : VarSymbol[IntType] = {
+ def addIteration(name: String, first: Expression[IntType], last: Expression[IntType]) : VarSymbol[IntType] = {
val symbol = new DeclaredVarSymbol[IntType](name)
- ranges +:= new VariableRange(symbol, count)
+ ranges +:= new VariableRange(symbol, first, last)
symbol
}
}
val data = NamedUnboundVarSymbol[ArrayType[FloatType]](dataName)
val pubCell = NamedUnboundVarSymbol[StructType]("pub_cell")
- val numSpheres = basis~>FieldSymbol[IntType]("num");
- val ppdWidths = for(dim <- 1 to 3) yield pubCell~>FieldSymbol[IntType]("n_pt"+dim)
- val cellWidthInPPDs = for(dim <- 1 to 3) yield pubCell~>FieldSymbol[IntType]("n_ppds_a"+dim)
+ val numSpheres = basis % FieldSymbol[IntType]("num");
+ val ppdWidths = for(dim <- 1 to 3) yield pubCell % FieldSymbol[IntType]("n_pt"+dim)
+ val cellWidthInPPDs = for(dim <- 1 to 3) yield pubCell % FieldSymbol[IntType]("n_ppds_a"+dim)
def getSuffixFragment = {
val producer = new ProducerStatement
- val sphereIndex = producer.addIteration("sphere_index", numSpheres)
- val numPPDs = (~(basis~>FieldSymbol[PointerType[ArrayType[IntType]]]("n_ppds_sphere"))).readAt(sphereIndex)
- val ppdIndex = producer.addIteration("ppd_index", numPPDs)
- val ppdGlobalCount = (~(basis~>FieldSymbol[PointerType[ArrayType[IntType]]]("ppd_list"))).readAt(ppdIndex, 1) - 1
+ val sphereIndex = producer.addIteration("sphere_index", 1, numSpheres)
+ val numPPDs = (~(basis % FieldSymbol[PointerType[ArrayType[IntType]]]("n_ppds_sphere"))).readAt(sphereIndex)
+ val ppdIndex = producer.addIteration("ppd_index", 1, numPPDs)
+ val ppdGlobalCount = (~(basis % FieldSymbol[PointerType[ArrayType[IntType]]]("ppd_list"))).readAt(ppdIndex, 1) - 1
// We need to calculate the integer co-ordinates of the PPD (0-based)
val a3pos = ppdGlobalCount / (cellWidthInPPDs(0)*cellWidthInPPDs(1))
val ppdPos = List(a1pos, a2pos, a3pos)
- val tightbox = (~(basis~>FieldSymbol[PointerType[ArrayType[StructType]]]("tight_boxes"))).readAt(sphereIndex)
- val ppdStartOffsets = for(dim <- 1 to 3) yield tightbox ~> FieldSymbol[IntType]("start_pts"+dim)
- val ppdFinishOffsets = for(dim <- 1 to 3) yield tightbox ~> FieldSymbol[IntType]("finish_pts"+dim)
+ val tightbox = (~(basis % FieldSymbol[PointerType[ArrayType[StructType]]]("tight_boxes"))).readAt(sphereIndex)
+ val ppdStartOffsets = for(dim <- 1 to 3) yield tightbox % FieldSymbol[IntType]("start_pts"+dim)
+ val ppdFinishOffsets = for(dim <- 1 to 3) yield tightbox % FieldSymbol[IntType]("finish_pts"+dim)
- val startPPDs = for(dim <- 0 to 2) yield (tightbox ~> FieldSymbol("start_ppd"+(dim+1)) + cellWidthInPPDs(dim)) % cellWidthInPPDs(dim)
- val finishPPDs = for(dim <- 0 to 2) yield (tightbox ~> FieldSymbol("finish_ppd"+(dim+1)) + cellWidthInPPDs(dim)) % cellWidthInPPDs(dim)
+ val startPPDs = for(dim <- 0 to 2) yield (tightbox % FieldSymbol[IntType]("start_ppd"+(dim+1)) + cellWidthInPPDs(dim)) % cellWidthInPPDs(dim)
+ val finishPPDs = for(dim <- 0 to 2) yield (tightbox % FieldSymbol[IntType]("finish_ppd"+(dim+1)) + cellWidthInPPDs(dim)) % cellWidthInPPDs(dim)
- //val ppdRanges = for(dim <- 0 to 2) yield producer.addIteration("point"+(dim+1), ppdWidths(dim))
+ val loopStarts = for(dim <- 0 to 2) yield new ConditionalValue[IntType](startPPDs(dim) |==| ppdPos(dim), ppdStartOffsets(dim), 1)
+ val loopEnds = for(dim <- 0 to 2) yield new ConditionalValue[IntType](finishPPDs(dim) |==| ppdPos(dim), ppdFinishOffsets(dim), ppdWidths(dim))
+
+ val ppdIndices = for(dim <- 0 to 2) yield producer.addIteration("point"+(dim+1), loopStarts(dim), loopEnds(dim))
producer
}