Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
kurtamohler committed Jan 31, 2025
1 parent 269d0bd commit 9a9272b
Showing 1 changed file with 104 additions and 1 deletion.
105 changes: 104 additions & 1 deletion test/test_storage_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch

from tensordict import assert_close, TensorDict
from torchrl.data import LazyTensorStorage, ListStorage, MCTSForest
from torchrl.data import LazyTensorStorage, ListStorage, MCTSForest, Tree
from torchrl.data.map import (
BinaryToDecimal,
QueryModule,
Expand Down Expand Up @@ -248,6 +248,109 @@ def test_map_rollout(self):
assert not contains[rollout.shape[-1] :].any()


# Tests Tree independent of MCTSForest
class TestTree:
def dummy_tree(self):
"""Creates a tree with the following node IDs:
0
├── 1
| ├── 3
| └── 4
└── 2
├── 5
└── 6
"""

class IDGen:
def __init__(self):
self.next_id = 0

def __call__(self):
res = self.next_id
self.next_id += 1
return res

gen_id = IDGen()
gen_hash = lambda: hash(torch.rand(1).item())

def dummy_node_stack(obervations):
return TensorDict.lazy_stack(
[
Tree(
node_data=TensorDict({"obs": torch.tensor(obs)}),
hash=gen_hash(),
node_id=gen_id(),
)
for obs in obervations
]
)

tree = dummy_node_stack([0])[0]
tree.subtree = dummy_node_stack([1, 2])
tree.subtree[0].subtree = dummy_node_stack([3, 4])
tree.subtree[1].subtree = dummy_node_stack([6, 7])
return tree

# Checks that when adding nodes to a tree, the `parent` property is set
# correctly
def test_parents(self):
tree = self.dummy_tree()

def check_parents_recursive(tree, parent):
if parent is None:
if tree.parent is not None:
return False
elif tree.parent.node_data is not parent.node_data:
return False

if tree.subtree is not None:
for subtree in tree.subtree:
if not check_parents_recursive(subtree, tree):
return False

return True

assert check_parents_recursive(tree, None)

def test_vertices(self):
tree = self.dummy_tree()
N = 7
assert tree.num_vertices(count_repeat=False) == N
assert tree.num_vertices(count_repeat=True) == N
assert len(tree.vertices(key_type="hash")) == N
assert len(tree.vertices(key_type="id")) == N
assert len(tree.vertices(key_type="path")) == N

for path, vertex in tree.vertices(key_type="path").items():
vertex_check = tree
for i in path:
vertex_check = vertex_check.subtree[i]
assert vertex.node_data is vertex_check.node_data

def test_in(self):
for tree in self.dummy_tree().vertices().values():
for path, subtree in tree.vertices(key_type="path").items():
assert subtree in tree

if len(path) == 0:
assert tree in subtree
else:
assert tree not in subtree

def test_valid_paths(self):
tree = self.dummy_tree()
paths = set(tree.valid_paths())
paths_check = {(0, 0), (0, 1), (1, 0), (1, 1)}
assert paths == paths_check

def test_edges(self):
tree = self.dummy_tree()
edges = set(tree.edges())
edges_check = {(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)}
assert edges == edges_check


class TestMCTSForest:
def dummy_rollouts(self) -> Tuple[TensorDict, ...]:
"""
Expand Down

0 comments on commit 9a9272b

Please sign in to comment.