From b8ad113762f89d355b70c27ccbc878c716ad411c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 27 Feb 2024 14:35:28 +0000 Subject: [PATCH] [Refactor] Remove remnant legacy functional calls (#1973) --- test/test_tensordictmodules.py | 1256 +++----------------- torchrl/envs/utils.py | 1 - tutorials/sphinx-tutorials/coding_ddpg.py | 17 +- tutorials/sphinx-tutorials/torchrl_demo.py | 21 +- 4 files changed, 190 insertions(+), 1105 deletions(-) diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index c2df40be012..7e0fef99786 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -9,12 +9,7 @@ import torch from mocking_classes import DiscreteActionVecMockEnv from tensordict import pad, TensorDict, unravel_key_list -from tensordict.nn import ( - InteractionType, - make_functional, - TensorDictModule, - TensorDictSequential, -) +from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential from torch import nn from torchrl.data.tensor_specs import ( BoundedTensorSpec, @@ -255,15 +250,33 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy, exp_mode, out_keys) elif safe and spec_type == "bounded": assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() + +class TestTDSequence: + # Temporarily disabling this test until 473 is merged in tensordict + # def test_in_key_warning(self): + # with pytest.warns(UserWarning, match='key "_" is for ignoring output'): + # tensordict_module = SafeModule( + # nn.Linear(3, 4), in_keys=["_"], out_keys=["out1"] + # ) + # with pytest.warns(UserWarning, match='key "_" is for ignoring output'): + # tensordict_module = SafeModule( + # nn.Linear(3, 4), in_keys=["_", "key2"], out_keys=["out1"] + # ) + @pytest.mark.parametrize("safe", [True, False]) @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional(self, safe, spec_type): + @pytest.mark.parametrize("lazy", [True, False]) + def test_stateful(self, safe, spec_type, lazy): torch.manual_seed(0) param_multiplier = 1 - - net = nn.Linear(3, 4 * param_multiplier) - - params = make_functional(net) + if lazy: + net1 = nn.LazyLinear(4) + dummy_net = nn.LazyLinear(4) + net2 = nn.LazyLinear(4 * param_multiplier) + else: + net1 = nn.Linear(3, 4) + dummy_net = nn.Linear(4, 4) + net2 = nn.Linear(4, 4 * param_multiplier) if spec_type is None: spec = None @@ -272,31 +285,51 @@ def test_functional(self, safe, spec_type): elif spec_type == "unbounded": spec = UnboundedContinuousTensorSpec(4) + kwargs = {} + if safe and spec is None: - with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", - ): - tensordict_module = SafeModule( - spec=spec, - module=net, - in_keys=["in"], - out_keys=["out"], - safe=safe, - ) - return + pytest.skip("safe and spec is None is checked elsewhere") else: - tensordict_module = SafeModule( - spec=spec, - module=net, + tdmodule1 = SafeModule( + net1, + spec=None, in_keys=["in"], + out_keys=["hidden"], + safe=False, + ) + dummy_tdmodule = SafeModule( + dummy_net, + spec=None, + in_keys=["hidden"], + out_keys=["hidden"], + safe=False, + ) + tdmodule2 = SafeModule( + spec=spec, + module=net2, + in_keys=["hidden"], out_keys=["out"], - safe=safe, + safe=False, + **kwargs, ) + tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) + + assert hasattr(tdmodule, "__setitem__") + assert len(tdmodule) == 3 + tdmodule[1] = tdmodule2 + assert len(tdmodule) == 3 + + assert hasattr(tdmodule, "__delitem__") + assert len(tdmodule) == 3 + del tdmodule[2] + assert len(tdmodule) == 2 + + assert hasattr(tdmodule, "__getitem__") + assert tdmodule[0] is tdmodule1 + assert tdmodule[1] is tdmodule2 td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tensordict_module(td, params=TensorDict({"module": params}, [])) + tdmodule(td) assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 4]) @@ -308,16 +341,19 @@ def test_functional(self, safe, spec_type): @pytest.mark.parametrize("safe", [True, False]) @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional_probabilistic(self, safe, spec_type): + @pytest.mark.parametrize("lazy", [True, False]) + def test_stateful_probabilistic(self, safe, spec_type, lazy): torch.manual_seed(0) param_multiplier = 2 - - tdnet = SafeModule( - module=NormalParamWrapper(nn.Linear(3, 4 * param_multiplier)), - spec=None, - in_keys=["in"], - out_keys=["loc", "scale"], - ) + if lazy: + net1 = nn.LazyLinear(4) + dummy_net = nn.LazyLinear(4) + net2 = nn.LazyLinear(4 * param_multiplier) + else: + net1 = nn.Linear(3, 4) + dummy_net = nn.Linear(4, 4) + net2 = nn.Linear(4, 4 * param_multiplier) + net2 = NormalParamWrapper(net2) if spec_type is None: spec = None @@ -331,1075 +367,128 @@ def test_functional_probabilistic(self, safe, spec_type): kwargs = {"distribution_class": TanhNormal} if safe and spec is None: - with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", - ): - prob_module = SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) - return + pytest.skip("safe and spec is None is checked elsewhere") else: + tdmodule1 = SafeModule( + net1, + in_keys=["in"], + out_keys=["hidden"], + spec=None, + safe=False, + ) + dummy_tdmodule = SafeModule( + dummy_net, + in_keys=["hidden"], + out_keys=["hidden"], + spec=None, + safe=False, + ) + tdmodule2 = SafeModule( + module=net2, + in_keys=["hidden"], + out_keys=["loc", "scale"], + spec=None, + safe=False, + ) + prob_module = SafeProbabilisticModule( + spec=spec, in_keys=["loc", "scale"], out_keys=["out"], - spec=spec, - safe=safe, + safe=False, **kwargs, ) + tdmodule = SafeProbabilisticTensorDictSequential( + tdmodule1, dummy_tdmodule, tdmodule2, prob_module + ) + + assert hasattr(tdmodule, "__setitem__") + assert len(tdmodule) == 4 + tdmodule[1] = tdmodule2 + tdmodule[2] = prob_module + assert len(tdmodule) == 4 - tensordict_module = SafeProbabilisticTensorDictSequential(tdnet, prob_module) - params = make_functional(tensordict_module) + assert hasattr(tdmodule, "__delitem__") + assert len(tdmodule) == 4 + del tdmodule[3] + assert len(tdmodule) == 3 + + assert hasattr(tdmodule, "__getitem__") + assert tdmodule[0] is tdmodule1 + assert tdmodule[1] is tdmodule2 + assert tdmodule[2] is prob_module td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tensordict_module(td, params=params) + tdmodule(td) assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 4]) + dist = tdmodule.get_dist(td) + assert dist.rsample().shape[: td.ndimension()] == td.shape + # test bounds if not safe and spec_type == "bounded": assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() elif safe and spec_type == "bounded": assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional_with_buffer(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 1 - - net = nn.BatchNorm1d(32 * param_multiplier) - params = make_functional(net) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 32) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(32) - - if safe and spec is None: - with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", - ): - tdmodule = SafeModule( - spec=spec, - module=net, - in_keys=["in"], - out_keys=["out"], - safe=safe, - ) - return - else: - tdmodule = SafeModule( - spec=spec, - module=net, - in_keys=["in"], - out_keys=["out"], - safe=safe, - ) - - td = TensorDict({"in": torch.randn(3, 32 * param_multiplier)}, [3]) - tdmodule(td, params=TensorDict({"module": params}, [])) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 32]) + def test_submodule_sequence(self): + td_module_1 = SafeModule( + nn.Linear(3, 2), + in_keys=["in"], + out_keys=["hidden"], + ) + td_module_2 = SafeModule( + nn.Linear(2, 4), + in_keys=["hidden"], + out_keys=["out"], + ) + td_module = SafeSequential(td_module_1, td_module_2) - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() + td_1 = TensorDict({"in": torch.randn(5, 3)}, [5]) + sub_seq_1 = td_module.select_subsequence(out_keys=["hidden"]) + sub_seq_1(td_1) + assert "hidden" in td_1.keys() + assert "out" not in td_1.keys() + td_2 = TensorDict({"hidden": torch.randn(5, 2)}, [5]) + sub_seq_2 = td_module.select_subsequence(in_keys=["hidden"]) + sub_seq_2(td_2) + assert "out" in td_2.keys() + assert td_2.get("out").shape == torch.Size([5, 4]) - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional_with_buffer_probabilistic(self, safe, spec_type): + @pytest.mark.parametrize("stack", [True, False]) + def test_sequential_partial(self, stack): torch.manual_seed(0) param_multiplier = 2 - tdnet = SafeModule( - module=NormalParamWrapper(nn.BatchNorm1d(32 * param_multiplier)), - spec=None, - in_keys=["in"], - out_keys=["loc", "scale"], - ) + net1 = nn.Linear(3, 4) - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 32) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(32) - else: - raise NotImplementedError + net2 = nn.Linear(4, 4 * param_multiplier) + net2 = NormalParamWrapper(net2) + net2 = SafeModule(net2, in_keys=["b"], out_keys=["loc", "scale"]) - kwargs = {"distribution_class": TanhNormal} + net3 = nn.Linear(4, 4 * param_multiplier) + net3 = NormalParamWrapper(net3) + net3 = SafeModule(net3, in_keys=["c"], out_keys=["loc", "scale"]) - if safe and spec is None: - with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", - ): - prob_module = SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) + spec = BoundedTensorSpec(-0.1, 0.1, 4) - return - else: - prob_module = SafeProbabilisticModule( + kwargs = {"distribution_class": TanhNormal} + + tdmodule1 = SafeModule( + net1, + in_keys=["a"], + out_keys=["hidden"], + spec=None, + safe=False, + ) + tdmodule2 = SafeProbabilisticTensorDictSequential( + net2, + SafeProbabilisticModule( in_keys=["loc", "scale"], out_keys=["out"], spec=spec, - safe=safe, - **kwargs, - ) - - tdmodule = SafeProbabilisticTensorDictSequential(tdnet, prob_module) - params = make_functional(tdmodule) - - td = TensorDict({"in": torch.randn(3, 32 * param_multiplier)}, [3]) - tdmodule(td, params=params) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 32]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.skipif( - not _has_functorch, reason="vmap can only be used with functorch" - ) - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_vmap(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 1 - - net = nn.Linear(3, 4 * param_multiplier) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - - if safe and spec is None: - with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", - ): - tdmodule = SafeModule( - spec=spec, - module=net, - in_keys=["in"], - out_keys=["out"], - safe=safe, - ) - return - else: - tdmodule = SafeModule( - spec=spec, - module=net, - in_keys=["in"], - out_keys=["out"], - safe=safe, - ) - - params = make_functional(tdmodule) - - # vmap = True - params = params.expand(10) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - if safe and spec_type == "bounded": - with pytest.raises( - RuntimeError, match="vmap cannot be used with safe=True" - ): - td_out = vmap(tdmodule, (None, 0))(td, params) - return - else: - td_out = vmap(tdmodule, (None, 0))(td, params) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - # vmap = (0, 0) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td_repeat = td.expand(10, *td.batch_size) - td_out = vmap(tdmodule, (0, 0))(td_repeat, params) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - @pytest.mark.skipif( - not _has_functorch, reason="vmap can only be used with functorch" - ) - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_vmap_probabilistic(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 2 - - net = NormalParamWrapper(nn.Linear(3, 4 * param_multiplier)) - tdnet = SafeModule( - module=net, in_keys=["in"], out_keys=["loc", "scale"], spec=None - ) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - else: - raise NotImplementedError - - kwargs = {"distribution_class": TanhNormal} - - if safe and spec is None: - with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", - ): - prob_module = SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) - return - else: - prob_module = SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) - - tdmodule = SafeProbabilisticTensorDictSequential(tdnet, prob_module) - params = make_functional(tdmodule) - - # vmap = True - params = params.expand(10) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - if safe and spec_type == "bounded": - with pytest.raises( - RuntimeError, match="vmap cannot be used with safe=True" - ): - td_out = vmap(tdmodule, (None, 0))(td, params) - return - else: - td_out = vmap(tdmodule, (None, 0))(td, params) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - # vmap = (0, 0) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td_repeat = td.expand(10, *td.batch_size) - td_out = vmap(tdmodule, (0, 0))(td_repeat, params) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - -class TestTDSequence: - # Temporarily disabling this test until 473 is merged in tensordict - # def test_in_key_warning(self): - # with pytest.warns(UserWarning, match='key "_" is for ignoring output'): - # tensordict_module = SafeModule( - # nn.Linear(3, 4), in_keys=["_"], out_keys=["out1"] - # ) - # with pytest.warns(UserWarning, match='key "_" is for ignoring output'): - # tensordict_module = SafeModule( - # nn.Linear(3, 4), in_keys=["_", "key2"], out_keys=["out1"] - # ) - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - @pytest.mark.parametrize("lazy", [True, False]) - def test_stateful(self, safe, spec_type, lazy): - torch.manual_seed(0) - param_multiplier = 1 - if lazy: - net1 = nn.LazyLinear(4) - dummy_net = nn.LazyLinear(4) - net2 = nn.LazyLinear(4 * param_multiplier) - else: - net1 = nn.Linear(3, 4) - dummy_net = nn.Linear(4, 4) - net2 = nn.Linear(4, 4 * param_multiplier) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - - kwargs = {} - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, - spec=None, - in_keys=["in"], - out_keys=["hidden"], - safe=False, - ) - dummy_tdmodule = SafeModule( - dummy_net, - spec=None, - in_keys=["hidden"], - out_keys=["hidden"], - safe=False, - ) - tdmodule2 = SafeModule( - spec=spec, - module=net2, - in_keys=["hidden"], - out_keys=["out"], - safe=False, - **kwargs, - ) - tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) - - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 3 - tdmodule[1] = tdmodule2 - assert len(tdmodule) == 3 - - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 3 - del tdmodule[2] - assert len(tdmodule) == 2 - - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 - - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tdmodule(td) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 4]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - @pytest.mark.parametrize("lazy", [True, False]) - def test_stateful_probabilistic(self, safe, spec_type, lazy): - torch.manual_seed(0) - param_multiplier = 2 - if lazy: - net1 = nn.LazyLinear(4) - dummy_net = nn.LazyLinear(4) - net2 = nn.LazyLinear(4 * param_multiplier) - else: - net1 = nn.Linear(3, 4) - dummy_net = nn.Linear(4, 4) - net2 = nn.Linear(4, 4 * param_multiplier) - net2 = NormalParamWrapper(net2) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - else: - raise NotImplementedError - - kwargs = {"distribution_class": TanhNormal} - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, - in_keys=["in"], - out_keys=["hidden"], - spec=None, - safe=False, - ) - dummy_tdmodule = SafeModule( - dummy_net, - in_keys=["hidden"], - out_keys=["hidden"], - spec=None, - safe=False, - ) - tdmodule2 = SafeModule( - module=net2, - in_keys=["hidden"], - out_keys=["loc", "scale"], - spec=None, - safe=False, - ) - - prob_module = SafeProbabilisticModule( - spec=spec, - in_keys=["loc", "scale"], - out_keys=["out"], - safe=False, - **kwargs, - ) - tdmodule = SafeProbabilisticTensorDictSequential( - tdmodule1, dummy_tdmodule, tdmodule2, prob_module - ) - - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 4 - tdmodule[1] = tdmodule2 - tdmodule[2] = prob_module - assert len(tdmodule) == 4 - - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 4 - del tdmodule[3] - assert len(tdmodule) == 3 - - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 - assert tdmodule[2] is prob_module - - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tdmodule(td) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 4]) - - dist = tdmodule.get_dist(td) - assert dist.rsample().shape[: td.ndimension()] == td.shape - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 1 - - net1 = nn.Linear(3, 4) - dummy_net = nn.Linear(4, 4) - net2 = nn.Linear(4, 4 * param_multiplier) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False - ) - dummy_tdmodule = SafeModule( - dummy_net, - spec=None, - in_keys=["hidden"], - out_keys=["hidden"], - safe=False, - ) - tdmodule2 = SafeModule( - net2, - spec=spec, - in_keys=["hidden"], - out_keys=["out"], - safe=safe, - ) - tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) - - params = make_functional(tdmodule) - - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 3 - tdmodule[1] = tdmodule2 - with params.unlock_(): - params["module", "1"] = params["module", "2"] - assert len(tdmodule) == 3 - - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 3 - del tdmodule[2] - with params.unlock_(): - del params["module", "2"] - assert len(tdmodule) == 2 - - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 - - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tdmodule(td, params) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 4]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional_probabilistic(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 2 - - net1 = nn.Linear(3, 4) - dummy_net = nn.Linear(4, 4) - net2 = nn.Linear(4, 4 * param_multiplier) - net2 = NormalParamWrapper(net2) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - else: - raise NotImplementedError - - kwargs = {"distribution_class": TanhNormal} - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False - ) - dummy_tdmodule = SafeModule( - dummy_net, - spec=None, - in_keys=["hidden"], - out_keys=["hidden"], - safe=False, - ) - tdmodule2 = SafeModule( - module=net2, in_keys=["hidden"], out_keys=["loc", "scale"] - ) - - prob_module = SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) - tdmodule = SafeProbabilisticTensorDictSequential( - tdmodule1, dummy_tdmodule, tdmodule2, prob_module - ) - - params = make_functional(tdmodule, funs_to_decorate=["forward", "get_dist"]) - - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 4 - tdmodule[1] = tdmodule2 - tdmodule[2] = prob_module - with params.unlock_(): - params["module", "1"] = params["module", "2"] - params["module", "2"] = params["module", "3"] - assert len(tdmodule) == 4 - - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 4 - del tdmodule[3] - with params.unlock_(): - del params["module", "3"] - assert len(tdmodule) == 3 - - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 - assert tdmodule[2] is prob_module - - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tdmodule(td, params=params) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 4]) - - dist = tdmodule.get_dist(td, params=params) - assert dist.rsample().shape[: td.ndimension()] == td.shape - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional_with_buffer(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 1 - - net1 = nn.Sequential(nn.Linear(7, 7), nn.BatchNorm1d(7)) - dummy_net = nn.Sequential(nn.Linear(7, 7), nn.BatchNorm1d(7)) - net2 = nn.Sequential( - nn.Linear(7, 7 * param_multiplier), nn.BatchNorm1d(7 * param_multiplier) - ) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 7) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(7) - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False - ) - dummy_tdmodule = SafeModule( - dummy_net, - spec=None, - in_keys=["hidden"], - out_keys=["hidden"], - safe=False, - ) - tdmodule2 = SafeModule( - net2, - spec=spec, - in_keys=["hidden"], - out_keys=["out"], - safe=safe, - ) - tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) - - params = make_functional(tdmodule) - - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 3 - tdmodule[1] = tdmodule2 - with params.unlock_(): - params["module", "1"] = params["module", "2"] - assert len(tdmodule) == 3 - - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 3 - del tdmodule[2] - with params.unlock_(): - del params["module", "2"] - assert len(tdmodule) == 2 - - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 - - td = TensorDict({"in": torch.randn(3, 7)}, [3]) - tdmodule(td, params=params) - - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 7]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional_with_buffer_probabilistic(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 2 - - net1 = nn.Sequential(nn.Linear(7, 7), nn.BatchNorm1d(7)) - dummy_net = nn.Sequential(nn.Linear(7, 7), nn.BatchNorm1d(7)) - net2 = nn.Sequential( - nn.Linear(7, 7 * param_multiplier), nn.BatchNorm1d(7 * param_multiplier) - ) - net2 = NormalParamWrapper(net2) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 7) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(7) - else: - raise NotImplementedError - - kwargs = {"distribution_class": TanhNormal} - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, in_keys=["in"], out_keys=["hidden"], spec=None, safe=False - ) - dummy_tdmodule = SafeModule( - dummy_net, - in_keys=["hidden"], - out_keys=["hidden"], - spec=None, - safe=False, - ) - tdmodule2 = SafeModule( - net2, - in_keys=["hidden"], - out_keys=["loc", "scale"], - spec=None, - safe=False, - ) - - prob_module = SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) - tdmodule = SafeProbabilisticTensorDictSequential( - tdmodule1, dummy_tdmodule, tdmodule2, prob_module - ) - - params = make_functional(tdmodule, ["forward", "get_dist"]) - - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 4 - tdmodule[1] = tdmodule2 - tdmodule[2] = prob_module - with params.unlock_(): - params["module", "1"] = params["module", "2"] - params["module", "2"] = params["module", "3"] - assert len(tdmodule) == 4 - - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 4 - del tdmodule[3] - with params.unlock_(): - del params["module", "3"] - assert len(tdmodule) == 3 - - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 - assert tdmodule[2] is prob_module - - td = TensorDict({"in": torch.randn(3, 7)}, [3]) - tdmodule(td, params=params) - - dist = tdmodule.get_dist(td, params=params) - assert dist.rsample().shape[: td.ndimension()] == td.shape - - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 7]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.skipif( - not _has_functorch, reason="vmap can only be used with functorch" - ) - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_vmap(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 1 - - net1 = nn.Linear(3, 4) - dummy_net = nn.Linear(4, 4) - net2 = nn.Linear(4, 4 * param_multiplier) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, - spec=None, - in_keys=["in"], - out_keys=["hidden"], - safe=False, - ) - dummy_tdmodule = SafeModule( - dummy_net, - spec=None, - in_keys=["hidden"], - out_keys=["hidden"], - safe=False, - ) - tdmodule2 = SafeModule( - net2, - spec=spec, - in_keys=["hidden"], - out_keys=["out"], - safe=safe, - ) - tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) - - params = make_functional(tdmodule) - - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 3 - tdmodule[1] = tdmodule2 - with params.unlock_(): - params["module", "1"] = params["module", "2"] - assert len(tdmodule) == 3 - - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 3 - del tdmodule[2] - with params.unlock_(): - del params["module", "2"] - assert len(tdmodule) == 2 - - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 - - # vmap = True - params = params.expand(10) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - if safe and spec_type == "bounded": - with pytest.raises( - RuntimeError, match="vmap cannot be used with safe=True" - ): - td_out = vmap(tdmodule, (None, 0))(td, params) - return - else: - td_out = vmap(tdmodule, (None, 0))(td, params) - - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - # vmap = (0, 0) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td_repeat = td.expand(10, *td.batch_size) - td_out = vmap(tdmodule, (0, 0))(td_repeat, params) - assert td_out is not td_repeat - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - @pytest.mark.skipif( - not _has_functorch, reason="vmap can only be used with functorch" - ) - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_vmap_probabilistic(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 2 - - net1 = nn.Linear(3, 4) - - net2 = nn.Linear(4, 4 * param_multiplier) - net2 = NormalParamWrapper(net2) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - else: - raise NotImplementedError - - kwargs = {"distribution_class": TanhNormal} - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, - spec=None, - in_keys=["in"], - out_keys=["hidden"], - safe=False, - ) - tdmodule2 = SafeModule(net2, in_keys=["hidden"], out_keys=["loc", "scale"]) - prob_module = SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) - tdmodule = SafeProbabilisticTensorDictSequential( - tdmodule1, tdmodule2, prob_module - ) - - params = make_functional(tdmodule) - - # vmap = True - params = params.expand(10) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - if safe and spec_type == "bounded": - with pytest.raises( - RuntimeError, match="vmap cannot be used with safe=True" - ): - td_out = vmap(tdmodule, (None, 0))(td, params) - return - else: - td_out = vmap(tdmodule, (None, 0))(td, params) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - # vmap = (0, 0) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td_repeat = td.expand(10, *td.batch_size) - td_out = vmap(tdmodule, (0, 0))(td_repeat, params) - assert td_out is not td_repeat - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - @pytest.mark.parametrize("functional", [True, False]) - def test_submodule_sequence(self, functional): - td_module_1 = SafeModule( - nn.Linear(3, 2), - in_keys=["in"], - out_keys=["hidden"], - ) - td_module_2 = SafeModule( - nn.Linear(2, 4), - in_keys=["hidden"], - out_keys=["out"], - ) - td_module = SafeSequential(td_module_1, td_module_2) - - if functional: - td_1 = TensorDict({"in": torch.randn(5, 3)}, [5]) - sub_seq_1 = td_module.select_subsequence(out_keys=["hidden"]) - params = make_functional(sub_seq_1) - sub_seq_1(td_1, params=params) - assert "hidden" in td_1.keys() - assert "out" not in td_1.keys() - td_2 = TensorDict({"hidden": torch.randn(5, 2)}, [5]) - sub_seq_2 = td_module.select_subsequence(in_keys=["hidden"]) - params = make_functional(sub_seq_2) - sub_seq_2(td_2, params=params) - assert "out" in td_2.keys() - assert td_2.get("out").shape == torch.Size([5, 4]) - else: - td_1 = TensorDict({"in": torch.randn(5, 3)}, [5]) - sub_seq_1 = td_module.select_subsequence(out_keys=["hidden"]) - sub_seq_1(td_1) - assert "hidden" in td_1.keys() - assert "out" not in td_1.keys() - td_2 = TensorDict({"hidden": torch.randn(5, 2)}, [5]) - sub_seq_2 = td_module.select_subsequence(in_keys=["hidden"]) - sub_seq_2(td_2) - assert "out" in td_2.keys() - assert td_2.get("out").shape == torch.Size([5, 4]) - - @pytest.mark.parametrize("stack", [True, False]) - @pytest.mark.parametrize("functional", [True, False]) - def test_sequential_partial(self, stack, functional): - torch.manual_seed(0) - param_multiplier = 2 - - net1 = nn.Linear(3, 4) - - net2 = nn.Linear(4, 4 * param_multiplier) - net2 = NormalParamWrapper(net2) - net2 = SafeModule(net2, in_keys=["b"], out_keys=["loc", "scale"]) - - net3 = nn.Linear(4, 4 * param_multiplier) - net3 = NormalParamWrapper(net3) - net3 = SafeModule(net3, in_keys=["c"], out_keys=["loc", "scale"]) - - spec = BoundedTensorSpec(-0.1, 0.1, 4) - - kwargs = {"distribution_class": TanhNormal} - - tdmodule1 = SafeModule( - net1, - in_keys=["a"], - out_keys=["hidden"], - spec=None, - safe=False, - ) - tdmodule2 = SafeProbabilisticTensorDictSequential( - net2, - SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=True, + safe=True, **kwargs, ), ) @@ -1417,11 +506,6 @@ def test_sequential_partial(self, stack, functional): tdmodule1, tdmodule2, tdmodule3, partial_tolerant=True ) - if functional: - params = make_functional(tdmodule) - else: - params = None - if stack: td = torch.stack( [ @@ -1430,10 +514,7 @@ def test_sequential_partial(self, stack, functional): ], 0, ) - if functional: - tdmodule(td, params=params) - else: - tdmodule(td) + tdmodule(td) assert "loc" in td.keys() assert "scale" in td.keys() assert "out" in td.keys() @@ -1444,10 +525,7 @@ def test_sequential_partial(self, stack, functional): assert "b" in td[0].keys() else: td = TensorDict({"a": torch.randn(3), "b": torch.randn(4)}, []) - if functional: - tdmodule(td, params=params) - else: - tdmodule(td) + tdmodule(td) assert "loc" in td.keys() assert "scale" in td.keys() assert "out" in td.keys() diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index fa3d28848a8..e779bfc165d 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -1145,7 +1145,6 @@ def _make_compatible_policy(policy, observation_spec, env=None, fast_wrap=False) ) try: - # signature modified by make_functional sig = policy.forward.__signature__ except AttributeError: sig = inspect.signature(policy.forward) diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index 252b4fd2146..4a818474985 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -297,12 +297,11 @@ def _loss_actor( ) -> torch.Tensor: td_copy = tensordict.select(*self.actor_in_keys) # Get an action from the actor network: since we made it functional, we need to pass the params - td_copy = self.actor_network(td_copy, params=self.actor_network_params) + with self.actor_network_params.to_module(self.actor_network): + td_copy = self.actor_network(td_copy) # get the value associated with that action - td_copy = self.value_network( - td_copy, - params=self.value_network_params.detach(), - ) + with self.value_network_params.detach().to_module(self.value_network): + td_copy = self.value_network(td_copy) return -td_copy.get("state_action_value") @@ -324,7 +323,8 @@ def _loss_value( td_copy = tensordict.clone() # V(s, a) - self.value_network(td_copy, params=self.value_network_params) + with self.value_network_params.to_module(self.value_network): + self.value_network(td_copy) pred_val = td_copy.get("state_action_value").squeeze(-1) # we manually reconstruct the parameters of the actor-critic, where the first @@ -339,9 +339,8 @@ def _loss_value( batch_size=self.target_actor_network_params.batch_size, device=self.target_actor_network_params.device, ) - target_value = self.value_estimator.value_estimate( - tensordict, target_params=target_params - ).squeeze(-1) + with target_params.to_module(self.value_estimator): + target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) # Computes the value loss: L2, L1 or smooth L1 depending on `self.loss_function` loss_value = distance_loss(pred_val, target_value, loss_function=self.loss_function) diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index ce3f0bb4b98..25213503e19 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -543,21 +543,30 @@ # Functional Programming (Ensembling / Meta-RL) # ---------------------------------------------- -from tensordict.nn import make_functional +from tensordict import TensorDict -params = make_functional(sequence) -len(list(sequence.parameters())) # functional modules have no parameters +params = TensorDict.from_module(sequence) +print("extracted params", params) ############################################################################### +# functional call using tensordict: -sequence(tensordict, params) +with params.to_module(sequence): + sequence(tensordict) ############################################################################### - +# Using vectorized map for model ensembling from torch import vmap params_expand = params.expand(4) -tensordict_exp = vmap(sequence, (None, 0))(tensordict, params_expand) + + +def exec_sequence(params, data): + with params.to_module(sequence): + return sequence(data) + + +tensordict_exp = vmap(exec_sequence, (0, None))(params_expand, tensordict) print(tensordict_exp) ###############################################################################