From 51ac5ec170c466d05c8af2c4349fe53a6ded11a2 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 22 Nov 2023 12:53:13 +0000 Subject: [PATCH 1/8] minor --- tensordict/_td.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index 2e01b91cd..b644b17a7 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -320,7 +320,6 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None) swap = swap.to(device=local_out.device) if return_swap: - assert local_out is not None, key swap._set_str(key, local_out, inplace=False, validated=True) return swap @@ -1242,12 +1241,13 @@ def _set_str( inplace: bool, validated: bool, ) -> T: - best_attempt = inplace is BEST_ATTEMPT_INPLACE - inplace = self._convert_inplace(inplace, key) + if inplace is not False: + best_attempt = inplace is BEST_ATTEMPT_INPLACE + inplace = self._convert_inplace(inplace, key) if not validated: value = self._validate_value(value, check_shape=True) if not inplace: - if self.is_locked: + if self._is_locked: raise RuntimeError(_LOCK_ERROR) self._tensordict[key] = value else: @@ -1703,14 +1703,13 @@ def contiguous(self) -> T: def empty(self, recurse=False) -> T: if not recurse: return TensorDict( - device=self.device, - batch_size=self.batch_size, + device=self._device, + batch_size=self._batch_size, source={}, - # names=self.names if self._has_names() else None, names=self._td_dim_names, _run_checks=False, - _is_memmap=self._is_memmap, - _is_shared=self._is_shared, + _is_memmap=False, + _is_shared=False, ) return super().empty(recurse=recurse) From ccc1d706fe8c006c19ac6a88d22dcf69ef8b058c Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 22 Nov 2023 14:03:35 +0000 Subject: [PATCH 2/8] amend --- tensordict/_td.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index b644b17a7..8091b1d7d 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -276,6 +276,7 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None) else: swap = swap_dest memo[id(module)] = swap + _swap = {} for key, value in self.items(): if isinstance(value, (Tensor, ftdim.Tensor)): @@ -320,7 +321,12 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None) swap = swap.to(device=local_out.device) if return_swap: - swap._set_str(key, local_out, inplace=False, validated=True) + _swap[key] = local_out + if return_swap: + if isinstance(swap, TensorDict): + swap._tensordict.update(_swap) + else: + swap.update(_swap) return swap def __ne__(self, other: object) -> T | bool: From af9a7e4a0ab701347969a99cae0bbaa825cc1074 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 22 Nov 2023 14:19:52 +0000 Subject: [PATCH 3/8] amend --- tensordict/_lazy.py | 6 +++--- tensordict/_td.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 488681dcd..62f2a1c3b 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -359,11 +359,11 @@ def _set_str( "permitted if all members of the stack have this key in " "their register." ) from e + if self.hook_in is not None: + value = self.hook_in(value) if not validated: value = self._validate_value(value) validated = True - if self.hook_in is not None: - value = self.hook_in(value) values = value.unbind(self.stack_dim) for tensordict, item in zip(self.tensordicts, values): tensordict._set_str(key, item, inplace=inplace, validated=validated) @@ -1584,7 +1584,7 @@ def update(self, input_dict_or_td: T, clone: bool = False, **kwargs: Any) -> T: for key, value in input_dict_or_td.items(): if clone and hasattr(value, "clone"): value = value.clone() - else: + elif clone: value = tree_map(torch.clone, value) if isinstance(key, tuple): key, subkey = key[0], key[1:] diff --git a/tensordict/_td.py b/tensordict/_td.py index 8091b1d7d..c6b8e8653 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -271,7 +271,7 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None) # For batch-size it is a minor issue (unlikely that a td with batch-size # is passed with to_module) but for the device it could be a problem. if swap_dest is None: - swap = self.empty() + swap = TensorDict({}, batch_size=[]) swap.clear_device_() else: swap = swap_dest From 06925be799687c294942642ca5a025ca5474b316 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 22 Nov 2023 15:32:49 +0000 Subject: [PATCH 4/8] fix lazy stack get within vmap --- tensordict/_lazy.py | 62 +++++++++++++++++++++++++-------------------- 1 file changed, 35 insertions(+), 27 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 62f2a1c3b..a5a2f7388 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -19,7 +19,7 @@ import torch from functorch import dim as ftdim from tensordict._td import _SubTensorDict, _TensorDictKeysView, TensorDict -from tensordict._tensordict import _unravel_key_to_tuple +from tensordict._tensordict import _unravel_key_to_tuple, unravel_key from tensordict.base import ( _ACCEPTED_CLASSES, _is_tensor_collection, @@ -359,11 +359,11 @@ def _set_str( "permitted if all members of the stack have this key in " "their register." ) from e - if self.hook_in is not None: - value = self.hook_in(value) if not validated: value = self._validate_value(value) validated = True + if self.hook_in is not None: + value = self.hook_in(value) values = value.unbind(self.stack_dim) for tensordict, item in zip(self.tensordicts, values): tensordict._set_str(key, item, inplace=inplace, validated=validated) @@ -778,10 +778,17 @@ def _get_str( # then it's a LazyStackedTD out.hook_out = self.hook_out out.hook_in = self.hook_in + out._batch_size = ( + self._batch_size + out.batch_size[(len(self._batch_size) + 1) :] + ) else: # then it's a tensorclass out._tensordict.hook_out = self.hook_out out._tensordict.hook_in = self.hook_in + out._tensordict._batch_size = ( + self._batch_size + + out._tensordict.batch_size[(len(self._batch_size) + 1) :] + ) elif self.hook_out is not None: out = self.hook_out(out) return out @@ -802,7 +809,7 @@ def _get_str( def _get_tuple(self, key, default): first = self._get_str(key[0], None) if first is None: - return self._default_get(first, default) + return self._default_get(key[0], default) if len(key) == 1: return first try: @@ -1580,36 +1587,37 @@ def update(self, input_dict_or_td: T, clone: bool = False, **kwargs: Any) -> T: td_dest.update(td_source, clone=clone, **kwargs) return self - keys = self.keys(False) + inplace = kwargs.get("inplace", False) for key, value in input_dict_or_td.items(): if clone and hasattr(value, "clone"): value = value.clone() elif clone: value = tree_map(torch.clone, value) + key = unravel_key(key) if isinstance(key, tuple): - key, subkey = key[0], key[1:] - else: - subkey = () - # the key must be a string by now. Let's check if it is present - if key in keys: - target_class = self.entry_class(key) - if _is_tensor_collection(target_class): - if isinstance(value, dict): - value_unbind = TensorDict( - value, self.batch_size, _run_checks=False - ).unbind(self.stack_dim) - else: - value_unbind = value.unbind(self.stack_dim) - for t, _value in zip(self.tensordicts, value_unbind): - if len(subkey): - t.update({key: {subkey: _value}}, clone=clone, **kwargs) - else: - t.update({key: _value}, clone=clone, **kwargs) - continue - if len(subkey): - self.set((key, *subkey), value, **kwargs) + # we must check that the target is not a leaf + target = self._get_str(key[0], default=None) + if is_tensor_collection(target): + target.update({key[1:]: value}, inplace=inplace, clone=clone) + elif target is None: + self._set_tuple(key, value, inplace=inplace, validated=False) + else: + raise TypeError( + f"Type mismatch: self.get(key[0]) is {type(target)} but expected a tensor collection." + ) else: - self.set(key, value, **kwargs) + target = self._get_str(key, default=None) + if is_tensor_collection(target) and ( + is_tensor_collection(value) or isinstance(value, dict) + ): + target.update(value, inplace=inplace, clone=clone) + elif target is None or not is_tensor_collection(value): + self._set_str(key, value, inplace=inplace, validated=False) + else: + raise TypeError( + f"Type mismatch: self.get(key) is {type(target)} but value is of type {type(value)}." + ) + return self def update_( From c8d7b3552e64f7a35c625675fb64e247f4fdd94a Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 22 Nov 2023 20:48:30 +0000 Subject: [PATCH 5/8] amend --- tensordict/_lazy.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index a5a2f7388..59f5fbae7 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -778,16 +778,19 @@ def _get_str( # then it's a LazyStackedTD out.hook_out = self.hook_out out.hook_in = self.hook_in + incr = 0 if self.hook_in is None else 1 out._batch_size = ( - self._batch_size + out.batch_size[(len(self._batch_size) + 1) :] + self._batch_size + + out.batch_size[(len(self._batch_size) + incr) :] ) else: # then it's a tensorclass out._tensordict.hook_out = self.hook_out out._tensordict.hook_in = self.hook_in + incr = 0 if self.hook_in is None else 1 out._tensordict._batch_size = ( self._batch_size - + out._tensordict.batch_size[(len(self._batch_size) + 1) :] + + out._tensordict.batch_size[(len(self._batch_size) + incr) :] ) elif self.hook_out is not None: out = self.hook_out(out) From 66afc321b3ace004fd379c39e146a78a97fb9365 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 22 Nov 2023 20:52:43 +0000 Subject: [PATCH 6/8] remove unlock --- tensordict/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index 74e4980f3..4e2abf7a6 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -3125,8 +3125,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): return self.lock_() if last_op == self.__class__.to_module.__name__: if is_tensor_collection(out): - with out.unlock_(): - return self.to_module(*args, **kwargs, swap_dest=out) + # with out.unlock_(): + return self.to_module(*args, **kwargs, swap_dest=out) else: raise RuntimeError( "to_module cannot be used as a decorator when return_swap=False." From 0edf1485b54f73e479c750c987a8674c5b76202d Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 23 Nov 2023 08:10:50 +0000 Subject: [PATCH 7/8] amend --- tensordict/_lazy.py | 21 +++++++++++++-------- tensordict/base.py | 4 ++++ 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 59f5fbae7..d060beda2 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -133,6 +133,8 @@ class LazyStackedTensorDict(TensorDictBase): """ + _is_vmapped: bool = False + @classmethod def __torch_function__( cls, @@ -362,7 +364,7 @@ def _set_str( if not validated: value = self._validate_value(value) validated = True - if self.hook_in is not None: + if self._is_vmapped: value = self.hook_in(value) values = value.unbind(self.stack_dim) for tensordict, item in zip(self.tensordicts, values): @@ -397,7 +399,7 @@ def _set_tuple( if not validated: value = self._validate_value(value) validated = True - if self.hook_in is not None: + if self._is_vmapped: value = self.hook_in(value) values = value.unbind(self.stack_dim) for tensordict, item in zip(self.tensordicts, values): @@ -554,7 +556,7 @@ def _set_at_str(self, key, value, index, *, validated): if not validated: value = self._validate_value(value, check_shape=False) validated = True - if self.hook_in is not None: + if self._is_vmapped: value = self.hook_in(value) split_index = self._split_index(index) converted_idx = split_index["index_dict"] @@ -649,7 +651,7 @@ def _set_at_tuple(self, key, value, idx, *, validated): if not validated: value = self._validate_value(value, check_shape=False) validated = True - if self.hook_in is not None: + if self._is_vmapped: value = self.hook_in(value) item = td._get_str(key, NO_DEFAULT) item[idx] = value @@ -778,7 +780,8 @@ def _get_str( # then it's a LazyStackedTD out.hook_out = self.hook_out out.hook_in = self.hook_in - incr = 0 if self.hook_in is None else 1 + out._is_vmapped = self._is_vmapped + incr = 0 if not self._is_vmapped else 1 out._batch_size = ( self._batch_size + out.batch_size[(len(self._batch_size) + incr) :] @@ -787,7 +790,8 @@ def _get_str( # then it's a tensorclass out._tensordict.hook_out = self.hook_out out._tensordict.hook_in = self.hook_in - incr = 0 if self.hook_in is None else 1 + out._tensordict._is_vmapped = self._is_vmapped + incr = 0 if not self._is_vmapped else 1 out._tensordict._batch_size = ( self._batch_size + out._tensordict.batch_size[(len(self._batch_size) + incr) :] @@ -860,7 +864,7 @@ def _cached_add_batch_dims(cls, td, in_dim, vmap_level): # we return a stack with hook_out, and hack the batch_size and names # Per se it is still a LazyStack but the stacking dim is "hidden" from # the outside - out = td.clone(False) + out = td.copy() def hook_out(tensor, in_dim=in_dim, vmap_level=vmap_level): return _add_batch_dim(tensor, in_dim, vmap_level) @@ -879,6 +883,7 @@ def hook_in( out.hook_out = hook_out out.hook_in = hook_in + out._is_vmapped = True out._batch_size = torch.Size( [dim for i, dim in enumerate(out._batch_size) if i != out.stack_dim] ) @@ -1580,7 +1585,7 @@ def update(self, input_dict_or_td: T, clone: bool = False, **kwargs: Any) -> T: isinstance(input_dict_or_td, LazyStackedTensorDict) and input_dict_or_td.stack_dim == self.stack_dim ): - if not input_dict_or_td.shape[self.stack_dim] == len(self.tensordicts): + if len(input_dict_or_td.tensordicts) != len(self.tensordicts): raise ValueError( "cannot update stacked tensordicts with different shapes." ) diff --git a/tensordict/base.py b/tensordict/base.py index 4e2abf7a6..e54d77d22 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -3520,6 +3520,10 @@ def flatten_keys(self, separator: str = ".", inplace: bool = False) -> T: result._set_str( leaf_flat, self.get(leaf), validated=True, inplace=False ) + shared = result._is_shared = self._is_shared + mmap = result._is_memmap = self._is_memmap + if shared or mmap: + result._is_locked = True return result @cache # noqa: B019 From c55e28afe4e00c8e3e8f32d92300c312622935dd Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 23 Nov 2023 08:57:31 +0000 Subject: [PATCH 8/8] amend --- tensordict/_td.py | 2 +- tensordict/base.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index c6b8e8653..5c34d8fbd 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -272,7 +272,6 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None) # is passed with to_module) but for the device it could be a problem. if swap_dest is None: swap = TensorDict({}, batch_size=[]) - swap.clear_device_() else: swap = swap_dest memo[id(module)] = swap @@ -324,6 +323,7 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None) _swap[key] = local_out if return_swap: if isinstance(swap, TensorDict): + # this is very ad-hoc but faster than calling _set_str every time swap._tensordict.update(_swap) else: swap.update(_swap) diff --git a/tensordict/base.py b/tensordict/base.py index e54d77d22..e9f0b7b5a 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -3125,7 +3125,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): return self.lock_() if last_op == self.__class__.to_module.__name__: if is_tensor_collection(out): - # with out.unlock_(): return self.to_module(*args, **kwargs, swap_dest=out) else: raise RuntimeError(