diff --git a/test/test_loggers.py b/test/test_loggers.py index 98a330d0daf..7381286703b 100644 --- a/test/test_loggers.py +++ b/test/test_loggers.py @@ -8,16 +8,21 @@ import os.path import pathlib import tempfile +from sys import platform from time import sleep import pytest +import tensordict.utils import torch - +import importlib.util from tensordict import MemoryMappedTensor + +from torchrl.envs import GymEnv, ParallelEnv, check_env_specs from torchrl.record.loggers.csv import CSVLogger from torchrl.record.loggers.mlflow import _has_mlflow, _has_tv, MLFlowLogger from torchrl.record.loggers.tensorboard import _has_tb, TensorboardLogger from torchrl.record.loggers.wandb import _has_wandb, WandbLogger +from torchrl.record.recorder import PixelRenderTransform, VideoRecorder if _has_tv: import torchvision @@ -28,6 +33,7 @@ if _has_mlflow: import mlflow +_has_gym = importlib.util.find_spec("gym", None) is not None or importlib.util.find_spec("gymnasium", None) is not None @pytest.fixture def tb_logger(tmp_path_factory): @@ -396,6 +402,29 @@ def test_log_hparams(self, mlflow_fixture, config): logger, client = mlflow_fixture logger.log_hparams(config) +@pytest.mark.skipif(not _has_gym, reason="gym required to test rendering") +class TestPixelRenderTransform: + @pytest.mark.parametrize("parallel", [False, True]) + @pytest.mark.parametrize("in_key", ["pixels", ("nested", "pix")]) + def test_pixel_render(self, parallel, in_key, tmpdir): + def make_env(): + env = GymEnv("CartPole-v1", render_mode="rgb_array", device=None) + env = env.append_transform(PixelRenderTransform(out_keys=in_key)) + return env + if parallel: + env = ParallelEnv(2, make_env, mp_start_method="spawn") + else: + env = make_env() + logger = CSVLogger("dummy", log_dir=tmpdir) + try: + env = env.append_transform(VideoRecorder(logger=logger, in_keys=[in_key], tag="pixels_record")) + check_env_specs(env) + env.rollout(10) + env.transform.dump() + assert os.path.isfile(os.path.join(tmpdir, "dummy", "videos", "pixels_record_0.pt")) + finally: + if not env.is_closed: + env.close() if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 71598938eab..cb327929828 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -31,7 +31,13 @@ import numpy as np import torch -from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase, unravel_key +from tensordict import ( + LazyStackedTensorDict, + NonTensorData, + TensorDict, + TensorDictBase, + unravel_key, +) from tensordict.utils import _getitem_batch_size, NestedKey from torchrl._utils import get_binary_env_var @@ -1917,6 +1923,102 @@ def _is_nested_list(index, notuple=False): return False +class NonTensorSpec(TensorSpec): + """A spec for non-tensor data.""" + def __init__( + self, + shape: Union[torch.Size, int] = _DEFAULT_SHAPE, + device: Optional[DEVICE_TYPING] = None, + dtype: torch.dtype | None = None, + **kwargs, + ): + if isinstance(shape, int): + shape = torch.Size([shape]) + + _, device = _default_dtype_and_device(None, device) + domain = None + super().__init__( + shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs + ) + + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensorSpec: + if isinstance(dest, torch.dtype): + dest_dtype = dest + dest_device = self.device + elif dest is None: + return self + else: + dest_dtype = self.dtype + dest_device = torch.device(dest) + if dest_device == self.device and dest_dtype == self.dtype: + return self + return self.__class__(shape=self.shape, device=dest_device, dtype=None) + + def clone(self) -> NonTensorSpec: + return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype) + + def rand(self, shape): + return NonTensorData(data=None, shape=self.shape, device=self.device) + + def zero(self, shape): + return NonTensorData(data=None, shape=self.shape, device=self.device) + + def one(self, shape): + return NonTensorData(data=None, shape=self.shape, device=self.device) + + def is_in(self, val: torch.Tensor) -> bool: + return ( + isinstance(val, NonTensorData) + and val.shape == self.shape + and val.device == self.device + and val.dtype == self.dtype + ) + + def expand(self, *shape): + if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): + shape = shape[0] + shape = torch.Size(shape) + if not shape[-len(self.shape)] == self.shape: + raise ValueError( + f"The last elements of the expanded shape must match the current one. Got shape={shape} while self.shape={self.shape}." + ) + return self.__class__(shape=shape, device=self.device, dtype=None) + + def _reshape(self, shape): + return self.__class__(shape=shape, device=self.device, dtype=self.dtype) + + def _unflatten(self, dim, sizes): + shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape + return self.__class__( + shape=shape, + device=self.device, + dtype=self.dtype, + ) + + def __getitem__(self, idx: SHAPE_INDEX_TYPING): + """Indexes the current TensorSpec based on the provided index.""" + indexed_shape = torch.Size(_shape_indexing(self.shape, idx)) + return self.__class__(shape=indexed_shape, device=self.device, dtype=self.dtype) + + def unbind(self, dim: int): + orig_dim = dim + if dim < 0: + dim = len(self.shape) + dim + if dim < 0: + raise ValueError( + f"Cannot unbind along dim {orig_dim} with shape {self.shape}." + ) + shape = tuple(s for i, s in enumerate(self.shape) if i != dim) + return tuple( + self.__class__( + shape=shape, + device=self.device, + dtype=self.dtype, + ) + for i in range(self.shape[dim]) + ) + + @dataclass(repr=False) class UnboundedContinuousTensorSpec(TensorSpec): """An unbounded continuous tensor spec. @@ -1954,7 +2056,9 @@ def __init__( shape=shape, space=box, device=device, dtype=dtype, domain=domain, **kwargs ) - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + def to( + self, dest: Union[torch.dtype, DEVICE_TYPING] + ) -> UnboundedContinuousTensorSpec: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device @@ -1979,7 +2083,7 @@ def rand(self, shape=None) -> torch.Tensor: return torch.empty(shape, device=self.device, dtype=self.dtype).random_() def is_in(self, val: torch.Tensor) -> bool: - return True + return val.shape == self.shape and val.dtype == self.dtype def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 67636523e46..da071e33ac8 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -16,6 +16,7 @@ from enum import Enum from typing import Any, Dict, List, Union +import tensordict import torch from tensordict import ( @@ -25,6 +26,7 @@ TensorDictBase, unravel_key, ) +from tensordict.base import _is_leaf_nontensor from tensordict.nn import TensorDictModule, TensorDictModuleBase from tensordict.nn.probabilistic import ( # noqa # Note: the `set_interaction_mode` and their associated arg `default_interaction_mode` are being deprecated! @@ -179,10 +181,15 @@ def _is_reset(key: NestedKey): return key == "_reset" return key[-1] == "_reset" - actual = {key for key in tensordict.keys(True, True) if not _is_reset(key)} + actual = { + key + for key in tensordict.keys(True, True, is_leaf=_is_leaf_nontensor) + if not _is_reset(key) + } expected = set(expected) self.validated = expected.intersection(actual) == expected if not self.validated: + raise RuntimeError warnings.warn( "The expected key set and actual key set differ. " "This will work but with a slower throughput than " @@ -243,7 +250,7 @@ def _exclude( cls._exclude(nested_key_dict, td, td_out) return out has_set = False - for key, value in data_in.items(): + for key, value in data_in.items(is_leaf=tensordict.base._is_leaf_nontensor): subdict = nested_key_dict.get(key, NO_DEFAULT) if subdict is NO_DEFAULT: value = value.copy() if is_tensor_collection(value) else value diff --git a/torchrl/record/recorder.py b/torchrl/record/recorder.py index a486b689feb..9e8d681de7c 100644 --- a/torchrl/record/recorder.py +++ b/torchrl/record/recorder.py @@ -6,14 +6,19 @@ import importlib.util from copy import copy -from typing import Optional, Sequence +from typing import Callable, List, Optional, Sequence +import numpy as np import torch -from tensordict import TensorDictBase +from tensordict import NonTensorData, TensorDict, TensorDictBase from tensordict.utils import NestedKey +from torchrl._utils import _can_be_pickled +from torchrl.data import TensorSpec +from torchrl.data.tensor_specs import NonTensorSpec, UnboundedContinuousTensorSpec +from torchrl.data.utils import CloudpickleWrapper from torchrl.envs.transforms import ObservationTransform, Transform from torchrl.record.loggers import Logger @@ -155,20 +160,22 @@ def skip(self, value): self._skip = value def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: + if isinstance(observation, NonTensorData): + observation_trsf = torch.tensor(observation.data) + else: + observation_trsf = observation self.count += 1 if self.count % self.skip == 0: if ( - observation.ndim >= 3 - and observation.shape[-3] == 3 - and observation.shape[-2] > 3 - and observation.shape[-1] > 3 + observation_trsf.ndim >= 3 + and observation_trsf.shape[-3] == 3 + and observation_trsf.shape[-2] > 3 + and observation_trsf.shape[-1] > 3 ): # permute the channels to the last dim - observation_trsf = observation.permute( - *range(observation.ndim - 3), -2, -1, -3 + observation_trsf = observation_trsf.permute( + *range(observation_trsf.ndim - 3), -2, -1, -3 ) - else: - observation_trsf = observation if not ( observation_trsf.shape[-1] == 3 or observation_trsf.ndimension() == 2 ): @@ -321,3 +328,150 @@ def _reset( ) -> TensorDictBase: self._call(tensordict_reset) return tensordict_reset + + +class PixelRenderTransform(Transform): + """A transform to call render on the parent environment and register the pixel observation in the tensordict. + + This transform offers an alternative to the ``from_pixels`` syntatic sugar when instantiating an environment + that offers rendering is expensive, or when ``from_pixels`` is not implemented. + + Args: + out_keys (List[NestedKey] or Nested): List of keys where to register the pixel observations. + preproc (Callable, optional): a preproc function. Can be used to reshape the observation, or apply + any other transformation that makes it possible to register it in the output data. + as_non_tensor (bool, optional): if ``True``, the data will be written as a :class:`~tensordict.NonTensorData` + thereby relaxing the shape requirements. If not provided, it will be inferred automatically from the + input data type and shape. + render_method (str, optional): the name of the render method. Defaults to ``"render"``. + **kwargs: additional keyword arguments to pass to the render function (e.g. ``mode="rgb_array"``). + + Examples: + >>> from torchrl.envs import GymEnv, check_env_specs, ParallelEnv, EnvCreator + >>> from torchrl.record.loggers import CSVLogger + >>> from torchrl.record.recorder import PixelRenderTransform, VideoRecorder + >>> + >>> def make_env(): + >>> env = GymEnv("CartPole-v1", render_mode="rgb_array") + >>> env = env.append_transform(PixelRenderTransform()) + >>> return env + >>> + >>> if __name__ == "__main__": + ... logger = CSVLogger("dummy", video_format="mp4") + ... + ... env = ParallelEnv(4, EnvCreator(make_env)) + ... + ... env = env.append_transform(VideoRecorder(logger=logger, tag="pixels_record")) + ... env.rollout(3) + ... + ... check_env_specs(env) + ... + ... r = env.rollout(30) + ... print(env) + ... env.transform.dump() + ... env.close() + + This transform can also be used whenever a batched environment ``render()`` returns a single image: + + Examples: + >>> from torchrl.envs import check_env_specs + >>> from torchrl.envs.libs.vmas import VmasEnv + >>> from torchrl.record.loggers import CSVLogger + >>> from torchrl.record.recorder import PixelRenderTransform, VideoRecorder + >>> + >>> env = VmasEnv( + ... scenario="flocking", + ... num_envs=32, + ... continuous_actions=True, + ... max_steps=200, + ... device="cpu", + ... seed=None, + ... # Scenario kwargs + ... n_agents=5, + ... ) + >>> + >>> logger = CSVLogger("dummy", video_format="mp4") + >>> + >>> env = env.append_transform(PixelRenderTransform(mode="rgb_array", preproc=lambda x: x.copy())) + >>> env = env.append_transform(VideoRecorder(logger=logger, tag="pixels_record")) + >>> + >>> check_env_specs(env) + >>> + >>> r = env.rollout(30) + >>> env.transform[-1].dump() + + """ + + def __init__( + self, + out_keys: List[NestedKey] = None, + preproc: Callable[ + [np.ndarray | torch.Tensor], np.ndarray | torch.Tensor + ] = None, + as_non_tensor: bool = None, + render_method: str = "render", + **kwargs, + ) -> None: + if out_keys is None: + out_keys = ["pixels"] + elif isinstance(out_keys, (str, tuple)): + out_keys = [out_keys] + if len(out_keys) != 1: + raise RuntimeError( + f"Expected one and only one out_key, got out_keys={out_keys}" + ) + if preproc is not None and not _can_be_pickled(preproc): + preproc = CloudpickleWrapper(preproc) + self.preproc = preproc + self.as_non_tensor = as_non_tensor + self.kwargs = kwargs + self.render_method = render_method + super().__init__(in_keys=[], out_keys=out_keys) + + def _reset( + self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase + ) -> TensorDictBase: + return self._call(tensordict_reset) + + def _call(self, tensordict: TensorDictBase) -> TensorDictBase: + array = getattr(self.parent, self.render_method)(**self.kwargs) + if self.preproc: + array = self.preproc(array) + if self.as_non_tensor is None: + if ( + array.ndim == 3 + and array.shape[-1] == 3 + and self.parent.batch_size != () + ): + self.as_non_tensor = True + else: + self.as_non_tensor = False + if not self.as_non_tensor: + try: + tensordict.set(self.out_keys[0], array) + except Exception: + raise RuntimeError( + f"An exception was raised while writing the rendered array " + f"(shape={getattr(array, 'shape', None)}, dtype={getattr(array, 'dtype', None)}) in the tensordict with shape {tensordict.shape}. " + f"Consider adapting your preproc function in {type(self).__name__}. You can also " + f"pass keyword arguments to the render function of the parent environment, or save " + f"this observation as a non-tensor data with as_non_tensor=True." + ) + else: + tensordict.set_non_tensor(self.out_keys[0], array) + return tensordict + + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + # Adds the pixel observation spec by calling render on the parent env + parent = self.parent + td_in = TensorDict({}, batch_size=parent.batch_size, device=parent.device) + self._call(td_in) + obs = td_in.get(self.out_keys[0]) + if isinstance(obs, NonTensorData): + spec = NonTensorSpec(device=obs.device, dtype=obs.dtype, shape=obs.shape) + else: + spec = UnboundedContinuousTensorSpec( + device=obs.device, dtype=obs.dtype, shape=obs.shape + ) + observation_spec[self.out_keys[0]] = spec + return observation_spec