From 447348540fe6615aa48c4a36736553ced4e1423b Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 13 Jan 2025 10:17:23 +0000 Subject: [PATCH] [BugFix] Make min/max tensorclasses be interchangeable with PT equivalent (#1180) --- tensordict/base.py | 20 ++++++------- tensordict/return_types.py | 29 +++++++++++++++++++ test/test_tensordict.py | 57 +++++++++++++++++--------------------- 3 files changed, 63 insertions(+), 43 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index e0b1d0411..a73e4163f 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -872,7 +872,7 @@ def min( ) if dim is not NO_DEFAULT and return_indices: # Split the tensordict - from .return_types import min + from torch.return_types import min values_dict = {} indices_dict = {} @@ -882,8 +882,7 @@ def min( else: indices_dict[key] = key[:-1] return min( - *result.split_keys(values_dict, indices_dict), - batch_size=result.batch_size, + result.split_keys(values_dict, indices_dict)[:2], ) return result @@ -1006,7 +1005,7 @@ def max( ) if dim is not NO_DEFAULT and return_indices: # Split the tensordict - from .return_types import max + from torch.return_types import max values_dict = {} indices_dict = {} @@ -1016,8 +1015,7 @@ def max( else: indices_dict[key] = key[:-1] return max( - *result.split_keys(values_dict, indices_dict), - batch_size=result.batch_size, + result.split_keys(values_dict, indices_dict)[:2], ) return result @@ -1110,7 +1108,7 @@ def cummin( return result if dim is not NO_DEFAULT and return_indices: # Split the tensordict - from .return_types import cummin + from torch.return_types import cummin values_dict = {} indices_dict = {} @@ -1120,8 +1118,7 @@ def cummin( else: indices_dict[key] = key[:-1] return cummin( - *result.split_keys(values_dict, indices_dict), - batch_size=result.batch_size, + result.split_keys(values_dict, indices_dict)[:2], ) return result @@ -1214,7 +1211,7 @@ def cummax( return result if dim is not NO_DEFAULT and return_indices: # Split the tensordict - from .return_types import cummax + from torch.return_types import cummax values_dict = {} indices_dict = {} @@ -1224,8 +1221,7 @@ def cummax( else: indices_dict[key] = key[:-1] return cummax( - *result.split_keys(values_dict, indices_dict), - batch_size=result.batch_size, + result.split_keys(values_dict, indices_dict)[:2], ) return result diff --git a/tensordict/return_types.py b/tensordict/return_types.py index 0c6668305..ec559598c 100644 --- a/tensordict/return_types.py +++ b/tensordict/return_types.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import warnings from tensordict.tensorclass import tensorclass from tensordict.tensordict import TensorDict @@ -14,6 +15,13 @@ class min: vals: TensorDict indices: TensorDict + def __post_init__(self): + warnings.warn( + f"{type(self)}.min is deprecated and will be removed in v0.9. " + f"Use torch.return_types.min instead.", + category=DeprecationWarning, + ) + @tensorclass class max: @@ -22,6 +30,13 @@ class max: vals: TensorDict indices: TensorDict + def __post_init__(self): + warnings.warn( + f"{type(self)}.max is deprecated and will be removed in v0.9. " + f"Use torch.return_types.max instead.", + category=DeprecationWarning, + ) + @tensorclass class cummin: @@ -30,6 +45,13 @@ class cummin: vals: TensorDict indices: TensorDict + def __post_init__(self): + warnings.warn( + f"{type(self)}.cummin is deprecated and will be removed in v0.9. " + f"Use torch.return_types.cummin instead.", + category=DeprecationWarning, + ) + @tensorclass class cummax: @@ -37,3 +59,10 @@ class cummax: vals: TensorDict indices: TensorDict + + def __post_init__(self): + warnings.warn( + f"{type(self)}.cummax is deprecated and will be removed in v0.9. " + f"Use torch.return_types.cummax instead.", + category=DeprecationWarning, + ) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index ce17a33c1..0fc698f59 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -5303,7 +5303,10 @@ def test_memmap_threads(self, td_name, device, use_dir, tmpdir, num_threads): ], ) def test_min_max_cummin_cummax(self, td_name, device, dim, keepdim, return_indices): - import tensordict.return_types as return_types + def _get_td(v): + if not is_tensor_collection(v): + return v.values + return v td = getattr(self, td_name)(device) # min @@ -5315,71 +5318,63 @@ def test_min_max_cummin_cummax(self, td_name, device, dim, keepdim, return_indic if not return_indices and dim is not None: assert_allclose_td(r, td.amin(dim=dim, keepdim=keepdim)) if return_indices: - assert is_tensorclass(r) - assert isinstance(r, return_types.min) - assert not r.vals.is_empty() + # assert is_tensorclass(r) + assert isinstance(r, torch.return_types.min) + assert not r.values.is_empty() assert not r.indices.is_empty() - else: - assert not is_tensorclass(r) if dim is None: - assert r.batch_size == () + assert _get_td(r).batch_size == () elif keepdim: s = list(td.batch_size) s[dim] = 1 - assert r.batch_size == tuple(s) + assert _get_td(r).batch_size == tuple(s) else: s = list(td.batch_size) s.pop(dim) - assert r.batch_size == tuple(s) + assert _get_td(r).batch_size == tuple(s) r = td.max(**kwargs) if not return_indices and dim is not None: assert_allclose_td(r, td.amax(dim=dim, keepdim=keepdim)) if return_indices: - assert is_tensorclass(r) - assert isinstance(r, return_types.max) - assert not r.vals.is_empty() + # assert is_tensorclass(r) + assert isinstance(r, torch.return_types.max) + assert not r.values.is_empty() assert not r.indices.is_empty() - else: - assert not is_tensorclass(r) if dim is None: - assert r.batch_size == () + assert _get_td(r).batch_size == () elif keepdim: s = list(td.batch_size) s[dim] = 1 - assert r.batch_size == tuple(s) + assert _get_td(r).batch_size == tuple(s) else: s = list(td.batch_size) s.pop(dim) - assert r.batch_size == tuple(s) + assert _get_td(r).batch_size == tuple(s) if dim is None: return kwargs.pop("keepdim") r = td.cummin(**kwargs) if return_indices: - assert is_tensorclass(r) - assert isinstance(r, return_types.cummin) - assert not r.vals.is_empty() + # assert is_tensorclass(r) + assert isinstance(r, torch.return_types.cummin) + assert not r.values.is_empty() assert not r.indices.is_empty() - else: - assert not is_tensorclass(r) if dim is None: - assert r.batch_size == () + assert _get_td(r).batch_size == () else: - assert r.batch_size == td.batch_size + assert _get_td(r).batch_size == td.batch_size r = td.cummax(**kwargs) if return_indices: - assert is_tensorclass(r) - assert isinstance(r, return_types.cummax) - assert not r.vals.is_empty() + # assert is_tensorclass(r) + assert isinstance(r, torch.return_types.cummax) + assert not r.values.is_empty() assert not r.indices.is_empty() - else: - assert not is_tensorclass(r) if dim is None: - assert r.batch_size == () + assert _get_td(r).batch_size == () else: - assert r.batch_size == td.batch_size + assert _get_td(r).batch_size == td.batch_size @pytest.mark.parametrize("inplace", [False, True]) def test_named_apply(self, td_name, device, inplace):