From 9a9272bcffbb57ec0651a082243cdf2452361b00 Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Fri, 31 Jan 2025 15:41:17 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- test/test_storage_map.py | 105 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 104 insertions(+), 1 deletion(-) diff --git a/test/test_storage_map.py b/test/test_storage_map.py index 90b16db00d3..9866d893ff9 100644 --- a/test/test_storage_map.py +++ b/test/test_storage_map.py @@ -12,7 +12,7 @@ import torch from tensordict import assert_close, TensorDict -from torchrl.data import LazyTensorStorage, ListStorage, MCTSForest +from torchrl.data import LazyTensorStorage, ListStorage, MCTSForest, Tree from torchrl.data.map import ( BinaryToDecimal, QueryModule, @@ -248,6 +248,109 @@ def test_map_rollout(self): assert not contains[rollout.shape[-1] :].any() +# Tests Tree independent of MCTSForest +class TestTree: + def dummy_tree(self): + """Creates a tree with the following node IDs: + + 0 + ├── 1 + | ├── 3 + | └── 4 + └── 2 + ├── 5 + └── 6 + """ + + class IDGen: + def __init__(self): + self.next_id = 0 + + def __call__(self): + res = self.next_id + self.next_id += 1 + return res + + gen_id = IDGen() + gen_hash = lambda: hash(torch.rand(1).item()) + + def dummy_node_stack(obervations): + return TensorDict.lazy_stack( + [ + Tree( + node_data=TensorDict({"obs": torch.tensor(obs)}), + hash=gen_hash(), + node_id=gen_id(), + ) + for obs in obervations + ] + ) + + tree = dummy_node_stack([0])[0] + tree.subtree = dummy_node_stack([1, 2]) + tree.subtree[0].subtree = dummy_node_stack([3, 4]) + tree.subtree[1].subtree = dummy_node_stack([6, 7]) + return tree + + # Checks that when adding nodes to a tree, the `parent` property is set + # correctly + def test_parents(self): + tree = self.dummy_tree() + + def check_parents_recursive(tree, parent): + if parent is None: + if tree.parent is not None: + return False + elif tree.parent.node_data is not parent.node_data: + return False + + if tree.subtree is not None: + for subtree in tree.subtree: + if not check_parents_recursive(subtree, tree): + return False + + return True + + assert check_parents_recursive(tree, None) + + def test_vertices(self): + tree = self.dummy_tree() + N = 7 + assert tree.num_vertices(count_repeat=False) == N + assert tree.num_vertices(count_repeat=True) == N + assert len(tree.vertices(key_type="hash")) == N + assert len(tree.vertices(key_type="id")) == N + assert len(tree.vertices(key_type="path")) == N + + for path, vertex in tree.vertices(key_type="path").items(): + vertex_check = tree + for i in path: + vertex_check = vertex_check.subtree[i] + assert vertex.node_data is vertex_check.node_data + + def test_in(self): + for tree in self.dummy_tree().vertices().values(): + for path, subtree in tree.vertices(key_type="path").items(): + assert subtree in tree + + if len(path) == 0: + assert tree in subtree + else: + assert tree not in subtree + + def test_valid_paths(self): + tree = self.dummy_tree() + paths = set(tree.valid_paths()) + paths_check = {(0, 0), (0, 1), (1, 0), (1, 1)} + assert paths == paths_check + + def test_edges(self): + tree = self.dummy_tree() + edges = set(tree.edges()) + edges_check = {(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)} + assert edges == edges_check + + class TestMCTSForest: def dummy_rollouts(self) -> Tuple[TensorDict, ...]: """