Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] A PixelRenderTransform #2099

Merged
merged 11 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,7 @@ Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check.
DiscreteTensorSpec
MultiDiscreteTensorSpec
MultiOneHotDiscreteTensorSpec
NonTensorSpec
OneHotDiscreteTensorSpec
UnboundedContinuousTensorSpec
UnboundedDiscreteTensorSpec
Expand Down
40 changes: 40 additions & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,46 @@ to always know what the latest available actions are. You can do this like so:
Recorders
---------

Recording data during environment rollout execution is crucial to keep an eye on the algorithm performance as well as
reporting results after training.

TorchRL offers several tools to interact with the environment output: first and foremost, a ``callback`` callable
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BIt of a sidenote, but at least in current public doc version the signature of the callback is not very clear. Maybe clarify here (or ideally in rollout's doc) that the callback is called with callback(self, tensordict)?

can be passed to the :meth:`~torchrl.envs.EnvBase.rollout` method. This function will be called upon the collected
tensordict at each iteration of the rollout (if some iterations have to be skipped, an internal variable should be added
to keep track of the call count within ``callback``).

To save collected tensordicts on disk, the :class:`~torchrl.record.TensorDictRecorder` can be used.

Recording videos
~~~~~~~~~~~~~~~~

Several backends offer the possibility of recording rendered images from the environment.
If the pixels are already part of the environment output (e.g. Atari or other game simulators), a
:class:`~torchrl.record.VideoRecorder` can be appended to the environment. This environment transform takes as input
a logger capable of recording videos (e.g. :class:`~torchrl.record.loggers.CSVLogger`, :class:`~torchrl.record.loggers.WandbLogger`
or :class:`~torchrl.record.loggers.TensorBoardLogger`) as well as a tag indicating where the video should be saved.
For instance, to save mp4 videos on disk, one can use :class:`~torchrl.record.loggers.CSVLogger` with a `video_format="mp4"`
argument.

The :class:`~torchrl.record.VideoRecorder` transform can handle batched images and automatically detects numpy or PyTorch
formatted images (WHC or CWH).

>>> logger = CSVLogger("dummy-exp", video_format="mp4")
>>> env = GymEnv("ALE/Pong-v5")
>>> env = env.append_transform(VideoRecorder(logger, tag="rendered", in_keys=["pixels"]))
>>> env.rollout(10)
>>> env.transform.dump() # Save the video and clear cache

Note that the cache of the transform will keep on growing until dump is called. It is the user responsibility to
take care of calling dumpy when needed to avoid OOM issues.

In some cases, creating a testing environment where images can be collected is tedious or expensive, or simply impossible
(some libraries only allow one environment instance per workspace).
In these cases, assuming that a `render` method is available in the environment, the :class:`~torchrl.record.PixelRenderTransform`
can be used to call `render` on the parent environment and save the images in the rollout data stream.
This class should only be used within the same process as the environment that is being rendered (remote calls to `render`
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class should only be used within the same process as the environment that is being rendered

This part is bit unclear. Is it saying that, for example, using ParallelEnv with multiprocessing won't work, and the environments need to be serial/in the same process for things to work?

are not allowed).

.. currentmodule:: torchrl.record

Recorders are transforms that register data as they come in, for logging purposes.
Expand Down
40 changes: 39 additions & 1 deletion test/test_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import argparse
import importlib.util
import os
import os.path
import pathlib
Expand All @@ -12,12 +13,14 @@

import pytest
import torch

from tensordict import MemoryMappedTensor

from torchrl.envs import check_env_specs, GymEnv, ParallelEnv
from torchrl.record.loggers.csv import CSVLogger
from torchrl.record.loggers.mlflow import _has_mlflow, _has_tv, MLFlowLogger
from torchrl.record.loggers.tensorboard import _has_tb, TensorboardLogger
from torchrl.record.loggers.wandb import _has_wandb, WandbLogger
from torchrl.record.recorder import PixelRenderTransform, VideoRecorder

if _has_tv:
import torchvision
Expand All @@ -28,6 +31,11 @@
if _has_mlflow:
import mlflow

_has_gym = (
importlib.util.find_spec("gym", None) is not None
or importlib.util.find_spec("gymnasium", None) is not None
)


@pytest.fixture
def tb_logger(tmp_path_factory):
Expand Down Expand Up @@ -397,6 +405,36 @@ def test_log_hparams(self, mlflow_fixture, config):
logger.log_hparams(config)


@pytest.mark.skipif(not _has_gym, reason="gym required to test rendering")
class TestPixelRenderTransform:
@pytest.mark.parametrize("parallel", [False, True])
@pytest.mark.parametrize("in_key", ["pixels", ("nested", "pix")])
def test_pixel_render(self, parallel, in_key, tmpdir):
def make_env():
env = GymEnv("CartPole-v1", render_mode="rgb_array", device=None)
env = env.append_transform(PixelRenderTransform(out_keys=in_key))
return env

if parallel:
env = ParallelEnv(2, make_env, mp_start_method="spawn")
else:
env = make_env()
logger = CSVLogger("dummy", log_dir=tmpdir)
try:
env = env.append_transform(
VideoRecorder(logger=logger, in_keys=[in_key], tag="pixels_record")
)
check_env_specs(env)
env.rollout(10)
env.transform.dump()
assert os.path.isfile(
os.path.join(tmpdir, "dummy", "videos", "pixels_record_0.pt")
)
finally:
if not env.is_closed:
env.close()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
28 changes: 28 additions & 0 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
LazyStackedCompositeSpec,
MultiDiscreteTensorSpec,
MultiOneHotDiscreteTensorSpec,
NonTensorSpec,
OneHotDiscreteTensorSpec,
TensorSpec,
UnboundedContinuousTensorSpec,
Expand Down Expand Up @@ -1462,6 +1463,14 @@ def test_multionehot(self, shape1, shape2):
assert spec2.rand().shape == spec2.shape
assert spec2.zero().shape == spec2.shape

def test_non_tensor(self):
spec = NonTensorSpec((3, 4), device="cpu")
assert (
spec.expand(2, 3, 4)
== spec.expand((2, 3, 4))
== NonTensorSpec((2, 3, 4), device="cpu")
)

@pytest.mark.parametrize("shape1", [None, (), (5,)])
@pytest.mark.parametrize("shape2", [(), (10,)])
def test_onehot(self, shape1, shape2):
Expand Down Expand Up @@ -1675,6 +1684,11 @@ def test_multionehot(
assert spec == spec.clone()
assert spec is not spec.clone()

def test_non_tensor(self):
spec = NonTensorSpec(shape=(3, 4), device="cpu")
assert spec.clone() == spec
assert spec.clone() is not spec

@pytest.mark.parametrize("shape1", [None, (), (5,)])
def test_onehot(
self,
Expand Down Expand Up @@ -1840,6 +1854,11 @@ def test_multionehot(
with pytest.raises(ValueError):
spec.unbind(-1)

def test_non_tensor(self):
spec = NonTensorSpec(shape=(3, 4), device="cpu")
assert spec.unbind(1)[0] == spec[:, 0]
assert spec.unbind(1)[0] is not spec[:, 0]

@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_onehot(
self,
Expand Down Expand Up @@ -2114,6 +2133,15 @@ def test_stack_multionehot_zero(self, shape, stack_dim):
r = c.zero()
assert r.shape == c.shape

def test_stack_non_tensor(self, shape, stack_dim):
spec0 = NonTensorSpec(shape=shape, device="cpu")
spec1 = NonTensorSpec(shape=shape, device="cpu")
new_spec = torch.stack([spec0, spec1], stack_dim)
shape_insert = list(shape)
shape_insert.insert(stack_dim, 2)
assert new_spec.shape == torch.Size(shape_insert)
assert new_spec.device == torch.device("cpu")

def test_stack_onehot(self, shape, stack_dim):
n = 5
shape = (*shape, 5)
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
LazyStackedTensorSpec,
MultiDiscreteTensorSpec,
MultiOneHotDiscreteTensorSpec,
NonTensorSpec,
OneHotDiscreteTensorSpec,
TensorSpec,
UnboundedContinuousTensorSpec,
Expand Down
125 changes: 120 additions & 5 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@

import numpy as np
import torch
from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase, unravel_key
from tensordict import (
LazyStackedTensorDict,
NonTensorData,
TensorDict,
TensorDictBase,
unravel_key,
)
from tensordict.utils import _getitem_batch_size, NestedKey

from torchrl._utils import get_binary_env_var
Expand Down Expand Up @@ -715,8 +721,9 @@ def _flatten(self, start_dim, end_dim):
shape = torch.zeros(self.shape, device="meta").flatten(start_dim, end_dim).shape
return self._reshape(shape)

@abc.abstractmethod
def _project(self, val: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
raise NotImplementedError(type(self))

@abc.abstractmethod
def is_in(self, val: torch.Tensor) -> bool:
Expand Down Expand Up @@ -1917,6 +1924,107 @@ def _is_nested_list(index, notuple=False):
return False


class NonTensorSpec(TensorSpec):
"""A spec for non-tensor data."""

def __init__(
self,
shape: Union[torch.Size, int] = _DEFAULT_SHAPE,
device: Optional[DEVICE_TYPING] = None,
dtype: torch.dtype | None = None,
**kwargs,
):
if isinstance(shape, int):
shape = torch.Size([shape])

_, device = _default_dtype_and_device(None, device)
domain = None
super().__init__(
shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs
)

def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensorSpec:
if isinstance(dest, torch.dtype):
dest_dtype = dest
dest_device = self.device
elif dest is None:
return self
else:
dest_dtype = self.dtype
dest_device = torch.device(dest)
if dest_device == self.device and dest_dtype == self.dtype:
return self
return self.__class__(shape=self.shape, device=dest_device, dtype=None)

def clone(self) -> NonTensorSpec:
return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype)

def rand(self, shape):
return NonTensorData(data=None, shape=self.shape, device=self.device)

def zero(self, shape):
return NonTensorData(data=None, shape=self.shape, device=self.device)

def one(self, shape):
return NonTensorData(data=None, shape=self.shape, device=self.device)

def is_in(self, val: torch.Tensor) -> bool:
shape = torch.broadcast_shapes(self.shape, val.shape)
return (
isinstance(val, NonTensorData)
and val.shape == shape
and val.device == self.device
and val.dtype == self.dtype
)

def expand(self, *shape):
if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)):
shape = shape[0]
shape = torch.Size(shape)
if not all(
(old == 1) or (old == new)
for old, new in zip(self.shape, shape[-len(self.shape) :])
):
raise ValueError(
f"The last elements of the expanded shape must match the current one. Got shape={shape} while self.shape={self.shape}."
)
return self.__class__(shape=shape, device=self.device, dtype=None)

def _reshape(self, shape):
return self.__class__(shape=shape, device=self.device, dtype=self.dtype)

def _unflatten(self, dim, sizes):
shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape
return self.__class__(
shape=shape,
device=self.device,
dtype=self.dtype,
)

def __getitem__(self, idx: SHAPE_INDEX_TYPING):
"""Indexes the current TensorSpec based on the provided index."""
indexed_shape = torch.Size(_shape_indexing(self.shape, idx))
return self.__class__(shape=indexed_shape, device=self.device, dtype=self.dtype)

def unbind(self, dim: int):
orig_dim = dim
if dim < 0:
dim = len(self.shape) + dim
if dim < 0:
raise ValueError(
f"Cannot unbind along dim {orig_dim} with shape {self.shape}."
)
shape = tuple(s for i, s in enumerate(self.shape) if i != dim)
return tuple(
self.__class__(
shape=shape,
device=self.device,
dtype=self.dtype,
)
for i in range(self.shape[dim])
)


@dataclass(repr=False)
class UnboundedContinuousTensorSpec(TensorSpec):
"""An unbounded continuous tensor spec.
Expand Down Expand Up @@ -1954,7 +2062,9 @@ def __init__(
shape=shape, space=box, device=device, dtype=dtype, domain=domain, **kwargs
)

def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec:
def to(
self, dest: Union[torch.dtype, DEVICE_TYPING]
) -> UnboundedContinuousTensorSpec:
if isinstance(dest, torch.dtype):
dest_dtype = dest
dest_device = self.device
Expand All @@ -1979,7 +2089,11 @@ def rand(self, shape=None) -> torch.Tensor:
return torch.empty(shape, device=self.device, dtype=self.dtype).random_()

def is_in(self, val: torch.Tensor) -> bool:
return True
shape = torch.broadcast_shapes(self.shape, val.shape)
return val.shape == shape and val.dtype == self.dtype

def _project(self, val: torch.Tensor) -> torch.Tensor:
return torch.as_tensor(val, dtype=self.dtype).reshape(self.shape)

def expand(self, *shape):
if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)):
Expand Down Expand Up @@ -2130,7 +2244,8 @@ def rand(self, shape=None) -> torch.Tensor:
return r.to(self.device)

def is_in(self, val: torch.Tensor) -> bool:
return True
shape = torch.broadcast_shapes(self.shape, val.shape)
return val.shape == shape and val.dtype == self.dtype

def expand(self, *shape):
if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)):
Expand Down
Loading
Loading