Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into fix-stack-deprec
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Apr 23, 2024
2 parents 5f9828a + 6c2e141 commit e1c6b98
Show file tree
Hide file tree
Showing 68 changed files with 1,060 additions and 225 deletions.
32 changes: 0 additions & 32 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,20 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/de
optim.pretrain_gradient_steps=55 \
optim.updates_per_episode=3 \
optim.warmup_steps=10 \
optim.device=cuda:0 \
logger.backend= \
env.backend=gymnasium \
env.name=HalfCheetah-v4
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/decision_transformer/online_dt.py \
optim.pretrain_gradient_steps=55 \
optim.updates_per_episode=3 \
optim.warmup_steps=10 \
optim.device=cuda:0 \
env.backend=gymnasium \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iql/iql_offline.py \
optim.gradient_steps=55 \
optim.device=cuda:0 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/cql_offline.py \
optim.gradient_steps=55 \
optim.device=cuda:0 \
logger.backend=

# ==================================================================================== #
Expand Down Expand Up @@ -86,8 +82,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dd
optim.batch_size=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=2 \
collector.device=cuda:0 \
network.device=cuda:0 \
optim.utd_ratio=1 \
replay_buffer.size=120 \
env.name=Pendulum-v1 \
Expand All @@ -112,7 +106,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dq
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
buffer.batch_size=10 \
device=cuda:0 \
loss.num_updates=1 \
logger.backend= \
buffer.buffer_size=120
Expand All @@ -122,7 +115,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cq
optim.batch_size=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=2 \
collector.device=cuda:0 \
replay_buffer.size=120 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/redq/redq.py \
Expand All @@ -131,7 +123,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/re
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=2 \
collector.device=cuda:0 \
buffer.batch_size=10 \
optim.steps_per_batch=1 \
logger.record_video=True \
Expand All @@ -143,22 +134,18 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/sa
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=2 \
collector.device=cuda:0 \
optim.batch_size=10 \
optim.utd_ratio=1 \
replay_buffer.size=120 \
env.name=Pendulum-v1 \
network.device=cuda:0 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/discrete_sac/discrete_sac.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=1 \
collector.device=cuda:0 \
optim.batch_size=10 \
optim.utd_ratio=1 \
network.device=cuda:0 \
optim.batch_size=10 \
optim.utd_ratio=1 \
replay_buffer.size=120 \
Expand All @@ -185,9 +172,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td
collector.frames_per_batch=16 \
collector.num_workers=4 \
collector.env_per_collector=2 \
collector.device=cuda:0 \
collector.device=cuda:0 \
network.device=cuda:0 \
logger.mode=offline \
env.name=Pendulum-v1 \
logger.backend=
Expand All @@ -196,26 +180,20 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq
optim.batch_size=10 \
collector.frames_per_batch=16 \
env.train_num_envs=2 \
optim.device=cuda:0 \
collector.device=cuda:0 \
logger.mode=offline \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iql/discrete_iql.py \
collector.total_frames=48 \
optim.batch_size=10 \
collector.frames_per_batch=16 \
env.train_num_envs=2 \
optim.device=cuda:0 \
collector.device=cuda:0 \
logger.mode=offline \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/cql_online.py \
collector.total_frames=48 \
optim.batch_size=10 \
collector.frames_per_batch=16 \
env.train_num_envs=2 \
collector.device=cuda:0 \
optim.device=cuda:0 \
logger.mode=offline \
logger.backend=

Expand All @@ -238,8 +216,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dd
optim.batch_size=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=1 \
collector.device=cuda:0 \
network.device=cuda:0 \
optim.utd_ratio=1 \
replay_buffer.size=120 \
env.name=Pendulum-v1 \
Expand All @@ -251,7 +227,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dq
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
buffer.batch_size=10 \
device=cuda:0 \
loss.num_updates=1 \
logger.backend= \
buffer.buffer_size=120
Expand All @@ -262,7 +237,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/re
collector.frames_per_batch=16 \
collector.env_per_collector=1 \
buffer.batch_size=10 \
collector.device=cuda:0 \
optim.steps_per_batch=1 \
logger.record_video=True \
logger.record_frames=4 \
Expand All @@ -274,29 +248,23 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq
collector.frames_per_batch=16 \
env.train_num_envs=1 \
logger.mode=offline \
optim.device=cuda:0 \
collector.device=cuda:0 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/cql_online.py \
collector.total_frames=48 \
optim.batch_size=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=1 \
logger.mode=offline \
optim.device=cuda:0 \
collector.device=cuda:0 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3/td3.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
collector.num_workers=2 \
collector.env_per_collector=1 \
collector.device=cuda:0 \
logger.mode=offline \
optim.batch_size=10 \
env.name=Pendulum-v1 \
network.device=cuda:0 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/multiagent/mappo_ippo.py \
collector.n_iters=2 \
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,7 @@ Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check.
DiscreteTensorSpec
MultiDiscreteTensorSpec
MultiOneHotDiscreteTensorSpec
NonTensorSpec
OneHotDiscreteTensorSpec
UnboundedContinuousTensorSpec
UnboundedDiscreteTensorSpec
Expand Down
70 changes: 70 additions & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,75 @@ to always know what the latest available actions are. You can do this like so:
Recorders
---------

.. _Environment-Recorders:

Recording data during environment rollout execution is crucial to keep an eye on the algorithm performance as well as
reporting results after training.

TorchRL offers several tools to interact with the environment output: first and foremost, a ``callback`` callable
can be passed to the :meth:`~torchrl.envs.EnvBase.rollout` method. This function will be called upon the collected
tensordict at each iteration of the rollout (if some iterations have to be skipped, an internal variable should be added
to keep track of the call count within ``callback``).

To save collected tensordicts on disk, the :class:`~torchrl.record.TensorDictRecorder` can be used.

Recording videos
~~~~~~~~~~~~~~~~

Several backends offer the possibility of recording rendered images from the environment.
If the pixels are already part of the environment output (e.g. Atari or other game simulators), a
:class:`~torchrl.record.VideoRecorder` can be appended to the environment. This environment transform takes as input
a logger capable of recording videos (e.g. :class:`~torchrl.record.loggers.CSVLogger`, :class:`~torchrl.record.loggers.WandbLogger`
or :class:`~torchrl.record.loggers.TensorBoardLogger`) as well as a tag indicating where the video should be saved.
For instance, to save mp4 videos on disk, one can use :class:`~torchrl.record.loggers.CSVLogger` with a `video_format="mp4"`
argument.

The :class:`~torchrl.record.VideoRecorder` transform can handle batched images and automatically detects numpy or PyTorch
formatted images (WHC or CWH).

>>> logger = CSVLogger("dummy-exp", video_format="mp4")
>>> env = GymEnv("ALE/Pong-v5")
>>> env = env.append_transform(VideoRecorder(logger, tag="rendered", in_keys=["pixels"]))
>>> env.rollout(10)
>>> env.transform.dump() # Save the video and clear cache

Note that the cache of the transform will keep on growing until dump is called. It is the user responsibility to
take care of calling dumpy when needed to avoid OOM issues.

In some cases, creating a testing environment where images can be collected is tedious or expensive, or simply impossible
(some libraries only allow one environment instance per workspace).
In these cases, assuming that a `render` method is available in the environment, the :class:`~torchrl.record.PixelRenderTransform`
can be used to call `render` on the parent environment and save the images in the rollout data stream.
This class works over single and batched environments alike:

>>> from torchrl.envs import GymEnv, check_env_specs, ParallelEnv, EnvCreator
>>> from torchrl.record.loggers import CSVLogger
>>> from torchrl.record.recorder import PixelRenderTransform, VideoRecorder
>>>
>>> def make_env():
>>> env = GymEnv("CartPole-v1", render_mode="rgb_array")
>>> # Uncomment this line to execute per-env
>>> # env = env.append_transform(PixelRenderTransform())
>>> return env
>>>
>>> if __name__ == "__main__":
... logger = CSVLogger("dummy", video_format="mp4")
...
... env = ParallelEnv(16, EnvCreator(make_env))
... env.start()
... # Comment this line to execute per-env
... env = env.append_transform(PixelRenderTransform())
...
... env = env.append_transform(VideoRecorder(logger=logger, tag="pixels_record"))
... env.rollout(3)
...
... check_env_specs(env)
...
... r = env.rollout(30)
... env.transform.dump()
... env.close()


.. currentmodule:: torchrl.record

Recorders are transforms that register data as they come in, for logging purposes.
Expand All @@ -769,6 +838,7 @@ Recorders are transforms that register data as they come in, for logging purpose

TensorDictRecorder
VideoRecorder
PixelRenderTransform


Helpers
Expand Down
63 changes: 33 additions & 30 deletions docs/source/reference/trainers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ loop the optimization steps. We believe this fits multiple RL training schemes,
on-policy, off-policy, model-based and model-free solutions, offline RL and others.
More particular cases, such as meta-RL algorithms may have training schemes that differ substentially.

The :obj:`trainer.train()` method can be sketched as follows:
The ``trainer.train()`` method can be sketched as follows:

.. code-block::
:caption: Trainer loops
Expand Down Expand Up @@ -63,35 +63,35 @@ The :obj:`trainer.train()` method can be sketched as follows:
... self._post_steps_hook() # "post_steps"
... self._post_steps_log_hook(batch) # "post_steps_log"
There are 10 hooks that can be used in a trainer loop: :obj:`"batch_process"`, :obj:`"pre_optim_steps"`,
:obj:`"process_optim_batch"`, :obj:`"post_loss"`, :obj:`"post_steps"`, :obj:`"post_optim"`, :obj:`"pre_steps_log"`,
:obj:`"post_steps_log"`, :obj:`"post_optim_log"` and :obj:`"optimizer"`. They are indicated in the comments where they are applied.
Hooks can be split into 3 categories: **data processing** (:obj:`"batch_process"` and :obj:`"process_optim_batch"`),
**logging** (:obj:`"pre_steps_log"`, :obj:`"post_optim_log"` and :obj:`"post_steps_log"`) and **operations** hook
(:obj:`"pre_optim_steps"`, :obj:`"post_loss"`, :obj:`"post_optim"` and :obj:`"post_steps"`).

- **Data processing** hooks update a tensordict of data. Hooks :obj:`__call__` method should accept
a :obj:`TensorDict` object as input and update it given some strategy.
Examples of such hooks include Replay Buffer extension (:obj:`ReplayBufferTrainer.extend`), data normalization (including normalization
constants update), data subsampling (:class:`~torchrl.trainers.BatchSubSampler`) and such.

- **Logging** hooks take a batch of data presented as a :obj:`TensorDict` and write in the logger
some information retrieved from that data. Examples include the :obj:`Recorder` hook, the reward
logger (:obj:`LogReward`) and such. Hooks should return a dictionary (or a None value) containing the
data to log. The key :obj:`"log_pbar"` is reserved to boolean values indicating if the logged value
There are 10 hooks that can be used in a trainer loop: ``"batch_process"``, ``"pre_optim_steps"``,
``"process_optim_batch"``, ``"post_loss"``, ``"post_steps"``, ``"post_optim"``, ``"pre_steps_log"``,
``"post_steps_log"``, ``"post_optim_log"`` and ``"optimizer"``. They are indicated in the comments where they are applied.
Hooks can be split into 3 categories: **data processing** (``"batch_process"`` and ``"process_optim_batch"``),
**logging** (``"pre_steps_log"``, ``"post_optim_log"`` and ``"post_steps_log"``) and **operations** hook
(``"pre_optim_steps"``, ``"post_loss"``, ``"post_optim"`` and ``"post_steps"``).

- **Data processing** hooks update a tensordict of data. Hooks ``__call__`` method should accept
a ``TensorDict`` object as input and update it given some strategy.
Examples of such hooks include Replay Buffer extension (``ReplayBufferTrainer.extend``), data normalization (including normalization
constants update), data subsampling (:class:``~torchrl.trainers.BatchSubSampler``) and such.

- **Logging** hooks take a batch of data presented as a ``TensorDict`` and write in the logger
some information retrieved from that data. Examples include the ``Recorder`` hook, the reward
logger (``LogReward``) and such. Hooks should return a dictionary (or a None value) containing the
data to log. The key ``"log_pbar"`` is reserved to boolean values indicating if the logged value
should be displayed on the progression bar printed on the training log.

- **Operation** hooks are hooks that execute specific operations over the models, data collectors,
target network updates and such. For instance, syncing the weights of the collectors using :obj:`UpdateWeights`
or update the priority of the replay buffer using :obj:`ReplayBufferTrainer.update_priority` are examples
of operation hooks. They are data-independent (they do not require a :obj:`TensorDict`
target network updates and such. For instance, syncing the weights of the collectors using ``UpdateWeights``
or update the priority of the replay buffer using ``ReplayBufferTrainer.update_priority`` are examples
of operation hooks. They are data-independent (they do not require a ``TensorDict``
input), they are just supposed to be executed once at every iteration (or every N iterations).

The hooks provided by TorchRL usually inherit from a common abstract class :obj:`TrainerHookBase`,
and all implement three base methods: a :obj:`state_dict` and :obj:`load_state_dict` method for
checkpointing and a :obj:`register` method that registers the hook at the default value in the
The hooks provided by TorchRL usually inherit from a common abstract class ``TrainerHookBase``,
and all implement three base methods: a ``state_dict`` and ``load_state_dict`` method for
checkpointing and a ``register`` method that registers the hook at the default value in the
trainer. This method takes a trainer and a module name as input. For instance, the following logging
hook is executed every 10 calls to :obj:`"post_optim_log"`:
hook is executed every 10 calls to ``"post_optim_log"``:

.. code-block::
Expand Down Expand Up @@ -122,22 +122,22 @@ Checkpointing
-------------

The trainer class and hooks support checkpointing, which can be achieved either
using the `torchsnapshot <https://github.com/pytorch/torchsnapshot/>`_ backend or
the regular torch backend. This can be controlled via the global variable :obj:`CKPT_BACKEND`:
using the ``torchsnapshot <https://github.com/pytorch/torchsnapshot/>``_ backend or
the regular torch backend. This can be controlled via the global variable ``CKPT_BACKEND``:

.. code-block::
$ CKPT_BACKEND=torch python script.py
which defaults to :obj:`torchsnapshot`. The advantage of torchsnapshot over pytorch
which defaults to ``torchsnapshot``. The advantage of torchsnapshot over pytorch
is that it is a more flexible API, which supports distributed checkpointing and
also allows users to load tensors from a file stored on disk to a tensor with a
physical storage (which pytorch currently does not support). This allows, for instance,
to load tensors from and to a replay buffer that would otherwise not fit in memory.

When building a trainer, one can provide a file path where the checkpoints are to
be written. With the :obj:`torchsnapshot` backend, a directory path is expected,
whereas the :obj:`torch` backend expects a file path (typically a :obj:`.pt` file).
be written. With the ``torchsnapshot`` backend, a directory path is expected,
whereas the ``torch`` backend expects a file path (typically a ``.pt`` file).

.. code-block::
Expand All @@ -157,7 +157,7 @@ whereas the :obj:`torch` backend expects a file path (typically a :obj:`.pt` fi
>>> # to load from a path
>>> trainer.load_from_file(filepath)
The :obj:`Trainer.train()` method can be used to execute the above loop with all of
The ``Trainer.train()`` method can be used to execute the above loop with all of
its hooks, although using the :obj:`Trainer` class for its checkpointing capability
only is also a perfectly valid use.

Expand Down Expand Up @@ -238,6 +238,8 @@ Loggers
Recording utils
---------------

Recording utils are detailed :ref:`here <Environment-Recorders>`.

.. currentmodule:: torchrl.record

.. autosummary::
Expand All @@ -246,3 +248,4 @@ Recording utils

VideoRecorder
TensorDictRecorder
PixelRenderTransform
Loading

0 comments on commit e1c6b98

Please sign in to comment.