From df749a30d0710404c0383ca79c6988c68911fc5d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 23 Apr 2024 16:38:02 +0100 Subject: [PATCH 1/2] [Feature] A PixelRenderTransform (#2099) --- docs/source/reference/data.rst | 1 + docs/source/reference/envs.rst | 70 +++++++++ docs/source/reference/trainers.rst | 63 ++++---- test/test_loggers.py | 40 ++++- test/test_specs.py | 28 ++++ torchrl/data/__init__.py | 1 + torchrl/data/tensor_specs.py | 125 ++++++++++++++- torchrl/envs/batched_envs.py | 1 + torchrl/envs/common.py | 7 +- torchrl/envs/utils.py | 11 +- torchrl/record/__init__.py | 2 +- torchrl/record/recorder.py | 234 +++++++++++++++++++++++++++-- 12 files changed, 532 insertions(+), 51 deletions(-) diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 43d5bfe0d00..efb1a755d0e 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -815,6 +815,7 @@ Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check. DiscreteTensorSpec MultiDiscreteTensorSpec MultiOneHotDiscreteTensorSpec + NonTensorSpec OneHotDiscreteTensorSpec UnboundedContinuousTensorSpec UnboundedDiscreteTensorSpec diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index fe72ea89a56..5c39c5a1349 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -759,6 +759,75 @@ to always know what the latest available actions are. You can do this like so: Recorders --------- +.. _Environment-Recorders: + +Recording data during environment rollout execution is crucial to keep an eye on the algorithm performance as well as +reporting results after training. + +TorchRL offers several tools to interact with the environment output: first and foremost, a ``callback`` callable +can be passed to the :meth:`~torchrl.envs.EnvBase.rollout` method. This function will be called upon the collected +tensordict at each iteration of the rollout (if some iterations have to be skipped, an internal variable should be added +to keep track of the call count within ``callback``). + +To save collected tensordicts on disk, the :class:`~torchrl.record.TensorDictRecorder` can be used. + +Recording videos +~~~~~~~~~~~~~~~~ + +Several backends offer the possibility of recording rendered images from the environment. +If the pixels are already part of the environment output (e.g. Atari or other game simulators), a +:class:`~torchrl.record.VideoRecorder` can be appended to the environment. This environment transform takes as input +a logger capable of recording videos (e.g. :class:`~torchrl.record.loggers.CSVLogger`, :class:`~torchrl.record.loggers.WandbLogger` +or :class:`~torchrl.record.loggers.TensorBoardLogger`) as well as a tag indicating where the video should be saved. +For instance, to save mp4 videos on disk, one can use :class:`~torchrl.record.loggers.CSVLogger` with a `video_format="mp4"` +argument. + +The :class:`~torchrl.record.VideoRecorder` transform can handle batched images and automatically detects numpy or PyTorch +formatted images (WHC or CWH). + + >>> logger = CSVLogger("dummy-exp", video_format="mp4") + >>> env = GymEnv("ALE/Pong-v5") + >>> env = env.append_transform(VideoRecorder(logger, tag="rendered", in_keys=["pixels"])) + >>> env.rollout(10) + >>> env.transform.dump() # Save the video and clear cache + +Note that the cache of the transform will keep on growing until dump is called. It is the user responsibility to +take care of calling dumpy when needed to avoid OOM issues. + +In some cases, creating a testing environment where images can be collected is tedious or expensive, or simply impossible +(some libraries only allow one environment instance per workspace). +In these cases, assuming that a `render` method is available in the environment, the :class:`~torchrl.record.PixelRenderTransform` +can be used to call `render` on the parent environment and save the images in the rollout data stream. +This class works over single and batched environments alike: + + >>> from torchrl.envs import GymEnv, check_env_specs, ParallelEnv, EnvCreator + >>> from torchrl.record.loggers import CSVLogger + >>> from torchrl.record.recorder import PixelRenderTransform, VideoRecorder + >>> + >>> def make_env(): + >>> env = GymEnv("CartPole-v1", render_mode="rgb_array") + >>> # Uncomment this line to execute per-env + >>> # env = env.append_transform(PixelRenderTransform()) + >>> return env + >>> + >>> if __name__ == "__main__": + ... logger = CSVLogger("dummy", video_format="mp4") + ... + ... env = ParallelEnv(16, EnvCreator(make_env)) + ... env.start() + ... # Comment this line to execute per-env + ... env = env.append_transform(PixelRenderTransform()) + ... + ... env = env.append_transform(VideoRecorder(logger=logger, tag="pixels_record")) + ... env.rollout(3) + ... + ... check_env_specs(env) + ... + ... r = env.rollout(30) + ... env.transform.dump() + ... env.close() + + .. currentmodule:: torchrl.record Recorders are transforms that register data as they come in, for logging purposes. @@ -769,6 +838,7 @@ Recorders are transforms that register data as they come in, for logging purpose TensorDictRecorder VideoRecorder + PixelRenderTransform Helpers diff --git a/docs/source/reference/trainers.rst b/docs/source/reference/trainers.rst index 04d4386c631..821902b2ee2 100644 --- a/docs/source/reference/trainers.rst +++ b/docs/source/reference/trainers.rst @@ -9,7 +9,7 @@ loop the optimization steps. We believe this fits multiple RL training schemes, on-policy, off-policy, model-based and model-free solutions, offline RL and others. More particular cases, such as meta-RL algorithms may have training schemes that differ substentially. -The :obj:`trainer.train()` method can be sketched as follows: +The ``trainer.train()`` method can be sketched as follows: .. code-block:: :caption: Trainer loops @@ -63,35 +63,35 @@ The :obj:`trainer.train()` method can be sketched as follows: ... self._post_steps_hook() # "post_steps" ... self._post_steps_log_hook(batch) # "post_steps_log" -There are 10 hooks that can be used in a trainer loop: :obj:`"batch_process"`, :obj:`"pre_optim_steps"`, -:obj:`"process_optim_batch"`, :obj:`"post_loss"`, :obj:`"post_steps"`, :obj:`"post_optim"`, :obj:`"pre_steps_log"`, -:obj:`"post_steps_log"`, :obj:`"post_optim_log"` and :obj:`"optimizer"`. They are indicated in the comments where they are applied. -Hooks can be split into 3 categories: **data processing** (:obj:`"batch_process"` and :obj:`"process_optim_batch"`), -**logging** (:obj:`"pre_steps_log"`, :obj:`"post_optim_log"` and :obj:`"post_steps_log"`) and **operations** hook -(:obj:`"pre_optim_steps"`, :obj:`"post_loss"`, :obj:`"post_optim"` and :obj:`"post_steps"`). - -- **Data processing** hooks update a tensordict of data. Hooks :obj:`__call__` method should accept - a :obj:`TensorDict` object as input and update it given some strategy. - Examples of such hooks include Replay Buffer extension (:obj:`ReplayBufferTrainer.extend`), data normalization (including normalization - constants update), data subsampling (:class:`~torchrl.trainers.BatchSubSampler`) and such. - -- **Logging** hooks take a batch of data presented as a :obj:`TensorDict` and write in the logger - some information retrieved from that data. Examples include the :obj:`Recorder` hook, the reward - logger (:obj:`LogReward`) and such. Hooks should return a dictionary (or a None value) containing the - data to log. The key :obj:`"log_pbar"` is reserved to boolean values indicating if the logged value +There are 10 hooks that can be used in a trainer loop: ``"batch_process"``, ``"pre_optim_steps"``, +``"process_optim_batch"``, ``"post_loss"``, ``"post_steps"``, ``"post_optim"``, ``"pre_steps_log"``, +``"post_steps_log"``, ``"post_optim_log"`` and ``"optimizer"``. They are indicated in the comments where they are applied. +Hooks can be split into 3 categories: **data processing** (``"batch_process"`` and ``"process_optim_batch"``), +**logging** (``"pre_steps_log"``, ``"post_optim_log"`` and ``"post_steps_log"``) and **operations** hook +(``"pre_optim_steps"``, ``"post_loss"``, ``"post_optim"`` and ``"post_steps"``). + +- **Data processing** hooks update a tensordict of data. Hooks ``__call__`` method should accept + a ``TensorDict`` object as input and update it given some strategy. + Examples of such hooks include Replay Buffer extension (``ReplayBufferTrainer.extend``), data normalization (including normalization + constants update), data subsampling (:class:``~torchrl.trainers.BatchSubSampler``) and such. + +- **Logging** hooks take a batch of data presented as a ``TensorDict`` and write in the logger + some information retrieved from that data. Examples include the ``Recorder`` hook, the reward + logger (``LogReward``) and such. Hooks should return a dictionary (or a None value) containing the + data to log. The key ``"log_pbar"`` is reserved to boolean values indicating if the logged value should be displayed on the progression bar printed on the training log. - **Operation** hooks are hooks that execute specific operations over the models, data collectors, - target network updates and such. For instance, syncing the weights of the collectors using :obj:`UpdateWeights` - or update the priority of the replay buffer using :obj:`ReplayBufferTrainer.update_priority` are examples - of operation hooks. They are data-independent (they do not require a :obj:`TensorDict` + target network updates and such. For instance, syncing the weights of the collectors using ``UpdateWeights`` + or update the priority of the replay buffer using ``ReplayBufferTrainer.update_priority`` are examples + of operation hooks. They are data-independent (they do not require a ``TensorDict`` input), they are just supposed to be executed once at every iteration (or every N iterations). -The hooks provided by TorchRL usually inherit from a common abstract class :obj:`TrainerHookBase`, -and all implement three base methods: a :obj:`state_dict` and :obj:`load_state_dict` method for -checkpointing and a :obj:`register` method that registers the hook at the default value in the +The hooks provided by TorchRL usually inherit from a common abstract class ``TrainerHookBase``, +and all implement three base methods: a ``state_dict`` and ``load_state_dict`` method for +checkpointing and a ``register`` method that registers the hook at the default value in the trainer. This method takes a trainer and a module name as input. For instance, the following logging -hook is executed every 10 calls to :obj:`"post_optim_log"`: +hook is executed every 10 calls to ``"post_optim_log"``: .. code-block:: @@ -122,22 +122,22 @@ Checkpointing ------------- The trainer class and hooks support checkpointing, which can be achieved either -using the `torchsnapshot `_ backend or -the regular torch backend. This can be controlled via the global variable :obj:`CKPT_BACKEND`: +using the ``torchsnapshot ``_ backend or +the regular torch backend. This can be controlled via the global variable ``CKPT_BACKEND``: .. code-block:: $ CKPT_BACKEND=torch python script.py -which defaults to :obj:`torchsnapshot`. The advantage of torchsnapshot over pytorch +which defaults to ``torchsnapshot``. The advantage of torchsnapshot over pytorch is that it is a more flexible API, which supports distributed checkpointing and also allows users to load tensors from a file stored on disk to a tensor with a physical storage (which pytorch currently does not support). This allows, for instance, to load tensors from and to a replay buffer that would otherwise not fit in memory. When building a trainer, one can provide a file path where the checkpoints are to -be written. With the :obj:`torchsnapshot` backend, a directory path is expected, -whereas the :obj:`torch` backend expects a file path (typically a :obj:`.pt` file). +be written. With the ``torchsnapshot`` backend, a directory path is expected, +whereas the ``torch`` backend expects a file path (typically a ``.pt`` file). .. code-block:: @@ -157,7 +157,7 @@ whereas the :obj:`torch` backend expects a file path (typically a :obj:`.pt` fi >>> # to load from a path >>> trainer.load_from_file(filepath) -The :obj:`Trainer.train()` method can be used to execute the above loop with all of +The ``Trainer.train()`` method can be used to execute the above loop with all of its hooks, although using the :obj:`Trainer` class for its checkpointing capability only is also a perfectly valid use. @@ -238,6 +238,8 @@ Loggers Recording utils --------------- +Recording utils are detailed :ref:`here `. + .. currentmodule:: torchrl.record .. autosummary:: @@ -246,3 +248,4 @@ Recording utils VideoRecorder TensorDictRecorder + PixelRenderTransform diff --git a/test/test_loggers.py b/test/test_loggers.py index 98a330d0daf..f51b9d290ab 100644 --- a/test/test_loggers.py +++ b/test/test_loggers.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import argparse +import importlib.util import os import os.path import pathlib @@ -12,12 +13,14 @@ import pytest import torch - from tensordict import MemoryMappedTensor + +from torchrl.envs import check_env_specs, GymEnv, ParallelEnv from torchrl.record.loggers.csv import CSVLogger from torchrl.record.loggers.mlflow import _has_mlflow, _has_tv, MLFlowLogger from torchrl.record.loggers.tensorboard import _has_tb, TensorboardLogger from torchrl.record.loggers.wandb import _has_wandb, WandbLogger +from torchrl.record.recorder import PixelRenderTransform, VideoRecorder if _has_tv: import torchvision @@ -28,6 +31,11 @@ if _has_mlflow: import mlflow +_has_gym = ( + importlib.util.find_spec("gym", None) is not None + or importlib.util.find_spec("gymnasium", None) is not None +) + @pytest.fixture def tb_logger(tmp_path_factory): @@ -397,6 +405,36 @@ def test_log_hparams(self, mlflow_fixture, config): logger.log_hparams(config) +@pytest.mark.skipif(not _has_gym, reason="gym required to test rendering") +class TestPixelRenderTransform: + @pytest.mark.parametrize("parallel", [False, True]) + @pytest.mark.parametrize("in_key", ["pixels", ("nested", "pix")]) + def test_pixel_render(self, parallel, in_key, tmpdir): + def make_env(): + env = GymEnv("CartPole-v1", render_mode="rgb_array", device=None) + env = env.append_transform(PixelRenderTransform(out_keys=in_key)) + return env + + if parallel: + env = ParallelEnv(2, make_env, mp_start_method="spawn") + else: + env = make_env() + logger = CSVLogger("dummy", log_dir=tmpdir) + try: + env = env.append_transform( + VideoRecorder(logger=logger, in_keys=[in_key], tag="pixels_record") + ) + check_env_specs(env) + env.rollout(10) + env.transform.dump() + assert os.path.isfile( + os.path.join(tmpdir, "dummy", "videos", "pixels_record_0.pt") + ) + finally: + if not env.is_closed: + env.close() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_specs.py b/test/test_specs.py index 36f5aef65ca..058144c1a94 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -23,6 +23,7 @@ LazyStackedCompositeSpec, MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, + NonTensorSpec, OneHotDiscreteTensorSpec, TensorSpec, UnboundedContinuousTensorSpec, @@ -1462,6 +1463,14 @@ def test_multionehot(self, shape1, shape2): assert spec2.rand().shape == spec2.shape assert spec2.zero().shape == spec2.shape + def test_non_tensor(self): + spec = NonTensorSpec((3, 4), device="cpu") + assert ( + spec.expand(2, 3, 4) + == spec.expand((2, 3, 4)) + == NonTensorSpec((2, 3, 4), device="cpu") + ) + @pytest.mark.parametrize("shape1", [None, (), (5,)]) @pytest.mark.parametrize("shape2", [(), (10,)]) def test_onehot(self, shape1, shape2): @@ -1675,6 +1684,11 @@ def test_multionehot( assert spec == spec.clone() assert spec is not spec.clone() + def test_non_tensor(self): + spec = NonTensorSpec(shape=(3, 4), device="cpu") + assert spec.clone() == spec + assert spec.clone() is not spec + @pytest.mark.parametrize("shape1", [None, (), (5,)]) def test_onehot( self, @@ -1840,6 +1854,11 @@ def test_multionehot( with pytest.raises(ValueError): spec.unbind(-1) + def test_non_tensor(self): + spec = NonTensorSpec(shape=(3, 4), device="cpu") + assert spec.unbind(1)[0] == spec[:, 0] + assert spec.unbind(1)[0] is not spec[:, 0] + @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) def test_onehot( self, @@ -2114,6 +2133,15 @@ def test_stack_multionehot_zero(self, shape, stack_dim): r = c.zero() assert r.shape == c.shape + def test_stack_non_tensor(self, shape, stack_dim): + spec0 = NonTensorSpec(shape=shape, device="cpu") + spec1 = NonTensorSpec(shape=shape, device="cpu") + new_spec = torch.stack([spec0, spec1], stack_dim) + shape_insert = list(shape) + shape_insert.insert(stack_dim, 2) + assert new_spec.shape == torch.Size(shape_insert) + assert new_spec.device == torch.device("cpu") + def test_stack_onehot(self, shape, stack_dim): n = 5 shape = (*shape, 5) diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index cb84ce32586..bc512a585b7 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -51,6 +51,7 @@ LazyStackedTensorSpec, MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, + NonTensorSpec, OneHotDiscreteTensorSpec, TensorSpec, UnboundedContinuousTensorSpec, diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 71598938eab..c9d0683ad9c 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -31,7 +31,13 @@ import numpy as np import torch -from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase, unravel_key +from tensordict import ( + LazyStackedTensorDict, + NonTensorData, + TensorDict, + TensorDictBase, + unravel_key, +) from tensordict.utils import _getitem_batch_size, NestedKey from torchrl._utils import get_binary_env_var @@ -715,8 +721,9 @@ def _flatten(self, start_dim, end_dim): shape = torch.zeros(self.shape, device="meta").flatten(start_dim, end_dim).shape return self._reshape(shape) + @abc.abstractmethod def _project(self, val: torch.Tensor) -> torch.Tensor: - raise NotImplementedError + raise NotImplementedError(type(self)) @abc.abstractmethod def is_in(self, val: torch.Tensor) -> bool: @@ -1917,6 +1924,107 @@ def _is_nested_list(index, notuple=False): return False +class NonTensorSpec(TensorSpec): + """A spec for non-tensor data.""" + + def __init__( + self, + shape: Union[torch.Size, int] = _DEFAULT_SHAPE, + device: Optional[DEVICE_TYPING] = None, + dtype: torch.dtype | None = None, + **kwargs, + ): + if isinstance(shape, int): + shape = torch.Size([shape]) + + _, device = _default_dtype_and_device(None, device) + domain = None + super().__init__( + shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs + ) + + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensorSpec: + if isinstance(dest, torch.dtype): + dest_dtype = dest + dest_device = self.device + elif dest is None: + return self + else: + dest_dtype = self.dtype + dest_device = torch.device(dest) + if dest_device == self.device and dest_dtype == self.dtype: + return self + return self.__class__(shape=self.shape, device=dest_device, dtype=None) + + def clone(self) -> NonTensorSpec: + return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype) + + def rand(self, shape): + return NonTensorData(data=None, shape=self.shape, device=self.device) + + def zero(self, shape): + return NonTensorData(data=None, shape=self.shape, device=self.device) + + def one(self, shape): + return NonTensorData(data=None, shape=self.shape, device=self.device) + + def is_in(self, val: torch.Tensor) -> bool: + shape = torch.broadcast_shapes(self.shape, val.shape) + return ( + isinstance(val, NonTensorData) + and val.shape == shape + and val.device == self.device + and val.dtype == self.dtype + ) + + def expand(self, *shape): + if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): + shape = shape[0] + shape = torch.Size(shape) + if not all( + (old == 1) or (old == new) + for old, new in zip(self.shape, shape[-len(self.shape) :]) + ): + raise ValueError( + f"The last elements of the expanded shape must match the current one. Got shape={shape} while self.shape={self.shape}." + ) + return self.__class__(shape=shape, device=self.device, dtype=None) + + def _reshape(self, shape): + return self.__class__(shape=shape, device=self.device, dtype=self.dtype) + + def _unflatten(self, dim, sizes): + shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape + return self.__class__( + shape=shape, + device=self.device, + dtype=self.dtype, + ) + + def __getitem__(self, idx: SHAPE_INDEX_TYPING): + """Indexes the current TensorSpec based on the provided index.""" + indexed_shape = torch.Size(_shape_indexing(self.shape, idx)) + return self.__class__(shape=indexed_shape, device=self.device, dtype=self.dtype) + + def unbind(self, dim: int): + orig_dim = dim + if dim < 0: + dim = len(self.shape) + dim + if dim < 0: + raise ValueError( + f"Cannot unbind along dim {orig_dim} with shape {self.shape}." + ) + shape = tuple(s for i, s in enumerate(self.shape) if i != dim) + return tuple( + self.__class__( + shape=shape, + device=self.device, + dtype=self.dtype, + ) + for i in range(self.shape[dim]) + ) + + @dataclass(repr=False) class UnboundedContinuousTensorSpec(TensorSpec): """An unbounded continuous tensor spec. @@ -1954,7 +2062,9 @@ def __init__( shape=shape, space=box, device=device, dtype=dtype, domain=domain, **kwargs ) - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + def to( + self, dest: Union[torch.dtype, DEVICE_TYPING] + ) -> UnboundedContinuousTensorSpec: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device @@ -1979,7 +2089,11 @@ def rand(self, shape=None) -> torch.Tensor: return torch.empty(shape, device=self.device, dtype=self.dtype).random_() def is_in(self, val: torch.Tensor) -> bool: - return True + shape = torch.broadcast_shapes(self.shape, val.shape) + return val.shape == shape and val.dtype == self.dtype + + def _project(self, val: torch.Tensor) -> torch.Tensor: + return torch.as_tensor(val, dtype=self.dtype).reshape(self.shape) def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): @@ -2130,7 +2244,8 @@ def rand(self, shape=None) -> torch.Tensor: return r.to(self.device) def is_in(self, val: torch.Tensor) -> bool: - return True + shape = torch.broadcast_shapes(self.shape, val.shape) + return val.shape == shape and val.dtype == self.dtype def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 660aecb3fd8..b3026da35ca 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1624,6 +1624,7 @@ def __getattr__(self, attr: str) -> Any: try: # _ = getattr(self._dummy_env, attr) if self.is_closed: + self.start() raise RuntimeError( "Trying to access attributes of closed/non started " "environments. Check that the batched environment " diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index f5d4625fd07..8712c74340a 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -2298,7 +2298,7 @@ def rollout( self, max_steps: int, policy: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, - callback: Optional[Callable[[TensorDictBase, ...], TensorDictBase]] = None, + callback: Optional[Callable[[TensorDictBase, ...], Any]] = None, auto_reset: bool = True, auto_cast_to_device: bool = False, break_when_any_done: bool = True, @@ -2320,7 +2320,10 @@ def rollout( The policy can be any callable that reads either a tensordict or the entire sequence of observation entries __sorted as__ the ``env.observation_spec.keys()``. Defaults to `None`. - callback (callable, optional): function to be called at each iteration with the given TensorDict. + callback (Callable[[TensorDict], Any], optional): function to be called at each iteration with the given + TensorDict. Defaults to ``None``. The output of ``callback`` will not be collected, it is the user + responsibility to save any result within the callback call if data needs to be carried over beyond + the call to ``rollout``. auto_reset (bool, optional): if ``True``, resets automatically the environment if it is in a done state when the rollout is initiated. Default is ``True``. diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 49cf58f8103..cd51b4fd23b 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -16,6 +16,7 @@ from enum import Enum from typing import Any, Dict, List, Union +import tensordict import torch from tensordict import ( @@ -25,6 +26,7 @@ TensorDictBase, unravel_key, ) +from tensordict.base import _is_leaf_nontensor from tensordict.nn import TensorDictModule, TensorDictModuleBase from tensordict.nn.probabilistic import ( # noqa # Note: the `set_interaction_mode` and their associated arg `default_interaction_mode` are being deprecated! @@ -183,10 +185,15 @@ def _is_reset(key: NestedKey): return key == "_reset" return key[-1] == "_reset" - actual = {key for key in tensordict.keys(True, True) if not _is_reset(key)} + actual = { + key + for key in tensordict.keys(True, True, is_leaf=_is_leaf_nontensor) + if not _is_reset(key) + } expected = set(expected) self.validated = expected.intersection(actual) == expected if not self.validated: + raise RuntimeError warnings.warn( "The expected key set and actual key set differ. " "This will work but with a slower throughput than " @@ -262,7 +269,7 @@ def _exclude( cls._exclude(nested_key_dict, td, td_out) return out has_set = False - for key, value in data_in.items(): + for key, value in data_in.items(is_leaf=tensordict.base._is_leaf_nontensor): subdict = nested_key_dict.get(key, NO_DEFAULT) if subdict is NO_DEFAULT: value = value.copy() if is_tensor_collection(value) else value diff --git a/torchrl/record/__init__.py b/torchrl/record/__init__.py index 726d29ea051..f6c9bcdefbb 100644 --- a/torchrl/record/__init__.py +++ b/torchrl/record/__init__.py @@ -4,4 +4,4 @@ # LICENSE file in the root directory of this source tree. from .loggers import CSVLogger, MLFlowLogger, TensorboardLogger, WandbLogger -from .recorder import TensorDictRecorder, VideoRecorder +from .recorder import PixelRenderTransform, TensorDictRecorder, VideoRecorder diff --git a/torchrl/record/recorder.py b/torchrl/record/recorder.py index a486b689feb..079c8b71e12 100644 --- a/torchrl/record/recorder.py +++ b/torchrl/record/recorder.py @@ -6,14 +6,20 @@ import importlib.util from copy import copy -from typing import Optional, Sequence +from typing import Callable, List, Optional, Sequence, Union +import numpy as np import torch -from tensordict import TensorDictBase +from tensordict import NonTensorData, TensorDict, TensorDictBase from tensordict.utils import NestedKey +from torchrl._utils import _can_be_pickled +from torchrl.data import TensorSpec +from torchrl.data.tensor_specs import NonTensorSpec, UnboundedContinuousTensorSpec +from torchrl.data.utils import CloudpickleWrapper +from torchrl.envs import EnvBase from torchrl.envs.transforms import ObservationTransform, Transform from torchrl.record.loggers import Logger @@ -155,20 +161,22 @@ def skip(self, value): self._skip = value def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: + if isinstance(observation, NonTensorData): + observation_trsf = torch.tensor(observation.data) + else: + observation_trsf = observation self.count += 1 if self.count % self.skip == 0: if ( - observation.ndim >= 3 - and observation.shape[-3] == 3 - and observation.shape[-2] > 3 - and observation.shape[-1] > 3 + observation_trsf.ndim >= 3 + and observation_trsf.shape[-3] == 3 + and observation_trsf.shape[-2] > 3 + and observation_trsf.shape[-1] > 3 ): # permute the channels to the last dim - observation_trsf = observation.permute( - *range(observation.ndim - 3), -2, -1, -3 + observation_trsf = observation_trsf.permute( + *range(observation_trsf.ndim - 3), -2, -1, -3 ) - else: - observation_trsf = observation if not ( observation_trsf.shape[-1] == 3 or observation_trsf.ndimension() == 2 ): @@ -321,3 +329,209 @@ def _reset( ) -> TensorDictBase: self._call(tensordict_reset) return tensordict_reset + + +class PixelRenderTransform(Transform): + """A transform to call render on the parent environment and register the pixel observation in the tensordict. + + This transform offers an alternative to the ``from_pixels`` syntatic sugar when instantiating an environment + that offers rendering is expensive, or when ``from_pixels`` is not implemented. + It can be used within a single environment or over batched environments alike. + + Args: + out_keys (List[NestedKey] or Nested): List of keys where to register the pixel observations. + preproc (Callable, optional): a preproc function. Can be used to reshape the observation, or apply + any other transformation that makes it possible to register it in the output data. + as_non_tensor (bool, optional): if ``True``, the data will be written as a :class:`~tensordict.NonTensorData` + thereby relaxing the shape requirements. If not provided, it will be inferred automatically from the + input data type and shape. + render_method (str, optional): the name of the render method. Defaults to ``"render"``. + **kwargs: additional keyword arguments to pass to the render function (e.g. ``mode="rgb_array"``). + + Examples: + >>> from torchrl.envs import GymEnv, check_env_specs, ParallelEnv, EnvCreator + >>> from torchrl.record.loggers import CSVLogger + >>> from torchrl.record.recorder import PixelRenderTransform, VideoRecorder + >>> + >>> def make_env(): + >>> env = GymEnv("CartPole-v1", render_mode="rgb_array") + >>> env = env.append_transform(PixelRenderTransform()) + >>> return env + >>> + >>> if __name__ == "__main__": + ... logger = CSVLogger("dummy", video_format="mp4") + ... + ... env = ParallelEnv(4, EnvCreator(make_env)) + ... + ... env = env.append_transform(VideoRecorder(logger=logger, tag="pixels_record")) + ... env.rollout(3) + ... + ... check_env_specs(env) + ... + ... r = env.rollout(30) + ... print(env) + ... env.transform.dump() + ... env.close() + + This transform can also be used whenever a batched environment ``render()`` returns a single image: + + Examples: + >>> from torchrl.envs import check_env_specs + >>> from torchrl.envs.libs.vmas import VmasEnv + >>> from torchrl.record.loggers import CSVLogger + >>> from torchrl.record.recorder import PixelRenderTransform, VideoRecorder + >>> + >>> env = VmasEnv( + ... scenario="flocking", + ... num_envs=32, + ... continuous_actions=True, + ... max_steps=200, + ... device="cpu", + ... seed=None, + ... # Scenario kwargs + ... n_agents=5, + ... ) + >>> + >>> logger = CSVLogger("dummy", video_format="mp4") + >>> + >>> env = env.append_transform(PixelRenderTransform(mode="rgb_array", preproc=lambda x: x.copy())) + >>> env = env.append_transform(VideoRecorder(logger=logger, tag="pixels_record")) + >>> + >>> check_env_specs(env) + >>> + >>> r = env.rollout(30) + >>> env.transform[-1].dump() + + The transform can be disabled using the :meth:`~torchrl.record.PixelRenderTransform.switch` method, which will + turn the rendering on if it's off or off if it's on (an argument can also be passed to control this behaviour). + Since transforms are :class:`~torch.nn.Module` instances, :meth:`~torch.nn.Module.apply` can be used to control + this behaviour: + + >>> def switch(module): + ... if isinstance(module, PixelRenderTransform): + ... module.switch() + >>> env.apply(switch) + + """ + + def __init__( + self, + out_keys: List[NestedKey] = None, + preproc: Callable[ + [np.ndarray | torch.Tensor], np.ndarray | torch.Tensor + ] = None, + as_non_tensor: bool = None, + render_method: str = "render", + **kwargs, + ) -> None: + if out_keys is None: + out_keys = ["pixels"] + elif isinstance(out_keys, (str, tuple)): + out_keys = [out_keys] + if len(out_keys) != 1: + raise RuntimeError( + f"Expected one and only one out_key, got out_keys={out_keys}" + ) + if preproc is not None and not _can_be_pickled(preproc): + preproc = CloudpickleWrapper(preproc) + self.preproc = preproc + self.as_non_tensor = as_non_tensor + self.kwargs = kwargs + self.render_method = render_method + self._enabled = True + super().__init__(in_keys=[], out_keys=out_keys) + + def _reset( + self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase + ) -> TensorDictBase: + return self._call(tensordict_reset) + + def _call(self, tensordict: TensorDictBase) -> TensorDictBase: + if not self._enabled: + return tensordict + + array = getattr(self.parent, self.render_method)(**self.kwargs) + if self.preproc: + array = self.preproc(array) + if self.as_non_tensor is None: + if isinstance(array, list): + if isinstance(array[0], np.ndarray): + array = np.asarray(array) + else: + array = torch.as_tensor(array) + if ( + array.ndim == 3 + and array.shape[-1] == 3 + and self.parent.batch_size != () + ): + self.as_non_tensor = True + else: + self.as_non_tensor = False + if not self.as_non_tensor: + try: + tensordict.set(self.out_keys[0], array) + except Exception: + raise RuntimeError( + f"An exception was raised while writing the rendered array " + f"(shape={getattr(array, 'shape', None)}, dtype={getattr(array, 'dtype', None)}) in the tensordict with shape {tensordict.shape}. " + f"Consider adapting your preproc function in {type(self).__name__}. You can also " + f"pass keyword arguments to the render function of the parent environment, or save " + f"this observation as a non-tensor data with as_non_tensor=True." + ) + else: + tensordict.set_non_tensor(self.out_keys[0], array) + return tensordict + + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + # Adds the pixel observation spec by calling render on the parent env + switch = False + if not self.enabled: + switch = True + self.switch() + parent = self.parent + td_in = TensorDict({}, batch_size=parent.batch_size, device=parent.device) + self._call(td_in) + obs = td_in.get(self.out_keys[0]) + if isinstance(obs, NonTensorData): + spec = NonTensorSpec(device=obs.device, dtype=obs.dtype, shape=obs.shape) + else: + spec = UnboundedContinuousTensorSpec( + device=obs.device, dtype=obs.dtype, shape=obs.shape + ) + observation_spec[self.out_keys[0]] = spec + if switch: + self.switch() + return observation_spec + + def switch(self, mode: str | bool = None): + """Sets the transform on or off. + + Args: + mode (str or bool, optional): if provided, sets the switch to the desired mode. + ``"on"``, ``"off"``, ``True`` and ``False`` are accepted values. + By default, ``switch`` sets the mode to the opposite of the current one. + + """ + if mode is None: + mode = not self._enabled + if not isinstance(mode, bool): + if mode not in ("on", "off"): + raise ValueError("mode must be either 'on' or 'off', or a boolean.") + mode = mode == "on" + self._enabled = mode + + @property + def enabled(self) -> bool: + """Whether the recorder is enabled.""" + return self._enabled + + def set_container(self, container: Union[Transform, EnvBase]) -> None: + out = super().set_container(container) + if isinstance(self.parent, EnvBase): + # Start the env if needed + method = getattr(self.parent, self.render_method, None) + if method is None or not callable(method): + raise ValueError( + f"The render method must exist and be a callable. Got render={method}." + ) + return out From 6c2e141b571391c83c505ccda3d5b8e0379fad8b Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 23 Apr 2024 17:17:20 +0100 Subject: [PATCH 2/2] [Feature] Video recording in SOTA examples (#2070) --- .../linux_examples/scripts/run_test.sh | 32 --------------- sota-implementations/a2c/a2c_atari.py | 9 +++++ sota-implementations/a2c/a2c_mujoco.py | 18 +++++++-- sota-implementations/a2c/config_atari.yaml | 1 + sota-implementations/a2c/config_mujoco.yaml | 1 + sota-implementations/a2c/utils_atari.py | 8 ++++ sota-implementations/a2c/utils_mujoco.py | 16 +++++++- sota-implementations/cql/cql_offline.py | 15 ++++++- sota-implementations/cql/cql_online.py | 23 ++++++++--- .../cql/discrete_cql_config.yaml | 3 +- .../cql/discrete_cql_online.py | 8 +++- sota-implementations/cql/offline_config.yaml | 4 +- sota-implementations/cql/online_config.yaml | 6 +-- sota-implementations/cql/utils.py | 30 ++++++++++---- sota-implementations/ddpg/config.yaml | 4 +- sota-implementations/ddpg/ddpg.py | 12 +++++- sota-implementations/ddpg/utils.py | 32 +++++++++++---- .../decision_transformer/dt.py | 15 ++++++- .../decision_transformer/dt_config.yaml | 3 +- .../decision_transformer/odt_config.yaml | 6 +-- .../decision_transformer/online_dt.py | 15 ++++++- .../decision_transformer/utils.py | 20 +++++++--- sota-implementations/discrete_sac/config.yaml | 5 ++- .../discrete_sac/discrete_sac.py | 19 +++++++-- sota-implementations/discrete_sac/utils.py | 39 +++++++++++++++---- sota-implementations/dqn/config_atari.yaml | 3 +- sota-implementations/dqn/config_cartpole.yaml | 3 +- sota-implementations/dqn/dqn_atari.py | 25 +++++++++--- sota-implementations/dqn/dqn_cartpole.py | 27 ++++++++++--- sota-implementations/dqn/utils_atari.py | 7 ++++ sota-implementations/dqn/utils_cartpole.py | 12 +++++- sota-implementations/iql/discrete_iql.py | 11 +++++- sota-implementations/iql/discrete_iql.yaml | 3 +- sota-implementations/iql/iql_offline.py | 16 +++++++- sota-implementations/iql/iql_online.py | 11 +++++- sota-implementations/iql/offline_config.yaml | 3 +- sota-implementations/iql/online_config.yaml | 3 +- sota-implementations/iql/utils.py | 29 ++++++++++---- sota-implementations/ppo/config_atari.yaml | 1 + sota-implementations/ppo/config_mujoco.yaml | 1 + sota-implementations/ppo/ppo_atari.py | 8 ++++ sota-implementations/ppo/ppo_mujoco.py | 10 ++++- sota-implementations/ppo/utils_atari.py | 12 +++++- sota-implementations/ppo/utils_mujoco.py | 12 +++++- sota-implementations/redq/config.yaml | 2 +- sota-implementations/redq/redq.py | 32 ++++++++------- sota-implementations/redq/utils.py | 3 +- sota-implementations/sac/config.yaml | 3 +- sota-implementations/sac/sac.py | 12 +++++- sota-implementations/sac/utils.py | 35 ++++++++++++----- sota-implementations/td3/config.yaml | 3 +- sota-implementations/td3/td3.py | 12 +++++- sota-implementations/td3/utils.py | 34 +++++++++++----- torchrl/envs/transforms/transforms.py | 18 ++++++++- torchrl/record/loggers/csv.py | 3 +- torchrl/record/loggers/utils.py | 4 +- 56 files changed, 528 insertions(+), 174 deletions(-) diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index 4587be88ddc..bcc688b0a6d 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -36,7 +36,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/de optim.pretrain_gradient_steps=55 \ optim.updates_per_episode=3 \ optim.warmup_steps=10 \ - optim.device=cuda:0 \ logger.backend= \ env.backend=gymnasium \ env.name=HalfCheetah-v4 @@ -44,16 +43,13 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/de optim.pretrain_gradient_steps=55 \ optim.updates_per_episode=3 \ optim.warmup_steps=10 \ - optim.device=cuda:0 \ env.backend=gymnasium \ logger.backend= python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iql/iql_offline.py \ optim.gradient_steps=55 \ - optim.device=cuda:0 \ logger.backend= python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/cql_offline.py \ optim.gradient_steps=55 \ - optim.device=cuda:0 \ logger.backend= # ==================================================================================== # @@ -86,8 +82,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dd optim.batch_size=10 \ collector.frames_per_batch=16 \ collector.env_per_collector=2 \ - collector.device=cuda:0 \ - network.device=cuda:0 \ optim.utd_ratio=1 \ replay_buffer.size=120 \ env.name=Pendulum-v1 \ @@ -112,7 +106,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dq collector.init_random_frames=10 \ collector.frames_per_batch=16 \ buffer.batch_size=10 \ - device=cuda:0 \ loss.num_updates=1 \ logger.backend= \ buffer.buffer_size=120 @@ -122,7 +115,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cq optim.batch_size=10 \ collector.frames_per_batch=16 \ collector.env_per_collector=2 \ - collector.device=cuda:0 \ replay_buffer.size=120 \ logger.backend= python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/redq/redq.py \ @@ -131,7 +123,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/re collector.init_random_frames=10 \ collector.frames_per_batch=16 \ collector.env_per_collector=2 \ - collector.device=cuda:0 \ buffer.batch_size=10 \ optim.steps_per_batch=1 \ logger.record_video=True \ @@ -143,22 +134,18 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/sa collector.init_random_frames=10 \ collector.frames_per_batch=16 \ collector.env_per_collector=2 \ - collector.device=cuda:0 \ optim.batch_size=10 \ optim.utd_ratio=1 \ replay_buffer.size=120 \ env.name=Pendulum-v1 \ - network.device=cuda:0 \ logger.backend= python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/discrete_sac/discrete_sac.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ collector.frames_per_batch=16 \ collector.env_per_collector=1 \ - collector.device=cuda:0 \ optim.batch_size=10 \ optim.utd_ratio=1 \ - network.device=cuda:0 \ optim.batch_size=10 \ optim.utd_ratio=1 \ replay_buffer.size=120 \ @@ -185,9 +172,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td collector.frames_per_batch=16 \ collector.num_workers=4 \ collector.env_per_collector=2 \ - collector.device=cuda:0 \ - collector.device=cuda:0 \ - network.device=cuda:0 \ logger.mode=offline \ env.name=Pendulum-v1 \ logger.backend= @@ -196,8 +180,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq optim.batch_size=10 \ collector.frames_per_batch=16 \ env.train_num_envs=2 \ - optim.device=cuda:0 \ - collector.device=cuda:0 \ logger.mode=offline \ logger.backend= python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iql/discrete_iql.py \ @@ -205,8 +187,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq optim.batch_size=10 \ collector.frames_per_batch=16 \ env.train_num_envs=2 \ - optim.device=cuda:0 \ - collector.device=cuda:0 \ logger.mode=offline \ logger.backend= python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/cql_online.py \ @@ -214,8 +194,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq optim.batch_size=10 \ collector.frames_per_batch=16 \ env.train_num_envs=2 \ - collector.device=cuda:0 \ - optim.device=cuda:0 \ logger.mode=offline \ logger.backend= @@ -238,8 +216,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dd optim.batch_size=10 \ collector.frames_per_batch=16 \ collector.env_per_collector=1 \ - collector.device=cuda:0 \ - network.device=cuda:0 \ optim.utd_ratio=1 \ replay_buffer.size=120 \ env.name=Pendulum-v1 \ @@ -251,7 +227,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dq collector.init_random_frames=10 \ collector.frames_per_batch=16 \ buffer.batch_size=10 \ - device=cuda:0 \ loss.num_updates=1 \ logger.backend= \ buffer.buffer_size=120 @@ -262,7 +237,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/re collector.frames_per_batch=16 \ collector.env_per_collector=1 \ buffer.batch_size=10 \ - collector.device=cuda:0 \ optim.steps_per_batch=1 \ logger.record_video=True \ logger.record_frames=4 \ @@ -274,8 +248,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq collector.frames_per_batch=16 \ env.train_num_envs=1 \ logger.mode=offline \ - optim.device=cuda:0 \ - collector.device=cuda:0 \ logger.backend= python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/cql_online.py \ collector.total_frames=48 \ @@ -283,8 +255,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cq collector.frames_per_batch=16 \ collector.env_per_collector=1 \ logger.mode=offline \ - optim.device=cuda:0 \ - collector.device=cuda:0 \ logger.backend= python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3/td3.py \ collector.total_frames=48 \ @@ -292,11 +262,9 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td collector.frames_per_batch=16 \ collector.num_workers=2 \ collector.env_per_collector=1 \ - collector.device=cuda:0 \ logger.mode=offline \ optim.batch_size=10 \ env.name=Pendulum-v1 \ - network.device=cuda:0 \ logger.backend= python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/multiagent/mappo_ippo.py \ collector.n_iters=2 \ diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index 7ad39ed43e5..775dcfe206d 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import hydra from torchrl._utils import logger as torchrl_logger +from torchrl.record import VideoRecorder @hydra.main(config_path="", config_name="config_atari", version_base="1.1") @@ -104,6 +105,14 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create test environment test_env = make_parallel_env(cfg.env.env_name, 1, device, is_test=True) + test_env.set_seed(0) + if cfg.logger.video: + test_env = test_env.insert_transform( + 0, + VideoRecorder( + logger, tag=f"rendered/{cfg.env.env_name}", in_keys=["pixels"] + ), + ) test_env.eval() # Main loop diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index 7b4a153e150..0276039058f 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import hydra from torchrl._utils import logger as torchrl_logger +from torchrl.record import VideoRecorder @hydra.main(config_path="", config_name="config_mujoco", version_base="1.1") @@ -89,7 +90,15 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create test environment - test_env = make_env(cfg.env.env_name, device) + test_env = make_env(cfg.env.env_name, device, from_pixels=cfg.logger.video) + test_env.set_seed(0) + if cfg.logger.video: + test_env = test_env.insert_transform( + 0, + VideoRecorder( + logger, tag=f"rendered/{cfg.env.env_name}", in_keys=["pixels"] + ), + ) test_env.eval() # Main loop @@ -178,9 +187,10 @@ def main(cfg: "DictConfig"): # noqa: F821 # Get test rewards with torch.no_grad(), set_exploration_type(ExplorationType.MODE): - if ((i - 1) * frames_in_batch) // cfg.logger.test_interval < ( - i * frames_in_batch - ) // cfg.logger.test_interval: + prev_test_frame = ((i - 1) * frames_in_batch) // cfg.logger.test_interval + cur_test_frame = (i * frames_in_batch) // cfg.logger.test_interval + final = collected_frames >= collector.total_frames + if prev_test_frame < cur_test_frame or final: actor.eval() eval_start = time.time() test_rewards = eval_model( diff --git a/sota-implementations/a2c/config_atari.yaml b/sota-implementations/a2c/config_atari.yaml index 8c94f62fb93..dd0f43b52cb 100644 --- a/sota-implementations/a2c/config_atari.yaml +++ b/sota-implementations/a2c/config_atari.yaml @@ -16,6 +16,7 @@ logger: exp_name: Atari_Schulman17 test_interval: 40_000_000 num_test_episodes: 3 + video: False # Optim optim: diff --git a/sota-implementations/a2c/config_mujoco.yaml b/sota-implementations/a2c/config_mujoco.yaml index b30b7304f61..03a0bde32c5 100644 --- a/sota-implementations/a2c/config_mujoco.yaml +++ b/sota-implementations/a2c/config_mujoco.yaml @@ -15,6 +15,7 @@ logger: exp_name: Mujoco_Schulman17 test_interval: 1_000_000 num_test_episodes: 5 + video: False # Optim optim: diff --git a/sota-implementations/a2c/utils_atari.py b/sota-implementations/a2c/utils_atari.py index 0ddcd79123e..240ebac96d2 100644 --- a/sota-implementations/a2c/utils_atari.py +++ b/sota-implementations/a2c/utils_atari.py @@ -36,6 +36,8 @@ TanhNormal, ValueOperator, ) +from torchrl.record import VideoRecorder + # ==================================================================== # Environment utils @@ -201,6 +203,11 @@ def make_ppo_models(env_name): # -------------------------------------------------------------------- +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() + + def eval_model(actor, test_env, num_episodes=3): test_rewards = [] for _ in range(num_episodes): @@ -213,5 +220,6 @@ def eval_model(actor, test_env, num_episodes=3): ) reward = td_test["next", "episode_reward"][td_test["next", "done"]] test_rewards = np.append(test_rewards, reward.cpu().numpy()) + test_env.apply(dump_video) del td_test return test_rewards.mean() diff --git a/sota-implementations/a2c/utils_mujoco.py b/sota-implementations/a2c/utils_mujoco.py index 50780a9d086..178678e4457 100644 --- a/sota-implementations/a2c/utils_mujoco.py +++ b/sota-implementations/a2c/utils_mujoco.py @@ -20,14 +20,20 @@ ) from torchrl.envs.libs.gym import GymEnv from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator +from torchrl.record import VideoRecorder + # ==================================================================== # Environment utils # -------------------------------------------------------------------- -def make_env(env_name="HalfCheetah-v4", device="cpu"): - env = GymEnv(env_name, device=device) +def make_env( + env_name="HalfCheetah-v4", device="cpu", from_pixels=False, pixels_only=False +): + env = GymEnv( + env_name, device=device, from_pixels=from_pixels, pixels_only=pixels_only + ) env = TransformedEnv(env) env.append_transform(RewardSum()) env.append_transform(StepCounter()) @@ -125,6 +131,11 @@ def make_ppo_models(env_name): # -------------------------------------------------------------------- +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() + + def eval_model(actor, test_env, num_episodes=3): test_rewards = [] for _ in range(num_episodes): @@ -137,5 +148,6 @@ def eval_model(actor, test_env, num_episodes=3): ) reward = td_test["next", "episode_reward"][td_test["next", "done"]] test_rewards = np.append(test_rewards, reward.cpu().numpy()) + test_env.apply(dump_video) del td_test return test_rewards.mean() diff --git a/sota-implementations/cql/cql_offline.py b/sota-implementations/cql/cql_offline.py index 441cb3555e2..59b574090f9 100644 --- a/sota-implementations/cql/cql_offline.py +++ b/sota-implementations/cql/cql_offline.py @@ -20,6 +20,7 @@ from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( + dump_video, log_metrics, make_continuous_cql_optimizer, make_continuous_loss, @@ -49,16 +50,25 @@ def main(cfg: "DictConfig"): # noqa: F821 # Set seeds torch.manual_seed(cfg.env.seed) np.random.seed(cfg.env.seed) - device = torch.device(cfg.optim.device) + device = cfg.optim.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) # Create env - train_env, eval_env = make_environment(cfg, cfg.logger.eval_envs) + train_env, eval_env = make_environment( + cfg, train_num_envs=1, eval_num_envs=cfg.logger.eval_envs, logger=logger + ) # Create replay buffer replay_buffer = make_offline_replay_buffer(cfg.replay_buffer) # Create agent model = make_cql_model(cfg, train_env, eval_env, device) + del train_env # Create loss loss_module, target_net_updater = make_continuous_loss(cfg.loss, model) @@ -144,6 +154,7 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_td = eval_env.rollout( max_steps=eval_steps, policy=model[0], auto_cast_to_device=True ) + eval_env.apply(dump_video) eval_reward = eval_td["next", "reward"].sum(1).mean().item() to_log["evaluation_reward"] = eval_reward diff --git a/sota-implementations/cql/cql_online.py b/sota-implementations/cql/cql_online.py index a70f9091cb6..dc9bd512285 100644 --- a/sota-implementations/cql/cql_online.py +++ b/sota-implementations/cql/cql_online.py @@ -23,6 +23,7 @@ from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( + dump_video, log_metrics, make_collector, make_continuous_cql_optimizer, @@ -54,13 +55,20 @@ def main(cfg: "DictConfig"): # noqa: F821 # Set seeds torch.manual_seed(cfg.env.seed) np.random.seed(cfg.env.seed) - device = torch.device(cfg.optim.device) + device = cfg.optim.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) # Create env train_env, eval_env = make_environment( cfg, cfg.env.train_num_envs, cfg.env.eval_num_envs, + logger=logger, ) # Create replay buffer @@ -99,12 +107,12 @@ def main(cfg: "DictConfig"): # noqa: F821 * cfg.optim.utd_ratio ) prb = cfg.replay_buffer.prb - eval_iter = cfg.logger.eval_iter frames_per_batch = cfg.collector.frames_per_batch - eval_rollout_steps = cfg.collector.max_frames_per_traj + evaluation_interval = cfg.logger.log_interval + eval_rollout_steps = cfg.logger.eval_steps sampling_start = time.time() - for tensordict in collector: + for i, tensordict in enumerate(collector): sampling_time = time.time() - sampling_start pbar.update(tensordict.numel()) # update weights of the inference policy @@ -191,7 +199,11 @@ def main(cfg: "DictConfig"): # noqa: F821 metrics_to_log["train/training_time"] = training_time # Evaluation - if abs(collected_frames % eval_iter) < frames_per_batch: + + prev_test_frame = ((i - 1) * frames_per_batch) // evaluation_interval + cur_test_frame = (i * frames_per_batch) // evaluation_interval + final = current_frames >= collector.total_frames + if (i >= 1 and (prev_test_frame < cur_test_frame)) or final: with set_exploration_type(ExplorationType.MODE), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( @@ -202,6 +214,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() + eval_env.apply(dump_video) metrics_to_log["eval/reward"] = eval_reward metrics_to_log["eval/time"] = eval_time diff --git a/sota-implementations/cql/discrete_cql_config.yaml b/sota-implementations/cql/discrete_cql_config.yaml index 807479d45bd..644b8ec624e 100644 --- a/sota-implementations/cql/discrete_cql_config.yaml +++ b/sota-implementations/cql/discrete_cql_config.yaml @@ -30,6 +30,7 @@ logger: eval_steps: 200 mode: online eval_iter: 1000 + video: False # Buffer replay_buffer: @@ -41,7 +42,7 @@ replay_buffer: # Optimization optim: utd_ratio: 1 - device: cuda:0 + device: null lr: 1e-3 weight_decay: 0.0 batch_size: 256 diff --git a/sota-implementations/cql/discrete_cql_online.py b/sota-implementations/cql/discrete_cql_online.py index fd07684774d..4b6f14cd058 100644 --- a/sota-implementations/cql/discrete_cql_online.py +++ b/sota-implementations/cql/discrete_cql_online.py @@ -35,7 +35,13 @@ @hydra.main(version_base="1.1", config_path="", config_name="discrete_cql_config") def main(cfg: "DictConfig"): # noqa: F821 - device = torch.device(cfg.optim.device) + device = cfg.optim.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) # Create logger exp_name = generate_exp_name("DiscreteCQL", cfg.logger.exp_name) diff --git a/sota-implementations/cql/offline_config.yaml b/sota-implementations/cql/offline_config.yaml index 0047b74d14c..bf213d4e3c5 100644 --- a/sota-implementations/cql/offline_config.yaml +++ b/sota-implementations/cql/offline_config.yaml @@ -13,10 +13,12 @@ logger: project_name: torchrl_example_cql group_name: null exp_name: cql_${replay_buffer.dataset} + # eval iter in gradient steps eval_iter: 5000 eval_steps: 1000 mode: online eval_envs: 5 + video: False # replay buffer replay_buffer: @@ -25,7 +27,7 @@ replay_buffer: # optimization optim: - device: cuda:0 + device: null actor_lr: 3e-4 critic_lr: 3e-4 weight_decay: 0.0 diff --git a/sota-implementations/cql/online_config.yaml b/sota-implementations/cql/online_config.yaml index 9b3e5b5bf24..00db1d6bb62 100644 --- a/sota-implementations/cql/online_config.yaml +++ b/sota-implementations/cql/online_config.yaml @@ -26,9 +26,9 @@ logger: group_name: null exp_name: cql_${env.name} log_interval: 5000 # record interval in frames - eval_steps: 1000 mode: online - eval_iter: 1000 + eval_steps: 1000 + video: False # Buffer replay_buffer: @@ -39,7 +39,7 @@ replay_buffer: # Optimization optim: utd_ratio: 1 - device: cuda:0 + device: null actor_lr: 3e-4 critic_lr: 3e-4 weight_decay: 0.0 diff --git a/sota-implementations/cql/utils.py b/sota-implementations/cql/utils.py index 350b105b441..46b84ee434b 100644 --- a/sota-implementations/cql/utils.py +++ b/sota-implementations/cql/utils.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import functools + import torch.nn import torch.optim from tensordict.nn import TensorDictModule, TensorDictSequential @@ -37,6 +39,7 @@ ValueOperator, ) from torchrl.objectives import CQLLoss, DiscreteCQLLoss, SoftUpdate +from torchrl.record import VideoRecorder from torchrl.trainers.helpers.models import ACTIVATIONS @@ -45,16 +48,17 @@ # ----------------- -def env_maker(cfg, device="cpu"): +def env_maker(cfg, device="cpu", from_pixels=False): lib = cfg.env.backend if lib in ("gym", "gymnasium"): with set_gym_backend(lib): return GymEnv( - cfg.env.name, - device=device, + cfg.env.name, device=device, from_pixels=from_pixels, pixels_only=False ) elif lib == "dm_control": - env = DMControlEnv(cfg.env.name, cfg.env.task) + env = DMControlEnv( + cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False + ) return TransformedEnv( env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation") ) @@ -75,25 +79,32 @@ def apply_env_transforms( return transformed_env -def make_environment(cfg, train_num_envs=1, eval_num_envs=1): +def make_environment(cfg, train_num_envs=1, eval_num_envs=1, logger=None): """Make environments for training and evaluation.""" + maker = functools.partial(env_maker, cfg) parallel_env = ParallelEnv( train_num_envs, - EnvCreator(lambda cfg=cfg: env_maker(cfg)), + EnvCreator(maker), serial_for_single=True, ) parallel_env.set_seed(cfg.env.seed) train_env = apply_env_transforms(parallel_env) + maker = functools.partial(env_maker, cfg, from_pixels=cfg.logger.video) eval_env = TransformedEnv( ParallelEnv( eval_num_envs, - EnvCreator(lambda cfg=cfg: env_maker(cfg)), + EnvCreator(maker), serial_for_single=True, ), train_env.transform.clone(), ) + eval_env.set_seed(0) + if cfg.logger.video: + eval_env = eval_env.insert_transform( + 0, VideoRecorder(logger=logger, tag="rendered", in_keys=["pixels"]) + ) return train_env, eval_env @@ -373,3 +384,8 @@ def log_metrics(logger, metrics, step): if logger is not None: for metric_name, metric_value in metrics.items(): logger.log_scalar(metric_name, metric_value, step) + + +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() diff --git a/sota-implementations/ddpg/config.yaml b/sota-implementations/ddpg/config.yaml index 7d17038330b..43cb5093c09 100644 --- a/sota-implementations/ddpg/config.yaml +++ b/sota-implementations/ddpg/config.yaml @@ -32,12 +32,12 @@ optim: weight_decay: 1e-4 batch_size: 256 target_update_polyak: 0.995 + device: null # network network: hidden_sizes: [256, 256] activation: relu - device: "cuda:0" noise_type: "ou" # ou or gaussian # logging @@ -48,3 +48,5 @@ logger: exp_name: ${env.name}_DDPG mode: online eval_iter: 25000 + video: False + num_eval_envs: 1 diff --git a/sota-implementations/ddpg/ddpg.py b/sota-implementations/ddpg/ddpg.py index e8313e6c342..eb0b88c26f7 100644 --- a/sota-implementations/ddpg/ddpg.py +++ b/sota-implementations/ddpg/ddpg.py @@ -23,6 +23,7 @@ from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( + dump_video, log_metrics, make_collector, make_ddpg_agent, @@ -35,7 +36,13 @@ @hydra.main(version_base="1.1", config_path="", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 - device = torch.device(cfg.network.device) + device = cfg.optim.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) # Create logger exp_name = generate_exp_name("DDPG", cfg.logger.exp_name) @@ -58,7 +65,7 @@ def main(cfg: "DictConfig"): # noqa: F821 np.random.seed(cfg.env.seed) # Create environments - train_env, eval_env = make_environment(cfg) + train_env, eval_env = make_environment(cfg, logger=logger) # Create agent model, exploration_policy = make_ddpg_agent(cfg, train_env, eval_env, device) @@ -186,6 +193,7 @@ def main(cfg: "DictConfig"): # noqa: F821 auto_cast_to_device=True, break_when_any_done=True, ) + eval_env.apply(dump_video) eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward diff --git a/sota-implementations/ddpg/utils.py b/sota-implementations/ddpg/utils.py index 4006fc27b38..45c6da7a342 100644 --- a/sota-implementations/ddpg/utils.py +++ b/sota-implementations/ddpg/utils.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import functools + import torch from torch import nn, optim @@ -34,6 +36,7 @@ from torchrl.objectives import SoftUpdate from torchrl.objectives.ddpg import DDPGLoss +from torchrl.record import VideoRecorder # ==================================================================== @@ -41,16 +44,17 @@ # ----------------- -def env_maker(cfg, device="cpu"): +def env_maker(cfg, device="cpu", from_pixels=False): lib = cfg.env.library if lib in ("gym", "gymnasium"): with set_gym_backend(lib): return GymEnv( - cfg.env.name, - device=device, + cfg.env.name, device=device, from_pixels=from_pixels, pixels_only=False ) elif lib == "dm_control": - env = DMControlEnv(cfg.env.name, cfg.env.task) + env = DMControlEnv( + cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False + ) return TransformedEnv( env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation") ) @@ -71,11 +75,12 @@ def apply_env_transforms(env, max_episode_steps=1000): return transformed_env -def make_environment(cfg): +def make_environment(cfg, logger): """Make environments for training and evaluation.""" + maker = functools.partial(env_maker, cfg, from_pixels=False) parallel_env = ParallelEnv( cfg.collector.env_per_collector, - EnvCreator(lambda cfg=cfg: env_maker(cfg)), + EnvCreator(maker), serial_for_single=True, ) parallel_env.set_seed(cfg.env.seed) @@ -84,14 +89,20 @@ def make_environment(cfg): parallel_env, max_episode_steps=cfg.env.max_episode_steps ) + maker = functools.partial(env_maker, cfg, from_pixels=cfg.logger.video) eval_env = TransformedEnv( ParallelEnv( - cfg.collector.env_per_collector, - EnvCreator(lambda cfg=cfg: env_maker(cfg)), + cfg.logger.num_eval_envs, + EnvCreator(maker), serial_for_single=True, ), train_env.transform.clone(), ) + eval_env.set_seed(0) + if cfg.logger.video: + eval_env = eval_env.append_transform( + VideoRecorder(logger, tag="rendered", in_keys=["pixels"]) + ) return train_env, eval_env @@ -290,3 +301,8 @@ def get_activation(cfg): return nn.LeakyReLU else: raise NotImplementedError + + +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() diff --git a/sota-implementations/decision_transformer/dt.py b/sota-implementations/decision_transformer/dt.py index a79c0037205..59dbcafd8c9 100644 --- a/sota-implementations/decision_transformer/dt.py +++ b/sota-implementations/decision_transformer/dt.py @@ -17,8 +17,10 @@ from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper +from torchrl.record import VideoRecorder from utils import ( + dump_video, log_metrics, make_dt_loss, make_dt_model, @@ -34,6 +36,12 @@ def main(cfg: "DictConfig"): # noqa: F821 set_gym_backend(cfg.env.backend).set() model_device = cfg.optim.device + if model_device in ("", None): + if torch.cuda.is_available(): + model_device = "cuda:0" + else: + model_device = "cpu" + model_device = torch.device(model_device) # Set seeds torch.manual_seed(cfg.env.seed) @@ -48,7 +56,11 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create test environment - test_env = make_env(cfg.env, obs_loc, obs_std) + test_env = make_env(cfg.env, obs_loc, obs_std, from_pixels=cfg.logger.video) + if cfg.logger.video: + test_env = test_env.append_transform( + VideoRecorder(logger, tag="rendered", in_keys=["pixels"]) + ) # Create policy model actor = make_dt_model(cfg) @@ -109,6 +121,7 @@ def main(cfg: "DictConfig"): # noqa: F821 policy=inference_policy, auto_cast_to_device=True, ) + test_env.apply(dump_video) to_log["eval/reward"] = ( eval_td["next", "reward"].sum(1).mean().item() / reward_scaling ) diff --git a/sota-implementations/decision_transformer/dt_config.yaml b/sota-implementations/decision_transformer/dt_config.yaml index b42d8b58d35..4805785a62c 100644 --- a/sota-implementations/decision_transformer/dt_config.yaml +++ b/sota-implementations/decision_transformer/dt_config.yaml @@ -27,6 +27,7 @@ logger: pretrain_log_interval: 500 # record interval in frames fintune_log_interval: 1 eval_steps: 1000 + video: False # replay buffer replay_buffer: @@ -42,7 +43,7 @@ replay_buffer: # optimization optim: - device: cuda:0 + device: null lr: 1.0e-4 weight_decay: 5.0e-4 batch_size: 64 diff --git a/sota-implementations/decision_transformer/odt_config.yaml b/sota-implementations/decision_transformer/odt_config.yaml index f06972fd46b..eec2b455fb3 100644 --- a/sota-implementations/decision_transformer/odt_config.yaml +++ b/sota-implementations/decision_transformer/odt_config.yaml @@ -25,8 +25,9 @@ logger: exp_name: oDT-HalfCheetah-medium-v2 model_name: oDT pretrain_log_interval: 500 # record interval in frames - fintune_log_interval: 1 + finetune_log_interval: 1 eval_steps: 1000 + video: False # replay buffer replay_buffer: @@ -37,12 +38,11 @@ replay_buffer: buffer_prefetch: 64 capacity: 1_000_000 scratch_dir: - device: cuda:0 prefetch: 3 # optimizer optim: - device: cuda:0 + device: null lr: 1.0e-4 weight_decay: 5.0e-4 batch_size: 256 diff --git a/sota-implementations/decision_transformer/online_dt.py b/sota-implementations/decision_transformer/online_dt.py index 427b5d8eaa3..5cb297e5c0b 100644 --- a/sota-implementations/decision_transformer/online_dt.py +++ b/sota-implementations/decision_transformer/online_dt.py @@ -17,8 +17,10 @@ from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper +from torchrl.record import VideoRecorder from utils import ( + dump_video, log_metrics, make_env, make_logger, @@ -34,6 +36,12 @@ def main(cfg: "DictConfig"): # noqa: F821 set_gym_backend(cfg.env.backend).set() model_device = cfg.optim.device + if model_device in ("", None): + if torch.cuda.is_available(): + model_device = "cuda:0" + else: + model_device = "cpu" + model_device = torch.device(model_device) # Set seeds torch.manual_seed(cfg.env.seed) @@ -48,7 +56,11 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create test environment - test_env = make_env(cfg.env, obs_loc, obs_std) + test_env = make_env(cfg.env, obs_loc, obs_std, from_pixels=cfg.logger.video) + if cfg.logger.video: + test_env = test_env.append_transform( + VideoRecorder(logger, tag="rendered", in_keys=["pixels"]) + ) # Create policy model actor = make_odt_model(cfg) @@ -123,6 +135,7 @@ def main(cfg: "DictConfig"): # noqa: F821 auto_cast_to_device=True, break_when_any_done=False, ) + test_env.apply(dump_video) inference_policy.train() to_log["eval/reward"] = ( eval_td["next", "reward"].sum(1).mean().item() / reward_scaling diff --git a/sota-implementations/decision_transformer/utils.py b/sota-implementations/decision_transformer/utils.py index 9d479a8118d..a87b3cd8d9f 100644 --- a/sota-implementations/decision_transformer/utils.py +++ b/sota-implementations/decision_transformer/utils.py @@ -48,6 +48,7 @@ ) from torchrl.objectives import DTLoss, OnlineDTLoss +from torchrl.record import VideoRecorder from torchrl.record.loggers import generate_exp_name, get_logger from torchrl.trainers.helpers.envs import LIBS @@ -56,7 +57,7 @@ # ----------------- -def make_base_env(env_cfg): +def make_base_env(env_cfg, from_pixels=False): set_gym_backend(env_cfg.backend).set() env_library = LIBS[env_cfg.library] @@ -66,6 +67,8 @@ def make_base_env(env_cfg): env_kwargs = { "env_name": env_name, "frame_skip": frame_skip, + "from_pixels": from_pixels, + "pixels_only": False, } if env_library is DMControlEnv: env_task = env_cfg.task @@ -131,7 +134,7 @@ def make_transformed_env(base_env, env_cfg, obs_loc, obs_std, train=False): return transformed_env -def make_parallel_env(env_cfg, obs_loc, obs_std, train=False): +def make_parallel_env(env_cfg, obs_loc, obs_std, train=False, from_pixels=False): if train: num_envs = env_cfg.num_train_envs else: @@ -139,7 +142,7 @@ def make_parallel_env(env_cfg, obs_loc, obs_std, train=False): def make_env(): with set_gym_backend(env_cfg.backend): - return make_base_env(env_cfg) + return make_base_env(env_cfg, from_pixels=from_pixels) env = make_transformed_env( ParallelEnv(num_envs, EnvCreator(make_env), serial_for_single=True), @@ -151,8 +154,10 @@ def make_env(): return env -def make_env(env_cfg, obs_loc, obs_std, train=False): - env = make_parallel_env(env_cfg, obs_loc, obs_std, train=train) +def make_env(env_cfg, obs_loc, obs_std, train=False, from_pixels=False): + env = make_parallel_env( + env_cfg, obs_loc, obs_std, train=train, from_pixels=from_pixels + ) return env @@ -517,3 +522,8 @@ def make_logger(cfg): def log_metrics(logger, metrics, step): for metric_name, metric_value in metrics.items(): logger.log_scalar(metric_name, metric_value, step) + + +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() diff --git a/sota-implementations/discrete_sac/config.yaml b/sota-implementations/discrete_sac/config.yaml index df26c835ef0..aa852ca1fc3 100644 --- a/sota-implementations/discrete_sac/config.yaml +++ b/sota-implementations/discrete_sac/config.yaml @@ -14,7 +14,7 @@ collector: init_env_steps: 1000 frames_per_batch: 500 reset_at_each_iter: False - device: cuda:0 + device: null env_per_collector: 1 num_workers: 1 @@ -42,7 +42,7 @@ optim: network: hidden_sizes: [256, 256] activation: relu - device: "cuda:0" + device: null # logging logger: @@ -52,3 +52,4 @@ logger: exp_name: ${env.name}_DiscreteSAC mode: online eval_iter: 5000 + video: False diff --git a/sota-implementations/discrete_sac/discrete_sac.py b/sota-implementations/discrete_sac/discrete_sac.py index 40d9a1743c2..6e100f92dc3 100644 --- a/sota-implementations/discrete_sac/discrete_sac.py +++ b/sota-implementations/discrete_sac/discrete_sac.py @@ -23,6 +23,7 @@ from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( + dump_video, log_metrics, make_collector, make_environment, @@ -35,7 +36,13 @@ @hydra.main(version_base="1.1", config_path="", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 - device = torch.device(cfg.network.device) + device = cfg.network.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) # Create logger exp_name = generate_exp_name("DiscreteSAC", cfg.logger.exp_name) @@ -58,7 +65,7 @@ def main(cfg: "DictConfig"): # noqa: F821 np.random.seed(cfg.env.seed) # Create environments - train_env, eval_env = make_environment(cfg) + train_env, eval_env = make_environment(cfg, logger=logger) # Create agent model = make_sac_agent(cfg, train_env, eval_env, device) @@ -100,7 +107,7 @@ def main(cfg: "DictConfig"): # noqa: F821 frames_per_batch = cfg.collector.frames_per_batch sampling_start = time.time() - for tensordict in collector: + for i, tensordict in enumerate(collector): sampling_time = time.time() - sampling_start # Update weights of the inference policy @@ -193,7 +200,10 @@ def main(cfg: "DictConfig"): # noqa: F821 metrics_to_log["train/training_time"] = training_time # Evaluation - if abs(collected_frames % eval_iter) < frames_per_batch: + prev_test_frame = ((i - 1) * frames_per_batch) // eval_iter + cur_test_frame = (i * frames_per_batch) // eval_iter + final = current_frames >= collector.total_frames + if (i >= 1 and (prev_test_frame < cur_test_frame)) or final: with set_exploration_type(ExplorationType.MODE), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( @@ -202,6 +212,7 @@ def main(cfg: "DictConfig"): # noqa: F821 auto_cast_to_device=True, break_when_any_done=True, ) + eval_env.apply(dump_video) eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward diff --git a/sota-implementations/discrete_sac/utils.py b/sota-implementations/discrete_sac/utils.py index 5821ed53465..ddffffc2a8e 100644 --- a/sota-implementations/discrete_sac/utils.py +++ b/sota-implementations/discrete_sac/utils.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import functools import tempfile from contextlib import nullcontext @@ -36,22 +37,25 @@ from torchrl.modules.tensordict_module.actors import ProbabilisticActor from torchrl.objectives import SoftUpdate from torchrl.objectives.sac import DiscreteSACLoss +from torchrl.record import VideoRecorder + # ==================================================================== # Environment utils # ----------------- -def env_maker(cfg, device="cpu"): +def env_maker(cfg, device="cpu", from_pixels=False): lib = cfg.env.library if lib in ("gym", "gymnasium"): with set_gym_backend(lib): return GymEnv( - cfg.env.name, - device=device, + cfg.env.name, device=device, from_pixels=from_pixels, pixels_only=False ) elif lib == "dm_control": - env = DMControlEnv(cfg.env.name, cfg.env.task) + env = DMControlEnv( + cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False + ) return TransformedEnv( env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation") ) @@ -72,11 +76,12 @@ def apply_env_transforms(env, max_episode_steps): return transformed_env -def make_environment(cfg): +def make_environment(cfg, logger=None): """Make environments for training and evaluation.""" + maker = functools.partial(env_maker, cfg) parallel_env = ParallelEnv( cfg.collector.env_per_collector, - EnvCreator(lambda cfg=cfg: env_maker(cfg)), + EnvCreator(maker), serial_for_single=True, ) parallel_env.set_seed(cfg.env.seed) @@ -85,14 +90,19 @@ def make_environment(cfg): parallel_env, max_episode_steps=cfg.env.max_episode_steps ) + maker = functools.partial(env_maker, cfg, from_pixels=cfg.logger.video) eval_env = TransformedEnv( ParallelEnv( cfg.collector.env_per_collector, - EnvCreator(lambda cfg=cfg: env_maker(cfg)), + EnvCreator(maker), serial_for_single=True, ), train_env.transform.clone(), ) + if cfg.logger.video: + eval_env = eval_env.insert_transform( + 0, VideoRecorder(logger, tag="rendered", in_keys=["pixels"]) + ) return train_env, eval_env @@ -103,6 +113,13 @@ def make_environment(cfg): def make_collector(cfg, train_env, actor_model_explore): """Make collector.""" + device = cfg.collector.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) collector = SyncDataCollector( train_env, actor_model_explore, @@ -110,7 +127,8 @@ def make_collector(cfg, train_env, actor_model_explore): frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, reset_at_each_iter=cfg.collector.reset_at_each_iter, - device=cfg.collector.device, + device=device, + storing_device="cpu", ) collector.set_seed(cfg.env.seed) return collector @@ -288,3 +306,8 @@ def get_activation(cfg): return nn.LeakyReLU else: raise NotImplementedError + + +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() diff --git a/sota-implementations/dqn/config_atari.yaml b/sota-implementations/dqn/config_atari.yaml index 691fb4ff626..50e374cef14 100644 --- a/sota-implementations/dqn/config_atari.yaml +++ b/sota-implementations/dqn/config_atari.yaml @@ -1,4 +1,4 @@ -device: cuda:0 +device: null # Environment env: @@ -27,6 +27,7 @@ logger: exp_name: DQN test_interval: 1_000_000 num_test_episodes: 3 + video: False # Optim optim: diff --git a/sota-implementations/dqn/config_cartpole.yaml b/sota-implementations/dqn/config_cartpole.yaml index 1ebeba42f8c..9a69762d6bd 100644 --- a/sota-implementations/dqn/config_cartpole.yaml +++ b/sota-implementations/dqn/config_cartpole.yaml @@ -1,4 +1,4 @@ -device: cuda:0 +device: null # Environment env: @@ -26,6 +26,7 @@ logger: exp_name: DQN test_interval: 50_000 num_test_episodes: 5 + video: False # Optim optim: diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py index ba5f7cbf761..90f93551d4d 100644 --- a/sota-implementations/dqn/dqn_atari.py +++ b/sota-implementations/dqn/dqn_atari.py @@ -22,6 +22,7 @@ from torchrl.envs import ExplorationType, set_exploration_type from torchrl.modules import EGreedyModule from torchrl.objectives import DQNLoss, HardUpdate +from torchrl.record import VideoRecorder from torchrl.record.loggers import generate_exp_name, get_logger from utils_atari import eval_model, make_dqn_model, make_env @@ -29,7 +30,13 @@ @hydra.main(config_path="", config_name="config_atari", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 - device = torch.device(cfg.device) + device = cfg.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) # Correct for frame_skip frame_skip = 4 @@ -111,6 +118,13 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create the test environment test_env = make_env(cfg.env.env_name, frame_skip, device, is_test=True) + if cfg.logger.video: + test_env.insert_transform( + 0, + VideoRecorder( + logger, tag=f"rendered/{cfg.env.env_name}", in_keys=["pixels"] + ), + ) test_env.eval() # Main loop @@ -122,7 +136,7 @@ def main(cfg: "DictConfig"): # noqa: F821 num_test_episodes = cfg.logger.num_test_episodes q_losses = torch.zeros(num_updates, device=device) pbar = tqdm.tqdm(total=total_frames) - for data in collector: + for i, data in enumerate(collector): log_info = {} sampling_time = time.time() - sampling_start @@ -186,9 +200,10 @@ def main(cfg: "DictConfig"): # noqa: F821 # Get and log evaluation rewards and eval time with torch.no_grad(), set_exploration_type(ExplorationType.MODE): - if (collected_frames - frames_per_batch) // test_interval < ( - collected_frames // test_interval - ): + prev_test_frame = ((i - 1) * frames_per_batch) // test_interval + cur_test_frame = (i * frames_per_batch) // test_interval + final = current_frames >= collector.total_frames + if (i >= 1 and (prev_test_frame < cur_test_frame)) or final: model.eval() eval_start = time.time() test_rewards = eval_model( diff --git a/sota-implementations/dqn/dqn_cartpole.py b/sota-implementations/dqn/dqn_cartpole.py index cfe734173f5..ac3f17a9203 100644 --- a/sota-implementations/dqn/dqn_cartpole.py +++ b/sota-implementations/dqn/dqn_cartpole.py @@ -16,6 +16,7 @@ from torchrl.envs import ExplorationType, set_exploration_type from torchrl.modules import EGreedyModule from torchrl.objectives import DQNLoss, HardUpdate +from torchrl.record import VideoRecorder from torchrl.record.loggers import generate_exp_name, get_logger from utils_cartpole import eval_model, make_dqn_model, make_env @@ -23,7 +24,13 @@ @hydra.main(config_path="", config_name="config_cartpole", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 - device = torch.device(cfg.device) + device = cfg.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) # Make the components model = make_dqn_model(cfg.env.env_name) @@ -93,7 +100,14 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create the test environment - test_env = make_env(cfg.env.env_name, "cpu") + test_env = make_env(cfg.env.env_name, "cpu", from_pixels=cfg.logger.video) + if cfg.logger.video: + test_env.insert_transform( + 0, + VideoRecorder( + logger, tag=f"rendered/{cfg.env.env_name}", in_keys=["pixels"] + ), + ) # Main loop collected_frames = 0 @@ -108,7 +122,7 @@ def main(cfg: "DictConfig"): # noqa: F821 sampling_start = time.time() q_losses = torch.zeros(num_updates, device=device) - for data in collector: + for i, data in enumerate(collector): log_info = {} sampling_time = time.time() - sampling_start @@ -167,9 +181,10 @@ def main(cfg: "DictConfig"): # noqa: F821 # Get and log evaluation rewards and eval time with torch.no_grad(), set_exploration_type(ExplorationType.MODE): - if (collected_frames - frames_per_batch) // test_interval < ( - collected_frames // test_interval - ): + prev_test_frame = ((i - 1) * frames_per_batch) // test_interval + cur_test_frame = (i * frames_per_batch) // test_interval + final = current_frames >= collector.total_frames + if (i >= 1 and (prev_test_frame < cur_test_frame)) or final: model.eval() eval_start = time.time() test_rewards = eval_model(model, test_env, num_test_episodes) diff --git a/sota-implementations/dqn/utils_atari.py b/sota-implementations/dqn/utils_atari.py index b9805659e63..3dbbfe87af4 100644 --- a/sota-implementations/dqn/utils_atari.py +++ b/sota-implementations/dqn/utils_atari.py @@ -23,6 +23,7 @@ ) from torchrl.modules import ConvNet, MLP, QValueActor +from torchrl.record import VideoRecorder # ==================================================================== @@ -111,7 +112,13 @@ def eval_model(actor, test_env, num_episodes=3): break_when_any_done=True, max_steps=10_000_000, ) + test_env.apply(dump_video) reward = td_test["next", "episode_reward"][td_test["next", "done"]] test_rewards[i] = reward.sum() del td_test return test_rewards.mean() + + +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() diff --git a/sota-implementations/dqn/utils_cartpole.py b/sota-implementations/dqn/utils_cartpole.py index 8d2ec5fab06..2df280a04b4 100644 --- a/sota-implementations/dqn/utils_cartpole.py +++ b/sota-implementations/dqn/utils_cartpole.py @@ -9,14 +9,16 @@ from torchrl.envs import RewardSum, StepCounter, TransformedEnv from torchrl.envs.libs.gym import GymEnv from torchrl.modules import MLP, QValueActor +from torchrl.record import VideoRecorder + # ==================================================================== # Environment utils # -------------------------------------------------------------------- -def make_env(env_name="CartPole-v1", device="cpu"): - env = GymEnv(env_name, device=device) +def make_env(env_name="CartPole-v1", device="cpu", from_pixels=False): + env = GymEnv(env_name, device=device, from_pixels=from_pixels, pixels_only=False) env = TransformedEnv(env) env.append_transform(RewardSum()) env.append_transform(StepCounter()) @@ -74,7 +76,13 @@ def eval_model(actor, test_env, num_episodes=3): break_when_any_done=True, max_steps=10_000_000, ) + test_env.apply(dump_video) reward = td_test["next", "episode_reward"][td_test["next", "done"]] test_rewards[i] = reward.sum() del td_test return test_rewards.mean() + + +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() diff --git a/sota-implementations/iql/discrete_iql.py b/sota-implementations/iql/discrete_iql.py index c0101f1c941..33513dd3973 100644 --- a/sota-implementations/iql/discrete_iql.py +++ b/sota-implementations/iql/discrete_iql.py @@ -24,6 +24,7 @@ from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( + dump_video, log_metrics, make_collector, make_discrete_iql_model, @@ -57,13 +58,20 @@ def main(cfg: "DictConfig"): # noqa: F821 # Set seeds torch.manual_seed(cfg.env.seed) np.random.seed(cfg.env.seed) - device = torch.device(cfg.optim.device) + device = cfg.optim.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) # Create environments train_env, eval_env = make_environment( cfg, cfg.env.train_num_envs, cfg.env.eval_num_envs, + logger=logger, ) # Create replay buffer @@ -186,6 +194,7 @@ def main(cfg: "DictConfig"): # noqa: F821 auto_cast_to_device=True, break_when_any_done=True, ) + eval_env.apply(dump_video) eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward diff --git a/sota-implementations/iql/discrete_iql.yaml b/sota-implementations/iql/discrete_iql.yaml index c21a320e375..9245d4c4832 100644 --- a/sota-implementations/iql/discrete_iql.yaml +++ b/sota-implementations/iql/discrete_iql.yaml @@ -28,6 +28,7 @@ logger: eval_steps: 200 mode: online eval_iter: 1000 + video: False # replay buffer replay_buffer: @@ -38,7 +39,7 @@ replay_buffer: # optimization optim: utd_ratio: 1 - device: cuda:0 + device: null lr: 3e-4 weight_decay: 0.0 batch_size: 256 diff --git a/sota-implementations/iql/iql_offline.py b/sota-implementations/iql/iql_offline.py index 66c6d206c3d..d98724e1371 100644 --- a/sota-implementations/iql/iql_offline.py +++ b/sota-implementations/iql/iql_offline.py @@ -22,6 +22,7 @@ from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( + dump_video, log_metrics, make_environment, make_iql_model, @@ -54,10 +55,20 @@ def main(cfg: "DictConfig"): # noqa: F821 # Set seeds torch.manual_seed(cfg.env.seed) np.random.seed(cfg.env.seed) - device = torch.device(cfg.optim.device) + device = cfg.optim.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) # Creante env - train_env, eval_env = make_environment(cfg, cfg.logger.eval_envs) + train_env, eval_env = make_environment( + cfg, + cfg.logger.eval_envs, + logger=logger, + ) # Create replay buffer replay_buffer = make_offline_replay_buffer(cfg.replay_buffer) @@ -123,6 +134,7 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_td = eval_env.rollout( max_steps=eval_steps, policy=model[0], auto_cast_to_device=True ) + eval_env.apply(dump_video) eval_reward = eval_td["next", "reward"].sum(1).mean().item() to_log["evaluation_reward"] = eval_reward if logger is not None: diff --git a/sota-implementations/iql/iql_online.py b/sota-implementations/iql/iql_online.py index 307f6df5e2b..b66c6f9dcf2 100644 --- a/sota-implementations/iql/iql_online.py +++ b/sota-implementations/iql/iql_online.py @@ -24,6 +24,7 @@ from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( + dump_video, log_metrics, make_collector, make_environment, @@ -57,13 +58,20 @@ def main(cfg: "DictConfig"): # noqa: F821 # Set seeds torch.manual_seed(cfg.env.seed) np.random.seed(cfg.env.seed) - device = torch.device(cfg.optim.device) + device = cfg.optim.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) # Create environments train_env, eval_env = make_environment( cfg, cfg.env.train_num_envs, cfg.env.eval_num_envs, + logger=logger, ) # Create replay buffer @@ -184,6 +192,7 @@ def main(cfg: "DictConfig"): # noqa: F821 auto_cast_to_device=True, break_when_any_done=True, ) + eval_env.apply(dump_video) eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward diff --git a/sota-implementations/iql/offline_config.yaml b/sota-implementations/iql/offline_config.yaml index f7486708c5a..5f34fa5651a 100644 --- a/sota-implementations/iql/offline_config.yaml +++ b/sota-implementations/iql/offline_config.yaml @@ -17,6 +17,7 @@ logger: eval_steps: 1000 mode: online eval_envs: 5 + video: False # replay buffer replay_buffer: @@ -25,7 +26,7 @@ replay_buffer: # optimization optim: - device: cuda:0 + device: null lr: 3e-4 weight_decay: 0.0 gradient_steps: 50000 diff --git a/sota-implementations/iql/online_config.yaml b/sota-implementations/iql/online_config.yaml index 511d77ec365..1f7bb361e6c 100644 --- a/sota-implementations/iql/online_config.yaml +++ b/sota-implementations/iql/online_config.yaml @@ -28,6 +28,7 @@ logger: eval_steps: 200 mode: online eval_iter: 1000 + video: False # replay buffer replay_buffer: @@ -38,7 +39,7 @@ replay_buffer: # optimization optim: utd_ratio: 1 - device: cuda:0 + device: null lr: 3e-4 weight_decay: 0.0 batch_size: 256 diff --git a/sota-implementations/iql/utils.py b/sota-implementations/iql/utils.py index 8b594d3a60c..2d5aee80ce2 100644 --- a/sota-implementations/iql/utils.py +++ b/sota-implementations/iql/utils.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import functools + import torch.nn import torch.optim from tensordict.nn import InteractionType, TensorDictModule @@ -39,6 +41,7 @@ ValueOperator, ) from torchrl.objectives import DiscreteIQLLoss, HardUpdate, IQLLoss, SoftUpdate +from torchrl.record import VideoRecorder from torchrl.trainers.helpers.models import ACTIVATIONS @@ -48,16 +51,17 @@ # ----------------- -def env_maker(cfg, device="cpu"): +def env_maker(cfg, device="cpu", from_pixels=False): lib = cfg.env.backend if lib in ("gym", "gymnasium"): with set_gym_backend(lib): return GymEnv( - cfg.env.name, - device=device, + cfg.env.name, device=device, from_pixels=from_pixels, pixels_only=False ) elif lib == "dm_control": - env = DMControlEnv(cfg.env.name, cfg.env.task) + env = DMControlEnv( + cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False + ) return TransformedEnv( env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation") ) @@ -79,25 +83,31 @@ def apply_env_transforms( return transformed_env -def make_environment(cfg, train_num_envs=1, eval_num_envs=1): +def make_environment(cfg, train_num_envs=1, eval_num_envs=1, logger=None): """Make environments for training and evaluation.""" + maker = functools.partial(env_maker, cfg) parallel_env = ParallelEnv( train_num_envs, - EnvCreator(lambda: env_maker(cfg)), + EnvCreator(maker), serial_for_single=True, ) parallel_env.set_seed(cfg.env.seed) train_env = apply_env_transforms(parallel_env) + maker = functools.partial(env_maker, cfg, from_pixels=cfg.logger.video) eval_env = TransformedEnv( ParallelEnv( eval_num_envs, - EnvCreator(lambda: env_maker(cfg)), + EnvCreator(maker), serial_for_single=True, ), train_env.transform.clone(), ) + if cfg.logger.video: + eval_env.insert_transform( + 0, VideoRecorder(logger, tag="rendered", in_keys=["pixels"]) + ) return train_env, eval_env @@ -417,3 +427,8 @@ def log_metrics(logger, metrics, step): if logger is not None: for metric_name, metric_value in metrics.items(): logger.log_scalar(metric_name, metric_value, step) + + +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() diff --git a/sota-implementations/ppo/config_atari.yaml b/sota-implementations/ppo/config_atari.yaml index d6ec35ab5f2..31e6f13a58c 100644 --- a/sota-implementations/ppo/config_atari.yaml +++ b/sota-implementations/ppo/config_atari.yaml @@ -16,6 +16,7 @@ logger: exp_name: Atari_Schulman17 test_interval: 40_000_000 num_test_episodes: 3 + video: False # Optim optim: diff --git a/sota-implementations/ppo/config_mujoco.yaml b/sota-implementations/ppo/config_mujoco.yaml index 3320837ae3d..2dd3c6cc229 100644 --- a/sota-implementations/ppo/config_mujoco.yaml +++ b/sota-implementations/ppo/config_mujoco.yaml @@ -15,6 +15,7 @@ logger: exp_name: Mujoco_Schulman17 test_interval: 1_000_000 num_test_episodes: 5 + video: False # Optim optim: diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py index 69468e133a8..908cb7924a3 100644 --- a/sota-implementations/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -9,6 +9,7 @@ """ import hydra from torchrl._utils import logger as torchrl_logger +from torchrl.record import VideoRecorder @hydra.main(config_path="", config_name="config_atari", version_base="1.1") @@ -104,9 +105,16 @@ def main(cfg: "DictConfig"): # noqa: F821 "group": cfg.logger.group_name, }, ) + logger_video = cfg.logger.video + else: + logger_video = False # Create test environment test_env = make_parallel_env(cfg.env.env_name, 1, device, is_test=True) + if logger_video: + test_env = test_env.append_transform( + VideoRecorder(logger, tag="rendering/test", in_keys=["pixels_int"]) + ) test_env.eval() # Main loop diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py index ae4ba9ea9e5..e3e74971a49 100644 --- a/sota-implementations/ppo/ppo_mujoco.py +++ b/sota-implementations/ppo/ppo_mujoco.py @@ -9,6 +9,7 @@ """ import hydra from torchrl._utils import logger as torchrl_logger +from torchrl.record import VideoRecorder @hydra.main(config_path="", config_name="config_mujoco", version_base="1.1") @@ -96,9 +97,16 @@ def main(cfg: "DictConfig"): # noqa: F821 "group": cfg.logger.group_name, }, ) + logger_video = cfg.logger.video + else: + logger_video = False # Create test environment - test_env = make_env(cfg.env.env_name, device) + test_env = make_env(cfg.env.env_name, device, from_pixels=logger_video) + if logger_video: + test_env = test_env.append_transform( + VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"]) + ) test_env.eval() # Main loop diff --git a/sota-implementations/ppo/utils_atari.py b/sota-implementations/ppo/utils_atari.py index 5cb838cac47..f2e4ae8cebf 100644 --- a/sota-implementations/ppo/utils_atari.py +++ b/sota-implementations/ppo/utils_atari.py @@ -18,6 +18,7 @@ GymEnv, NoopResetEnv, ParallelEnv, + RenameTransform, Resize, RewardSum, SignTransform, @@ -35,6 +36,8 @@ TanhNormal, ValueOperator, ) +from torchrl.record import VideoRecorder + # ==================================================================== # Environment utils @@ -64,7 +67,8 @@ def make_parallel_env(env_name, num_envs, device, is_test=False): device=device, ) env = TransformedEnv(env) - env.append_transform(ToTensorImage()) + env.append_transform(RenameTransform(in_keys=["pixels"], out_keys=["pixels_int"])) + env.append_transform(ToTensorImage(in_keys=["pixels_int"], out_keys=["pixels"])) env.append_transform(GrayScale()) env.append_transform(Resize(84, 84)) env.append_transform(CatFrames(N=4, dim=-3)) @@ -198,6 +202,11 @@ def make_ppo_models(env_name): # -------------------------------------------------------------------- +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() + + def eval_model(actor, test_env, num_episodes=3): test_rewards = [] for _ in range(num_episodes): @@ -208,6 +217,7 @@ def eval_model(actor, test_env, num_episodes=3): break_when_any_done=True, max_steps=10_000_000, ) + test_env.apply(dump_video) reward = td_test["next", "episode_reward"][td_test["next", "done"]] test_rewards.append(reward.cpu()) del td_test diff --git a/sota-implementations/ppo/utils_mujoco.py b/sota-implementations/ppo/utils_mujoco.py index 7be234b322d..eefd8bebb6b 100644 --- a/sota-implementations/ppo/utils_mujoco.py +++ b/sota-implementations/ppo/utils_mujoco.py @@ -19,14 +19,16 @@ ) from torchrl.envs.libs.gym import GymEnv from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator +from torchrl.record import VideoRecorder + # ==================================================================== # Environment utils # -------------------------------------------------------------------- -def make_env(env_name="HalfCheetah-v4", device="cpu"): - env = GymEnv(env_name, device=device) +def make_env(env_name="HalfCheetah-v4", device="cpu", from_pixels: bool = False): + env = GymEnv(env_name, device=device, from_pixels=from_pixels, pixels_only=False) env = TransformedEnv(env) env.append_transform(VecNorm(in_keys=["observation"], decay=0.99999, eps=1e-2)) env.append_transform(ClipTransform(in_keys=["observation"], low=-10, high=10)) @@ -126,6 +128,11 @@ def make_ppo_models(env_name): # -------------------------------------------------------------------- +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() + + def eval_model(actor, test_env, num_episodes=3): test_rewards = [] for _ in range(num_episodes): @@ -138,5 +145,6 @@ def eval_model(actor, test_env, num_episodes=3): ) reward = td_test["next", "episode_reward"][td_test["next", "done"]] test_rewards.append(reward.cpu()) + test_env.apply(dump_video) del td_test return torch.cat(test_rewards, 0).mean() diff --git a/sota-implementations/redq/config.yaml b/sota-implementations/redq/config.yaml index c67543716dc..e60191c0f93 100644 --- a/sota-implementations/redq/config.yaml +++ b/sota-implementations/redq/config.yaml @@ -43,11 +43,11 @@ logger: project_name: torchrl_example_redq group_name: null exp_name: cheetah - record_video: 0 record_interval: 10 record_frames: 10000 mode: online recorder_log_keys: + video: False optim: optimizer: adam diff --git a/sota-implementations/redq/redq.py b/sota-implementations/redq/redq.py index d9aef64b525..d6b1668aadf 100644 --- a/sota-implementations/redq/redq.py +++ b/sota-implementations/redq/redq.py @@ -63,18 +63,20 @@ def main(cfg: "DictConfig"): # noqa: F821 ] ) - logger = get_logger( - logger_type=cfg.logger.backend, - logger_name="redq_logging", - experiment_name=exp_name, - wandb_kwargs={ - "mode": cfg.logger.mode, - "config": dict(cfg), - "project": cfg.logger.project_name, - "group": cfg.logger.group_name, - }, - ) - video_tag = exp_name if cfg.logger.record_video else "" + if cfg.logger.backend: + logger = get_logger( + logger_type=cfg.logger.backend, + logger_name="redq_logging", + experiment_name=exp_name, + wandb_kwargs={ + "mode": cfg.logger.mode, + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, + ) + else: + logger = "" key, init_env_steps, stats = None, None, None if not cfg.env.vecnorm and cfg.env.norm_stats: @@ -146,7 +148,7 @@ def main(cfg: "DictConfig"): # noqa: F821 recorder = transformed_env_constructor( cfg, - video_tag=video_tag, + video_tag="rendering/test", norm_obs_only=True, obs_norm_state_dict=obs_norm_state_dict, logger=logger, @@ -162,8 +164,8 @@ def main(cfg: "DictConfig"): # noqa: F821 recorder.transform = create_env_fn.transform.clone() else: raise NotImplementedError(f"Unsupported env type {type(create_env_fn)}") - if logger is not None and video_tag: - recorder.insert_transform(0, VideoRecorder(logger=logger, tag=video_tag)) + if logger is not None and cfg.logger.video: + recorder.insert_transform(0, VideoRecorder(logger=logger, tag="rendering/test")) # reset reward scaling for t in recorder.transform: diff --git a/sota-implementations/redq/utils.py b/sota-implementations/redq/utils.py index 37e7da91b4a..0d2e53b9cb1 100644 --- a/sota-implementations/redq/utils.py +++ b/sota-implementations/redq/utils.py @@ -282,7 +282,7 @@ def make_trainer( rb_trainer = ReplayBufferTrainer( replay_buffer, cfg.buffer.batch_size, - flatten_tensordicts=False, + flatten_tensordicts=True, memmap=False, device=device, ) @@ -1044,7 +1044,6 @@ def make_replay_buffer( storage=LazyMemmapStorage( cfg.buffer.size, scratch_dir=cfg.buffer.scratch_dir, - # device=device, # when using prefetch, this can overload the GPU memory ), sampler=sampler, pin_memory=device != torch.device("cpu"), diff --git a/sota-implementations/sac/config.yaml b/sota-implementations/sac/config.yaml index 6546f1e30b7..29586f2e9a7 100644 --- a/sota-implementations/sac/config.yaml +++ b/sota-implementations/sac/config.yaml @@ -40,7 +40,7 @@ network: activation: relu default_policy_scale: 1.0 scale_lb: 0.1 - device: "cuda:0" + device: # logging logger: @@ -50,3 +50,4 @@ logger: exp_name: ${env.name}_SAC mode: online eval_iter: 25000 + video: False diff --git a/sota-implementations/sac/sac.py b/sota-implementations/sac/sac.py index 576de96394d..f7a399cda72 100644 --- a/sota-implementations/sac/sac.py +++ b/sota-implementations/sac/sac.py @@ -24,6 +24,7 @@ from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( + dump_video, log_metrics, make_collector, make_environment, @@ -36,7 +37,13 @@ @hydra.main(version_base="1.1", config_path="", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 - device = torch.device(cfg.network.device) + device = cfg.network.device + if device in ("", None): + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") + device = torch.device(device) # Create logger exp_name = generate_exp_name("SAC", cfg.logger.exp_name) @@ -58,7 +65,7 @@ def main(cfg: "DictConfig"): # noqa: F821 np.random.seed(cfg.env.seed) # Create environments - train_env, eval_env = make_environment(cfg) + train_env, eval_env = make_environment(cfg, logger=logger) # Create agent model, exploration_policy = make_sac_agent(cfg, train_env, eval_env, device) @@ -198,6 +205,7 @@ def main(cfg: "DictConfig"): # noqa: F821 auto_cast_to_device=True, break_when_any_done=True, ) + eval_env.apply(dump_video) eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward diff --git a/sota-implementations/sac/utils.py b/sota-implementations/sac/utils.py index afb731dcc95..d190769772c 100644 --- a/sota-implementations/sac/utils.py +++ b/sota-implementations/sac/utils.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import functools import torch from tensordict.nn import InteractionType, TensorDictModule @@ -26,6 +27,7 @@ from torchrl.modules.distributions import TanhNormal from torchrl.objectives import SoftUpdate from torchrl.objectives.sac import SACLoss +from torchrl.record import VideoRecorder # ==================================================================== @@ -33,16 +35,20 @@ # ----------------- -def env_maker(cfg, device="cpu"): +def env_maker(cfg, device="cpu", from_pixels=False): lib = cfg.env.library if lib in ("gym", "gymnasium"): with set_gym_backend(lib): return GymEnv( cfg.env.name, device=device, + from_pixels=from_pixels, + pixels_only=False, ) elif lib == "dm_control": - env = DMControlEnv(cfg.env.name, cfg.env.task) + env = DMControlEnv( + cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False + ) return TransformedEnv( env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation") ) @@ -63,24 +69,31 @@ def apply_env_transforms(env, max_episode_steps=1000): return transformed_env -def make_environment(cfg): +def make_environment(cfg, logger=None): """Make environments for training and evaluation.""" + partial = functools.partial(env_maker, cfg=cfg) parallel_env = ParallelEnv( cfg.collector.env_per_collector, - EnvCreator(lambda cfg=cfg: env_maker(cfg)), + EnvCreator(partial), serial_for_single=True, ) parallel_env.set_seed(cfg.env.seed) train_env = apply_env_transforms(parallel_env, cfg.env.max_episode_steps) + partial = functools.partial(env_maker, cfg=cfg, from_pixels=cfg.logger.video) + trsf_clone = train_env.transform.clone() + if cfg.logger.video: + trsf_clone.insert( + 0, VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"]) + ) eval_env = TransformedEnv( ParallelEnv( cfg.collector.env_per_collector, - EnvCreator(lambda cfg=cfg: env_maker(cfg)), + EnvCreator(partial), serial_for_single=True, ), - train_env.transform.clone(), + trsf_clone, ) return train_env, eval_env @@ -211,13 +224,10 @@ def make_sac_agent(cfg, train_env, eval_env, device): # init nets with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): - td = eval_env.reset() + td = eval_env.fake_tensordict() td = td.to(device) for net in model: net(td) - del td - eval_env.close() - return model, model[0] @@ -298,3 +308,8 @@ def get_activation(cfg): return nn.LeakyReLU else: raise NotImplementedError + + +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() diff --git a/sota-implementations/td3/config.yaml b/sota-implementations/td3/config.yaml index e94a5b6b774..7f7854b68b3 100644 --- a/sota-implementations/td3/config.yaml +++ b/sota-implementations/td3/config.yaml @@ -41,7 +41,7 @@ optim: network: hidden_sizes: [256, 256] activation: relu - device: "cuda:0" + device: null # logging logger: @@ -51,3 +51,4 @@ logger: exp_name: ${env.name}_TD3 mode: online eval_iter: 25000 + video: False diff --git a/sota-implementations/td3/td3.py b/sota-implementations/td3/td3.py index 6b1ee046d55..97fd039c238 100644 --- a/sota-implementations/td3/td3.py +++ b/sota-implementations/td3/td3.py @@ -23,6 +23,7 @@ from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( + dump_video, log_metrics, make_collector, make_environment, @@ -35,7 +36,13 @@ @hydra.main(version_base="1.1", config_path="", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 - device = torch.device(cfg.network.device) + device = cfg.network.device + if device in ("", None): + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") + device = torch.device(device) # Create logger exp_name = generate_exp_name("TD3", cfg.logger.exp_name) @@ -58,7 +65,7 @@ def main(cfg: "DictConfig"): # noqa: F821 np.random.seed(cfg.env.seed) # Create environments - train_env, eval_env = make_environment(cfg) + train_env, eval_env = make_environment(cfg, logger=logger) # Create agent model, exploration_policy = make_td3_agent(cfg, train_env, eval_env, device) @@ -196,6 +203,7 @@ def main(cfg: "DictConfig"): # noqa: F821 auto_cast_to_device=True, break_when_any_done=True, ) + eval_env.apply(dump_video) eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward diff --git a/sota-implementations/td3/utils.py b/sota-implementations/td3/utils.py index fed055f98bf..c597ae205a2 100644 --- a/sota-implementations/td3/utils.py +++ b/sota-implementations/td3/utils.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import functools import tempfile from contextlib import nullcontext @@ -36,6 +37,7 @@ from torchrl.objectives import SoftUpdate from torchrl.objectives.td3 import TD3Loss +from torchrl.record import VideoRecorder # ==================================================================== @@ -43,16 +45,20 @@ # ----------------- -def env_maker(cfg, device="cpu"): +def env_maker(cfg, device="cpu", from_pixels=False): lib = cfg.env.library if lib in ("gym", "gymnasium"): with set_gym_backend(lib): return GymEnv( cfg.env.name, device=device, + from_pixels=from_pixels, + pixels_only=False, ) elif lib == "dm_control": - env = DMControlEnv(cfg.env.name, cfg.env.task) + env = DMControlEnv( + cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False + ) return TransformedEnv( env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation") ) @@ -73,26 +79,31 @@ def apply_env_transforms(env, max_episode_steps): return transformed_env -def make_environment(cfg): +def make_environment(cfg, logger=None): """Make environments for training and evaluation.""" + partial = functools.partial(env_maker, cfg=cfg) parallel_env = ParallelEnv( cfg.collector.env_per_collector, - EnvCreator(lambda cfg=cfg: env_maker(cfg)), + EnvCreator(partial), serial_for_single=True, ) parallel_env.set_seed(cfg.env.seed) - train_env = apply_env_transforms( - parallel_env, max_episode_steps=cfg.env.max_episode_steps - ) + train_env = apply_env_transforms(parallel_env, cfg.env.max_episode_steps) + partial = functools.partial(env_maker, cfg=cfg, from_pixels=cfg.logger.video) + trsf_clone = train_env.transform.clone() + if cfg.logger.video: + trsf_clone.insert( + 0, VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"]) + ) eval_env = TransformedEnv( ParallelEnv( cfg.collector.env_per_collector, - EnvCreator(lambda cfg=cfg: env_maker(cfg)), + EnvCreator(partial), serial_for_single=True, ), - train_env.transform.clone(), + trsf_clone, ) return train_env, eval_env @@ -297,3 +308,8 @@ def get_activation(cfg): return nn.LeakyReLU else: raise NotImplementedError + + +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 24cf2819eab..ccab829d480 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -929,7 +929,20 @@ def __getattr__(self, attr: str) -> Any: attr ) # make sure that appropriate exceptions are raised except AttributeError as err: - if attr.endswith("_spec"): + if attr in ( + "action_spec", + "done_spec", + "full_action_spec", + "full_done_spec", + "full_observation_spec", + "full_reward_spec", + "full_state_spec", + "input_spec", + "observation_spec", + "output_spec", + "reward_spec", + "state_spec", + ): raise AttributeError( f"Could not get {attr} because an internal error was raised. To find what this error " f"is, call env.transform.transform__spec(env.base_env.spec)." @@ -3511,9 +3524,10 @@ def func(name, item): item = self._apply_transform(item) tensordict.set(name, item) - return tensordict._fast_apply( + tensordict._fast_apply( func, named=True, nested_keys=True, filter_empty=True ) + return tensordict else: # we made sure that if in_keys is not None, out_keys is not None either for in_key, out_key in zip(in_keys, out_keys): diff --git a/torchrl/record/loggers/csv.py b/torchrl/record/loggers/csv.py index 3f188a02a61..dc3aff2ad6b 100644 --- a/torchrl/record/loggers/csv.py +++ b/torchrl/record/loggers/csv.py @@ -43,7 +43,8 @@ def add_scalar(self, name: str, value: float, global_step: Optional[int] = None) if not os.path.isfile(filepath): os.makedirs(Path(filepath).parent, exist_ok=True) if filepath not in self.files: - self.files[filepath] = open(filepath, "a") + os.makedirs(Path(filepath).parent, exist_ok=True) + self.files[filepath] = open(filepath, "a+") fd = self.files[filepath] fd.write(",".join([str(global_step), str(value)]) + "\n") fd.flush() diff --git a/torchrl/record/loggers/utils.py b/torchrl/record/loggers/utils.py index ec7321f5bbd..226135f333f 100644 --- a/torchrl/record/loggers/utils.py +++ b/torchrl/record/loggers/utils.py @@ -44,7 +44,9 @@ def get_logger( elif logger_type == "csv": from torchrl.record.loggers.csv import CSVLogger - logger = CSVLogger(log_dir=logger_name, exp_name=experiment_name) + logger = CSVLogger( + log_dir=logger_name, exp_name=experiment_name, video_format="mp4" + ) elif logger_type == "wandb": from torchrl.record.loggers.wandb import WandbLogger