diff --git a/.github/unittest/linux/scripts/run_all.sh b/.github/unittest/linux/scripts/run_all.sh index 22cdad1b479..6c9c1c43cf5 100755 --- a/.github/unittest/linux/scripts/run_all.sh +++ b/.github/unittest/linux/scripts/run_all.sh @@ -80,7 +80,7 @@ export DISPLAY=:0 export SDL_VIDEODRIVER=dummy # legacy from bash scripts: remove? -conda env config vars set MUJOCO_GL=$MUJOCO_GL PYOPENGL_PLATFORM=$MUJOCO_GL DISPLAY=:0 SDL_VIDEODRIVER=dummy LAZY_LEGACY_OP=False RL_LOGGING_LEVEL=DEBUG +conda env config vars set MUJOCO_GL=$MUJOCO_GL PYOPENGL_PLATFORM=$MUJOCO_GL DISPLAY=:0 SDL_VIDEODRIVER=dummy LAZY_LEGACY_OP=False RL_LOGGING_LEVEL=DEBUG TOKENIZERS_PARALLELISM=true pip3 install pip --upgrade pip install virtualenv diff --git a/.github/unittest/linux_distributed/scripts/setup_env.sh b/.github/unittest/linux_distributed/scripts/setup_env.sh index 2a48ab21459..4344c136994 100755 --- a/.github/unittest/linux_distributed/scripts/setup_env.sh +++ b/.github/unittest/linux_distributed/scripts/setup_env.sh @@ -69,7 +69,8 @@ conda env config vars set MUJOCO_PY_MUJOCO_PATH=$root_dir/.mujoco/mujoco210 \ LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$root_dir/.mujoco/mujoco210/bin \ SDL_VIDEODRIVER=dummy \ MUJOCO_GL=$PRIVATE_MUJOCO_GL \ - PYOPENGL_PLATFORM=$PRIVATE_MUJOCO_GL + PYOPENGL_PLATFORM=$PRIVATE_MUJOCO_GL \ + TOKENIZERS_PARALLELISM=true # Software rendering requires GLX and OSMesa. if [ $PRIVATE_MUJOCO_GL == 'egl' ] || [ $PRIVATE_MUJOCO_GL == 'osmesa' ] ; then diff --git a/.github/unittest/linux_libs/scripts_d4rl/setup_env.sh b/.github/unittest/linux_libs/scripts_d4rl/setup_env.sh index 58ec8becf2e..f1775a0375a 100755 --- a/.github/unittest/linux_libs/scripts_d4rl/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_d4rl/setup_env.sh @@ -92,6 +92,7 @@ conda env config vars set \ MUJOCO_PY_MJKEY_PATH=$root_dir/.mujoco/mjkey.txt \ SDL_VIDEODRIVER=dummy \ MUJOCO_GL=$PRIVATE_MUJOCO_GL \ - PYOPENGL_PLATFORM=$PRIVATE_MUJOCO_GL + PYOPENGL_PLATFORM=$PRIVATE_MUJOCO_GL \ + TOKENIZERS_PARALLELISM=true conda env update --file "${this_dir}/environment.yml" --prune diff --git a/.github/unittest/linux_libs/scripts_gym/setup_env.sh b/.github/unittest/linux_libs/scripts_gym/setup_env.sh index 8804370aa6d..163a26fbdf8 100755 --- a/.github/unittest/linux_libs/scripts_gym/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_gym/setup_env.sh @@ -80,6 +80,7 @@ conda env config vars set \ MUJOCO_PY_MJKEY_PATH=${root_dir}/mujoco-py/mujoco_py/binaries/mjkey.txt \ MUJOCO_PY_MUJOCO_PATH=${root_dir}/mujoco-py/mujoco_py/binaries/linux/mujoco210 \ LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/pytorch/rl/mujoco-py/mujoco_py/binaries/linux/mujoco210/bin + TOKENIZERS_PARALLELISM=true # LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/circleci/project/mujoco-py/mujoco_py/binaries/linux/mujoco210/bin # make env variables apparent diff --git a/.github/unittest/linux_libs/scripts_habitat/run_test.sh b/.github/unittest/linux_libs/scripts_habitat/run_test.sh index a60fffd8f45..b03fd0823a9 100755 --- a/.github/unittest/linux_libs/scripts_habitat/run_test.sh +++ b/.github/unittest/linux_libs/scripts_habitat/run_test.sh @@ -10,7 +10,7 @@ conda activate ./env # https://stackoverflow.com/questions/72540359/glibcxx-3-4-30-not-found-for-librosa-in-conda-virtual-environment-after-tryin #conda install -y -c conda-forge gcc=12.1.0 conda install -y -c conda-forge libstdcxx-ng=12 -conda env config vars set LD_PRELOAD=$LD_PRELOAD:$STDC_LOC +conda env config vars set LD_PRELOAD=$LD_PRELOAD:$STDC_LOC TOKENIZERS_PARALLELISM=true ## find libstdc STDC_LOC=$(find conda/ -name "libstdc++.so.6" | head -1) @@ -36,7 +36,7 @@ export MKL_THREADING_LAYER=GNU #wget https://github.com/openai/mujoco-py/blob/master/vendor/10_nvidia.json #mv 10_nvidia.json /usr/share/glvnd/egl_vendor.d/10_nvidia.json -conda env config vars set MAGNUM_LOG=quiet HABITAT_SIM_LOG=quiet +conda env config vars set MAGNUM_LOG=quiet HABITAT_SIM_LOG=quiet TOKENIZERS_PARALLELISM=true conda deactivate && conda activate ./env diff --git a/.github/unittest/linux_libs/scripts_habitat/setup_env.sh b/.github/unittest/linux_libs/scripts_habitat/setup_env.sh index 6ad970c3f47..e436a0c9bf0 100755 --- a/.github/unittest/linux_libs/scripts_habitat/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_habitat/setup_env.sh @@ -41,7 +41,7 @@ fi conda activate "${env_dir}" # set debug variables -conda env config vars set MAGNUM_LOG=debug HABITAT_SIM_LOG=debug +conda env config vars set MAGNUM_LOG=debug HABITAT_SIM_LOG=debug TOKENIZERS_PARALLELISM=true conda deactivate && conda activate "${env_dir}" pip3 install "cython<3" diff --git a/.github/unittest/linux_libs/scripts_robohive/setup_env.sh b/.github/unittest/linux_libs/scripts_robohive/setup_env.sh index 38e6d350354..4e3bc93bf03 100755 --- a/.github/unittest/linux_libs/scripts_robohive/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_robohive/setup_env.sh @@ -67,7 +67,8 @@ conda env config vars set \ PYOPENGL_PLATFORM=egl \ NVIDIA_PATH=/usr/src/nvidia-470.63.01 \ sim_backend=MUJOCO \ - LAZY_LEGACY_OP=False + LAZY_LEGACY_OP=False \ + TOKENIZERS_PARALLELISM=true # make env variables apparent conda deactivate && conda activate "${env_dir}" diff --git a/.github/unittest/linux_olddeps/scripts_gym_0_13/setup_env.sh b/.github/unittest/linux_olddeps/scripts_gym_0_13/setup_env.sh index 5df37723d6a..56803aadf49 100755 --- a/.github/unittest/linux_olddeps/scripts_gym_0_13/setup_env.sh +++ b/.github/unittest/linux_olddeps/scripts_gym_0_13/setup_env.sh @@ -85,7 +85,8 @@ conda env config vars set \ NVIDIA_PATH=/usr/src/nvidia-470.63.01 \ MUJOCO_PY_MJKEY_PATH=${root_dir}/mujoco-py/mujoco_py/binaries/mjkey.txt \ MUJOCO_PY_MUJOCO_PATH=${root_dir}/mujoco-py/mujoco_py/binaries/linux/mujoco210 \ - LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/circleci/project/mujoco-py/mujoco_py/binaries/linux/mujoco210/bin + LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/circleci/project/mujoco-py/mujoco_py/binaries/linux/mujoco210/bin \ + TOKENIZERS_PARALLELISM=true # make env variables apparent conda deactivate && conda activate "${env_dir}" diff --git a/.github/unittest/linux_sota/scripts/run_all.sh b/.github/unittest/linux_sota/scripts/run_all.sh index 5da5256de99..e8a7423c9d3 100755 --- a/.github/unittest/linux_sota/scripts/run_all.sh +++ b/.github/unittest/linux_sota/scripts/run_all.sh @@ -83,7 +83,8 @@ conda env config vars set MUJOCO_PY_MUJOCO_PATH=$root_dir/.mujoco/mujoco210 \ SDL_VIDEODRIVER=dummy \ MUJOCO_GL=egl \ PYOPENGL_PLATFORM=egl \ - BATCHED_PIPE_TIMEOUT=60 + BATCHED_PIPE_TIMEOUT=60 \ + TOKENIZERS_PARALLELISM=true pip install pip --upgrade @@ -100,7 +101,7 @@ pip install git+https://github.com/Farama-Foundation/d4rl@master#egg=d4rl conda install -y -c conda-forge libstdcxx-ng=12 ## find libstdc STDC_LOC=$(find conda/ -name "libstdc++.so.6" | head -1) -conda env config vars set LD_PRELOAD=${root_dir}/$STDC_LOC +conda env config vars set LD_PRELOAD=${root_dir}/$STDC_LOC TOKENIZERS_PARALLELISM=true # compile mujoco-py (bc it's done at runtime for whatever reason someone thought it was a good idea) python -c """import gym;import d4rl""" diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index d75a0e67c54..20b2802591c 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -41,19 +41,35 @@ Each env will have the following attributes: the done-flag spec. See the section on trajectory termination below. - :obj:`env.input_spec`: a :class:`~torchrl.data.Composite` object containing all the input keys (:obj:`"full_action_spec"` and :obj:`"full_state_spec"`). - It is locked and should not be modified directly. - :obj:`env.output_spec`: a :class:`~torchrl.data.Composite` object containing all the output keys (:obj:`"full_observation_spec"`, :obj:`"full_reward_spec"` and :obj:`"full_done_spec"`). - It is locked and should not be modified directly. -If the environment carries non-tensor data, a :class:`~torchrl.data.NonTensorSpec` +If the environment carries non-tensor data, a :class:`~torchrl.data.NonTensor` instance can be used. +Env specs: locks and batch size +------------------------------- + +.. _Environment-lock: + +Environment specs are locked by default (through a ``spec_locked`` arg passed to the env constructor). +Locking specs means that any modification of the spec (or its children if it is a :class:`~torchrl.data.Composite` +instance) will require to unlock it. This can be done via the :meth:`~torchrl.envs.EnvBase.set_spec_lock_`. +The reason specs are locked by default is that it makes it easy to cache values such as action or reset keys and the +likes. +Unlocking an env should only be done if it expected that the specs will be modified often (which, in principle, should +be avoided). +Modifications of the specs such as `env.observation_spec = new_spec` are allowed: under the hood, TorchRL will erase +the cache, unlock the specs, make the modification and relock the specs if the env was previously locked. + Importantly, the environment spec shapes should contain the batch size, e.g. an environment with :obj:`env.batch_size == torch.Size([4])` should have an :obj:`env.action_spec` with shape :obj:`torch.Size([4, action_size])`. This is helpful when preallocation tensors, checking shape consistency etc. +Env methods +----------- + With these, the following methods are implemented: - :meth:`env.reset`: a reset method that may (but not necessarily requires to) take diff --git a/test/test_env.py b/test/test_env.py index b3c528d720f..f45aa7c4668 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -9,11 +9,13 @@ import gc import importlib import os.path +import pickle import random import re from collections import defaultdict from functools import partial from sys import platform +from typing import Optional import numpy as np import pytest @@ -246,6 +248,41 @@ def test_run_type_checks(self): with pytest.raises(TypeError): check_env_specs(env) + class MyEnv(EnvBase): + def __init__(self): + super().__init__() + self.observation_spec = Unbounded(()) + self.action_spec = Unbounded(()) + + def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: + ... + + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + ... + + def _set_seed(self, seed: Optional[int]): + ... + + def test_env_lock(self): + + env = self.MyEnv() + for _ in range(2): + assert env.is_spec_locked + assert env.output_spec.is_locked + assert env.input_spec.is_locked + with pytest.raises(RuntimeError, match="lock"): + env.input_spec["full_action_spec", "action"] = Unbounded(()) + env = pickle.loads(pickle.dumps(env)) + + env = self.MyEnv(spec_locked=False) + assert not env.is_spec_locked + assert not env.output_spec.is_locked + assert not env.input_spec.is_locked + env.input_spec["full_action_spec", "action"] = Unbounded(()) + def test_single_env_spec(self): env = NestedCountingEnv(batch_size=[3, 1, 7]) assert not env.full_action_spec_unbatched.shape @@ -2294,15 +2331,14 @@ def test_multi_purpose_env(self, serial): env = SerialEnv(2, ContinuousActionVecMockEnv) else: env = ContinuousActionVecMockEnv() + env.set_spec_lock_() env.rollout(10) - assert env._step_mdp.validate(None) c = SyncDataCollector( env, env.rand_action, frames_per_batch=10, total_frames=20 ) for data in c: # noqa: B007 pass assert ("collector", "traj_ids") in data.keys(True) - assert env._step_mdp.validate(None) env.rollout(10) # An exception will be raised when the collector sees extra keys @@ -3387,6 +3423,10 @@ def policy(td): class TestEnvWithDynamicSpec: def test_dynamic_rollout(self): env = EnvWithDynamicSpec() + rollout = env.rollout(4) + assert isinstance(rollout, LazyStackedTensorDict) + rollout = env.rollout(4, return_contiguous=False) + assert isinstance(rollout, LazyStackedTensorDict) with pytest.raises( RuntimeError, match="The environment specs are dynamic. Call rollout with return_contiguous=False", diff --git a/test/test_libs.py b/test/test_libs.py index 8db543c146a..57358b732a4 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -1779,7 +1779,7 @@ def test_jumanji_rendering(self, envname, batch_size): # check that this works with a batch-size env = JumanjiEnv(envname, from_pixels=True, batch_size=batch_size, jit=True) env.set_seed(0) - env.transform.transform_observation_spec(env.base_env.observation_spec) + env.transform.transform_observation_spec(env.base_env.observation_spec.clone()) r = env.rollout(10) pixels = r["pixels"] diff --git a/test/test_specs.py b/test/test_specs.py index 340afaa449a..c05f3604563 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -410,6 +410,30 @@ def test_setitem_matches_device(self, shape, is_complete, device, dtype, dest): ) assert ts["bad"].device == (device if device is not None else dest) + def test_setitem_nested(self, shape, is_complete, device, dtype): + f = Unbounded(shape=shape, device=device, dtype=dtype) + g = ( + None + if not is_complete + else Unbounded(shape=shape, device=device, dtype=dtype) + ) + test = Composite( + a=Composite(b=Composite(c=Composite(d=Composite(e=Composite(f=f, g=g))))), + shape=shape, + device=device, + ) + trials = Composite(shape=shape, device=device) + assert trials != test + trials["a", "b", "c", "d", "e", "f"] = Unbounded( + shape=shape, device=device, dtype=dtype + ) + trials["a", "b", "c", "d", "e", "g"] = ( + None + if not is_complete + else Unbounded(shape=shape, device=device, dtype=dtype) + ) + assert trials == test + def test_del(self, shape, is_complete, device, dtype): ts = self._composite_spec(shape, is_complete, device, dtype) assert "obs" in ts.keys() diff --git a/test/test_transforms.py b/test/test_transforms.py index aba41ba614f..f57bc58221d 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1100,7 +1100,7 @@ def test_catframes_transform_observation_spec(self): } ) - result = cat_frames.transform_observation_spec(observation_spec) + result = cat_frames.transform_observation_spec(observation_spec.clone()) observation_spec = Composite( { key: Bounded(space_min, space_max, (1, 3, 3), dtype=torch.double) @@ -1665,7 +1665,9 @@ def test_r3mnet_transform_observation_spec( {key: Unbounded(r3m_net.outdim, device) for key in out_keys} ) - observation_spec_out = r3m_net.transform_observation_spec(observation_spec) + observation_spec_out = r3m_net.transform_observation_spec( + observation_spec.clone() + ) for key in in_keys: assert key not in observation_spec_out @@ -1681,7 +1683,9 @@ def test_r3mnet_transform_observation_spec( ts_dict[key] = Unbounded(r3m_net.outdim, device) exp_ts = Composite(ts_dict) - observation_spec_out = r3m_net.transform_observation_spec(observation_spec) + observation_spec_out = r3m_net.transform_observation_spec( + observation_spec.clone() + ) for key in in_keys + out_keys: assert observation_spec_out[key].shape == exp_ts[key].shape @@ -2059,7 +2063,7 @@ class TestTrajCounter(TransformBase): def test_single_trans_env_check(self): torch.manual_seed(0) env = TransformedEnv(CountingEnv(max_steps=4), TrajCounter()) - env.transform.transform_observation_spec(env.base_env.observation_spec) + env.transform.transform_observation_spec(env.base_env.observation_spec.clone()) check_env_specs(env) @pytest.mark.parametrize("predefined", [True, False]) @@ -2073,7 +2077,9 @@ def make_env(max_steps=4, t=t): if t is None: t = TrajCounter() env = TransformedEnv(CountingEnv(max_steps=max_steps), t.clone()) - env.transform.transform_observation_spec(env.base_env.observation_spec) + env.transform.transform_observation_spec( + env.base_env.observation_spec.clone() + ) return env if predefined: @@ -2109,7 +2115,9 @@ def make_env(max_steps=4, t=t): else: t = t.clone() env = TransformedEnv(CountingEnv(max_steps=max_steps), t) - env.transform.transform_observation_spec(env.base_env.observation_spec) + env.transform.transform_observation_spec( + env.base_env.observation_spec.clone() + ) return env if predefined: @@ -2137,7 +2145,7 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): ), TrajCounter(), ) - env.transform.transform_observation_spec(env.base_env.observation_spec) + env.transform.transform_observation_spec(env.base_env.observation_spec.clone()) r = env.rollout( 100, lambda td: td.set("action", torch.ones(env.shape + (1,))), @@ -2153,7 +2161,7 @@ def test_trans_serial_env_check(self): ), TrajCounter(), ) - env.transform.transform_observation_spec(env.base_env.observation_spec) + env.transform.transform_observation_spec(env.base_env.observation_spec.clone()) r = env.rollout( 100, lambda td: td.set("action", torch.ones(env.shape + (1,))), @@ -2165,7 +2173,7 @@ def test_trans_serial_env_check(self): def test_transform_env(self): torch.manual_seed(0) env = TransformedEnv(CountingEnv(max_steps=4), TrajCounter()) - env.transform.transform_observation_spec(env.base_env.observation_spec) + env.transform.transform_observation_spec(env.base_env.observation_spec.clone()) r = env.rollout(100, lambda td: td.set("action", 1), break_when_any_done=False) assert r["traj_count"].max() == 19 @@ -2178,7 +2186,7 @@ def test_nested(self): TrajCounter(out_key=(("nested"), (("traj_count",),))), ), ) - env.transform.transform_observation_spec(env.base_env.observation_spec) + env.transform.transform_observation_spec(env.base_env.observation_spec.clone()) r = env.rollout(100, lambda td: td.set("action", 1), break_when_any_done=False) assert r["nested", "traj_count"].max() == 19 @@ -2210,7 +2218,9 @@ def test_collector_match(self): def make_env(max_steps=4): env = TransformedEnv(CountingEnv(max_steps=max_steps), t.clone()) - env.transform.transform_observation_spec(env.base_env.observation_spec) + env.transform.transform_observation_spec( + env.base_env.observation_spec.clone() + ) return env collector = MultiSyncDataCollector( @@ -3283,13 +3293,17 @@ def test_transform_no_env(self, keys, device, out_key): if len(keys) == 1: observation_spec = Bounded(0, 1, (1, 4, 32)) - observation_spec = cattensors.transform_observation_spec(observation_spec) + observation_spec = cattensors.transform_observation_spec( + observation_spec.clone() + ) assert observation_spec.shape == torch.Size([1, len(keys) * 4, 32]) else: observation_spec = Composite( {key: Bounded(0, 1, (1, 4, 32)) for key in keys} ) - observation_spec = cattensors.transform_observation_spec(observation_spec) + observation_spec = cattensors.transform_observation_spec( + observation_spec.clone() + ) assert observation_spec[out_key].shape == torch.Size([1, len(keys) * 4, 32]) @pytest.mark.parametrize("device", get_default_devices()) @@ -3429,13 +3443,13 @@ def test_transform_no_env(self, keys, h, nchannels, batch, device): if len(keys) == 1: observation_spec = Bounded(-1, 1, (nchannels, 16, 16)) - observation_spec = crop.transform_observation_spec(observation_spec) + observation_spec = crop.transform_observation_spec(observation_spec.clone()) assert observation_spec.shape == torch.Size([nchannels, 20, h]) else: observation_spec = Composite( {key: Bounded(-1, 1, (nchannels, 16, 16)) for key in keys} ) - observation_spec = crop.transform_observation_spec(observation_spec) + observation_spec = crop.transform_observation_spec(observation_spec.clone()) for key in keys: assert observation_spec[key].shape == torch.Size([nchannels, 20, h]) @@ -3636,13 +3650,13 @@ def test_transform_no_env(self, keys, h, nchannels, batch, device): if len(keys) == 1: observation_spec = Bounded(-1, 1, (nchannels, 16, 16)) - observation_spec = cc.transform_observation_spec(observation_spec) + observation_spec = cc.transform_observation_spec(observation_spec.clone()) assert observation_spec.shape == torch.Size([nchannels, 20, h]) else: observation_spec = Composite( {key: Bounded(-1, 1, (nchannels, 16, 16)) for key in keys} ) - observation_spec = cc.transform_observation_spec(observation_spec) + observation_spec = cc.transform_observation_spec(observation_spec.clone()) for key in keys: assert observation_spec[key].shape == torch.Size([nchannels, 20, h]) @@ -3994,7 +4008,9 @@ def test_double2float(self, keys, keys_inv, device): observation_spec = Composite( {key: Bounded(0, 1, (1, 3, 3), dtype=torch.double) for key in keys} ) - observation_spec = double2float.transform_observation_spec(observation_spec) + observation_spec = double2float.transform_observation_spec( + observation_spec.clone() + ) for key in keys: assert observation_spec[key].dtype == torch.float, key @@ -4041,7 +4057,7 @@ def test_double2float_auto(self, keys, keys_inv, device): assert td_modif.get(key).dtype == torch.double def test_single_env_no_inkeys(self): - base_env = ContinuousActionVecMockEnv() + base_env = ContinuousActionVecMockEnv(spec_locked=False) for key, spec in list(base_env.observation_spec.items(True, True)): base_env.observation_spec[key] = spec.to(torch.float64) for key, spec in list(base_env.state_spec.items(True, True)): @@ -4052,6 +4068,7 @@ def test_single_env_no_inkeys(self): env = TransformedEnv( base_env, DoubleToFloat(), + spec_locked=False, ) for spec in env.observation_spec.values(True, True): assert spec.dtype == torch.float32 @@ -4773,13 +4790,17 @@ def test_transform_no_env(self, keys, size, nchannels, batch, device): if len(keys) == 1: observation_spec = Bounded(-1, 1, (*size, nchannels, 16, 16)) - observation_spec = flatten.transform_observation_spec(observation_spec) + observation_spec = flatten.transform_observation_spec( + observation_spec.clone() + ) assert observation_spec.shape[-3] == expected_size else: observation_spec = Composite( {key: Bounded(-1, 1, (*size, nchannels, 16, 16)) for key in keys} ) - observation_spec = flatten.transform_observation_spec(observation_spec) + observation_spec = flatten.transform_observation_spec( + observation_spec.clone() + ) for key in keys: assert observation_spec[key].shape[-3] == expected_size @@ -4813,13 +4834,17 @@ def test_transform_compose(self, keys, size, nchannels, batch, device): if len(keys) == 1: observation_spec = Bounded(-1, 1, (*size, nchannels, 16, 16)) - observation_spec = flatten.transform_observation_spec(observation_spec) + observation_spec = flatten.transform_observation_spec( + observation_spec.clone() + ) assert observation_spec.shape[-3] == expected_size else: observation_spec = Composite( {key: Bounded(-1, 1, (*size, nchannels, 16, 16)) for key in keys} ) - observation_spec = flatten.transform_observation_spec(observation_spec) + observation_spec = flatten.transform_observation_spec( + observation_spec.clone() + ) for key in keys: assert observation_spec[key].shape[-3] == expected_size @@ -5055,13 +5080,13 @@ def test_transform_no_env(self, keys, device): if len(keys) == 1: observation_spec = Bounded(-1, 1, (nchannels, 16, 16)) - observation_spec = gs.transform_observation_spec(observation_spec) + observation_spec = gs.transform_observation_spec(observation_spec.clone()) assert observation_spec.shape == torch.Size([1, 16, 16]) else: observation_spec = Composite( {key: Bounded(-1, 1, (nchannels, 16, 16)) for key in keys} ) - observation_spec = gs.transform_observation_spec(observation_spec) + observation_spec = gs.transform_observation_spec(observation_spec.clone()) for key in keys: assert observation_spec[key].shape == torch.Size([1, 16, 16]) @@ -5092,13 +5117,13 @@ def test_transform_compose(self, keys, device): if len(keys) == 1: observation_spec = Bounded(-1, 1, (nchannels, 16, 16)) - observation_spec = gs.transform_observation_spec(observation_spec) + observation_spec = gs.transform_observation_spec(observation_spec.clone()) assert observation_spec.shape == torch.Size([1, 16, 16]) else: observation_spec = Composite( {key: Bounded(-1, 1, (nchannels, 16, 16)) for key in keys} ) - observation_spec = gs.transform_observation_spec(observation_spec) + observation_spec = gs.transform_observation_spec(observation_spec.clone()) for key in keys: assert observation_spec[key].shape == torch.Size([1, 16, 16]) @@ -5702,7 +5727,7 @@ def test_observationnorm( if len(keys) == 1: observation_spec = Bounded(0, 1, (nchannels, 16, 16), device=device) - observation_spec = on.transform_observation_spec(observation_spec) + observation_spec = on.transform_observation_spec(observation_spec.clone()) if standard_normal: assert (observation_spec.space.low == -loc / scale).all() assert (observation_spec.space.high == (1 - loc) / scale).all() @@ -5714,7 +5739,7 @@ def test_observationnorm( observation_spec = Composite( {key: Bounded(0, 1, (nchannels, 16, 16), device=device) for key in keys} ) - observation_spec = on.transform_observation_spec(observation_spec) + observation_spec = on.transform_observation_spec(observation_spec.clone()) for key in keys: if standard_normal: assert (observation_spec[key].space.low == -loc / scale).all() @@ -5919,13 +5944,17 @@ def test_transform_no_env(self, interpolation, keys, nchannels, batch, device): if len(keys) == 1: observation_spec = Bounded(-1, 1, (nchannels, 16, 16)) - observation_spec = resize.transform_observation_spec(observation_spec) + observation_spec = resize.transform_observation_spec( + observation_spec.clone() + ) assert observation_spec.shape == torch.Size([nchannels, 20, 21]) else: observation_spec = Composite( {key: Bounded(-1, 1, (nchannels, 16, 16)) for key in keys} ) - observation_spec = resize.transform_observation_spec(observation_spec) + observation_spec = resize.transform_observation_spec( + observation_spec.clone() + ) for key in keys: assert observation_spec[key].shape == torch.Size([nchannels, 20, 21]) @@ -5956,13 +5985,17 @@ def test_transform_compose(self, interpolation, keys, nchannels, batch, device): if len(keys) == 1: observation_spec = Bounded(-1, 1, (nchannels, 16, 16)) - observation_spec = resize.transform_observation_spec(observation_spec) + observation_spec = resize.transform_observation_spec( + observation_spec.clone() + ) assert observation_spec.shape == torch.Size([nchannels, 20, 21]) else: observation_spec = Composite( {key: Bounded(-1, 1, (nchannels, 16, 16)) for key in keys} ) - observation_spec = resize.transform_observation_spec(observation_spec) + observation_spec = resize.transform_observation_spec( + observation_spec.clone() + ) for key in keys: assert observation_spec[key].shape == torch.Size([nchannels, 20, 21]) @@ -6951,7 +6984,9 @@ def test_transform_no_env(self, keys, size, nchannels, batch, device, dim): if len(keys) == 1: observation_spec = Bounded(-1, 1, (*batch, *size, nchannels, 16, 16)) - observation_spec = unsqueeze.transform_observation_spec(observation_spec) + observation_spec = unsqueeze.transform_observation_spec( + observation_spec.clone() + ) assert observation_spec.shape == expected_size else: observation_spec = Composite( @@ -6960,7 +6995,9 @@ def test_transform_no_env(self, keys, size, nchannels, batch, device, dim): for key in keys } ) - observation_spec = unsqueeze.transform_observation_spec(observation_spec) + observation_spec = unsqueeze.transform_observation_spec( + observation_spec.clone() + ) for key in keys: assert observation_spec[key].shape == expected_size @@ -7107,7 +7144,9 @@ def test_transform_compose(self, keys, size, nchannels, batch, device, dim): if len(keys) == 1: observation_spec = Bounded(-1, 1, (*batch, *size, nchannels, 16, 16)) - observation_spec = unsqueeze.transform_observation_spec(observation_spec) + observation_spec = unsqueeze.transform_observation_spec( + observation_spec.clone() + ) assert observation_spec.shape == expected_size else: observation_spec = Composite( @@ -7116,7 +7155,9 @@ def test_transform_compose(self, keys, size, nchannels, batch, device, dim): for key in keys } ) - observation_spec = unsqueeze.transform_observation_spec(observation_spec) + observation_spec = unsqueeze.transform_observation_spec( + observation_spec.clone() + ) for key in keys: assert observation_spec[key].shape == expected_size @@ -7713,7 +7754,7 @@ def test_transform_no_env(self, keys, batch, device): if len(keys) == 1: observation_spec = Bounded(0, 255, (16, 16, 3), dtype=torch.uint8) observation_spec = totensorimage.transform_observation_spec( - observation_spec + observation_spec.clone() ) assert observation_spec.shape == torch.Size([3, 16, 16]) assert (observation_spec.space.low == 0).all() @@ -7723,7 +7764,7 @@ def test_transform_no_env(self, keys, batch, device): {key: Bounded(0, 255, (16, 16, 3), dtype=torch.uint8) for key in keys} ) observation_spec = totensorimage.transform_observation_spec( - observation_spec + observation_spec.clone() ) for key in keys: assert observation_spec[key].shape == torch.Size([3, 16, 16]) @@ -7759,7 +7800,7 @@ def test_transform_compose(self, keys, batch, device): if len(keys) == 1: observation_spec = Bounded(0, 255, (16, 16, 3), dtype=torch.uint8) observation_spec = totensorimage.transform_observation_spec( - observation_spec + observation_spec.clone() ) assert observation_spec.shape == torch.Size([3, 16, 16]) assert (observation_spec.space.low == 0).all() @@ -7769,7 +7810,7 @@ def test_transform_compose(self, keys, batch, device): {key: Bounded(0, 255, (16, 16, 3), dtype=torch.uint8) for key in keys} ) observation_spec = totensorimage.transform_observation_spec( - observation_spec + observation_spec.clone() ) for key in keys: assert observation_spec[key].shape == torch.Size([3, 16, 16]) @@ -9048,7 +9089,9 @@ def test_vipnet_transform_observation_spec( if del_keys: exp_ts = Composite({key: Unbounded(1024, device) for key in out_keys}) - observation_spec_out = vip_net.transform_observation_spec(observation_spec) + observation_spec_out = vip_net.transform_observation_spec( + observation_spec.clone() + ) for key in in_keys: assert key not in observation_spec_out @@ -9064,7 +9107,9 @@ def test_vipnet_transform_observation_spec( ts_dict[key] = Unbounded(1024, device) exp_ts = Composite(ts_dict) - observation_spec_out = vip_net.transform_observation_spec(observation_spec) + observation_spec_out = vip_net.transform_observation_spec( + observation_spec.clone() + ) for key in in_keys + out_keys: assert observation_spec_out[key].shape == exp_ts[key].shape @@ -9528,7 +9573,8 @@ def test_parallelenv_vecnorm(self): lambda: TransformedEnv( GymEnv(PENDULUM_VERSIONED()), Compose( - self.rename_t, VecNorm(in_keys=[("some", "obs"), "reward"]) + self.rename_t, + VecNorm(in_keys=[("some", "obs"), "reward"]), ), ) ) @@ -9537,7 +9583,8 @@ def test_parallelenv_vecnorm(self): lambda: TransformedEnv( ContinuousActionVecMockEnv(), Compose( - self.rename_t, VecNorm(in_keys=[("some", "obs"), "reward"]) + self.rename_t, + VecNorm(in_keys=[("some", "obs"), "reward"]), ), ) ) @@ -9934,13 +9981,17 @@ def test_compose(self, keys, batch, device, nchannels=1, N=4): if len(keys) == 1: observation_spec = Bounded(0, 255, (nchannels, 16, 16)) # StepCounter does not want non composite specs - observation_spec = compose[:2].transform_observation_spec(observation_spec) + observation_spec = compose[:2].transform_observation_spec( + observation_spec.clone() + ) assert observation_spec.shape == torch.Size([nchannels * N, 16, 16]) else: observation_spec = Composite( {key: Bounded(0, 255, (nchannels, 16, 16)) for key in keys} ) - observation_spec = compose.transform_observation_spec(observation_spec) + observation_spec = compose.transform_observation_spec( + observation_spec.clone() + ) for key in keys: assert observation_spec[key].shape == torch.Size( [nchannels * N, 16, 16] diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 0296b55f972..f971fe478a3 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -4545,6 +4545,18 @@ def separates(self, *keys: NestedKey, default: Any = None) -> Composite: def set(self, name, spec): if self.locked: raise RuntimeError("Cannot modify a locked Composite.") + if spec is not None and self.device is not None and spec.device != self.device: + if isinstance(spec, Composite) and spec.device is None: + # We make a clone not to mess up the spec that was provided. + # in set() we do the same for shape - these two ops should be grouped. + # we don't care about the overhead of cloning twice though because in theory + # we don't set specs often. + spec = spec.clone().to(self._device) + else: + raise RuntimeError( + f"Setting a new attribute ({name}) on another device ({spec.device} against {self.device}). " + f"All devices of Composite must match." + ) if spec is not None: shape = spec.shape if shape[: self.ndim] != self.shape: @@ -4578,28 +4590,10 @@ def __init__( shape = _size(()) self._shape = _size(shape) self._specs = {} - for key, value in kwargs.items(): - self.set(key, value) _device = ( _make_ordinal_device(torch.device(device)) if device is not None else device ) - if len(kwargs): - for key, item in self.items(): - if item is None: - continue - if ( - isinstance(item, Composite) - and item.device is None - and _device is not None - ): - item = item.clone().to(_device) - elif (_device is not None) and (item.device != _device): - raise RuntimeError( - f"Setting a new attribute ({key}) on another device " - f"({item.device} against {_device}). All devices of " - "Composite must match." - ) self._device = _device if len(args): if len(args) > 1: @@ -4615,6 +4609,8 @@ def __init__( if isinstance(item, dict): item = Composite(item, shape=shape, device=_device) self[k] = item + for k, item in kwargs.items(): + self[k] = item @property def device(self) -> DEVICE_TYPING: @@ -4697,10 +4693,17 @@ def get(self, item, default=NO_DEFAULT): raise def __setitem__(self, key, value): + dest = self if isinstance(key, tuple) and len(key) > 1: - if key[0] not in self.keys(True): - self[key[0]] = Composite(shape=self.shape, device=self.device) - self[key[0]][key[1:]] = value + while key[0] not in self.keys(): + dest[key[0]] = dest = Composite(shape=self.shape, device=self.device) + if len(key) > 2: + key = key[1:] + else: + break + else: + dest = self[key[0]] + dest[key[1:]] = value return elif isinstance(key, tuple): self[key[0]] = value @@ -4711,22 +4714,6 @@ def __setitem__(self, key, value): raise AttributeError(f"Composite[{key}] cannot be set") if isinstance(value, dict): value = Composite(value, device=self._device, shape=self.shape) - if ( - value is not None - and self.device is not None - and value.device != self.device - ): - if isinstance(value, Composite) and value.device is None: - # We make a clone not to mess up the spec that was provided. - # in set() we do the same for shape - these two ops should be grouped. - # we don't care about the overhead of cloning twice though because in theory - # we don't set specs often. - value = value.clone().to(self.device) - else: - raise RuntimeError( - f"Setting a new attribute ({key}) on another device ({value.device} against {self.device}). " - f"All devices of Composite must match." - ) self.set(key, value) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 51331a86346..9db6949cb37 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -12,7 +12,7 @@ import os import weakref from collections import OrderedDict -from copy import copy, deepcopy +from copy import deepcopy from functools import wraps from multiprocessing import connection from multiprocessing.synchronize import Lock as MpLock @@ -371,6 +371,8 @@ def __init__( ) self._mp_start_method = mp_start_method + is_spec_locked = EnvBase.is_spec_locked + @property def non_blocking(self): nb = self._non_blocking @@ -471,7 +473,7 @@ def find_all_worker_devices(item): return _do_nothing, _do_nothing def __getstate__(self): - out = copy(self.__dict__) + out = self.__dict__.copy() out["_sync_m2w_value"] = None out["_sync_w2m_value"] = None return out @@ -933,8 +935,9 @@ def _start_workers(self) -> None: "environments!" ) weakref_set.add(wr) - self._envs.append(env) + self._envs.append(env.set_spec_lock_()) self.is_closed = False + self.set_spec_lock_() @_check_start def state_dict(self) -> OrderedDict: @@ -1458,6 +1461,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda): for channel in self.parent_channels: channel.send(("init", None)) self.is_closed = False + self.set_spec_lock_() @_check_start def state_dict(self) -> OrderedDict: @@ -2164,6 +2168,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda): ) env = env_fun del env_fun + env.set_spec_lock_() i = -1 import torchrl diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 78987d2df57..18258807521 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -20,7 +20,7 @@ TensorDictBase, unravel_key, ) -from tensordict.base import _is_leaf_nontensor +from tensordict.base import _is_leaf_nontensor, NO_DEFAULT from tensordict.utils import is_non_tensor, NestedKey from torchrl._utils import ( _ends_with, @@ -63,24 +63,37 @@ def _tensor_to_np(t): } +def _maybe_unlock(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + is_locked = self.is_spec_locked + try: + if is_locked: + self.set_spec_lock_(False) + result = func(self, *args, **kwargs) + finally: + if is_locked: + if not hasattr(self, "_cache"): + self._cache = {} + self._cache.clear() + self.set_spec_lock_(True) + return result + + return wrapper + + def _cache_value(func): """Caches the result of the decorated function in env._cache dictionary.""" - # func_name = func.__name__ + func_name = func.__name__ @wraps(func) def wrapper(self, *args, **kwargs): - # result = self.__dict__.setdefault("_cache", {}).get(func_name, NO_DEFAULT) - # if result is NO_DEFAULT: - result = func(self, *args, **kwargs) - # Ideally we'd like to cache all the `_keys` attributes but there's a catch: one can modify the specs at - # any time so this will not run as expected. - # The solution should be: - # - optionally lock the specs in the env, like we do with tensordict. - # - Locked specs will behave like locked tensordict: we lock the root spec, meaning that all the sub-specs - # will be locked, and no __setattr__ will be allowed within the env unless it's unlocked. - # We cannot just guard spec.__setattr__ because `spec[key0][key1] = smth` will not call a setattr - # on the root spec so there's a chance we miss it. - # self.__dict__.setdefault("_cache", {})[func_name] = result + if not self.is_spec_locked: + return func(self, *args, **kwargs) + result = self.__dict__.setdefault("_cache", {}).get(func_name, NO_DEFAULT) + if result is NO_DEFAULT: + result = func(self, *args, **kwargs) + self.__dict__.setdefault("_cache", {})[func_name] = result return result return wrapper @@ -219,11 +232,18 @@ def to(self, device: DEVICE_TYPING) -> EnvMetaData: class _EnvPostInit(abc.ABCMeta): def __call__(cls, *args, **kwargs): + spec_locked = kwargs.pop("spec_locked", True) auto_reset = kwargs.pop("auto_reset", False) auto_reset_replace = kwargs.pop("auto_reset_replace", True) instance: EnvBase = super().__call__(*args, **kwargs) if "_cache" not in instance.__dict__: instance._cache = {} + + if spec_locked: + instance.input_spec.lock_(recurse=True) + instance.output_spec.lock_(recurse=True) + instance._is_spec_locked = spec_locked + # we create the done spec by adding a done/terminated entry if one is missing instance._create_done_specs() # we access lazy attributed to make sure they're built properly. @@ -282,6 +302,20 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): at every reset and every step. Defaults to ``False``. allow_done_after_reset (bool, optional): if ``True``, an environment can be done after a call to :meth:`~.reset` is made. Defaults to ``False``. + spec_locked (bool, optional): if ``True``, the specs are locked and can only be + modified if :meth:`~torchrl.envs.EnvBase.set_spec_lock_` is called. + + .. note:: The locking is achieved by the `EnvBase` metaclass. It does not appear in the + `__init__` method and is included in the keyword arguments strictly for type-hinting purpose. + + .. seealso:: :ref:`Locking environment specs `. + + Defaults to ``True``. + auto_reset (bool, optional): if ``True``, the env is assumed to reset automatically + when done. Defaults to ``False``. + + .. note:: The auto-resetting is achieved by the `EnvBase` metaclass. It does not appear in the + `__init__` method and is included in the keyword arguments strictly for type-hinting purpose. Attributes: done_spec (Composite): equivalent to ``full_done_spec`` as all @@ -312,6 +346,8 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): batch_size (torch.Size): The batch-size of the environment. device (torch.device): the device where the input/outputs of the environment are to be expected. Can be ``None``. + is_spec_locked (bool): returns ``True`` if the specs are locked. See the :attr:`spec_locked` + argument above. Methods: step (TensorDictBase -> TensorDictBase): step in the environment @@ -412,6 +448,7 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): _batch_size: torch.Size | None _device: torch.device | None + _is_spec_locked: bool = False def __init__( self, @@ -420,6 +457,8 @@ def __init__( batch_size: Optional[torch.Size] = None, run_type_checks: bool = False, allow_done_after_reset: bool = False, + spec_locked: bool = True, + auto_reset: bool = False, ): if "_cache" not in self.__dict__: self._cache = {} @@ -444,7 +483,7 @@ def __init__( if output_spec is None: output_spec = self.__dict__["_output_spec"] = Composite( shape=batch_size, device=device - ).lock_() + ) elif self._output_spec.device != device and device is not None: self.__dict__["_output_spec"] = self.__dict__["_output_spec"].to( self.device @@ -453,12 +492,12 @@ def __init__( if input_spec is None: input_spec = self.__dict__["_input_spec"] = Composite( shape=batch_size, device=device - ).lock_() + ) elif self._input_spec.device != device and device is not None: self.__dict__["_input_spec"] = self.__dict__["_input_spec"].to(self.device) - output_spec.unlock_() - input_spec.unlock_() + output_spec.unlock_(recurse=True) + input_spec.unlock_(recurse=True) if "full_observation_spec" not in output_spec: output_spec["full_observation_spec"] = Composite() if "full_done_spec" not in output_spec: @@ -469,14 +508,58 @@ def __init__( input_spec["full_state_spec"] = Composite() if "full_action_spec" not in input_spec: input_spec["full_action_spec"] = Composite() - output_spec.lock_() - input_spec.lock_() if "is_closed" not in self.__dir__(): self.is_closed = True self._run_type_checks = run_type_checks self._allow_done_after_reset = allow_done_after_reset + def set_spec_lock_(self, mode: bool = True) -> EnvBase: + """Locks or unlocks the environment's specs. + + Args: + mode (bool): Whether to lock (`True`) or unlock (`False`) the specs. Defaults to `True`. + + Returns: + EnvBase: The environment instance itself. + + .. seealso:: :ref:`Locking environment specs `. + + """ + output_spec = self.__dict__.get("_output_spec") + input_spec = self.__dict__.get("_input_spec") + if mode: + if output_spec is not None: + output_spec.lock_(recurse=True) + if input_spec is not None: + input_spec.lock_(recurse=True) + else: + self._cache.clear() + if output_spec is not None: + output_spec.unlock_(recurse=True) + if input_spec is not None: + input_spec.unlock_(recurse=True) + self.__dict__["_is_spec_locked"] = mode + return self + + @property + def is_spec_locked(self): + """Gets whether the environment's specs are locked. + + This property can be modified directly. + + Returns: + bool: True if the specs are locked, False otherwise. + + .. seealso:: :ref:`Locking environment specs `. + + """ + return self.__dict__.get("_is_spec_locked", False) + + @is_spec_locked.setter + def is_spec_locked(self, value: bool): + self.set_spec_lock_(value) + def auto_specs_( self, policy: Callable[[TensorDictBase], TensorDictBase], @@ -707,19 +790,16 @@ def batch_size(self) -> torch.Size: return _batch_size @batch_size.setter + @_maybe_unlock def batch_size(self, value: torch.Size) -> None: self._batch_size = torch.Size(value) if ( hasattr(self, "output_spec") and self.output_spec.shape[: len(value)] != value ): - self.output_spec.unlock_() self.output_spec.shape = value - self.output_spec.lock_() if hasattr(self, "input_spec") and self.input_spec.shape[: len(value)] != value: - self.input_spec.unlock_() self.input_spec.shape = value - self.input_spec.lock_() @property def shape(self): @@ -810,12 +890,17 @@ def input_spec(self) -> TensorSpec: """ input_spec = self.__dict__.get("_input_spec") if input_spec is None: + is_locked = self.is_spec_locked + if is_locked: + self.set_spec_lock_(False) input_spec = Composite( full_state_spec=None, shape=self.batch_size, device=self.device, - ).lock_() + ) self.__dict__["_input_spec"] = input_spec + if is_locked: + self.set_spec_lock_(True) return input_spec @input_spec.setter @@ -870,11 +955,16 @@ def output_spec(self) -> TensorSpec: """ output_spec = self.__dict__.get("_output_spec") if output_spec is None: + is_locked = self.is_spec_locked + if is_locked: + self.set_spec_lock_(False) output_spec = Composite( shape=self.batch_size, device=self.device, - ).lock_() + ) self.__dict__["_output_spec"] = output_spec + if is_locked: + self.set_spec_lock_(True) return output_spec @output_spec.setter @@ -1031,29 +1121,25 @@ def action_spec(self) -> TensorSpec: return out @action_spec.setter - @_clear_cache_when_set + @_maybe_unlock def action_spec(self, value: TensorSpec) -> None: - try: - self.input_spec.unlock_() - device = self.input_spec._device - if not hasattr(value, "shape"): - raise TypeError( - f"action_spec of type {type(value)} do not have a shape attribute." - ) - if value.shape[: len(self.batch_size)] != self.batch_size: - raise ValueError( - f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size}). " - "Please use `env.action_spec_unbatched = value` to set unbatched versions instead." - ) + device = self.input_spec._device + if not hasattr(value, "shape"): + raise TypeError( + f"action_spec of type {type(value)} do not have a shape attribute." + ) + if value.shape[: len(self.batch_size)] != self.batch_size: + raise ValueError( + f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size}). " + "Please use `env.action_spec_unbatched = value` to set unbatched versions instead." + ) - if not isinstance(value, Composite): - value = Composite( - action=value.to(device), shape=self.batch_size, device=device - ) + if not isinstance(value, Composite): + value = Composite( + action=value.to(device), shape=self.batch_size, device=device + ) - self.input_spec["full_action_spec"] = value.to(device) - finally: - self.input_spec.lock_() + self.input_spec["full_action_spec"] = value.to(device) @property def full_action_spec(self) -> Composite: @@ -1081,10 +1167,13 @@ def full_action_spec(self) -> Composite: """ full_action_spec = self.input_spec.get("full_action_spec", None) if full_action_spec is None: + is_locked = self.is_spec_locked + if is_locked: + self.set_spec_lock_(False) full_action_spec = Composite(shape=self.batch_size, device=self.device) - self.input_spec.unlock_() self.input_spec["full_action_spec"] = full_action_spec - self.input_spec.lock_() + if is_locked: + self.set_spec_lock_(True) return full_action_spec @full_action_spec.setter @@ -1219,36 +1308,31 @@ def reward_spec(self) -> TensorSpec: return reward_spec[self.reward_keys[0]] @reward_spec.setter - @_clear_cache_when_set + @_maybe_unlock def reward_spec(self, value: TensorSpec) -> None: - try: - self.output_spec.unlock_() - device = self.output_spec._device - if not hasattr(value, "shape"): - raise TypeError( - f"reward_spec of type {type(value)} do not have a shape " - f"attribute." - ) - if value.shape[: len(self.batch_size)] != self.batch_size: - raise ValueError( - f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size}). " - "Please use `env.reward_spec_unbatched = value` to set unbatched versions instead." - ) - if not isinstance(value, Composite): - value = Composite( - reward=value.to(device), shape=self.batch_size, device=device + device = self.output_spec._device + if not hasattr(value, "shape"): + raise TypeError( + f"reward_spec of type {type(value)} do not have a shape " f"attribute." + ) + if value.shape[: len(self.batch_size)] != self.batch_size: + raise ValueError( + f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size}). " + "Please use `env.reward_spec_unbatched = value` to set unbatched versions instead." + ) + if not isinstance(value, Composite): + value = Composite( + reward=value.to(device), shape=self.batch_size, device=device + ) + for leaf in value.values(True, True): + if len(leaf.shape) == 0: + raise RuntimeError( + "the reward_spec's leafs shape cannot be empty (this error" + " usually comes from trying to set a reward_spec" + " with a null number of dimensions. Try using a multidimensional" + " spec instead, for instance with a singleton dimension at the tail)." ) - for leaf in value.values(True, True): - if len(leaf.shape) == 0: - raise RuntimeError( - "the reward_spec's leafs shape cannot be empty (this error" - " usually comes from trying to set a reward_spec" - " with a null number of dimensions. Try using a multidimensional" - " spec instead, for instance with a singleton dimension at the tail)." - ) - self.output_spec["full_reward_spec"] = value.to(device) - finally: - self.output_spec.lock_() + self.output_spec["full_reward_spec"] = value.to(device) @property def full_reward_spec(self) -> Composite: @@ -1289,7 +1373,7 @@ def full_reward_spec(self) -> Composite: return self.output_spec["full_reward_spec"] @full_reward_spec.setter - @_clear_cache_when_set + @_maybe_unlock def full_reward_spec(self, spec: Composite) -> None: self.reward_spec = spec.to(self.device) if self.device is not None else spec @@ -1353,7 +1437,7 @@ def full_done_spec(self) -> Composite: return self.output_spec["full_done_spec"] @full_done_spec.setter - @_clear_cache_when_set + @_maybe_unlock def full_done_spec(self, spec: Composite) -> None: self.done_spec = spec.to(self.device) if self.device is not None else spec @@ -1427,6 +1511,7 @@ def done_spec(self) -> TensorSpec: done_spec = self.output_spec["full_done_spec"] return done_spec + @_maybe_unlock def _create_done_specs(self): """Reads through the done specs and makes it so that it's complete. @@ -1455,9 +1540,7 @@ def _create_done_specs(self): dtype=torch.bool, device=self.device, ) - self.output_spec.unlock_() self.output_spec["full_done_spec"] = full_done_spec - self.output_spec.lock_() return def check_local_done(spec): @@ -1491,46 +1574,44 @@ def check_local_done(spec): n=2, shape=(*spec.shape, 1), dtype=torch.bool, device=self.device ) - self.output_spec.unlock_() + if_locked = self.is_spec_locked + if if_locked: + self.is_spec_locked = False check_local_done(full_done_spec) self.output_spec["full_done_spec"] = full_done_spec - self.output_spec.lock_() + if if_locked: + self.is_spec_locked = True return @done_spec.setter - @_clear_cache_when_set + @_maybe_unlock def done_spec(self, value: TensorSpec) -> None: - try: - self.output_spec.unlock_() - device = self.output_spec.device - if not hasattr(value, "shape"): - raise TypeError( - f"done_spec of type {type(value)} do not have a shape " - f"attribute." - ) - if value.shape[: len(self.batch_size)] != self.batch_size: - raise ValueError( - f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." - ) - if not isinstance(value, Composite): - value = Composite( - done=value.to(device), - terminated=value.to(device), - shape=self.batch_size, - device=device, + device = self.output_spec.device + if not hasattr(value, "shape"): + raise TypeError( + f"done_spec of type {type(value)} do not have a shape " f"attribute." + ) + if value.shape[: len(self.batch_size)] != self.batch_size: + raise ValueError( + f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." + ) + if not isinstance(value, Composite): + value = Composite( + done=value.to(device), + terminated=value.to(device), + shape=self.batch_size, + device=device, + ) + for leaf in value.values(True, True): + if len(leaf.shape) == 0: + raise RuntimeError( + "the done_spec's leafs shape cannot be empty (this error" + " usually comes from trying to set a reward_spec" + " with a null number of dimensions. Try using a multidimensional" + " spec instead, for instance with a singleton dimension at the tail)." ) - for leaf in value.values(True, True): - if len(leaf.shape) == 0: - raise RuntimeError( - "the done_spec's leafs shape cannot be empty (this error" - " usually comes from trying to set a reward_spec" - " with a null number of dimensions. Try using a multidimensional" - " spec instead, for instance with a singleton dimension at the tail)." - ) - self.output_spec["full_done_spec"] = value.to(device) - self._create_done_specs() - finally: - self.output_spec.lock_() + self.output_spec["full_done_spec"] = value.to(device) + self._create_done_specs() # observation spec: observation specs belong to output_spec @property @@ -1564,40 +1645,44 @@ def observation_spec(self) -> Composite: """ observation_spec = self.output_spec.get("full_observation_spec", default=None) if observation_spec is None: + is_locked = self.is_spec_locked + if is_locked: + self.set_spec_lock_(False) observation_spec = Composite(shape=self.batch_size, device=self.device) - self.output_spec.unlock_() self.output_spec["full_observation_spec"] = observation_spec - self.output_spec.lock_() + if is_locked: + self.set_spec_lock_(True) + return observation_spec @observation_spec.setter - @_clear_cache_when_set + @_maybe_unlock def observation_spec(self, value: TensorSpec) -> None: - try: - self.output_spec.unlock_() - if not isinstance(value, Composite): - raise TypeError("The type of an observation_spec must be Composite.") - elif value.shape[: len(self.batch_size)] != self.batch_size: - raise ValueError( - f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." - ) - if value.shape[: len(self.batch_size)] != self.batch_size: - raise ValueError( - f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." - ) - device = self.output_spec._device - self.output_spec["full_observation_spec"] = ( - value.to(device) if device is not None else value + if not isinstance(value, Composite): + value = Composite( + observation=value, + device=self.device, + batch_size=self.output_spec.batch_size, ) - finally: - self.output_spec.lock_() + elif value.shape[: len(self.batch_size)] != self.batch_size: + raise ValueError( + f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." + ) + if value.shape[: len(self.batch_size)] != self.batch_size: + raise ValueError( + f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." + ) + device = self.output_spec._device + self.output_spec["full_observation_spec"] = ( + value.to(device) if device is not None else value + ) @property def full_observation_spec(self) -> Composite: return self.observation_spec @full_observation_spec.setter - @_clear_cache_when_set + @_maybe_unlock def full_observation_spec(self, spec: Composite): self.observation_spec = spec @@ -1637,38 +1722,37 @@ def state_spec(self) -> Composite: """ state_spec = self.input_spec["full_state_spec"] if state_spec is None: + is_locked = self.is_spec_locked + if is_locked: + self.set_spec_lock_(False) state_spec = Composite(shape=self.batch_size, device=self.device) - self.input_spec.unlock_() self.input_spec["full_state_spec"] = state_spec - self.input_spec.lock_() + if is_locked: + self.set_spec_lock_(True) return state_spec @state_spec.setter - @_clear_cache_when_set + @_maybe_unlock def state_spec(self, value: Composite) -> None: - try: - self.input_spec.unlock_() - if value is None: - self.input_spec["full_state_spec"] = Composite( - device=self.device, shape=self.batch_size + if value is None: + self.input_spec["full_state_spec"] = Composite( + device=self.device, shape=self.batch_size + ) + else: + device = self.input_spec.device + if not isinstance(value, Composite): + raise TypeError("The type of an state_spec must be Composite.") + elif value.shape[: len(self.batch_size)] != self.batch_size: + raise ValueError( + f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." ) - else: - device = self.input_spec.device - if not isinstance(value, Composite): - raise TypeError("The type of an state_spec must be Composite.") - elif value.shape[: len(self.batch_size)] != self.batch_size: - raise ValueError( - f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." - ) - if value.shape[: len(self.batch_size)] != self.batch_size: - raise ValueError( - f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." - ) - self.input_spec["full_state_spec"] = ( - value.to(device) if device is not None else value + if value.shape[: len(self.batch_size)] != self.batch_size: + raise ValueError( + f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." ) - finally: - self.input_spec.lock_() + self.input_spec["full_state_spec"] = ( + value.to(device) if device is not None else value + ) @property def full_state_spec(self) -> Composite: @@ -1698,7 +1782,7 @@ def full_state_spec(self) -> Composite: return self.state_spec @full_state_spec.setter - @_clear_cache_when_set + @_maybe_unlock def full_state_spec(self, spec: Composite) -> None: self.state_spec = spec @@ -1720,6 +1804,7 @@ def full_action_spec_unbatched(self) -> Composite: return self._make_single_env_spec(self.full_action_spec) @full_action_spec_unbatched.setter + @_maybe_unlock def full_action_spec_unbatched(self, spec: Composite): spec = spec.expand(self.batch_size + spec.shape) self.full_action_spec = spec @@ -1730,6 +1815,7 @@ def action_spec_unbatched(self) -> TensorSpec: return self._make_single_env_spec(self.action_spec) @action_spec_unbatched.setter + @_maybe_unlock def action_spec_unbatched(self, spec: Composite): spec = spec.expand(self.batch_size + spec.shape) self.action_spec = spec @@ -1740,6 +1826,7 @@ def full_observation_spec_unbatched(self) -> Composite: return self._make_single_env_spec(self.full_observation_spec) @full_observation_spec_unbatched.setter + @_maybe_unlock def full_observation_spec_unbatched(self, spec: Composite): spec = spec.expand(self.batch_size + spec.shape) self.full_observation_spec = spec @@ -1750,6 +1837,7 @@ def observation_spec_unbatched(self) -> Composite: return self._make_single_env_spec(self.observation_spec) @observation_spec_unbatched.setter + @_maybe_unlock def observation_spec_unbatched(self, spec: Composite): spec = spec.expand(self.batch_size + spec.shape) self.observation_spec = spec @@ -1760,6 +1848,7 @@ def full_reward_spec_unbatched(self) -> Composite: return self._make_single_env_spec(self.full_reward_spec) @full_reward_spec_unbatched.setter + @_maybe_unlock def full_reward_spec_unbatched(self, spec: Composite): spec = spec.expand(self.batch_size + spec.shape) self.full_reward_spec = spec @@ -1770,6 +1859,7 @@ def reward_spec_unbatched(self) -> TensorSpec: return self._make_single_env_spec(self.reward_spec) @reward_spec_unbatched.setter + @_maybe_unlock def reward_spec_unbatched(self, spec: Composite): spec = spec.expand(self.batch_size + spec.shape) self.reward_spec = spec @@ -1780,6 +1870,7 @@ def full_done_spec_unbatched(self) -> Composite: return self._make_single_env_spec(self.full_done_spec) @full_done_spec_unbatched.setter + @_maybe_unlock def full_done_spec_unbatched(self, spec: Composite): spec = spec.expand(self.batch_size + spec.shape) self.full_done_spec = spec @@ -1790,6 +1881,7 @@ def done_spec_unbatched(self) -> TensorSpec: return self._make_single_env_spec(self.done_spec) @done_spec_unbatched.setter + @_maybe_unlock def done_spec_unbatched(self, spec: Composite): spec = spec.expand(self.batch_size + spec.shape) self.done_spec = spec @@ -1800,6 +1892,7 @@ def output_spec_unbatched(self) -> Composite: return self._make_single_env_spec(self.output_spec) @output_spec_unbatched.setter + @_maybe_unlock def output_spec_unbatched(self, spec: Composite): spec = spec.expand(self.batch_size + spec.shape) self.output_spec = spec @@ -1810,6 +1903,7 @@ def input_spec_unbatched(self) -> Composite: return self._make_single_env_spec(self.input_spec) @input_spec_unbatched.setter + @_maybe_unlock def input_spec_unbatched(self, spec: Composite): spec = spec.expand(self.batch_size + spec.shape) self.input_spec = spec @@ -1820,6 +1914,7 @@ def full_state_spec_unbatched(self) -> Composite: return self._make_single_env_spec(self.full_state_spec) @full_state_spec_unbatched.setter + @_maybe_unlock def full_state_spec_unbatched(self, spec: Composite): spec = spec.expand(self.batch_size + spec.shape) self.full_state_spec = spec @@ -1830,6 +1925,7 @@ def state_spec_unbatched(self) -> TensorSpec: return self._make_single_env_spec(self.state_spec) @state_spec_unbatched.setter + @_maybe_unlock def state_spec_unbatched(self, spec: Composite): spec = spec.expand(self.batch_size + spec.shape) self.state_spec = spec @@ -2725,7 +2821,7 @@ def specs(self) -> Composite: output_spec=self.output_spec, input_spec=self.input_spec, shape=self.batch_size, - ).lock_() + ) @property @_cache_value @@ -3073,7 +3169,7 @@ def rollout( out_td.refine_names(..., "time") return out_td - @_clear_cache_when_set + @_maybe_unlock def add_truncated_keys(self) -> EnvBase: """Adds truncated keys to the environment.""" i = 0 @@ -3133,13 +3229,9 @@ def step_mdp(self, next_tensordict: TensorDictBase) -> TensorDictBase: return self._step_mdp(next_tensordict) @property - # @_cache_value + @_cache_value def _step_mdp(self): - step_func = self._cache.get("_step_mdp_value") - if step_func is None: - step_func = _StepMDP(self, exclude_action=False) - self._cache["_step_mdp_value"] = step_func - return step_func + return _StepMDP(self, exclude_action=False) def _rollout_stop_early( self, @@ -3463,12 +3555,13 @@ def __del__(self): # __del__ will not affect the program. pass + @_maybe_unlock def to(self, device: DEVICE_TYPING) -> EnvBase: device = _make_ordinal_device(torch.device(device)) if device == self.device: return self - self.__dict__["_input_spec"] = self.input_spec.to(device).lock_() - self.__dict__["_output_spec"] = self.output_spec.to(device).lock_() + self.__dict__["_input_spec"] = self.input_spec.to(device) + self.__dict__["_output_spec"] = self.output_spec.to(device) self._device = device return super().to(device) @@ -3539,12 +3632,14 @@ def __init__( device: DEVICE_TYPING = None, batch_size: Optional[torch.Size] = None, allow_done_after_reset: bool = False, + spec_locked: bool = True, **kwargs, ): super().__init__( device=device, batch_size=batch_size, allow_done_after_reset=allow_done_after_reset, + spec_locked=spec_locked, ) if len(args): raise ValueError( diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index bb849847f3a..5e6dd55be8b 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -16,7 +16,7 @@ from torchrl._utils import logger as torchrl_logger from torchrl.data.tensor_specs import Composite, NonTensor, TensorSpec, Unbounded -from torchrl.envs.common import _EnvWrapper, EnvBase +from torchrl.envs.common import _EnvWrapper, _maybe_unlock, EnvBase class BaseInfoDictReader(metaclass=abc.ABCMeta): @@ -434,6 +434,7 @@ def _output_transform( def _reset_output_transform(self, reset_outputs_tuple: Tuple) -> Tuple: ... + @_maybe_unlock def set_info_dict_reader( self, info_dict_reader: BaseInfoDictReader | None = None, diff --git a/torchrl/envs/libs/_gym_utils.py b/torchrl/envs/libs/_gym_utils.py index b95bfb335c6..a68e5101ff6 100644 --- a/torchrl/envs/libs/_gym_utils.py +++ b/torchrl/envs/libs/_gym_utils.py @@ -39,7 +39,7 @@ def __init__( self.observation_space = _torchrl_to_gym_spec_transform( Composite( { - key: self.torchrl_env.full_observation_spec[key] + key: self.torchrl_env.full_observation_spec[key].clone() for key in self._observation_keys } ), diff --git a/torchrl/envs/libs/robohive.py b/torchrl/envs/libs/robohive.py index 8045f8e0ab4..d187dc10c91 100644 --- a/torchrl/envs/libs/robohive.py +++ b/torchrl/envs/libs/robohive.py @@ -13,6 +13,7 @@ import torch from tensordict import TensorDict from torchrl.data.tensor_specs import Unbounded +from torchrl.envs.common import _maybe_unlock from torchrl.envs.libs.gym import ( _AsyncMeta, _gym_to_torchrl_spec_transform, @@ -251,6 +252,7 @@ def register_visual_env(cls, env_name, cams, from_depths): cls.env_list += [env_name] return env_name + @_maybe_unlock def _refine_specs(self) -> None: # noqa: F821 env = self._env self.action_spec = _gym_to_torchrl_spec_transform( diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index e8b0a744d5c..b0a6a84e5c5 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -60,7 +60,6 @@ _ends_with, _make_ordinal_device, _replace_last, - implement_for, logger as torchrl_logger, ) @@ -78,7 +77,13 @@ Unbounded, UnboundedContinuous, ) -from torchrl.envs.common import _do_nothing, _EnvPostInit, EnvBase, make_tensordict +from torchrl.envs.common import ( + _do_nothing, + _EnvPostInit, + _maybe_unlock, + EnvBase, + make_tensordict, +) from torchrl.envs.transforms import functional as F from torchrl.envs.transforms.utils import ( _get_reset, @@ -850,36 +855,43 @@ def _inplace_update(self): @property def output_spec(self) -> TensorSpec: """Observation spec of the transformed environment.""" - if not self.cache_specs or self.__dict__.get("_output_spec", None) is None: - output_spec = self.base_env.output_spec.clone() - - # remove cached key values, but not _input_spec - super().empty_cache() - output_spec = output_spec.unlock_() - output_spec = self.transform.transform_output_spec(output_spec) - output_spec.lock_() - if self.cache_specs: - self.__dict__["_output_spec"] = output_spec - else: - output_spec = self.__dict__.get("_output_spec", None) + if self.cache_specs: + output_spec = self.__dict__.get("_output_spec") + if output_spec is not None: + return output_spec + output_spec = self._make_output_spec() + return output_spec + + @_maybe_unlock + def _make_output_spec(self): + output_spec = self.base_env.output_spec.clone() + + # remove cached key values, but not _input_spec + super().empty_cache() + output_spec = self.transform.transform_output_spec(output_spec) + if self.cache_specs: + self.__dict__["_output_spec"] = output_spec return output_spec @property def input_spec(self) -> TensorSpec: - """Action spec of the transformed environment.""" - if self.__dict__.get("_input_spec", None) is None or not self.cache_specs: - input_spec = self.base_env.input_spec.clone() - - # remove cached key values but not _output_spec - super().empty_cache() - - input_spec.unlock_() - input_spec = self.transform.transform_input_spec(input_spec) - input_spec.lock_() - if self.cache_specs: - self.__dict__["_input_spec"] = input_spec - else: - input_spec = self.__dict__.get("_input_spec", None) + """Observation spec of the transformed environment.""" + if self.cache_specs: + input_spec = self.__dict__.get("_input_spec") + if input_spec is not None: + return input_spec + input_spec = self._make_input_spec() + return input_spec + + @_maybe_unlock + def _make_input_spec(self): + input_spec = self.base_env.input_spec.clone() + + # remove cached key values, but not _input_spec + super().empty_cache() + input_spec = self.transform.transform_input_spec(input_spec) + if self.cache_specs: + self.__dict__["_input_spec"] = input_spec return input_spec def rand_action(self, tensordict: Optional[TensorDictBase] = None) -> TensorDict: @@ -6507,7 +6519,7 @@ def __repr__(self) -> str: ) def __getstate__(self) -> Dict[str, Any]: - state = self.__dict__.copy() + state = super().__getstate__() _lock = state.pop("lock", None) if _lock is not None: state["lock_placeholder"] = None @@ -6518,7 +6530,7 @@ def __setstate__(self, state: Dict[str, Any]): state.pop("lock_placeholder") _lock = mp.Lock() state["lock"] = _lock - self.__dict__.update(state) + super().__setstate__(state) @_apply_to_composite def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: @@ -9950,14 +9962,7 @@ def __init__(self, out_key: NestedKey = "traj_count"): def _make_shared_value(self): self._traj_count = mp.Value("i", 0) - @implement_for("torch", None, "2.1") def __getstate__(self): - state = self.__dict__.copy() - state["_traj_count"] = None - return state - - @implement_for("torch", "2.1") - def __getstate__(self): # noqa: F811 state = super().__getstate__() state["_traj_count"] = None return state diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 02d13c4924a..39b0faa9692 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -148,13 +148,14 @@ def __init__( self.exclude_from_root = self._repr_key_list_as_tree(self.exclude_from_root) self.keys_from_root = self._repr_key_list_as_tree(self.keys_from_root) self.keys_from_next = self._repr_key_list_as_tree(self.keys_from_next) - self.validated = None + self.validated = True # Model based envs can have missing keys # TODO: do we want to always allow this? check_env_specs should catch these or downstream ops self._allow_absent_keys = True def validate(self, tensordict): + # Deprecated - leaving dormant if self.validated: return True if self.validated is None: @@ -180,14 +181,16 @@ def _is_reset(key: NestedKey): if not _is_reset(key) } expected = set(expected) + # Actual (the input td) can have more keys, like loc and scale etc + # But we cannot have keys missing: if there's a key in expected that is not in actual + # it is a problem. self.validated = expected.intersection(actual) == expected if not self.validated: warnings.warn( - "The expected key set and actual key set differ. " - "This will work but with a slower throughput than " - "when the specs match exactly the actual key set " - "in the data. " - f"{{Expected keys}}-{{Actual keys}}={set(expected) - actual}, \n" + "The expected key set and actual key set differ (all expected keys must be present, " + "extra keys can be present in the input TensorDict). " + "As a result, step_mdp will need to run extra key checks at each iteration. " + f"{{Expected keys}}-{{Actual keys}}={set(expected) - actual} (<= this set should be empty), \n" f"{{Actual keys}}-{{Expected keys}}={actual- set(expected)}." ) return self.validated @@ -285,52 +288,38 @@ def __call__(self, tensordict): ) return out next_td = tensordict._get_str("next", None) - if self.validate(tensordict): - if self.keep_other: - out = self._exclude(self.exclude_from_root, tensordict, out=None) - if out is None: - out = tensordict.empty() - else: - out = next_td.empty() - self._grab_and_place( - self.keys_from_root, - tensordict, - out, - _allow_absent_keys=self._allow_absent_keys, + if self.keep_other: + out = self._exclude(self.exclude_from_root, tensordict, out=None) + if out is None: + out = tensordict.empty() + else: + out = next_td.empty() + self._grab_and_place( + self.keys_from_root, + tensordict, + out, + _allow_absent_keys=self._allow_absent_keys, + ) + if isinstance(next_td, LazyStackedTensorDict): + if not isinstance(out, LazyStackedTensorDict): + out = LazyStackedTensorDict( + *out.unbind(next_td.stack_dim), stack_dim=next_td.stack_dim ) - if isinstance(next_td, LazyStackedTensorDict): - if not isinstance(out, LazyStackedTensorDict): - out = LazyStackedTensorDict( - *out.unbind(next_td.stack_dim), stack_dim=next_td.stack_dim - ) - for _next_td, _out in zip(next_td.tensordicts, out.tensordicts): - self._grab_and_place( - self.keys_from_next, - _next_td, - _out, - _allow_absent_keys=self._allow_absent_keys, - ) - else: + for _next_td, _out in zip(next_td.tensordicts, out.tensordicts): self._grab_and_place( self.keys_from_next, - next_td, - out, + _next_td, + _out, _allow_absent_keys=self._allow_absent_keys, ) - return out else: - out = next_td.empty() - total_key = () - if self.keep_other: - for key in tensordict.keys(): - if key != "next": - _set(tensordict, out, key, total_key, self.excluded) - elif not self.exclude_action: - for action_key in self.action_keys: - _set_single_key(tensordict, out, action_key) - for key in next_td.keys(): - _set(next_td, out, key, total_key, self.excluded) - return out + self._grab_and_place( + self.keys_from_next, + next_td, + out, + _allow_absent_keys=self._allow_absent_keys, + ) + return out def step_mdp(