From: Francis Russell Date: Tue, 8 May 2012 15:46:03 +0000 (+0100) Subject: Generate SPAM3 iteration code. X-Git-Url: https://git.unchartedbackwaters.co.uk/w/?a=commitdiff_plain;h=7c3b8de84ea6f9adee09d75871802a06f7fd27f9;p=francis%2Fofc.git Generate SPAM3 iteration code. --- diff --git a/src/ofc/codegen/FortranGenerator.scala b/src/ofc/codegen/FortranGenerator.scala index aa73dd1..240123e 100644 --- a/src/ofc/codegen/FortranGenerator.scala +++ b/src/ofc/codegen/FortranGenerator.scala @@ -177,6 +177,7 @@ class FortranGenerator { case (c: NumericOperator[_]) => buildNumericOperator(c) case (c: Conversion[_,_]) => buildConversion(c) case (i: Intrinsic[_]) => buildIntrinsic(i) + case (f: FunctionCall[_]) => buildFunctionCall(f) case x => throw new UnimplementedException("Unknown expression type in FORTRAN generator: " + x.toString) } } @@ -256,10 +257,16 @@ class FortranGenerator { call.getSignature match { case (fortSub: FortranSubroutineSignature) => addLine("call %s(%s)".format(fortSub.getName, call.getParams.map(buildExpression(_)).mkString(", "))) - case _ => throw new LogicError("Fortran generator only knows how to call Fortran functions.") + case _ => throw new LogicError("Fortran generator only knows how to call Fortran sub-routines.") } } + private def buildFunctionCall(call: FunctionCall[_]) : ExpHolder = call.getSignature match { + case (fortFunc: FortranFunctionSignature[_]) => + new ExpHolder(maxPrec, "%s(%s)".format(fortFunc.getName, call.getParams.map(buildExpression(_)).mkString(", "))) + case _ => throw new LogicError("Fortran generator only knows how to call Fortran functions.") + } + private def processScope(scope: ScopeStatement) { for (sym <- scope.getDeclarations) { symbolManager.addSymbol(sym) diff --git a/src/ofc/codegen/FunctionSignature.scala b/src/ofc/codegen/FunctionSignature.scala index b27d8a8..75f1c87 100644 --- a/src/ofc/codegen/FunctionSignature.scala +++ b/src/ofc/codegen/FunctionSignature.scala @@ -13,3 +13,12 @@ class FortranSubroutineSignature(name: String, def getReturnType = new VoidType def getParams = params } + +class FortranFunctionSignature[R <: Type](name: String, + params: Seq[(String, Type)], retType: R) extends FunctionSignature[R] { + + def this(name: String, params: Seq[(String, Type)])(implicit builder: TypeBuilder[R]) = this(name, params, builder()) + def getName = name + def getReturnType = retType + def getParams = params +} diff --git a/src/ofc/codegen/IterationContext.scala b/src/ofc/codegen/IterationContext.scala index c6876aa..355576f 100644 --- a/src/ofc/codegen/IterationContext.scala +++ b/src/ofc/codegen/IterationContext.scala @@ -83,11 +83,17 @@ object IterationContext { class IterationContext extends Statement { import IterationContext._ - var statement = new Comment("Placeholder statement for consumer.") + var declarations : Seq[VarSymbol[_ <: Type]] = Nil + var headers : Seq[Statement] = Nil + var footers : Seq[Statement] = Nil var ranges : Seq[VariableRange] = Nil var predicates : Seq[Predicate] = Nil var expressions : Seq[DerivedExpression] = Nil + def addDeclaration(declaration: VarSymbol[_ <: Type]) { + declarations +:= declaration + } + def addExpression[T <: Type](name: String, expression: Expression[T]) : VarSymbol[T] = { val symbol = new DeclaredVarSymbol[T](name, expression.getType) expressions +:= new DerivedExpression(symbol, expression) @@ -103,22 +109,39 @@ class IterationContext extends Statement { def addPredicate(condition: Expression[BoolType]) { predicates +:= new Predicate(condition) } + + def addHeader(stat: Statement) { + headers +:= stat + } + + def addFooter(stat: Statement) { + footers +:= stat + } def merge(statement: IterationContext) : IterationContext = { val result = new IterationContext + result.declarations = declarations ++ statement.declarations result.ranges = ranges ++ statement.ranges result.predicates = predicates ++ statement.predicates result.expressions = expressions ++ statement.expressions + result.headers = headers ++ statement.headers + result.footers = footers ++ statement.footers result } - def toConcrete : Statement = { + def toConcrete(statement: Statement) : Statement = { val contexts = ranges ++ predicates ++ expressions val sortedContexts = Context.sort(contexts) val block = new BlockStatement var scope : ScopeStatement = block + for(declaration <- declarations) + block.addDeclaration(declaration) + + for(header <- headers) + block += header + for (context <- sortedContexts) { context match { case VariableRange(sym, first, last) => { @@ -141,6 +164,13 @@ class IterationContext extends Statement { } scope += statement + + for(footer <- footers) + block += footer + block } + + def toConcrete : Statement = + toConcrete(new Comment("Placeholder statement for consumer.")) } diff --git a/src/ofc/generators/onetep/CodeGenerator.scala b/src/ofc/generators/onetep/CodeGenerator.scala index 3b6a684..e470d2a 100644 --- a/src/ofc/generators/onetep/CodeGenerator.scala +++ b/src/ofc/generators/onetep/CodeGenerator.scala @@ -2,6 +2,7 @@ package ofc.generators.onetep import ofc.codegen._ class CodeGenerator(dictionary: Dictionary) { + /* val indexSyms : Map[NamedIndex, DeclaredVarSymbol[IntType]] = { for(index <- dictionary.getIndices) yield (index, new DeclaredVarSymbol[IntType](index.getName)) @@ -11,6 +12,7 @@ class CodeGenerator(dictionary: Dictionary) { for((index, sym) <- indexSyms) yield (index, sym: Expression[IntType]) }.toMap + */ class Context extends GenerationContext { val block = new BlockStatement @@ -30,19 +32,18 @@ class CodeGenerator(dictionary: Dictionary) { val lhs = assignment.lhs val rhs = assignment.rhs + val iterationInfo = lhs.getIterationInfo val context = new Context - //TODO: Remove me when symbols are created properly - for((index, sym) <- indexSyms) - context.addDeclaration(sym) - + val indexMap = iterationInfo.getIndexMappings val rhsFragment = rhs.getFragment(indexMap) rhsFragment.setup(context) rhsFragment.teardown(context) val generator = new FortranGenerator - val code = generator(context.getStatement) + val iterated = iterationInfo.getContext.toConcrete(context.getStatement) + val code = generator(iterated) println(code) } } diff --git a/src/ofc/generators/onetep/OnetepFunctions.scala b/src/ofc/generators/onetep/OnetepFunctions.scala index c056ebb..f862e3c 100644 --- a/src/ofc/generators/onetep/OnetepFunctions.scala +++ b/src/ofc/generators/onetep/OnetepFunctions.scala @@ -25,4 +25,21 @@ object OnetepFunctions { ("rspc1", new ArrayType[FloatType](3)), ("rspc2", new ArrayType[FloatType](3)), ("gspc", new ArrayType[ComplexType](3)))) + + val sparse_first_elem_on_node = new FortranFunctionSignature[IntType]("sparse_first_elem_on_node", + Seq(("node", new IntType), + ("mat", OnetepTypes.SPAM3), + ("rowcol", new CharType))) + + val sparse_index_length = new FortranFunctionSignature[IntType]("sparse_index_length", + Seq(("mat", OnetepTypes.SPAM3))) + + val sparse_generate_index = new FortranSubroutineSignature("sparse_generate_index", + Seq(("idx", new ArrayType[IntType](1)), + ("mat", OnetepTypes.SPAM3))) + + val sparse_atom_of_elem = new FortranFunctionSignature[IntType]("sparse_atom_of_elem", + Seq(("elem", new IntType), + ("mat", OnetepTypes.SPAM3), + ("rowcol", new CharType))) } diff --git a/src/ofc/generators/onetep/OnetepVariables.scala b/src/ofc/generators/onetep/OnetepVariables.scala index f08d80c..7a38733 100644 --- a/src/ofc/generators/onetep/OnetepVariables.scala +++ b/src/ofc/generators/onetep/OnetepVariables.scala @@ -3,6 +3,7 @@ import ofc.codegen._ object OnetepVariables { // parallel_strategy + val pub_my_node_id = new NamedUnboundVarSymbol[IntType]("pub_my_node_id") val pub_first_atom_on_node = new NamedUnboundVarSymbol[ArrayType[IntType]]("pub_first_atom_on_node", new ArrayType[IntType](1)) val pub_num_atoms_on_node = new NamedUnboundVarSymbol[ArrayType[IntType]]("pub_num_atoms_on_node", new ArrayType[IntType](1)) } diff --git a/src/ofc/generators/onetep/Operand.scala b/src/ofc/generators/onetep/Operand.scala index 4f0d428..d31822f 100644 --- a/src/ofc/generators/onetep/Operand.scala +++ b/src/ofc/generators/onetep/Operand.scala @@ -4,6 +4,8 @@ import ofc.codegen._ class IterationInfo(val context: IterationContext, val indexMappings: Map[NamedIndex, Expression[IntType]]) { def merge(other: IterationInfo) : IterationInfo = new IterationInfo(context merge other.context, indexMappings ++ other.indexMappings) + def getContext = context + def getIndexMappings = indexMappings } trait Operand { diff --git a/src/ofc/generators/onetep/SPAM3.scala b/src/ofc/generators/onetep/SPAM3.scala index 38be253..33235d4 100644 --- a/src/ofc/generators/onetep/SPAM3.scala +++ b/src/ofc/generators/onetep/SPAM3.scala @@ -19,6 +19,57 @@ class SPAM3(name : String, indices: Seq[NamedIndex]) extends Scalar { def getIterationInfo : IterationInfo = { val context = new IterationContext - throw new ofc.UnimplementedException("not yet implemented.") + + // Create sparse index + val header = new BlockStatement + val indexLength = new FunctionCall(OnetepFunctions.sparse_index_length, Seq(mat)) + val index = new DeclaredVarSymbol[ArrayType[IntType]]("sparse_idx", new ArrayType[IntType](1)) + header += new AllocateStatement(index, Seq(indexLength)) + header += new FunctionCallStatement(new FunctionCall(OnetepFunctions.sparse_generate_index, Seq(index, mat))) + context.addHeader(header) + context.addDeclaration(index) + + val footer = new DeallocateStatement(index) + context.addFooter(footer) + + val firstCol = new FunctionCall(OnetepFunctions.sparse_first_elem_on_node, + Seq(OnetepVariables.pub_first_atom_on_node.at(OnetepVariables.pub_my_node_id), + mat, + new CharLiteral('C'))) + + val lastCol = new FunctionCall(OnetepFunctions.sparse_first_elem_on_node, + Seq(OnetepVariables.pub_first_atom_on_node.at((OnetepVariables.pub_my_node_id: Expression[IntType])+1), + mat, + new CharLiteral('C'))) - 1 + + val firstRow = new FunctionCall(OnetepFunctions.sparse_first_elem_on_node, + Seq(OnetepVariables.pub_first_atom_on_node.at(OnetepVariables.pub_my_node_id), + mat, + new CharLiteral('R'))) + + val lastRow = new FunctionCall(OnetepFunctions.sparse_first_elem_on_node, + Seq(OnetepVariables.pub_first_atom_on_node.at((OnetepVariables.pub_my_node_id: Expression[IntType])+1), + mat, + new CharLiteral('R'))) - 1 + + val col = context.addIteration("col", firstCol, lastCol) + val colAtom = context.addExpression("col_atom", new FunctionCall(OnetepFunctions.sparse_atom_of_elem, + Seq(col, mat, new CharLiteral('C')))) + val localColAtom = context.addExpression("local_col_atom", + colAtom - OnetepVariables.pub_first_atom_on_node.at(OnetepVariables.pub_my_node_id)) + + val row = context.addIteration("row", firstRow, lastRow) + val rowAtom = context.addExpression("row_atom", new FunctionCall(OnetepFunctions.sparse_atom_of_elem, + Seq(row, mat, new CharLiteral('R')))) + + val rowIdx = context.addIteration("row_idx", index.at(localColAtom), + index.at((localColAtom: Expression[IntType])+1)-1) + context.addPredicate(index.at(rowIdx) |==| rowAtom) + + var indexMappings : Map[NamedIndex, Expression[IntType]] = Map.empty + indexMappings += indices(0) -> row + indexMappings += indices(1) -> col + + new IterationInfo(context, indexMappings) } }