From 2a244007bf60269a3c1ed0a8fc68122c8aec9677 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 15 Apr 2024 13:18:04 +0100 Subject: [PATCH] amend --- test/test_libs.py | 37 ++++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index a14e2626b14..5055fa51688 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -7,6 +7,8 @@ from contextlib import nullcontext from pathlib import Path +import robohive + from torchrl._utils import logger as torchrl_logger from torchrl.data.datasets.gen_dgrl import GenDGRLExperienceReplay @@ -3341,25 +3343,26 @@ class TestRoboHive: @pytest.mark.parametrize("from_pixels", [False, True]) @pytest.mark.parametrize("envname", RoboHiveEnv.available_envs) def test_robohive(self, envname, from_pixels): - torchrl_logger.info(f"{envname}-{from_pixels}") - if any(substr in envname for substr in ("_vr3m", "_vrrl", "_vflat", "_vvc1s")): - torchrl_logger.info("not testing envs with prebuilt rendering") - return - if "Adroit" in envname: - torchrl_logger.info("tcdm are broken") - return - if from_pixels and len(RoboHiveEnv.get_available_cams(env_name=envname)) == 0: - torchrl_logger.info("no camera") - return - try: - env = RoboHiveEnv(envname, from_pixels=from_pixels) - except AttributeError as err: - if "'MjData' object has no attribute 'get_body_xipos'" in str(err): + with set_gym_backend("gymnasium"): + torchrl_logger.info(f"{envname}-{from_pixels}") + if any(substr in envname for substr in ("_vr3m", "_vrrl", "_vflat", "_vvc1s")): + torchrl_logger.info("not testing envs with prebuilt rendering") + return + if "Adroit" in envname: torchrl_logger.info("tcdm are broken") return - else: - raise err - check_env_specs(env) + if from_pixels and len(RoboHiveEnv.get_available_cams(env_name=envname)) == 0: + torchrl_logger.info("no camera") + return + try: + env = RoboHiveEnv(envname, from_pixels=from_pixels) + except AttributeError as err: + if "'MjData' object has no attribute 'get_body_xipos'" in str(err): + torchrl_logger.info("tcdm are broken") + return + else: + raise err + check_env_specs(env) @pytest.mark.skipif(not _has_smacv2, reason="SMACv2 not found")