Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 24, 2024
1 parent ccd3ddc commit 9b2cc80
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 17 deletions.
4 changes: 2 additions & 2 deletions test/test_functorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,8 @@ def test_vmap_write_lazystack(
)
td0 = TensorDict({key: [1.0]}, [1])
td1 = TensorDict({key: [2.0]}, [1])
x = torch.stack([td0, td0.clone()], stack_dim)
y = torch.stack([td1, td1.clone()], stack_dim)
x = LazyStackedTensorDict.lazy_stack([td0, td0.clone()], stack_dim)
y = LazyStackedTensorDict.lazy_stack([td1, td1.clone()], stack_dim)
if lock_x:
x.lock_()
if lock_y:
Expand Down
49 changes: 34 additions & 15 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,9 @@ def test_dense_stack_tds(self, stack_dim, nested_stack_dim):
td_container_clone.apply_(lambda x: x + 1)

assert td_lazy.stack_dim == nested_stack_dim
td_stack = LazyStackedTensorDict.lazy_stack([td_container, td_container_clone], dim=stack_dim)
td_stack = LazyStackedTensorDict.lazy_stack(
[td_container, td_container_clone], dim=stack_dim
)
assert td_stack.stack_dim == stack_dim

assert isinstance(td_stack, LazyStackedTensorDict)
Expand Down Expand Up @@ -490,7 +492,9 @@ def test_filling_empty_tensordict(self, device, td_type, update):
elif td_type == "squeeze":
td = TensorDict({}, batch_size=[16, 1], device=device).squeeze(-1)
elif td_type == "stack":
td = LazyStackedTensorDict.lazy_stack([TensorDict({}, [], device=device) for _ in range(16)], 0)
td = LazyStackedTensorDict.lazy_stack(
[TensorDict({}, [], device=device) for _ in range(16)], 0
)
else:
raise NotImplementedError

Expand Down Expand Up @@ -4057,14 +4061,18 @@ def test_stack_onto(self, td_name, device, tmpdir):
else:
td1.apply_(lambda x: x.zero_() + 1)

is_lazy = td_name in (
"sub_td",
"sub_td2",
"permute_td",
"unsqueezed_td",
"squeezed_td",
"td_h5",
) and not lazy_legacy()
is_lazy = (
td_name
in (
"sub_td",
"sub_td2",
"permute_td",
"unsqueezed_td",
"squeezed_td",
"td_h5",
)
and not lazy_legacy()
)
error_dec = (
pytest.raises(RuntimeError, match="Make it dense")
if is_lazy
Expand Down Expand Up @@ -5706,7 +5714,10 @@ def test_add_batch_dim_cache_nested(self):
td = TensorDict(
{"a": torch.rand(3, 4, 5), ("b", "c"): torch.rand(3, 4, 5)}, [3, 4, 5]
)
td = TensorDict({"parent": LazyStackedTensorDict.lazy_stack([td, td.clone()], 0)}, [2, 3, 4, 5])
td = TensorDict(
{"parent": LazyStackedTensorDict.lazy_stack([td, td.clone()], 0)},
[2, 3, 4, 5],
)
from tensordict.nn import TensorDictModule # noqa
from torch import vmap

Expand Down Expand Up @@ -6065,7 +6076,9 @@ def test_stack_apply(self):
},
[3],
)
td = TensorDict({"parent": LazyStackedTensorDict.lazy_stack([td0, td1], 0)}, [2])
td = TensorDict(
{"parent": LazyStackedTensorDict.lazy_stack([td0, td1], 0)}, [2]
)
td2 = td.clone()
tdapply = td.apply(lambda x, y: x + y, td2)
assert isinstance(tdapply["parent", "a", "b"], LazyStackedTensorDict)
Expand Down Expand Up @@ -6190,7 +6203,9 @@ def test_stacked_indexing(self, device, stack_dim):
device=device,
)

tds = LazyStackedTensorDict.lazy_stack(list(tensordict.unbind(stack_dim)), stack_dim)
tds = LazyStackedTensorDict.lazy_stack(
list(tensordict.unbind(stack_dim)), stack_dim
)

for item, expected_shape in (
((2, 2), torch.Size([5])),
Expand Down Expand Up @@ -6315,7 +6330,9 @@ def test_update_with_lazy(self):
},
[3],
)
td = TensorDict({"parent": LazyStackedTensorDict.lazy_stack([td0, td1], 0)}, [2])
td = TensorDict(
{"parent": LazyStackedTensorDict.lazy_stack([td0, td1], 0)}, [2]
)

td_void = TensorDict(
{
Expand Down Expand Up @@ -7469,7 +7486,9 @@ def test_set(self, non_tensor_data):

def test_stack(self, non_tensor_data):
assert (
LazyStackedTensorDict.lazy_stack([non_tensor_data, non_tensor_data], 0).get(("nested", "int"))
LazyStackedTensorDict.lazy_stack([non_tensor_data, non_tensor_data], 0).get(
("nested", "int")
)
== NonTensorData(3, batch_size=[2])
).all()
assert (
Expand Down

0 comments on commit 9b2cc80

Please sign in to comment.