+++ /dev/null
-package ofc
-
-import scala.reflect.Manifest
-import scala.reflect.Manifest.singleType
-import java.io.FileReader
-import parser.{Parser,Statement,Target,TargetAssignment,Identifier,ParseException,Definition}
-import expression.{Dictionary,TreeBuilder}
-import generators.Generator
-
-class InvalidInputException(s: String) extends Exception(s)
-class UnimplementedException(s: String) extends Exception(s)
-class LogicError(s: String) extends Exception(s)
-
-object OFC extends Parser {
-
- def main(args: Array[String]) {
- if (args.isEmpty) {
- Console.err.println("Specify input file.")
- System.exit(1)
- }
-
- val reader = new FileReader(args(0))
-
- try {
- val program = parseProgram(reader)
- processAST(program)
- } catch {
- case e: ParseException => {
- Console.err.println("Parse failure: "+e)
- System.exit(1)
- }
- case e: InvalidInputException => {
- Console.err.println("Semantic Error: "+e)
- System.exit(1)
- }
- } finally {
- reader.close
- }
- }
-
- private def filterStatements[T <: parser.Statement](statements : Seq[parser.Statement])(implicit m: Manifest[T]) =
- statements.foldLeft(List[T]())((list, item) => item match {
- case s if (singleType(s) <:< m) => s.asInstanceOf[T] :: list
- case _ => list
- })
-
- private def getDeclarations(statements : Seq[parser.Statement]) : Map[parser.Identifier, parser.OFLType] = {
- def getMappings(dl : parser.DeclarationList) =
- for (name <- dl.names) yield
- (name, dl.oflType)
-
- filterStatements[parser.DeclarationList](statements).flatMap(getMappings(_)).toMap
- }
-
- private def buildDictionary(declarations : Map[parser.Identifier, parser.OFLType]) : Dictionary = {
- import expression.{Matrix,FunctionSet,Index}
- val dictionary = new Dictionary
-
- for(d <- declarations) {
- // Find corresponding target-specific declaration if it exists.
- d match {
- case (id, parser.Matrix()) => dictionary.add(new Matrix(id))
- case (id, parser.FunctionSet()) => dictionary.add(new FunctionSet(id))
- case (id, parser.Index()) => dictionary.add(new Index(id))
- }
- }
-
- dictionary
- }
-
- private def processAST(statements : Seq[Statement]) = {
- val declarations = getDeclarations(statements)
- val dictionary = buildDictionary(declarations)
- val treeBuilder = new TreeBuilder(dictionary)
-
- val definitions = filterStatements[Definition](statements)
-
- val definition = definitions match {
- case Seq(singleDef) => singleDef
- case _ => throw new InvalidInputException("OFL file should only have one definition.")
- }
-
- val expressionTree = treeBuilder(definition)
- val targetStatements = filterStatements[Target](statements)
-
- val generator : Generator = targetStatements match {
- case Seq(Target(Identifier("ONETEP"))) => new generators.Onetep
- case Seq(Target(Identifier(x))) => throw new InvalidInputException("Unknown target: " + x)
- case _ => throw new InvalidInputException("OFL file should have single target statement.")
- }
-
- val targetAssignments = filterStatements[TargetAssignment](statements)
- generator.acceptInput(dictionary, expressionTree, targetAssignments)
- }
-}
+++ /dev/null
-package ofc.codegen
-import ofc.LogicError
-
-class AllocateStatement(array: Expression[_], size: Seq[Expression[IntType]]) extends Statement {
- val arrayType = array.getType match {
- case (at: ArrayType[_]) => at
- case _ => throw new LogicError("Can only allocate an array expression.")
- }
-
- if (size.length != arrayType.getRank)
- throw new LogicError("Incorrect rank for array allocation.")
-
- def getArray : Expression[_] = array
- def getSize : Seq[Expression[IntType]] = size
- def getExpressions = array +: size
-}
+++ /dev/null
-package ofc.codegen
-import ofc.LogicError
-
-class AssignStatement(lhs: Expression[_ <: Type], rhs: Expression[_ <: Type]) extends Statement {
- if (lhs.getType != rhs.getType)
- throw new LogicError("Assignment from incompatible type.")
-
- def getLHS : Expression[_] = lhs
- def getRHS : Expression[_] = rhs
- def getExpressions = List(lhs, rhs)
-}
+++ /dev/null
-package ofc.codegen
-
-class Comment(value: String) extends Statement {
- def getValue = value
- def getExpressions = Nil
-}
+++ /dev/null
-package ofc.codegen
-import ofc.LogicError
-
-class ConditionalValue[T <: Type](predicate: Expression[BoolType], ifTrue: Expression[T], ifFalse: Expression[T]) extends Expression[T] {
- def foreach[U](f: Expression[_] => U) = List(predicate, ifTrue, ifFalse).foreach(f)
- def getPredicate = predicate
- def getIfTrue = ifTrue
- def getIfFalse = ifFalse
- def getType = {
- if (ifTrue.getType != ifFalse.getType)
- throw new LogicError("Parameters to ternary operator have different types")
- else
- ifTrue.getType
- }
-}
+++ /dev/null
-package ofc.codegen
-import ofc.LogicError
-
-class DeallocateStatement(array: Expression[_ <: Type]) extends Statement {
- array.getType match {
- case (_: ArrayType[_]) => ()
- case _ => throw new LogicError("Can only deallocate an array expression.")
- }
-
- def getArray : Expression[_] = array
- def getExpressions = List(array)
-}
+++ /dev/null
-package ofc.codegen
-import ofc.LogicError
-
-object Expression {
- implicit def fromInt(i: Int) : Expression[IntType] = new IntegerLiteral(i)
-
- def findReferencedVariables(expression: Expression[_]) : Set[VarSymbol[_]] =
- findReferencedVariables(Some[Expression[_]](expression))
-
- def findReferencedVariables(expressions: Traversable[Expression[_]]) : Set[VarSymbol[_]] = {
- val vars = scala.collection.mutable.Set[VarSymbol[_]]()
- expressions.foreach(_ match {
- case (ref: VarRef[_]) => vars += ref.getSymbol
- case x => vars ++= findReferencedVariables(x.toTraversable)
- })
-
- vars.toSet
- }
-}
-
-abstract class Expression[T <: Type] extends Traversable[Expression[_]] {
- def getType : T
-
- // Field Operations
- def +(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[T] =
- new NumericOperator[T](NumericOperations.Add, this, rhs)
-
- def -(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[T] =
- new NumericOperator[T](NumericOperations.Sub, this, rhs)
-
- def *(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[T] =
- new NumericOperator[T](NumericOperations.Mul, this, rhs)
-
- def /(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[T] =
- new NumericOperator[T](NumericOperations.Div, this, rhs)
-
- def %(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[T] =
- new NumericOperator[T](NumericOperations.Mod, this, rhs)
-
- // Comparison Operations
- def |<|(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[BoolType] =
- new NumericComparison[T](NumericOperations.LT, this, rhs)
-
- def |<=|(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[BoolType] =
- new NumericComparison[T](NumericOperations.LE, this, rhs)
-
- def |==|(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[BoolType] =
- new NumericComparison[T](NumericOperations.EQ, this, rhs)
-
- def |!=|(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[BoolType] =
- new NumericComparison[T](NumericOperations.NE, this, rhs)
-
- def |>|(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[BoolType] =
- new NumericComparison[T](NumericOperations.GT, this, rhs)
-
- def |>=|(rhs: Expression[T])(implicit arg: HasProperty[T, Numeric]) : Expression[BoolType] =
- new NumericComparison[T](NumericOperations.GE, this, rhs)
-
- def %[FieldType <: Type](field: FieldSymbol[FieldType])(implicit witness: <:<[this.type, Expression[StructType]]) : Expression[FieldType] =
- new FieldAccess[FieldType](witness(this), field)
-
- def at[T <: Type](index: Expression[IntType]*)(implicit witness: <:<[this.type, Expression[ArrayType[T]]]) : Expression[T] =
- new ArrayAccess(witness(this), index.toList)
-
- def unary_~[T <: Type]()(implicit witness: <:<[this.type, Expression[PointerType[T]]]) : Expression[T] =
- new PointerDereference(witness(this))
-}
-
-trait LeafExpression {
- def foreach[U](f: Expression[_] => U): Unit = ()
-}
-
-// Variable references
-class VarRef[T <: Type](symbol: VarSymbol[T]) extends Expression[T] with LeafExpression {
- def getSymbol = symbol
- def getType = symbol.getType
-}
-
-class FieldRef[T <: Type](symbol: FieldSymbol[T]) extends Expression[T] with LeafExpression {
- def getSymbol = symbol
- def getType = symbol.getType
-}
-
-// Struct and array accesses
-class FieldAccess[T <: Type](expression: Expression[StructType], field: FieldSymbol[T]) extends Expression[T] {
- def foreach[U](f: Expression[_] => U) = f(expression)
- def getStructExpression = expression
- def getField = field
- def getType = field.getType
-}
-
-class ArrayAccess[E <: Type](expression: Expression[ArrayType[E]], index: Seq[Expression[IntType]]) extends Expression[E] {
- if (index.size != expression.getType.getRank)
- throw new LogicError("Array of rank "+expression.getType.getRank+" indexed with rank "+index.size+" index.")
-
- def foreach[U](f: Expression[_] => U) = (index :+ expression).foreach(f)
- def getArrayExpression = expression
- def getIndexExpressions = index
- def getType = expression.getType.getElementType
-}
-class PointerDereference[E <: Type](expression: Expression[PointerType[E]]) extends Expression[E] {
- def foreach[U](f: Expression[_] => U) = f(expression)
- def getExpression = expression
- def getType = expression.getType.getTargetType
-}
-
-// Conversation
-class Conversion[F <: Type, T <: Type](expression: Expression[F], toType: T)(implicit convertible: IsConvertible[F,T]) extends Expression[T] {
- def this(expression: Expression[F])(implicit builder: TypeBuilder[T], convertible: IsConvertible[F,T])
- = this(expression, builder())(convertible)
-
- def getType = toType
- def getExpression = expression
- def foreach[U](f: Expression[_] => U) = expression.foreach(f)
-}
-
-// Literals
-class IntegerLiteral(value: Int) extends Expression[IntType] with LeafExpression {
- def getValue = value
- def getType = new IntType
-}
-
-class FloatLiteral(value: Double) extends Expression[FloatType] with LeafExpression {
- def getValue = value
- def getType = new FloatType
-}
-
-class CharLiteral(value: Char) extends Expression[CharType] with LeafExpression {
- def getValue = value
- def getType = new CharType
-}
+++ /dev/null
-package ofc.codegen
-
-class ForLoop(index: VarSymbol[IntType], begin: Expression[IntType], end: Expression[IntType]) extends ScopeStatement {
- def getIndex = index
- def getBegin = begin
- def getEnd = end
- def getExpressions = List(begin, end)
-}
+++ /dev/null
-package ofc.codegen
-import scala.annotation.tailrec
-import ofc.{UnimplementedException,LogicError}
-
-class SymbolManager {
- import scala.collection.mutable
-
- class SymbolInfo(name: String) {
- def getName = name
- }
-
- private val symbols = mutable.Set[Symbol]()
- private val declaredSymbols = mutable.Map[VarSymbol[_ <: Type], SymbolInfo]()
- private val names = mutable.Set[String]()
-
- private def createNewName(sym: VarSymbol[_]) : String = {
- @tailrec
- def helper(sym: VarSymbol[_], suffix: Int) : String = {
- val candidate = sym.getName + "_" + suffix
- if (names.contains(candidate))
- helper(sym, suffix + 1)
- else
- candidate
- }
-
- helper(sym, 1)
- }
-
- def addSymbol(sym: Symbol) {
- symbols += sym
- }
-
- def addDeclaration(sym: VarSymbol[_ <: Type]) {
- addSymbol(sym)
-
- sym match {
- case (s: DeclaredVarSymbol[_]) => if (!declaredSymbols.contains(s)) {
- val name = createNewName(s)
- names += name
- declaredSymbols += s -> new SymbolInfo(name)
-
- //FIXME: This is a hacky way to detect structures we need to import
- s.getType match {
- case (structType: StructType) => addSymbol(structType)
- case _ => ()
- }
- } else {
- throw new LogicError("Attempted to add symbol more than once (multiple declarations?): "+sym.getName)
- }
-
- case _ => throw new LogicError("Attempted to add declaration not of type DeclaredVarSymbol.")
- }
- }
-
- def getName(sym: VarSymbol[_ <: Type]) =
- declaredSymbols.get(sym) match {
- case None => throw new LogicError("Unknown symbol (missing declaration?): "+sym.getName)
- case Some(info) => info.getName
- }
-
- def getDeclarations : Seq[String] = {
- for ((sym, info) <- declaredSymbols) yield {
- var attributeStrings : Seq[String] = Nil
-
- // It seems these properties need to go after the type-related attributes
- for(property <- sym.getProperties) property match {
- case (p: FortranAttribute) => attributeStrings +:= p.getName
- case _ => ()
- }
-
- attributeStrings ++:= sym.getType.getFortranAttributes
- attributeStrings.mkString(", ") + " :: " + info.getName
- }
- }.toSeq.sorted
-
- def getUses : Seq[String] = {
- var uses : Map[String, mutable.Set[String]] = Map.empty
-
- for (sym <- symbols) {
- for(property <- sym.getProperties) property match {
- case FortranModule(name) => {
- val imported = uses.getOrElse(name, mutable.Set.empty)
- uses += name -> imported
- imported += sym.getName
- }
- case _ => ()
- }
- }
-
- for((moduleName, symbolNames) <- uses) yield
- "use "+moduleName+", only: " + symbolNames.mkString(", ")
- }.toSeq.sorted
-}
-
-object FortranGenerator {
- private val maxPrec = 30
-
- case class BinaryOpInfo(template: String, precedence: Int, assoc: BinaryOpInfo.Associativity)
-
- object BinaryOpInfo {
- sealed abstract class Associativity
- object LEFT extends Associativity
- object RIGHT extends Associativity
- object FUNCTION extends Associativity
- }
-
- private def getBinaryOpInfo(op: NumericOperations.CompareOp) : BinaryOpInfo = {
- import NumericOperations._
- import BinaryOpInfo._
- op match {
- case LT => BinaryOpInfo("%s .lt. %s", 16, LEFT)
- case LE => BinaryOpInfo("%s .le. %s", 16, LEFT)
- case EQ => BinaryOpInfo("%s .eq. %s", 16, LEFT)
- case NE => BinaryOpInfo("%s .ne. %s", 16, LEFT)
- case GT => BinaryOpInfo("%s .gt. %s", 16, LEFT)
- case GE => BinaryOpInfo("%s .ge. %s", 16, LEFT)
- case x => throw new UnimplementedException("Unknown comparison type in FORTRAN generator: "+x.toString)
- }
- }
-
- private def getBinaryOpInfo(op: NumericOperations.FieldOp) : BinaryOpInfo = {
- import NumericOperations._
- import BinaryOpInfo._
- op match {
- case Add => BinaryOpInfo("%s + %s", 22, LEFT)
- case Sub => BinaryOpInfo("%s - %s", 22, LEFT)
- case Mul => BinaryOpInfo("%s * %s", 26, LEFT)
- case Div => BinaryOpInfo("%s / %s", 26, LEFT)
- case Mod => BinaryOpInfo("mod(%s, %s)", maxPrec, FUNCTION)
- case x => throw new UnimplementedException("Unknown numeric operator in FORTRAN generator: "+x.toString)
- }
- }
-
- private def wrapLine(line: String) : Seq[String] = {
- // Fortran95 maximum line length is 132 characters, but let's assume people
- // might want to indent further.
- val maxLineLength = 120
- val buffer = scala.collection.mutable.Buffer[String]()
- val marginMatch = "^\\s*".r
- val margin = marginMatch.findFirstIn(line) match {
- case Some(string) => string
- case _ => ""
- }
-
- var remaining = line.drop(margin.length)
- while (remaining.length + margin.length > maxLineLength) {
- val takeLength = maxLineLength - margin.length + 1
- val nextSubLine = margin + remaining.take(takeLength) + "&"
- remaining = "&"+remaining.drop(takeLength)
- buffer.append(nextSubLine)
- }
- buffer.append(margin + remaining)
- buffer.toSeq
- }
-
- private def wrapLines(lines: Seq[String]) : Seq[String] = lines.flatMap(wrapLine(_))
-}
-
-class FortranGenerator {
- import FortranGenerator.{maxPrec, BinaryOpInfo, getBinaryOpInfo}
-
- private val symbolManager = new SymbolManager
- private val buffer = scala.collection.mutable.Buffer[String]()
- private var indentLevel = 0
-
- case class ExpHolder(prec: Int, exp: String) {
- override def toString = exp
- }
-
- def apply(stat: Statement) : String = {
- processStatement(stat)
-
- prependLine("\n")
- prependLines(symbolManager.getDeclarations)
- FortranGenerator.wrapLines(buffer).mkString("\n")
- }
-
- def apply(func: Function[_ <: Type]) : String = {
- in
- processStatement(func.getBlock)
- prependLine("")
- prependLines(symbolManager.getDeclarations)
- prependLine("implicit none")
- prependLine("")
- prependLines(symbolManager.getUses)
- out
-
- // parameters are only named *after* processing the body
- val paramNames = func.getParameters.map(symbolManager.getName(_))
- val (header, footer) = func.getReturnType match {
- case (_: VoidType) => {
- val header = "subroutine " + func.getName + paramNames.mkString("(", ", ", ")")
- val footer = "end subroutine"
- (header, footer)
- }
- case _ => throw new UnimplementedException("Fortran function code generation not implemented.")
- }
-
- prependLine(header)
- addLine(footer)
- FortranGenerator.wrapLines(buffer).mkString("\n")
- }
-
-
- private def processStatement(stat: Statement) {
- stat match {
- case (x : NullStatement) => ()
- case (x : Comment) => addLine("!" + x.getValue)
- case (x : BlockStatement) => processScope(x)
- case (x : IterationContext) => processStatement(x.toConcrete)
- case (x : ForLoop) => processForLoop(x)
- case (a : AssignStatement) => processAssignment(a)
- case (i : IfStatement) => processIf(i)
- case (a : AllocateStatement) => processAllocate(a)
- case (d : DeallocateStatement) => processDeallocate(d)
- case (f : FunctionCallStatement) => {
- symbolManager.addSymbol(f.getCall.getSignature)
- processFunctionCallStatement(f)
- }
- case x => throw new UnimplementedException("Unknown statement type in FORTRAN generator: " + x.toString)
- }
- }
-
- private def in() {
- indentLevel += 1
- }
-
- private def out() {
- indentLevel -=1
- if (indentLevel < 0) throw new LogicError("Indentation level dropped below 0 in FORTRAN generator.")
- }
-
- private def buildExpression(expression: Expression[_]) : ExpHolder = {
- expression match {
- case (i : IntegerLiteral) => ExpHolder(maxPrec, i.getValue.toString)
- case (i : FloatLiteral) => ExpHolder(maxPrec, i.getValue.toString)
- case (i : CharLiteral) => ExpHolder(maxPrec, "'%s'".format(i.getValue.toString))
- case (a : FieldAccess[_]) => ExpHolder(maxPrec, "%s%%%s".format(buildExpression(a.getStructExpression), a.getField.getName))
- case (r : VarRef[_]) => r.getSymbol match {
- case (s: DeclaredVarSymbol[_]) => ExpHolder(maxPrec, symbolManager.getName(s))
- case (s: Symbol) => {
- symbolManager.addSymbol(s)
- ExpHolder(maxPrec, s.getName)
- }
- }
- case (r: ArrayAccess[_]) =>
- ExpHolder(maxPrec, buildExpression(r.getArrayExpression) + r.getIndexExpressions.map(buildExpression(_)).mkString("(",", ",")"))
- case (d: PointerDereference[_]) => buildExpression(d.getExpression)
- case (c: ConditionalValue[_]) => buildConditionalValue(c)
- case (c: NumericComparison[_]) => buildNumericComparison(c)
- case (c: NumericOperator[_]) => buildNumericOperator(c)
- case (c: Conversion[_,_]) => buildConversion(c)
- case (i: Intrinsic[_]) => buildIntrinsic(i)
- case (f: FunctionCall[_]) => {
- symbolManager.addSymbol(f.getSignature)
- buildFunctionCall(f)
- }
- case x => throw new UnimplementedException("Unknown expression type in FORTRAN generator: " + x.toString)
- }
- }
-
- private def buildConditionalValue(conditional: ConditionalValue[_ <: Type]) : ExpHolder = {
- var symbol = new DeclaredVarSymbol[Type]("ternary", conditional.getType)
- symbolManager.addDeclaration(symbol)
- val name = symbolManager.getName(symbol)
- addLine("if (%s) then".format(buildExpression(conditional.getPredicate)))
- in
- addLine("%s = %s".format(name, buildExpression(conditional.getIfTrue)))
- out
- addLine("else")
- in
- addLine("%s = %s".format(name, buildExpression(conditional.getIfFalse)))
- out
- addLine("endif")
-
- ExpHolder(maxPrec, name)
- }
-
- private def buildBinaryOperation(opInfo: BinaryOpInfo, left: ExpHolder, right: ExpHolder) : ExpHolder = {
- import BinaryOpInfo._
-
- def bracket(opInfo: BinaryOpInfo, exp: ExpHolder, assoc: Associativity) =
- opInfo.assoc != FUNCTION &&
- (opInfo.precedence > exp.prec || (opInfo.precedence == exp.prec && opInfo.assoc != assoc))
-
- val lhs = if (bracket(opInfo, left, LEFT))
- "(" + left.exp + ")"
- else
- left.exp
-
- val rhs = if (bracket(opInfo, right, RIGHT))
- "(" + right.exp + ")"
- else
- right.exp
-
- ExpHolder(opInfo.precedence, opInfo.template.format(lhs, rhs))
- }
-
- private def buildConversion(c: Conversion[_,_]) : ExpHolder = c.getType match {
- case (_: FloatType) => ExpHolder(maxPrec, "real(%s)".format(buildExpression(c.getExpression)))
- case (_: ComplexType) => ExpHolder(maxPrec, "cmplx(%s)".format(buildExpression(c.getExpression)))
- case _ => throw new UnimplementedException("Fortran generator cannot handle conversion.")
- }
-
- private def buildIntrinsic(i: Intrinsic[_]) : ExpHolder = i match {
- case (m: Min[_]) => ExpHolder(maxPrec, "min(%s, %s)".format(buildExpression(m.getLeft), buildExpression(m.getRight)))
- case (m: Max[_]) => ExpHolder(maxPrec, "max(%s, %s)".format(buildExpression(m.getLeft), buildExpression(m.getRight)))
- case _ => throw new UnimplementedException("Unknown intrinsic in Fortran generator: "+i)
- }
-
- private def buildNumericComparison(c: NumericComparison[_]) : ExpHolder =
- buildBinaryOperation(getBinaryOpInfo(c.getOperation), buildExpression(c.getLeft), buildExpression(c.getRight))
-
- private def buildNumericOperator(o: NumericOperator[_]) : ExpHolder = {
- val left = buildExpression(o.getLeft)
- var right = buildExpression(o.getRight)
-
- // Fortran has stupid rules about unary negation on the right hand side of an arithmetic operator.
- def isUnaryNegate(expression: Expression[_]) = expression match {
- case (l: FloatLiteral) if l.getValue < 0 => true
- case (l: IntegerLiteral) if l.getValue < 0 => true
- case _ => false
- }
-
- if (isUnaryNegate(o.getRight))
- right = ExpHolder(maxPrec, "(%s)".format(right.exp))
-
- buildBinaryOperation(getBinaryOpInfo(o.getOperation), left, right)
- }
-
- private def processForLoop(stat: ForLoop) {
- val index = stat.getIndex
- val name = symbolManager.getName(index)
- val begin = buildExpression(stat.getBegin)
- val end = buildExpression(stat.getEnd)
-
- val header = "do %s = %s, %s".format(name, begin, end)
- val footer = "end do"
- addLine(header)
- in
- processScope(stat)
- out
- addLine(footer)
- }
-
- private def processFunctionCallStatement(stat: FunctionCallStatement) {
- val call = stat.getCall
- call.getSignature match {
- case (fortSub: FortranSubroutineSignature) =>
- addLine("call %s(%s)".format(fortSub.getName, call.getParams.map(buildExpression(_)).mkString(", ")))
- case _ => throw new LogicError("Fortran generator only knows how to call Fortran sub-routines.")
- }
- }
-
- private def buildFunctionCall(call: FunctionCall[_]) : ExpHolder = call.getSignature match {
- case (fortFunc: FortranFunctionSignature[_]) =>
- new ExpHolder(maxPrec, "%s(%s)".format(fortFunc.getName, call.getParams.map(buildExpression(_)).mkString(", ")))
- case _ => throw new LogicError("Fortran generator only knows how to call Fortran functions.")
- }
-
- private def processScope(scope: ScopeStatement) {
- for (sym <- scope.getDeclarations) {
- symbolManager.addDeclaration(sym)
- }
- for(stat <- scope.getStatements) {
- processStatement(stat)
- }
- }
-
- private def processAssignment(assignment: AssignStatement) {
- addLine("%s = %s".format(buildExpression(assignment.getLHS), buildExpression(assignment.getRHS)))
- }
-
- private def processIf(ifStatement: IfStatement) {
- addLine("if (%s) then".format(buildExpression(ifStatement.getPredicate)))
- in
- processScope(ifStatement)
- out
- addLine("endif")
- }
-
- private def processAllocate(allocate: AllocateStatement) {
- addLine("allocate(%s(%s))".format(buildExpression(allocate.getArray), allocate.getSize.map(buildExpression(_)).mkString(",")))
- }
-
- private def processDeallocate(deallocate: DeallocateStatement) {
- addLine("deallocate(%s)".format(buildExpression(deallocate.getArray)))
- }
-
- private def addLine(line: String) {
- buffer += " "*indentLevel + line
- }
-
- private def prependLine(line: String) {
- buffer.prepend(" "*indentLevel + line)
- }
-
- private def prependLines(lines: Seq[String]) {
- for(line <- lines.reverse)
- prependLine(line)
- }
-}
+++ /dev/null
-package ofc.codegen
-
-trait FortranProperty extends SymbolProperty {
-}
-
-trait FortranAttribute extends FortranProperty {
- def getName : String
-}
-
-case class AllocatableProperty() extends FortranAttribute {
- def getName = "allocatable"
-}
-
-case class FortranModule(module: String) extends FortranProperty {
- def getName = module
- def getModuleName : String = module
-}
+++ /dev/null
-package ofc.codegen
-
-class Function[R <: Type](name: String, retType: R) {
- val block = new BlockStatement
- var params : Seq[VarSymbol[_ <: Type]] = Nil
-
- def addParameter(param: VarSymbol[_ <: Type]) {
- params :+= param
- block.addDeclaration(param)
- }
-
- def getName = name
-
- def getBlock = block
-
- def getParameters : Seq[VarSymbol[_ <: Type]] = params
-
- def getReturnType : Type = retType
-}
+++ /dev/null
-package ofc.codegen
-import ofc.LogicError
-
-class FunctionCall[R <: Type](signature: FunctionSignature[R], params: Seq[Expression[_]]) extends Expression[R] {
- if (params.size != signature.getParams.size)
- throw new LogicError("Function "+signature.getName+" called with incorrect number of parameters.")
-
- for((param, (name, paramType)) <- params.zip(signature.getParams))
- if (param.getType != paramType)
- throw new LogicError("Type mismatch on parameter "+name+" when calling function "+signature.getName)
-
- def getType = signature.getReturnType
- def foreach[U](f: Expression[_] => U) = params.foreach(f)
- def getSignature = signature
- def getParams = params
-}
+++ /dev/null
-package ofc.codegen
-
-class FunctionCallStatement(call: FunctionCall[VoidType]) extends Statement {
- def getCall : FunctionCall[VoidType] = call
- def getExpressions = call.getParams
-}
+++ /dev/null
-package ofc.codegen
-
-trait FunctionSignature[R <: Type] extends Symbol {
- def getReturnType: R
- def getParams: Seq[(String, Type)]
-}
-
-class FortranSubroutineSignature(name: String,
- params: Seq[(String, Type)]) extends FunctionSignature[VoidType] {
-
- def getName = name
- def getReturnType = new VoidType
- def getParams = params
-}
-
-class FortranFunctionSignature[R <: Type](name: String,
- params: Seq[(String, Type)], retType: R) extends FunctionSignature[R] {
-
- def this(name: String, params: Seq[(String, Type)])(implicit builder: TypeBuilder[R]) =
- this(name, params, builder())
-
- def getName = name
- def getReturnType = retType
- def getParams = params
-}
+++ /dev/null
-package ofc.codegen
-
-class IfStatement(predicate: Expression[BoolType]) extends ScopeStatement {
- def getPredicate : Expression[BoolType] = predicate
- def getExpressions = predicate
-}
+++ /dev/null
-package ofc.codegen
-import ofc.LogicError
-
-abstract class Intrinsic[T <: Type] extends Expression[T]
-
-class Min[T <: Type](a: Expression[T], b: Expression[T])(implicit isNumeric: HasProperty[T, Numeric]) extends Intrinsic[T] {
- if (a.getType != b.getType)
- throw new LogicError("Parameters to min must have matching types.")
-
- def getType = a.getType
- def foreach[U](f: Expression[_] => U) = List(a, b).foreach(f)
- def getLeft = a
- def getRight = b
-}
-
-class Max[T <: Type](a: Expression[T], b: Expression[T])(implicit isNumeric: HasProperty[T, Numeric]) extends Intrinsic[T] {
- if (a.getType != b.getType)
- throw new LogicError("Parameters to max must have matching types.")
-
- def getType = a.getType
- def foreach[U](f: Expression[_] => U) = List(a, b).foreach(f)
- def getLeft = a
- def getRight = b
-}
+++ /dev/null
-package ofc.codegen
-import ofc.LogicError
-import ofc.util.DirectedGraph
-
-object IterationContext {
- object Context {
- private def priority(context: Context) : Int = {
- // This ensures that the nesting ordering is Predicate, DerivedExpression, VariableRange
- // when no other dependencies exist.
- context match {
- case (_: Predicate) => 1
- case (_: DerivedExpression) => 2
- case (_: VariableRange) => 3
- case _ => throw new LogicError("Unknown context type.")
- }
- }
-
- def sort(contexts: Seq[Context]) : Seq[Context] = {
- def pathFunction(c1: Context, c2: Context) = c1.tryCompare(c2) match {
- case Some(x) if x<0 => true
- case _ => false
- }
-
- val graph = new DirectedGraph
- val contextMapping = scala.collection.mutable.Map[Context, DirectedGraph#Vertex]()
-
- for(context <- contexts) {
- val vertex = graph.addVertex
- contextMapping += (context -> vertex)
- }
-
- for(c1 <- contexts; c2 <- contexts; if pathFunction(c1, c2)) {
- graph.addEdge(contextMapping.get(c1).get, contextMapping.get(c2).get)
- }
-
- val flippedMapping = contextMapping.map(_.swap).toMap
- val vertexPriorityFunction = (v: DirectedGraph#Vertex) => priority(flippedMapping.get(v).get)
- val sortedVertices = DirectedGraph.topoSort(graph, vertexPriorityFunction)
- val sortedContexts = { for (v <- sortedVertices) yield flippedMapping.get(v).get }
-
- sortedContexts.toSeq
- }
- }
-
- sealed trait Context {
- def defines : Set[VarSymbol[_]]
- def depends : Set[VarSymbol[_]]
- def tryCompare(other: Context) : Option[Int] = {
- val isChild = other.defines & depends
- val isParent = defines & other.depends
- assert(isChild.isEmpty || isParent.isEmpty)
-
- if (this == other)
- Some(0)
- else if (isParent.nonEmpty)
- Some(-1)
- else if (isChild.nonEmpty)
- Some(1)
- else
- None
- }
- }
-
- case class VariableRange(symbol: VarSymbol[IntType], first: Expression[IntType], last: Expression[IntType]) extends Context {
- override def toString = "VariableRange("+symbol.toString+", "+first.toString+", "+last.toString+")"
- def defines = Set(symbol)
- def depends = Expression.findReferencedVariables(List(first, last))
- }
-
- case class Predicate(expression: Expression[BoolType]) extends Context {
- override def toString = "Predicate("+expression.toString+")"
- def defines = Set.empty
- def depends = Expression.findReferencedVariables(expression)
- }
-
- case class DerivedExpression(symbol: VarSymbol[_ <: Type], expression: Expression[_ <: Type]) extends Context {
- override def toString = "DerivedExpression("+symbol.toString + " <- " + expression.toString+")"
- def defines = Set(symbol)
- def depends = Expression.findReferencedVariables(expression)
- }
-}
-
-class IterationContext extends Statement {
- import IterationContext._
-
- var declarations : Seq[VarSymbol[_ <: Type]] = Nil
- var headers : Seq[Statement] = Nil
- var footers : Seq[Statement] = Nil
- var ranges : Seq[VariableRange] = Nil
- var predicates : Seq[Predicate] = Nil
- var expressions : Seq[DerivedExpression] = Nil
-
- def addDeclaration(declaration: VarSymbol[_ <: Type]) {
- declarations +:= declaration
- }
-
- def addExpression[T <: Type](name: String, expression: Expression[T]) : VarSymbol[T] = {
- val symbol = new DeclaredVarSymbol[T](name, expression.getType)
- expressions +:= new DerivedExpression(symbol, expression)
- symbol
- }
-
- def addIteration(name: String, first: Expression[IntType], last: Expression[IntType]) : VarSymbol[IntType] = {
- val symbol = new DeclaredVarSymbol[IntType](name)
- ranges +:= new VariableRange(symbol, first, last)
- symbol
- }
-
- def addPredicate(condition: Expression[BoolType]) {
- predicates +:= new Predicate(condition)
- }
-
- def addHeader(stat: Statement) {
- headers +:= stat
- }
-
- def addFooter(stat: Statement) {
- footers +:= stat
- }
-
- def merge(statement: IterationContext) : IterationContext = {
- val result = new IterationContext
- result.declarations = declarations ++ statement.declarations
- result.ranges = ranges ++ statement.ranges
- result.predicates = predicates ++ statement.predicates
- result.expressions = expressions ++ statement.expressions
- result.headers = headers ++ statement.headers
- result.footers = footers ++ statement.footers
- result
- }
-
- def toConcrete(statement: Statement) : Statement = {
- val contexts = ranges ++ predicates ++ expressions
- val sortedContexts = Context.sort(contexts)
-
- val block = new BlockStatement
- var scope : ScopeStatement = block
-
- for(declaration <- declarations)
- block.addDeclaration(declaration)
-
- for(header <- headers)
- block += header
-
- for (context <- sortedContexts) {
- context match {
- case VariableRange(sym, first, last) => {
- val loop = new ForLoop(sym, first, last)
- scope.addDeclaration(sym)
- scope += loop
- scope = loop
- }
- case Predicate(expression) => {
- val ifStat = new IfStatement(expression)
- scope += ifStat
- scope = ifStat
- }
- case DerivedExpression(sym, expression) => {
- val assignment = new AssignStatement(sym, expression)
- scope.addDeclaration(sym)
- scope += assignment
- }
- }
- }
-
- scope += statement
-
- for(footer <- footers)
- block += footer
-
- block
- }
-
- def toConcrete : Statement =
- toConcrete(new Comment("Placeholder statement for consumer."))
-
- def getExpressions =
- throw new LogicError("Call IterationContext::toConcrete() before trying to access expressions.")
-}
+++ /dev/null
-package ofc.codegen
-import ofc.LogicError
-
-object NumericOperations {
- sealed abstract class FieldOp
- case object Add extends FieldOp
- case object Sub extends FieldOp
- case object Mul extends FieldOp
- case object Div extends FieldOp
- case object Mod extends FieldOp
-
- sealed abstract class CompareOp
- object LT extends CompareOp
- object LE extends CompareOp
- object EQ extends CompareOp
- object NE extends CompareOp
- object GT extends CompareOp
- object GE extends CompareOp
-}
-
-class NumericOperator[T <: Type](op: NumericOperations.FieldOp, left: Expression[T], right: Expression[T]) extends Expression[T] {
- def foreach[U](f: Expression[_] => U) = List(left, right).foreach(f)
- def getOperation = op
- def getLeft = left
- def getRight = right
- def getType = {
- if (left.getType != right.getType)
- throw new LogicError("Non-matching types for parameters to numeric comparison")
- else
- left.getType
- }
-}
-
-class NumericComparison[T <: Type](op: NumericOperations.CompareOp, left: Expression[T], right: Expression[T]) extends Expression[BoolType] {
- def foreach[U](f: Expression[_] => U) = List(left, right).foreach(f)
- def getOperation = op
- def getLeft = left
- def getRight = right
- def getType = new BoolType
-}
+++ /dev/null
-package ofc.codegen
-
-abstract class ScopeStatement(initialStatements: Seq[Statement] = Nil) extends Statement {
- val declarations = scala.collection.mutable.Set[VarSymbol[_ <: Type]]()
- val statements = initialStatements.toBuffer
-
- def +=(stat: Statement) {
- statements += stat
- }
-
- def getStatements : Seq[Statement] = statements.toSeq
-
- def addDeclaration(sym: VarSymbol[_ <: Type]) {
- declarations += sym
- }
-
- def getDeclarations : Seq[VarSymbol[_ <: Type]] = declarations.toSeq
-}
-
-class BlockStatement(initialStatements: Seq[Statement] = Nil) extends ScopeStatement(initialStatements) {
- def getExpressions = Nil
-}
+++ /dev/null
-package ofc.codegen
-
-trait Statement {
- def getExpressions: Traversable[Expression[_]]
-}
-
-class NullStatement extends Statement {
- def getExpressions = Nil
-}
+++ /dev/null
-package ofc.codegen
-
-trait Symbol {
- private var properties : Seq[SymbolProperty] = Nil
-
- def getName : String
- def addProperty(property: SymbolProperty) {
- properties +:= property
- }
- def getProperties : Seq[SymbolProperty] = properties
-}
-
-trait SymbolProperty {
- def getName : String
-}
-
-class FieldSymbol[T <: Type](name: String, fieldType: T) extends Symbol {
- def this(name: String)(implicit builder: TypeBuilder[T]) = this(name, builder())
- def getName = name
- def getType = fieldType
-}
-
-sealed abstract class VarSymbol[T <: Type](name: String, varType: T) extends Symbol {
- def getName = name
- def getType : T = varType
-}
-
-object VarSymbol {
- implicit def toRef[T <: Type](symbol: VarSymbol[T]) = new VarRef[T](symbol)
-}
-
-class DeclaredVarSymbol[T <: Type](name: String, varType: T) extends VarSymbol[T](name, varType) {
- def this(name: String)(implicit builder: TypeBuilder[T]) = this(name, builder())
-}
-
-class NamedUnboundVarSymbol[T <: Type](name: String, varType: T) extends VarSymbol[T](name, varType) {
- def this(name: String)(implicit builder: TypeBuilder[T]) = this(name, builder())
-}
+++ /dev/null
-package ofc.codegen
-import ofc.LogicError
-
-sealed abstract class Type {
- def getFortranAttributes : Set[String]
-}
-sealed abstract class PrimitiveType extends Type
-
-// These are case classes solely for the comparison operators
-final case class IntType() extends PrimitiveType {
- def getFortranAttributes = Set("integer")
-}
-
-final case class FloatType() extends PrimitiveType {
- def getFortranAttributes = Set("real(kind=DP)")
-}
-
-final case class BoolType() extends PrimitiveType {
- def getFortranAttributes = Set("logical")
-}
-
-final case class VoidType() extends PrimitiveType {
- def getFortranAttributes = throw new LogicError("void type does not exist in Fortran.")
-}
-
-final case class CharType() extends PrimitiveType {
- def getFortranAttributes = Set("character")
-}
-
-final case class ComplexType() extends PrimitiveType {
- def getFortranAttributes = Set("complex(kind=DP)")
-}
-
-final case class ArrayType[ElementType <: Type](rank: Int, eType: ElementType) extends Type {
- def this(rank: Int)(implicit builder: TypeBuilder[ElementType]) = this(rank, builder())
- def getElementType = eType
- def getFortranAttributes = eType.getFortranAttributes ++ Set((":"*rank).mkString("dimension(",",",")"))
- def getRank = rank
-}
-
-final case class PointerType[TargetType <: Type](tType: TargetType) extends Type {
- def this()(implicit builder: TypeBuilder[TargetType]) = this(builder())
- def getTargetType = tType
- def getFortranAttributes = tType.getFortranAttributes + "pointer"
-}
-
-abstract class StructType extends Type with Symbol {
- def getFortranAttributes = Set("type(" + getName + ")")
-}
-
-trait TypeBuilder[T <: Type] {
- def apply() : T
-}
-
-object TypeBuilder {
- implicit val charBuilder = new TypeBuilder[CharType] { def apply() = new CharType }
- implicit val intBuilder = new TypeBuilder[IntType] { def apply() = new IntType }
- implicit val floatBuilder = new TypeBuilder[FloatType] { def apply() = new FloatType }
- implicit val complexBuilder = new TypeBuilder[ComplexType] { def apply() = new ComplexType }
- implicit val boolBuilder = new TypeBuilder[BoolType] { def apply() = new BoolType }
-}
-
-trait TypeProperty
-trait Numeric extends TypeProperty
-
-class HasProperty[T <: Type, P <: TypeProperty]
-object HasProperty {
- implicit val intNumeric = new HasProperty[IntType, Numeric]()
- implicit val floatNumeric = new HasProperty[FloatType, Numeric]()
- implicit val complexNumeric = new HasProperty[ComplexType, Numeric]()
-}
-
-class IsConvertible[From <: Type, To <: Type]
-object IsConvertible {
- implicit val intToFloat = new IsConvertible[IntType, FloatType]()
- implicit val floatToFloat = new IsConvertible[FloatType, ComplexType]()
-}
+++ /dev/null
-package ofc.expression
-
-case class Assignment(lhs: Scalar, rhs: Scalar) {
- override def toString = lhs.toString + " = " + rhs.toString
-}
+++ /dev/null
-package ofc.expression
-import ofc.parser.Identifier
-
-class Dictionary {
- import scala.collection.mutable.HashMap
-
- var matrices = new HashMap[Identifier, Matrix]
- var functionSets = new HashMap[Identifier, FunctionSet]
- var indices = new HashMap[Identifier, Index]
-
- def add(matrix: Matrix) {
- matrices += matrix.getIdentifier -> matrix
- }
-
- def add(functionSet: FunctionSet) {
- functionSets += functionSet.getIdentifier -> functionSet
- }
-
- def add(index: Index) {
- indices += index.getIdentifier -> index
- }
-
- def getMatrix(id: Identifier) : Option[Matrix] = matrices.get(id)
-
- def getFunctionSet(id: Identifier) : Option[FunctionSet] = functionSets.get(id)
-
- def getIndex(id: Identifier) : Option[Index] = indices.get(id)
-
- def getOperands = matrices.values ++ functionSets.values
-
- def getIndices = indices.values
-}
+++ /dev/null
-package ofc.expression
-import ofc.parser.Identifier
-
-case class Index(id: Identifier) {
- def getIdentifier = id
- def getName = id.getName
- override def toString() = id.getName
-}
-
-sealed trait Expression {
- def isAssignable : Boolean
- def numIndices : Int
- def getDependentIndices : Set[Index]
-}
-
-sealed trait NamedOperand {
- val id: Identifier
- def getIdentifier = id
- def isAssignable = true
- def getDependentIndices : Set[Index] = Set.empty
- override def toString = id.getName
-}
-
-sealed trait Scalar extends Expression
-sealed trait Field extends Expression
-
-trait IndexingOperation {
- val op: Expression
- def getIndices : List[Index]
- def isAssignable = op.isAssignable
- def numIndices = 0
- def getDependentIndices = op.getDependentIndices ++ getIndices
- override def toString = op.toString + getIndices.map(_.getName).mkString("[",",","]")
-}
-
-case class ScalarIndexingOperation(val op: Scalar, indices: List[Index]) extends IndexingOperation with Scalar {
- def getIndices = indices
-}
-
-case class FieldIndexingOperation(val op: Field, indices: List[Index]) extends IndexingOperation with Field {
- def getIndices = indices
-}
-
-case class InnerProduct(left: Field, right: Field) extends Scalar {
- override def toString = "inner(" + left.toString + ", " + right.toString+")"
- def isAssignable = false
- def numIndices = 0
- def getDependentIndices = left.getDependentIndices ++ right.getDependentIndices
-}
-
-case class Laplacian(op: Field) extends Field {
- override def toString = "laplacian("+op.toString+")"
- def isAssignable = false
- def numIndices = 0
- def getDependentIndices = op.getDependentIndices
-}
-
-case class FieldScaling(op: Field, scale: Scalar) extends Field {
- override def toString = "(" + op.toString + "*" + scale.toString + ")"
- def isAssignable = false
- def numIndices = 0
- def getDependentIndices = op.getDependentIndices ++ scale.getDependentIndices
-}
-
-case class ScalarLiteral(literal: Double) extends Scalar {
- override def toString = literal.toString
- def isAssignable = false
- def numIndices = 0
- def getDependentIndices = Set.empty
-}
-
-class FunctionSet(val id: Identifier) extends Field with NamedOperand {
- def numIndices = 1
-}
-
-class Matrix(val id: Identifier) extends Scalar with NamedOperand {
- def numIndices = 2
-}
+++ /dev/null
-package ofc.expression
-
-import ofc.parser
-import ofc.parser.Identifier
-import ofc.{InvalidInputException,UnimplementedException}
-
-class TreeBuilder(dictionary : Dictionary) {
- def apply(definition: parser.Definition) : Assignment = {
- val lhsTree = buildExpression(definition.term)
- val rhsTree = buildExpression(definition.expr)
-
- if (!lhsTree.isAssignable)
- throw new InvalidInputException("Non-assignable expression on LHS of assignment.")
- else (lhsTree, rhsTree) match {
- case (lhs: Scalar, rhs: Scalar) => new Assignment(lhs, rhs)
- case _ => throw new InvalidInputException("Assignment must be of scalar type.")
- }
- }
-
- private def buildIndexedOperand(term: parser.IndexedIdentifier) : Expression = {
- val indices = for(id <- term.indices) yield buildIndex(id)
-
- dictionary.getMatrix(term.id) match {
- case Some(matrix) => new ScalarIndexingOperation(matrix, indices)
- case None => dictionary.getFunctionSet(term.id) match {
- case Some(functionSet) => new FieldIndexingOperation(functionSet, indices)
- case None => throw new UnimplementedException("No idea how to index "+term.id)
- }
- }
- }
-
- private def buildIndex(id: parser.Identifier) : Index = dictionary.getIndex(id) match {
- case Some(index) => index
- case None => throw new InvalidInputException("Unknown index "+id)
- }
-
- private def buildIndex(term: parser.Expression) : Index = term match {
- case (indexedID: parser.IndexedIdentifier) => {
- if (indexedID.indices.nonEmpty)
- throw new InvalidInputException("Tried to parse expression "+term+" as index but it is indexed.")
- else
- buildIndex(indexedID.id)
- }
- case other => throw new InvalidInputException("Cannot parse expression "+other+" as index.")
- }
-
- private def buildExpression(term: parser.Expression) : Expression = {
- import parser._
-
- term match {
- case (t: IndexedIdentifier) => buildIndexedOperand(t)
- case ScalarConstant(s) => new ScalarLiteral(s)
- case Division(a, b) =>
- throw new UnimplementedException("Semantics of division not yet defined, or implemented.")
- case Multiplication(left, right) => (buildExpression(left), buildExpression(right)) match {
- case (field: Field, factor: Scalar) => new FieldScaling(field, factor)
- case (factor: Scalar, field: Field) => new FieldScaling(field, factor)
- case _ => throw new InvalidInputException("Cannot multiply "+left+" and "+right+".")
- }
- case Operator(Identifier("inner"), List(a,b)) => (buildExpression(a), buildExpression(b)) match {
- case (left: Field, right: Field) => new InnerProduct(left, right)
- case _ => throw new InvalidInputException("inner requires both operands to be fields.")
- }
- case Operator(Identifier("laplacian"), List(op)) => buildExpression(op) match {
- case (field: Field) => new Laplacian(field)
- case _ => throw new InvalidInputException("laplacian can only be applied to a field.")
- }
- case Operator(Identifier(name), _) => throw new UnimplementedException("Unknown or incorrectly called operator: "+name)
- }
- }
-}
+++ /dev/null
-package ofc.generators
-import ofc.parser.TargetAssignment
-import ofc.expression.{Dictionary,Assignment}
-
-trait Generator {
- def acceptInput(dictionary: Dictionary, expression: Assignment, targetSpecific : Seq[TargetAssignment])
-}
+++ /dev/null
-package ofc.generators
-
-import ofc.InvalidInputException
-import ofc.parser
-import ofc.codegen
-import ofc.expression
-import ofc.generators.onetep._
-
-class Onetep extends Generator {
- val dictionary = new Dictionary
- var parameters : Map[String, codegen.VarSymbol[_ <: codegen.Type]] = Map.empty
- var functionIdentifiers : Option[(String, Seq[String])] = None
-
- def acceptInput(exprDictionary: expression.Dictionary, exprAssignment:
- expression.Assignment, targetSpecific : Seq[parser.TargetAssignment]) {
-
- buildDictionary(exprDictionary, targetSpecific)
- val function = buildFunction(targetSpecific)
- val assignment = new Assignment(buildScalarExpression(exprAssignment.lhs), buildScalarExpression(exprAssignment.rhs))
- val codeGenerator = new CodeGenerator(dictionary, function.getBlock)
- codeGenerator(assignment)
-
- val generator = new codegen.FortranGenerator
- val code = generator(function)
- println(code)
- }
-
- private def buildDictionary(exprDictionary: expression.Dictionary, targetSpecific : Seq[parser.TargetAssignment]) {
- for(operand <- exprDictionary.getOperands) {
- // Find corresponding target-specific declaration if it exists.
- val targetDeclarationCall = targetSpecific.filter(_.id == operand.getIdentifier) match {
- case Seq(x) => Some(x.value)
- case Seq(_,_,_*) => throw new InvalidInputException("Invalid multiple target declarations for symbol " + operand.getIdentifier + ".")
- case Nil => None
- }
-
- operand match {
- case (m: expression.Matrix) => buildMatrix(operand.getIdentifier, targetDeclarationCall)
- case (f: expression.FunctionSet) => buildFunctionSet(operand.getIdentifier, targetDeclarationCall)
- }
- }
-
- for(index <- exprDictionary.getIndices) {
- dictionary.addIndex(index.getIdentifier, new NamedIndex(index.getName))
- }
- }
-
- private def buildFunction(targetSpecific : Seq[parser.TargetAssignment]) = {
- import parser._
- import codegen._
-
- val outputCall = targetSpecific.filter(_.id == Identifier("output")) match {
- case Seq(x) => x.value
- case Seq(_,_,_*) => throw new InvalidInputException("Too many output function specifications.")
- case Nil => throw new InvalidInputException("No output function specification found.")
- }
-
- outputCall match {
- case FunctionCall(Identifier("FortranFunction"), callInfo) => callInfo match {
- case ParameterList(StringParameter(funcName), funcParams: ParameterList) => {
- val function = new Function(funcName, new VoidType)
-
- for(funcParam <- funcParams.toSeq) funcParam match {
- case StringParameter(paramName) => parameters.get(paramName) match {
- case Some(symbol) => function.addParameter(symbol)
- case None => throw new InvalidInputException("Unable to find definition of parameter "+paramName)
- }
- case _ => throw new InvalidInputException("FortranFunction only takes string parameters")
- }
-
- function
- }
- case _ => throw new InvalidInputException("FortranFunction takes a name and a parameter list.")
- }
- case _ => throw new InvalidInputException("Unknown output type "+outputCall.name)
- }
- }
-
- private def getIndex(exprIndex: Seq[expression.Index]) : Seq[NamedIndex] = {
- for(index <- exprIndex) yield
- dictionary.getIndex(index.getIdentifier)
- }
-
- private def matchLHS(lhs: expression.Scalar) : Boolean = {
- lhs match {
- case expression.ScalarIndexingOperation(_: expression.Matrix, List(bra, ket)) => true
- case _ => false
- }
- }
-
- private def buildScalarExpression(scalar: expression.Scalar) : Scalar = {
- scalar match {
- case expression.ScalarLiteral(s) => new ScalarLiteral(s)
- case expression.InnerProduct(l, r) => new InnerProduct(buildFieldExpression(l), buildFieldExpression(r))
- case expression.ScalarIndexingOperation(op, indices) => buildScalarAccess(op, indices)
- case (_: expression.Matrix) => throw new InvalidInputException("Cannot handle un-indexed matrices.")
- }
- }
-
- private def buildScalarAccess(op: expression.Scalar, indices: Seq[expression.Index]) : Scalar =
- op match {
- case (matrix: expression.Matrix) => dictionary.getScalar(matrix.getIdentifier)(getIndex(indices))
- case _ => throw new InvalidInputException("Can only index leaf-matrices.")
- }
-
- private def buildFieldAccess(op: expression.Field, indices: Seq[expression.Index]) : Field =
- op match {
- case (functionSet: expression.FunctionSet) => dictionary.getField(functionSet.getIdentifier)(getIndex(indices))
- case _ => throw new InvalidInputException("Can only index function-sets.")
- }
-
- private def buildFieldExpression(field: expression.Field) : Field = {
- field match {
- case expression.Laplacian(op) => new Laplacian(buildFieldExpression(op))
- case expression.FieldScaling(op, scale) => new ScaledField(buildFieldExpression(op), buildScalarExpression(scale))
- case expression.FieldIndexingOperation(op, indices) => buildFieldAccess(op, indices)
- case (_: expression.FunctionSet) => throw new InvalidInputException("Cannot handle un-indexed function sets.")
- }
- }
-
- private def buildMatrix(id: parser.Identifier, call : Option[parser.FunctionCall]) {
- import parser._
- import codegen._
-
- call match {
- case Some(FunctionCall(matType, params)) => (matType, params) match {
- case (Identifier("SPAM3"), ParameterList(StringParameter(name))) => {
- val mat = new DeclaredVarSymbol[StructType](name, OnetepTypes.SPAM3)
- parameters += (name -> mat)
- dictionary.addScalar(id, new SPAM3(mat, _: Seq[NamedIndex]))
- }
- case _ => throw new InvalidInputException("Unknown usage of type: "+matType.name)
- }
- case _ => throw new InvalidInputException("Undefined concrete type for matrix: "+id.name)
- }
- }
-
- private def buildFunctionSet(id: parser.Identifier, call : Option[parser.FunctionCall]) {
- import parser._
- import codegen._
-
- call match {
- case Some(FunctionCall(fSetType, params)) => (fSetType, params) match {
- case (Identifier("PPDFunctionSet"), ParameterList(StringParameter(basisName), StringParameter(dataName))) => {
- val basis = new DeclaredVarSymbol[StructType](basisName, OnetepTypes.FunctionBasis)
- val data = new DeclaredVarSymbol[ArrayType[FloatType]](dataName, new ArrayType[FloatType](1))
- parameters += (basisName -> basis)
- parameters += (dataName -> data)
- dictionary.addField(id, new PPDFunctionSet(basis, data, _: Seq[NamedIndex]))
- }
- case _ => throw new InvalidInputException("Unknown usage of type: "+fSetType.name)
- }
- case _ => throw new InvalidInputException("Undefined concrete type for function set: "+id.name)
- }
- }
-}
+++ /dev/null
-package ofc.generators.onetep
-
-class Assignment(val lhs: Scalar, val rhs: Scalar)
+++ /dev/null
-package ofc.generators.onetep
-import ofc.codegen._
-
-class CodeGenerator(dictionary: Dictionary, scope: ScopeStatement) {
- class Context extends GenerationContext {
- val block = new BlockStatement
-
- def addDeclaration(sym: VarSymbol[_ <: Type]) {
- block.addDeclaration(sym)
- }
-
- def +=(stat: Statement) {
- block += stat
- }
-
- def getStatement: Statement = block
- }
-
- def apply(assignment: Assignment) {
- val lhs = assignment.lhs
- val rhs = assignment.rhs
-
- val iterationInfo = lhs.getIterationInfo
- val context = new Context
-
- val indexMap = iterationInfo.getIndexMappings
- val lhsFragment = lhs.getFragment(indexMap)
- val rhsFragment = rhs.getFragment(indexMap)
-
- rhsFragment.setup(context)
- lhsFragment.setValue(context, rhsFragment.getValue)
- rhsFragment.teardown(context)
-
- val iterated = iterationInfo.getContext.toConcrete(context.getStatement)
- scope += iterated;
- }
-}
+++ /dev/null
-package ofc.generators.onetep
-import ofc.codegen._
-
-class DensePsincToReciprocal(op: DensePsincFragment, indices: Map[NamedIndex, Expression[IntType]]) extends ReciprocalFragment {
- import OnetepTypes.FunctionBasis
- val reciprocalBox = new DeclaredVarSymbol[ArrayType[ComplexType]]("reciprocal_box", new ArrayType[ComplexType](3))
- reciprocalBox.addProperty(new AllocatableProperty)
-
- def setup(context: GenerationContext) {
- import OnetepTypes.FFTBoxInfo
-
- op.setup(context)
- context.addDeclaration(reciprocalBox)
-
- val fftboxSize : Seq[Expression[IntType]] = op.getSize
- context += new AllocateStatement(reciprocalBox, fftboxSize)
-
- val fourierParams : Seq[Expression[_]] = Seq(new CharLiteral('C'), new CharLiteral('F'), op.getBuffer, op.getBuffer, reciprocalBox)
- context += new FunctionCallStatement(new FunctionCall(OnetepFunctions.fourier_apply_box_pair, fourierParams))
- op.teardown(context)
- }
-
- def teardown(context: GenerationContext) {
- context += new DeallocateStatement(reciprocalBox)
- }
-
- def getSize = for (dim <- 0 to 2) yield OnetepTypes.FFTBoxInfo.public % OnetepTypes.FFTBoxInfo.totalPts(dim)
-
- def getOrigin = op.getOrigin
-
- def getBuffer = reciprocalBox
-
- def toPsinc = new ReciprocalToPsinc(this)
-
- def toDensePsinc = new ReciprocalToPsinc(this)
-}
+++ /dev/null
-package ofc.generators.onetep
-import ofc.parser.Identifier
-import ofc.InvalidInputException
-
-class Dictionary {
- import scala.collection.mutable.HashMap
-
- var scalars = new HashMap[Identifier, Seq[NamedIndex] => Scalar]
- var fields = new HashMap[Identifier, Seq[NamedIndex] => Field]
- var indices = new HashMap[Identifier, NamedIndex]
-
- def addScalar(id: Identifier, scalarGenerator: Seq[NamedIndex] => Scalar) {
- scalars += id -> scalarGenerator
- }
-
- def addField(id: Identifier, fieldGenerator: Seq[NamedIndex] => Field) {
- fields += id -> fieldGenerator
- }
-
- def addIndex(id: Identifier, index: NamedIndex) {
- indices += id -> index
- }
-
- def getScalar(id: Identifier) = scalars.get(id) match {
- case Some(s) => s
- case None => throw new InvalidInputException("Unknown scalar operand "+id.getName)
- }
-
- def getField(id: Identifier) = fields.get(id) match {
- case Some(f) => f
- case None => throw new InvalidInputException("Unknown field operand "+id.getName)
- }
-
- def getIndex(id: Identifier) : NamedIndex = indices.get(id) match {
- case Some(i) => i
- case None => throw new InvalidInputException("Unknown index operand "+id.getName)
- }
-
- def getIndices = indices.values
-}
+++ /dev/null
-package ofc.generators.onetep
-import ofc.codegen._
-
-trait Field extends Operand {
- def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : FieldFragment
-}
+++ /dev/null
-package ofc.generators.onetep
-import ofc.codegen._
-
-trait FieldFragment extends Fragment {
- def toReciprocal : ReciprocalFragment
- def toPsinc : PsincFragment
- def toDensePsinc : DensePsincFragment
-}
-
-trait PsincFragment extends FieldFragment {
- def toPsinc = this
-}
-
-trait DensePsincFragment extends PsincFragment {
- def getBuffer : Expression[ArrayType[FloatType]]
- def toDensePsinc = this
- def getSize : Seq[Expression[IntType]]
- def getOrigin : Seq[Expression[IntType]]
-}
-
-trait ReciprocalFragment extends FieldFragment {
- def toReciprocal = this
- def getSize : Seq[Expression[IntType]]
- def getOrigin : Seq[Expression[IntType]]
- def getBuffer : Expression[ArrayType[ComplexType]]
-}
+++ /dev/null
-package ofc.generators.onetep
-
-trait Fragment {
- def setup(context: GenerationContext)
- def teardown(context: GenerationContext)
-}
+++ /dev/null
-package ofc.generators.onetep
-import ofc.codegen._
-
-trait GenerationContext {
- def addDeclaration(sym: VarSymbol[_ <: Type])
- def +=(stat: Statement)
-}
+++ /dev/null
-package ofc.generators.onetep
-import ofc.codegen._
-
-class InnerProduct(left: Field, right: Field) extends Scalar {
- class LocalFragment(left: FieldFragment, right: FieldFragment) extends ScalarFragment with NonAssignableScalarFragment {
- val result = new DeclaredVarSymbol[FloatType]("inner_product_result")
- val leftDense = left.toDensePsinc
- val rightDense = right.toDensePsinc
-
- def setup(context: GenerationContext) {
- context.addDeclaration(result)
- leftDense.setup(context)
- rightDense.setup(context)
-
- val leftOrigin = leftDense.getOrigin
- val leftSize = leftDense.getSize
-
- val rightOrigin = rightDense.getOrigin
- val rightSize = rightDense.getSize
-
- val topLeft : Seq[Expression[IntType]] =
- for (dim <- 0 to 2) yield new Max[IntType](leftOrigin(dim), rightOrigin(dim))
-
- val bottomRight : Seq[Expression[IntType]] =
- for (dim <- 0 to 2) yield new Min[IntType](leftOrigin(dim) + leftSize(dim), rightOrigin(dim) + rightSize(dim)) - 1
-
- val indices = for(dim <- 0 to 2) yield {
- val index = new DeclaredVarSymbol[IntType]("i"+(dim+1))
- context.addDeclaration(index)
- index
- }
-
- val loops = for(dim <- 0 to 2) yield new ForLoop(indices(dim), topLeft(dim), bottomRight(dim))
- for(dim <- 1 to 2) loops(dim) += loops(dim-1)
-
- context += new AssignStatement(result, new FloatLiteral(0.0))
- context += loops(2)
-
- val leftIndex = for (dim <- 0 to 2) yield indices(dim) - leftOrigin(dim) + 1
- val rightIndex = for (dim <- 0 to 2) yield indices(dim) - rightOrigin(dim) + 1
-
- loops(0) += new AssignStatement(result, (result : Expression[FloatType]) +
- leftDense.getBuffer.at(leftIndex: _*) *
- rightDense.getBuffer.at(rightIndex: _*) *
- (OnetepTypes.CellInfo.public % OnetepTypes.CellInfo.weight))
-
- leftDense.teardown(context)
- rightDense.teardown(context)
- }
-
- def getValue = result
-
- def teardown(context: GenerationContext) {
- }
- }
-
- def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : ScalarFragment =
- new LocalFragment(left.getFragment(indices), right.getFragment(indices))
-
- def getIterationInfo : IterationInfo = {
- var leftInfo = left.getIterationInfo
- var rightInfo = right.getIterationInfo
- leftInfo merge rightInfo
- }
-}
+++ /dev/null
-package ofc.generators.onetep
-import ofc.codegen._
-
-class Laplacian(op: Field) extends Field {
- class LocalFragment(parent: Laplacian, indices: Map[NamedIndex, Expression[IntType]]) extends ReciprocalFragment {
- val transformed = new DeclaredVarSymbol[ArrayType[ComplexType]]("transformed", new ArrayType[ComplexType](3))
- transformed.addProperty(new AllocatableProperty)
-
- val opFragment = parent.getOperand.getFragment(indices).toReciprocal
-
- def setup(context: GenerationContext) {
- context.addDeclaration(transformed)
- opFragment.setup(context)
-
- context += new AllocateStatement(transformed, opFragment.getSize)
-
- val indices = for(dim <- 0 to 2) yield {
- val index = new DeclaredVarSymbol[IntType]("i"+(dim+1))
- context.addDeclaration(index)
- index
- }
-
- // Construct loops
- val loops = for(dim <- 0 to 2) yield new ForLoop(indices(dim), 1, getSize(dim))
-
- val frequencies = for(dim <- 0 to 2) yield {
- val index = indices(dim)
- val frequency = new DeclaredVarSymbol[IntType]("freq_"+(dim+1))
- val halfWidth = getSize(dim)/2 + 1;
- context.addDeclaration(frequency)
- loops(dim) += new AssignStatement(frequency,
- new ConditionalValue[IntType](index |>| halfWidth, index - getSize(dim) - 1, index - 1))
- frequency
- }
-
- // Nest loops and add outer to context
- for(dim <- 1 to 2) loops(dim) += loops(dim-1)
- context += loops(2)
-
- val reciprocalVector = for(dim <- 0 to 2) yield {
- val component = new DeclaredVarSymbol[FloatType]("reciprocal_vector"+(dim+1))
- context.addDeclaration(component)
- (component : Expression[FloatType])
- }
-
- for(dim <- 0 to 2) {
- var component : Expression[FloatType] = new FloatLiteral(0.0)
- for(vec <- 0 to 2) {
- val vector = OnetepTypes.FFTBoxInfo.public % OnetepTypes.FFTBoxInfo.latticeReciprocal(vec)
- component = component + vector % OnetepTypes.Point.coord(dim) * new Conversion[IntType, FloatType](frequencies(vec))
- }
- loops(0) += new AssignStatement(reciprocalVector(dim), component)
- }
-
- val reciprocalIndex = indices.map(new VarRef[IntType](_))
-
- //TODO: Use a unary negation instead of multiplication by -1.0.
- loops(0) += new AssignStatement(transformed.at(reciprocalIndex: _*),
- opFragment.getBuffer.at(reciprocalIndex: _*) *
- new Conversion[FloatType, ComplexType](magnitude(reciprocalVector) * new FloatLiteral(-1.0)))
-
- opFragment.teardown(context)
- }
-
- private def magnitude(vector: Seq[Expression[FloatType]]) = {
- var result : Expression[FloatType] = new FloatLiteral(0.0)
- for(element <- vector) result += element * element
- result
- }
-
- def teardown(context: GenerationContext) {
- context += new DeallocateStatement(transformed)
- }
-
- def getSize = opFragment.getSize
-
- def getOrigin = opFragment.getOrigin
-
- def getBuffer = transformed
-
- def toPsinc = new ReciprocalToPsinc(this)
-
- def toDensePsinc = new ReciprocalToPsinc(this)
- }
-
- private def getOperand = op
-
- def getFragment(indices: Map[NamedIndex, Expression[IntType]]) =
- new LocalFragment(this, indices)
-
- def getIterationInfo : IterationInfo =
- op.getIterationInfo
-}
+++ /dev/null
-package ofc.generators.onetep
-
-class NamedIndex(name: String) {
- def getName = name
-}
+++ /dev/null
-package ofc.generators.onetep
-import ofc.codegen._
-
-object OnetepFunctions {
-
- // module basis
-
- val basis_copy_function_to_box = new FortranSubroutineSignature("basis_copy_function_to_box",
- Seq(("fa_box", new ArrayType[FloatType](3)),
- ("box_n1", new IntType),
- ("box_n2", new IntType),
- ("box_n3", new IntType),
- ("offset1", new IntType),
- ("offset2", new IntType),
- ("offset3", new IntType),
- ("fa_tightbox", OnetepTypes.TightBox),
- ("fa_on_grid", new ArrayType[FloatType](1)),
- ("fa_sphere", OnetepTypes.Sphere)))
-
- val basis_ket_start_wrt_fftbox = new FortranSubroutineSignature("basis_ket_start_wrt_fftbox",
- Seq(("row_start1", new IntType),
- ("row_start1", new IntType),
- ("row_start3", new IntType),
- ("n1", new IntType),
- ("n2", new IntType),
- ("n3", new IntType)))
-
- List(basis_copy_function_to_box, basis_ket_start_wrt_fftbox).map(_.addProperty(new FortranModule("basis")))
-
- // module fourier
-
- val fourier_apply_box_pair = new FortranSubroutineSignature("fourier_apply_box_pair",
- Seq(("grid", new CharType),
- ("dir", new CharType),
- ("rspc1", new ArrayType[FloatType](3)),
- ("rspc2", new ArrayType[FloatType](3)),
- ("gspc", new ArrayType[ComplexType](3))))
-
- List(fourier_apply_box_pair).map(_.addProperty(new FortranModule("fourier")))
-
- // module sparse
-
- val sparse_first_elem_on_node = new FortranFunctionSignature[IntType]("sparse_first_elem_on_node",
- Seq(("node", new IntType),
- ("mat", OnetepTypes.SPAM3),
- ("rowcol", new CharType)))
-
- val sparse_index_length = new FortranFunctionSignature[IntType]("sparse_index_length",
- Seq(("mat", OnetepTypes.SPAM3)))
-
- val sparse_generate_index = new FortranSubroutineSignature("sparse_generate_index",
- Seq(("idx", new ArrayType[IntType](1)),
- ("mat", OnetepTypes.SPAM3)))
-
- val sparse_atom_of_elem = new FortranFunctionSignature[IntType]("sparse_atom_of_elem",
- Seq(("elem", new IntType),
- ("mat", OnetepTypes.SPAM3),
- ("rowcol", new CharType)))
-
- val sparse_put_element_real = new FortranSubroutineSignature("sparse_put_element",
- Seq(("el", new FloatType),
- ("mat", OnetepTypes.SPAM3),
- ("jrow", new IntType),
- ("jcol", new IntType)))
-
- List(sparse_first_elem_on_node,
- sparse_index_length,
- sparse_generate_index,
- sparse_atom_of_elem,
- sparse_put_element_real).map(_.addProperty(new FortranModule("sparse")))
-}
+++ /dev/null
-package ofc.generators.onetep
-import ofc.codegen._
-
-object OnetepTypes {
- object FunctionBasis extends StructType {
- addProperty(new FortranModule("function_basis"))
- def getName = "FUNC_BASIS"
-
- val numPPDsInSphere = {
- val fieldType = new PointerType[ArrayType[IntType]](new ArrayType[IntType](1))
- new FieldSymbol[PointerType[ArrayType[IntType]]]("n_ppds_sphere", fieldType)
- }
-
- val num = new FieldSymbol[IntType]("num")
-
- val tightBoxes = {
- val fieldType = new PointerType[ArrayType[StructType]](new ArrayType(1, TightBox))
- new FieldSymbol[PointerType[ArrayType[StructType]]]("tight_boxes", fieldType)
- }
-
- val spheres = {
- val fieldType = new PointerType[ArrayType[StructType]](new ArrayType[StructType](1, Sphere))
- new FieldSymbol[PointerType[ArrayType[StructType]]]("spheres", fieldType)
- }
- }
-
- object Sphere extends StructType {
- addProperty(new FortranModule("basis"))
- def getName = "SPHERE"
-
- val ppdList = {
- val fieldType = new PointerType[ArrayType[IntType]](new ArrayType[IntType](2))
- new FieldSymbol[PointerType[ArrayType[IntType]]]("ppd_list", fieldType)
- }
-
- val offset = new FieldSymbol[IntType]("offset")
- }
-
- object Point extends StructType {
- addProperty(new FortranModule("geometry"))
- def getName = "POINT"
-
- val x = new FieldSymbol[FloatType]("X")
- val y = new FieldSymbol[FloatType]("Y")
- val z = new FieldSymbol[FloatType]("Z")
- val coord = List(x,y,z)
- }
-
- object CellInfo extends StructType {
- private val module = new FortranModule("simulation_cell")
- addProperty(module)
- def getName = "CELL_INFO"
-
- val public = new NamedUnboundVarSymbol[StructType]("pub_cell", OnetepTypes.CellInfo)
- public.addProperty(module)
-
- val ppdWidth = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("n_pt"+dim)}.toSeq
- val numPPDs = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("n_ppds_a"+dim)}.toSeq
- val pointsInPPD = new FieldSymbol[IntType]("n_pts")
- val latticeReciprocal = for(dim <- 1 to 3) yield new FieldSymbol[StructType]("b"+dim, Point)
- val weight = new FieldSymbol[FloatType]("weight")
- }
-
- object TightBox extends StructType {
- addProperty(new FortranModule("basis"))
- def getName = "FUNCTION_TIGHT_BOX"
-
- val startPts = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("start_pts"+dim)}.toSeq
- val finishPts = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("finish_pts"+dim)}.toSeq
- val startPPD = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("start_ppds"+dim)}.toSeq
- val finishPPD = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("finish_ppds"+dim)}.toSeq
- }
-
- object FFTBoxInfo extends StructType {
- addProperty(new FortranModule("fourier"))
- def getName = "FFTBOX_INFO"
-
- val public = new NamedUnboundVarSymbol[StructType]("pub_fftbox", FFTBoxInfo)
- public.addProperty(new FortranModule("simulation_cell"))
-
- val latticeReciprocal = for(dim <- 1 to 3) yield new FieldSymbol[StructType]("b"+dim, Point)
- val totalPts = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("total_pt"+dim)}.toSeq
- }
-
- object SPAM3 extends StructType {
- addProperty(new FortranModule("sparse"))
- def getName = "SPAM3"
- }
-}
+++ /dev/null
-package ofc.generators.onetep
-import ofc.codegen._
-
-object OnetepVariables {
- // comms
- val pub_my_node_id = new NamedUnboundVarSymbol[IntType]("pub_my_node_id")
- pub_my_node_id.addProperty(new FortranModule("comms"))
-
- // parallel_strategy
- val pub_first_atom_on_node = new NamedUnboundVarSymbol[ArrayType[IntType]]("pub_first_atom_on_node", new ArrayType[IntType](1))
- val pub_num_atoms_on_node = new NamedUnboundVarSymbol[ArrayType[IntType]]("pub_num_atoms_on_node", new ArrayType[IntType](1))
-
- List(pub_first_atom_on_node,
- pub_num_atoms_on_node).map(_.addProperty(new FortranModule("parallel_strategy")))
-}
+++ /dev/null
-package ofc.generators.onetep
-import ofc.codegen._
-
-class IterationInfo(val context: IterationContext, val indexMappings: Map[NamedIndex, Expression[IntType]]) {
- def merge(other: IterationInfo) : IterationInfo =
- new IterationInfo(context merge other.context, indexMappings ++ other.indexMappings)
- def getContext = context
- def getIndexMappings = indexMappings
-}
-
-trait Operand {
- def getIterationInfo : IterationInfo
-}
+++ /dev/null
-package ofc.generators.onetep
-import ofc.codegen._
-import ofc.{LogicError,UnimplementedException}
-/*
-object PPDFunctionSet {
- private class SphereIndex(name: String, value: Expression[IntType]) extends DiscreteIndex {
- def getName = name
- def getValue = value
- }
-
- private class PositionIndex(name: String, value: Expression[IntType], centre: Expression[IntType]) extends FunctionSpatialIndex {
- def getName = name
- def getValue = value
- def getFunctionCentre = centre
- }
-
- def apply(basisName: String, dataName: String) : PPDFunctionSet = {
- import OnetepTypes._
-
- val basis = new NamedUnboundVarSymbol[StructType](basisName, FunctionBasis)
- val data = new NamedUnboundVarSymbol[ArrayType[FloatType]](dataName, new ArrayType[FloatType](1))
-
- val numSpheres = basis % FunctionBasis.num
- val ppdWidths = for(dim <- 0 to 2) yield CellInfo.public % CellInfo.ppdWidth(dim)
- val cellWidthPPDs = for(dim <- 0 to 2) yield CellInfo.public % CellInfo.numPPDs(dim)
- val cellWidthPts = for(dim <- 0 to 2) yield cellWidthPPDs(dim) * ppdWidths(dim)
-
- val producer = new ProducerStatement
- val sphereIndex = producer.addIteration("sphere_index", 1, numSpheres)
- val numPPDs = (~(basis % FunctionBasis.numPPDsInSphere)).at(sphereIndex)
- val ppdIndex = producer.addIteration("ppd_index", 1, numPPDs)
- val sphere = (~(basis % FunctionBasis.spheres)).at(sphereIndex)
- val ppdGlobalCount = (~(sphere % Sphere.ppdList)).at(ppdIndex, 1) - 1
-
- // The integer co-ordinates of the PPD (0-based)
- val a3pos = producer.addExpression("ppd_pos1", ppdGlobalCount / (cellWidthPPDs(0)*cellWidthPPDs(1)))
- val a2pos = producer.addExpression("ppd_pos2", (ppdGlobalCount % (cellWidthPPDs(0)*cellWidthPPDs(1)))/cellWidthPPDs(0))
- val a1pos = producer.addExpression("ppd_pos3", ppdGlobalCount % cellWidthPPDs(0))
- val ppdPos = List(a1pos, a2pos, a3pos)
-
- val tightbox = (~(basis % FunctionBasis.tightBoxes)).at(sphereIndex)
-
- // The offsets into the PPDs for the edges of the tightbox
- val ppdStartOffsets = for(dim <- 0 to 2) yield tightbox % TightBox.startPts(dim) - 1
- val ppdFinishOffsets = for(dim <- 0 to 2) yield tightbox % TightBox.finishPts(dim) - 1
-
- // The first and last PPDs in PPD co-ordinates (0-based, inside simulation cell)
- val startPPDs = for(dim <- 0 to 2) yield
- producer.addExpression("start_ppd"+(dim+1), (tightbox % TightBox.startPPD(dim) + cellWidthPPDs(dim)-1) % cellWidthPPDs(dim))
- val finishPPDs = for(dim <- 0 to 2) yield
- producer.addExpression("finish_ppd"+(dim+1),(tightbox % TightBox.finishPPD(dim) + cellWidthPPDs(dim)-1) % cellWidthPPDs(dim))
-
- // The dimensions of the tightbox
- val tightboxStartPts = for(dim <- 0 to 2) yield
- producer.addExpression("tightbox_start_pt"+(dim+1), startPPDs(dim)*(CellInfo.public % CellInfo.ppdWidth(dim)) + ppdStartOffsets(dim))
-
- val tightboxFinishPts = for(dim <- 0 to 2) yield
- producer.addExpression("tightbox_finsh_pt"+(dim+1), finishPPDs(dim)*(CellInfo.public % CellInfo.ppdWidth(dim)) + ppdFinishOffsets(dim))
-
- val tightboxWidth = for(dim <- 0 to 2) yield
- producer.addExpression("tightbox_width_pt"+(dim+1), (tightboxFinishPts(dim) - tightboxStartPts(dim) + cellWidthPts(dim)) % cellWidthPts(dim))
-
- val tightboxCentre = for(dim <- 0 to 2) yield
- producer.addExpression("tightbox_centre_pt"+(dim+1), ((tightboxStartPts(dim) : Expression[IntType]) + tightboxWidth(dim) / 2) % cellWidthPts(dim))
-
- // Offsets for the current PPD being iterated over
- val loopStarts = for(dim <- 0 to 2) yield
- producer.addExpression("start_pt"+(dim+1), new ConditionalValue[IntType](startPPDs(dim) |==| ppdPos(dim), ppdStartOffsets(dim), 0))
-
- val loopEnds = for(dim <- 0 to 2) yield
- producer.addExpression("end_pt"+(dim+1), new ConditionalValue[IntType](finishPPDs(dim) |==| ppdPos(dim), ppdFinishOffsets(dim), ppdWidths(dim) - 1))
-
- // Loops for iterating over the PPD itself
- val ppdIndices = for(dim <- 0 to 2) yield producer.addIteration("point"+(dim+1), loopStarts(dim), loopEnds(dim))
-
- // Values exposed as indices and data
- val positions = for(dim <- 0 to 2) yield
- producer.addExpression("pos"+(dim+1), ppdPos(dim)*ppdWidths(dim) + ppdIndices(dim))
-
- val ppdDataStart =
- producer.addExpression("ppd_data_start", (sphere % Sphere.offset) + (ppdIndex-1) * (CellInfo.public % CellInfo.pointsInPPD))
-
- val ppdDataIndex = producer.addExpression("ppd_data_index", (ppdDataStart: Expression[IntType])
- + ppdIndices(2) * (CellInfo.public % CellInfo.ppdWidth(1)) * (CellInfo.public % CellInfo.ppdWidth(0))
- + ppdIndices(1) * (CellInfo.public % CellInfo.ppdWidth(0))
- + ppdIndices(0))
-
- val dataValue = producer.addExpression("data", data.at(ppdDataIndex))
-
- val discreteIndices = List[DiscreteIndex](new SphereIndex("sphere", sphereIndex))
- val spatialIndices = {
- val indexNames = List("x", "y", "z")
- for (dim <- 0 to 2) yield new PositionIndex(indexNames(dim), positions(dim), tightboxCentre(dim))
- }
-
- new PPDFunctionSet(discreteIndices, spatialIndices, dataValue, producer)
- }
-}
-*/
-
-class PPDFunctionSet(val basis: Expression[StructType], val data: Expression[ArrayType[FloatType]], indices: Seq[NamedIndex]) extends Field {
-
- class LocalFragment(parent: PPDFunctionSet, indices: Map[NamedIndex, Expression[IntType]]) extends PsincFragment {
- def setup(context: GenerationContext) {}
- def teardown(context: GenerationContext) {}
- def toReciprocal : ReciprocalFragment = toDensePsinc.toReciprocal
- def toDensePsinc = new LocalDense(parent, indices)
- }
-
- class LocalDense(parent: PPDFunctionSet, indices: Map[NamedIndex, Expression[IntType]]) extends DensePsincFragment {
- import OnetepTypes.FunctionBasis
-
- val sphereIndex = indices.get(parent.getSphereIndex) match {
- case Some(expression) => expression
- case None => throw new LogicError("Cannot find expression for index "+parent.getSphereIndex)
- }
-
- val fftbox = new DeclaredVarSymbol[ArrayType[FloatType]]("fftbox", new ArrayType[FloatType](3))
- fftbox.addProperty(new AllocatableProperty)
-
- val tightbox = (~(basis % FunctionBasis.tightBoxes)).at(sphereIndex)
- val sphere = (~(basis % FunctionBasis.spheres)).at(sphereIndex)
- val fftboxOffset = for(dim <- 0 to 2) yield new DeclaredVarSymbol[IntType]("fftbox_offset"+(dim+1))
- val tightboxOrigin = for(dim <- 0 to 2) yield new DeclaredVarSymbol[IntType]("tightbox_origin"+(dim+1))
-
- def setup(context: GenerationContext) {
- import OnetepTypes.FFTBoxInfo
-
- context.addDeclaration(fftbox)
- fftboxOffset.map(context.addDeclaration(_))
-
- val fftboxSize : Seq[Expression[IntType]] = getSize
- context += new AllocateStatement(fftbox, fftboxSize)
- context += new FunctionCallStatement(new FunctionCall(OnetepFunctions.basis_ket_start_wrt_fftbox,
- fftboxOffset.map(new VarRef[IntType](_)) ++ fftboxSize))
-
- var basisCopyParams : Seq[Expression[_]] = Nil
- basisCopyParams :+= (fftbox: Expression[ArrayType[FloatType]])
- basisCopyParams ++= fftboxSize
- basisCopyParams ++= fftboxOffset.map(new VarRef[IntType](_))
- basisCopyParams :+= tightbox
- basisCopyParams :+= (parent.data: Expression[ArrayType[FloatType]])
- basisCopyParams :+= sphere
-
- context += new FunctionCallStatement(new FunctionCall(OnetepFunctions.basis_copy_function_to_box, basisCopyParams))
-
- for (dim <- 0 to 2) yield {
- import OnetepTypes._
- val startPPD = tightbox % TightBox.startPPD(dim) - 1
- val startPPDPoint = startPPD * (CellInfo.public % CellInfo.ppdWidth(dim))
- val startPoint = startPPDPoint + tightbox % TightBox.startPts(dim)
-
- context.addDeclaration(tightboxOrigin(dim))
- context += new AssignStatement(tightboxOrigin(dim), startPoint)
- }
- }
-
- def teardown(context: GenerationContext) {
- context += new DeallocateStatement(fftbox)
- }
-
- def getSize = for (dim <- 0 to 2) yield OnetepTypes.FFTBoxInfo.public % OnetepTypes.FFTBoxInfo.totalPts(dim)
-
- def getOrigin = {
- for (dim <- 0 to 2) yield
- tightboxOrigin(dim) - fftboxOffset(dim)
- }
-
- def getBuffer = fftbox
-
- def toReciprocal = new DensePsincToReciprocal(this, indices)
- }
-
- private def getSphereIndex = indices.head
-
- def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : FieldFragment =
- new LocalFragment(this, indices)
-
- def getIterationInfo : IterationInfo = {
- val context = new IterationContext
- val numSpheres = basis % OnetepTypes.FunctionBasis.num
- val sphereIndexExpr = context.addIteration("sphere_index", 1, numSpheres)
- new IterationInfo(context, Map(getSphereIndex -> sphereIndexExpr))
- }
-}
+++ /dev/null
-package ofc.generators.onetep
-import ofc.codegen._
-
-class ReciprocalToPsinc(op: ReciprocalFragment) extends DensePsincFragment {
- val fftbox = new DeclaredVarSymbol[ArrayType[FloatType]]("fftbox", new ArrayType[FloatType](3))
- fftbox.addProperty(new AllocatableProperty)
-
- val dummybox = new DeclaredVarSymbol[ArrayType[FloatType]]("dummybox", new ArrayType[FloatType](3))
- dummybox.addProperty(new AllocatableProperty)
-
- def toReciprocal = op
-
- def setup(context: GenerationContext) {
- op.setup(context)
- context.addDeclaration(fftbox)
- context.addDeclaration(dummybox)
-
- val fftboxSize : Seq[Expression[IntType]] = getSize
- context += new AllocateStatement(fftbox, fftboxSize)
- context += new AllocateStatement(dummybox, fftboxSize)
-
- val fourierParams : Seq[Expression[_]] = Seq(new CharLiteral('C'), new CharLiteral('B'), fftbox, dummybox, op.getBuffer)
- context += new FunctionCallStatement(new FunctionCall(OnetepFunctions.fourier_apply_box_pair, fourierParams))
-
- op.teardown(context)
- }
-
- def teardown(context: GenerationContext) {
- context += new DeallocateStatement(fftbox)
- context += new DeallocateStatement(dummybox)
- }
-
- def getSize = for (dim <- 0 to 2) yield OnetepTypes.FFTBoxInfo.public % OnetepTypes.FFTBoxInfo.totalPts(dim)
-
- def getOrigin = op.getOrigin
-
- def getBuffer = fftbox
-}
+++ /dev/null
-package ofc.generators.onetep
-import ofc.codegen._
-
-class SPAM3(mat: Expression[StructType], position: Seq[NamedIndex]) extends Scalar {
-
- class LocalFragment(row: Expression[IntType], col: Expression[IntType]) extends ScalarFragment {
- def setup(context: GenerationContext) {
- }
-
- def getValue = throw new ofc.UnimplementedException("get unimplemented for SPAM3")
-
- def setValue(context: GenerationContext, value: Expression[FloatType]) {
- val functionCall = new FunctionCall(OnetepFunctions.sparse_put_element_real,
- Seq(value, mat, row, col))
- context += new FunctionCallStatement(functionCall)
- }
-
- def teardown(context: GenerationContext) {
- }
- }
-
- def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : ScalarFragment =
- new LocalFragment(indices.get(position(0)).get, indices.get(position(1)).get)
-
- def getIterationInfo : IterationInfo = {
- val context = new IterationContext
-
- // Create sparse index
- val header = new BlockStatement
- val indexLength = new FunctionCall(OnetepFunctions.sparse_index_length, Seq(mat))
- val index = new DeclaredVarSymbol[ArrayType[IntType]]("sparse_idx", new ArrayType[IntType](1))
- index.addProperty(new AllocatableProperty)
-
- header += new AllocateStatement(index, Seq(indexLength))
- header += new FunctionCallStatement(new FunctionCall(OnetepFunctions.sparse_generate_index, Seq(index, mat)))
- context.addHeader(header)
- context.addDeclaration(index)
-
- val footer = new DeallocateStatement(index)
- context.addFooter(footer)
-
- val firstCol = new FunctionCall(OnetepFunctions.sparse_first_elem_on_node,
- Seq(OnetepVariables.pub_my_node_id,
- mat,
- new CharLiteral('C')))
-
- val lastCol = new FunctionCall(OnetepFunctions.sparse_first_elem_on_node,
- Seq((OnetepVariables.pub_my_node_id: Expression[IntType])+1,
- mat,
- new CharLiteral('C'))) - 1
-
- val firstRow = new FunctionCall(OnetepFunctions.sparse_first_elem_on_node,
- Seq(OnetepVariables.pub_my_node_id,
- mat,
- new CharLiteral('R')))
-
- val lastRow = new FunctionCall(OnetepFunctions.sparse_first_elem_on_node,
- Seq((OnetepVariables.pub_my_node_id: Expression[IntType])+1,
- mat,
- new CharLiteral('R'))) - 1
-
- val col = context.addIteration("col", firstCol, lastCol)
- val colAtom = context.addExpression("col_atom", new FunctionCall(OnetepFunctions.sparse_atom_of_elem,
- Seq(col, mat, new CharLiteral('C'))))
- val localColAtom = context.addExpression("local_col_atom",
- colAtom - OnetepVariables.pub_first_atom_on_node.at(OnetepVariables.pub_my_node_id) + 1)
-
- val row = context.addIteration("row", firstRow, lastRow)
- val rowAtom = context.addExpression("row_atom", new FunctionCall(OnetepFunctions.sparse_atom_of_elem,
- Seq(row, mat, new CharLiteral('R'))))
-
- val rowIdx = context.addIteration("row_idx", index.at(localColAtom),
- index.at((localColAtom: Expression[IntType])+1)-1)
- context.addPredicate(index.at(rowIdx) |==| rowAtom)
-
- var indexMappings : Map[NamedIndex, Expression[IntType]] = Map.empty
- indexMappings += position(0) -> row
- indexMappings += position(1) -> col
-
- new IterationInfo(context, indexMappings)
- }
-}
+++ /dev/null
-package ofc.generators.onetep
-import ofc.codegen._
-
-trait Scalar extends Operand {
- def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : ScalarFragment
-}
+++ /dev/null
-package ofc.generators.onetep
-import ofc.codegen._
-
-trait ScalarFragment extends Fragment {
- def getValue : Expression[FloatType]
- def setValue(context: GenerationContext, value: Expression[FloatType])
-}
-
-trait NonAssignableScalarFragment {
- def setValue(context: GenerationContext, value: Expression[FloatType]) {
- throw new ofc.LogicError("Expression: "+this+" is not assignable.")
- }
-}
+++ /dev/null
-package ofc.generators.onetep
-import ofc.codegen._
-
-class ScalarLiteral(s: Double) extends Scalar {
- class LocalFragment(s: Double) extends ScalarFragment with NonAssignableScalarFragment {
- def setup(context: GenerationContext) {
- }
-
- def getValue =
- new FloatLiteral(s)
-
- def teardown(context: GenerationContext) {
- }
- }
-
- def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : ScalarFragment =
- new LocalFragment(s)
-
- def getIterationInfo : IterationInfo =
- new IterationInfo(new IterationContext, Map.empty)
-}
+++ /dev/null
-package ofc.generators.onetep
-import ofc.codegen._
-
-class ScaledField(op: Field, factor: Scalar) extends Field {
- class LocalFragment(parent: ScaledField, indices: Map[NamedIndex, Expression[IntType]]) extends DensePsincFragment {
- val transformed = new DeclaredVarSymbol[ArrayType[FloatType]]("scaled", new ArrayType[FloatType](3))
- transformed.addProperty(new AllocatableProperty)
-
- val scaleFragment = parent.getScalingFactor.getFragment(indices)
- val opFragment = parent.getOperand.getFragment(indices).toDensePsinc
-
- def setup(context: GenerationContext) {
- context.addDeclaration(transformed)
- opFragment.setup(context)
- scaleFragment.setup(context)
-
- context += new AllocateStatement(transformed, opFragment.getSize)
-
- val indices = for(dim <- 0 to 2) yield {
- val index = new DeclaredVarSymbol[IntType]("i"+(dim+1))
- context.addDeclaration(index)
- index
- }
-
- // Construct loops
- val loops = for(dim <- 0 to 2) yield new ForLoop(indices(dim), 1, getSize(dim))
-
- // Nest loops and add outer to context
- for(dim <- 1 to 2) loops(dim) += loops(dim-1)
- context += loops(2)
-
- val index = indices.map(new VarRef[IntType](_))
-
- loops(0) += new AssignStatement(transformed.at(index: _*),
- opFragment.getBuffer.at(index: _*) * scaleFragment.getValue)
-
- opFragment.teardown(context)
- scaleFragment.teardown(context)
- }
-
- private def magnitude(vector: Seq[Expression[FloatType]]) = {
- var result : Expression[FloatType] = new FloatLiteral(0.0)
- for(element <- vector) result += element * element
- result
- }
-
- def teardown(context: GenerationContext) {
- context += new DeallocateStatement(transformed)
- }
-
- def getSize = opFragment.getSize
-
- def getOrigin = opFragment.getOrigin
-
- def getBuffer = transformed
-
- def toReciprocal = new DensePsincToReciprocal(this, indices)
- }
-
- private def getOperand = op
-
- private def getScalingFactor = factor
-
- def getFragment(indices: Map[NamedIndex, Expression[IntType]]) =
- new LocalFragment(this, indices)
-
- def getIterationInfo : IterationInfo =
- op.getIterationInfo merge factor.getIterationInfo
-}
+++ /dev/null
-package ofc.parser
-import scala.util.parsing.combinator._
-import scala.util.parsing.input.Position
-import java.io.Reader
-
-sealed abstract class ScalarOperationTag
-case class MultiplicationTag() extends ScalarOperationTag
-case class DivisionTag() extends ScalarOperationTag
-
-class ParseException(message: String, pos: Position) extends Exception {
- override def toString : String =
- message + " at line " + pos.line + ", column " + pos.column + "."
-}
-
-class Parser extends JavaTokenParsers {
- def program : Parser[List[Statement]] = rep(comment | declarations | definition | target | comment | specifics)
- def comment : Parser[Comment] = "#"~!".*".r ^^ (v => new Comment(v._2))
- def identifier : Parser[Identifier] = ident ^^ (v => new Identifier(v))
-
- def oflType : Parser[OFLType] = arrayType | functionSetType | indexType
- def arrayType : Parser[Matrix] = "Array["~>repsep(indexType, ",")<~"]" ^^ (x => new Matrix(x))
- def functionSetType : Parser[FunctionSet] = "FunctionSet" ^^ (_ => new FunctionSet)
- def indexType : Parser[Index] = "FunctionIndex" ^^ (_ => new Index)
-
- def declarations: Parser[DeclarationList] = oflType~!repsep(identifier, ",") ^^
- (d => new DeclarationList(d._1, d._2))
-
- def definition: Parser[Definition] = indexedIdentifier~("="~>expr) ^^ (x => new Definition(x._1, x._2))
- def expr : Parser[Expression] = term~opt(scalarOperator~!expr) ^^
- (x => x._2 match {
- case None => x._1
- case Some(y) => y._1 match {
- case MultiplicationTag() => new Multiplication(x._1, y._2)
- case DivisionTag() => new Division(x._1, y._2)
- }
- })
-
- def scalarOperator : Parser[ScalarOperationTag] = mulOp | divOp
- def mulOp : Parser[ScalarOperationTag] = "*" ^^ (_ => MultiplicationTag())
- def divOp : Parser[ScalarOperationTag] = "/" ^^ (_ => DivisionTag())
-
- def term: Parser[Expression] = scalarConstant ||| indexedIdentifier ||| operator
- def scalarConstant : Parser[ScalarConstant] = floatingPointNumber ^^ (x => new ScalarConstant(x.toDouble))
- def indexedIdentifier: Parser[IndexedIdentifier] = identifier~opt("["~>repsep(identifier, ",")<~"]") ^^
- (x => new IndexedIdentifier(x._1, x._2 match {
- case Some(list) => list
- case None => Nil
- }))
- def operator : Parser[Operator] = identifier~("("~>repsep(expr, ",")<~")") ^^ (x => new Operator(x._1, x._2))
-
- def target : Parser[Target] = "target"~!identifier ^^ (x => new Target(x._2))
- def specifics : Parser[TargetAssignment] = "Variable"~>identifier~("="~>functionCall) ^^ (x => new TargetAssignment(x._1, x._2))
- def functionCall : Parser[FunctionCall] = identifier~("("~>repsep(functionParameter, ",")<~")") ^^
- (x => new FunctionCall(x._1, new ParameterList(x._2 : _*)))
-
- def functionParameter : Parser[Parameter] = stringParameter | numericParameter | identifier | parameterList
- def stringParameter : Parser[StringParameter] = stringLiteral ^^ (x => new StringParameter(x.slice(1, x.length-1)))
- def numericParameter : Parser[NumericParameter] = floatingPointNumber ^^ (x => new NumericParameter(x.toDouble))
- def parameterList : Parser[ParameterList] = "["~>repsep(functionParameter, ",")<~"]" ^^ (x => new ParameterList(x : _*))
-
- def parseProgram(in: Reader) : List[Statement] =
- parseAll(program, in) match {
- case Success(result, _) => result
- case r : NoSuccess =>
- throw new ParseException(r.msg, r.next.pos)
- }
-}
+++ /dev/null
-package ofc.parser
-
-case class Identifier(name: String) extends Parameter {
- override def toString : String = "id(\""+name+"\")"
- def getName = name
-}
-
-sealed abstract class Statement
-class Comment(value: String) extends Statement {
- override def toString : String = "comment(\""+value+"\")"
-}
-case class DeclarationList(oflType: OFLType, names: List[Identifier]) extends Statement {
- override def toString : String = "decl("+oflType+", "+names+")"
-}
-case class Definition(term: IndexedIdentifier, expr: Expression) extends Statement {
- override def toString : String = "define("+term+", "+expr+")"
-}
-case class Target(name: Identifier) extends Statement {
- override def toString : String = "target("+name+")"
-}
-case class TargetAssignment(id: Identifier, value: FunctionCall) extends Statement {
- override def toString : String = "target_assignment("+id+", "+value+")"
-}
-
-sealed abstract class OFLType
-case class Matrix(indices: List[Index]) extends OFLType {
- override def toString : String = "Matrix"
-}
-case class FunctionSet() extends OFLType {
- override def toString : String = "FunctionSet"
-}
-case class Index() extends OFLType {
- override def toString : String = "Index"
-}
-
-sealed abstract class Expression
-case class ScalarConstant(s: Double) extends Expression {
- override def toString : String = s.toString
-}
-case class IndexedIdentifier(id: Identifier, indices : List[Identifier]) extends Expression {
- override def toString : String = id+indices.mkString("[", ", ", "]")
-}
-case class Operator(id: Identifier, operands : List[Expression]) extends Expression {
- override def toString : String = "operator("+id+ ", " + operands.mkString("[", ", ", "]")+")"
-}
-
-sealed abstract class ScalarOperation extends Expression
-case class Multiplication(a: Expression, b: Expression) extends ScalarOperation {
- override def toString : String = "mul("+a.toString+","+b.toString+")"
-}
-case class Division(a: Expression, b: Expression) extends ScalarOperation {
- override def toString : String = "div("+a.toString+","+b.toString+")"
-}
-
-case class FunctionCall(name: Identifier, params: ParameterList) {
- override def toString : String = "call("+name+", "+params+")"
-}
-
-sealed abstract class Parameter
-case class ParameterList(params: Parameter*) extends Parameter {
- override def toString : String = params.mkString("[", ", ", "]")
- def toSeq : Seq[Parameter] = params.toSeq
-}
-case class StringParameter(value: String) extends Parameter {
- override def toString : String = "\""+value+"\""
-}
-case class NumericParameter(value: Double) extends Parameter {
- override def toString : String = value.toString
-}
+++ /dev/null
-package ofc.util
-import ofc.LogicError
-
-object DirectedGraph {
- private def topoSort(graph: DirectedGraph, queue: Queue[DirectedGraph#Vertex]) = {
- type Vertex = DirectedGraph#Vertex
- val degrees = scala.collection.mutable.Map[Vertex, Int]()
- val result = scala.collection.mutable.ArrayBuffer[Vertex]()
-
- degrees ++= { for (v <- graph.vertices) yield (v, graph.inDegree(v)) }
- queue ++= { for (v <- graph.vertices; if graph.inDegree(v) == 0) yield v }
-
- while(queue.nonEmpty) {
- val top = queue.pop()
- result += top
-
- for(outEdge <- graph.outEdges(top)) {
- val target = graph.target(outEdge)
- degrees.get(target) match {
- case None => throw new LogicError("Unknown vertex in topological sort")
- case Some(degree) => {
- val newDegree = degree-1
- degrees += (target -> newDegree)
- if (newDegree == 0) queue += target
- if (newDegree < 0) throw new LogicError("Degree of vertex dropped below zero in topological sort.")
- }
- }
- }
- }
-
- if (degrees.values.filter(x => x>0).nonEmpty) throw new LogicError("Cycle in graph.")
- result.toSeq
- }
-
- def topoSort(graph: DirectedGraph) : Seq[DirectedGraph#Vertex] =
- topoSort(graph, new StackQueue[DirectedGraph#Vertex]())
-
- def topoSort(graph: DirectedGraph, priority: DirectedGraph#Vertex => Int) : Seq[DirectedGraph#Vertex] =
- topoSort(graph, new PriorityQueue[DirectedGraph#Vertex](priority))
-}
-
-class DirectedGraph extends GraphBase {
- protected def canonicalEdge(e: Edge) = e
-
- def inEdges(v: Vertex) = {
- val info = getInfo(v)
- info.in.map(x => (x, v))
- }
-
- def outEdges(v: Vertex) = {
- val info = getInfo(v)
- info.out.map(x => (v, x))
- }
-
- def inDegree(v: Vertex) = {
- val info = getInfo(v)
- info.in.size
- }
-
- def outDegree(v: Vertex) = {
- val info = getInfo(v)
- info.out.size
- }
-}
+++ /dev/null
-package ofc.util
-
-trait Graph {
- type Vertex
- type Edge
- def nullVertex : Vertex
-}
-
-trait IncidenceGraph extends Graph {
- def source(e: Edge) : Vertex
- def target(e: Edge) : Vertex
- def outEdges(v: Vertex) : Traversable[Edge]
- def outDegree(v: Vertex) : Int
-}
-
-trait BidirectionalGraph extends IncidenceGraph {
- def inEdges(v: Vertex) : Traversable[Edge]
- def inDegree(v: Vertex) : Int
- def degree(v: Vertex) : Int
-}
-
-trait VertexListGraph extends Graph {
- def vertices : Traversable[Vertex]
- def numVertices: Int
-}
-
-trait EdgeListGraph extends Graph {
- def edges : Traversable[Edge]
- def numEdges : Int
-}
-
-trait AdjacencyMatrix extends Graph {
- def hasEdge(u: Vertex, v: Vertex) : Boolean
-}
-
-trait MutableGraph extends Graph {
- def addVertex() : Vertex
- def removeVertex(v: Vertex) : Unit
- def addEdge(u: Vertex, v: Vertex) : Edge
- def removeEdge(e: Edge) : Unit
-}
+++ /dev/null
-package ofc.util
-import ofc.LogicError
-
-abstract class GraphBase extends BidirectionalGraph with VertexListGraph with EdgeListGraph with AdjacencyMatrix with MutableGraph {
- import scala.collection.mutable
-
- type Vertex = Int
- type Edge = (Vertex, Vertex)
-
- case class VertexInfo(var in: Set[Vertex], var out: Set[Vertex]) {
- def this() = this(Set.empty, Set.empty)
- }
-
- val nullVertex = 0
- private var lastVertex = nullVertex
- private val vertexMap = collection.mutable.Map[Vertex, VertexInfo]()
-
- protected def getInfo(v: Vertex) : VertexInfo = {
- vertexMap.get(v) match {
- case Some(vertexInfo) => vertexInfo
- case _ => throw new LogicError("Cannot find vertex "+v+" in graph.")
- }
- }
-
- protected def canonicalEdge(e: Edge) : Edge
-
- def source(e: Edge) = e._1
- def target(e: Edge) = e._2
-
- def vertices = vertexMap.keys
- def numVertices = vertexMap.size
-
- def degree(v: Vertex) = {
- val info = getInfo(v)
- info.in.size + info.out.size
- }
-
- def addVertex = {
- lastVertex += 1
- vertexMap += (lastVertex -> new VertexInfo)
- lastVertex
- }
-
- def addEdge(u: Vertex, v: Vertex) = {
- val canonical = canonicalEdge(u -> v)
- getInfo(source(canonical)).out += target(canonical)
- getInfo(target(canonical)).in += source(canonical)
- u -> v
- }
-
- def removeEdge(e: Edge) {
- val canonical = canonicalEdge(e)
- getInfo(source(e)).out -= target(e)
- getInfo(target(e)).in -= source(e)
- }
-
- def hasEdge(u: Vertex, v: Vertex) = {
- val canonical = canonicalEdge(u -> v)
- val info = getInfo(source(canonical))
- info.out.contains(target(canonical))
- }
-
- def removeVertex(v: Vertex) = {
- val info = getInfo(v)
- for (in <- info.in) getInfo(in).out -= v
- for (out <- info.out) getInfo(out).in -= v
- vertexMap -= v
- }
-
- def edges = for(fromInfo <- vertexMap; to <- fromInfo._2.out) yield (fromInfo._1, to)
-
- def numEdges = vertexMap.values.map(info => info.out.size).sum
-}
+++ /dev/null
-package ofc.util
-
-object Ordering {
- def transitiveClosure[T](nodes: Seq[T], hasPath: (T,T) => Boolean) : Set[(T,T)] = {
- val ordering = scala.collection.mutable.Set[(T, T)]()
- ordering ++= { for(n1 <- nodes; n2 <- nodes; if hasPath(n1, n2)) yield (n1 -> n2) }
-
- for(via <- nodes; start <- nodes; end <- nodes)
- if (!ordering.contains(start -> end) && ordering.contains(start -> via) && ordering.contains(via -> end))
- ordering += (start -> end)
-
- ordering.toSet
- }
-}
+++ /dev/null
-package ofc.util
-import scala.collection.generic.Growable
-
-trait Queue[A] extends Growable[A] {
- def pop() : A
- def nonEmpty : Boolean
-}
-
-class StackQueue[A] extends Queue[A] {
- private var stack = List[A]()
-
- def nonEmpty = stack.nonEmpty
-
- def +=(e: A) = {
- stack = (e :: stack)
- this
- }
-
- def pop() = {
- val (head, tail) = (stack.head, stack.tail)
- stack = tail
- head
- }
-
- def clear() {
- stack = Nil
- }
-}
-
-// Nodes with smaller priorities are popped first
-class PriorityQueue[A](priority: A => Int) extends Queue[A] {
- // This class helps create an artificial total ordering by comparing
- // unique integers if the compare function is non-total. This is
- // probably unnecessary since priority queues permit duplicates.
- private class UniqueOrdering(priority: A => Int) extends scala.math.Ordering[(A, Int)] {
- def compare(x: (A, Int), y: (A, Int)) = {
- val xPri = priority(x._1)
- val yPri = priority(y._1)
- val comparison = -xPri.compareTo(yPri)
-
- if (comparison != 0)
- comparison
- else
- x._2.compareTo(y._2)
- }
- }
-
- var uniqueID = 0
- val queue = new scala.collection.mutable.PriorityQueue[(A, Int)]()(new UniqueOrdering(priority))
-
- def nonEmpty = queue.nonEmpty
- def pop() = queue.dequeue()._1
- def clear() = queue.clear()
-
- def +=(e: A) = {
- queue += (e -> uniqueID)
- uniqueID += 1
- this
- }
-}
-
+++ /dev/null
-package ofc.util
-
-class UndirectedGraph extends GraphBase {
- protected def canonicalEdge(e: Edge) = if (e._1 < e._2) e else (e._2, e._1)
-
- def inEdges(v: Vertex) = outEdges(v).map(_.swap)
-
- def outEdges(v: Vertex) = {
- val info = getInfo(v)
- (info.in.toSeq ++ info.out.toSeq).map(x => (v, x))
- }
-
- def inDegree(v: Vertex) = degree(v)
-
- def outDegree(v: Vertex) = degree(v)
-}