Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Apr 23, 2024
1 parent 0ea236d commit 8e09bfa
Show file tree
Hide file tree
Showing 4 changed files with 310 additions and 16 deletions.
31 changes: 30 additions & 1 deletion test/test_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,21 @@
import os.path
import pathlib
import tempfile
from sys import platform
from time import sleep

import pytest
import tensordict.utils
import torch

import importlib.util
from tensordict import MemoryMappedTensor

from torchrl.envs import GymEnv, ParallelEnv, check_env_specs
from torchrl.record.loggers.csv import CSVLogger
from torchrl.record.loggers.mlflow import _has_mlflow, _has_tv, MLFlowLogger
from torchrl.record.loggers.tensorboard import _has_tb, TensorboardLogger
from torchrl.record.loggers.wandb import _has_wandb, WandbLogger
from torchrl.record.recorder import PixelRenderTransform, VideoRecorder

if _has_tv:
import torchvision
Expand All @@ -28,6 +33,7 @@
if _has_mlflow:
import mlflow

_has_gym = importlib.util.find_spec("gym", None) is not None or importlib.util.find_spec("gymnasium", None) is not None

@pytest.fixture
def tb_logger(tmp_path_factory):
Expand Down Expand Up @@ -396,6 +402,29 @@ def test_log_hparams(self, mlflow_fixture, config):
logger, client = mlflow_fixture
logger.log_hparams(config)

@pytest.mark.skipif(not _has_gym, reason="gym required to test rendering")
class TestPixelRenderTransform:
@pytest.mark.parametrize("parallel", [False, True])
@pytest.mark.parametrize("in_key", ["pixels", ("nested", "pix")])
def test_pixel_render(self, parallel, in_key, tmpdir):
def make_env():
env = GymEnv("CartPole-v1", render_mode="rgb_array", device=None)
env = env.append_transform(PixelRenderTransform(out_keys=in_key))
return env
if parallel:
env = ParallelEnv(2, make_env, mp_start_method="spawn")
else:
env = make_env()
logger = CSVLogger("dummy", log_dir=tmpdir)
try:
env = env.append_transform(VideoRecorder(logger=logger, in_keys=[in_key], tag="pixels_record"))
check_env_specs(env)
env.rollout(10)
env.transform.dump()
assert os.path.isfile(os.path.join(tmpdir, "dummy", "videos", "pixels_record_0.pt"))
finally:
if not env.is_closed:
env.close()

if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
110 changes: 107 additions & 3 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@

import numpy as np
import torch
from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase, unravel_key
from tensordict import (
LazyStackedTensorDict,
NonTensorData,
TensorDict,
TensorDictBase,
unravel_key,
)
from tensordict.utils import _getitem_batch_size, NestedKey

from torchrl._utils import get_binary_env_var
Expand Down Expand Up @@ -1917,6 +1923,102 @@ def _is_nested_list(index, notuple=False):
return False


class NonTensorSpec(TensorSpec):
"""A spec for non-tensor data."""
def __init__(
self,
shape: Union[torch.Size, int] = _DEFAULT_SHAPE,
device: Optional[DEVICE_TYPING] = None,
dtype: torch.dtype | None = None,
**kwargs,
):
if isinstance(shape, int):
shape = torch.Size([shape])

_, device = _default_dtype_and_device(None, device)
domain = None
super().__init__(
shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs
)

def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensorSpec:
if isinstance(dest, torch.dtype):
dest_dtype = dest
dest_device = self.device
elif dest is None:
return self
else:
dest_dtype = self.dtype
dest_device = torch.device(dest)
if dest_device == self.device and dest_dtype == self.dtype:
return self
return self.__class__(shape=self.shape, device=dest_device, dtype=None)

def clone(self) -> NonTensorSpec:
return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype)

def rand(self, shape):
return NonTensorData(data=None, shape=self.shape, device=self.device)

def zero(self, shape):
return NonTensorData(data=None, shape=self.shape, device=self.device)

def one(self, shape):
return NonTensorData(data=None, shape=self.shape, device=self.device)

def is_in(self, val: torch.Tensor) -> bool:
return (
isinstance(val, NonTensorData)
and val.shape == self.shape
and val.device == self.device
and val.dtype == self.dtype
)

def expand(self, *shape):
if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)):
shape = shape[0]
shape = torch.Size(shape)
if not shape[-len(self.shape)] == self.shape:
raise ValueError(
f"The last elements of the expanded shape must match the current one. Got shape={shape} while self.shape={self.shape}."
)
return self.__class__(shape=shape, device=self.device, dtype=None)

def _reshape(self, shape):
return self.__class__(shape=shape, device=self.device, dtype=self.dtype)

def _unflatten(self, dim, sizes):
shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape
return self.__class__(
shape=shape,
device=self.device,
dtype=self.dtype,
)

def __getitem__(self, idx: SHAPE_INDEX_TYPING):
"""Indexes the current TensorSpec based on the provided index."""
indexed_shape = torch.Size(_shape_indexing(self.shape, idx))
return self.__class__(shape=indexed_shape, device=self.device, dtype=self.dtype)

def unbind(self, dim: int):
orig_dim = dim
if dim < 0:
dim = len(self.shape) + dim
if dim < 0:
raise ValueError(
f"Cannot unbind along dim {orig_dim} with shape {self.shape}."
)
shape = tuple(s for i, s in enumerate(self.shape) if i != dim)
return tuple(
self.__class__(
shape=shape,
device=self.device,
dtype=self.dtype,
)
for i in range(self.shape[dim])
)


@dataclass(repr=False)
class UnboundedContinuousTensorSpec(TensorSpec):
"""An unbounded continuous tensor spec.
Expand Down Expand Up @@ -1954,7 +2056,9 @@ def __init__(
shape=shape, space=box, device=device, dtype=dtype, domain=domain, **kwargs
)

def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec:
def to(
self, dest: Union[torch.dtype, DEVICE_TYPING]
) -> UnboundedContinuousTensorSpec:
if isinstance(dest, torch.dtype):
dest_dtype = dest
dest_device = self.device
Expand All @@ -1979,7 +2083,7 @@ def rand(self, shape=None) -> torch.Tensor:
return torch.empty(shape, device=self.device, dtype=self.dtype).random_()

def is_in(self, val: torch.Tensor) -> bool:
return True
return val.shape == self.shape and val.dtype == self.dtype

def expand(self, *shape):
if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)):
Expand Down
11 changes: 9 additions & 2 deletions torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from enum import Enum
from typing import Any, Dict, List, Union

import tensordict
import torch

from tensordict import (
Expand All @@ -25,6 +26,7 @@
TensorDictBase,
unravel_key,
)
from tensordict.base import _is_leaf_nontensor
from tensordict.nn import TensorDictModule, TensorDictModuleBase
from tensordict.nn.probabilistic import ( # noqa
# Note: the `set_interaction_mode` and their associated arg `default_interaction_mode` are being deprecated!
Expand Down Expand Up @@ -179,10 +181,15 @@ def _is_reset(key: NestedKey):
return key == "_reset"
return key[-1] == "_reset"

actual = {key for key in tensordict.keys(True, True) if not _is_reset(key)}
actual = {
key
for key in tensordict.keys(True, True, is_leaf=_is_leaf_nontensor)
if not _is_reset(key)
}
expected = set(expected)
self.validated = expected.intersection(actual) == expected
if not self.validated:
raise RuntimeError
warnings.warn(
"The expected key set and actual key set differ. "
"This will work but with a slower throughput than "
Expand Down Expand Up @@ -243,7 +250,7 @@ def _exclude(
cls._exclude(nested_key_dict, td, td_out)
return out
has_set = False
for key, value in data_in.items():
for key, value in data_in.items(is_leaf=tensordict.base._is_leaf_nontensor):
subdict = nested_key_dict.get(key, NO_DEFAULT)
if subdict is NO_DEFAULT:
value = value.copy() if is_tensor_collection(value) else value
Expand Down
Loading

0 comments on commit 8e09bfa

Please sign in to comment.