diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 966f36872..ae5c642b3 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -1235,9 +1235,13 @@ def _add_batch_dim(self, *, in_dim, vmap_level): else: in_dim = in_dim - 1 stack_dim = td.stack_dim + + def addbatchdim(_arg): + return _add_batch_dim(_arg, in_dim, vmap_level) + tds = [ td._fast_apply( - lambda _arg: _add_batch_dim(_arg, in_dim, vmap_level), + addbatchdim, batch_size=[b for i, b in enumerate(td.batch_size) if i != in_dim], names=( [name for i, name in enumerate(td.names) if i != in_dim] @@ -3568,7 +3572,10 @@ def is_contiguous(self) -> bool: return all([value.is_contiguous() for _, value in self.items()]) def contiguous(self) -> T: - return self._fast_apply(lambda x: x.contiguous(), propagate_lock=True) + def contiguous(x): + return x.contiguous() + + return self._fast_apply(contiguous, propagate_lock=True) def rename_key_( self, old_key: NestedKey, new_key: NestedKey, safe: bool = False diff --git a/tensordict/_td.py b/tensordict/_td.py index 1730e9200..f7f3a4b3f 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -1871,10 +1871,14 @@ def repeat_interleave( for i, s in enumerate(self.batch_size) ] ) - return self._fast_apply( - lambda leaf: leaf.repeat_interleave( + + def rep(leaf): + return leaf.repeat_interleave( repeats=repeats, dim=dim_corrected, output_size=output_size - ), + ) + + return self._fast_apply( + rep, batch_size=new_batch_size, call_on_nested=True, propagate_lock=True, @@ -1882,8 +1886,12 @@ def repeat_interleave( def _repeat(self, *repeats: int) -> TensorDictBase: new_batch_size = torch.Size([i * r for i, r in zip(self.batch_size, repeats)]) + + def rep(leaf): + return leaf.repeat(*repeats, *((1,) * (leaf.ndim - self.ndim))) + return self._fast_apply( - lambda leaf: leaf.repeat(*repeats, *((1,) * (leaf.ndim - self.ndim))), + rep, batch_size=new_batch_size, call_on_nested=True, propagate_lock=True, @@ -2008,8 +2016,11 @@ def _squeeze(tensor): if names: names.pop(dim) + def squeeze(x): + return x.squeeze(newdim) + result = self._fast_apply( - lambda x: x.squeeze(newdim), + squeeze, batch_size=batch_size, names=names, inplace=False, @@ -3362,9 +3373,13 @@ def keys( if not sort: return _StringKeys(self._tensordict.keys()) else: + + def keyfunc(x): + return ".".join(x) if isinstance(x, tuple) else x + return sorted( _StringKeys(self._tensordict.keys()), - key=lambda x: ".".join(x) if isinstance(x, tuple) else x, + key=keyfunc, ) else: return self._nested_keys( @@ -3403,7 +3418,11 @@ def items( if not include_nested and not leaves_only: if not sort: return self._tensordict.items() - return sorted(self._tensordict.items(), key=lambda x: x[0]) + + def keyfunc(x): + return x[0] + + return sorted(self._tensordict.items(), key=keyfunc) elif include_nested and leaves_only and not sort: is_leaf = _default_is_leaf if is_leaf is None else is_leaf result = [] @@ -3449,9 +3468,11 @@ def values( if not sort: return self._tensordict.values() else: - return list(zip(*sorted(self._tensordict.items(), key=lambda x: x[0])))[ - 1 - ] + + def keyfunc(x): + return x[0] + + return list(zip(*sorted(self._tensordict.items(), key=keyfunc)))[1] else: return TensorDictBase.values( self, @@ -4061,8 +4082,12 @@ def expand(self, *args: int, inplace: bool = False) -> T: shape = tuple(args[0]) else: shape = args + + def expand(x): + return x.expand((*shape, *x.shape[self.ndim :])) + return self._fast_apply( - lambda x: x.expand((*shape, *x.shape[self.ndim :])), + expand, batch_size=shape, propagate_lock=True, ) @@ -4432,9 +4457,13 @@ def _iter(): ) if self.sort: + + def keyfunc(key): + return ".".join(key) if isinstance(key, tuple) else key + yield from sorted( _iter(), - key=lambda key: ".".join(key) if isinstance(key, tuple) else key, + key=keyfunc, ) else: yield from _iter() diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index a587360d7..26095372d 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -144,8 +144,11 @@ def _gather_tensor(tensor, dest_container=None, dest_key=None): @implements_for_td(torch.full_like) def _full_like(td: T, fill_value: float, *args, **kwargs: Any) -> T: + def full_like(x): + return torch.full_like(x, fill_value, *args, **kwargs) + return td._fast_apply( - lambda x: torch.full_like(x, fill_value, *args, **kwargs), + full_like, inplace=True, propagate_lock=True, device=kwargs.get("device", NO_DEFAULT), @@ -154,8 +157,11 @@ def _full_like(td: T, fill_value: float, *args, **kwargs: Any) -> T: @implements_for_td(torch.zeros_like) def _zeros_like(td: T, *args, **kwargs: Any) -> T: + def zeros_like(x): + return torch.zeros_like(x, *args, **kwargs) + td_clone = td._fast_apply( - lambda x: torch.zeros_like(x, *args, **kwargs), + zeros_like, propagate_lock=True, device=kwargs.get("device", NO_DEFAULT), ) @@ -173,8 +179,11 @@ def _zeros_like(td: T, *args, **kwargs: Any) -> T: @implements_for_td(torch.ones_like) def _ones_like(td: T, *args, **kwargs: Any) -> T: + def ones_like(x): + return torch.ones_like(x, *args, **kwargs) + td_clone = td._fast_apply( - lambda x: torch.ones_like(x, *args, **kwargs), + ones_like, propagate_lock=True, device=kwargs.get("device", NO_DEFAULT), ) @@ -190,8 +199,11 @@ def _ones_like(td: T, *args, **kwargs: Any) -> T: @implements_for_td(torch.rand_like) def _rand_like(td: T, *args, **kwargs: Any) -> T: + def rand_like(x): + return torch.rand_like(x, *args, **kwargs) + td_clone = td._fast_apply( - lambda x: torch.rand_like(x, *args, **kwargs), + rand_like, propagate_lock=True, device=kwargs.get("device", NO_DEFAULT), ) @@ -207,8 +219,11 @@ def _rand_like(td: T, *args, **kwargs: Any) -> T: @implements_for_td(torch.randn_like) def _randn_like(td: T, *args, **kwargs: Any) -> T: + def randn_like(x): + return torch.randn_like(x, *args, **kwargs) + td_clone = td._fast_apply( - lambda x: torch.randn_like(x, *args, **kwargs), + randn_like, propagate_lock=True, device=kwargs.get("device", NO_DEFAULT), ) @@ -224,8 +239,11 @@ def _randn_like(td: T, *args, **kwargs: Any) -> T: @implements_for_td(torch.empty_like) def _empty_like(td: T, *args, **kwargs) -> T: + def empty_like(x): + return torch.empty_like(x, *args, **kwargs) + return td._fast_apply( - lambda x: torch.empty_like(x, *args, **kwargs), + empty_like, propagate_lock=True, device=kwargs.get("device", NO_DEFAULT), ) diff --git a/tensordict/base.py b/tensordict/base.py index 621db37db..9b6818a98 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -517,8 +517,12 @@ def isfinite(self) -> T: keys, vals = self._items_list(True, True) vals = [val.isfinite() for val in vals] items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -534,8 +538,12 @@ def isnan(self) -> T: keys, vals = self._items_list(True, True) vals = [val.isnan() for val in vals] items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -547,8 +555,12 @@ def isneginf(self) -> T: keys, vals = self._items_list(True, True) vals = [val.isneginf() for val in vals] items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -560,8 +572,12 @@ def isposinf(self) -> T: keys, vals = self._items_list(True, True) vals = [val.isposinf() for val in vals] items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -573,8 +589,12 @@ def isreal(self) -> T: keys, vals = self._items_list(True, True) vals = [val.isreal() for val in vals] items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -2874,9 +2894,11 @@ def expand_as(self, other: TensorDictBase | torch.Tensor) -> TensorDictBase: """ if _is_tensor_collection(type(other)): - return self.apply( - lambda x, y: x.expand_as(y), other, batch_size=other.batch_size - ) + + def expand_as(x, y): + return x.expand_as(y) + + return self.apply(expand_as, other, batch_size=other.batch_size) return self.expand(other.shape) def new_zeros( @@ -3828,7 +3850,10 @@ def view( def _view_dtype(self, *, dtype, batch_size): # We use apply because we want to check the shapes - return self.apply(lambda x: x.view(dtype), batch_size=batch_size) + def view(x): + return x.view(dtype) + + return self.apply(view, batch_size=batch_size) def _legacy_view( self, @@ -4458,8 +4483,12 @@ def pin_memory(self, num_threads: int | None = None, inplace: bool = False) -> T Defaults to ``False``. """ + + def pin_memory(x): + return x.pin_memory() + return self._fast_apply( - lambda x: x.pin_memory(), + pin_memory, num_threads=num_threads, inplace=inplace, propagate_lock=True, @@ -5676,14 +5705,16 @@ def memmap_like( if return_early: executor = ThreadPoolExecutor(max_workers=num_threads) futures = [] + # we create an empty copy of self # This is because calling MMapTensor.from_tensor(mmap_tensor) does nothing # if both are in filesystem - input = self.apply( - lambda x: torch.empty((), device=x.device, dtype=x.dtype).expand( + def empty(x): + return torch.empty((), device=x.device, dtype=x.dtype).expand( x.shape ) - ) + + input = self.apply(empty) result = input._memmap_( prefix=prefix, copy_existing=copy_existing, @@ -5699,9 +5730,11 @@ def memmap_like( return result else: return TensorDictFuture(futures, result) - input = self.apply( - lambda x: torch.empty((), device=x.device, dtype=x.dtype).expand(x.shape) - ) + + def empty_expand(x): + return torch.empty((), device=x.device, dtype=x.dtype).expand(x.shape) + + input = self.apply(empty_expand) return input._memmap_( prefix=prefix, copy_existing=copy_existing, @@ -6795,15 +6828,17 @@ def items( """ if sort: + + def keyfunc(item): + return item[0] if isinstance(item[0], str) else ".".join(item[0]) + yield from sorted( self.items( include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf, ), - key=lambda item: ( - item[0] if isinstance(item[0], str) else ".".join(item[0]) - ), + key=keyfunc, ) else: @@ -6993,8 +7028,12 @@ def _grad(self): keys, vals = self._items_list(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS) grads = [val.grad for val in vals] items = dict(zip(keys, grads)) + + def get(name, val): + return items[name] + return self._fast_apply( - lambda name, val: items[name], + get, named=True, nested_keys=True, propagate_lock=True, @@ -7006,8 +7045,12 @@ def _data(self): keys, vals = self._items_list(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS) data = [val.data for val in vals] items = dict(zip(keys, data)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name), + get, named=True, nested_keys=True, propagate_lock=True, @@ -8890,8 +8933,12 @@ def abs(self) -> T: keys, vals = self._items_list(True, True) vals = torch._foreach_abs(vals) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -8908,8 +8955,12 @@ def acos(self) -> T: keys, vals = self._items_list(True, True) vals = torch._foreach_acos(vals) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -8926,8 +8977,12 @@ def exp(self) -> T: keys, vals = self._items_list(True, True) vals = torch._foreach_exp(vals) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -8944,8 +8999,12 @@ def neg(self) -> T: keys, vals = self._items_list(True, True) vals = torch._foreach_neg(vals) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -8962,8 +9021,12 @@ def reciprocal(self) -> T: keys, vals = self._items_list(True, True) vals = torch._foreach_reciprocal(vals) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -8980,8 +9043,12 @@ def sigmoid(self) -> T: keys, vals = self._items_list(True, True) vals = torch._foreach_sigmoid(vals) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -8998,8 +9065,12 @@ def sign(self) -> T: keys, vals = self._items_list(True, True) vals = torch._foreach_sign(vals) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -9016,8 +9087,12 @@ def sin(self) -> T: keys, vals = self._items_list(True, True) vals = torch._foreach_sin(vals) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -9034,8 +9109,12 @@ def sinh(self) -> T: keys, vals = self._items_list(True, True) vals = torch._foreach_sinh(vals) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -9052,8 +9131,12 @@ def tan(self) -> T: keys, vals = self._items_list(True, True) vals = torch._foreach_tan(vals) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -9070,8 +9153,12 @@ def tanh(self) -> T: keys, vals = self._items_list(True, True) vals = torch._foreach_tanh(vals) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -9088,8 +9175,12 @@ def trunc(self) -> T: keys, vals = self._items_list(True, True) vals = torch._foreach_trunc(vals) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -9120,8 +9211,12 @@ def norm( raise RuntimeError("dtype must be None for torch <= 2.3") vals = torch._foreach_norm(vals) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, batch_size=[], @@ -9146,8 +9241,12 @@ def norm( # noqa: F811 keys, vals = self._items_list(True, True, collapse=True) vals = torch._foreach_norm(vals, dtype=dtype) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, batch_size=[], @@ -9160,8 +9259,12 @@ def lgamma(self) -> T: keys, vals = self._items_list(True, True) vals = torch._foreach_lgamma(vals) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -9178,8 +9281,12 @@ def frac(self) -> T: keys, vals = self._items_list(True, True) vals = torch._foreach_frac(vals) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -9196,8 +9303,12 @@ def expm1(self) -> T: keys, vals = self._items_list(True, True) vals = torch._foreach_expm1(vals) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -9214,8 +9325,12 @@ def log(self) -> T: keys, vals = self._items_list(True, True) vals = torch._foreach_log(vals) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -9232,8 +9347,12 @@ def log10(self) -> T: keys, vals = self._items_list(True, True) vals = torch._foreach_log10(vals) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -9250,8 +9369,12 @@ def log1p(self) -> T: keys, vals = self._items_list(True, True) vals = torch._foreach_log1p(vals) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -9268,8 +9391,12 @@ def log2(self) -> T: keys, vals = self._items_list(True, True) vals = torch._foreach_log2(vals) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -9286,8 +9413,12 @@ def ceil(self) -> T: keys, vals = self._items_list(True, True) vals = torch._foreach_ceil(vals) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -9304,8 +9435,12 @@ def floor(self) -> T: keys, vals = self._items_list(True, True) vals = torch._foreach_floor(vals) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -9322,8 +9457,12 @@ def round(self) -> T: keys, vals = self._items_list(True, True) vals = torch._foreach_round(vals) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -9340,8 +9479,12 @@ def erf(self) -> T: keys, vals = self._items_list(True, True) vals = torch._foreach_erf(vals) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -9358,8 +9501,12 @@ def erfc(self) -> T: keys, vals = self._items_list(True, True) vals = torch._foreach_erfc(vals) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -9376,8 +9523,12 @@ def asin(self) -> T: keys, vals = self._items_list(True, True) vals = torch._foreach_asin(vals) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -9394,8 +9545,12 @@ def atan(self) -> T: keys, vals = self._items_list(True, True) vals = torch._foreach_atan(vals) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -9412,8 +9567,12 @@ def cos(self) -> T: keys, vals = self._items_list(True, True) vals = torch._foreach_cos(vals) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -9430,8 +9589,12 @@ def cosh(self) -> T: keys, vals = self._items_list(True, True) vals = torch._foreach_cosh(vals) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -9476,8 +9639,12 @@ def _clone_recurse(self) -> TensorDictBase: # noqa: D417 items = foreach_vals items.update(iter_vals) + + def pop(name, val): + return items.pop(name, None) + result = self._fast_apply( - lambda name, val: items.pop(name, None), + pop, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -9530,8 +9697,12 @@ def add( else: vals = torch._foreach_add(vals, other_val) items = dict(zip(keys, vals)) + + def pop(name, val): + return items.pop(name, None) + result = self._fast_apply( - lambda name, val: items.pop(name, None), + pop, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -9593,8 +9764,12 @@ def lerp( weight_val = weight vals = torch._foreach_lerp(vals, end_val, weight_val) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -9651,8 +9826,12 @@ def addcdiv( other2_val = other2 vals = torch._foreach_addcdiv(vals, other1_val, other2_val, value=value) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -9704,8 +9883,12 @@ def addcmul(self, other1, other2, *, value: float | None = 1): # noqa: D417 other2_val = other2 vals = torch._foreach_addcmul(vals, other1_val, other2_val, value=value) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -9771,8 +9954,12 @@ def sub( else: vals = torch._foreach_sub(vals, other_val) items = dict(zip(keys, vals)) + + def pop(name, val): + return items.pop(name, None) + result = self._fast_apply( - lambda name, val: items.pop(name, None), + pop, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -9857,8 +10044,12 @@ def mul( other_val = other vals = torch._foreach_mul(vals, other_val) items = dict(zip(keys, vals)) + + def pop(name, val): + return items.pop(name, None) + result = self._fast_apply( - lambda name, val: items.pop(name, None), + pop, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -9919,8 +10110,12 @@ def maximum( other_val = other vals = torch._foreach_maximum(vals, other_val) items = dict(zip(keys, vals)) + + def pop(name, val): + return items.pop(name, None) + result = self._fast_apply( - lambda name, val: items.pop(name, None), + pop, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -9981,8 +10176,12 @@ def minimum( other_val = other vals = torch._foreach_minimum(vals, other_val) items = dict(zip(keys, vals)) + + def pop(name, val): + return items.pop(name, None) + result = self._fast_apply( - lambda name, val: items.pop(name, None), + pop, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -10043,8 +10242,12 @@ def clamp_max( other_val = other vals = torch._foreach_clamp_max(vals, other_val) items = dict(zip(keys, vals)) + + def pop(name, val): + return items.pop(name, None) + result = self._fast_apply( - lambda name, val: items.pop(name, None), + pop, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -10104,8 +10307,12 @@ def clamp_min( other_val = other vals = torch._foreach_clamp_min(vals, other_val) items = dict(zip(keys, vals)) + + def pop(name, val): + return items.pop(name, None) + result = self._fast_apply( - lambda name, val: items.pop(name, None), + pop, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -10171,8 +10378,12 @@ def pow( other_val = other vals = torch._foreach_pow(vals, other_val) items = dict(zip(keys, vals)) + + def pop(name, val): + return items.pop(name, None) + result = self._fast_apply( - lambda name, val: items.pop(name, None), + pop, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -10239,8 +10450,12 @@ def div( other_val = other vals = torch._foreach_div(vals, other_val) items = dict(zip(keys, vals)) + + def pop(name, val): + return items.pop(name, None) + result = self._fast_apply( - lambda name, val: items.pop(name, None), + pop, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -10262,8 +10477,12 @@ def sqrt(self): keys, vals = self._items_list(True, True) vals = torch._foreach_sqrt(vals) items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + return self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -11351,7 +11570,11 @@ def fill_(self, key: NestedKey, value: float | bool) -> T: key = _unravel_key_to_tuple(key) data = self._get_tuple(key, NO_DEFAULT) if _is_tensor_collection(type(data)): - data._fast_apply(lambda x: x.fill_(value), inplace=True) + + def fill(x): + return x.fill_(value) + + data._fast_apply(fill, inplace=True) else: data = data.fill_(value) self._set_tuple(key, data, inplace=True, validated=True, non_blocking=False) @@ -12100,8 +12323,12 @@ def _to_cuda_with_pin_mem( finally: for thread in threads: thread.join(timeout=_PIN_MEM_TIMEOUT) + + def get(name, val): + return items.get(name, val) + result = self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -12256,8 +12483,12 @@ def to_pinmem(tensor, _to=to): keys, tensors = self._items_list(True, True) tensors = [to(t) for t in tensors] items = dict(zip(keys, tensors)) + + def get(name, val): + return items.get(name, val) + result = self._fast_apply( - lambda name, val: items.get(name, val), + get, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -12396,23 +12627,43 @@ def is_floating_point(self): def double(self): r"""Casts all tensors to ``torch.bool``.""" - return self._fast_apply(lambda x: x.double(), propagate_lock=True) + + def dble(x): + return x.double() + + return self._fast_apply(dble, propagate_lock=True) def float(self): r"""Casts all tensors to ``torch.float``.""" - return self._fast_apply(lambda x: x.float(), propagate_lock=True) + + def tofloat(x): + return x.float() + + return self._fast_apply(tofloat, propagate_lock=True) def int(self): r"""Casts all tensors to ``torch.int``.""" - return self._fast_apply(lambda x: x.int(), propagate_lock=True) + + def toint(x): + return x.int() + + return self._fast_apply(toint, propagate_lock=True) def bool(self): r"""Casts all tensors to ``torch.bool``.""" - return self._fast_apply(lambda x: x.bool(), propagate_lock=True) + + def tobool(x): + return x.bool() + + return self._fast_apply(tobool, propagate_lock=True) def half(self): r"""Casts all tensors to ``torch.half``.""" - return self._fast_apply(lambda x: x.half(), propagate_lock=True) + + def tohalf(x): + return x.half() + + return self._fast_apply(tohalf, propagate_lock=True) def type(self, dst_type): r"""Casts all tensors to :attr:`dst_type`. @@ -12421,7 +12672,11 @@ def type(self, dst_type): dst_type (type or string): the desired type """ - return self._fast_apply(lambda x: x.type(dst_type)) + + def totype(x): + return x.type(dst_type) + + return self._fast_apply(totype) # Gradient compatibility @property @@ -12460,8 +12715,12 @@ def detach(self) -> T: a new tensordict with no tensor requiring gradient. """ + + def detach(x): + return x.detach() + return self._fast_apply( - lambda x: x.detach(), + detach, propagate_lock=True, ) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index e5a433af2..89b0c1d42 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -68,6 +68,11 @@ except ImportError: # torch 2.0 from torch._dynamo import is_compiling as is_dynamo_compiling + +def _identity(cls): + return cls + + try: from typing import dataclass_transform except ImportError: @@ -77,7 +82,7 @@ def dataclass_transform(*args, **kwargs): Placeholder for dataclass_transform (python<3.11). """ - return lambda cls: cls + return _identity T = TypeVar("T", bound=TensorDictBase) @@ -673,7 +678,7 @@ def __torch_function__( f"Attribute name {attr} can't be used with @tensorclass" ) - cls.fields = classmethod(lambda cls: dataclasses.fields(cls)) + cls.fields = classmethod(dataclasses.fields) for field in cls.fields(): if hasattr(cls, field.name): delattr(cls, field.name) @@ -936,7 +941,7 @@ def wrapper( return wrapper -_cast_funcs = KeyDependentDefaultDict(lambda cls: cls) +_cast_funcs = KeyDependentDefaultDict(_identity) _cast_funcs[torch.Tensor] = torch.as_tensor _cast_funcs[np.ndarray] = np.asarray @@ -1332,7 +1337,7 @@ def deliver_result(self, result, kwargs): non_tensordict = dict(non_tensordict) if copy_non_tensor and non_tensordict: # use tree_map to copy - non_tensordict = tree_map(lambda x: x, non_tensordict) + non_tensordict = tree_map(_identity, non_tensordict) return self._from_tensordict(result, non_tensordict, safe=False) return result diff --git a/tensordict/utils.py b/tensordict/utils.py index f282edd53..ae6d3b44f 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1571,8 +1571,11 @@ def assert_close( set2 = set(expected.keys()) except ValueError: # Persistent tensordicts do not work with is_leaf - set1 = set(actual.keys(is_leaf=lambda cls: issubclass(cls, torch.Tensor))) - set2 = set(expected.keys(is_leaf=lambda cls: issubclass(cls, torch.Tensor))) + def istensor(cls): + return issubclass(cls, torch.Tensor) + + set1 = set(actual.keys(is_leaf=istensor)) + set2 = set(expected.keys(is_leaf=istensor)) if not intersection and ( not (len(set1.difference(set2)) == 0 and len(set2) == len(set1)) ): @@ -2505,7 +2508,11 @@ def new_func(self): raise NotImplementedError( f"Your pytorch version {torch.__version__} does not support {dtype}." ) - return self._fast_apply(lambda x: x.to(dtype), propagate_lock=True) + + def todtype(x): + return x.to(dtype) + + return self._fast_apply(todtype, propagate_lock=True) new_func.__doc__ = rf"""Casts all tensors to ``{str(dtype)}``.""" return new_func @@ -2605,7 +2612,7 @@ def _prefix_last_key(key, prefix): "version." ) -_DEVICE2STRDEVICE = KeyDependentDefaultDict(lambda key: str(key)) +_DEVICE2STRDEVICE = KeyDependentDefaultDict(str) def _lock_warn(): @@ -2802,13 +2809,16 @@ def _rebuild_njt_from_njt(x, values, offsets, lengths): def _mismatch_keys(keys1, keys2): + def keyfunc(key): + return "".join(key) if isinstance(key, tuple) else key + keys1 = sorted( keys1, - key=lambda key: "".join(key) if isinstance(key, tuple) else key, + key=keyfunc, ) keys2 = sorted( keys2, - key=lambda key: "".join(key) if isinstance(key, tuple) else key, + key=keyfunc, ) if set(keys1) - set(keys2): sub1 = rf"The first TD has keys {set(keys1) - set(keys2)} that the second does not have."