diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 93da37b5068..6b963ce9ca1 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -27,7 +27,7 @@ jobs: strategy: matrix: python_version: ["3.10"] - cuda_arch_version: ["12.1"] + cuda_arch_version: ["12.4"] uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: repository: pytorch/rl diff --git a/.github/workflows/test-linux-habitat.yml b/.github/workflows/test-linux-habitat.yml index 57564a2f6fa..95aad59f763 100644 --- a/.github/workflows/test-linux-habitat.yml +++ b/.github/workflows/test-linux-habitat.yml @@ -24,7 +24,7 @@ jobs: strategy: matrix: python_version: ["3.9"] - cuda_arch_version: ["12.1"] + cuda_arch_version: ["12.4"] fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: diff --git a/.github/workflows/test-linux-libs.yml b/.github/workflows/test-linux-libs.yml index 87580d67235..df7c3279ebd 100644 --- a/.github/workflows/test-linux-libs.yml +++ b/.github/workflows/test-linux-libs.yml @@ -25,7 +25,7 @@ jobs: strategy: matrix: python_version: ["3.9"] - cuda_arch_version: ["12.1"] + cuda_arch_version: ["12.4"] if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }} uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: @@ -59,7 +59,7 @@ jobs: strategy: matrix: python_version: ["3.11"] - cuda_arch_version: ["12.1"] + cuda_arch_version: ["12.4"] if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }} uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: @@ -96,7 +96,7 @@ jobs: strategy: matrix: python_version: ["3.9"] - cuda_arch_version: ["12.1"] + cuda_arch_version: ["12.4"] if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }} uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: @@ -131,7 +131,7 @@ jobs: strategy: matrix: python_version: ["3.9"] - cuda_arch_version: ["12.1"] + cuda_arch_version: ["12.4"] if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }} uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: @@ -166,7 +166,7 @@ jobs: strategy: matrix: python_version: ["3.9"] - cuda_arch_version: ["12.1"] + cuda_arch_version: ["12.4"] if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }} uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: @@ -200,7 +200,7 @@ jobs: strategy: matrix: python_version: ["3.9"] - cuda_arch_version: ["12.1"] + cuda_arch_version: ["12.4"] uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: repository: pytorch/rl @@ -235,7 +235,7 @@ jobs: strategy: matrix: python_version: ["3.9"] - cuda_arch_version: ["12.1"] + cuda_arch_version: ["12.4"] if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }} uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: @@ -256,7 +256,7 @@ jobs: set -euo pipefail export PYTHON_VERSION="3.9" - export CU_VERSION="12.1" + export CU_VERSION="12.4" export TAR_OPTIONS="--no-same-owner" export UPLOAD_CHANNEL="nightly" export TF_CPP_MIN_LOG_LEVEL=0 @@ -277,7 +277,7 @@ jobs: repository: pytorch/rl runner: "linux.g5.4xlarge.nvidia.gpu" gpu-arch-type: cuda - gpu-arch-version: "12.1" + gpu-arch-version: "12.4" docker-image: "nvidia/cuda:12.4.1-runtime-ubuntu22.04" timeout: 120 script: | @@ -291,7 +291,7 @@ jobs: set -euo pipefail export PYTHON_VERSION="3.11" - export CU_VERSION="12.1" + export CU_VERSION="12.4" export TAR_OPTIONS="--no-same-owner" export UPLOAD_CHANNEL="nightly" export TF_CPP_MIN_LOG_LEVEL=0 @@ -309,7 +309,7 @@ jobs: strategy: matrix: python_version: ["3.9"] - cuda_arch_version: ["12.1"] + cuda_arch_version: ["12.4"] if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }} uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: @@ -330,7 +330,7 @@ jobs: set -euo pipefail export PYTHON_VERSION="3.9" - export CU_VERSION="12.1" + export CU_VERSION="12.4" export TAR_OPTIONS="--no-same-owner" export UPLOAD_CHANNEL="nightly" export TF_CPP_MIN_LOG_LEVEL=0 @@ -347,7 +347,7 @@ jobs: strategy: matrix: python_version: ["3.9"] - cuda_arch_version: ["12.1"] + cuda_arch_version: ["12.4"] if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }} uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: @@ -368,7 +368,7 @@ jobs: set -euo pipefail export PYTHON_VERSION="3.9" - export CU_VERSION="12.1" + export CU_VERSION="12.4" export TAR_OPTIONS="--no-same-owner" export UPLOAD_CHANNEL="nightly" export TF_CPP_MIN_LOG_LEVEL=0 @@ -385,7 +385,7 @@ jobs: strategy: matrix: python_version: ["3.10.12"] - cuda_arch_version: ["12.1"] + cuda_arch_version: ["12.4"] if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }} uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: @@ -406,7 +406,7 @@ jobs: set -euo pipefail export PYTHON_VERSION="3.10.12" - export CU_VERSION="12.1" + export CU_VERSION="12.4" export TAR_OPTIONS="--no-same-owner" export UPLOAD_CHANNEL="nightly" export TF_CPP_MIN_LOG_LEVEL=0 @@ -423,7 +423,7 @@ jobs: strategy: matrix: python_version: ["3.9"] - cuda_arch_version: ["12.1"] + cuda_arch_version: ["12.4"] if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }} uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: @@ -458,7 +458,7 @@ jobs: strategy: matrix: python_version: ["3.9"] - cuda_arch_version: ["12.1"] + cuda_arch_version: ["12.4"] if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }} uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: @@ -510,7 +510,7 @@ jobs: set -euo pipefail export PYTHON_VERSION="3.9" - export CU_VERSION="12.1" + export CU_VERSION="12.4" export TAR_OPTIONS="--no-same-owner" export UPLOAD_CHANNEL="nightly" export TF_CPP_MIN_LOG_LEVEL=0 @@ -528,7 +528,7 @@ jobs: strategy: matrix: python_version: ["3.9"] - cuda_arch_version: ["12.1"] + cuda_arch_version: ["12.4"] if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }} uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: @@ -562,7 +562,7 @@ jobs: strategy: matrix: python_version: ["3.9"] - cuda_arch_version: ["12.1"] + cuda_arch_version: ["12.4"] if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }} uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: @@ -597,7 +597,7 @@ jobs: strategy: matrix: python_version: ["3.9"] - cuda_arch_version: ["12.1"] + cuda_arch_version: ["12.4"] uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: repository: pytorch/rl @@ -633,7 +633,7 @@ jobs: strategy: matrix: python_version: ["3.9"] - cuda_arch_version: ["12.1"] + cuda_arch_version: ["12.4"] if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }} uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: @@ -654,7 +654,7 @@ jobs: set -euo pipefail export PYTHON_VERSION="3.9" - export CU_VERSION="12.1" + export CU_VERSION="12.4" export TAR_OPTIONS="--no-same-owner" export UPLOAD_CHANNEL="nightly" export TF_CPP_MIN_LOG_LEVEL=0 @@ -672,7 +672,7 @@ jobs: strategy: matrix: python_version: ["3.9"] - cuda_arch_version: ["12.1"] + cuda_arch_version: ["12.4"] if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }} uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: @@ -707,7 +707,7 @@ jobs: strategy: matrix: python_version: ["3.9"] - cuda_arch_version: ["12.1"] + cuda_arch_version: ["12.4"] if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }} uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: @@ -728,7 +728,7 @@ jobs: set -euo pipefail export PYTHON_VERSION="3.9" - export CU_VERSION="12.1" + export CU_VERSION="12.4" export TAR_OPTIONS="--no-same-owner" export UPLOAD_CHANNEL="nightly" export TF_CPP_MIN_LOG_LEVEL=0 diff --git a/.github/workflows/test-linux-rlhf.yml b/.github/workflows/test-linux-rlhf.yml index 2e647476b69..1b4e04d95f8 100644 --- a/.github/workflows/test-linux-rlhf.yml +++ b/.github/workflows/test-linux-rlhf.yml @@ -24,7 +24,7 @@ jobs: strategy: matrix: python_version: ["3.9"] - cuda_arch_version: ["12.1"] + cuda_arch_version: ["12.4"] uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: repository: pytorch/rl diff --git a/.github/workflows/test-linux-sota.yml b/.github/workflows/test-linux-sota.yml index edab7e935ea..589761bfe8b 100644 --- a/.github/workflows/test-linux-sota.yml +++ b/.github/workflows/test-linux-sota.yml @@ -27,7 +27,7 @@ jobs: strategy: matrix: python_version: ["3.9"] - cuda_arch_version: ["12.1"] + cuda_arch_version: ["12.4"] fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: diff --git a/.github/workflows/test-linux.yml b/.github/workflows/test-linux.yml index 15cd067a822..1565e49707e 100644 --- a/.github/workflows/test-linux.yml +++ b/.github/workflows/test-linux.yml @@ -89,7 +89,7 @@ jobs: strategy: matrix: python_version: ["3.11"] - cuda_arch_version: ["12.1"] + cuda_arch_version: ["12.4"] fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: @@ -158,7 +158,7 @@ jobs: strategy: matrix: python_version: ["3.11"] - cuda_arch_version: ["12.1"] + cuda_arch_version: ["12.4"] fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 852592992b9..607be49211a 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -148,6 +148,7 @@ using the following components: LazyMemmapStorage LazyTensorStorage ListStorage + LazyStackStorage ListStorageCheckpointer NestedStorageCheckpointer PrioritizedSampler diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index b9ba8276796..da9c2114161 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -1118,7 +1118,7 @@ in the relevant functions: >>> print(env2._env.env.env) -We can see that the two libraries modify the value returned by :func:`~.gym.gym_backend()` +We can see that the two libraries modify the value returned by :func:`~torchrl.envs.gym.gym_backend()` which can be further used to indicate which library needs to be used for the current computation. :class:`~.gym.set_gym_backend` is also a decorator: we can use it to tell to a specific function what gym backend needs to be used @@ -1189,3 +1189,4 @@ the following function will return ``1`` when queried: VmasWrapper gym_backend set_gym_backend + register_gym_spec_conversion diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 71375fd13a2..d8dfeb6c2d1 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1068,6 +1068,11 @@ def _step( return tensordict +def get_random_string(min_size, max_size): + size = random.randint(min_size, max_size) + return "".join(random.choice(string.ascii_lowercase) for _ in range(size)) + + class CountingEnvWithString(CountingEnv): def __init__(self, *args, **kwargs): self.max_size = kwargs.pop("max_size", 30) @@ -1083,8 +1088,7 @@ def __init__(self, *args, **kwargs): ) def get_random_string(self): - size = random.randint(self.min_size, self.max_size) - return "".join(random.choice(string.ascii_lowercase) for _ in range(size)) + return get_random_string(self.min_size, self.max_size) def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: res = super()._reset(tensordict, **kwargs) @@ -2202,3 +2206,39 @@ def _step( def _set_seed(self, seed): ... + + +class Str2StrEnv(EnvBase): + def __init__(self, min_size=4, max_size=10, **kwargs): + self.observation_spec = Composite( + observation=NonTensor(example_data="an observation!", shape=()) + ) + self.full_action_spec = Composite( + action=NonTensor(example_data="an action!", shape=()) + ) + self.reward_spec = Unbounded(shape=(1,), dtype=torch.float) + self.min_size = min_size + self.max_size = max_size + super().__init__(**kwargs) + + def _step(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: + assert isinstance(tensordict["action"], str) + out = tensordict.empty() + out.set("observation", self.get_random_string()) + out.set("done", torch.zeros(1, dtype=torch.bool).bernoulli_(0.01)) + out.set("reward", torch.zeros(1, dtype=torch.float).bernoulli_(0.01)) + return out + + def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: + out = tensordict.empty() if tensordict is not None else TensorDict() + out.set("observation", self.get_random_string()) + out.set("done", torch.zeros(1, dtype=torch.bool).bernoulli_(0.01)) + return out + + def get_random_string(self): + return get_random_string(self.min_size, self.max_size) + + def _set_seed(self, seed: Optional[int]): + random.seed(seed) + torch.manual_seed(0) + return seed diff --git a/test/test_collector.py b/test/test_collector.py index 413ce57ffe3..d2f1c102416 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -5,11 +5,14 @@ from __future__ import annotations import argparse +import contextlib import functools import gc import os import sys +from typing import Optional +from unittest.mock import patch import numpy as np import pytest @@ -469,6 +472,84 @@ def test_output_device(self, main_device, storing_device): break assert data.device == storing_device + class CudaPolicy(TensorDictSequential): + def __init__(self, n_obs): + module = torch.nn.Linear(n_obs, n_obs, device="cuda") + module.weight.data.copy_(torch.eye(n_obs)) + module.bias.data.fill_(0) + m0 = TensorDictModule(module, in_keys=["observation"], out_keys=["hidden"]) + m1 = TensorDictModule( + lambda a: a + 1, in_keys=["hidden"], out_keys=["action"] + ) + super().__init__(m0, m1) + + class GoesThroughEnv(EnvBase): + def __init__(self, n_obs, device): + self.observation_spec = Composite(observation=Unbounded(n_obs)) + self.action_spec = Unbounded(n_obs) + self.reward_spec = Unbounded(1) + self.full_done_specs = Composite(done=Unbounded(1, dtype=torch.bool)) + super().__init__(device=device) + + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + a = tensordict["action"] + if self.device is not None: + assert a.device == self.device + out = tensordict.empty() + out["observation"] = tensordict["observation"] + ( + a - tensordict["observation"] + ) + out["reward"] = torch.zeros((1,), device=self.device) + out["done"] = torch.zeros((1,), device=self.device, dtype=torch.bool) + return out + + def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: + return self.full_done_specs.zeros().update(self.observation_spec.zeros()) + + def _set_seed(self, seed: Optional[int]): + return seed + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device") + @pytest.mark.parametrize("env_device", ["cuda:0", "cpu"]) + @pytest.mark.parametrize("storing_device", [None, "cuda:0", "cpu"]) + @pytest.mark.parametrize("no_cuda_sync", [True, False]) + def test_no_synchronize(self, env_device, storing_device, no_cuda_sync): + """Tests that no_cuda_sync avoids any call to torch.cuda.synchronize() and that the data is not corrupted.""" + should_raise = not no_cuda_sync + should_raise = should_raise & ( + (env_device == "cpu") or (storing_device == "cpu") + ) + with patch("torch.cuda.synchronize") as mock_synchronize, pytest.raises( + AssertionError, match="Expected 'synchronize' to not have been called." + ) if should_raise else contextlib.nullcontext(): + collector = SyncDataCollector( + create_env_fn=functools.partial( + self.GoesThroughEnv, n_obs=1000, device=None + ), + policy=self.CudaPolicy(n_obs=1000), + frames_per_batch=100, + total_frames=1000, + env_device=env_device, + storing_device=storing_device, + policy_device="cuda:0", + no_cuda_sync=no_cuda_sync, + ) + assert collector.env.device == torch.device(env_device) + i = 0 + for d in collector: + for _d in d.unbind(0): + u = _d["observation"].unique() + assert u.numel() == 1, i + assert u == i, i + i += 1 + u = _d["next", "observation"].unique() + assert u.numel() == 1, i + assert u == i, i + mock_synchronize.assert_not_called() + # @pytest.mark.skipif( # IS_WINDOWS and PYTHON_3_10, diff --git a/test/test_env.py b/test/test_env.py index 600495a04c1..d215f858fa6 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -57,6 +57,7 @@ MultiKeyCountingEnv, MultiKeyCountingEnvPolicy, NestedCountingEnv, + Str2StrEnv, ) else: from _utils_internal import ( @@ -95,6 +96,7 @@ MultiKeyCountingEnv, MultiKeyCountingEnvPolicy, NestedCountingEnv, + Str2StrEnv, ) from packaging import version from tensordict import ( @@ -133,6 +135,7 @@ AutoResetTransform, Tokenizer, Transform, + UnsqueezeTransform, ) from torchrl.envs.utils import ( _StepMDP, @@ -174,6 +177,7 @@ _has_chess = importlib.util.find_spec("chess") is not None _has_tv = importlib.util.find_spec("torchvision") is not None _has_cairosvg = importlib.util.find_spec("cairosvg") is not None +_has_transformers = importlib.util.find_spec("transformers") is not None ## TO BE FIXED: DiscreteActionProjection queries a randint on each worker, which leads to divergent results between ## the serial and parallel batched envs # def _make_atari_env(atari_env): @@ -2614,6 +2618,7 @@ def test_parallel( NestedCountingEnv, HeterogeneousCountingEnv, MultiKeyCountingEnv, + Str2StrEnv, ], ) def test_mocking_envs(envclass): @@ -3285,8 +3290,9 @@ def test_dynamic_rollout(self): RuntimeError, match="The environment specs are dynamic. Call rollout with return_contiguous=False", ): - rollout = env.rollout(4) - rollout = env.rollout(4, return_contiguous=False) + env.rollout(4, return_contiguous=True) + env.rollout(4) + env.rollout(4, return_contiguous=False) check_env_specs(env, return_contiguous=False) @pytest.mark.skipif(not _has_gym, reason="requires gym to be installed") @@ -3440,6 +3446,96 @@ def test_partial_rest(self, batched): assert s_["string"] == ["0", "6"] assert s["next", "string"] == ["6", "6"] + @pytest.mark.skipif(not _has_transformers, reason="transformers required") + def test_str2str_env_tokenizer(self): + env = Str2StrEnv() + env.set_seed(0) + env = env.append_transform( + Tokenizer( + in_keys=["observation"], + out_keys=["obs_tokens"], + in_keys_inv=["action"], + out_keys_inv=["action_tokens"], + ) + ) + env.check_env_specs() + assert env._has_dynamic_specs + r = env.rollout(3, return_contiguous=False) + assert len(r) == 3 + assert isinstance(r["observation"], list) + r = r.densify(layout=torch.jagged) + assert isinstance(r["observation"], list) + assert isinstance(r["obs_tokens"], torch.Tensor) + assert isinstance(r["action_tokens"], torch.Tensor) + + @pytest.mark.skipif(not _has_transformers, reason="transformers required") + def test_str2str_env_tokenizer_catframes(self): + """Tests that we can use Unsqueeze + CatFrames with tokenized strings of variable lengths.""" + env = Str2StrEnv() + env.set_seed(0) + env = env.append_transform( + Tokenizer( + in_keys=["observation"], + out_keys=["obs_tokens"], + in_keys_inv=["action"], + out_keys_inv=["action_tokens"], + # We must use max_length otherwise we can't call cat + # Perhaps we could use NJT here? + max_length=10, + ) + ) + env = env.append_transform( + UnsqueezeTransform( + dim=-2, in_keys=["obs_tokens"], out_keys=["obs_tokens_cat"] + ), + ) + env = env.append_transform(CatFrames(N=4, dim=-2, in_keys=["obs_tokens_cat"])) + r = env.rollout(3) + assert r["obs_tokens_cat"].shape == (3, 4, 10) + + @pytest.mark.skipif(not _has_transformers, reason="transformers required") + def test_str2str_rb_slicesampler(self): + """Dedicated test for replay buffer sampling of trajectories with variable token length""" + from torchrl.data import LazyStackStorage, ReplayBuffer, SliceSampler + from torchrl.envs import TrajCounter + + env = Str2StrEnv() + env.set_seed(0) + env = env.append_transform( + Tokenizer( + in_keys=["observation"], + out_keys=["obs_tokens"], + in_keys_inv=["action"], + out_keys_inv=["action_tokens"], + ) + ) + env = env.append_transform(StepCounter(max_steps=10)) + env = env.append_transform(TrajCounter()) + rb = ReplayBuffer( + storage=LazyStackStorage(100), + sampler=SliceSampler(slice_len=10, end_key=("next", "done")), + ) + r0 = env.rollout(20, break_when_any_done=False) + rb.extend(r0) + has_0 = False + has_1 = False + for _ in range(100): + v0 = rb.sample(10) + assert (v0["step_count"].squeeze() == torch.arange(10)).all() + assert (v0["next", "step_count"].squeeze() == torch.arange(1, 11)).all() + try: + traj = v0["traj_count"].unique().item() + except Exception: + raise RuntimeError( + f"More than one traj found in single slice: {v0['traj_count']}" + ) + has_0 |= traj == 0 + has_1 |= traj == 1 + if has_0 and has_1: + break + else: + raise RuntimeError("Failed to sample both trajs") + # fen strings for board positions generated with: # https://lichess.org/editor @@ -3675,6 +3771,7 @@ def test_reward( assert td["reward"] == expected_reward assert td["turn"] == (not expected_turn) + @pytest.mark.skipif(not _has_transformers, reason="transformers required") def test_chess_tokenized(self): env = ChessEnv(include_fen=True, stateful=True, include_san=True) assert isinstance(env.observation_spec["fen"], NonTensor) diff --git a/test/test_libs.py b/test/test_libs.py index d5d2ea88c1d..8db543c146a 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -79,6 +79,7 @@ Composite, MultiCategorical, MultiOneHot, + NonTensor, OneHot, ReplayBuffer, ReplayBufferEnsemble, @@ -119,6 +120,7 @@ GymWrapper, MOGymEnv, MOGymWrapper, + register_gym_spec_conversion, set_gym_backend, ) from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv @@ -337,6 +339,39 @@ def test_gym_spec_cast(self, categorical): assert spec == recon assert recon.shape == spec.shape + def test_gym_new_spec_reg(self): + Space = gym_backend("spaces").Space + + class MySpaceParent(Space): + ... + + s_parent = MySpaceParent() + + class MySpaceChild(MySpaceParent): + ... + + # We intentionally register first the child then the parent + @register_gym_spec_conversion(MySpaceChild) + def convert_myspace_child(spec, **kwargs): + return NonTensor((), example_data="child") + + @register_gym_spec_conversion(MySpaceParent) + def convert_myspace_parent(spec, **kwargs): + return NonTensor((), example_data="parent") + + s_child = MySpaceChild() + assert _gym_to_torchrl_spec_transform(s_parent).example_data == "parent" + assert _gym_to_torchrl_spec_transform(s_child).example_data == "child" + + class NoConversionSpace(Space): + ... + + s_no_conv = NoConversionSpace() + with pytest.raises( + KeyError, match="No conversion tool could be found with the gym space" + ): + _gym_to_torchrl_spec_transform(s_no_conv) + @pytest.mark.parametrize("order", ["tuple_seq"]) @implement_for("gym") def test_gym_spec_cast_tuple_sequential(self, order): diff --git a/test/test_rb.py b/test/test_rb.py index a139d34f1a5..b63f888453d 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -81,6 +81,7 @@ from torchrl.data.replay_buffers.storages import ( LazyMemmapStorage, + LazyStackStorage, LazyTensorStorage, ListStorage, StorageEnsemble, @@ -1116,6 +1117,31 @@ def test_storage_inplace_writing_ndim(self, storage_type): assert (rb[:, 10:20] == 0).all() assert len(rb) == 100 + @pytest.mark.parametrize("max_size", [1000, None]) + @pytest.mark.parametrize("stack_dim", [-1, 0]) + def test_lazy_stack_storage(self, max_size, stack_dim): + # Create an instance of LazyStackStorage with given parameters + storage = LazyStackStorage(max_size=max_size, stack_dim=stack_dim) + # Create a ReplayBuffer using the created storage + rb = ReplayBuffer(storage=storage) + # Generate some random data to add to the buffer + torch.manual_seed(0) + data0 = TensorDict(a=torch.randn((10,)), b=torch.rand(4), c="a string!") + data1 = TensorDict(a=torch.randn((11,)), b=torch.rand(4), c="another string!") + # Add the data to the buffer + rb.add(data0) + rb.add(data1) + # Sample from the buffer + sample = rb.sample(10) + # Check that the sampled data has the correct shape and type + assert isinstance(sample, LazyStackedTensorDict) + assert sample["b"].shape[0] == 10 + assert all(isinstance(item, str) for item in sample["c"]) + # If densify is True, check that the sampled data is dense + sample = sample.densify(layout=torch.jagged) + assert isinstance(sample["a"], torch.Tensor) + assert sample["a"].shape[0] == 10 + @pytest.mark.parametrize("max_size", [1000]) @pytest.mark.parametrize("shape", [[3, 4]]) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 7e687b02999..a5b3a32ae7d 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -440,6 +440,11 @@ class SyncDataCollector(DataCollectorBase): cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped in :class:`~tensordict.nn.CudaGraphModule` with default kwargs. If a dictionary of kwargs is passed, it will be used to wrap the policy. + no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed. + For environments running directly on CUDA (`IsaacLab `_ + or `ManiSkills `_) cuda synchronization may cause unexpected + crashes. + Defaults to ``False``. Examples: >>> from torchrl.envs.libs.gym import GymEnv @@ -532,6 +537,7 @@ def __init__( trust_policy: bool = None, compile_policy: bool | Dict[str, Any] | None = None, cudagraph_policy: bool | Dict[str, Any] | None = None, + no_cuda_sync: bool = False, **kwargs, ): from torchrl.envs.batched_envs import BatchedEnvBase @@ -625,6 +631,7 @@ def __init__( else: self._sync_policy = _do_nothing self.device = device + self.no_cuda_sync = no_cuda_sync # Check if we need to cast things from device to device # If the policy has a None device and the env too, no need to cast (we don't know # and assume the user knows what she's doing). @@ -1010,12 +1017,16 @@ def iterator(self) -> Iterator[TensorDictBase]: Yields: TensorDictBase objects containing (chunks of) trajectories """ - if self.storing_device and self.storing_device.type == "cuda": + if ( + not self.no_cuda_sync + and self.storing_device + and self.storing_device.type == "cuda" + ): stream = torch.cuda.Stream(self.storing_device, priority=-1) event = stream.record_event() streams = [stream] events = [event] - elif self.storing_device is None: + elif not self.no_cuda_sync and self.storing_device is None: streams = [] events = [] # this way of checking cuda is robust to lazy stacks with mismatching shapes @@ -1166,10 +1177,17 @@ def rollout(self) -> TensorDictBase: else: if self._cast_to_policy_device: if self.policy_device is not None: + # This is unsafe if the shuttle is in pin_memory -- otherwise cuda will be happy with non_blocking + non_blocking = ( + not self.no_cuda_sync + or self.policy_device.type == "cuda" + ) policy_input = self._shuttle.to( - self.policy_device, non_blocking=True + self.policy_device, + non_blocking=non_blocking, ) - self._sync_policy() + if not self.no_cuda_sync: + self._sync_policy() elif self.policy_device is None: # we know the tensordict has a device otherwise we would not be here # we can pass this, clear_device_ must have been called earlier @@ -1191,8 +1209,14 @@ def rollout(self) -> TensorDictBase: if self._cast_to_env_device: if self.env_device is not None: - env_input = self._shuttle.to(self.env_device, non_blocking=True) - self._sync_env() + non_blocking = ( + not self.no_cuda_sync or self.env_device.type == "cuda" + ) + env_input = self._shuttle.to( + self.env_device, non_blocking=non_blocking + ) + if not self.no_cuda_sync: + self._sync_env() elif self.env_device is None: # we know the tensordict has a device otherwise we would not be here # we can pass this, clear_device_ must have been called earlier @@ -1216,10 +1240,16 @@ def rollout(self) -> TensorDictBase: return else: if self.storing_device is not None: + non_blocking = ( + not self.no_cuda_sync or self.storing_device.type == "cuda" + ) tensordicts.append( - self._shuttle.to(self.storing_device, non_blocking=True) + self._shuttle.to( + self.storing_device, non_blocking=non_blocking + ) ) - self._sync_storage() + if not self.no_cuda_sync: + self._sync_storage() else: tensordicts.append(self._shuttle) @@ -1558,6 +1588,11 @@ class _MultiDataCollector(DataCollectorBase): cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped in :class:`~tensordict.nn.CudaGraphModule` with default kwargs. If a dictionary of kwargs is passed, it will be used to wrap the policy. + no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed. + For environments running directly on CUDA (`IsaacLab `_ + or `ManiSkills `_) cuda synchronization may cause unexpected + crashes. + Defaults to ``False``. """ @@ -1597,6 +1632,7 @@ def __init__( trust_policy: bool = None, compile_policy: bool | Dict[str, Any] | None = None, cudagraph_policy: bool | Dict[str, Any] | None = None, + no_cuda_sync: bool = False, ): self.closed = True self.num_workers = len(create_env_fn) @@ -1636,6 +1672,7 @@ def __init__( self.env_device = env_devices del storing_device, env_device, policy_device, device + self.no_cuda_sync = no_cuda_sync self._use_buffers = use_buffers self.replay_buffer = replay_buffer @@ -1909,6 +1946,7 @@ def _run_processes(self) -> None: "cudagraph_policy": self.cudagraphed_policy_kwargs if self.cudagraphed_policy else False, + "no_cuda_sync": self.no_cuda_sync, } proc = _ProcessNoWarn( target=_main_async_collector, @@ -2914,6 +2952,7 @@ def _main_async_collector( trust_policy: bool = False, compile_policy: bool = False, cudagraph_policy: bool = False, + no_cuda_sync: bool = False, ) -> None: pipe_parent.close() # init variables that will be cleared when closing @@ -2943,6 +2982,7 @@ def _main_async_collector( trust_policy=trust_policy, compile_policy=compile_policy, cudagraph_policy=cudagraph_policy, + no_cuda_sync=no_cuda_sync, ) use_buffers = inner_collector._use_buffers if verbose: diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 3ed65d59d16..7fa882cbbaa 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -23,6 +23,7 @@ H5StorageCheckpointer, ImmutableDatasetWriter, LazyMemmapStorage, + LazyStackStorage, LazyTensorStorage, ListStorage, ListStorageCheckpointer, diff --git a/torchrl/data/replay_buffers/__init__.py b/torchrl/data/replay_buffers/__init__.py index 25822dcfe4c..4f230f30701 100644 --- a/torchrl/data/replay_buffers/__init__.py +++ b/torchrl/data/replay_buffers/__init__.py @@ -32,6 +32,7 @@ ) from .storages import ( LazyMemmapStorage, + LazyStackStorage, LazyTensorStorage, ListStorage, Storage, diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index fc27401d5e5..fa92d84295a 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -1049,6 +1049,10 @@ def __init__( self._cache["stop-and-length"] = vals else: + if traj_key is not None: + self._fetch_traj = True + elif end_key is not None: + self._fetch_traj = False if end_key is None: end_key = ("next", "done") if traj_key is None: @@ -1331,7 +1335,7 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict] if start_idx.shape[1] != storage.ndim: raise RuntimeError( f"Expected the end-of-trajectory signal to be " - f"{storage.ndim}-dimensional. Got a {start_idx.shape[1]} tensor " + f"{storage.ndim}-dimensional. Got a tensor with shape[1]={start_idx.shape[1]} " "instead." ) seq_length, num_slices = self._adjusted_batch_size(batch_size) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 52d137208ad..344814e728c 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -297,7 +297,15 @@ def set( def get(self, index: Union[int, Sequence[int], slice]) -> Any: if isinstance(index, (INT_CLASSES, slice)): return self._storage[index] + elif isinstance(index, tuple): + if len(index) > 1: + raise RuntimeError( + f"{type(self).__name__} can only be indexed with one-length tuples." + ) + return self.get(index[0]) else: + if isinstance(index, torch.Tensor) and index.device.type != "cpu": + index = index.cpu().tolist() return [self._storage[i] for i in index] def __len__(self): @@ -353,6 +361,77 @@ def contains(self, item): raise NotImplementedError(f"type {type(item)} is not supported yet.") +class LazyStackStorage(ListStorage): + """A ListStorage that returns LazyStackTensorDict instances. + + This storage allows for heterougeneous structures to be indexed as a single `TensorDict` representation. + It uses :class:`~tensordict.LazyStackedTensorDict` which operates on non-contiguous lists of tensordicts, + lazily stacking items when queried. + This means that this storage is going to be fast to sample but data access may be slow (as it requires a stack). + Tensors of heterogeneous shapes can also be stored within the storage and stacked together. + Because the storage is represented as a list, the number of tensors to store in memory will grow linearly with + the size of the buffer. + + If possible, nested tensors can also be created via :meth:`~tensordict.LazyStackedTensorDict.densify` + (see :mod:`~torch.nested`). + + Args: + max_size (int, optional): the maximum number of elements stored in the storage. + If not provided, an unlimited storage is created. + + Keyword Args: + compilable (bool, optional): if ``True``, the storage will be made compatible with :func:`~torch.compile` at + the cost of being executable in multiprocessed settings. + stack_dim (int, optional): the stack dimension in terms of TensorDict batch sizes. Defaults to `-1`. + + Examples: + >>> import torch + >>> from torchrl.data import ReplayBuffer, LazyStackStorage + >>> from tensordict import TensorDict + >>> _ = torch.manual_seed(0) + >>> rb = ReplayBuffer(storage=LazyStackStorage(max_size=1000, stack_dim=-1)) + >>> data0 = TensorDict(a=torch.randn((10,)), b=torch.rand(4), c="a string!") + >>> data1 = TensorDict(a=torch.randn((11,)), b=torch.rand(4), c="another string!") + >>> _ = rb.add(data0) + >>> _ = rb.add(data1) + >>> rb.sample(10) + LazyStackedTensorDict( + fields={ + a: Tensor(shape=torch.Size([10, -1]), device=cpu, dtype=torch.float32, is_shared=False), + b: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False), + c: NonTensorStack( + ['another string!', 'another string!', 'another st..., + batch_size=torch.Size([10]), + device=None)}, + exclusive_fields={ + }, + batch_size=torch.Size([10]), + device=None, + is_shared=False, + stack_dim=0) + """ + + def __init__( + self, + max_size: int | None = None, + *, + compilable: bool = False, + stack_dim: int = -1, + ): + super().__init__(max_size=max_size, compilable=compilable) + self.stack_dim = stack_dim + + def get(self, index: Union[int, Sequence[int], slice]) -> Any: + out = super().get(index=index) + if isinstance(out, list): + stack_dim = self.stack_dim + if stack_dim < 0: + stack_dim = out[0].ndim + 1 + stack_dim + out = LazyStackedTensorDict(*out, stack_dim=stack_dim) + return out + return out + + class TensorStorage(Storage): """A storage for tensors and tensordicts. diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 95aaaebd936..a7914f4f1d7 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -2450,6 +2450,9 @@ class NonTensor(TensorSpec): :meth:`.rand` will return a :class:`~tensordict.NonTensorData` object with `None` data value. (same will go for :meth:`.zero` and :meth:`.one`). + + .. note:: The default shape of `NonTensor` is `(1,)`. + """ example_data: Any = None diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 52f3bb3ac1b..a103a979174 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -32,6 +32,7 @@ OpenSpielWrapper, PettingZooEnv, PettingZooWrapper, + register_gym_spec_conversion, RoboHiveEnv, set_gym_backend, SMACv2Env, diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index c3a714fcf91..35f86d78eaf 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -2699,6 +2699,7 @@ def specs(self) -> Composite: @property def _has_dynamic_specs(self) -> bool: + # TODO: cache this value return _has_dynamic_specs(self.specs) def rollout( @@ -2711,7 +2712,7 @@ def rollout( auto_cast_to_device: bool = False, break_when_any_done: bool | None = None, break_when_all_done: bool | None = None, - return_contiguous: bool = True, + return_contiguous: bool | None = False, tensordict: Optional[TensorDictBase] = None, set_truncated: bool = False, out=None, @@ -2746,7 +2747,8 @@ def rollout( break_when_all_done (bool, optional): if ``True``, break if all of the contained environments reach any of the done states. If ``False``, break if at least one environment reaches any of the done states. Default is ``False``. - return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is True. + return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is `True` if + the env does not have dynamic specs, otherwise `False`. tensordict (TensorDict, optional): if ``auto_reset`` is False, an initial tensordict must be provided. Rollout will check if this tensordict has done flags and reset the environment in those dimensions (if needed). @@ -2957,7 +2959,8 @@ def rollout( raise TypeError( "Cannot have both break_when_all_done and break_when_any_done True at the same time." ) - + if return_contiguous is None: + return_contiguous = not self._has_dynamic_specs if policy is not None: policy = _make_compatible_policy( policy, diff --git a/torchrl/envs/libs/__init__.py b/torchrl/envs/libs/__init__.py index 7ea113ce46d..1cff97c1d49 100644 --- a/torchrl/envs/libs/__init__.py +++ b/torchrl/envs/libs/__init__.py @@ -12,6 +12,7 @@ GymWrapper, MOGymEnv, MOGymWrapper, + register_gym_spec_conversion, set_gym_backend, ) from .habitat import HabitatEnv diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index e4cab0edc80..4538b226d4a 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -234,6 +234,80 @@ def gym_backend(submodule=None): __all__ = ["GymWrapper", "GymEnv"] +# Define a dictionary to store conversion functions for each spec type +class _ConversionRegistry(collections.UserDict): + def __getitem__(self, cls): + if cls not in super().keys(): + # We want to find the closest parent + parents = {} + for k in self.keys(): + if not isinstance(k, str): + parents[k] = k + continue + try: + space_cls = gym_backend("spaces") + for sbsp in k.split("."): + space_cls = getattr(space_cls, sbsp) + except AttributeError: + # Some specs may be too recent + continue + parents[space_cls] = k + mro = cls.mro() + for base in mro: + for p in parents: + if issubclass(base, p): + return self[parents[p]] + else: + raise KeyError( + f"No conversion tool could be found with the gym space {cls}. " + f"You can register your own with `torchrl.envs.libs.register_gym_spec_conversion.`" + ) + return super().__getitem__(cls) + + +_conversion_registry = _ConversionRegistry() + + +def register_gym_spec_conversion(spec_type): + """Decorator to register a conversion function for a specific spec type. + + The method must have the following signature: + + >>> @register_gym_spec_conversion("spec.name") + ... def convert_specname( + ... spec, + ... dtype=None, + ... device=None, + ... categorical_action_encoding=None, + ... remap_state_to_observation=None, + ... batch_size=None, + ... ): + + where `gym(nasium).spaces.spec.name` is the location of the spec in gym. + + If the spec type is accessible, this will also work: + + >>> @register_gym_spec_conversion(SpecType) + ... def convert_specname( + ... spec, + ... dtype=None, + ... device=None, + ... categorical_action_encoding=None, + ... remap_state_to_observation=None, + ... batch_size=None, + ... ): + + ..note:: The wrapped function can be simplified, and unused kwargs can be wrapped in `**kwargs`. + + """ + + def decorator(conversion_func): + _conversion_registry[spec_type] = conversion_func + return conversion_func + + return decorator + + def _gym_to_torchrl_spec_transform( spec, dtype=None, @@ -256,7 +330,6 @@ def _gym_to_torchrl_spec_transform( Dict specs to "observation". Default is true. batch_size (torch.Size): batch size to which expand the spec. Defaults to ``torch.Size([])``. - """ if batch_size: return _gym_to_torchrl_spec_transform( @@ -267,139 +340,239 @@ def _gym_to_torchrl_spec_transform( remap_state_to_observation=remap_state_to_observation, batch_size=None, ).expand(batch_size) - gym_spaces = gym_backend("spaces") - if isinstance(spec, gym_spaces.tuple.Tuple): - result = torch.stack( - [ - _gym_to_torchrl_spec_transform( - s, - device=device, - categorical_action_encoding=categorical_action_encoding, - remap_state_to_observation=remap_state_to_observation, - ) - for s in spec - ], - dim=0, - ) - return result - if isinstance(spec, gym_spaces.discrete.Discrete): - action_space_cls = Categorical if categorical_action_encoding else OneHot + + # Get the conversion function from the registry + conversion_func = _conversion_registry[type(spec)] + # Call the conversion function with the provided arguments + return conversion_func( + spec, + dtype=dtype, + device=device, + categorical_action_encoding=categorical_action_encoding, + remap_state_to_observation=remap_state_to_observation, + batch_size=batch_size, + ) + + +# Register conversion functions for each spec type +@register_gym_spec_conversion("tuple.Tuple") +def convert_tuple_spec( + spec, + dtype=None, + device=None, + categorical_action_encoding=None, + remap_state_to_observation=None, + batch_size=None, +): + # Implementation for Tuple spec type + result = torch.stack( + [ + _gym_to_torchrl_spec_transform( + s, + device=device, + categorical_action_encoding=categorical_action_encoding, + remap_state_to_observation=remap_state_to_observation, + ) + for s in spec + ], + dim=0, + ) + return result + + +@register_gym_spec_conversion("discrete.Discrete") +def convert_discrete_spec( + spec, + dtype=None, + device=None, + categorical_action_encoding=None, + remap_state_to_observation=None, + batch_size=None, +): + # Implementation for Discrete spec type + action_space_cls = Categorical if categorical_action_encoding else OneHot + dtype = ( + numpy_to_torch_dtype_dict[spec.dtype] + if categorical_action_encoding + else torch.long + ) + return action_space_cls(spec.n, device=device, dtype=dtype) + + +@register_gym_spec_conversion("multi_binary.MultiBinary") +def convert_multi_binary_spec( + spec, + dtype=None, + device=None, + categorical_action_encoding=None, + remap_state_to_observation=None, + batch_size=None, +): + # Implementation for MultiBinary spec type + return Binary(spec.n, device=device, dtype=numpy_to_torch_dtype_dict[spec.dtype]) + + +@register_gym_spec_conversion("multi_discrete.MultiDiscrete") +def convert_multidiscrete_spec( + spec, + dtype=None, + device=None, + categorical_action_encoding=None, + remap_state_to_observation=None, + batch_size=None, +): + if len(spec.nvec.shape) == 1 and len(np.unique(spec.nvec)) > 1: dtype = ( numpy_to_torch_dtype_dict[spec.dtype] if categorical_action_encoding else torch.long ) - return action_space_cls(spec.n, device=device, dtype=dtype) - elif isinstance(spec, gym_spaces.multi_binary.MultiBinary): - return Binary( - spec.n, device=device, dtype=numpy_to_torch_dtype_dict[spec.dtype] - ) - # a spec type cannot be a string, so we're sure that versions of gym that don't have Sequence will just skip through this - elif isinstance(spec, getattr(gym_spaces, "Sequence", str)): - if not hasattr(spec, "stack"): - # gym does not have a stack attribute in sequence - raise ValueError( - "gymnasium should be used whenever a Sequence is present, as it needs to be stacked. " - "If you need the gym backend at all price, please raise an issue on the TorchRL GitHub repository." - ) - if not getattr(spec, "stack", False): - raise ValueError( - "Sequence spaces must have the stack argument set to ``True``. " - ) - space = spec.feature_space - out = _gym_to_torchrl_spec_transform(space, device=device, dtype=dtype) - out = out.unsqueeze(0) - out.make_neg_dim(0) - return out - elif isinstance(spec, gym_spaces.multi_discrete.MultiDiscrete): - if len(spec.nvec.shape) == 1 and len(np.unique(spec.nvec)) > 1: - dtype = ( - numpy_to_torch_dtype_dict[spec.dtype] - if categorical_action_encoding - else torch.long - ) - - return ( - MultiCategorical(spec.nvec, device=device, dtype=dtype) - if categorical_action_encoding - else MultiOneHot(spec.nvec, device=device, dtype=dtype) - ) - return torch.stack( - [ - _gym_to_torchrl_spec_transform( - spec[i], - device=device, - categorical_action_encoding=categorical_action_encoding, - remap_state_to_observation=remap_state_to_observation, - ) - for i in range(len(spec.nvec)) - ], - 0, - ) - elif isinstance(spec, gym_spaces.Box): - shape = spec.shape - if not len(shape): - shape = torch.Size([1]) - if dtype is None: - dtype = numpy_to_torch_dtype_dict[spec.dtype] - low = torch.as_tensor(spec.low, device=device, dtype=dtype) - high = torch.as_tensor(spec.high, device=device, dtype=dtype) - is_unbounded = low.isinf().all() and high.isinf().all() - - minval, maxval = _minmax_dtype(dtype) - minval = torch.as_tensor(minval).to(low.device, dtype) - maxval = torch.as_tensor(maxval).to(low.device, dtype) - is_unbounded = is_unbounded or ( - torch.isclose(low, torch.as_tensor(minval, dtype=dtype)).all() - and torch.isclose(high, torch.as_tensor(maxval, dtype=dtype)).all() - ) return ( - Unbounded(shape, device=device, dtype=dtype) - if is_unbounded - else Bounded( - low, - high, - shape, - dtype=dtype, - device=device, - ) + MultiCategorical(spec.nvec, device=device, dtype=dtype) + if categorical_action_encoding + else MultiOneHot(spec.nvec, device=device, dtype=dtype) ) - elif isinstance(spec, (Dict,)): - spec_out = {} - for k in spec.keys(): - key = k - if ( - remap_state_to_observation - and k == "state" - and "observation" not in spec.keys() - ): - # we rename "state" in "observation" as "observation" is the conventional name - # for single observation in torchrl. - # naming it 'state' will result in envs that have a different name for the state vector - # when queried with and without pixels - key = "observation" - spec_out[key] = _gym_to_torchrl_spec_transform( - spec[k], + + return torch.stack( + [ + _gym_to_torchrl_spec_transform( + spec[i], device=device, categorical_action_encoding=categorical_action_encoding, remap_state_to_observation=remap_state_to_observation, ) - # the batch-size must be set later - return Composite(spec_out, device=device) - elif isinstance(spec, gym_spaces.dict.Dict): - return _gym_to_torchrl_spec_transform( - spec.spaces, + for i in range(len(spec.nvec)) + ], + 0, + ) + + +@register_gym_spec_conversion("Box") +def convert_box_spec( + spec, + dtype=None, + device=None, + categorical_action_encoding=None, + remap_state_to_observation=None, + batch_size=None, +): + shape = spec.shape + if not len(shape): + shape = torch.Size([1]) + if dtype is None: + dtype = numpy_to_torch_dtype_dict[spec.dtype] + low = torch.as_tensor(spec.low, device=device, dtype=dtype) + high = torch.as_tensor(spec.high, device=device, dtype=dtype) + is_unbounded = low.isinf().all() and high.isinf().all() + + minval, maxval = _minmax_dtype(dtype) + minval = torch.as_tensor(minval).to(low.device, dtype) + maxval = torch.as_tensor(maxval).to(low.device, dtype) + is_unbounded = is_unbounded or ( + torch.isclose(low, torch.as_tensor(minval, dtype=dtype)).all() + and torch.isclose(high, torch.as_tensor(maxval, dtype=dtype)).all() + ) + return ( + Unbounded(shape, device=device, dtype=dtype) + if is_unbounded + else Bounded( + low, + high, + shape, + dtype=dtype, + device=device, + ) + ) + + +@register_gym_spec_conversion("Sequence") +def convert_sequence_spec( + spec, + dtype=None, + device=None, + categorical_action_encoding=None, + remap_state_to_observation=None, + batch_size=None, +): + if not hasattr(spec, "stack"): + # gym does not have a stack attribute in sequence + raise ValueError( + "gymnasium should be used whenever a Sequence is present, as it needs to be stacked. " + "If you need the gym backend at all price, please raise an issue on the TorchRL GitHub repository." + ) + if not getattr(spec, "stack", False): + raise ValueError( + "Sequence spaces must have the stack argument set to ``True``. " + ) + space = spec.feature_space + out = _gym_to_torchrl_spec_transform(space, device=device, dtype=dtype) + out = out.unsqueeze(0) + out.make_neg_dim(0) + return out + + +@register_gym_spec_conversion(Dict) +def convert_dict_spec( + spec, + dtype=None, + device=None, + categorical_action_encoding=None, + remap_state_to_observation=None, + batch_size=None, +): + spec_out = {} + for k in spec.keys(): + key = k + if ( + remap_state_to_observation + and k == "state" + and "observation" not in spec.keys() + ): + # we rename "state" in "observation" as "observation" is the conventional name + # for single observation in torchrl. + # naming it 'state' will result in envs that have a different name for the state vector + # when queried with and without pixels + key = "observation" + spec_out[key] = _gym_to_torchrl_spec_transform( + spec[k], device=device, categorical_action_encoding=categorical_action_encoding, remap_state_to_observation=remap_state_to_observation, + batch_size=batch_size, ) - elif _has_minigrid and isinstance(spec, _minigrid_lib().core.mission.MissionSpace): - return NonTensor((), device=device) - else: - raise NotImplementedError( - f"spec of type {type(spec).__name__} is currently unaccounted for" - ) + # the batch-size must be set later + return Composite(spec_out, device=device) + + +@register_gym_spec_conversion("Text") +def convert_text_soec( + spec, + dtype=None, + device=None, + categorical_action_encoding=None, + remap_state_to_observation=None, + batch_size=None, +): + return NonTensor((), device=device, example_data="a string") + + +@register_gym_spec_conversion("dict.Dict") +def convert_dict_spec2( + spec, + dtype=None, + device=None, + categorical_action_encoding=None, + remap_state_to_observation=None, + batch_size=None, +): + return _gym_to_torchrl_spec_transform( + spec.spaces, + device=device, + categorical_action_encoding=categorical_action_encoding, + remap_state_to_observation=remap_state_to_observation, + batch_size=batch_size, + ) @implement_for("gym", None, "0.18")