case (c: NumericOperator[_]) => buildNumericOperator(c)
case (c: Conversion[_,_]) => buildConversion(c)
case (i: Intrinsic[_]) => buildIntrinsic(i)
+ case (f: FunctionCall[_]) => buildFunctionCall(f)
case x => throw new UnimplementedException("Unknown expression type in FORTRAN generator: " + x.toString)
}
}
call.getSignature match {
case (fortSub: FortranSubroutineSignature) =>
addLine("call %s(%s)".format(fortSub.getName, call.getParams.map(buildExpression(_)).mkString(", ")))
- case _ => throw new LogicError("Fortran generator only knows how to call Fortran functions.")
+ case _ => throw new LogicError("Fortran generator only knows how to call Fortran sub-routines.")
}
}
+ private def buildFunctionCall(call: FunctionCall[_]) : ExpHolder = call.getSignature match {
+ case (fortFunc: FortranFunctionSignature[_]) =>
+ new ExpHolder(maxPrec, "%s(%s)".format(fortFunc.getName, call.getParams.map(buildExpression(_)).mkString(", ")))
+ case _ => throw new LogicError("Fortran generator only knows how to call Fortran functions.")
+ }
+
private def processScope(scope: ScopeStatement) {
for (sym <- scope.getDeclarations) {
symbolManager.addSymbol(sym)
def getReturnType = new VoidType
def getParams = params
}
+
+class FortranFunctionSignature[R <: Type](name: String,
+ params: Seq[(String, Type)], retType: R) extends FunctionSignature[R] {
+
+ def this(name: String, params: Seq[(String, Type)])(implicit builder: TypeBuilder[R]) = this(name, params, builder())
+ def getName = name
+ def getReturnType = retType
+ def getParams = params
+}
class IterationContext extends Statement {
import IterationContext._
- var statement = new Comment("Placeholder statement for consumer.")
+ var declarations : Seq[VarSymbol[_ <: Type]] = Nil
+ var headers : Seq[Statement] = Nil
+ var footers : Seq[Statement] = Nil
var ranges : Seq[VariableRange] = Nil
var predicates : Seq[Predicate] = Nil
var expressions : Seq[DerivedExpression] = Nil
+ def addDeclaration(declaration: VarSymbol[_ <: Type]) {
+ declarations +:= declaration
+ }
+
def addExpression[T <: Type](name: String, expression: Expression[T]) : VarSymbol[T] = {
val symbol = new DeclaredVarSymbol[T](name, expression.getType)
expressions +:= new DerivedExpression(symbol, expression)
def addPredicate(condition: Expression[BoolType]) {
predicates +:= new Predicate(condition)
}
+
+ def addHeader(stat: Statement) {
+ headers +:= stat
+ }
+
+ def addFooter(stat: Statement) {
+ footers +:= stat
+ }
def merge(statement: IterationContext) : IterationContext = {
val result = new IterationContext
+ result.declarations = declarations ++ statement.declarations
result.ranges = ranges ++ statement.ranges
result.predicates = predicates ++ statement.predicates
result.expressions = expressions ++ statement.expressions
+ result.headers = headers ++ statement.headers
+ result.footers = footers ++ statement.footers
result
}
- def toConcrete : Statement = {
+ def toConcrete(statement: Statement) : Statement = {
val contexts = ranges ++ predicates ++ expressions
val sortedContexts = Context.sort(contexts)
val block = new BlockStatement
var scope : ScopeStatement = block
+ for(declaration <- declarations)
+ block.addDeclaration(declaration)
+
+ for(header <- headers)
+ block += header
+
for (context <- sortedContexts) {
context match {
case VariableRange(sym, first, last) => {
}
scope += statement
+
+ for(footer <- footers)
+ block += footer
+
block
}
+
+ def toConcrete : Statement =
+ toConcrete(new Comment("Placeholder statement for consumer."))
}
import ofc.codegen._
class CodeGenerator(dictionary: Dictionary) {
+ /*
val indexSyms : Map[NamedIndex, DeclaredVarSymbol[IntType]] = {
for(index <- dictionary.getIndices) yield
(index, new DeclaredVarSymbol[IntType](index.getName))
for((index, sym) <- indexSyms) yield
(index, sym: Expression[IntType])
}.toMap
+ */
class Context extends GenerationContext {
val block = new BlockStatement
val lhs = assignment.lhs
val rhs = assignment.rhs
+ val iterationInfo = lhs.getIterationInfo
val context = new Context
- //TODO: Remove me when symbols are created properly
- for((index, sym) <- indexSyms)
- context.addDeclaration(sym)
-
+ val indexMap = iterationInfo.getIndexMappings
val rhsFragment = rhs.getFragment(indexMap)
rhsFragment.setup(context)
rhsFragment.teardown(context)
val generator = new FortranGenerator
- val code = generator(context.getStatement)
+ val iterated = iterationInfo.getContext.toConcrete(context.getStatement)
+ val code = generator(iterated)
println(code)
}
}
("rspc1", new ArrayType[FloatType](3)),
("rspc2", new ArrayType[FloatType](3)),
("gspc", new ArrayType[ComplexType](3))))
+
+ val sparse_first_elem_on_node = new FortranFunctionSignature[IntType]("sparse_first_elem_on_node",
+ Seq(("node", new IntType),
+ ("mat", OnetepTypes.SPAM3),
+ ("rowcol", new CharType)))
+
+ val sparse_index_length = new FortranFunctionSignature[IntType]("sparse_index_length",
+ Seq(("mat", OnetepTypes.SPAM3)))
+
+ val sparse_generate_index = new FortranSubroutineSignature("sparse_generate_index",
+ Seq(("idx", new ArrayType[IntType](1)),
+ ("mat", OnetepTypes.SPAM3)))
+
+ val sparse_atom_of_elem = new FortranFunctionSignature[IntType]("sparse_atom_of_elem",
+ Seq(("elem", new IntType),
+ ("mat", OnetepTypes.SPAM3),
+ ("rowcol", new CharType)))
}
object OnetepVariables {
// parallel_strategy
+ val pub_my_node_id = new NamedUnboundVarSymbol[IntType]("pub_my_node_id")
val pub_first_atom_on_node = new NamedUnboundVarSymbol[ArrayType[IntType]]("pub_first_atom_on_node", new ArrayType[IntType](1))
val pub_num_atoms_on_node = new NamedUnboundVarSymbol[ArrayType[IntType]]("pub_num_atoms_on_node", new ArrayType[IntType](1))
}
class IterationInfo(val context: IterationContext, val indexMappings: Map[NamedIndex, Expression[IntType]]) {
def merge(other: IterationInfo) : IterationInfo =
new IterationInfo(context merge other.context, indexMappings ++ other.indexMappings)
+ def getContext = context
+ def getIndexMappings = indexMappings
}
trait Operand {
def getIterationInfo : IterationInfo = {
val context = new IterationContext
- throw new ofc.UnimplementedException("not yet implemented.")
+
+ // Create sparse index
+ val header = new BlockStatement
+ val indexLength = new FunctionCall(OnetepFunctions.sparse_index_length, Seq(mat))
+ val index = new DeclaredVarSymbol[ArrayType[IntType]]("sparse_idx", new ArrayType[IntType](1))
+ header += new AllocateStatement(index, Seq(indexLength))
+ header += new FunctionCallStatement(new FunctionCall(OnetepFunctions.sparse_generate_index, Seq(index, mat)))
+ context.addHeader(header)
+ context.addDeclaration(index)
+
+ val footer = new DeallocateStatement(index)
+ context.addFooter(footer)
+
+ val firstCol = new FunctionCall(OnetepFunctions.sparse_first_elem_on_node,
+ Seq(OnetepVariables.pub_first_atom_on_node.at(OnetepVariables.pub_my_node_id),
+ mat,
+ new CharLiteral('C')))
+
+ val lastCol = new FunctionCall(OnetepFunctions.sparse_first_elem_on_node,
+ Seq(OnetepVariables.pub_first_atom_on_node.at((OnetepVariables.pub_my_node_id: Expression[IntType])+1),
+ mat,
+ new CharLiteral('C'))) - 1
+
+ val firstRow = new FunctionCall(OnetepFunctions.sparse_first_elem_on_node,
+ Seq(OnetepVariables.pub_first_atom_on_node.at(OnetepVariables.pub_my_node_id),
+ mat,
+ new CharLiteral('R')))
+
+ val lastRow = new FunctionCall(OnetepFunctions.sparse_first_elem_on_node,
+ Seq(OnetepVariables.pub_first_atom_on_node.at((OnetepVariables.pub_my_node_id: Expression[IntType])+1),
+ mat,
+ new CharLiteral('R'))) - 1
+
+ val col = context.addIteration("col", firstCol, lastCol)
+ val colAtom = context.addExpression("col_atom", new FunctionCall(OnetepFunctions.sparse_atom_of_elem,
+ Seq(col, mat, new CharLiteral('C'))))
+ val localColAtom = context.addExpression("local_col_atom",
+ colAtom - OnetepVariables.pub_first_atom_on_node.at(OnetepVariables.pub_my_node_id))
+
+ val row = context.addIteration("row", firstRow, lastRow)
+ val rowAtom = context.addExpression("row_atom", new FunctionCall(OnetepFunctions.sparse_atom_of_elem,
+ Seq(row, mat, new CharLiteral('R'))))
+
+ val rowIdx = context.addIteration("row_idx", index.at(localColAtom),
+ index.at((localColAtom: Expression[IntType])+1)-1)
+ context.addPredicate(index.at(rowIdx) |==| rowAtom)
+
+ var indexMappings : Map[NamedIndex, Expression[IntType]] = Map.empty
+ indexMappings += indices(0) -> row
+ indexMappings += indices(1) -> col
+
+ new IterationInfo(context, indexMappings)
}
}