Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Apr 15, 2024
1 parent 07c0c11 commit 2a24400
Showing 1 changed file with 20 additions and 17 deletions.
37 changes: 20 additions & 17 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from contextlib import nullcontext
from pathlib import Path

import robohive

from torchrl._utils import logger as torchrl_logger

from torchrl.data.datasets.gen_dgrl import GenDGRLExperienceReplay
Expand Down Expand Up @@ -3341,25 +3343,26 @@ class TestRoboHive:
@pytest.mark.parametrize("from_pixels", [False, True])
@pytest.mark.parametrize("envname", RoboHiveEnv.available_envs)
def test_robohive(self, envname, from_pixels):
torchrl_logger.info(f"{envname}-{from_pixels}")
if any(substr in envname for substr in ("_vr3m", "_vrrl", "_vflat", "_vvc1s")):
torchrl_logger.info("not testing envs with prebuilt rendering")
return
if "Adroit" in envname:
torchrl_logger.info("tcdm are broken")
return
if from_pixels and len(RoboHiveEnv.get_available_cams(env_name=envname)) == 0:
torchrl_logger.info("no camera")
return
try:
env = RoboHiveEnv(envname, from_pixels=from_pixels)
except AttributeError as err:
if "'MjData' object has no attribute 'get_body_xipos'" in str(err):
with set_gym_backend("gymnasium"):
torchrl_logger.info(f"{envname}-{from_pixels}")
if any(substr in envname for substr in ("_vr3m", "_vrrl", "_vflat", "_vvc1s")):
torchrl_logger.info("not testing envs with prebuilt rendering")
return
if "Adroit" in envname:
torchrl_logger.info("tcdm are broken")
return
else:
raise err
check_env_specs(env)
if from_pixels and len(RoboHiveEnv.get_available_cams(env_name=envname)) == 0:
torchrl_logger.info("no camera")
return
try:
env = RoboHiveEnv(envname, from_pixels=from_pixels)
except AttributeError as err:
if "'MjData' object has no attribute 'get_body_xipos'" in str(err):
torchrl_logger.info("tcdm are broken")
return
else:
raise err
check_env_specs(env)


@pytest.mark.skipif(not _has_smacv2, reason="SMACv2 not found")
Expand Down

0 comments on commit 2a24400

Please sign in to comment.