From 07fcfb1cff2e897fff406ef45c2e5ff7ea57a14f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 27 Nov 2023 15:12:44 +0000 Subject: [PATCH] [Minor] Hide params in ddpg actor-critic (#1716) --- torchrl/objectives/common.py | 12 +++--------- torchrl/objectives/ddpg.py | 2 +- torchrl/objectives/decision_transformer.py | 7 ++++--- torchrl/objectives/multiagent/qmixer.py | 13 +++++++------ 4 files changed, 15 insertions(+), 19 deletions(-) diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 76e5ef10900..00ba8cf456a 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -179,15 +179,9 @@ def convert_to_functional( Args: module (TensorDictModule or compatible): a stateful tensordict module. - This module will be made functional, yet still stateful, meaning - that it will be callable with the following alternative signatures: - - >>> module(tensordict) - >>> module(tensordict, params=params) - - ``params`` is a :class:`tensordict.TensorDict` instance with parameters - stuctured as the output of :func:`tensordict.TensorDict.from_module` - is. + Parameters from this module will be isolated in the `_params` + attribute and a stateless version of the module will be registed + under the `module_name` attribute. module_name (str): name where the module will be found. The parameters of the module will be found under ``loss_module._params`` whereas the module will be found under ``loss_module.``. diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index a72b84f69e4..3b4debe6259 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -200,7 +200,7 @@ def __init__( params = TensorDict.from_module(actor_critic) params_meta = params.apply(self._make_meta_params, device=torch.device("meta")) with params_meta.to_module(actor_critic): - self.actor_critic = deepcopy(actor_critic) + self.__dict__["actor_critic"] = deepcopy(actor_critic) self.convert_to_functional( actor_network, diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index ba7e2d4ba3f..52339d583dd 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -317,9 +317,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict = tensordict.clone(False) target_actions = tensordict.get(self.tensor_keys.action_target).detach() - pred_actions = self.actor_network( - tensordict, params=self.actor_network_params - ).get(self.tensor_keys.action_pred) + with self.actor_network_params.to_module(self.actor_network): + pred_actions = self.actor_network(tensordict).get( + self.tensor_keys.action_pred + ) loss = distance_loss( pred_actions, target_actions, diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py index 35e03d35744..23947696c9f 100644 --- a/torchrl/objectives/multiagent/qmixer.py +++ b/torchrl/objectives/multiagent/qmixer.py @@ -216,7 +216,7 @@ def __init__( with params.apply( self._make_meta_params, device=torch.device("meta") ).to_module(global_value_network): - self.global_value_network = deepcopy(global_value_network) + self.__dict__["global_value_network"] = deepcopy(global_value_network) self.convert_to_functional( local_value_network, @@ -327,10 +327,10 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDict: td_copy = tensordict.clone(False) - self.local_value_network( - td_copy, - params=self.local_value_network_params, - ) + with self.local_value_network_params.to_module(self.local_value_network): + self.local_value_network( + td_copy, + ) action = tensordict.get(self.tensor_keys.action) pred_val = td_copy.get( @@ -347,7 +347,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: pred_val_index = (pred_val * action).sum(-1, keepdim=True) td_copy.set(self.tensor_keys.local_value, pred_val_index) # [*B, n_agents, 1] - self.mixer_network(td_copy, params=self.mixer_network_params) + with self.mixer_network_params.to_module(self.mixer_network): + self.mixer_network(td_copy) pred_val_index = td_copy.get(self.tensor_keys.global_value).squeeze(-1) # [*B] this is global and shared among the agents as will be the target