diff --git a/torchrl/objectives/td3_bc.py b/torchrl/objectives/td3_bc.py index c136fca5f3f..93845bb00bd 100644 --- a/torchrl/objectives/td3_bc.py +++ b/torchrl/objectives/td3_bc.py @@ -32,7 +32,14 @@ class TD3BCLoss(LossModule): r"""TD3+BC Loss Module. Implementation of the TD3+BC loss presented in the paper `"A Minimalist Approach to - Offline Reinforcement Learning" ` + Offline Reinforcement Learning" `. + + This class incorporates two loss functions, executed sequentially within the `forward` method: + + 1. :meth:`~.qvalue_loss` + 2. :meth:`~.actor_loss` + + Users also have the option to call these functions directly in the same order if preferred. Args: actor_network (TensorDictModule): the actor to be trained @@ -370,6 +377,17 @@ def _cached_stack_actor_params(self): ) def actor_loss(self, tensordict): + """Compute the actor loss. + + The actor loss should be computed after the :meth:`~.qvalue_loss` and is usually delayed 1-3 critic updates. + + Args: + tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields + are required for this to be computed. + Returns: a differentiable tensor with the actor loss along with a metadata dictionary containing the detached `"bc_loss"` + used in the combined actor loss as well as the detached `"state_action_value_actor"` used to calculate the lambda + value, and the lambda value `"lmbd"` itself. + """ tensordict_actor_grad = tensordict.select( *self.actor_network.in_keys, strict=False ) @@ -398,14 +416,24 @@ def actor_loss(self, tensordict): loss_actor = -lmbd * state_action_value_actor[0] + bc_loss metadata = { - "state_action_value_actor": state_action_value_actor.detach(), - "bc_loss": bc_loss, + "state_action_value_actor": state_action_value_actor[0].detach(), + "bc_loss": bc_loss.detach(), "lmbd": lmbd, } loss_actor = _reduce(loss_actor, reduction=self.reduction) return loss_actor, metadata - def value_loss(self, tensordict): + def qvalue_loss(self, tensordict): + """Compute the q-value loss. + + The q-value loss should be computed before the :meth:`~.actor_loss`. + + Args: + tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields + are required for this to be computed. + Returns: a differentiable tensor with the qvalue loss along with a metadata dictionary containing + the detached `"td_error"` to be used for prioritized sampling, the detached `"next_state_value"`, the detached `"pred_value"`, and the detached `"target_value"`. + """ tensordict = tensordict.clone(False) act = tensordict.get(self.tensor_keys.action) @@ -484,9 +512,16 @@ def value_loss(self, tensordict): @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + """The forward method. + + Computes successively the :meth:`~.actor_loss`, :meth:`~.qvalue_loss`, and returns + a tensordict with these values. + To see what keys are expected in the input tensordict and what keys are expected as output, check the + class's `"in_keys"` and `"out_keys"` attributes. + """ tensordict_save = tensordict loss_actor, metadata_actor = self.actor_loss(tensordict) - loss_qval, metadata_value = self.value_loss(tensordict_save) + loss_qval, metadata_value = self.qvalue_loss(tensordict_save) tensordict_save.set( self.tensor_keys.priority, metadata_value.pop("td_error").detach().max(0)[0] )