--- /dev/null
+import sympy
+import numpy
+import itertools
+import copy
+from sets import Set
+
+
+class Block:
+ def __init__(self, shape):
+ self.shape = shape
+ self.indices = map(sympy.Dummy,
+ map(chr, range(ord('i'), ord('z'))[:len(shape)]))
+ self.blocks = numpy.empty([1] * len(shape), dtype=numpy.object)
+ self.blocks.flat[0] = DenseBlock(shape, 0)
+
+ def __setitem__(self, key, value):
+ print str(key)
+
+ @staticmethod
+ def from_block_shapes(block_shapes):
+ block = Block(map(sum, block_shapes))
+ block.blocks = numpy.empty(map(len, block_shapes), dtype=numpy.object)
+
+ for idx, block_shape in enumerate(itertools.product(*block_shapes)):
+ block.blocks.flat[idx] = DenseBlock(block_shape, 0)
+
+ return block
+
+ def block_shape(self):
+ """Return the shape of the subblocks of this block"""
+ return self.blocks.shape
+
+ def block_sizes(self):
+ """Return the sizes of the subblocks of this block"""
+ sizes = []
+ for dimension in range(0, len(self.shape)):
+ slice_sizes = []
+ slice_index = ([0] * dimension +
+ [slice(None, None, None)] +
+ [0] * (len(self.shape) - dimension - 1))
+ for block in self.blocks[slice_index]:
+ slice_sizes.append(block.shape[dimension])
+
+ sizes.append(slice_sizes)
+ return sizes
+
+ def split(self, dimension, offset):
+ assert(dimension < len(self.shape))
+ assert(offset < self.shape[dimension])
+
+ result = copy.deepcopy(self)
+
+ dimension_sizes = self.block_sizes()[dimension]
+
+ block_offset = 0
+ local_offset = offset
+ block_index = 0
+ while(local_offset - dimension_sizes[block_index] >= 0):
+ block_index += 1
+ block_offset += dimension_sizes[block_index]
+ local_offset -= dimension_sizes[block_index]
+
+ if (local_offset == 0):
+ return self
+
+ slice_index = ([slice(None, None, None)] * dimension +
+ [block_index] +
+ [slice(None, None, None)] *
+ (len(self.shape) - dimension - 1))
+
+ split_blocks = []
+ for block in self.blocks[slice_index].flat:
+ split_blocks.append(block.split(dimension, local_offset))
+
+ result.blocks = numpy.delete(result.blocks, block_index, dimension)
+ result.blocks = numpy.insert(result.blocks, block_index, split_blocks,
+ dimension)
+
+ return result
+
+ def compute_splits(self):
+ splits = []
+ for dimension in range(0, len(self.shape)):
+ offsets = Set()
+ offset = 0
+ for block_index in range(0, self.block_shape()[dimension]):
+ slice_index = ([slice(None, None, None)] * dimension +
+ [block_index] +
+ [slice(None, None, None)] *
+ (len(self.shape) - dimension - 1))
+
+ for block in self.blocks[slice_index].flat:
+ for local_offset in block.compute_splits()[dimension]:
+ offsets.add(offset + local_offset)
+
+ offset += self.block_sizes()[dimension][block_index]
+ splits.append(sorted(list(offsets)))
+
+ return splits
+
+ def flatten(self):
+ raise NotImplementedError("Not yet done.")
+
+
+class DenseBlock:
+ def __init__(self, shape, expr):
+ self.expr = expr
+ self.shape = shape
+
+ def block_sizes(self):
+ return self.shape
+
+ def compute_splits(self):
+ return [[size] for size in self.shape]
+
+ def split(self, dimension, offset):
+ assert(dimension < len(self.shape))
+ assert(offset < self.shape[dimension])
+
+ before_shape = (self.shape[0:dimension] + [offset] +
+ self.shape[dimension + 1:])
+ after_shape = (self.shape[0:dimension] + [self.shape[dimension] - offset] +
+ self.shape[dimension + 1:])
+
+ before = DenseBlock(before_shape, self.expr)
+ after = DenseBlock(after_shape, self.expr)
+
+ result = Block(self.shape)
+ result.blocks = numpy.empty([1] * dimension + [2] + [1] * (len(self.shape) - dimension - 1), dtype=numpy.object)
+ result.blocks.flat[0] = before
+ result.blocks.flat[1] = after
+
+ return result