From a6b739a8ae77d000db96b6d434e6734fe609714e Mon Sep 17 00:00:00 2001 From: Francis Russell Date: Thu, 5 Apr 2012 16:25:35 +0100 Subject: [PATCH] Construct PPDFunctionSet producer iteration space. --- src/ofc/codegen/Expression.scala | 35 ++++++++++++++----- src/ofc/codegen/NumericOperator.scala | 25 ++++++++----- src/ofc/codegen/ProducerStatement.scala | 6 ++-- .../generators/onetep/PPDFunctionSet.scala | 29 ++++++++------- 4 files changed, 62 insertions(+), 33 deletions(-) diff --git a/src/ofc/codegen/Expression.scala b/src/ofc/codegen/Expression.scala index 05ce29c..fcd3901 100644 --- a/src/ofc/codegen/Expression.scala +++ b/src/ofc/codegen/Expression.scala @@ -23,30 +23,47 @@ object Expression { } class Expression[T <: Type] { + // Field Operations def +(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[T] = - new NumericOperator[T](NumericOperator.Add, this, rhs) + new NumericOperator[T](NumericOperations.Add, this, rhs) def -(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[T] = - new NumericOperator[T](NumericOperator.Sub, this, rhs) + new NumericOperator[T](NumericOperations.Sub, this, rhs) def *(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[T] = - new NumericOperator[T](NumericOperator.Mul, this, rhs) + new NumericOperator[T](NumericOperations.Mul, this, rhs) def /(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[T] = - new NumericOperator[T](NumericOperator.Div, this, rhs) + new NumericOperator[T](NumericOperations.Div, this, rhs) def %(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[T] = - new NumericOperator[T](NumericOperator.Mod, this, rhs) + new NumericOperator[T](NumericOperations.Mod, this, rhs) - def ~>[FieldType <: Type](field: FieldSymbol[FieldType])(implicit witness: <:<[this.type, Expression[StructType]]) : Expression[FieldType] = + // Comparison Operations + def |<|(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[BoolType] = + new NumericComparison[T](NumericOperations.LT, this, rhs) + + def |<=|(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[BoolType] = + new NumericComparison[T](NumericOperations.LTE, this, rhs) + + def |==|(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[BoolType] = + new NumericComparison[T](NumericOperations.EQ, this, rhs) + + def |!=|(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[BoolType] = + new NumericComparison[T](NumericOperations.NE, this, rhs) + + def |>|(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[BoolType] = + new NumericComparison[T](NumericOperations.GT, this, rhs) + + def |>=|(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[BoolType] = + new NumericComparison[T](NumericOperations.GTE, this, rhs) + + def %[FieldType <: Type](field: FieldSymbol[FieldType])(implicit witness: <:<[this.type, Expression[StructType]]) : Expression[FieldType] = new FieldAccess[FieldType](witness(this), field) def readAt[T <: Type](index: Expression[IntType]*)(implicit witness: <:<[this.type, Expression[ArrayType[T]]]) : Expression[T] = new ArrayRead(witness(this), index.toList) - //def apply[T <: Type](index: Expression[IntType]*)(implicit witness: <:<[this.type, Expression[ArrayType[T]]]) : Expression[T] = - // this.readAt(index.toList) - def unary_~[T <: Type]()(implicit witness: <:<[this.type, Expression[PointerType[T]]]) : Expression[T] = new PointerDereference(witness(this)) } diff --git a/src/ofc/codegen/NumericOperator.scala b/src/ofc/codegen/NumericOperator.scala index 267ba10..8225cf7 100644 --- a/src/ofc/codegen/NumericOperator.scala +++ b/src/ofc/codegen/NumericOperator.scala @@ -1,12 +1,21 @@ package ofc.codegen -object NumericOperator { - sealed abstract class Operator - case object Add extends Operator - case object Sub extends Operator - case object Mul extends Operator - case object Div extends Operator - case object Mod extends Operator +object NumericOperations { + sealed abstract class FieldOp + case object Add extends FieldOp + case object Sub extends FieldOp + case object Mul extends FieldOp + case object Div extends FieldOp + case object Mod extends FieldOp + + sealed abstract class CompareOp + object LT extends CompareOp + object LTE extends CompareOp + object EQ extends CompareOp + object NE extends CompareOp + object GT extends CompareOp + object GTE extends CompareOp } -class NumericOperator[T <: Type](op: NumericOperator.Operator, left: Expression[T], right: Expression[T]) extends Expression[T] +class NumericOperator[T <: Type](op: NumericOperations.FieldOp, left: Expression[T], right: Expression[T]) extends Expression[T] +class NumericComparison[T <: Type](op: NumericOperations.CompareOp, left: Expression[T], right: Expression[T]) extends Expression[BoolType] diff --git a/src/ofc/codegen/ProducerStatement.scala b/src/ofc/codegen/ProducerStatement.scala index 8c1af5e..99d858e 100644 --- a/src/ofc/codegen/ProducerStatement.scala +++ b/src/ofc/codegen/ProducerStatement.scala @@ -1,7 +1,7 @@ package ofc.codegen class ProducerStatement extends Statement { - class VariableRange(symbol: Symbol, count: Expression[IntType]) + class VariableRange(symbol: Symbol, first: Expression[IntType], last: Expression[IntType]) class Predicate var statement = new NullStatement @@ -15,9 +15,9 @@ class ProducerStatement extends Statement { symbol } - def addIteration(name: String, count: Expression[IntType]) : VarSymbol[IntType] = { + def addIteration(name: String, first: Expression[IntType], last: Expression[IntType]) : VarSymbol[IntType] = { val symbol = new DeclaredVarSymbol[IntType](name) - ranges +:= new VariableRange(symbol, count) + ranges +:= new VariableRange(symbol, first, last) symbol } } diff --git a/src/ofc/generators/onetep/PPDFunctionSet.scala b/src/ofc/generators/onetep/PPDFunctionSet.scala index 24c298a..622308a 100644 --- a/src/ofc/generators/onetep/PPDFunctionSet.scala +++ b/src/ofc/generators/onetep/PPDFunctionSet.scala @@ -6,16 +6,16 @@ class PPDFunctionSet(val basisName: String, dataName: String) extends FunctionSe val data = NamedUnboundVarSymbol[ArrayType[FloatType]](dataName) val pubCell = NamedUnboundVarSymbol[StructType]("pub_cell") - val numSpheres = basis~>FieldSymbol[IntType]("num"); - val ppdWidths = for(dim <- 1 to 3) yield pubCell~>FieldSymbol[IntType]("n_pt"+dim) - val cellWidthInPPDs = for(dim <- 1 to 3) yield pubCell~>FieldSymbol[IntType]("n_ppds_a"+dim) + val numSpheres = basis % FieldSymbol[IntType]("num"); + val ppdWidths = for(dim <- 1 to 3) yield pubCell % FieldSymbol[IntType]("n_pt"+dim) + val cellWidthInPPDs = for(dim <- 1 to 3) yield pubCell % FieldSymbol[IntType]("n_ppds_a"+dim) def getSuffixFragment = { val producer = new ProducerStatement - val sphereIndex = producer.addIteration("sphere_index", numSpheres) - val numPPDs = (~(basis~>FieldSymbol[PointerType[ArrayType[IntType]]]("n_ppds_sphere"))).readAt(sphereIndex) - val ppdIndex = producer.addIteration("ppd_index", numPPDs) - val ppdGlobalCount = (~(basis~>FieldSymbol[PointerType[ArrayType[IntType]]]("ppd_list"))).readAt(ppdIndex, 1) - 1 + val sphereIndex = producer.addIteration("sphere_index", 1, numSpheres) + val numPPDs = (~(basis % FieldSymbol[PointerType[ArrayType[IntType]]]("n_ppds_sphere"))).readAt(sphereIndex) + val ppdIndex = producer.addIteration("ppd_index", 1, numPPDs) + val ppdGlobalCount = (~(basis % FieldSymbol[PointerType[ArrayType[IntType]]]("ppd_list"))).readAt(ppdIndex, 1) - 1 // We need to calculate the integer co-ordinates of the PPD (0-based) val a3pos = ppdGlobalCount / (cellWidthInPPDs(0)*cellWidthInPPDs(1)) @@ -24,14 +24,17 @@ class PPDFunctionSet(val basisName: String, dataName: String) extends FunctionSe val ppdPos = List(a1pos, a2pos, a3pos) - val tightbox = (~(basis~>FieldSymbol[PointerType[ArrayType[StructType]]]("tight_boxes"))).readAt(sphereIndex) - val ppdStartOffsets = for(dim <- 1 to 3) yield tightbox ~> FieldSymbol[IntType]("start_pts"+dim) - val ppdFinishOffsets = for(dim <- 1 to 3) yield tightbox ~> FieldSymbol[IntType]("finish_pts"+dim) + val tightbox = (~(basis % FieldSymbol[PointerType[ArrayType[StructType]]]("tight_boxes"))).readAt(sphereIndex) + val ppdStartOffsets = for(dim <- 1 to 3) yield tightbox % FieldSymbol[IntType]("start_pts"+dim) + val ppdFinishOffsets = for(dim <- 1 to 3) yield tightbox % FieldSymbol[IntType]("finish_pts"+dim) - val startPPDs = for(dim <- 0 to 2) yield (tightbox ~> FieldSymbol("start_ppd"+(dim+1)) + cellWidthInPPDs(dim)) % cellWidthInPPDs(dim) - val finishPPDs = for(dim <- 0 to 2) yield (tightbox ~> FieldSymbol("finish_ppd"+(dim+1)) + cellWidthInPPDs(dim)) % cellWidthInPPDs(dim) + val startPPDs = for(dim <- 0 to 2) yield (tightbox % FieldSymbol[IntType]("start_ppd"+(dim+1)) + cellWidthInPPDs(dim)) % cellWidthInPPDs(dim) + val finishPPDs = for(dim <- 0 to 2) yield (tightbox % FieldSymbol[IntType]("finish_ppd"+(dim+1)) + cellWidthInPPDs(dim)) % cellWidthInPPDs(dim) - //val ppdRanges = for(dim <- 0 to 2) yield producer.addIteration("point"+(dim+1), ppdWidths(dim)) + val loopStarts = for(dim <- 0 to 2) yield new ConditionalValue[IntType](startPPDs(dim) |==| ppdPos(dim), ppdStartOffsets(dim), 1) + val loopEnds = for(dim <- 0 to 2) yield new ConditionalValue[IntType](finishPPDs(dim) |==| ppdPos(dim), ppdFinishOffsets(dim), ppdWidths(dim)) + + val ppdIndices = for(dim <- 0 to 2) yield producer.addIteration("point"+(dim+1), loopStarts(dim), loopEnds(dim)) producer } -- 2.47.3