From 0fb9fdf4471991b4e3d1e56b1230b2afcac36aa5 Mon Sep 17 00:00:00 2001 From: Francis Russell Date: Sun, 20 May 2012 14:48:03 +0100 Subject: [PATCH] Initial work on generating function boiler-plate. --- src/ofc/codegen/FortranGenerator.scala | 38 ++++++++- src/ofc/codegen/Function.scala | 19 +++++ src/ofc/codegen/FunctionSignature.scala | 6 +- src/ofc/generators/Onetep.scala | 81 ++++++++++++------- src/ofc/generators/onetep/CodeGenerator.scala | 18 +---- .../generators/onetep/PPDFunctionSet.scala | 4 +- src/ofc/generators/onetep/SPAM3.scala | 3 +- src/ofc/parser/Statement.scala | 9 ++- 8 files changed, 121 insertions(+), 57 deletions(-) create mode 100644 src/ofc/codegen/Function.scala diff --git a/src/ofc/codegen/FortranGenerator.scala b/src/ofc/codegen/FortranGenerator.scala index 86c5b07..b65eb94 100644 --- a/src/ofc/codegen/FortranGenerator.scala +++ b/src/ofc/codegen/FortranGenerator.scala @@ -129,11 +129,35 @@ class FortranGenerator { def apply(stat: Statement) : String = { processStatement(stat) - buffer.prepend("\n") - buffer.prependAll(symbolManager.getDeclarations) + prependLine("\n") + prependLines(symbolManager.getDeclarations) FortranGenerator.wrapLines(buffer).mkString("\n") } + def apply(func: Function[_ <: Type]) : String = { + in + processStatement(func.getBlock) + prependLine("\n") + prependLines(symbolManager.getDeclarations) + out + + // parameters are only named after processing the body + val paramNames = func.getParameters.map(symbolManager.getName(_)) + val (header, footer) = func.getReturnType match { + case (_: VoidType) => { + val header = "subroutine " + func.getName + paramNames.mkString("(", ", ", ")") + val footer = "end subroutine" + (header, footer) + } + case _ => throw new UnimplementedException("Fortran function code generation not implemented.") + } + + prependLine(header) + addLine(footer) + FortranGenerator.wrapLines(buffer).mkString("\n") + } + + private def processStatement(stat: Statement) { stat match { case (x : NullStatement) => () @@ -313,4 +337,14 @@ class FortranGenerator { private def addLine(line: String) { buffer += " "*indentLevel + line } + + private def prependLine(line: String) { + buffer.prepend(" "*indentLevel + line) + } + + private def prependLines(lines: Seq[String]) { + for(line <- lines.reverse) + prependLine(line) + } + } diff --git a/src/ofc/codegen/Function.scala b/src/ofc/codegen/Function.scala new file mode 100644 index 0000000..fa8e517 --- /dev/null +++ b/src/ofc/codegen/Function.scala @@ -0,0 +1,19 @@ +package ofc.codegen + +class Function[R <: Type](name: String, retType: R) { + val block = new BlockStatement + var params : Seq[VarSymbol[_ <: Type]] = Nil + + def addParameter(param: VarSymbol[_ <: Type]) { + params :+= param + block.addDeclaration(param) + } + + def getName = name + + def getBlock = block + + def getParameters : Seq[VarSymbol[_ <: Type]] = params + + def getReturnType : Type = retType +} diff --git a/src/ofc/codegen/FunctionSignature.scala b/src/ofc/codegen/FunctionSignature.scala index 75f1c87..75b5b4e 100644 --- a/src/ofc/codegen/FunctionSignature.scala +++ b/src/ofc/codegen/FunctionSignature.scala @@ -17,8 +17,10 @@ class FortranSubroutineSignature(name: String, class FortranFunctionSignature[R <: Type](name: String, params: Seq[(String, Type)], retType: R) extends FunctionSignature[R] { - def this(name: String, params: Seq[(String, Type)])(implicit builder: TypeBuilder[R]) = this(name, params, builder()) + def this(name: String, params: Seq[(String, Type)])(implicit builder: TypeBuilder[R]) = + this(name, params, builder()) + def getName = name - def getReturnType = retType + def getReturnType = retType def getParams = params } diff --git a/src/ofc/generators/Onetep.scala b/src/ofc/generators/Onetep.scala index 22b2440..38d8e60 100644 --- a/src/ofc/generators/Onetep.scala +++ b/src/ofc/generators/Onetep.scala @@ -2,19 +2,27 @@ package ofc.generators import ofc.InvalidInputException import ofc.parser +import ofc.codegen import ofc.expression import ofc.generators.onetep._ class Onetep extends Generator { val dictionary = new Dictionary + var parameters : Map[String, codegen.VarSymbol[_ <: codegen.Type]] = Map.empty + var functionIdentifiers : Option[(String, Seq[String])] = None def acceptInput(exprDictionary: expression.Dictionary, exprAssignment: expression.Assignment, targetSpecific : Seq[parser.TargetAssignment]) { buildDictionary(exprDictionary, targetSpecific) + val function = buildFunction(targetSpecific) val assignment = new Assignment(buildScalarExpression(exprAssignment.lhs), buildScalarExpression(exprAssignment.rhs)) - val codeGenerator = new CodeGenerator(dictionary) + val codeGenerator = new CodeGenerator(dictionary, function.getBlock) codeGenerator(assignment) + + val generator = new codegen.FortranGenerator + val code = generator(function) + println(code) } private def buildDictionary(exprDictionary: expression.Dictionary, targetSpecific : Seq[parser.TargetAssignment]) { @@ -37,6 +45,37 @@ class Onetep extends Generator { } } + private def buildFunction(targetSpecific : Seq[parser.TargetAssignment]) = { + import parser._ + import codegen._ + + val outputCall = targetSpecific.filter(_.id == Identifier("output")) match { + case Seq(x) => x.value + case Seq(_,_,_*) => throw new InvalidInputException("Too many output function specifications.") + case Nil => throw new InvalidInputException("No output function specification found.") + } + + outputCall match { + case FunctionCall(Identifier("FortranFunction"), callInfo) => callInfo match { + case ParameterList(StringParameter(funcName), funcParams: ParameterList) => { + val function = new Function(funcName, new VoidType) + + for(funcParam <- funcParams.toSeq) funcParam match { + case StringParameter(paramName) => parameters.get(paramName) match { + case Some(symbol) => function.addParameter(symbol) + case None => throw new InvalidInputException("Unable to find definition of parameter "+paramName) + } + case _ => throw new InvalidInputException("FortranFunction only takes string parameters") + } + + function + } + case _ => throw new InvalidInputException("FortranFunction takes a name and a parameter list.") + } + case _ => throw new InvalidInputException("Unknown output type "+outputCall.name) + } + } + private def getIndex(exprIndex: Seq[expression.Index]) : Seq[NamedIndex] = { for(index <- exprIndex) yield dictionary.getIndex(index.getIdentifier) @@ -81,11 +120,15 @@ class Onetep extends Generator { private def buildMatrix(id: parser.Identifier, call : Option[parser.FunctionCall]) { import parser._ + import codegen._ call match { case Some(FunctionCall(matType, params)) => (matType, params) match { - case (Identifier("SPAM3"), ParameterList(StringParameter(name))) => - dictionary.addScalar(id, new SPAM3(name, _: Seq[NamedIndex])) + case (Identifier("SPAM3"), ParameterList(StringParameter(name))) => { + val mat = new DeclaredVarSymbol[StructType](name, OnetepTypes.SPAM3) + parameters += (name -> mat) + dictionary.addScalar(id, new SPAM3(mat, _: Seq[NamedIndex])) + } case _ => throw new InvalidInputException("Unknown usage of type: "+matType.name) } case _ => throw new InvalidInputException("Undefined concrete type for matrix: "+id.name) @@ -94,38 +137,20 @@ class Onetep extends Generator { private def buildFunctionSet(id: parser.Identifier, call : Option[parser.FunctionCall]) { import parser._ + import codegen._ call match { case Some(FunctionCall(fSetType, params)) => (fSetType, params) match { - case (Identifier("PPDFunctionSet"), ParameterList(StringParameter(basis), StringParameter(data))) => + case (Identifier("PPDFunctionSet"), ParameterList(StringParameter(basisName), StringParameter(dataName))) => { + val basis = new DeclaredVarSymbol[StructType](basisName, OnetepTypes.FunctionBasis) + val data = new DeclaredVarSymbol[ArrayType[FloatType]](dataName, new ArrayType[FloatType](1)) + parameters += (basisName -> basis) + parameters += (dataName -> data) dictionary.addField(id, new PPDFunctionSet(basis, data, _: Seq[NamedIndex])) + } case _ => throw new InvalidInputException("Unknown usage of type: "+fSetType.name) } case _ => throw new InvalidInputException("Undefined concrete type for function set: "+id.name) } } - - /* - def buildBindingIndex(id: parser.Identifier, call : Option[parser.FunctionCall]) { - call match { - case Some(_) => throw new InvalidInputException("Index "+id.name+" cannot have concrete type.") - case None => dictionary.indices += (id -> new BindingIndex(id.name)) - } - } - - def buildDefinition(definition : parser.Definition) { - val builder = new TreeBuilder(dictionary) - val assignment = builder(definition.term, definition.expr) - val codeGenerator = new CodeGenerator(builder.getIndexBindings) - codeGenerator(assignment) - } - - def buildDefinitions(statements : Seq[parser.Statement]) { - val definitions = filterStatements[parser.Definition](statements) - if (definitions.size != 1) - throw new InvalidInputException("Input file should only contain a single definition.") - else - buildDefinition(definitions.head) - } - */ } diff --git a/src/ofc/generators/onetep/CodeGenerator.scala b/src/ofc/generators/onetep/CodeGenerator.scala index d93d0c0..1d47395 100644 --- a/src/ofc/generators/onetep/CodeGenerator.scala +++ b/src/ofc/generators/onetep/CodeGenerator.scala @@ -1,19 +1,7 @@ package ofc.generators.onetep import ofc.codegen._ -class CodeGenerator(dictionary: Dictionary) { - /* - val indexSyms : Map[NamedIndex, DeclaredVarSymbol[IntType]] = { - for(index <- dictionary.getIndices) yield - (index, new DeclaredVarSymbol[IntType](index.getName)) - }.toMap - - val indexMap : Map[NamedIndex, Expression[IntType]] = { - for((index, sym) <- indexSyms) yield - (index, sym: Expression[IntType]) - }.toMap - */ - +class CodeGenerator(dictionary: Dictionary, scope: ScopeStatement) { class Context extends GenerationContext { val block = new BlockStatement @@ -43,9 +31,7 @@ class CodeGenerator(dictionary: Dictionary) { lhsFragment.setValue(context, rhsFragment.getValue) rhsFragment.teardown(context) - val generator = new FortranGenerator val iterated = iterationInfo.getContext.toConcrete(context.getStatement) - val code = generator(iterated) - println(code) + scope += iterated; } } diff --git a/src/ofc/generators/onetep/PPDFunctionSet.scala b/src/ofc/generators/onetep/PPDFunctionSet.scala index 24b68f3..cabffb6 100644 --- a/src/ofc/generators/onetep/PPDFunctionSet.scala +++ b/src/ofc/generators/onetep/PPDFunctionSet.scala @@ -98,9 +98,7 @@ object PPDFunctionSet { } */ -class PPDFunctionSet(basisName: String, dataName: String, indices: Seq[NamedIndex]) extends Field { - val basis = new NamedUnboundVarSymbol[StructType](basisName, OnetepTypes.FunctionBasis) - val data = new NamedUnboundVarSymbol[ArrayType[FloatType]](dataName, new ArrayType[FloatType](1)) +class PPDFunctionSet(val basis: Expression[StructType], val data: Expression[ArrayType[FloatType]], indices: Seq[NamedIndex]) extends Field { class LocalFragment(parent: PPDFunctionSet, indices: Map[NamedIndex, Expression[IntType]]) extends PsincFragment { def setup(context: GenerationContext) {} diff --git a/src/ofc/generators/onetep/SPAM3.scala b/src/ofc/generators/onetep/SPAM3.scala index 2384f4b..a76a7b7 100644 --- a/src/ofc/generators/onetep/SPAM3.scala +++ b/src/ofc/generators/onetep/SPAM3.scala @@ -1,8 +1,7 @@ package ofc.generators.onetep import ofc.codegen._ -class SPAM3(name : String, position: Seq[NamedIndex]) extends Scalar { - val mat = new NamedUnboundVarSymbol[StructType](name, OnetepTypes.SPAM3) +class SPAM3(mat: Expression[StructType], position: Seq[NamedIndex]) extends Scalar { class LocalFragment(row: Expression[IntType], col: Expression[IntType]) extends ScalarFragment { def setup(context: GenerationContext) { diff --git a/src/ofc/parser/Statement.scala b/src/ofc/parser/Statement.scala index 94e6487..6c6414f 100644 --- a/src/ofc/parser/Statement.scala +++ b/src/ofc/parser/Statement.scala @@ -59,10 +59,11 @@ case class FunctionCall(name: Identifier, params: ParameterList) { sealed abstract class Parameter case class ParameterList(params: Parameter*) extends Parameter { override def toString : String = params.mkString("[", ", ", "]") + def toSeq : Seq[Parameter] = params.toSeq } -case class StringParameter(s: String) extends Parameter { - override def toString : String = "\""+s+"\"" +case class StringParameter(value: String) extends Parameter { + override def toString : String = "\""+value+"\"" } -case class NumericParameter(s: Double) extends Parameter { - override def toString : String = s.toString +case class NumericParameter(value: Double) extends Parameter { + override def toString : String = value.toString } -- 2.47.3