From: Francis Russell Date: Wed, 25 Jan 2012 05:24:29 +0000 (+0000) Subject: Allocate storage for operations that remove indices. X-Git-Url: https://git.unchartedbackwaters.co.uk/w/?a=commitdiff_plain;h=84ac25c1656edce90723e71768ed3eb9fac69c12;p=francis%2Fofc.git Allocate storage for operations that remove indices. --- diff --git a/src/ofc/generators/onetep/CodeGenerator.scala b/src/ofc/generators/onetep/CodeGenerator.scala index 934090d..3d2f9db 100644 --- a/src/ofc/generators/onetep/CodeGenerator.scala +++ b/src/ofc/generators/onetep/CodeGenerator.scala @@ -15,7 +15,15 @@ class IndexNames { } class CodeGenerator { + val code = new StringBuilder() val indexNames = new IndexNames() + var nextVarIndex = 0 + + def newLocalVar : String = { + val name = "var_" + nextVarIndex + nextVarIndex += 1 + name + } def apply(assignment: Assignment) { generateCode(assignment) @@ -40,6 +48,17 @@ class CodeGenerator { // We search for all indices bound to the one being destroyed // We generate a composite iteration over those loops // If GeneralInnerProduct rebuilds derived indices, we need to be able to construct a valid size + val concreteIndexList = destroyedIndices.toList + val storageName = newLocalVar + code append "real(kind=DP), allocatable, dimension" + (":"*concreteIndexList.size).mkString("(",",",")") + " :: " + + storageName + "\n" + code append "allocate("+ storageName + + (concreteIndexList map ((x : Index) => x.getDenseWidth)).mkString("(",",",")") + ", stat=ierr)" + + // We've declared temporary storage, now create the loops to populate it + //for (index <- concreteIndexList) code append index.generateIterationHeader(indexNames) + + println(code.mkString) System.exit(0) } diff --git a/src/ofc/generators/onetep/Tree.scala b/src/ofc/generators/onetep/Tree.scala index 02c5145..cb758a2 100644 --- a/src/ofc/generators/onetep/Tree.scala +++ b/src/ofc/generators/onetep/Tree.scala @@ -7,7 +7,10 @@ import ofc.{InvalidInputException,UnimplementedException} trait Index { def getName : String def getDependencies : Set[Index] + def getDenseWidth : String + //def generateIterationHeader(names: IndexNames) : String } + trait SpatialIndex extends Index trait DiscreteIndex extends Index @@ -43,14 +46,16 @@ class Scalar(value: Double) extends IterationSpace { class GeneralInnerProduct(operands: List[IterationSpace], removedIndices: Set[Index]) extends IterationSpace { - class DenseSpatialIndex(parent: GeneralInnerProduct) extends SpatialIndex{ + class DenseSpatialIndex(parent: GeneralInnerProduct, original: SpatialIndex) extends SpatialIndex{ def getDependencies = Set() def getName = "dense_spatial_index" + def getDenseWidth = original.getDenseWidth } - class DenseDiscreteIndex(parent: GeneralInnerProduct) extends DiscreteIndex { + class DenseDiscreteIndex(parent: GeneralInnerProduct, original: DiscreteIndex) extends DiscreteIndex { def getDependencies = Set() def getName = "dense_discrete_index" + def getDenseWidth = original.getDenseWidth } val spatialIndices = @@ -58,14 +63,14 @@ class GeneralInnerProduct(operands: List[IterationSpace], removedIndices: Set[In if (index.getDependencies.intersect(removedIndices).isEmpty) index else - new DenseSpatialIndex(this) + new DenseSpatialIndex(this, index) val discreteIndices = for(op <- operands; index <- op.getDiscreteIndices; if (!removedIndices.contains(index))) yield if (index.getDependencies.intersect(removedIndices).isEmpty) index else - new DenseDiscreteIndex(this) + new DenseDiscreteIndex(this, index) def getOperands = operands def getSpatialIndices() = spatialIndices @@ -74,11 +79,14 @@ class GeneralInnerProduct(operands: List[IterationSpace], removedIndices: Set[In } class Reciprocal(op: IterationSpace) extends IterationSpace { - class BlockIndex(parent: Reciprocal, dimension: Int) extends SpatialIndex { + class BlockIndex(parent: Reciprocal, dimension: Int, original: SpatialIndex) extends SpatialIndex { def getName = "reciprocal_index_" + dimension def getDependencies = Set() + def getDenseWidth = original.getDenseWidth } - val spatialIndices = for (dimension <- 0 to op.getSpatialIndices.size) yield new BlockIndex(this, dimension) + + val spatialIndices = for (dimension <- 0 until op.getSpatialIndices.size) yield + new BlockIndex(this, dimension, op.getSpatialIndices()(dimension)) def getOperands = List(op) def getSpatialIndices() = spatialIndices.toList @@ -97,9 +105,10 @@ class SpatialRestriction(op: IterationSpace) extends IterationSpace { class RestrictedIndex(parent: SpatialRestriction, dimension: Int) extends SpatialIndex { def getName = "restriction_index_" + dimension def getDependencies = Set() + def getDenseWidth = throw new UnimplementedException("Restriction unimplemnted") } - val spatialIndices = for (dimension <- 0 to op.getSpatialIndices.size) yield new RestrictedIndex(this, dimension) + val spatialIndices = for (dimension <- 0 until op.getSpatialIndices.size) yield new RestrictedIndex(this, dimension) def getOperands() = List(op) def getSpatialIndices() = spatialIndices.toList @@ -114,11 +123,13 @@ class SPAM3(name : String) extends Matrix { override def toString = parent + ".row" def getName = "row_index" def getDependencies = Set() + def getDenseWidth = throw new UnimplementedException("Matrix size unimplemented") } class ColIndex(parent: SPAM3) extends DiscreteIndex { override def toString = parent + ".col" def getName = "row_index" def getDependencies = Set() + def getDenseWidth = throw new UnimplementedException("Matrix size unimplemented") } val rowIndex = new RowIndex(this) @@ -129,20 +140,25 @@ class SPAM3(name : String) extends Matrix { def getAccessExpression(indexNames: IndexNames) = throw new UnimplementedException("Access failed") } -class PPDFunctionSet(basis : String, data : String) extends FunctionSet { +class PPDFunctionSet(val basis : String, data : String) extends FunctionSet { class SphereIndex(parent: PPDFunctionSet) extends DiscreteIndex { def getName = "sphere_index" def getDependencies = Set() + def getDenseWidth = throw new UnimplementedException("Sphere count unimplemented") + //TODO: def getSphere = parent.basis + "%spheres(????)" } class PPDIndex(parent: PPDFunctionSet) extends DiscreteIndex { def getName = "ppd_index" def getDependencies = Set[Index](parent.getSphereIndex()) + //TODO: def getDenseWidth = parent.getSphereIndex.getSphere + "%n_ppds_sphere" + def getDenseWidth = parent.basis+"%max_n_ppds_sphere" } class IntraPPDIndex(parent: PPDFunctionSet, dimension: Int) extends SpatialIndex { def getName = "intra_ppd_index_" + dimension def getDependencies = Set[Index](parent.getPPDIndex) + def getDenseWidth = "pub_cell%total_pt"+(dimension+1) } val ppdIndex = new PPDIndex(this)