Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Nov 23, 2023
1 parent fd0344d commit 5a08121
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 10 deletions.
2 changes: 1 addition & 1 deletion benchmarks/nn/functional_benchmarks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def test_vmap_mlp_speed(benchmark, stack, tdmodule):

@torch.no_grad()
@pytest.mark.parametrize("stack", [True, False])
@pytest.mark.parametrize("tdmodule", [True, False])
@pytest.mark.parametrize("tdmodule", [False, True])
def test_vmap_mlp_speed_decorator(benchmark, stack, tdmodule):
# tests speed of vmapping over a transformer
device = "cuda" if torch.cuda.device_count() else "cpu"
Expand Down
10 changes: 4 additions & 6 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,13 +1088,11 @@ def _iterate_over_keys(self) -> None:

@cache # noqa: B019
def _key_list(self, leaves_only=False, nodes_only=False):
keys = set(
self.tensordicts[0].keys(leaves_only=leaves_only, nodes_only=nodes_only)
)
s0 = self.tensordicts[0].keys(leaves_only=leaves_only, nodes_only=nodes_only)
keys = set(s0)
for td in self.tensordicts[1:]:
keys = keys.intersection(
td.keys(leaves_only=leaves_only, nodes_only=nodes_only)
)
s = set(td.keys(leaves_only=leaves_only, nodes_only=nodes_only))
keys = keys.intersection(s)
return sorted(keys, key=str)

def entry_class(self, key: NestedKey) -> type:
Expand Down
6 changes: 3 additions & 3 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None)
_swap = {}

for key, value in self.items(include_nested=False, leaves_only=True):
assert not is_tensor_collection(value), (key, self)
if module.__class__.__setattr__ is __base__setattr__:
# if setattr is the native nn.Module.setattr, we can rely on _set_tensor_dict
local_out = _set_tensor_dict(__dict__, module, key, value)
Expand All @@ -296,6 +297,7 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None)
_swap[key] = local_out

for key, value in self.items(include_nested=False, nodes_only=True):
assert is_tensor_collection(value)
for _ in value.keys():
# if there is at least one key, we must populate the module.
# Otherwise we just go to the next key
Expand Down Expand Up @@ -1721,9 +1723,7 @@ def clone(self, recurse: bool = True) -> T:
source._tensor_dict[key] = val

for key in self.keys(nodes_only=True):
source._tensor_dict[key] = self._get_str(key, NO_DEFAULT).clone(
recurse=False
)
source._dict_dict[key] = self._get_str(key, NO_DEFAULT).clone(recurse=False)

return TensorDict(
source=source,
Expand Down

0 comments on commit 5a08121

Please sign in to comment.