]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Generate SPAM3 iteration code.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Tue, 8 May 2012 15:46:03 +0000 (16:46 +0100)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Tue, 8 May 2012 15:46:03 +0000 (16:46 +0100)
src/ofc/codegen/FortranGenerator.scala
src/ofc/codegen/FunctionSignature.scala
src/ofc/codegen/IterationContext.scala
src/ofc/generators/onetep/CodeGenerator.scala
src/ofc/generators/onetep/OnetepFunctions.scala
src/ofc/generators/onetep/OnetepVariables.scala
src/ofc/generators/onetep/Operand.scala
src/ofc/generators/onetep/SPAM3.scala

index aa73dd16e95857ebb91b2c8752a388872a2d9a6c..240123eca729e0eae9e04e5adcf69ae9d97770b3 100644 (file)
@@ -177,6 +177,7 @@ class FortranGenerator {
       case (c: NumericOperator[_]) => buildNumericOperator(c)
       case (c: Conversion[_,_]) => buildConversion(c)
       case (i: Intrinsic[_]) => buildIntrinsic(i)
+      case (f: FunctionCall[_]) => buildFunctionCall(f)
       case x => throw new UnimplementedException("Unknown expression type in FORTRAN generator: " + x.toString)
     }
   }
@@ -256,10 +257,16 @@ class FortranGenerator {
     call.getSignature match {
       case (fortSub: FortranSubroutineSignature) => 
         addLine("call %s(%s)".format(fortSub.getName, call.getParams.map(buildExpression(_)).mkString(", ")))
-      case _ => throw new LogicError("Fortran generator only knows how to call Fortran functions.")
+      case _ => throw new LogicError("Fortran generator only knows how to call Fortran sub-routines.")
     }
   }
 
+  private def buildFunctionCall(call: FunctionCall[_]) : ExpHolder = call.getSignature match {
+    case (fortFunc: FortranFunctionSignature[_]) => 
+      new ExpHolder(maxPrec, "%s(%s)".format(fortFunc.getName, call.getParams.map(buildExpression(_)).mkString(", ")))
+    case _ => throw new LogicError("Fortran generator only knows how to call Fortran functions.")
+  }
+
   private def processScope(scope: ScopeStatement) {
     for (sym <- scope.getDeclarations) {
       symbolManager.addSymbol(sym)
index b27d8a8741da472e2e8b593beedf54f4251c2c5a..75f1c874c9e1772195a83d453a2e1941d9c8c872 100644 (file)
@@ -13,3 +13,12 @@ class FortranSubroutineSignature(name: String,
   def getReturnType = new VoidType
   def getParams = params
 }
+
+class FortranFunctionSignature[R <: Type](name: String, 
+  params: Seq[(String, Type)], retType: R) extends FunctionSignature[R] {
+
+  def this(name: String, params: Seq[(String, Type)])(implicit builder: TypeBuilder[R]) = this(name, params, builder())
+  def getName = name
+  def getReturnType =  retType
+  def getParams = params
+}
index c6876aa0e87be8cb95aaaa3e02405413cb5ca350..355576f83475f588e69a51611bdcd35de261463c 100644 (file)
@@ -83,11 +83,17 @@ object IterationContext {
 class IterationContext extends Statement {
   import IterationContext._
 
-  var statement = new Comment("Placeholder statement for consumer.")
+  var declarations : Seq[VarSymbol[_ <: Type]] = Nil
+  var headers : Seq[Statement] = Nil
+  var footers : Seq[Statement] = Nil
   var ranges : Seq[VariableRange] = Nil
   var predicates : Seq[Predicate] = Nil
   var expressions : Seq[DerivedExpression] = Nil
 
+  def addDeclaration(declaration: VarSymbol[_ <: Type]) {
+    declarations +:= declaration
+  }
+
   def addExpression[T <: Type](name: String, expression: Expression[T]) : VarSymbol[T] = {
     val symbol = new DeclaredVarSymbol[T](name, expression.getType)
     expressions +:= new DerivedExpression(symbol, expression)
@@ -103,22 +109,39 @@ class IterationContext extends Statement {
   def addPredicate(condition: Expression[BoolType]) {
     predicates +:= new Predicate(condition)
   }
+  
+  def addHeader(stat: Statement) {
+    headers +:= stat
+  }
+
+  def addFooter(stat: Statement) {
+    footers +:= stat
+  }
 
   def merge(statement: IterationContext) : IterationContext = {
     val result = new IterationContext
+    result.declarations = declarations ++ statement.declarations
     result.ranges = ranges ++ statement.ranges
     result.predicates = predicates ++ statement.predicates
     result.expressions = expressions ++ statement.expressions
+    result.headers = headers ++ statement.headers
+    result.footers = footers ++ statement.footers
     result
   }
 
-  def toConcrete : Statement = {
+  def toConcrete(statement: Statement) : Statement = {
     val contexts = ranges ++ predicates ++ expressions
     val sortedContexts = Context.sort(contexts)
 
     val block = new BlockStatement
     var scope : ScopeStatement = block
 
+    for(declaration <- declarations)
+      block.addDeclaration(declaration)
+
+    for(header <- headers)
+      block += header
+
     for (context <- sortedContexts) {
       context match {
         case VariableRange(sym, first, last) => {
@@ -141,6 +164,13 @@ class IterationContext extends Statement {
     }
     
     scope += statement
+
+    for(footer <- footers)
+      block += footer
+
     block
   }
+
+  def toConcrete : Statement = 
+    toConcrete(new Comment("Placeholder statement for consumer."))
 }
index 3b6a6843bfe9a72a5cdebafd245d70feca736588..e470d2ad073a8407c0ce3a47a9858945565703dc 100644 (file)
@@ -2,6 +2,7 @@ package ofc.generators.onetep
 import ofc.codegen._
 
 class CodeGenerator(dictionary: Dictionary) {
+  /*
   val indexSyms : Map[NamedIndex, DeclaredVarSymbol[IntType]] = {
     for(index <- dictionary.getIndices) yield
       (index, new DeclaredVarSymbol[IntType](index.getName))
@@ -11,6 +12,7 @@ class CodeGenerator(dictionary: Dictionary) {
     for((index, sym) <- indexSyms) yield
       (index, sym: Expression[IntType])
   }.toMap
+  */
 
   class Context extends GenerationContext {
     val block = new BlockStatement
@@ -30,19 +32,18 @@ class CodeGenerator(dictionary: Dictionary) {
     val lhs = assignment.lhs
     val rhs = assignment.rhs
 
+    val iterationInfo = lhs.getIterationInfo
     val context = new Context
 
-    //TODO: Remove me when symbols are created properly
-    for((index, sym) <- indexSyms)
-      context.addDeclaration(sym)
-
+    val indexMap = iterationInfo.getIndexMappings
     val rhsFragment = rhs.getFragment(indexMap)
 
     rhsFragment.setup(context)
     rhsFragment.teardown(context)
 
     val generator = new FortranGenerator
-    val code = generator(context.getStatement)
+    val iterated = iterationInfo.getContext.toConcrete(context.getStatement)
+    val code = generator(iterated)
     println(code)
   }
 }
index c056ebb74d00a514bada090deb4d4b5a15c81369..f862e3ccbdffa510e52f189750409867e11693e8 100644 (file)
@@ -25,4 +25,21 @@ object OnetepFunctions {
         ("rspc1", new ArrayType[FloatType](3)),
         ("rspc2", new ArrayType[FloatType](3)),
         ("gspc", new ArrayType[ComplexType](3))))
+
+  val sparse_first_elem_on_node = new FortranFunctionSignature[IntType]("sparse_first_elem_on_node",
+    Seq(("node", new IntType),
+        ("mat", OnetepTypes.SPAM3),
+        ("rowcol", new CharType)))
+
+  val sparse_index_length = new FortranFunctionSignature[IntType]("sparse_index_length",
+    Seq(("mat", OnetepTypes.SPAM3)))
+
+  val sparse_generate_index = new FortranSubroutineSignature("sparse_generate_index",
+    Seq(("idx", new ArrayType[IntType](1)),
+        ("mat", OnetepTypes.SPAM3)))
+
+  val sparse_atom_of_elem = new FortranFunctionSignature[IntType]("sparse_atom_of_elem",
+    Seq(("elem", new IntType),
+        ("mat", OnetepTypes.SPAM3),
+        ("rowcol", new CharType)))
 }
index f08d80ce1351d8806c9595858d0290db38532e21..7a38733e8a054983ab63a44465f581d30687247d 100644 (file)
@@ -3,6 +3,7 @@ import ofc.codegen._
 
 object OnetepVariables {
   // parallel_strategy
+  val pub_my_node_id = new  NamedUnboundVarSymbol[IntType]("pub_my_node_id")
   val pub_first_atom_on_node = new NamedUnboundVarSymbol[ArrayType[IntType]]("pub_first_atom_on_node", new ArrayType[IntType](1))
   val pub_num_atoms_on_node = new NamedUnboundVarSymbol[ArrayType[IntType]]("pub_num_atoms_on_node", new ArrayType[IntType](1))
 }
index 4f0d428e1d980b9437e6cfb5ee582fa64f9ec96c..d31822f438fb67580a6a4622479b9913d23e762e 100644 (file)
@@ -4,6 +4,8 @@ import ofc.codegen._
 class IterationInfo(val context: IterationContext, val indexMappings: Map[NamedIndex, Expression[IntType]]) {
   def merge(other: IterationInfo) : IterationInfo = 
     new IterationInfo(context merge other.context, indexMappings ++ other.indexMappings)
+  def getContext = context
+  def getIndexMappings = indexMappings
 }
 
 trait Operand {
index 38be25385d70544d7ed3150faf772d246966e008..33235d4c1cd200139d73a8371f3dbd441583b885 100644 (file)
@@ -19,6 +19,57 @@ class SPAM3(name : String, indices: Seq[NamedIndex]) extends Scalar {
 
   def getIterationInfo : IterationInfo = {
     val context = new IterationContext
-    throw new ofc.UnimplementedException("not yet implemented.")
+
+    // Create sparse index
+    val header = new BlockStatement
+    val indexLength = new FunctionCall(OnetepFunctions.sparse_index_length, Seq(mat))
+    val index = new DeclaredVarSymbol[ArrayType[IntType]]("sparse_idx", new ArrayType[IntType](1))
+    header += new AllocateStatement(index, Seq(indexLength))
+    header += new FunctionCallStatement(new FunctionCall(OnetepFunctions.sparse_generate_index, Seq(index, mat)))
+    context.addHeader(header)
+    context.addDeclaration(index)
+
+    val footer = new DeallocateStatement(index)
+    context.addFooter(footer)
+
+    val firstCol = new FunctionCall(OnetepFunctions.sparse_first_elem_on_node, 
+      Seq(OnetepVariables.pub_first_atom_on_node.at(OnetepVariables.pub_my_node_id),
+          mat,
+          new CharLiteral('C')))
+
+    val lastCol = new FunctionCall(OnetepFunctions.sparse_first_elem_on_node, 
+      Seq(OnetepVariables.pub_first_atom_on_node.at((OnetepVariables.pub_my_node_id: Expression[IntType])+1),
+          mat,
+          new CharLiteral('C'))) - 1
+
+    val firstRow = new FunctionCall(OnetepFunctions.sparse_first_elem_on_node, 
+      Seq(OnetepVariables.pub_first_atom_on_node.at(OnetepVariables.pub_my_node_id),
+          mat,
+          new CharLiteral('R')))
+
+    val lastRow = new FunctionCall(OnetepFunctions.sparse_first_elem_on_node, 
+      Seq(OnetepVariables.pub_first_atom_on_node.at((OnetepVariables.pub_my_node_id: Expression[IntType])+1),
+          mat,
+          new CharLiteral('R'))) - 1
+
+    val col = context.addIteration("col", firstCol, lastCol)
+    val colAtom = context.addExpression("col_atom", new FunctionCall(OnetepFunctions.sparse_atom_of_elem, 
+      Seq(col, mat, new CharLiteral('C'))))
+    val localColAtom = context.addExpression("local_col_atom", 
+      colAtom - OnetepVariables.pub_first_atom_on_node.at(OnetepVariables.pub_my_node_id))
+
+    val row = context.addIteration("row", firstRow, lastRow)
+    val rowAtom = context.addExpression("row_atom", new FunctionCall(OnetepFunctions.sparse_atom_of_elem, 
+      Seq(row, mat, new CharLiteral('R'))))
+
+    val rowIdx = context.addIteration("row_idx", index.at(localColAtom), 
+      index.at((localColAtom: Expression[IntType])+1)-1)
+    context.addPredicate(index.at(rowIdx) |==| rowAtom)
+
+    var indexMappings : Map[NamedIndex, Expression[IntType]] = Map.empty
+    indexMappings += indices(0) -> row
+    indexMappings += indices(1) -> col
+
+    new IterationInfo(context, indexMappings)
   }
 }