From: Francis Russell Date: Tue, 25 Oct 2011 06:51:28 +0000 (+0000) Subject: Initial implementation of form compiler for ONETEP. X-Git-Url: https://git.unchartedbackwaters.co.uk/w/?a=commitdiff_plain;h=ee158ce953e2e551e508b33589a55589f9b1fbf0;p=francis%2Fofc.git Initial implementation of form compiler for ONETEP. --- ee158ce953e2e551e508b33589a55589f9b1fbf0 diff --git a/codegenerator.py b/codegenerator.py new file mode 100644 index 0000000..ccc0c81 --- /dev/null +++ b/codegenerator.py @@ -0,0 +1,74 @@ +class FortranVariable(object): + def __init__(self, name, typeString, size): + self._name = name + self._typeString = typeString + self._size = size + + def getName(self): + return self._name + + def getType(self): + return self._typeString + + def getSize(self): + return self._size + +class CodeGenerator(object): + def __init__(self): + self.code = "" + self.indentLevel = 0 + self.symbols = set() + + def addSymbol(self, sym): + self.symbols.add(sym) + + def getDeclarations(self): + declarations = "" + for symbol in self.symbols: + declarations += self.getIndent() + symbol.getType() + " :: " + symbol.getName() + symbol.getSize() + "\n" + + return declarations + + def getBody(self): + return self.code + + def getCode(self): + return self.getDeclarations() + "\n" + self.getBody() + + def getIndentString(self): + return " " + + def getIndent(self): + return self.getIndentString() * self.indentLevel + + def addLine(self, line): + line += "\n" + if (len(line) >= 120): + line = line.replace(", ", ", &\n" + self.getIndentString() * (self.indentLevel+1)) + + self.code += self.getIndent() + line + + def newInt(self, name): + var = FortranVariable(name, "integer", "") + self.addSymbol(var) + return var + + def newArray(self, name, shape): + var = FortranVariable(name, "real(kind=DP), allocatable", "(" + shape + ")") + self.addSymbol(var) + return var + + def enterDo(self, variable, initial, final): + self.addLine("do %s=%s,%s" % (variable.getName(), initial, final)) + self.indentLevel += 1 + + def exitDo(self): + self.indentLevel -= 1 + self.addLine("enddo") + + def callFunction(self, name, params): + call = "call %s(%s)" % (name, ", ".join(params)) + self.addLine(call) + + def assign(self, var, rhs): + self.addLine(var.getName() + " = " + rhs) diff --git a/common.py b/common.py new file mode 100644 index 0000000..f7fd8a3 --- /dev/null +++ b/common.py @@ -0,0 +1,3 @@ + +class OFCException(Exception): + pass diff --git a/frontend.py b/frontend.py new file mode 100644 index 0000000..e98c98b --- /dev/null +++ b/frontend.py @@ -0,0 +1,176 @@ +from region import * + +class Field(object): + def __mul__(self, value): + return ScaledField(self, value) + +class ScaledField(Field): + def __init__(self, field, value): + self._field = field + self._value = value + +class IndexedAssignment(object): + def __init__(self, obj, index, value): + self._obj = obj + self._index = index + self._value = value + +class IndexedRead(object): + def __init__(self, obj, index): + self._obj = obj + self._index = index + +class SPAM3(object): + def __setitem__(self, index, value): + if not isinstance (index, tuple): + index = (index,) + return IndexedAssignment(self, index, value) + +class FunctionSet(object): + def __init__(self, basisName, dataName): + self._basisName = basisName + self._dataName = dataName + + def __getitem__(self, index): + return FunctionSetElement(self, index) + + def getBasisName(self): + return self._basisName + + def getDataName(self): + return self._dataName + +class FunctionSetElement(Field): + def __init__(self, parent, index): + self._parent = parent + self._index = index + + def numPPDs(self, codeGenerator): + return "%s%%spheres(%s)%%n_ppds_sphere" % (self._parent.getBasisName(), self._index.getName()) + + def getGlobalIndices(self, codeGenerator, ppdIndex): + return "%s%%spheres(%s)%%ppd_list(1, %s)" % (self._parent.getBasisName(), self._index.getName(), ppdIndex.getName()) + + def getContributionLocations(self, codeGenerator, ppdIndex): + return "%s%%spheres(%s)%%ppd_list(2, %s)" % (self._parent.getBasisName(), self._index.getName(), ppdIndex.getName()) + + def getTightBox(self, codeGenerator): + return "%s%%tight_boxes(%s)" % (self._parent.getBasisName(), self._index.getName()) + + def getIndex(self): + return self._index + + def getBoxStart(self, codeGenerator): + tb = self.getTightBox(codeGenerator) + common = tb + "%start_pts" + return (common + "1", common + "2", common + "3") + + def getPPDOffset(self, codeGenerator, ppdIndex): + ppdPos1 = codeGenerator.newInt("ppd_pos1") + ppdPos2 = codeGenerator.newInt("ppd_pos2") + ppdPos3 = codeGenerator.newInt("ppd_pos3") + + ppdPos = (ppdPos1, ppdPos2, ppdPos3) + + codeGenerator.callFunction("basis_find_ppd_in_neighbour", + [ppdPos1.getName(), ppdPos2.getName(), ppdPos3.getName(), + self.getGlobalIndices(codeGenerator, ppdIndex), + self.getContributionLocations(codeGenerator, ppdIndex), + "pub_cell%n_ppds_a1", "pub_cell%n_ppds_a2", "pub_cell%n_ppds_a3"]) + + tb = self.getTightBox(codeGenerator) + + return tuple("(" + (ppdPos[i].getName() + " - " + tb + "%%start_ppds%i" % (i+1)) + ") * " + + self.getPPDSize(codeGenerator)[i] for i in [0, 1, 2]) + + def getPPDLocation(self, codeGenerator, ppdIndex): + boxStart = self.getBoxStart(codeGenerator) + offsets = self.getPPDOffset(codeGenerator, ppdIndex) + return tuple("(" + boxStart[i] + "+" + offsets[i] + ")" for i in [0,1,2]) + + def getPPDSize(self, codeGenerator): + return tuple("pub_cell%%n_pt%i" % x for x in [1, 2, 3]) + + def getPPDRegion(self, codeGenerator, ppdIndex): + return Region(self.getPPDLocation(codeGenerator, ppdIndex), self.getPPDSize(codeGenerator)) + + def getPPDRegionData(self, codeGenerator, ppdIndex): + return RegionDataColMajor(self.getPPDRegion(codeGenerator, ppdIndex), + self._parent.getDataName() + "+" + "pub_cell%n_pts * " + ppdIndex.getName()) + + def create(self, codeGenerator): + return + + def destroy(self, codeGenerator): + return + + def iterate(self, codeGenerator, callback): + ppdIndex = codeGenerator.newInt("fa_ppd") + codeGenerator.enterDo(ppdIndex, 1, self.numPPDs(codeGenerator)) + ppdData = self.getPPDRegionData(codeGenerator, ppdIndex) + callback(ppdData) + codeGenerator.exitDo() + +class FFTBoxCallback(object): + def __init__(self, parent, codeGenerator): + self._parent = parent + self._codeGenerator = codeGenerator + + def __call__(self, region): + self._parent.addRegionData(self._codeGenerator, region) + +class FFTBox(Field): + def __init__(self, operand): + self._operand = operand + + def create(self, codeGenerator): + self._operand.create(codeGenerator) + + size = tuple("pub_fftbox%%total_pt%i" % x for x in [1, 2, 3]) + position = tuple("(%s - %s/3)" % (self._operand.getBoxStart(codeGenerator)[i], size[i]) for i in [0,1,2]) + + region = Region(position, size) + data = codeGenerator.newArray("fftbox", ":,:,:") + + self._regionData = RegionDataColMajor(region, data.getName()) + codeGenerator.callFunction("allocate", + ["%s(%s, %s, %s)" % (self._regionData.getName(), size[0], size[1], size[2]), "stat=ierr"]) + + callback = FFTBoxCallback(self, codeGenerator) + self._operand.iterate(codeGenerator, callback) + self._operand.destroy(codeGenerator) + + def addRegionData(self, codeGenerator, regionData): + self._regionData.generateOperation(codeGenerator, regionData, "+=") + + def destroy(self, codeGenerator): + codeGenerator.callFunction("deallocate", [self._regionData.getName(), "stat=ierr"]) + +class Reciprocal(Field): + def __init__(self, operand): + self.operand = operand + +class InnerProduct(Field): + def __init__(self, left, right): + self._left = left + self._right = right + +class Laplacian(Field): + def __init__(self, operand): + self.operand = operand + +def fftbox(x): + return FFTBox(x) + +def reciprocal(x): + return Reciprocal(x) + +def inner(x, y): + return InnerProduct(x, y) + +def laplacian(x): + return Laplacian(x) + +class Index(object): + def getName(self): + return "ket_index" diff --git a/inputinfo.py b/inputinfo.py new file mode 100644 index 0000000..36291ff --- /dev/null +++ b/inputinfo.py @@ -0,0 +1,17 @@ +import os +from common import OFCException + +def read_ofc_file(filename): + if not os.path.exists(filename): + raise OFCException("File %s does not exist." % filename) + + with open(filename) as inputFile: + fcode = inputFile.read() + + namespace = {} + fcode = "from frontend import *\n" + fcode + exec fcode in namespace + return namespace + +class InputInfo(object): + "hai" diff --git a/integrals_kinetic.ofl b/integrals_kinetic.ofl new file mode 100644 index 0000000..c867db7 --- /dev/null +++ b/integrals_kinetic.ofl @@ -0,0 +1,25 @@ +from codegenerator import * + + +# Parameter information +kinet = SPAM3() +bra = FunctionSet("bra_basis", "bras_on_grid") +ket = FunctionSet("ket_basis", "kets_on_grid") + +# How do we capture that we need the range mapped by alpha and beta to be sparse? +alpha = Index(); +beta = Index(); + +# (alpha, beta) = sparsity(kinet) + +# Computation +kinet[alpha, beta] = inner(bra[alpha], reciprocal(laplacian(reciprocal(fftbox(ket[beta])))*-0.5)) + +# Function declaration +function = ["integrals_kinetic", "bras_on_grid", "bra_basis", "kets_on_grid", "ket_basis"] + + +codeGenerator = CodeGenerator() +fftbox(ket[beta]).create(codeGenerator) + +print codeGenerator.getCode() diff --git a/ofc b/ofc new file mode 100755 index 0000000..879ed4b --- /dev/null +++ b/ofc @@ -0,0 +1,21 @@ +#!/usr/bin/env python + +import sys +from inputinfo import * +from codegenerator import * +from common import OFCException + +def main(argv): + if len(argv) == 0: + print("Missing input file.") + return 1 + + try: + info = read_ofc_file(argv[0]) + #print "Info: %s" % info + except OFCException as e: + print "Exception: %s" % e + return 1 + +if __name__ == "__main__": + sys.exit(main(sys.argv[1:])) diff --git a/region.py b/region.py new file mode 100644 index 0000000..1f39eec --- /dev/null +++ b/region.py @@ -0,0 +1,64 @@ +class Region(object): + def __init__(self, position, size): + self._position = position + self._size = size + + def getPosition(self): + return self._position + + def getSize(self): + return self._size + +class RegionDataColMajor(object): + def __init__(self, region, data): + self._region = region + self._data = data + + def getPosition(self): + return self._region.getPosition() + + def getSize(self): + return self._region.getSize() + + def getName(self): + return self._data + + def accessElement(self, codeGenerator, index): + return self.getName() + "(" + ", ".join(index) + ")" + + def generateOperation(self, codeGenerator, regionData, operation): + + start1 = codeGenerator.newInt("start1") + start2 = codeGenerator.newInt("start2") + start3 = codeGenerator.newInt("start3") + + codeGenerator.assign(start1, "max(%s, %s)" % (self.getPosition()[0], regionData.getPosition()[0])) + codeGenerator.assign(start2, "max(%s, %s)" % (self.getPosition()[1], regionData.getPosition()[1])) + codeGenerator.assign(start3, "max(%s, %s)" % (self.getPosition()[2], regionData.getPosition()[2])) + + end1 = codeGenerator.newInt("end1") + end2 = codeGenerator.newInt("end2") + end3 = codeGenerator.newInt("end3") + + codeGenerator.assign(end1, "min(%s, %s)" % (self.getPosition()[0] + self.getSize()[0], regionData.getPosition()[0] + regionData.getSize()[0])) + codeGenerator.assign(end2, "min(%s, %s)" % (self.getPosition()[1] + self.getSize()[1], regionData.getPosition()[1] + regionData.getSize()[1])) + codeGenerator.assign(end3, "min(%s, %s)" % (self.getPosition()[2] + self.getSize()[2], regionData.getPosition()[2] + regionData.getSize()[2])) + + point1 = codeGenerator.newInt("point1") + point2 = codeGenerator.newInt("point2") + point3 = codeGenerator.newInt("point3") + point = (point1, point2, point3) + + codeGenerator.enterDo(point3, start3.getName(), end3.getName()) + codeGenerator.enterDo(point2, start2.getName(), end2.getName()) + codeGenerator.enterDo(point1, start1.getName(), end1.getName()) + + pointSelf = tuple(point[i].getName() + "-" + self.getPosition()[i] for i in [0,1,2]) + pointOther = tuple(point[i].getName() + "-" + regionData.getPosition()[i] for i in [0,1,2]) + + codeGenerator.addLine("%s = %s" % + (self.accessElement(codeGenerator, pointSelf), regionData.accessElement(codeGenerator, pointOther))) + + codeGenerator.exitDo() + codeGenerator.exitDo() + codeGenerator.exitDo()