]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Fix so generated code compiles (except for long lines).
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Sun, 8 Apr 2012 02:57:00 +0000 (03:57 +0100)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Sun, 8 Apr 2012 02:57:00 +0000 (03:57 +0100)
src/ofc/codegen/Expression.scala
src/ofc/codegen/Type.scala
src/ofc/generators/onetep/OnetepTypes.scala
src/ofc/generators/onetep/PPDFunctionSet.scala

index 7325ce4ebf5679623ff924d30f999e287dbc32e6..fd5b00a570d2199cf70514d99ef8c72ba576f5cc 100644 (file)
@@ -1,4 +1,5 @@
 package ofc.codegen
+import ofc.LogicError
 
 object Expression {
   implicit def fromInt(i: Int) : Expression[IntType] = new IntegerLiteral(i)
@@ -84,6 +85,9 @@ class FieldAccess[T <: Type](expression: Expression[StructType], field: FieldSym
 }
 
 class ArrayRead[E <: Type](expression: Expression[ArrayType[E]], index: Seq[Expression[IntType]]) extends Expression[E] {
+  if (index.size != expression.getType.getRank)
+    throw new LogicError("Array of rank "+expression.getType.getRank+" indexed with rank "+index.size+" index.")
+
   def foreach[U](f: Expression[_] => U) = (index :+ expression).foreach(f)
   def getArrayExpression = expression
   def getIndexExpressions = index
index 4d07aec5d22d5717c6dd6f5e6b02362216e77afb..8f90c1ed5a2d9c8021b7f2b655952615ba92f650 100644 (file)
@@ -22,6 +22,7 @@ final case class ArrayType[ElementType <: Type](rank: Int, eType: ElementType) e
   def this(rank: Int)(implicit builder: TypeBuilder[ElementType]) = this(rank, builder())
   def getElementType = eType
   def getFortranAttributes = eType.getFortranAttributes ++ Set("allocatable", (":"*rank).mkString("dimension(",",",")"))
+  def getRank = rank
 }
 
 final case class PointerType[TargetType <: Type](tType: TargetType) extends Type {
index 9770d65c27c63a6022d57bd5b3b762058f7ddf88..40e6cc76684572815c8a322bb15aa4e0b37b6957 100644 (file)
@@ -8,11 +8,6 @@ object OnetepTypes {
       new FieldSymbol[PointerType[ArrayType[IntType]]]("n_ppds_sphere", fieldType)
     }
 
-    val ppdList = {
-      val fieldType = new PointerType[ArrayType[IntType]](new ArrayType[IntType](1))
-      new FieldSymbol[PointerType[ArrayType[IntType]]]("ppd_list", fieldType)
-    }
-
     val num = new FieldSymbol[IntType]("num")
     
     val tightBoxes = {
@@ -20,9 +15,23 @@ object OnetepTypes {
       new FieldSymbol[PointerType[ArrayType[StructType]]]("tight_boxes", fieldType)
     }
 
+    val spheres = {
+      val fieldType = new PointerType[ArrayType[StructType]](new ArrayType[StructType](1, Sphere))
+      new FieldSymbol[PointerType[ArrayType[StructType]]]("spheres", fieldType)
+    }
+
     def getFortranAttributes = Set("type(FUNC_BASIS)")
   }
 
+  object Sphere extends StructType {
+    val ppdList = {
+      val fieldType = new PointerType[ArrayType[IntType]](new ArrayType[IntType](2))
+      new FieldSymbol[PointerType[ArrayType[IntType]]]("ppd_list", fieldType)
+    }
+
+    def getFortranAttributes = Set("type(SPHERE)")
+  }
+
   object CellInfo extends StructType {
     val ppdWidth = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("n_pt"+dim)}.toSeq
     val numPPDs = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("n_ppds_a"+dim)}.toSeq
@@ -32,8 +41,8 @@ object OnetepTypes {
   object TightBox extends StructType {
     val startPts = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("start_pts"+dim)}.toSeq
     val finishPts = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("finish_pts"+dim)}.toSeq
-    val startPPD = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("start_ppd"+dim)}.toSeq
-    val finishPPD = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("finish_ppd"+dim)}.toSeq
+    val startPPD = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("start_ppds"+dim)}.toSeq
+    val finishPPD = {for (dim <- 1 to 3) yield new FieldSymbol[IntType]("finish_ppds"+dim)}.toSeq
     def getFortranAttributes = Set("type(FUNCTION_TIGHT_BOX)")
   }
 }
index 827a02e22b6b438d30ae28e6ab82a9356a6eeb1b..ce91b26b41176474193b8b5f2c06e5abe85f79f0 100644 (file)
@@ -17,7 +17,7 @@ class PPDFunctionSet(val basisName: String, dataName: String) extends FunctionSe
     val sphereIndex = producer.addIteration("sphere_index", 1, numSpheres)
     val numPPDs = (~(basis % FunctionBasis.numPPDsInSphere)).readAt(sphereIndex)
     val ppdIndex = producer.addIteration("ppd_index", 1, numPPDs)
-    val ppdGlobalCount = (~(basis % FunctionBasis.ppdList)).readAt(ppdIndex, 1) - 1
+    val ppdGlobalCount = (~((~(basis % FunctionBasis.spheres)).readAt(sphereIndex) % Sphere.ppdList)).readAt(ppdIndex, 1) - 1
 
     // The integer co-ordinates of the PPD (0-based)
     val a3pos = producer.addExpression("ppd_pos1", ppdGlobalCount / (cellWidthInPPDs(0)*cellWidthInPPDs(1)))