sizes.append(slice_sizes)
return sizes
+ def get_expression(self, multi_index):
+ assert(len(multi_index) == len(self.shape))
+ block_sizes = self.block_sizes()
+ block_multi_index = []
+ block_local_index = []
+
+ for dimension in range(0, len(block_sizes)):
+ offset = multi_index[dimension]
+ block_index = 0
+ while(offset - block_sizes[dimension][block_index] >= 0):
+ block_index += 1
+ offset -= block_sizes[dimension][block_index]
+
+ block_multi_index.append(block_index)
+ block_local_index.append(offset)
+
+ return self.blocks[tuple(block_multi_index)].get_expression(tuple(block_local_index))
+
def split(self, dimension, offset):
assert(dimension < len(self.shape))
assert(offset < self.shape[dimension])
return result
- def compute_splits(self):
+ def compute_nested_splits(self):
splits = []
for dimension in range(0, len(self.shape)):
offsets = Set()
(len(self.shape) - dimension - 1))
for block in self.blocks[slice_index].flat:
- for local_offset in block.compute_splits()[dimension]:
+ for local_offset in block.compute_nested_splits()[dimension]:
offsets.add(offset + local_offset)
offset += self.block_sizes()[dimension][block_index]
return splits
+ def compute_nested_sizes(self):
+ splits = self.compute_nested_splits()
+ sizes = []
+ for dimension_splits in splits:
+ dimension_sizes = []
+ for idx, split_offset in enumerate(dimension_splits):
+ if idx == 0:
+ dimension_sizes.append(dimension_splits[0])
+ else:
+ dimension_sizes.append(dimension_splits[idx] -
+ dimension_splits[idx - 1])
+
+ sizes.append(dimension_sizes)
+
+ return sizes
+
def flatten(self):
- raise NotImplementedError("Not yet done.")
+ sizes = self.compute_nested_sizes()
+ splits = self.compute_nested_splits()
+ result = self.from_block_shapes(sizes)
+
+ for idx, index in enumerate(itertools.product(*splits)):
+ # We offset each index by -1 since we are dealing with sizes, and we need a valid index
+ offset_index = [i - 1 for i in index]
+ expression = self.get_expression(offset_index)
+ result.blocks.flat[idx] = DenseBlock(result.blocks.flat[idx].shape, expression)
+
+ return result
class DenseBlock:
def block_sizes(self):
return self.shape
- def compute_splits(self):
+ def compute_nested_splits(self):
return [[size] for size in self.shape]
+ def get_expression(self, multi_index):
+ assert(len(multi_index) == len(self.shape))
+ return self.expr
+
def split(self, dimension, offset):
assert(dimension < len(self.shape))
assert(offset < self.shape[dimension])