diff --git a/test/test_libs.py b/test/test_libs.py index ce90ed96d03..b97aef47fe3 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3083,22 +3083,24 @@ def test_data(self, dataset): ) @pytest.mark.parametrize("num_envs", [10, 20]) @pytest.mark.parametrize("device", get_default_devices()) +@pytest.mark.parametrize("from_pixels", [True, False]) class TestIsaacGym: @classmethod - def _run_on_proc(cls, q, task, num_envs, device): + def _run_on_proc(cls, q, task, num_envs, device, from_pixels): try: - env = IsaacGymEnv(task=task, num_envs=num_envs, device=device) + env = IsaacGymEnv(task=task, num_envs=num_envs, device=device, from_pixels=from_pixels) + print(env.rollout(3)) check_env_specs(env) q.put(("succeeded!", None)) except Exception as err: q.put(("failed!", err)) raise err - def test_env(self, task, num_envs, device): + def test_env(self, task, num_envs, device, from_pixels): from torch import multiprocessing as mp - q = mp.Queue(1) - proc = mp.Process(target=self._run_on_proc, args=(q, task, num_envs, device)) + self._run_on_proc(q, task, num_envs, device, from_pixels) + proc = mp.Process(target=self._run_on_proc, args=(q, task, num_envs, device, from_pixels)) try: proc.start() msg, error = q.get() diff --git a/torchrl/envs/libs/isaacgym.py b/torchrl/envs/libs/isaacgym.py index e130775bf5b..ad0d79d8d12 100644 --- a/torchrl/envs/libs/isaacgym.py +++ b/torchrl/envs/libs/isaacgym.py @@ -86,9 +86,14 @@ def _make_specs(self, env: "gym.Env") -> None: # noqa: F821 def _output_transform(self, output): obs, reward, done, info = output + if self.from_pixels: + obs["pixels"] = self._env.render(mode='rgb_array') return obs, reward, done ^ done, done, done, info def _reset_output_transform(self, reset_data): + reset_data.pop("reward", None) + if self.from_pixels: + reset_data["pixels"] = self._env.render(mode='rgb_array') return reset_data, {} @classmethod @@ -187,6 +192,7 @@ def __init__(self, task=None, *, env=None, num_envs, device, **kwargs): raise RuntimeError("Cannot provide both `task` and `env` arguments.") elif env is not None: task = env - envs = self._make_envs(task=task, num_envs=num_envs, device=device, **kwargs) + from_pixels = kwargs.pop("from_pixels", False) + envs = self._make_envs(task=task, num_envs=num_envs, device=device, virtual_screen_capture=True, **kwargs) self.task = task - super().__init__(envs, **kwargs) + super().__init__(envs, from_pixels=from_pixels, **kwargs)