From: Francis Russell Date: Tue, 31 Jan 2012 20:52:56 +0000 (+0000) Subject: Handle external indices differently. X-Git-Url: https://git.unchartedbackwaters.co.uk/w/?a=commitdiff_plain;h=811eaa150f042e10da026591146014620452e951;p=francis%2Fofc.git Handle external indices differently. --- diff --git a/src/ofc/generators/onetep/CodeGenerator.scala b/src/ofc/generators/onetep/CodeGenerator.scala index a628761..737c44f 100644 --- a/src/ofc/generators/onetep/CodeGenerator.scala +++ b/src/ofc/generators/onetep/CodeGenerator.scala @@ -25,12 +25,17 @@ class NameManager { } } +object CodeGenerator { + def getAllSpaces(term: IterationSpace) : Set[IterationSpace] = + term.getOperands.toSet.flatMap(getAllSpaces(_: IterationSpace)) + term +} + class CodeGenerator { val code = new StringBuilder() val nameManager = new NameManager() def collectDeclarations(term: IterationSpace) : Set[String] = { - val declarations = for(index <- term.getSpatialIndices ++ term.getDiscreteIndices; + val declarations = for(index <- term.getIndices; declaration <- index.getDeclarations(nameManager)) yield declaration var declarationsSet = declarations.toSet @@ -45,6 +50,20 @@ class CodeGenerator { } def generateCode(space: IterationSpace) { + val allSpaces = CodeGenerator.getAllSpaces(space) + val allIndices = allSpaces flatMap (_.getIndices) + + println("dumping operations") + for(op <- allSpaces) + println(op) + println("finished dumping operations\n\ndumping indices") + for (i <- allIndices) + println(i) + println("finished dumping indices") + + // Next: we dump all these things into a prefix map + System.exit(0) + val operands = space.getOperands for(operand <- operands) diff --git a/src/ofc/generators/onetep/Tree.scala b/src/ofc/generators/onetep/Tree.scala index 93ba521..6d16e5f 100644 --- a/src/ofc/generators/onetep/Tree.scala +++ b/src/ofc/generators/onetep/Tree.scala @@ -22,6 +22,9 @@ trait IterationSpace { def getOperands : List[IterationSpace] def getSpatialIndices : List[SpatialIndex] def getDiscreteIndices : List[DiscreteIndex] + def getExternalIndices : Set[Index] + def getInternalIndices : Set[Index] = (getSpatialIndices ++ getDiscreteIndices).toSet + def getIndices : Set[Index] = getInternalIndices ++ getExternalIndices } trait DataSpace extends IterationSpace { @@ -37,6 +40,7 @@ class Assignment(indexBindings: IndexBindings, lhs: DataSpace, rhs: IterationSpa def getOperands = List(lhs,rhs) def getSpatialIndices = Nil def getDiscreteIndices = Nil + def getExternalIndices = Set() def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed") } @@ -44,6 +48,7 @@ class Scalar(value: Double) extends IterationSpace { def getOperands = Nil def getSpatialIndices = Nil def getDiscreteIndices = Nil + def getExternalIndices = Set() def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed") } @@ -84,6 +89,7 @@ class GeneralInnerProduct(operands: List[IterationSpace], removedIndices: Set[In def getOperands = operands def getSpatialIndices = spatialIndices def getDiscreteIndices = discreteIndices + def getExternalIndices = Set() def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed") } @@ -103,6 +109,7 @@ class Reciprocal(op: IterationSpace) extends IterationSpace { def getOperands = List(op) def getSpatialIndices = spatialIndices.toList def getDiscreteIndices = op.getDiscreteIndices + def getExternalIndices = Set() def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed") } @@ -110,6 +117,7 @@ class Laplacian(op: IterationSpace) extends IterationSpace { def getOperands = List(op) def getSpatialIndices = op.getSpatialIndices def getDiscreteIndices = op.getDiscreteIndices + def getExternalIndices = Set() def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed") } @@ -129,6 +137,7 @@ class SpatialRestriction(op: IterationSpace) extends IterationSpace { def getOperands = List(op) def getSpatialIndices = spatialIndices.toList def getDiscreteIndices = op.getDiscreteIndices + def getExternalIndices = Set() def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed") } @@ -171,6 +180,7 @@ class SPAM3(name : String) extends Matrix { def getSpatialIndices = Nil def getDiscreteIndices = List(rowIndex, colIndex) + def getExternalIndices = Set() def getAccessExpression(indexNames: NameManager) = throw new UnimplementedException("Access failed") } @@ -241,6 +251,7 @@ class PPDFunctionSet(val basis : String, data : String) extends FunctionSet { def getSpatialIndices = spatialIndices.toList def getDiscreteIndices = List(getSphereIndex) + def getExternalIndices = Set(getPPDIndex) def getAccessExpression(indexNames: NameManager) = { val index = getSphere(indexNames)+"%offset + &\n" +