From 6fb2b6892c9b427ece935c72601f042028ab509f Mon Sep 17 00:00:00 2001 From: Francis Russell Date: Tue, 9 Jul 2013 16:44:08 +0100 Subject: [PATCH] Implement flattening of nested blocks. --- block.py | 56 ++++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 52 insertions(+), 4 deletions(-) diff --git a/block.py b/block.py index 8dbb0e8..b49e712 100644 --- a/block.py +++ b/block.py @@ -44,6 +44,24 @@ class Block: sizes.append(slice_sizes) return sizes + def get_expression(self, multi_index): + assert(len(multi_index) == len(self.shape)) + block_sizes = self.block_sizes() + block_multi_index = [] + block_local_index = [] + + for dimension in range(0, len(block_sizes)): + offset = multi_index[dimension] + block_index = 0 + while(offset - block_sizes[dimension][block_index] >= 0): + block_index += 1 + offset -= block_sizes[dimension][block_index] + + block_multi_index.append(block_index) + block_local_index.append(offset) + + return self.blocks[tuple(block_multi_index)].get_expression(tuple(block_local_index)) + def split(self, dimension, offset): assert(dimension < len(self.shape)) assert(offset < self.shape[dimension]) @@ -78,7 +96,7 @@ class Block: return result - def compute_splits(self): + def compute_nested_splits(self): splits = [] for dimension in range(0, len(self.shape)): offsets = Set() @@ -90,7 +108,7 @@ class Block: (len(self.shape) - dimension - 1)) for block in self.blocks[slice_index].flat: - for local_offset in block.compute_splits()[dimension]: + for local_offset in block.compute_nested_splits()[dimension]: offsets.add(offset + local_offset) offset += self.block_sizes()[dimension][block_index] @@ -98,8 +116,34 @@ class Block: return splits + def compute_nested_sizes(self): + splits = self.compute_nested_splits() + sizes = [] + for dimension_splits in splits: + dimension_sizes = [] + for idx, split_offset in enumerate(dimension_splits): + if idx == 0: + dimension_sizes.append(dimension_splits[0]) + else: + dimension_sizes.append(dimension_splits[idx] - + dimension_splits[idx - 1]) + + sizes.append(dimension_sizes) + + return sizes + def flatten(self): - raise NotImplementedError("Not yet done.") + sizes = self.compute_nested_sizes() + splits = self.compute_nested_splits() + result = self.from_block_shapes(sizes) + + for idx, index in enumerate(itertools.product(*splits)): + # We offset each index by -1 since we are dealing with sizes, and we need a valid index + offset_index = [i - 1 for i in index] + expression = self.get_expression(offset_index) + result.blocks.flat[idx] = DenseBlock(result.blocks.flat[idx].shape, expression) + + return result class DenseBlock: @@ -110,9 +154,13 @@ class DenseBlock: def block_sizes(self): return self.shape - def compute_splits(self): + def compute_nested_splits(self): return [[size] for size in self.shape] + def get_expression(self, multi_index): + assert(len(multi_index) == len(self.shape)) + return self.expr + def split(self, dimension, offset): assert(dimension < len(self.shape)) assert(offset < self.shape[dimension]) -- 2.47.3