From 8bf6650b8a44224c2ebc9487b07551cb4142ab4e Mon Sep 17 00:00:00 2001 From: Francis Russell Date: Tue, 10 Apr 2012 10:57:03 +0100 Subject: [PATCH] Work on SpatialRestriction consumer/producer. --- src/ofc/codegen/ProducerStatement.scala | 4 +++ .../onetep/SpatialRestriction.scala | 29 +++++++++++++++++-- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/src/ofc/codegen/ProducerStatement.scala b/src/ofc/codegen/ProducerStatement.scala index e564ca7..b651c96 100644 --- a/src/ofc/codegen/ProducerStatement.scala +++ b/src/ofc/codegen/ProducerStatement.scala @@ -96,6 +96,10 @@ class ProducerStatement extends Statement { symbol } + def addPredicate(condition: Expression[BoolType]) { + predicates +:= new Predicate(condition) + } + def toConcrete : Statement = { val contexts = ranges ++ predicates ++ expressions val sortedContexts = Context.sort(contexts) diff --git a/src/ofc/generators/onetep/SpatialRestriction.scala b/src/ofc/generators/onetep/SpatialRestriction.scala index 27903a6..98a7c38 100644 --- a/src/ofc/generators/onetep/SpatialRestriction.scala +++ b/src/ofc/generators/onetep/SpatialRestriction.scala @@ -1,4 +1,5 @@ package ofc.generators.onetep +import ofc.InvalidInputException import ofc.codegen._ object SpatialRestriction { @@ -14,8 +15,32 @@ class SpatialRestriction(op: IterationSpace) extends IterationSpace { def getSuffixFragment = new NullStatement def getDataValue = op.getDataValue def getReaderFragment = { - //TODO: Implement me! - new NullStatement + import OnetepTypes._ + import SpatialRestriction._ + + val inputIndices = for(index <- op.getSpatialIndices) yield + index match { + case (f: FunctionSpatialIndex) => f + case _ => throw new InvalidInputException("Input to SpatialRestriction must be a function") + } + + val ppdWidths = for(dim <- 0 to 2) yield CellInfo.public % CellInfo.ppdWidth(dim) + val cellWidthPPDs = for(dim <- 0 to 2) yield CellInfo.public % CellInfo.numPPDs(dim) + val cellWidthPts = for(dim <- 0 to 2) yield cellWidthPPDs(dim) * ppdWidths(dim) + val functionCentre = inputIndices.map(_.getFunctionCentre) + + val producer = new ProducerStatement + val origin = for (dim <- 0 to 2) + yield producer.addExpression("fftbox_origin_pt"+(dim+1), functionCentre(dim) - pubFFTBoxWidth(dim)/2) + + val offset = for (dim <- 0 to 2) + yield producer.addExpression("fftbox_offset_pt"+(dim+1), + (inputIndices(dim).getValue - origin(dim) + cellWidthPts(dim)) % cellWidthPts(dim)) + + for (dim <- 0 to 2) + producer.addPredicate(offset(dim) |<| pubFFTBoxWidth(dim)) + + producer } def getSpatialIndices = { -- 2.47.3