Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Nov 27, 2023
1 parent f3cc664 commit 58b0aae
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 17 deletions.
12 changes: 3 additions & 9 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,9 @@ def convert_to_functional(
Args:
module (TensorDictModule or compatible): a stateful tensordict module.
This module will be made functional, yet still stateful, meaning
that it will be callable with the following alternative signatures:
>>> module(tensordict)
>>> module(tensordict, params=params)
``params`` is a :class:`tensordict.TensorDict` instance with parameters
stuctured as the output of :func:`tensordict.TensorDict.from_module`
is.
Parameters from this module will be isolated in the `<module_name>_params`
attribute and a stateless version of the module will be registed
under the `module_name` attribute.
module_name (str): name where the module will be found.
The parameters of the module will be found under ``loss_module.<module_name>_params``
whereas the module will be found under ``loss_module.<module_name>``.
Expand Down
7 changes: 4 additions & 3 deletions torchrl/objectives/decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict = tensordict.clone(False)
target_actions = tensordict.get(self.tensor_keys.action_target).detach()

pred_actions = self.actor_network(
tensordict, params=self.actor_network_params
).get(self.tensor_keys.action_pred)
with self.actor_network_params.to_module(self.actor_network):
pred_actions = self.actor_network(tensordict).get(
self.tensor_keys.action_pred
)
loss = distance_loss(
pred_actions,
target_actions,
Expand Down
11 changes: 6 additions & 5 deletions torchrl/objectives/multiagent/qmixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,10 +327,10 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDict:
td_copy = tensordict.clone(False)
self.local_value_network(
td_copy,
params=self.local_value_network_params,
)
with self.local_value_network_params.to_module(self.local_value_network):
self.local_value_network(
td_copy,
)

action = tensordict.get(self.tensor_keys.action)
pred_val = td_copy.get(
Expand All @@ -347,7 +347,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
pred_val_index = (pred_val * action).sum(-1, keepdim=True)

td_copy.set(self.tensor_keys.local_value, pred_val_index) # [*B, n_agents, 1]
self.mixer_network(td_copy, params=self.mixer_network_params)
with self.mixer_network_params.to_module(self.mixer_network):
self.mixer_network(td_copy)
pred_val_index = td_copy.get(self.tensor_keys.global_value).squeeze(-1)
# [*B] this is global and shared among the agents as will be the target

Expand Down

0 comments on commit 58b0aae

Please sign in to comment.