From 9b2cc800ad790801430b28bdff36c0b896bed5c1 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 24 Jan 2024 16:21:03 +0000 Subject: [PATCH] amend --- test/test_functorch.py | 4 ++-- test/test_tensordict.py | 49 ++++++++++++++++++++++++++++------------- 2 files changed, 36 insertions(+), 17 deletions(-) diff --git a/test/test_functorch.py b/test/test_functorch.py index 74de2d3b4..68c13be0e 100644 --- a/test/test_functorch.py +++ b/test/test_functorch.py @@ -308,8 +308,8 @@ def test_vmap_write_lazystack( ) td0 = TensorDict({key: [1.0]}, [1]) td1 = TensorDict({key: [2.0]}, [1]) - x = torch.stack([td0, td0.clone()], stack_dim) - y = torch.stack([td1, td1.clone()], stack_dim) + x = LazyStackedTensorDict.lazy_stack([td0, td0.clone()], stack_dim) + y = LazyStackedTensorDict.lazy_stack([td1, td1.clone()], stack_dim) if lock_x: x.lock_() if lock_y: diff --git a/test/test_tensordict.py b/test/test_tensordict.py index c0c80c99c..5f0e2290c 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -339,7 +339,9 @@ def test_dense_stack_tds(self, stack_dim, nested_stack_dim): td_container_clone.apply_(lambda x: x + 1) assert td_lazy.stack_dim == nested_stack_dim - td_stack = LazyStackedTensorDict.lazy_stack([td_container, td_container_clone], dim=stack_dim) + td_stack = LazyStackedTensorDict.lazy_stack( + [td_container, td_container_clone], dim=stack_dim + ) assert td_stack.stack_dim == stack_dim assert isinstance(td_stack, LazyStackedTensorDict) @@ -490,7 +492,9 @@ def test_filling_empty_tensordict(self, device, td_type, update): elif td_type == "squeeze": td = TensorDict({}, batch_size=[16, 1], device=device).squeeze(-1) elif td_type == "stack": - td = LazyStackedTensorDict.lazy_stack([TensorDict({}, [], device=device) for _ in range(16)], 0) + td = LazyStackedTensorDict.lazy_stack( + [TensorDict({}, [], device=device) for _ in range(16)], 0 + ) else: raise NotImplementedError @@ -4057,14 +4061,18 @@ def test_stack_onto(self, td_name, device, tmpdir): else: td1.apply_(lambda x: x.zero_() + 1) - is_lazy = td_name in ( - "sub_td", - "sub_td2", - "permute_td", - "unsqueezed_td", - "squeezed_td", - "td_h5", - ) and not lazy_legacy() + is_lazy = ( + td_name + in ( + "sub_td", + "sub_td2", + "permute_td", + "unsqueezed_td", + "squeezed_td", + "td_h5", + ) + and not lazy_legacy() + ) error_dec = ( pytest.raises(RuntimeError, match="Make it dense") if is_lazy @@ -5706,7 +5714,10 @@ def test_add_batch_dim_cache_nested(self): td = TensorDict( {"a": torch.rand(3, 4, 5), ("b", "c"): torch.rand(3, 4, 5)}, [3, 4, 5] ) - td = TensorDict({"parent": LazyStackedTensorDict.lazy_stack([td, td.clone()], 0)}, [2, 3, 4, 5]) + td = TensorDict( + {"parent": LazyStackedTensorDict.lazy_stack([td, td.clone()], 0)}, + [2, 3, 4, 5], + ) from tensordict.nn import TensorDictModule # noqa from torch import vmap @@ -6065,7 +6076,9 @@ def test_stack_apply(self): }, [3], ) - td = TensorDict({"parent": LazyStackedTensorDict.lazy_stack([td0, td1], 0)}, [2]) + td = TensorDict( + {"parent": LazyStackedTensorDict.lazy_stack([td0, td1], 0)}, [2] + ) td2 = td.clone() tdapply = td.apply(lambda x, y: x + y, td2) assert isinstance(tdapply["parent", "a", "b"], LazyStackedTensorDict) @@ -6190,7 +6203,9 @@ def test_stacked_indexing(self, device, stack_dim): device=device, ) - tds = LazyStackedTensorDict.lazy_stack(list(tensordict.unbind(stack_dim)), stack_dim) + tds = LazyStackedTensorDict.lazy_stack( + list(tensordict.unbind(stack_dim)), stack_dim + ) for item, expected_shape in ( ((2, 2), torch.Size([5])), @@ -6315,7 +6330,9 @@ def test_update_with_lazy(self): }, [3], ) - td = TensorDict({"parent": LazyStackedTensorDict.lazy_stack([td0, td1], 0)}, [2]) + td = TensorDict( + {"parent": LazyStackedTensorDict.lazy_stack([td0, td1], 0)}, [2] + ) td_void = TensorDict( { @@ -7469,7 +7486,9 @@ def test_set(self, non_tensor_data): def test_stack(self, non_tensor_data): assert ( - LazyStackedTensorDict.lazy_stack([non_tensor_data, non_tensor_data], 0).get(("nested", "int")) + LazyStackedTensorDict.lazy_stack([non_tensor_data, non_tensor_data], 0).get( + ("nested", "int") + ) == NonTensorData(3, batch_size=[2]) ).all() assert (