From 589a249455dbd89d52794cee48ee34dafbf3dc77 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 6 Feb 2025 13:11:55 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- tensordict/base.py | 1 - tensordict/nn/common.py | 26 ++++++++++++++------------ tensordict/nn/cudagraphs.py | 4 ++-- tensordict/nn/params.py | 5 +++-- tensordict/nn/probabilistic.py | 7 +++---- tensordict/nn/sequence.py | 6 +++--- tensordict/tensorclass.pyi | 7 +++++-- 7 files changed, 30 insertions(+), 26 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index 354237cc6..376830c31 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -6444,7 +6444,6 @@ def _get_tuple_maybe_non_tensor(self, key, default): return result.data return result - @overload def get_at(self, key, index): ... diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index de28f6d49..02b98cf30 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -1055,7 +1055,7 @@ def _write_to_tensordict( tensordict_out = TensorDict() else: tensordict_out = tensordict - for _out_key, _tensor in zip(out_keys, tensors): + for _out_key, _tensor in _zip_strict(out_keys, tensors): if _out_key != "_": tensordict_out.set(_out_key, TensorDict.from_any(_tensor)) return tensordict_out @@ -1097,7 +1097,9 @@ def forward( for in_key in self.in_keys ) try: - tensors = self._call_module(tensors, **kwargs) + tensors_out = self._call_module(tensors, **kwargs) + if tensors_out is None: + tensors_out = () except Exception as err: if any(tensor is None for tensor in tensors) and "None" in str(err): none_set = { @@ -1112,18 +1114,18 @@ def forward( ) from err else: raise err - if isinstance(tensors, (dict, TensorDictBase)) and all( - key in tensors for key in self.out_keys + if isinstance(tensors_out, (dict, TensorDictBase)) and all( + key in tensors_out for key in self.out_keys ): - if isinstance(tensors, dict): - keys = unravel_key_list(list(tensors.keys())) - values = tensors.values() - tensors = dict(_zip_strict(keys, values)) - tensors = tuple(tensors.get(key) for key in self.out_keys) - if not isinstance(tensors, tuple): - tensors = (tensors,) + if isinstance(tensors_out, dict): + keys = unravel_key_list(list(tensors_out.keys())) + values = tensors_out.values() + tensors_out = dict(_zip_strict(keys, values)) + tensors_out = tuple(tensors_out.get(key) for key in self.out_keys) + if not isinstance(tensors_out, tuple): + tensors_out = (tensors_out,) tensordict_out = self._write_to_tensordict( - tensordict, tensors, tensordict_out + tensordict, tensors_out, tensordict_out ) return tensordict_out except Exception as err: diff --git a/tensordict/nn/cudagraphs.py b/tensordict/nn/cudagraphs.py index c3fcffe5f..f7da93cb9 100644 --- a/tensordict/nn/cudagraphs.py +++ b/tensordict/nn/cudagraphs.py @@ -22,7 +22,7 @@ PYTREE_REGISTERED_LAZY_TDS, PYTREE_REGISTERED_TDS, ) -from tensordict.utils import strtobool +from tensordict.utils import _zip_strict, strtobool from torch import Tensor from torch.utils._pytree import SUPPORTED_NODES, tree_map @@ -296,7 +296,7 @@ def check_tensor_id(name, t0, t1): def _call(*args: torch.Tensor, **kwargs: torch.Tensor): if self.counter >= self._warmup: srcs, dests = [], [] - for arg_src, arg_dest in zip( + for arg_src, arg_dest in _zip_strict( tree_leaves((args, kwargs)), self._flat_tree ): self._maybe_copy_onto_(arg_src, arg_dest, srcs, dests) diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 0a5107e9e..64295667e 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -34,6 +34,7 @@ from tensordict.memmap import MemoryMappedTensor from tensordict.utils import ( _LOCK_ERROR, + _zip_strict, BufferLegacy, erase_cache, implement_for, @@ -475,8 +476,8 @@ def _reset_params(self, params: dict | None = None, buffers: dict | None = None) buffer_keys.append(key) buffers.append(value) - self._parameters.update(dict(zip(param_keys, params))) - self._buffers.update(dict(zip(buffer_keys, buffers))) + self._parameters.update(dict(_zip_strict(param_keys, params))) + self._buffers.update(dict(_zip_strict(buffer_keys, buffers))) else: self._parameters.update(params) self._buffers.update(buffers) diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index ee7b72076..666a772f5 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -576,7 +576,7 @@ def log_prob( return dist.log_prob(tensordict.get(self.out_keys[0])) def _update_td_lp(self, lp): - for out_key, lp_key in zip(self.dist_sample_keys, self.log_prob_keys): + for out_key, lp_key in _zip_strict(self.dist_sample_keys, self.log_prob_keys): lp_key_expected = _add_suffix(out_key, "_log_prob") if lp_key != lp_key_expected: lp.rename_key_(lp_key_expected, lp_key) @@ -637,7 +637,7 @@ def forward( if isinstance(out_tensors, Tensor): out_tensors = (out_tensors,) tensordict_out.update( - {key: value for key, value in zip(self.out_keys, out_tensors)} + dict(_zip_strict(self.dist_sample_keys, out_tensors)) ) if self.return_log_prob: log_prob = dist.log_prob(*out_tensors) @@ -1155,8 +1155,7 @@ def get_dist( if isinstance(tdm, ProbabilisticTensorDictModule): if isinstance(sample, torch.Tensor): sample = [sample] - for val, key in zip(sample, tdm.out_keys): - td_copy.set(key, val) + td_copy.update(dict(_zip_strict(tdm.dist_sample_keys, sample))) else: td_copy.update(sample) dists[tdm.out_keys[0]] = dist diff --git a/tensordict/nn/sequence.py b/tensordict/nn/sequence.py index adb2ff314..1517e3067 100644 --- a/tensordict/nn/sequence.py +++ b/tensordict/nn/sequence.py @@ -20,7 +20,7 @@ ) from tensordict.nn.utils import _set_skip_existing_None from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase -from tensordict.utils import unravel_key_list +from tensordict.utils import _zip_strict, unravel_key_list from torch import nn _has_functorch = False @@ -205,7 +205,7 @@ def __init__( in_keys, out_keys = self._compute_in_and_out_keys(modules_vals) self._complete_out_keys = list(out_keys) modules = collections.OrderedDict( - **{key: val for key, val in zip(modules[0], modules_vals)} + **{key: val for key, val in _zip_strict(modules[0], modules_vals)} ) super().__init__( module=nn.ModuleDict(modules), in_keys=in_keys, out_keys=out_keys @@ -493,7 +493,7 @@ def select_subsequence( else: keys = [key for key in self.module if self.module[key] in modules] modules_dict = collections.OrderedDict( - **{key: val for key, val in zip(keys, modules)} + **{key: val for key, val in _zip_strict(keys, modules)} ) return type(self)(modules_dict) diff --git a/tensordict/tensorclass.pyi b/tensordict/tensorclass.pyi index 615c72658..9db6d2c71 100644 --- a/tensordict/tensorclass.pyi +++ b/tensordict/tensorclass.pyi @@ -601,11 +601,14 @@ class TensorClass: def get(self, key, default): ... def get(self, key: NestedKey, *args, **kwargs) -> CompatibleType: ... @overload - def get_at(self, key, index):... + def get_at(self, key, index): ... @overload def get_at(self, key, index, default): ... def get_at( - self, key: NestedKey, *args, **kwargs, + self, + key: NestedKey, + *args, + **kwargs, ) -> CompatibleType: ... def get_item_shape(self, key: NestedKey): ... def update(