]> git.unchartedbackwaters.co.uk Git - francis/lta2.git/commitdiff
Implement flattening of nested blocks. master
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Tue, 9 Jul 2013 15:44:08 +0000 (16:44 +0100)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Tue, 9 Jul 2013 15:44:08 +0000 (16:44 +0100)
block.py

index 8dbb0e861ab20e279fd081addcfa4b399dd8e0fb..b49e71233daeb3441153b89e0cda14ecc9359a6e 100644 (file)
--- a/block.py
+++ b/block.py
@@ -44,6 +44,24 @@ class Block:
             sizes.append(slice_sizes)
         return sizes
 
+    def get_expression(self, multi_index):
+        assert(len(multi_index) == len(self.shape))
+        block_sizes = self.block_sizes()
+        block_multi_index = []
+        block_local_index = []
+
+        for dimension in range(0, len(block_sizes)):
+            offset = multi_index[dimension]
+            block_index = 0
+            while(offset - block_sizes[dimension][block_index] >= 0):
+                block_index += 1
+                offset -= block_sizes[dimension][block_index]
+
+            block_multi_index.append(block_index)
+            block_local_index.append(offset)
+
+        return self.blocks[tuple(block_multi_index)].get_expression(tuple(block_local_index))
+
     def split(self, dimension, offset):
         assert(dimension < len(self.shape))
         assert(offset < self.shape[dimension])
@@ -78,7 +96,7 @@ class Block:
 
         return result
 
-    def compute_splits(self):
+    def compute_nested_splits(self):
         splits = []
         for dimension in range(0, len(self.shape)):
             offsets = Set()
@@ -90,7 +108,7 @@ class Block:
                                 (len(self.shape) - dimension - 1))
 
                 for block in self.blocks[slice_index].flat:
-                    for local_offset in block.compute_splits()[dimension]:
+                    for local_offset in block.compute_nested_splits()[dimension]:
                         offsets.add(offset + local_offset)
 
                 offset += self.block_sizes()[dimension][block_index]
@@ -98,8 +116,34 @@ class Block:
 
         return splits
 
+    def compute_nested_sizes(self):
+        splits = self.compute_nested_splits()
+        sizes = []
+        for dimension_splits in splits:
+            dimension_sizes = []
+            for idx, split_offset in enumerate(dimension_splits):
+                if idx == 0:
+                    dimension_sizes.append(dimension_splits[0])
+                else:
+                    dimension_sizes.append(dimension_splits[idx] -
+                                           dimension_splits[idx - 1])
+
+            sizes.append(dimension_sizes)
+
+        return sizes
+
     def flatten(self):
-        raise NotImplementedError("Not yet done.")
+        sizes = self.compute_nested_sizes()
+        splits = self.compute_nested_splits()
+        result = self.from_block_shapes(sizes)
+
+        for idx, index in enumerate(itertools.product(*splits)):
+            # We offset each index by -1 since we are dealing with sizes, and we need a valid index
+            offset_index = [i - 1 for i in index]
+            expression = self.get_expression(offset_index)
+            result.blocks.flat[idx] = DenseBlock(result.blocks.flat[idx].shape, expression)
+
+        return result
 
 
 class DenseBlock:
@@ -110,9 +154,13 @@ class DenseBlock:
     def block_sizes(self):
         return self.shape
 
-    def compute_splits(self):
+    def compute_nested_splits(self):
         return [[size] for size in self.shape]
 
+    def get_expression(self, multi_index):
+        assert(len(multi_index) == len(self.shape))
+        return self.expr
+
     def split(self, dimension, offset):
         assert(dimension < len(self.shape))
         assert(offset < self.shape[dimension])