Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Apr 23, 2024
1 parent 675f5c9 commit f068784
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 30 deletions.
3 changes: 3 additions & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,8 @@ to always know what the latest available actions are. You can do this like so:
Recorders
---------

.. _Environment-Recorders:

Recording data during environment rollout execution is crucial to keep an eye on the algorithm performance as well as
reporting results after training.

Expand Down Expand Up @@ -809,6 +811,7 @@ Recorders are transforms that register data as they come in, for logging purpose

TensorDictRecorder
VideoRecorder
PixelRenderTransform


Helpers
Expand Down
63 changes: 33 additions & 30 deletions docs/source/reference/trainers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ loop the optimization steps. We believe this fits multiple RL training schemes,
on-policy, off-policy, model-based and model-free solutions, offline RL and others.
More particular cases, such as meta-RL algorithms may have training schemes that differ substentially.

The :obj:`trainer.train()` method can be sketched as follows:
The ``trainer.train()`` method can be sketched as follows:

.. code-block::
:caption: Trainer loops
Expand Down Expand Up @@ -63,35 +63,35 @@ The :obj:`trainer.train()` method can be sketched as follows:
... self._post_steps_hook() # "post_steps"
... self._post_steps_log_hook(batch) # "post_steps_log"
There are 10 hooks that can be used in a trainer loop: :obj:`"batch_process"`, :obj:`"pre_optim_steps"`,
:obj:`"process_optim_batch"`, :obj:`"post_loss"`, :obj:`"post_steps"`, :obj:`"post_optim"`, :obj:`"pre_steps_log"`,
:obj:`"post_steps_log"`, :obj:`"post_optim_log"` and :obj:`"optimizer"`. They are indicated in the comments where they are applied.
Hooks can be split into 3 categories: **data processing** (:obj:`"batch_process"` and :obj:`"process_optim_batch"`),
**logging** (:obj:`"pre_steps_log"`, :obj:`"post_optim_log"` and :obj:`"post_steps_log"`) and **operations** hook
(:obj:`"pre_optim_steps"`, :obj:`"post_loss"`, :obj:`"post_optim"` and :obj:`"post_steps"`).

- **Data processing** hooks update a tensordict of data. Hooks :obj:`__call__` method should accept
a :obj:`TensorDict` object as input and update it given some strategy.
Examples of such hooks include Replay Buffer extension (:obj:`ReplayBufferTrainer.extend`), data normalization (including normalization
constants update), data subsampling (:class:`~torchrl.trainers.BatchSubSampler`) and such.

- **Logging** hooks take a batch of data presented as a :obj:`TensorDict` and write in the logger
some information retrieved from that data. Examples include the :obj:`Recorder` hook, the reward
logger (:obj:`LogReward`) and such. Hooks should return a dictionary (or a None value) containing the
data to log. The key :obj:`"log_pbar"` is reserved to boolean values indicating if the logged value
There are 10 hooks that can be used in a trainer loop: ``"batch_process"``, ``"pre_optim_steps"``,
``"process_optim_batch"``, ``"post_loss"``, ``"post_steps"``, ``"post_optim"``, ``"pre_steps_log"``,
``"post_steps_log"``, ``"post_optim_log"`` and ``"optimizer"``. They are indicated in the comments where they are applied.
Hooks can be split into 3 categories: **data processing** (``"batch_process"`` and ``"process_optim_batch"``),
**logging** (``"pre_steps_log"``, ``"post_optim_log"`` and ``"post_steps_log"``) and **operations** hook
(``"pre_optim_steps"``, ``"post_loss"``, ``"post_optim"`` and ``"post_steps"``).

- **Data processing** hooks update a tensordict of data. Hooks ``__call__`` method should accept
a ``TensorDict`` object as input and update it given some strategy.
Examples of such hooks include Replay Buffer extension (``ReplayBufferTrainer.extend``), data normalization (including normalization
constants update), data subsampling (:class:``~torchrl.trainers.BatchSubSampler``) and such.

- **Logging** hooks take a batch of data presented as a ``TensorDict`` and write in the logger
some information retrieved from that data. Examples include the ``Recorder`` hook, the reward
logger (``LogReward``) and such. Hooks should return a dictionary (or a None value) containing the
data to log. The key ``"log_pbar"`` is reserved to boolean values indicating if the logged value
should be displayed on the progression bar printed on the training log.

- **Operation** hooks are hooks that execute specific operations over the models, data collectors,
target network updates and such. For instance, syncing the weights of the collectors using :obj:`UpdateWeights`
or update the priority of the replay buffer using :obj:`ReplayBufferTrainer.update_priority` are examples
of operation hooks. They are data-independent (they do not require a :obj:`TensorDict`
target network updates and such. For instance, syncing the weights of the collectors using ``UpdateWeights``
or update the priority of the replay buffer using ``ReplayBufferTrainer.update_priority`` are examples
of operation hooks. They are data-independent (they do not require a ``TensorDict``
input), they are just supposed to be executed once at every iteration (or every N iterations).

The hooks provided by TorchRL usually inherit from a common abstract class :obj:`TrainerHookBase`,
and all implement three base methods: a :obj:`state_dict` and :obj:`load_state_dict` method for
checkpointing and a :obj:`register` method that registers the hook at the default value in the
The hooks provided by TorchRL usually inherit from a common abstract class ``TrainerHookBase``,
and all implement three base methods: a ``state_dict`` and ``load_state_dict`` method for
checkpointing and a ``register`` method that registers the hook at the default value in the
trainer. This method takes a trainer and a module name as input. For instance, the following logging
hook is executed every 10 calls to :obj:`"post_optim_log"`:
hook is executed every 10 calls to ``"post_optim_log"``:

.. code-block::
Expand Down Expand Up @@ -122,22 +122,22 @@ Checkpointing
-------------

The trainer class and hooks support checkpointing, which can be achieved either
using the `torchsnapshot <https://github.com/pytorch/torchsnapshot/>`_ backend or
the regular torch backend. This can be controlled via the global variable :obj:`CKPT_BACKEND`:
using the ``torchsnapshot <https://github.com/pytorch/torchsnapshot/>``_ backend or
the regular torch backend. This can be controlled via the global variable ``CKPT_BACKEND``:

.. code-block::
$ CKPT_BACKEND=torch python script.py
which defaults to :obj:`torchsnapshot`. The advantage of torchsnapshot over pytorch
which defaults to ``torchsnapshot``. The advantage of torchsnapshot over pytorch
is that it is a more flexible API, which supports distributed checkpointing and
also allows users to load tensors from a file stored on disk to a tensor with a
physical storage (which pytorch currently does not support). This allows, for instance,
to load tensors from and to a replay buffer that would otherwise not fit in memory.

When building a trainer, one can provide a file path where the checkpoints are to
be written. With the :obj:`torchsnapshot` backend, a directory path is expected,
whereas the :obj:`torch` backend expects a file path (typically a :obj:`.pt` file).
be written. With the ``torchsnapshot`` backend, a directory path is expected,
whereas the ``torch`` backend expects a file path (typically a ``.pt`` file).

.. code-block::
Expand All @@ -157,7 +157,7 @@ whereas the :obj:`torch` backend expects a file path (typically a :obj:`.pt` fi
>>> # to load from a path
>>> trainer.load_from_file(filepath)
The :obj:`Trainer.train()` method can be used to execute the above loop with all of
The ``Trainer.train()`` method can be used to execute the above loop with all of
its hooks, although using the :obj:`Trainer` class for its checkpointing capability
only is also a perfectly valid use.

Expand Down Expand Up @@ -238,6 +238,8 @@ Loggers
Recording utils
---------------

Recording utils are detailed :ref:`here <Environment-Recorders>`.

.. currentmodule:: torchrl.record

.. autosummary::
Expand All @@ -246,3 +248,4 @@ Recording utils

VideoRecorder
TensorDictRecorder
PixelRenderTransform

0 comments on commit f068784

Please sign in to comment.