Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 24, 2024
1 parent 6af7b74 commit 3b47582
Showing 1 changed file with 28 additions and 28 deletions.
56 changes: 28 additions & 28 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5573,7 +5573,7 @@ def dense_stack_tds_v2(self, td_list, stack_dim: int) -> TensorDictBase:
out = td_list[0].unsqueeze(stack_dim).expand(shape).clone()

data_ptr_set_before = {val.data_ptr() for val in decompose(out)}
res = torch.stack(td_list, dim=stack_dim, out=out)
res = LazyStackedTensorDict.lazy_stack(td_list, dim=stack_dim, out=out)
data_ptr_set_after = {val.data_ptr() for val in decompose(out)}
assert data_ptr_set_before == data_ptr_set_after

Expand Down Expand Up @@ -5619,7 +5619,7 @@ def nested_lazy_het_td(batch_size):
td[f"individual_{i}_td"] = td.clone()
td["shared_td"] = td.clone()

td_stack = torch.stack(td_list, dim=0)
td_stack = LazyStackedTensorDict.lazy_stack(td_list, dim=0)
obs = TensorDict(
{"lazy": td_stack, "dense": torch.zeros(3, 3, 2)},
[],
Expand Down Expand Up @@ -5647,7 +5647,7 @@ def test_add_batch_dim_cache(self):
td = TensorDict(
{"a": torch.rand(3, 4, 5), ("b", "c"): torch.rand(3, 4, 5)}, [3, 4, 5]
)
td = torch.stack([td, td.clone()], 0)
td = LazyStackedTensorDict.lazy_stack([td, td.clone()], 0)
from tensordict.nn import TensorDictModule # noqa
from torch import vmap

Expand All @@ -5665,7 +5665,7 @@ def test_add_batch_dim_cache_nested(self):
td = TensorDict(
{"a": torch.rand(3, 4, 5), ("b", "c"): torch.rand(3, 4, 5)}, [3, 4, 5]
)
td = TensorDict({"parent": torch.stack([td, td.clone()], 0)}, [2, 3, 4, 5])
td = TensorDict({"parent": LazyStackedTensorDict.lazy_stack([td, td.clone()], 0)}, [2, 3, 4, 5])
from tensordict.nn import TensorDictModule # noqa
from torch import vmap

Expand All @@ -5682,7 +5682,7 @@ def test_add_batch_dim_cache_nested(self):
def test_all_keys(self):
td = TensorDict({"a": torch.zeros(1)}, [])
td2 = TensorDict({"a": torch.zeros(2)}, [])
stack = torch.stack([td, td2])
stack = LazyStackedTensorDict.lazy_stack([td, td2])
assert set(stack.keys(True, True)) == {"a"}

def test_best_intention_stack(self):
Expand Down Expand Up @@ -5853,7 +5853,7 @@ def test_lazy_stack_stack(self, batch_size):
assert obs["lazy"].shape == (*batch_size, 3)
assert isinstance(obs["lazy"][..., 0], TensorDict) # succeeds

obs_stack = torch.stack([obs])
obs_stack = LazyStackedTensorDict.lazy_stack([obs])

assert (
isinstance(obs_stack, LazyStackedTensorDict) and obs_stack.stack_dim == 0
Expand All @@ -5866,7 +5866,7 @@ def test_lazy_stack_stack(self, batch_size):
assert obs_stack["lazy"][0] is obs["lazy"]

obs2 = obs.clone()
obs_stack = torch.stack([obs, obs2])
obs_stack = LazyStackedTensorDict.lazy_stack([obs, obs2])

assert (
isinstance(obs_stack, LazyStackedTensorDict) and obs_stack.stack_dim == 0
Expand All @@ -5882,7 +5882,7 @@ def test_lazy_stack_stack(self, batch_size):
@pytest.mark.parametrize("device", get_available_devices())
def test_lazy_stacked_append(self, dim, device):
td = TensorDict({"a": torch.zeros(4)}, [4], device=device)
lstd = torch.stack([td] * 2, dim=dim)
lstd = LazyStackedTensorDict.lazy_stack([td] * 2, dim=dim)

lstd.append(
TensorDict(
Expand Down Expand Up @@ -5921,7 +5921,7 @@ def test_lazy_stacked_contains(self):
td = TensorDict(
{"a": TensorDict({"b": torch.rand(1, 2)}, [1, 2]), "c": torch.rand(1)}, [1]
)
lstd = torch.stack([td, td, td])
lstd = LazyStackedTensorDict.lazy_stack([td, td, td])

assert td in lstd
assert td.clone() not in lstd
Expand All @@ -5937,7 +5937,7 @@ def test_lazy_stacked_contains(self):
@pytest.mark.parametrize("device", get_available_devices())
def test_lazy_stacked_insert(self, dim, index, device):
td = TensorDict({"a": torch.zeros(4)}, [4], device=device)
lstd = torch.stack([td] * 2, dim=dim)
lstd = LazyStackedTensorDict.lazy_stack([td] * 2, dim=dim)

lstd.insert(
index,
Expand Down Expand Up @@ -5997,7 +5997,7 @@ def test_setitem_hetero(self, batch_size, stack_dim):
def test_stack(self, device):
torch.manual_seed(1)
tds_list = [TensorDict(source={}, batch_size=(4, 5)) for _ in range(3)]
tds = stack_td(tds_list, 0, contiguous=False)
tds = LazyStackedTensorDict.lazy_stack(tds_list, 0)
assert tds[0] is tds_list[0]

td = TensorDict(
Expand All @@ -6024,7 +6024,7 @@ def test_stack_apply(self):
},
[3],
)
td = TensorDict({"parent": torch.stack([td0, td1], 0)}, [2])
td = TensorDict({"parent": LazyStackedTensorDict.lazy_stack([td0, td1], 0)}, [2])
td2 = td.clone()
tdapply = td.apply(lambda x, y: x + y, td2)
assert isinstance(tdapply["parent", "a", "b"], LazyStackedTensorDict)
Expand All @@ -6039,7 +6039,7 @@ def test_stack_hetero(self, batch_size):
obs2 = obs.clone()
obs2.apply_(lambda x: x + 1)

obs_stack = torch.stack([obs, obs2])
obs_stack = LazyStackedTensorDict.lazy_stack([obs, obs2])
obs_stack_resolved = self.dense_stack_tds_v2([obs, obs2], stack_dim=0)

assert isinstance(obs_stack, LazyStackedTensorDict) and obs_stack.stack_dim == 0
Expand Down Expand Up @@ -6129,7 +6129,7 @@ def test_stack_keys(self):
@pytest.mark.parametrize("unsqueeze_dim", [0, 1, -1, -2])
def test_stack_unsqueeze(self, unsqueeze_dim):
td = TensorDict({("a", "b"): torch.ones(3, 4, 5)}, [3, 4])
td_stack = torch.stack(td.unbind(1), 1)
td_stack = LazyStackedTensorDict.lazy_stack(td.unbind(1), 1)
td_unsqueeze = td.unsqueeze(unsqueeze_dim)
td_stack_unsqueeze = td_stack.unsqueeze(unsqueeze_dim)
assert isinstance(td_stack_unsqueeze, LazyStackedTensorDict)
Expand All @@ -6140,7 +6140,7 @@ def test_stack_unsqueeze(self, unsqueeze_dim):
def test_stack_update_heter_stacked_td(self, stack_dim):
td1 = TensorDict({"a": torch.randn(3, 4)}, [3])
td2 = TensorDict({"a": torch.randn(3, 5)}, [3])
td_a = torch.stack([td1, td2], stack_dim)
td_a = LazyStackedTensorDict.lazy_stack([td1, td2], stack_dim)
td_b = td_a.clone()
td_a.update(td_b)
with pytest.raises(
Expand All @@ -6164,7 +6164,7 @@ def test_stacked_indexing(self, device, stack_dim):
device=device,
)

tds = torch.stack(list(tensordict.unbind(stack_dim)), stack_dim)
tds = LazyStackedTensorDict.lazy_stack(list(tensordict.unbind(stack_dim)), stack_dim)

for item, expected_shape in (
((2, 2), torch.Size([5])),
Expand Down Expand Up @@ -6240,7 +6240,7 @@ def test_stacked_td(self, stack_dim, device):
assert (std5.contiguous() == sub_td.contiguous().unbind(1)[0]).all()

def test_stacked_td_nested_keys(self):
td = torch.stack(
td = LazyStackedTensorDict.lazy_stack(
[
TensorDict({"a": {"b": {"d": [1]}, "c": [2]}}, []),
TensorDict({"a": {"b": {"d": [1]}, "d": [2]}}, []),
Expand Down Expand Up @@ -6269,7 +6269,7 @@ def test_unbind_lazystack(self):
},
[3, 4],
)
td = torch.stack([td0, td0, td0], 1)
td = LazyStackedTensorDict.lazy_stack([td0, td0, td0], 1)

assert all(_td is td0 for _td in td.unbind(1))

Expand All @@ -6289,7 +6289,7 @@ def test_update_with_lazy(self):
},
[3],
)
td = TensorDict({"parent": torch.stack([td0, td1], 0)}, [2])
td = TensorDict({"parent": LazyStackedTensorDict.lazy_stack([td0, td1], 0)}, [2])

td_void = TensorDict(
{
Expand Down Expand Up @@ -6725,11 +6725,11 @@ def test_squeeze_td(self):

def test_stack(self):
td = TensorDict({}, batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"])
tds = torch.stack([td, td], 0)
tds = LazyStackedTensorDict.lazy_stack([td, td], 0)
assert tds.names == [None, "a", "b", "c", "d"]
tds = torch.stack([td, td], -1)
tds = LazyStackedTensorDict.lazy_stack([td, td], -1)
assert tds.names == ["a", "b", "c", "d", None]
tds = torch.stack([td, td], 2)
tds = LazyStackedTensorDict.lazy_stack([td, td], 2)
tds.names = list("mnopq")
assert tds.names == list("mnopq")
assert td.names == ["m", "n", "p", "q"]
Expand All @@ -6738,7 +6738,7 @@ def test_stack_assign(self):
td = TensorDict(
{"": TensorDict({}, [3, 4], names=["c", "d"])}, [3], names=["c"]
)
tds = torch.stack([td, td], -1)
tds = LazyStackedTensorDict.lazy_stack([td, td], -1)
assert tds.names == ["c", None]
assert tds[""].names == ["c", None, "d"]
with pytest.raises(ValueError):
Expand Down Expand Up @@ -6848,7 +6848,7 @@ def check_weakref_count(weakref_list, expected):
def test_lock_stack(self):
td0 = TensorDict({("a", "b", "c", "d"): 1.0}, [])
td1 = td0.clone()
td = torch.stack([td0, td1])
td = LazyStackedTensorDict.lazy_stack([td0, td1])
td = td.lock_()
a = td["a"]
b = td["a", "b"]
Expand Down Expand Up @@ -6983,7 +6983,7 @@ def test_nested_lock_erros(self):
def test_stack_cache_lock(self):
td0 = TensorDict({("a", "b", "c", "d"): 1.0}, [])
td1 = td0.clone()
td = torch.stack([td0, td1])
td = LazyStackedTensorDict.lazy_stack([td0, td1])
assert td._is_locked is None
td = td.lock_()
assert td._is_locked
Expand Down Expand Up @@ -7019,7 +7019,7 @@ def test_stack_cache_lock(self):
def test_stacked_append_and_insert(self):
td0 = TensorDict({("a", "b", "c", "d"): 1.0}, [])
td1 = td0.clone()
td = torch.stack([td0, td1])
td = LazyStackedTensorDict.lazy_stack([td0, td1])
td.lock_()
with pytest.raises(RuntimeError, match=re.escape(_LOCK_ERROR)):
td.insert(0, td0)
Expand Down Expand Up @@ -7392,7 +7392,7 @@ def test_map_unbind(self):
mp.set_start_method("spawn")
td0 = TensorDict({"0": 0}, [])
td1 = TensorDict({"1": 1}, [])
td = torch.stack([td0, td1], 0)
td = LazyStackedTensorDict.lazy_stack([td0, td1], 0)
td_out = td.map(self._set_2, chunksize=0, num_workers=4)
assert td_out[0]["0"] == 0
assert td_out[1]["1"] == 1
Expand Down Expand Up @@ -7443,7 +7443,7 @@ def test_set(self, non_tensor_data):

def test_stack(self, non_tensor_data):
assert (
torch.stack([non_tensor_data, non_tensor_data], 0).get(("nested", "int"))
LazyStackedTensorDict.lazy_stack([non_tensor_data, non_tensor_data], 0).get(("nested", "int"))
== NonTensorData(3, batch_size=[2])
).all()
assert (
Expand Down

0 comments on commit 3b47582

Please sign in to comment.