case (x : NullStatement) => ()
case (x : Comment) => addLine("!" + x.getValue)
case (x : BlockStatement) => processScope(x)
- case (x : IteratedStatement) => processStatement(x.toConcrete)
+ case (x : IterationContext) => processStatement(x.toConcrete)
case (x : ForLoop) => processForLoop(x)
case (a : AssignStatement) => processAssignment(a)
case (i : IfStatement) => processIf(i)
import ofc.LogicError
import ofc.util.DirectedGraph
-object IteratedStatement {
+object IterationContext {
object Context {
private def priority(context: Context) : Int = {
// This ensures that the nesting ordering is Predicate, DerivedExpression, VariableRange
}
}
-class IteratedStatement extends Statement {
- import IteratedStatement._
+class IterationContext extends Statement {
+ import IterationContext._
var statement = new Comment("Placeholder statement for consumer.")
var ranges : Seq[VariableRange] = Nil
predicates +:= new Predicate(condition)
}
- def merge(statement: IteratedStatement) : IteratedStatement = {
- val result = new IteratedStatement
+ def merge(statement: IterationContext) : IterationContext = {
+ val result = new IterationContext
result.ranges = ranges ++ statement.ranges
result.predicates = predicates ++ statement.predicates
result.expressions = expressions ++ statement.expressions
package ofc.generators.onetep
import ofc.codegen._
-trait Field {
+trait Field extends Operand {
def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : FieldFragment
}
def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : ScalarFragment =
new LocalFragment(left.getFragment(indices), right.getFragment(indices))
+
+ def getIterationInfo : IterationInfo = {
+ var leftInfo = left.getIterationInfo
+ var rightInfo = right.getIterationInfo
+ leftInfo merge rightInfo
+ }
}
def getFragment(indices: Map[NamedIndex, Expression[IntType]]) =
new LocalFragment(this, indices)
+
+ def getIterationInfo : IterationInfo =
+ op.getIterationInfo
}
val totalPts = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("total_pt"+dim)}.toSeq
def getFortranAttributes = Set("type(FFTBOX_INFO)")
}
+
+ object SPAM3 extends StructType {
+ def getFortranAttributes = Set("type(SPAM3)")
+ }
}
--- /dev/null
+package ofc.generators.onetep
+import ofc.codegen._
+
+object OnetepVariables {
+ // parallel_strategy
+ val pub_first_atom_on_node = new NamedUnboundVarSymbol[ArrayType[IntType]]("pub_first_atom_on_node", new ArrayType[IntType](1))
+ val pub_num_atoms_on_node = new NamedUnboundVarSymbol[ArrayType[IntType]]("pub_num_atoms_on_node", new ArrayType[IntType](1))
+}
--- /dev/null
+package ofc.generators.onetep
+import ofc.codegen._
+
+class IterationInfo(val context: IterationContext, val indexMappings: Map[NamedIndex, Expression[IntType]]) {
+ def merge(other: IterationInfo) : IterationInfo =
+ new IterationInfo(context merge other.context, indexMappings ++ other.indexMappings)
+}
+
+trait Operand {
+ def getIterationInfo : IterationInfo
+}
def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : FieldFragment =
new LocalFragment(this, indices)
+
+ def getIterationInfo : IterationInfo = {
+ val context = new IterationContext
+ val numSpheres = basis % OnetepTypes.FunctionBasis.num
+ val sphereIndexExpr = context.addIteration("sphere_index", 1, numSpheres)
+ new IterationInfo(context, Map(getSphereIndex -> sphereIndexExpr))
+ }
}
import ofc.codegen._
class SPAM3(name : String, indices: Seq[NamedIndex]) extends Scalar {
+ val mat = new NamedUnboundVarSymbol[StructType](name, OnetepTypes.SPAM3)
+
class LocalFragment extends ScalarFragment {
def setup(context: GenerationContext) {
}
def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : ScalarFragment =
new LocalFragment
+
+ def getIterationInfo : IterationInfo = {
+ val context = new IterationContext
+ throw new ofc.UnimplementedException("not yet implemented.")
+ }
}
package ofc.generators.onetep
import ofc.codegen._
-trait Scalar {
+trait Scalar extends Operand {
def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : ScalarFragment
}
def getFragment(indices: Map[NamedIndex, Expression[IntType]]) : ScalarFragment =
new LocalFragment(s)
+
+ def getIterationInfo : IterationInfo =
+ new IterationInfo(new IterationContext, Map.empty)
}
def getFragment(indices: Map[NamedIndex, Expression[IntType]]) =
new LocalFragment(this, indices)
+ def getIterationInfo : IterationInfo =
+ op.getIterationInfo merge factor.getIterationInfo
}