Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Feb 6, 2025
1 parent d1b97e0 commit 589a249
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 26 deletions.
1 change: 0 additions & 1 deletion tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6444,7 +6444,6 @@ def _get_tuple_maybe_non_tensor(self, key, default):
return result.data
return result


@overload
def get_at(self, key, index): ...

Expand Down
26 changes: 14 additions & 12 deletions tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,7 +1055,7 @@ def _write_to_tensordict(
tensordict_out = TensorDict()
else:
tensordict_out = tensordict
for _out_key, _tensor in zip(out_keys, tensors):
for _out_key, _tensor in _zip_strict(out_keys, tensors):
if _out_key != "_":
tensordict_out.set(_out_key, TensorDict.from_any(_tensor))
return tensordict_out
Expand Down Expand Up @@ -1097,7 +1097,9 @@ def forward(
for in_key in self.in_keys
)
try:
tensors = self._call_module(tensors, **kwargs)
tensors_out = self._call_module(tensors, **kwargs)
if tensors_out is None:
tensors_out = ()
except Exception as err:
if any(tensor is None for tensor in tensors) and "None" in str(err):
none_set = {
Expand All @@ -1112,18 +1114,18 @@ def forward(
) from err
else:
raise err
if isinstance(tensors, (dict, TensorDictBase)) and all(
key in tensors for key in self.out_keys
if isinstance(tensors_out, (dict, TensorDictBase)) and all(
key in tensors_out for key in self.out_keys
):
if isinstance(tensors, dict):
keys = unravel_key_list(list(tensors.keys()))
values = tensors.values()
tensors = dict(_zip_strict(keys, values))
tensors = tuple(tensors.get(key) for key in self.out_keys)
if not isinstance(tensors, tuple):
tensors = (tensors,)
if isinstance(tensors_out, dict):
keys = unravel_key_list(list(tensors_out.keys()))
values = tensors_out.values()
tensors_out = dict(_zip_strict(keys, values))
tensors_out = tuple(tensors_out.get(key) for key in self.out_keys)
if not isinstance(tensors_out, tuple):
tensors_out = (tensors_out,)
tensordict_out = self._write_to_tensordict(
tensordict, tensors, tensordict_out
tensordict, tensors_out, tensordict_out
)
return tensordict_out
except Exception as err:
Expand Down
4 changes: 2 additions & 2 deletions tensordict/nn/cudagraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
PYTREE_REGISTERED_LAZY_TDS,
PYTREE_REGISTERED_TDS,
)
from tensordict.utils import strtobool
from tensordict.utils import _zip_strict, strtobool
from torch import Tensor

from torch.utils._pytree import SUPPORTED_NODES, tree_map
Expand Down Expand Up @@ -296,7 +296,7 @@ def check_tensor_id(name, t0, t1):
def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
if self.counter >= self._warmup:
srcs, dests = [], []
for arg_src, arg_dest in zip(
for arg_src, arg_dest in _zip_strict(
tree_leaves((args, kwargs)), self._flat_tree
):
self._maybe_copy_onto_(arg_src, arg_dest, srcs, dests)
Expand Down
5 changes: 3 additions & 2 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from tensordict.memmap import MemoryMappedTensor
from tensordict.utils import (
_LOCK_ERROR,
_zip_strict,
BufferLegacy,
erase_cache,
implement_for,
Expand Down Expand Up @@ -475,8 +476,8 @@ def _reset_params(self, params: dict | None = None, buffers: dict | None = None)
buffer_keys.append(key)
buffers.append(value)

self._parameters.update(dict(zip(param_keys, params)))
self._buffers.update(dict(zip(buffer_keys, buffers)))
self._parameters.update(dict(_zip_strict(param_keys, params)))
self._buffers.update(dict(_zip_strict(buffer_keys, buffers)))
else:
self._parameters.update(params)
self._buffers.update(buffers)
Expand Down
7 changes: 3 additions & 4 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def log_prob(
return dist.log_prob(tensordict.get(self.out_keys[0]))

def _update_td_lp(self, lp):
for out_key, lp_key in zip(self.dist_sample_keys, self.log_prob_keys):
for out_key, lp_key in _zip_strict(self.dist_sample_keys, self.log_prob_keys):
lp_key_expected = _add_suffix(out_key, "_log_prob")
if lp_key != lp_key_expected:
lp.rename_key_(lp_key_expected, lp_key)
Expand Down Expand Up @@ -637,7 +637,7 @@ def forward(
if isinstance(out_tensors, Tensor):
out_tensors = (out_tensors,)
tensordict_out.update(
{key: value for key, value in zip(self.out_keys, out_tensors)}
dict(_zip_strict(self.dist_sample_keys, out_tensors))
)
if self.return_log_prob:
log_prob = dist.log_prob(*out_tensors)
Expand Down Expand Up @@ -1155,8 +1155,7 @@ def get_dist(
if isinstance(tdm, ProbabilisticTensorDictModule):
if isinstance(sample, torch.Tensor):
sample = [sample]
for val, key in zip(sample, tdm.out_keys):
td_copy.set(key, val)
td_copy.update(dict(_zip_strict(tdm.dist_sample_keys, sample)))
else:
td_copy.update(sample)
dists[tdm.out_keys[0]] = dist
Expand Down
6 changes: 3 additions & 3 deletions tensordict/nn/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from tensordict.nn.utils import _set_skip_existing_None
from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase
from tensordict.utils import unravel_key_list
from tensordict.utils import _zip_strict, unravel_key_list
from torch import nn

_has_functorch = False
Expand Down Expand Up @@ -205,7 +205,7 @@ def __init__(
in_keys, out_keys = self._compute_in_and_out_keys(modules_vals)
self._complete_out_keys = list(out_keys)
modules = collections.OrderedDict(
**{key: val for key, val in zip(modules[0], modules_vals)}
**{key: val for key, val in _zip_strict(modules[0], modules_vals)}
)
super().__init__(
module=nn.ModuleDict(modules), in_keys=in_keys, out_keys=out_keys
Expand Down Expand Up @@ -493,7 +493,7 @@ def select_subsequence(
else:
keys = [key for key in self.module if self.module[key] in modules]
modules_dict = collections.OrderedDict(
**{key: val for key, val in zip(keys, modules)}
**{key: val for key, val in _zip_strict(keys, modules)}
)
return type(self)(modules_dict)

Expand Down
7 changes: 5 additions & 2 deletions tensordict/tensorclass.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -601,11 +601,14 @@ class TensorClass:
def get(self, key, default): ...
def get(self, key: NestedKey, *args, **kwargs) -> CompatibleType: ...
@overload
def get_at(self, key, index):...
def get_at(self, key, index): ...
@overload
def get_at(self, key, index, default): ...
def get_at(
self, key: NestedKey, *args, **kwargs,
self,
key: NestedKey,
*args,
**kwargs,
) -> CompatibleType: ...
def get_item_shape(self, key: NestedKey): ...
def update(
Expand Down

0 comments on commit 589a249

Please sign in to comment.