Skip to content

Commit

Permalink
Merge branch 'fix-isaac' of https://github.com/pytorch/rl into fix-isaac
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Apr 24, 2024
2 parents 64996b7 + 0c86ea6 commit b076865
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
12 changes: 7 additions & 5 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3083,22 +3083,24 @@ def test_data(self, dataset):
)
@pytest.mark.parametrize("num_envs", [10, 20])
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("from_pixels", [True, False])
class TestIsaacGym:
@classmethod
def _run_on_proc(cls, q, task, num_envs, device):
def _run_on_proc(cls, q, task, num_envs, device, from_pixels):
try:
env = IsaacGymEnv(task=task, num_envs=num_envs, device=device)
env = IsaacGymEnv(task=task, num_envs=num_envs, device=device, from_pixels=from_pixels)
print(env.rollout(3))
check_env_specs(env)
q.put(("succeeded!", None))
except Exception as err:
q.put(("failed!", err))
raise err

def test_env(self, task, num_envs, device):
def test_env(self, task, num_envs, device, from_pixels):
from torch import multiprocessing as mp

q = mp.Queue(1)
proc = mp.Process(target=self._run_on_proc, args=(q, task, num_envs, device))
self._run_on_proc(q, task, num_envs, device, from_pixels)
proc = mp.Process(target=self._run_on_proc, args=(q, task, num_envs, device, from_pixels))
try:
proc.start()
msg, error = q.get()
Expand Down
10 changes: 8 additions & 2 deletions torchrl/envs/libs/isaacgym.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,14 @@ def _make_specs(self, env: "gym.Env") -> None: # noqa: F821

def _output_transform(self, output):
obs, reward, done, info = output
if self.from_pixels:
obs["pixels"] = self._env.render(mode='rgb_array')
return obs, reward, done ^ done, done, done, info

def _reset_output_transform(self, reset_data):
reset_data.pop("reward", None)
if self.from_pixels:
reset_data["pixels"] = self._env.render(mode='rgb_array')
return reset_data, {}

@classmethod
Expand Down Expand Up @@ -187,6 +192,7 @@ def __init__(self, task=None, *, env=None, num_envs, device, **kwargs):
raise RuntimeError("Cannot provide both `task` and `env` arguments.")
elif env is not None:
task = env
envs = self._make_envs(task=task, num_envs=num_envs, device=device, **kwargs)
from_pixels = kwargs.pop("from_pixels", False)
envs = self._make_envs(task=task, num_envs=num_envs, device=device, virtual_screen_capture=True, **kwargs)
self.task = task
super().__init__(envs, **kwargs)
super().__init__(envs, from_pixels=from_pixels, **kwargs)

0 comments on commit b076865

Please sign in to comment.