]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Initial work on generating function boiler-plate.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Sun, 20 May 2012 13:48:03 +0000 (14:48 +0100)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Sun, 20 May 2012 13:48:03 +0000 (14:48 +0100)
src/ofc/codegen/FortranGenerator.scala
src/ofc/codegen/Function.scala [new file with mode: 0644]
src/ofc/codegen/FunctionSignature.scala
src/ofc/generators/Onetep.scala
src/ofc/generators/onetep/CodeGenerator.scala
src/ofc/generators/onetep/PPDFunctionSet.scala
src/ofc/generators/onetep/SPAM3.scala
src/ofc/parser/Statement.scala

index 86c5b07c573661bb63eca1c22e2c978dcdd98af5..b65eb94a32ff3abc0736c9f57df97ad9d29f101a 100644 (file)
@@ -129,11 +129,35 @@ class FortranGenerator {
   def apply(stat: Statement) : String = {
     processStatement(stat)
 
-    buffer.prepend("\n")
-    buffer.prependAll(symbolManager.getDeclarations)
+    prependLine("\n")
+    prependLines(symbolManager.getDeclarations)
     FortranGenerator.wrapLines(buffer).mkString("\n")
   }
 
+  def apply(func: Function[_ <: Type]) : String = {
+    in
+    processStatement(func.getBlock)
+    prependLine("\n")
+    prependLines(symbolManager.getDeclarations)
+    out
+
+    // parameters are only named after processing the body
+    val paramNames = func.getParameters.map(symbolManager.getName(_)) 
+    val (header, footer) = func.getReturnType match {
+      case (_: VoidType) => {
+        val header = "subroutine " + func.getName + paramNames.mkString("(", ", ", ")")
+        val footer = "end subroutine"
+        (header, footer)
+      }
+      case _ => throw new UnimplementedException("Fortran function code generation not implemented.")
+    }
+
+    prependLine(header)
+    addLine(footer)
+    FortranGenerator.wrapLines(buffer).mkString("\n")
+  }
+
+
   private def processStatement(stat: Statement) {
     stat match {
       case (x : NullStatement) => ()
@@ -313,4 +337,14 @@ class FortranGenerator {
   private def addLine(line: String) {
     buffer += "  "*indentLevel + line
   }
+
+  private def prependLine(line: String) {
+    buffer.prepend("  "*indentLevel + line)
+  }
+
+  private def prependLines(lines: Seq[String]) {
+    for(line <- lines.reverse)
+      prependLine(line)
+  }
+
 }
diff --git a/src/ofc/codegen/Function.scala b/src/ofc/codegen/Function.scala
new file mode 100644 (file)
index 0000000..fa8e517
--- /dev/null
@@ -0,0 +1,19 @@
+package ofc.codegen
+
+class Function[R <: Type](name: String, retType: R) {
+  val block = new BlockStatement
+  var params : Seq[VarSymbol[_ <: Type]] = Nil
+
+  def addParameter(param: VarSymbol[_ <: Type]) {
+    params :+= param
+    block.addDeclaration(param)
+  }
+
+  def getName = name
+
+  def getBlock = block
+
+  def getParameters : Seq[VarSymbol[_ <: Type]] = params
+
+  def getReturnType : Type = retType
+}
index 75f1c874c9e1772195a83d453a2e1941d9c8c872..75b5b4e0adab2aa20f7e553a40c73321b331b430 100644 (file)
@@ -17,8 +17,10 @@ class FortranSubroutineSignature(name: String,
 class FortranFunctionSignature[R <: Type](name: String, 
   params: Seq[(String, Type)], retType: R) extends FunctionSignature[R] {
 
-  def this(name: String, params: Seq[(String, Type)])(implicit builder: TypeBuilder[R]) = this(name, params, builder())
+  def this(name: String, params: Seq[(String, Type)])(implicit builder: TypeBuilder[R]) = 
+    this(name, params, builder())
+
   def getName = name
-  def getReturnType =  retType
+  def getReturnType = retType
   def getParams = params
 }
index 22b24401f52a4d4ddfeb2952dba842d00fa9712d..38d8e60c1920f5d9f4cd0108e6ffe72444d39b9d 100644 (file)
@@ -2,19 +2,27 @@ package ofc.generators
 
 import ofc.InvalidInputException
 import ofc.parser
+import ofc.codegen
 import ofc.expression
 import ofc.generators.onetep._
 
 class Onetep extends Generator {
   val dictionary = new Dictionary
+  var parameters : Map[String, codegen.VarSymbol[_ <: codegen.Type]] = Map.empty
+  var functionIdentifiers : Option[(String, Seq[String])] = None
 
   def acceptInput(exprDictionary: expression.Dictionary, exprAssignment: 
     expression.Assignment, targetSpecific : Seq[parser.TargetAssignment]) {
 
     buildDictionary(exprDictionary, targetSpecific)
+    val function = buildFunction(targetSpecific)
     val assignment = new Assignment(buildScalarExpression(exprAssignment.lhs), buildScalarExpression(exprAssignment.rhs))
-    val codeGenerator = new CodeGenerator(dictionary)
+    val codeGenerator = new CodeGenerator(dictionary, function.getBlock)
     codeGenerator(assignment)
+
+    val generator = new codegen.FortranGenerator
+    val code = generator(function)
+    println(code)
   }
 
   private def buildDictionary(exprDictionary: expression.Dictionary, targetSpecific : Seq[parser.TargetAssignment]) {
@@ -37,6 +45,37 @@ class Onetep extends Generator {
     }
   }
 
+  private def buildFunction(targetSpecific : Seq[parser.TargetAssignment]) = {
+    import parser._
+    import codegen._
+
+    val outputCall = targetSpecific.filter(_.id == Identifier("output")) match {
+      case Seq(x) => x.value
+      case Seq(_,_,_*) => throw new InvalidInputException("Too many output function specifications.")
+      case Nil => throw new InvalidInputException("No output function specification found.")
+    }
+
+    outputCall match {
+      case FunctionCall(Identifier("FortranFunction"), callInfo) => callInfo match {
+        case ParameterList(StringParameter(funcName), funcParams: ParameterList) => {
+          val function = new Function(funcName, new VoidType)
+
+          for(funcParam <- funcParams.toSeq) funcParam match {
+            case StringParameter(paramName) => parameters.get(paramName) match {
+              case Some(symbol) => function.addParameter(symbol)
+              case None => throw new InvalidInputException("Unable to find definition of parameter "+paramName)
+            }
+            case _ => throw new InvalidInputException("FortranFunction only takes string parameters")
+          }
+
+          function
+        }
+        case _ => throw new InvalidInputException("FortranFunction takes a name and a parameter list.")
+      }
+      case _ => throw new InvalidInputException("Unknown output type "+outputCall.name)
+    }
+  }
+
   private def getIndex(exprIndex: Seq[expression.Index]) : Seq[NamedIndex] = {
     for(index <- exprIndex) yield
       dictionary.getIndex(index.getIdentifier)
@@ -81,11 +120,15 @@ class Onetep extends Generator {
 
   private def buildMatrix(id: parser.Identifier, call : Option[parser.FunctionCall]) {
     import parser._
+    import codegen._
 
     call match {
       case Some(FunctionCall(matType, params)) => (matType, params) match {
-        case (Identifier("SPAM3"), ParameterList(StringParameter(name))) => 
-          dictionary.addScalar(id, new SPAM3(name, _: Seq[NamedIndex]))
+        case (Identifier("SPAM3"), ParameterList(StringParameter(name))) => {
+          val mat = new DeclaredVarSymbol[StructType](name, OnetepTypes.SPAM3)
+          parameters += (name -> mat)
+          dictionary.addScalar(id, new SPAM3(mat, _: Seq[NamedIndex]))
+        }
         case _ => throw new InvalidInputException("Unknown usage of type: "+matType.name)
       }
       case _ => throw new InvalidInputException("Undefined concrete type for matrix: "+id.name)
@@ -94,38 +137,20 @@ class Onetep extends Generator {
 
   private def buildFunctionSet(id: parser.Identifier, call : Option[parser.FunctionCall]) {
     import parser._
+    import codegen._
 
     call match {
       case Some(FunctionCall(fSetType, params)) => (fSetType, params) match {
-        case (Identifier("PPDFunctionSet"), ParameterList(StringParameter(basis), StringParameter(data))) => 
+        case (Identifier("PPDFunctionSet"), ParameterList(StringParameter(basisName), StringParameter(dataName))) => {
+          val basis = new DeclaredVarSymbol[StructType](basisName, OnetepTypes.FunctionBasis)
+          val data =  new DeclaredVarSymbol[ArrayType[FloatType]](dataName, new ArrayType[FloatType](1))
+          parameters += (basisName -> basis)
+          parameters += (dataName -> data)
           dictionary.addField(id, new PPDFunctionSet(basis, data, _: Seq[NamedIndex]))
+        }
         case _ => throw new InvalidInputException("Unknown usage of type: "+fSetType.name)
       }
       case _ => throw new InvalidInputException("Undefined concrete type for function set: "+id.name)
     }
   }
-
-  /*
-  def buildBindingIndex(id: parser.Identifier, call : Option[parser.FunctionCall]) {
-    call match {
-      case Some(_) => throw new InvalidInputException("Index "+id.name+" cannot have concrete type.")
-      case None => dictionary.indices += (id -> new BindingIndex(id.name))
-    }
-  }
-
-  def buildDefinition(definition : parser.Definition) {
-    val builder = new TreeBuilder(dictionary)
-    val assignment = builder(definition.term, definition.expr)
-    val codeGenerator = new CodeGenerator(builder.getIndexBindings)
-    codeGenerator(assignment)
-  }
-
-  def buildDefinitions(statements : Seq[parser.Statement]) {
-    val definitions = filterStatements[parser.Definition](statements)
-    if (definitions.size != 1)
-      throw new InvalidInputException("Input file should only contain a single definition.")
-    else 
-      buildDefinition(definitions.head)
-  }
-  */
 }
index d93d0c0482e57399ffa0d9a408b2cbf10f96a75f..1d47395c064510fde958929bc73c7482c9460793 100644 (file)
@@ -1,19 +1,7 @@
 package ofc.generators.onetep
 import ofc.codegen._
 
-class CodeGenerator(dictionary: Dictionary) {
-  /*
-  val indexSyms : Map[NamedIndex, DeclaredVarSymbol[IntType]] = {
-    for(index <- dictionary.getIndices) yield
-      (index, new DeclaredVarSymbol[IntType](index.getName))
-  }.toMap
-
-  val indexMap : Map[NamedIndex, Expression[IntType]] = {
-    for((index, sym) <- indexSyms) yield
-      (index, sym: Expression[IntType])
-  }.toMap
-  */
-
+class CodeGenerator(dictionary: Dictionary, scope: ScopeStatement) {
   class Context extends GenerationContext {
     val block = new BlockStatement
 
@@ -43,9 +31,7 @@ class CodeGenerator(dictionary: Dictionary) {
     lhsFragment.setValue(context, rhsFragment.getValue)
     rhsFragment.teardown(context)
 
-    val generator = new FortranGenerator
     val iterated = iterationInfo.getContext.toConcrete(context.getStatement)
-    val code = generator(iterated)
-    println(code)
+    scope += iterated;
   }
 }
index 24b68f39238f6cedcfb93bac26b0e6282f2837c5..cabffb6aa09e8bd7833884d86839d0e9ee66c0d7 100644 (file)
@@ -98,9 +98,7 @@ object PPDFunctionSet {
 }
 */
 
-class PPDFunctionSet(basisName: String, dataName: String, indices: Seq[NamedIndex]) extends Field {
-  val basis = new NamedUnboundVarSymbol[StructType](basisName, OnetepTypes.FunctionBasis)
-  val data =  new NamedUnboundVarSymbol[ArrayType[FloatType]](dataName, new ArrayType[FloatType](1))
+class PPDFunctionSet(val basis: Expression[StructType], val data: Expression[ArrayType[FloatType]], indices: Seq[NamedIndex]) extends Field {
 
   class LocalFragment(parent: PPDFunctionSet, indices: Map[NamedIndex, Expression[IntType]]) extends PsincFragment {
     def setup(context: GenerationContext) {}
index 2384f4bceb873244634523a5a20f91f0598cdb73..a76a7b7e6125c3d311a234f9dac55c03368900c0 100644 (file)
@@ -1,8 +1,7 @@
 package ofc.generators.onetep
 import ofc.codegen._
 
-class SPAM3(name : String, position: Seq[NamedIndex]) extends Scalar {
-  val mat = new NamedUnboundVarSymbol[StructType](name, OnetepTypes.SPAM3)
+class SPAM3(mat: Expression[StructType], position: Seq[NamedIndex]) extends Scalar {
 
   class LocalFragment(row: Expression[IntType], col: Expression[IntType]) extends ScalarFragment {
     def setup(context: GenerationContext) {
index 94e6487bf9215eab063f422d8b4747da89edbf0c..6c6414f14e0ccc128e8be95f79a6ba27b8a3ace5 100644 (file)
@@ -59,10 +59,11 @@ case class FunctionCall(name: Identifier, params: ParameterList) {
 sealed abstract class Parameter
 case class ParameterList(params: Parameter*) extends Parameter {
   override def toString : String = params.mkString("[", ", ", "]")
+  def toSeq : Seq[Parameter] = params.toSeq
 }
-case class StringParameter(s: String)  extends Parameter {
-  override def toString : String = "\""+s+"\""
+case class StringParameter(value: String)  extends Parameter {
+  override def toString : String = "\""+value+"\""
 }
-case class NumericParameter(s: Double)  extends Parameter {
-  override def toString : String = s.toString
+case class NumericParameter(value: Double)  extends Parameter {
+  override def toString : String = value.toString
 }