From 93641f1dc83cfdec2c9e9eb33888cd632d76be0a Mon Sep 17 00:00:00 2001 From: Jet <38184875+jjshoots@users.noreply.github.com> Date: Mon, 5 Dec 2022 19:40:00 +0000 Subject: [PATCH] Add dm-lab (no tests yet) and dm-control-multiagent (#7) --- setup.py | 2 +- shimmy/__init__.py | 18 +- shimmy/dm_control_compatibility.py | 30 +--- shimmy/dm_control_multiagent_compatibility.py | 167 ++++++++++++++++++ shimmy/dm_lab_compatibility.py | 81 +++++++++ ..._wrapper.py => openspiel_compatibility.py} | 2 +- shimmy/registration.py | 10 +- shimmy/utils/dm_env.py | 33 ++++ shimmy/utils/dm_lab.py | 68 +++++++ tests/test_dm_control.py | 4 +- tests/test_dm_control_multiagent.py | 36 ++++ tests/test_openspiel.py | 19 +- 12 files changed, 425 insertions(+), 45 deletions(-) create mode 100644 shimmy/dm_control_multiagent_compatibility.py create mode 100644 shimmy/dm_lab_compatibility.py rename shimmy/{openspiel_wrapper.py => openspiel_compatibility.py} (99%) create mode 100644 shimmy/utils/dm_lab.py create mode 100644 tests/test_dm_control_multiagent.py diff --git a/setup.py b/setup.py index 0fd6d138..dbb18220 100644 --- a/setup.py +++ b/setup.py @@ -36,7 +36,7 @@ def get_version(): "gym": ["gym>=0.26"], "atari": ["ale-py~=0.8.0"], # "imageio" should be "gymnasium[mujoco]>=0.26" but there are install conflicts - "dm-control": ["dm-control>=1.0.8", "imageio"], + "dm-control": ["dm-control>=1.0.8", "imageio", "h5py>=3.7.0"], "openspiel": ["open_spiel>=1.2", "pettingzoo>=1.22"], } extras["all"] = list({lib for libs in extras.values() for lib in libs}) diff --git a/shimmy/__init__.py b/shimmy/__init__.py index f0899fca..9ef5aedd 100644 --- a/shimmy/__init__.py +++ b/shimmy/__init__.py @@ -1,19 +1,29 @@ """API for converting popular non-gymnasium environments to a gymnasium compatible environment.""" -__version__ = "0.1.0" +__version__ = "0.2.0" try: - from shimmy.dm_control_compatibility import DmControlCompatibility + from shimmy.dm_control_compatibility import DmControlCompatibilityV0 except ImportError: pass try: - from shimmy.openspiel_wrapper import OpenspielWrapperV0 + from shimmy.dm_control_multiagent_compatibility import ( + DmControlMultiAgentCompatibilityV0, + ) except ImportError: pass -from shimmy.openai_gym_compatibility import GymV22Compatibility, GymV26Compatibility +try: + from shimmy.openspiel_compatibility import OpenspielCompatibilityV0 +except ImportError: + pass + +try: + from shimmy.dm_lab_compatibility import DmLabCompatibilityV0 +except ImportError: + pass __all__ = [ "DmControlCompatibility", diff --git a/shimmy/dm_control_compatibility.py b/shimmy/dm_control_compatibility.py index e8d906d3..67dd040b 100644 --- a/shimmy/dm_control_compatibility.py +++ b/shimmy/dm_control_compatibility.py @@ -17,7 +17,7 @@ from gymnasium.core import ObsType from gymnasium.envs.mujoco.mujoco_rendering import Viewer -from shimmy.utils.dm_env import dm_obs2gym_obs, dm_spec2gym_space +from shimmy.utils.dm_env import dm_control_step2gym_step, dm_spec2gym_space class EnvType(Enum): @@ -27,7 +27,7 @@ class EnvType(Enum): RL_CONTROL = 1 -class DmControlCompatibility(gymnasium.Env[ObsType, np.ndarray]): +class DmControlCompatibilityV0(gymnasium.Env[ObsType, np.ndarray]): """A compatibility wrapper that converts a dm-control environment into a gymnasium environment. Dm-control actually has two Environments classes, `dm_control.composer.Environment` and @@ -81,36 +81,18 @@ def reset( self.np_random = np.random.RandomState(seed=seed) timestep = self._env.reset() - obs = dm_obs2gym_obs(timestep.observation) - info = { - "timestep.discount": timestep.discount, - "timestep.step_type": timestep.step_type, - } + + obs, reward, terminated, truncated, info = dm_control_step2gym_step(timestep) + return obs, info # pyright: ignore[reportGeneralTypeIssues] def step( self, action: np.ndarray ) -> tuple[ObsType, float, bool, bool, dict[str, Any]]: """Steps through the dm-control environment.""" - # Step through the dm-control environment timestep = self._env.step(action) - # open up the timestep and process reward and observation - obs = dm_obs2gym_obs(timestep.observation) - reward = timestep.reward or 0 - - # set terminated and truncated - terminated, truncated = False, False - if timestep.last(): - if timestep.discount == 0: - truncated = True - else: - terminated = True - - info = { - "timestep.discount": timestep.discount, - "timestep.step_type": timestep.step_type, - } + obs, reward, terminated, truncated, info = dm_control_step2gym_step(timestep) if self.render_mode == "human": self.viewer.render() diff --git a/shimmy/dm_control_multiagent_compatibility.py b/shimmy/dm_control_multiagent_compatibility.py new file mode 100644 index 00000000..c73c3cb5 --- /dev/null +++ b/shimmy/dm_control_multiagent_compatibility.py @@ -0,0 +1,167 @@ +"""Wrapper to convert a dm_env multiagent environment into a pettingzoo compatible environment.""" +from __future__ import annotations + +import functools +from itertools import repeat +from typing import Any + +import dm_control.composer +import dm_env +import gymnasium +from gymnasium.envs.mujoco.mujoco_rendering import Viewer +from pettingzoo import ParallelEnv + +from shimmy.utils.dm_env import dm_obs2gym_obs, dm_spec2gym_space + + +def _unravel_ma_timestep( + timestep: dm_env.TimeStep, agents: list[str] +) -> tuple[ + dict[str, Any], + dict[str, float], + dict[str, bool], + dict[str, bool], + dict[str, Any], +]: + """Opens up the timestep to return obs, reward, terminated, truncated, info.""" + # set terminated and truncated + term, trunc = False, False + if timestep.last(): + if timestep.discount == 0: + trunc = True + else: + term = True + + # expand the observations + list_observations = [dm_obs2gym_obs(obs) for obs in timestep.observation] + observations: dict[str, Any] = dict(zip(agents, list_observations)) + + # sometimes deepmind decides not to reward people + rewards: dict[str, float] = dict(zip(agents, repeat(0.0))) + if timestep.reward: + rewards = dict(zip(agents, timestep.reward)) + + # expand everything else + terminations: dict[str, bool] = dict(zip(agents, repeat(term))) + truncations: dict[str, bool] = dict(zip(agents, repeat(trunc))) + + # duplicate infos + info = { + "timestep.discount": timestep.discount, + "timestep.step_type": timestep.step_type, + } + info: dict[str, Any] = dict(zip(agents, repeat(info))) + + return ( + observations, + rewards, + terminations, + truncations, + info, + ) + + +class DmControlMultiAgentCompatibilityV0(ParallelEnv): + """Compatibility environment for multi-agent dm-control environments, primarily soccer.""" + + metadata = {"render_modes": ["human"]} + + def __init__( + self, + env: dm_control.composer.Environment, + render_mode: str | None = None, + ): + """Wrapper that converts a dm control multi-agent environment into a pettingzoo environment. + + Due to how the underlying environment is setup, this environment is nondeterministic, so seeding doesn't work. + + Args: + env (dm_env.Environment): dm control multi-agent environment + render_mode (Optional[str]): render_mode + """ + super().__init__() + self._env = env + self.render_mode = render_mode + + # get action and observation specs first + all_obs_spaces = [ + dm_spec2gym_space(spec) for spec in self._env.observation_spec() + ] + all_act_spaces = [dm_spec2gym_space(spec) for spec in self._env.action_spec()] + num_players = len(all_obs_spaces) + + # agent definitions + self.possible_agents = ["player_" + str(r) for r in range(num_players)] + self.agent_id_name_mapping = dict(zip(range(num_players), self.possible_agents)) + self.agent_name_id_mapping = dict(zip(self.possible_agents, range(num_players))) + + # the official spaces + self.obs_spaces = dict(zip(self.possible_agents, all_obs_spaces)) + self.act_spaces = dict(zip(self.possible_agents, all_act_spaces)) + + if self.render_mode == "human": + self.viewer = Viewer( + self._env.physics.model.ptr, self._env.physics.data.ptr + ) + + @functools.lru_cache(maxsize=None) + def observation_space(self, agent): + """The observation space for agent.""" + return self.obs_spaces[agent] + + @functools.lru_cache(maxsize=None) + def action_space(self, agent): + """The action space for agent.""" + return self.act_spaces[agent] + + def render(self): + """Renders the environment.""" + if self.render_mode is None: + gymnasium.logger.warn( + "You are calling render method without specifying any render mode." + ) + return + + def close(self): + """Closes the environment.""" + self._env.physics.free() + self._env.close() + + if hasattr(self, "viewer"): + self.viewer.close() + + def reset(self, seed=None, return_info=False, options=None): + """Resets the dm-control environment.""" + self.agents = self.possible_agents[:] + self.num_moves = 0 + + timestep = self._env.reset() + + observations, _, _, _, infos = _unravel_ma_timestep(timestep, self.agents) + + if not return_info: + return observations + else: + return observations, infos + + def step(self, actions): + """Steps through all agents with the actions.""" + # assert that the actions _must_ have actions for all agents + assert len(actions) == len( + self.agents + ), f"Must have actions for all {len(self.agents)} agents, currently only found {len(actions)}." + + actions = actions.values() + timestep = self._env.step(actions) + + obs, rewards, terminations, truncations, infos = _unravel_ma_timestep( + timestep, self.agents + ) + + if self.render_mode == "human": + self.viewer.render() + + if any(terminations.values()) or any(truncations.values()): + self.agents = [] + + return obs, rewards, terminations, truncations, infos diff --git a/shimmy/dm_lab_compatibility.py b/shimmy/dm_lab_compatibility.py new file mode 100644 index 00000000..7034aa6d --- /dev/null +++ b/shimmy/dm_lab_compatibility.py @@ -0,0 +1,81 @@ +"""Wrapper to convert a dm_lab environment into a gymnasium compatible environment.""" +from __future__ import annotations + +from typing import Any, TypeVar + +import gymnasium +import numpy as np +from gymnasium.core import ObsType + +from shimmy.utils.dm_lab import dm_lab_obs2gym_obs_space, dm_lab_spec2gym_space + + +class DmLabCompatibilityV0(gymnasium.Env[ObsType, np.ndarray]): + """A compatibility wrapper that converts a dm_lab-control environment into a gymnasium environment.""" + + metadata = {"render_modes": [], "render_fps": 10} + + def __init__( + self, + env: Any, + render_mode: None = None, + ): + """Initialises the environment with a render mode along with render information.""" + self._env = env + + # need to do this to figure out what observation spec the user used + self._env.reset() + self.observation_space = dm_lab_obs2gym_obs_space(self._env.observations()) + self.action_space = dm_lab_spec2gym_space(env.action_spec()) + + assert ( + render_mode is None + ), "Render mode must be set on dm_lab environment init. Pass `renderer='sdl'` to the config of the base env to enable human rendering." + self.render_mode = render_mode + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[ObsType, dict[str, Any]]: + """Resets the dm-lab environment.""" + super().reset(seed=seed) + + self._env.reset(seed=seed) + info = {} + + return ( + self._env.observations(), + info, + ) # pyright: ignore[reportGeneralTypeIssues] + + def step( + self, action: dict[str, np.ndarray] + ) -> tuple[ObsType, float, bool, bool, dict[str, Any]]: + """Steps through the dm-lab environment.""" + # there's some funky quantization happening here, dm_lab only accepts ints as actions + action = np.array([a[0] for a in action.values()], dtype=np.intc) + reward = self._env.step(action) + + obs = self._env.observations() + terminated = not self._env.is_running() + truncated = False + info = {} + + return ( # pyright: ignore[reportGeneralTypeIssues] + obs, + reward, + terminated, + truncated, + info, + ) + + def render(self) -> None: + """Renders the dm_lab env.""" + raise NotImplementedError + + def close(self): + """Closes the environment.""" + self._env.close() + + def __getattr__(self, item: str): + """If the attribute is missing, try getting the attribute from dm_lab env.""" + return getattr(self._env, item) diff --git a/shimmy/openspiel_wrapper.py b/shimmy/openspiel_compatibility.py similarity index 99% rename from shimmy/openspiel_wrapper.py rename to shimmy/openspiel_compatibility.py index 8272677c..c0056884 100644 --- a/shimmy/openspiel_wrapper.py +++ b/shimmy/openspiel_compatibility.py @@ -12,7 +12,7 @@ from pettingzoo.utils.env import AgentID -class OpenspielWrapperV0(pz.AECEnv): +class OpenspielCompatibilityV0(pz.AECEnv): """Wrapper that converts an openspiel environment into a pettingzoo environment.""" metadata = {"render_modes": []} diff --git a/shimmy/registration.py b/shimmy/registration.py index f10f24c1..7def9c73 100644 --- a/shimmy/registration.py +++ b/shimmy/registration.py @@ -23,11 +23,11 @@ def _register_dm_control_envs(): except ImportError: return - from shimmy.dm_control_compatibility import DmControlCompatibility + from shimmy.dm_control_compatibility import DmControlCompatibilityV0 # Add generic environment support def _make_dm_control_generic_env(env, **render_kwargs): - return DmControlCompatibility(env, **render_kwargs) + return DmControlCompatibilityV0(env, **render_kwargs) register("dm_control/compatibility-env-v0", _make_dm_control_generic_env) @@ -50,7 +50,7 @@ def _make_dm_control_suite_env( environment_kwargs=environment_kwargs, visualize_reward=visualize_reward, ) - return DmControlCompatibility(env, **render_kwargs) + return DmControlCompatibilityV0(env, **render_kwargs) for _domain_name, _task_name in DM_CONTROL_SUITE_ENVS: register( @@ -72,7 +72,7 @@ def _make_dm_control_example_locomotion_env( random_state: np.random.RandomState | None = None, **render_kwargs, ): - return DmControlCompatibility(env_fn(random_state), **render_kwargs) + return DmControlCompatibilityV0(env_fn(random_state), **render_kwargs) for locomotion_env, nondeterministic in ( (basic_cmu_2019.cmu_humanoid_run_walls, False), @@ -97,7 +97,7 @@ def _make_dm_control_example_locomotion_env( def _make_dm_control_manipulation_env(env_name: str, **render_kwargs): env = dm_control.manipulation.load(env_name) - return DmControlCompatibility(env, **render_kwargs) + return DmControlCompatibilityV0(env, **render_kwargs) for env_name in DM_CONTROL_MANIPULATION_ENVS: register( diff --git a/shimmy/utils/dm_env.py b/shimmy/utils/dm_env.py index cd238e67..d9ce367d 100644 --- a/shimmy/utils/dm_env.py +++ b/shimmy/utils/dm_env.py @@ -8,6 +8,7 @@ import numpy as np from dm_env.specs import Array, BoundedArray, DiscreteArray from gymnasium import spaces +from gymnasium.core import ObsType def dm_spec2gym_space(spec) -> spaces.Space[Any]: @@ -28,6 +29,9 @@ def dm_spec2gym_space(spec) -> spaces.Space[Any]: elif np.issubdtype(spec.dtype, np.inexact): low = float("-inf") high = float("inf") + elif spec.dtype == "bool": + low = int(0) + high = int(1) else: raise ValueError(f"Unknown dtype {spec.dtype} for spec {spec}.") @@ -46,3 +50,32 @@ def dm_obs2gym_obs(obs) -> np.ndarray | dict[str, Any]: return {key: dm_obs2gym_obs(value) for key, value in copy.copy(obs).items()} else: return np.asarray(obs) + + +def dm_control_step2gym_step( + timestep, +) -> tuple[ObsType, float, bool, bool, dict[str, Any]]: + """Opens up the timestep to return obs, reward, terminated, truncated, info.""" + obs = dm_obs2gym_obs(timestep.observation) + reward = timestep.reward or 0 + + # set terminated and truncated + terminated, truncated = False, False + if timestep.last(): + if timestep.discount == 0: + truncated = True + else: + terminated = True + + info = { + "timestep.discount": timestep.discount, + "timestep.step_type": timestep.step_type, + } + + return ( # pyright: ignore[reportGeneralTypeIssues] + obs, + reward, + terminated, + truncated, + info, + ) diff --git a/shimmy/utils/dm_lab.py b/shimmy/utils/dm_lab.py new file mode 100644 index 00000000..3db5ad64 --- /dev/null +++ b/shimmy/utils/dm_lab.py @@ -0,0 +1,68 @@ +"""Utility functions for the compatibility wrappers.""" +from __future__ import annotations + +from collections import OrderedDict +from typing import Any + +import numpy as np +from gymnasium import spaces + + +def dm_lab_obs2gym_obs_space(observation: dict) -> spaces.Space[Any]: + """Gets the observation spec from a single observation.""" + assert isinstance( + observation, (OrderedDict, dict) + ), f"Observation must be a dict, got {observation}" + + all_spaces = dict() + for key, value in observation.items(): + dtype = value.dtype + + low = None + high = None + if np.issubdtype(dtype, np.integer): + low = np.iinfo(dtype).min + high = np.iinfo(dtype).max + elif np.issubdtype(dtype, np.inexact): + low = float("-inf") + high = float("inf") + else: + raise ValueError(f"Unknown dtype {dtype}.") + + all_spaces[key] = spaces.Box(low=low, high=high, shape=value.shape, dtype=dtype) + + return spaces.Dict(all_spaces) + + +def dm_lab_spec2gym_space(spec) -> spaces.Space[Any]: + """Converts a dm_lab spec to a gymnasium space.""" + if isinstance(spec, list): + expanded = {} + for desc in spec: + assert ( + "name" in desc + ), f"Can't find name for the description: {desc} in spec." + + # some observation spaces have a string description, we ignore those for now + if "dtype" in desc: + if desc["dtype"] == str: + continue + + expanded[desc["name"]] = dm_lab_spec2gym_space(desc) + + return spaces.Dict(expanded) + if isinstance(spec, (OrderedDict, dict)): + # this is an action space + if "min" in spec and "max" in spec: + return spaces.Box(low=spec["min"], high=spec["max"], dtype=np.float64) + + # we dk wtf it is here + else: + raise NotImplementedError( + f"Unknown spec definition: {spec}, please report." + ) + + else: + raise NotImplementedError( + f"Cannot convert dm_spec to gymnasium space, unknown spec: {spec}, please report." + ) diff --git a/tests/test_dm_control.py b/tests/test_dm_control.py index 5e31faae..eb80df35 100644 --- a/tests/test_dm_control.py +++ b/tests/test_dm_control.py @@ -18,7 +18,7 @@ from gymnasium.error import Error from gymnasium.utils.env_checker import check_env, data_equivalence -from shimmy.dm_control_compatibility import DmControlCompatibility +from shimmy.dm_control_compatibility import DmControlCompatibilityV0 from shimmy.registration import DM_CONTROL_SUITE_ENVS DM_CONTROL_ENV_IDS = [ @@ -155,7 +155,7 @@ def test_dm_control_wrappers( ): return wrapped_env = wrapper_fn(dm_control_env) - env = DmControlCompatibility(wrapped_env) + env = DmControlCompatibilityV0(wrapped_env) with warnings.catch_warnings(record=True) as caught_warnings: check_env(env) diff --git a/tests/test_dm_control_multiagent.py b/tests/test_dm_control_multiagent.py new file mode 100644 index 00000000..d2eaaac8 --- /dev/null +++ b/tests/test_dm_control_multiagent.py @@ -0,0 +1,36 @@ +"""Tests the multi-agent dm-control soccer environment.""" + +import gymnasium +import pytest +from dm_control.locomotion import soccer as dm_soccer +from gymnasium.utils.env_checker import data_equivalence +from pettingzoo.test import parallel_api_test + +from shimmy.dm_control_multiagent_compatibility import ( + DmControlMultiAgentCompatibilityV0, +) + +WALKER_TYPES = [ + dm_soccer.WalkerType.BOXHEAD, + dm_soccer.WalkerType.ANT, + dm_soccer.WalkerType.HUMANOID, +] + + +@pytest.mark.parametrize("walker_type", WALKER_TYPES) +def test_check_env(walker_type): + """Check that environment pass the pettingzoo check_env.""" + env = dm_soccer.load( + team_size=2, + time_limit=10.0, + disable_walker_contacts=False, + enable_field_box=True, + terminate_on_goal=False, + walker_type=walker_type, + ) + + env = DmControlMultiAgentCompatibilityV0(env) + + parallel_api_test(env) + + env.close() diff --git a/tests/test_openspiel.py b/tests/test_openspiel.py index 30fc8b5f..3f6f2735 100644 --- a/tests/test_openspiel.py +++ b/tests/test_openspiel.py @@ -4,7 +4,7 @@ import pyspiel import pytest -from shimmy import OpenspielWrapperV0 +from shimmy.openspiel_compatibility import OpenspielCompatibilityV0 # todo add api_test however chess causes a OOM error # from pettingzoo.test import api_test @@ -116,7 +116,7 @@ def test_passing_games(game): """Tests the conversion of all openspiel envs.""" for _ in range(5): env = pyspiel.load_game(game) - env = OpenspielWrapperV0(game=env, render_mode=None) + env = OpenspielCompatibilityV0(game=env, render_mode=None) # api test the env # api_test(env) @@ -131,19 +131,20 @@ def test_passing_games(game): @pytest.mark.parametrize("game", _FAILING_GAMES) def test_failing_games(game): """Ensures that failing games are still failing.""" - with pytest.raises((pyspiel.SpielError, NotImplementedError)): + with pytest.raises(pyspiel.SpielError): test_passing_games(game) -def test_seeding(): +@pytest.mark.parametrize("game", _PASSING_GAMES) +def test_seeding(game): """Tests the seeding of the openspiel conversion wrapper.""" # load envs - env1 = pyspiel.load_game("2048") - env2 = pyspiel.load_game("2048") + env1 = pyspiel.load_game(game) + env2 = pyspiel.load_game(game) # convert the environment - env1 = OpenspielWrapperV0(env1, render_mode=None) - env2 = OpenspielWrapperV0(env2, render_mode=None) + env1 = OpenspielCompatibilityV0(env1, render_mode=None) + env2 = OpenspielCompatibilityV0(env2, render_mode=None) env1.reset(seed=42) env2.reset(seed=42) @@ -173,3 +174,5 @@ def test_seeding(): assert stuff1 == stuff2, "Incorrect returns on iteration." elif isinstance(stuff1, np.ndarray): assert (stuff1 == stuff2).all(), "Incorrect returns on iteration." + elif isinstance(stuff1, str): + assert stuff1 == stuff2, "Incorrect returns on iteration."