Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Test] Add tests for Tree #2738

Merged
merged 1 commit into from
Feb 3, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 104 additions & 1 deletion test/test_storage_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch

from tensordict import assert_close, TensorDict
from torchrl.data import LazyTensorStorage, ListStorage, MCTSForest
from torchrl.data import LazyTensorStorage, ListStorage, MCTSForest, Tree
from torchrl.data.map import (
BinaryToDecimal,
QueryModule,
Expand Down Expand Up @@ -248,6 +248,109 @@ def test_map_rollout(self):
assert not contains[rollout.shape[-1] :].any()


# Tests Tree independent of MCTSForest
class TestTree:
def dummy_tree(self):
"""Creates a tree with the following node IDs:

0
├── 1
| ├── 3
| └── 4
└── 2
├── 5
└── 6
"""

class IDGen:
def __init__(self):
self.next_id = 0

def __call__(self):
res = self.next_id
self.next_id += 1
return res

gen_id = IDGen()
gen_hash = lambda: hash(torch.rand(1).item())

def dummy_node_stack(obervations):
return TensorDict.lazy_stack(
[
Tree(
node_data=TensorDict({"obs": torch.tensor(obs)}),
hash=gen_hash(),
node_id=gen_id(),
)
for obs in obervations
]
)

tree = dummy_node_stack([0])[0]
tree.subtree = dummy_node_stack([1, 2])
tree.subtree[0].subtree = dummy_node_stack([3, 4])
tree.subtree[1].subtree = dummy_node_stack([6, 7])
return tree

# Checks that when adding nodes to a tree, the `parent` property is set
# correctly
def test_parents(self):
tree = self.dummy_tree()

def check_parents_recursive(tree, parent):
if parent is None:
if tree.parent is not None:
return False
elif tree.parent.node_data is not parent.node_data:
return False

if tree.subtree is not None:
for subtree in tree.subtree:
if not check_parents_recursive(subtree, tree):
return False

return True

assert check_parents_recursive(tree, None)

def test_vertices(self):
tree = self.dummy_tree()
N = 7
assert tree.num_vertices(count_repeat=False) == N
assert tree.num_vertices(count_repeat=True) == N
assert len(tree.vertices(key_type="hash")) == N
assert len(tree.vertices(key_type="id")) == N
assert len(tree.vertices(key_type="path")) == N

for path, vertex in tree.vertices(key_type="path").items():
vertex_check = tree
for i in path:
vertex_check = vertex_check.subtree[i]
assert vertex.node_data is vertex_check.node_data

def test_in(self):
for tree in self.dummy_tree().vertices().values():
for path, subtree in tree.vertices(key_type="path").items():
assert subtree in tree

if len(path) == 0:
assert tree in subtree
else:
assert tree not in subtree

def test_valid_paths(self):
tree = self.dummy_tree()
paths = set(tree.valid_paths())
paths_check = {(0, 0), (0, 1), (1, 0), (1, 1)}
assert paths == paths_check

def test_edges(self):
tree = self.dummy_tree()
edges = set(tree.edges())
edges_check = {(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)}
assert edges == edges_check


class TestMCTSForest:
def dummy_rollouts(self) -> Tuple[TensorDict, ...]:
"""
Expand Down
Loading