Skip to content

Commit

Permalink
update docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
BY571 committed Jul 9, 2024
1 parent 5476e36 commit 235ca51
Showing 1 changed file with 40 additions and 5 deletions.
45 changes: 40 additions & 5 deletions torchrl/objectives/td3_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,14 @@ class TD3BCLoss(LossModule):
r"""TD3+BC Loss Module.
Implementation of the TD3+BC loss presented in the paper `"A Minimalist Approach to
Offline Reinforcement Learning" <https://arxiv.org/pdf/2106.06860>`
Offline Reinforcement Learning" <https://arxiv.org/pdf/2106.06860>`.
This class incorporates two loss functions, executed sequentially within the `forward` method:
1. :meth:`~.qvalue_loss`
2. :meth:`~.actor_loss`
Users also have the option to call these functions directly in the same order if preferred.
Args:
actor_network (TensorDictModule): the actor to be trained
Expand Down Expand Up @@ -370,6 +377,17 @@ def _cached_stack_actor_params(self):
)

def actor_loss(self, tensordict):
"""Compute the actor loss.
The actor loss should be computed after the :meth:`~.qvalue_loss` and is usually delayed 1-3 critic updates.
Args:
tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields
are required for this to be computed.
Returns: a differentiable tensor with the actor loss along with a metadata dictionary containing the detached `"bc_loss"`
used in the combined actor loss as well as the detached `"state_action_value_actor"` used to calculate the lambda
value, and the lambda value `"lmbd"` itself.
"""
tensordict_actor_grad = tensordict.select(
*self.actor_network.in_keys, strict=False
)
Expand Down Expand Up @@ -398,14 +416,24 @@ def actor_loss(self, tensordict):
loss_actor = -lmbd * state_action_value_actor[0] + bc_loss

metadata = {
"state_action_value_actor": state_action_value_actor.detach(),
"bc_loss": bc_loss,
"state_action_value_actor": state_action_value_actor[0].detach(),
"bc_loss": bc_loss.detach(),
"lmbd": lmbd,
}
loss_actor = _reduce(loss_actor, reduction=self.reduction)
return loss_actor, metadata

def value_loss(self, tensordict):
def qvalue_loss(self, tensordict):
"""Compute the q-value loss.
The q-value loss should be computed before the :meth:`~.actor_loss`.
Args:
tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields
are required for this to be computed.
Returns: a differentiable tensor with the qvalue loss along with a metadata dictionary containing
the detached `"td_error"` to be used for prioritized sampling, the detached `"next_state_value"`, the detached `"pred_value"`, and the detached `"target_value"`.
"""
tensordict = tensordict.clone(False)

act = tensordict.get(self.tensor_keys.action)
Expand Down Expand Up @@ -484,9 +512,16 @@ def value_loss(self, tensordict):

@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
"""The forward method.
Computes successively the :meth:`~.actor_loss`, :meth:`~.qvalue_loss`, and returns
a tensordict with these values.
To see what keys are expected in the input tensordict and what keys are expected as output, check the
class's `"in_keys"` and `"out_keys"` attributes.
"""
tensordict_save = tensordict
loss_actor, metadata_actor = self.actor_loss(tensordict)
loss_qval, metadata_value = self.value_loss(tensordict_save)
loss_qval, metadata_value = self.qvalue_loss(tensordict_save)
tensordict_save.set(
self.tensor_keys.priority, metadata_value.pop("td_error").detach().max(0)[0]
)
Expand Down

0 comments on commit 235ca51

Please sign in to comment.