diff --git a/test/test_cost.py b/test/test_cost.py index 41fc0781a01..60d1b1e374f 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -514,7 +514,7 @@ def test_dqn(self, delay_value, double_dqn, device, action_spec_type, td_est): assert loss_fn.tensor_keys.priority in td.keys() - sum([item for _, item in loss.items()]).backward() + sum([item for name, item in loss.items() if name.startswith("loss")]).backward() assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 # Check param update effect on targets @@ -581,15 +581,21 @@ def test_dqn_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9): loss = loss_fn(td) if n == 0: assert_allclose_td(td, ms_td.select(*td.keys(True, True))) - _loss = sum([item for _, item in loss.items()]) - _loss_ms = sum([item for _, item in loss_ms.items()]) + _loss = sum( + [item for name, item in loss.items() if name.startswith("loss")] + ) + _loss_ms = sum( + [item for name, item in loss_ms.items() if name.startswith("loss")] + ) assert ( abs(_loss - _loss_ms) < 1e-3 ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" else: with pytest.raises(AssertionError): assert_allclose_td(loss, loss_ms) - sum([item for _, item in loss_ms.items()]).backward() + sum( + [item for name, item in loss_ms.items() if name.startswith("loss")] + ).backward() assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 # Check param update effect on targets @@ -727,7 +733,7 @@ def test_distributional_dqn( assert loss_fn.tensor_keys.priority in td.keys() - sum([item for _, item in loss.items()]).backward() + sum([item for name, item in loss.items() if name.startswith("loss")]).backward() assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 if delay_value: @@ -875,6 +881,8 @@ def test_dqn_reduction(self, reduction): assert loss[key].shape == td.shape else: for key in loss.keys(): + if not key.startswith("loss"): + continue assert loss[key].shape == torch.Size([]) @pytest.mark.parametrize("atoms", range(4, 10)) @@ -901,6 +909,8 @@ def test_distributional_dqn_reduction(self, reduction, atoms): assert loss[key].shape == td.shape else: for key in loss.keys(): + if not key.startswith("loss"): + continue assert loss[key].shape == torch.Size([]) @@ -1065,7 +1075,7 @@ def test_qmixer(self, delay_value, device, action_spec_type, td_est): loss = loss_fn(td) assert loss_fn.tensor_keys.priority in td.keys() - sum([item for _, item in loss.items()]).backward() + sum([item for name, item in loss.items() if name.startswith("loss")]).backward() assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 if delay_value: @@ -1150,15 +1160,21 @@ def test_qmix_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9) loss = loss_fn(td) if n == 0: assert_allclose_td(td, ms_td.select(*td.keys(True, True))) - _loss = sum([item for _, item in loss.items()]) - _loss_ms = sum([item for _, item in loss_ms.items()]) + _loss = sum( + [item for name, item in loss.items() if name.startswith("loss")] + ) + _loss_ms = sum( + [item for name, item in loss_ms.items() if name.startswith("loss")] + ) assert ( abs(_loss - _loss_ms) < 1e-3 ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" else: with pytest.raises(AssertionError): assert_allclose_td(loss, loss_ms) - sum([item for _, item in loss_ms.items()]).backward() + sum( + [item for name, item in loss_ms.items() if name.startswith("loss")] + ).backward() assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 # Check param update effect on targets @@ -1604,7 +1620,9 @@ def test_ddpg(self, delay_actor, delay_value, device, td_est): loss_fn.zero_grad() # check overall grad - sum([item for _, item in loss.items()]).backward() + sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ).backward() parameters = list(actor.parameters()) + list(value.parameters()) for p in parameters: assert p.grad.norm() > 0.0 @@ -1816,15 +1834,21 @@ def test_ddpg_batcher(self, n, delay_actor, delay_value, device, gamma=0.9): loss = loss_fn(td) if n == 0: assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) - _loss = sum([item for _, item in loss.items()]) - _loss_ms = sum([item for _, item in loss_ms.items()]) + _loss = sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ) + _loss_ms = sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ) assert ( abs(_loss - _loss_ms) < 1e-3 ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" else: with pytest.raises(AssertionError): assert_allclose_td(loss, loss_ms) - sum([item for _, item in loss_ms.items()]).backward() + sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ).backward() parameters = list(actor.parameters()) + list(value.parameters()) for p in parameters: assert p.grad.norm() > 0.0 @@ -1971,6 +1995,8 @@ def test_ddpg_reduction(self, reduction): assert loss[key].shape == td.shape else: for key in loss.keys(): + if not key.startswith("loss_"): + continue assert loss[key].shape == torch.Size([]) @@ -2259,7 +2285,9 @@ def test_td3( raise NotImplementedError(k) loss_fn.zero_grad() - sum([item for _, item in loss.items()]).backward() + sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ).backward() named_parameters = list(loss_fn.named_parameters()) named_buffers = list(loss_fn.named_buffers()) @@ -2453,8 +2481,12 @@ def test_td3_batcher( if n == 0: assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) - _loss = sum([item for _, item in loss.items()]) - _loss_ms = sum([item for _, item in loss_ms.items()]) + _loss = sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ) + _loss_ms = sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ) assert ( abs(_loss - _loss_ms) < 1e-3 ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" @@ -2462,7 +2494,9 @@ def test_td3_batcher( with pytest.raises(AssertionError): assert_allclose_td(loss, loss_ms) - sum([item for _, item in loss_ms.items()]).backward() + sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ).backward() named_parameters = loss_fn.named_parameters() for name, p in named_parameters: @@ -2620,11 +2654,8 @@ def test_td3_notensordict( loss_val_td = loss(td) torch.manual_seed(0) loss_val = loss(**kwargs) - for i in loss_val: - assert i in loss_val_td.values(), f"{i} not in {loss_val_td.values()}" - - for i, key in enumerate(loss.out_keys): - torch.testing.assert_close(loss_val_td.get(key), loss_val[i]) + loss_val_reconstruct = TensorDict(dict(zip(loss.out_keys, loss_val)), []) + assert_allclose_td(loss_val_reconstruct, loss_val_td) # test select loss.select_out_keys("loss_actor", "loss_qvalue") @@ -2673,6 +2704,8 @@ def test_td3_reduction(self, reduction): assert loss[key].shape == td.shape else: for key in loss.keys(): + if not key.startswith("loss"): + continue assert loss[key].shape == torch.Size([]) @@ -3033,7 +3066,9 @@ def test_sac( raise NotImplementedError(k) loss_fn.zero_grad() - sum([item for _, item in loss.items()]).backward() + sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ).backward() named_parameters = list(loss_fn.named_parameters()) named_buffers = list(loss_fn.named_buffers()) @@ -3259,15 +3294,21 @@ def test_sac_batcher( loss = loss_fn(td) if n == 0: assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) - _loss = sum([item for _, item in loss.items()]) - _loss_ms = sum([item for _, item in loss_ms.items()]) + _loss = sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ) + _loss_ms = sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ) assert ( abs(_loss - _loss_ms) < 1e-3 ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" else: with pytest.raises(AssertionError): assert_allclose_td(loss, loss_ms) - sum([item for _, item in loss_ms.items()]).backward() + sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ).backward() named_parameters = loss_fn.named_parameters() for name, p in named_parameters: if not name.startswith("target_"): @@ -3576,6 +3617,8 @@ def test_sac_reduction(self, reduction, version): assert loss[key].shape == td.shape else: for key in loss.keys(): + if not key.startswith("loss"): + continue assert loss[key].shape == torch.Size([]) @@ -3818,7 +3861,9 @@ def test_discrete_sac( raise NotImplementedError(k) loss_fn.zero_grad() - sum([item for _, item in loss.items()]).backward() + sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ).backward() named_parameters = list(loss_fn.named_parameters()) named_buffers = list(loss_fn.named_buffers()) @@ -3940,15 +3985,21 @@ def test_discrete_sac_batcher( loss = loss_fn(td) if n == 0: assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) - _loss = sum([item for _, item in loss.items()]) - _loss_ms = sum([item for _, item in loss_ms.items()]) + _loss = sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ) + _loss_ms = sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ) assert ( abs(_loss - _loss_ms) < 1e-3 ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" else: with pytest.raises(AssertionError): assert_allclose_td(loss, loss_ms) - sum([item for _, item in loss_ms.items()]).backward() + sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ).backward() named_parameters = loss_fn.named_parameters() for name, p in named_parameters: if not name.startswith("target_"): @@ -4155,6 +4206,8 @@ def test_discrete_sac_reduction(self, reduction): assert loss[key].shape == td.shape else: for key in loss.keys(): + if not key.startswith("loss"): + continue assert loss[key].shape == torch.Size([]) @@ -4474,7 +4527,9 @@ def test_redq(self, delay_qvalue, num_qvalue, device, td_est): raise NotImplementedError(k) loss_fn.zero_grad() - sum([item for _, item in loss.items()]).backward() + sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ).backward() named_parameters = list(loss_fn.named_parameters()) named_buffers = list(loss_fn.named_buffers()) @@ -4861,15 +4916,21 @@ def test_redq_batcher(self, n, delay_qvalue, num_qvalue, device, gamma=0.9): loss = loss_fn(td) if n == 0: assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) - _loss = sum([item for _, item in loss.items()]) - _loss_ms = sum([item for _, item in loss_ms.items()]) + _loss = sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ) + _loss_ms = sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ) assert ( abs(_loss - _loss_ms) < 1e-3 ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" else: with pytest.raises(AssertionError): assert_allclose_td(loss, loss_ms) - sum([item for _, item in loss_ms.items()]).backward() + sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ).backward() named_parameters = loss_fn.named_parameters() for name, p in named_parameters: if not name.startswith("target_"): @@ -5099,6 +5160,8 @@ def test_redq_reduction(self, reduction, deprecated_loss): assert loss[key].shape[-1] == td.shape[0] else: for key in loss.keys(): + if not key.startswith("loss"): + continue assert loss[key].shape == torch.Size([]) @@ -5350,7 +5413,9 @@ def test_cql( ) ) - sum([item for _, item in loss.items()]).backward() + sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ).backward() named_parameters = list(loss_fn.named_parameters()) named_buffers = list(loss_fn.named_buffers()) @@ -5474,15 +5539,21 @@ def test_cql_batcher( loss = loss_fn(td) if n == 0: assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) - _loss = sum([item for _, item in loss.items()]) - _loss_ms = sum([item for _, item in loss_ms.items()]) + _loss = sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ) + _loss_ms = sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ) assert ( abs(_loss - _loss_ms) < 1e-3 ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" else: with pytest.raises(AssertionError): assert_allclose_td(loss, loss_ms) - sum([item for _, item in loss_ms.items()]).backward() + sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ).backward() named_parameters = loss_fn.named_parameters() for name, p in named_parameters: if not name.startswith("target_"): @@ -8914,7 +8985,9 @@ def test_iql( raise NotImplementedError(k) loss_fn.zero_grad() - sum([item for _, item in loss.items()]).backward() + sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ).backward() named_parameters = list(loss_fn.named_parameters()) named_buffers = list(loss_fn.named_buffers()) @@ -9181,15 +9254,21 @@ def test_iql_batcher( loss = loss_fn(td) if n == 0: assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) - _loss = sum([item for _, item in loss.items()]) - _loss_ms = sum([item for _, item in loss_ms.items()]) + _loss = sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ) + _loss_ms = sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ) assert ( abs(_loss - _loss_ms) < 1e-3 ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" else: with pytest.raises(AssertionError): assert_allclose_td(loss, loss_ms) - sum([item for _, item in loss_ms.items()]).backward() + sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ).backward() named_parameters = loss_fn.named_parameters() for name, p in named_parameters: if not name.startswith("target_"): @@ -9679,7 +9758,9 @@ def test_discrete_iql( raise NotImplementedError(k) loss_fn.zero_grad() - sum([item for _, item in loss.items()]).backward() + sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ).backward() named_parameters = list(loss_fn.named_parameters()) named_buffers = list(loss_fn.named_buffers()) @@ -9949,15 +10030,21 @@ def test_discrete_iql_batcher( loss = loss_fn(td) if n == 0: assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) - _loss = sum([item for _, item in loss.items()]) - _loss_ms = sum([item for _, item in loss_ms.items()]) + _loss = sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ) + _loss_ms = sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ) assert ( abs(_loss - _loss_ms) < 1e-3 ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" else: with pytest.raises(AssertionError): assert_allclose_td(loss, loss_ms) - sum([item for _, item in loss_ms.items()]).backward() + sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ).backward() named_parameters = loss_fn.named_parameters() for name, p in named_parameters: if not name.startswith("target_"): diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 5024068889f..1846db4989a 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -3,7 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import contextlib -import functools import warnings from copy import deepcopy from dataclasses import dataclass @@ -472,13 +471,16 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: td_out = TensorDict({"loss_objective": loss}, batch_size=[]) if self.entropy_bonus: entropy = self.get_entropy_bonus(dist) - td_out.set("entropy", entropy.detach()) # for logging + td_out.set("entropy", entropy.detach().mean()) # for logging td_out.set("loss_entropy", -self.entropy_coef * entropy) if self.critic_coef: loss_critic = self.loss_critic(tensordict) td_out.set("loss_critic", loss_critic) - td_out = td_out.apply( - functools.partial(_reduce, reduction=self.reduction), batch_size=[] + td_out = td_out.named_apply( + lambda name, value: _reduce(value, reduction=self.reduction) + if name.startswith("loss_") + else value, + batch_size=[], ) return td_out diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index 2ecdfde6bb3..ba78e5f193a 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -5,7 +5,6 @@ from __future__ import annotations -import functools from copy import deepcopy from dataclasses import dataclass from typing import Tuple @@ -296,9 +295,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: source={"loss_actor": loss_actor, "loss_value": loss_value, **metadata}, batch_size=[], ) - td_out = td_out.apply( - functools.partial(_reduce, reduction=self.reduction), batch_size=[] - ) return td_out def loss_actor( @@ -314,6 +310,7 @@ def loss_actor( td_copy = self.value_network(td_copy) loss_actor = -td_copy.get(self.tensor_keys.state_action_value).squeeze(-1) metadata = {} + loss_actor = _reduce(loss_actor, self.reduction) return loss_actor, metadata def loss_value( @@ -352,6 +349,7 @@ def loss_value( "target_value_max": target_value.max(), "pred_value_max": pred_val.max(), } + loss_value = _reduce(loss_value, self.reduction) return loss_value, metadata def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index 954bd0b9a42..ec6ed4f5252 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -220,7 +220,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "loss_log_likelihood": -log_likelihood, "loss_entropy": -entropy_bonus, "loss_alpha": loss_alpha, - "entropy": entropy.detach(), + "entropy": entropy.detach().mean(), "alpha": self.alpha.detach(), } return TensorDict(out, []) diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index d08edee71bd..9eb7d9a07e3 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -2,7 +2,6 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import functools import math from dataclasses import dataclass from numbers import Number @@ -312,12 +311,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "loss_qvalue": loss_qval, "loss_alpha": loss_alpha, "alpha": self.alpha, - "entropy": -sample_log_prob.detach(), + "entropy": -sample_log_prob.detach().mean(), }, [], ) - td_out = td_out.apply( - functools.partial(_reduce, reduction=self.reduction), batch_size=[] + td_out = td_out.named_apply( + lambda name, value: _reduce(value, reduction=self.reduction) + if name.startswith("loss_") + else value, + batch_size=[], ) return td_out diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 04756a94c02..d871098cbb8 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -555,7 +555,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: td_out = TensorDict({"loss_objective": -neg_loss}, batch_size=[]) if self.entropy_bonus: entropy = self.get_entropy_bonus(dist) - td_out.set("entropy", entropy.detach()) # for logging + td_out.set("entropy", entropy.detach().mean()) # for logging td_out.set("loss_entropy", -self.entropy_coef * entropy) if self.critic_coef: loss_critic = self.loss_critic(tensordict) @@ -799,15 +799,18 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if self.entropy_bonus: entropy = self.get_entropy_bonus(dist) - td_out.set("entropy", entropy.detach()) # for logging + td_out.set("entropy", entropy.detach().mean()) # for logging td_out.set("loss_entropy", -self.entropy_coef * entropy) if self.critic_coef: loss_critic = self.loss_critic(tensordict) td_out.set("loss_critic", loss_critic) td_out.set("ESS", _reduce(ess, self.reduction) / batch) - td_out = td_out.apply( - functools.partial(_reduce, reduction=self.reduction), batch_size=[] + td_out = td_out.named_apply( + lambda name, value: _reduce(value, reduction=self.reduction) + if name.startswith("loss_") + else value, + batch_size=[], ) return td_out @@ -1061,13 +1064,16 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: if self.entropy_bonus: entropy = self.get_entropy_bonus(dist) - td_out.set("entropy", entropy.detach()) # for logging + td_out.set("entropy", entropy.detach().mean()) # for logging td_out.set("loss_entropy", -self.entropy_coef * entropy) if self.critic_coef: loss_critic = self.loss_critic(tensordict_copy) td_out.set("loss_critic", loss_critic) - td_out = td_out.apply( - functools.partial(_reduce, reduction=self.reduction), batch_size=[] + td_out = td_out.named_apply( + lambda name, value: _reduce(value, reduction=self.reduction) + if name.startswith("loss_") + else value, + batch_size=[], ) return td_out diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index 1c4e8785240..817483a0269 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -2,7 +2,6 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import functools import math from dataclasses import dataclass from numbers import Number @@ -564,7 +563,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "loss_qvalue": loss_qval, "loss_alpha": loss_alpha, "alpha": self.alpha.detach(), - "entropy": -sample_log_prob.detach(), + "entropy": -sample_log_prob.detach().mean(), "state_action_value_actor": state_action_value_actor.detach(), "action_log_prob_actor": action_log_prob_actor.detach(), "next.state_value": next_state_value.detach(), @@ -572,8 +571,11 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: }, [], ) - td_out = td_out.apply( - functools.partial(_reduce, reduction=self.reduction), batch_size=[] + td_out = td_out.named_apply( + lambda name, value: _reduce(value, reduction=self.reduction) + if name.startswith("loss_") + else value, + batch_size=[], ) return td_out diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index eb45a7c106e..37b595f7cb7 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -5,7 +5,6 @@ from __future__ import annotations import contextlib -import functools import warnings from copy import deepcopy from dataclasses import dataclass @@ -402,8 +401,11 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: td_out = TensorDict({"loss_actor": loss_actor}, batch_size=[]) td_out.set("loss_value", self.loss_critic(tensordict)) - td_out = td_out.apply( - functools.partial(_reduce, reduction=self.reduction), batch_size=[] + td_out = td_out.named_apply( + lambda name, value: _reduce(value, reduction=self.reduction) + if name.startswith("loss_") + else value, + batch_size=[], ) return td_out diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index c80b56ae77b..277d068ca3e 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -2,7 +2,6 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import functools import math import warnings from dataclasses import dataclass @@ -573,13 +572,16 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "loss_qvalue": loss_qvalue, "loss_alpha": loss_alpha, "alpha": self._alpha, - "entropy": entropy, + "entropy": entropy.detach().mean(), } if self._version == 1: out["loss_value"] = loss_value td_out = TensorDict(out, []) - td_out = td_out.apply( - functools.partial(_reduce, reduction=self.reduction), batch_size=[] + td_out = td_out.named_apply( + lambda name, value: _reduce(value, reduction=self.reduction) + if name.startswith("loss_") + else value, + batch_size=[], ) return td_out @@ -1134,11 +1136,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "loss_qvalue": loss_value, "loss_alpha": loss_alpha, "alpha": self._alpha, - "entropy": entropy, + "entropy": entropy.detach().mean(), } td_out = TensorDict(out, []) - td_out = td_out.apply( - functools.partial(_reduce, reduction=self.reduction), batch_size=[] + td_out = td_out.named_apply( + lambda name, value: _reduce(value, reduction=self.reduction) + if name.startswith("loss_") + else value, + batch_size=[], ) return td_out diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index d6f4d2c10c8..ee14eb66fea 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -2,7 +2,6 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import functools from dataclasses import dataclass from typing import Optional, Tuple @@ -375,6 +374,7 @@ def actor_loss(self, tensordict): metadata = { "state_action_value_actor": state_action_value_actor.detach(), } + loss_actor = _reduce(loss_actor, reduction=self.reduction) return loss_actor, metadata def value_loss(self, tensordict): @@ -449,6 +449,7 @@ def value_loss(self, tensordict): "pred_value": current_qvalue.detach(), "target_value": target_value.detach(), } + loss_qval = _reduce(loss_qval, reduction=self.reduction) return loss_qval, metadata @dispatch @@ -472,9 +473,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: }, batch_size=[], ) - td_out = td_out.apply( - functools.partial(_reduce, reduction=self.reduction), batch_size=[] - ) return td_out def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):