From 5a08121453355253ca5269e80dfd4aefd89b0a0e Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 23 Nov 2023 16:43:42 +0000 Subject: [PATCH] fix --- benchmarks/nn/functional_benchmarks_test.py | 2 +- tensordict/_lazy.py | 10 ++++------ tensordict/_td.py | 6 +++--- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/benchmarks/nn/functional_benchmarks_test.py b/benchmarks/nn/functional_benchmarks_test.py index 0308797eb..2276962da 100644 --- a/benchmarks/nn/functional_benchmarks_test.py +++ b/benchmarks/nn/functional_benchmarks_test.py @@ -193,7 +193,7 @@ def test_vmap_mlp_speed(benchmark, stack, tdmodule): @torch.no_grad() @pytest.mark.parametrize("stack", [True, False]) -@pytest.mark.parametrize("tdmodule", [True, False]) +@pytest.mark.parametrize("tdmodule", [False, True]) def test_vmap_mlp_speed_decorator(benchmark, stack, tdmodule): # tests speed of vmapping over a transformer device = "cuda" if torch.cuda.device_count() else "cpu" diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index bffbd2840..e5934abf6 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -1088,13 +1088,11 @@ def _iterate_over_keys(self) -> None: @cache # noqa: B019 def _key_list(self, leaves_only=False, nodes_only=False): - keys = set( - self.tensordicts[0].keys(leaves_only=leaves_only, nodes_only=nodes_only) - ) + s0 = self.tensordicts[0].keys(leaves_only=leaves_only, nodes_only=nodes_only) + keys = set(s0) for td in self.tensordicts[1:]: - keys = keys.intersection( - td.keys(leaves_only=leaves_only, nodes_only=nodes_only) - ) + s = set(td.keys(leaves_only=leaves_only, nodes_only=nodes_only)) + keys = keys.intersection(s) return sorted(keys, key=str) def entry_class(self, key: NestedKey) -> type: diff --git a/tensordict/_td.py b/tensordict/_td.py index 209c2b47a..1d2e41c20 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -284,6 +284,7 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None) _swap = {} for key, value in self.items(include_nested=False, leaves_only=True): + assert not is_tensor_collection(value), (key, self) if module.__class__.__setattr__ is __base__setattr__: # if setattr is the native nn.Module.setattr, we can rely on _set_tensor_dict local_out = _set_tensor_dict(__dict__, module, key, value) @@ -296,6 +297,7 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None) _swap[key] = local_out for key, value in self.items(include_nested=False, nodes_only=True): + assert is_tensor_collection(value) for _ in value.keys(): # if there is at least one key, we must populate the module. # Otherwise we just go to the next key @@ -1721,9 +1723,7 @@ def clone(self, recurse: bool = True) -> T: source._tensor_dict[key] = val for key in self.keys(nodes_only=True): - source._tensor_dict[key] = self._get_str(key, NO_DEFAULT).clone( - recurse=False - ) + source._dict_dict[key] = self._get_str(key, NO_DEFAULT).clone(recurse=False) return TensorDict( source=source,