From ddffb2e4548cc28c4f126ac9332aa033c9a86a22 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 26 Feb 2024 17:47:04 +0000 Subject: [PATCH 1/7] amend --- tensordict/nn/common.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 10ba56cc7..9a9a2fde1 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -13,10 +13,11 @@ import torch from cloudpickle import dumps as cloudpickle_dumps, loads as cloudpickle_loads -from tensordict._td import is_tensor_collection, TensorDictBase from tensordict._tensordict import _unravel_key_to_tuple, unravel_key_list -from tensordict.functional import make_tensordict +from torch import nn, Tensor +from tensordict._td import is_tensor_collection, TensorDictBase +from tensordict.functional import make_tensordict from tensordict.nn.functional_modules import ( _swap_state, extract_weights_and_buffers, @@ -24,14 +25,12 @@ make_functional, repopulate_module, ) - from tensordict.nn.utils import ( _auto_make_functional, _dispatch_td_nn_modules, set_skip_existing, ) from tensordict.utils import implement_for, NestedKey -from torch import nn, Tensor try: from functorch import FunctionalModule, FunctionalModuleWithBuffers @@ -248,7 +247,6 @@ def __call__(self, func: Callable) -> Callable: @functools.wraps(func) def wrapper(_self, *args: Any, **kwargs: Any) -> Any: - if not _dispatch_td_nn_modules(): return func(_self, *args, **kwargs) @@ -830,7 +828,11 @@ def reset_parameters_recursive( False """ if parameters is None: - self._reset_parameters(self) + any_reset = self._reset_parameters(self) + if not any_reset: + warnings.warn( + "reset_parameters_recursive was called without parameters and did not apply any reset" + ) return elif parameters.ndim: raise RuntimeError( @@ -868,13 +870,16 @@ def reset_parameters_recursive( self._reset_parameters(self) return sanitized_parameters - def _reset_parameters(self, module: nn.Module) -> None: + def _reset_parameters(self, module: nn.Module) -> bool: + any_reset = False for child in module.children(): if isinstance(child, nn.Module): - self._reset_parameters(child) + any_reset += self._reset_parameters(child) if hasattr(child, "reset_parameters"): child.reset_parameters() + any_reset += True + return any_reset class TensorDictModule(TensorDictModuleBase): From 35c588976475ee3daafa7c08a9e9c9e2ad9848a2 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 26 Feb 2024 17:51:19 +0000 Subject: [PATCH 2/7] amend --- tensordict/nn/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 9a9a2fde1..37a5b8cb7 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -831,7 +831,7 @@ def reset_parameters_recursive( any_reset = self._reset_parameters(self) if not any_reset: warnings.warn( - "reset_parameters_recursive was called without parameters and did not apply any reset" + "reset_parameters_recursive was called without parameters and did not find any parameters to reset" ) return elif parameters.ndim: From ba4412bcf0c43df274f1ed820f10a21d3af3895e Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 26 Feb 2024 18:07:05 +0000 Subject: [PATCH 3/7] amend --- tensordict/nn/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 37a5b8cb7..a4b0cccd8 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -831,7 +831,7 @@ def reset_parameters_recursive( any_reset = self._reset_parameters(self) if not any_reset: warnings.warn( - "reset_parameters_recursive was called without parameters and did not find any parameters to reset" + "reset_parameters_recursive was called without the parameters argument and did not find any parameters to reset" ) return elif parameters.ndim: From ba07d952c57280a1d09ee8856b23954e4d7f87f2 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 26 Feb 2024 18:10:57 +0000 Subject: [PATCH 4/7] test --- test/test_nn.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index 37d5975b0..4652b4d75 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -12,9 +12,12 @@ import pytest import torch +from tensordict._tensordict import unravel_key_list +from torch import distributions as d, nn +from torch.distributions import Normal +from torch.utils._pytree import tree_map from tensordict import tensorclass, TensorDict -from tensordict._tensordict import unravel_key_list from tensordict.nn import ( dispatch, probabilistic as nn_probabilistic, @@ -42,9 +45,6 @@ set_skip_existing, skip_existing, ) -from torch import distributions as d, nn -from torch.distributions import Normal -from torch.utils._pytree import tree_map try: import functorch # noqa @@ -165,6 +165,16 @@ def test_reset(self): seq.reset_parameters_recursive() assert torch.all(old_param != net[0][0].weight.data) + def test_reset_warning(self): + torch.manual_seed(0) + net = nn.ModuleList([nn.Tanh(), nn.ReLU()]) + module = TensorDictModule(net, in_keys=["in"], out_keys=["out"]) + with pytest.warns( + UserWarning, + match="reset_parameters_recursive was called without the parameters argument and did not find any parameters to reset", + ): + module.reset_parameters_recursive() + @pytest.mark.parametrize( "net", [ @@ -2666,7 +2676,6 @@ def test_module_buffer(): ], ) def test_nested_keys_probabilistic_delta(log_prob_key): - policy_module = TensorDictModule( nn.Linear(1, 1), in_keys=[("data", "states")], out_keys=[("data", "param")] ) @@ -2711,7 +2720,6 @@ def test_nested_keys_probabilistic_delta(log_prob_key): ], ) def test_nested_keys_probabilistic_normal(log_prob_key): - loc_module = TensorDictModule( nn.Linear(1, 1), in_keys=[("data", "states")], From 169181ca7c88c394cdb69d3a16e8ce84af9c92ea Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 26 Feb 2024 18:50:16 +0000 Subject: [PATCH 5/7] Update tensordict/nn/common.py --- tensordict/nn/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index a4b0cccd8..7cbedd655 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -874,7 +874,7 @@ def _reset_parameters(self, module: nn.Module) -> bool: any_reset = False for child in module.children(): if isinstance(child, nn.Module): - any_reset += self._reset_parameters(child) + any_reset |= self._reset_parameters(child) if hasattr(child, "reset_parameters"): child.reset_parameters() From d893c371b591b31f74dfbe5adaca68d72b01e43c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 26 Feb 2024 18:50:21 +0000 Subject: [PATCH 6/7] Update tensordict/nn/common.py --- tensordict/nn/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 7cbedd655..ad166bf08 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -878,7 +878,7 @@ def _reset_parameters(self, module: nn.Module) -> bool: if hasattr(child, "reset_parameters"): child.reset_parameters() - any_reset += True + any_reset |= True return any_reset From 46f12a46fa14f62a9e1d1e4e3e5df7a9d53bd393 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 26 Feb 2024 14:35:01 -0500 Subject: [PATCH 7/7] lint --- tensordict/nn/common.py | 4 ++-- test/test_nn.py | 50 ++++++++++++++++++++++++++++------------- 2 files changed, 37 insertions(+), 17 deletions(-) diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index ad166bf08..79f99a224 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -13,10 +13,9 @@ import torch from cloudpickle import dumps as cloudpickle_dumps, loads as cloudpickle_loads -from tensordict._tensordict import _unravel_key_to_tuple, unravel_key_list -from torch import nn, Tensor from tensordict._td import is_tensor_collection, TensorDictBase +from tensordict._tensordict import _unravel_key_to_tuple, unravel_key_list from tensordict.functional import make_tensordict from tensordict.nn.functional_modules import ( _swap_state, @@ -31,6 +30,7 @@ set_skip_existing, ) from tensordict.utils import implement_for, NestedKey +from torch import nn, Tensor try: from functorch import FunctionalModule, FunctionalModuleWithBuffers diff --git a/test/test_nn.py b/test/test_nn.py index 4652b4d75..34a047c6f 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -12,12 +12,9 @@ import pytest import torch -from tensordict._tensordict import unravel_key_list -from torch import distributions as d, nn -from torch.distributions import Normal -from torch.utils._pytree import tree_map from tensordict import tensorclass, TensorDict +from tensordict._tensordict import unravel_key_list from tensordict.nn import ( dispatch, probabilistic as nn_probabilistic, @@ -45,6 +42,9 @@ set_skip_existing, skip_existing, ) +from torch import distributions, nn +from torch.distributions import Normal +from torch.utils._pytree import tree_map try: import functorch # noqa @@ -345,7 +345,7 @@ def test_stateful_probabilistic_kwargs( net = TensorDictModule(module=net, in_keys=in_keys, out_keys=out_keys) kwargs = { - "distribution_class": torch.distributions.Uniform, + "distribution_class": distributions.Uniform, "distribution_kwargs": {"high": max_dist}, } if out_keys == ["low"]: @@ -3091,7 +3091,10 @@ def test_const(self): ) dist = CompositeDistribution( params, - distribution_map={"cont": d.Normal, ("nested", "disc"): d.Categorical}, + distribution_map={ + "cont": distributions.Normal, + ("nested", "disc"): distributions.Categorical, + }, ) assert dist.batch_shape == params.shape assert len(dist.dists) == 2 @@ -3108,7 +3111,10 @@ def test_sample(self): ) dist = CompositeDistribution( params, - distribution_map={"cont": d.Normal, ("nested", "disc"): d.Categorical}, + distribution_map={ + "cont": distributions.Normal, + ("nested", "disc"): distributions.Categorical, + }, ) sample = dist.sample() assert sample.shape == params.shape @@ -3129,8 +3135,8 @@ def test_rsample(self): dist = CompositeDistribution( params, distribution_map={ - "cont": d.Normal, - ("nested", "disc"): d.RelaxedOneHotCategorical, + "cont": distributions.Normal, + ("nested", "disc"): distributions.RelaxedOneHotCategorical, }, extra_kwargs={("nested", "disc"): {"temperature": torch.tensor(1.0)}}, ) @@ -3155,8 +3161,8 @@ def test_log_prob(self): dist = CompositeDistribution( params, distribution_map={ - "cont": d.Normal, - ("nested", "disc"): d.RelaxedOneHotCategorical, + "cont": distributions.Normal, + ("nested", "disc"): distributions.RelaxedOneHotCategorical, }, extra_kwargs={("nested", "disc"): {"temperature": torch.tensor(1.0)}}, ) @@ -3180,7 +3186,11 @@ def test_cdf(self): [3], ) dist = CompositeDistribution( - params, distribution_map={"cont": d.Normal, ("nested", "cont"): d.Normal} + params, + distribution_map={ + "cont": distributions.Normal, + ("nested", "cont"): distributions.Normal, + }, ) sample = dist.rsample((4,)) sample = dist.cdf(sample) @@ -3202,7 +3212,11 @@ def test_icdf(self): [3], ) dist = CompositeDistribution( - params, distribution_map={"cont": d.Normal, ("nested", "cont"): d.Normal} + params, + distribution_map={ + "cont": distributions.Normal, + ("nested", "cont"): distributions.Normal, + }, ) sample = dist.rsample((4,)) sample = dist.cdf(sample) @@ -3233,7 +3247,10 @@ def test_prob_module(self, interaction, return_log_prob): ) in_keys = ["params"] out_keys = ["cont", ("nested", "cont")] - distribution_map = {"cont": d.Normal, ("nested", "cont"): d.Normal} + distribution_map = { + "cont": distributions.Normal, + ("nested", "cont"): distributions.Normal, + } module = ProbabilisticTensorDictModule( in_keys=in_keys, out_keys=out_keys, @@ -3283,7 +3300,10 @@ def test_prob_module_seq(self, interaction, return_log_prob): ) in_keys = ["params"] out_keys = ["cont", ("nested", "cont")] - distribution_map = {"cont": d.Normal, ("nested", "cont"): d.Normal} + distribution_map = { + "cont": distributions.Normal, + ("nested", "cont"): distributions.Normal, + } backbone = TensorDictModule(lambda: None, in_keys=[], out_keys=[]) module = ProbabilisticTensorDictSequential( backbone,