diff --git a/setup.py b/setup.py index e6e2fb4c..38698d74 100644 --- a/setup.py +++ b/setup.py @@ -37,6 +37,7 @@ def get_version(): "atari": ["ale-py~=0.8.0"], # "imageio" should be "gymnasium[mujoco]>=0.26" but there are install conflicts "dm-control": ["dm-control>=1.0.8", "imageio", "h5py>=3.7.0"], + "dm-control-multi-agent": ["dm-control>=1.0.8", "pettingzoo>=1.22"], "openspiel": ["open_spiel>=1.2", "pettingzoo>=1.22"], } extras["all"] = list({lib for libs in extras.values() for lib in libs}) @@ -63,10 +64,11 @@ def get_version(): tests_require=extras["testing"], extras_require=extras, classifiers=[ - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ], diff --git a/shimmy/__init__.py b/shimmy/__init__.py index 0a3ca8fc..37cc3d2b 100644 --- a/shimmy/__init__.py +++ b/shimmy/__init__.py @@ -1,29 +1,52 @@ """API for converting popular non-gymnasium environments to a gymnasium compatible environment.""" +from __future__ import annotations + +from typing import Any + +from shimmy.dm_lab_compatibility import DmLabCompatibilityV0 +from shimmy.openai_gym_compatibility import GymV22CompatibilityV0, GymV26CompatibilityV0 __version__ = "0.2.0" +class NotInstallClass: + """Rather than an attribute error, this raises a more helpful import error with install instructions for shimmy.""" + + def __init__(self, install_message: str, import_exception: ImportError): + self.install_message = install_message + self.import_exception = import_exception + + def __call__(self, *args: list[Any], **kwargs: Any): + """Acts like the `__init__` for the class.""" + raise ImportError(self.install_message) from self.import_exception + + try: from shimmy.dm_control_compatibility import DmControlCompatibilityV0 -except ImportError: - pass +except ImportError as e: + DmControlCompatibilityV0 = NotInstallClass( + "Dm-control is not installed, run `pip install 'shimmy[dm-control]'`", e + ) + try: from shimmy.dm_control_multiagent_compatibility import ( DmControlMultiAgentCompatibilityV0, ) -except ImportError: - pass +except ImportError as e: + DmControlMultiAgentCompatibilityV0 = NotInstallClass( + "Dm-control or Pettingzoo is not installed, run `pip install 'shimmy[dm-control-multi-agent]'`", + e, + ) try: from shimmy.openspiel_compatibility import OpenspielCompatibilityV0 -except ImportError: - pass +except ImportError as e: + OpenspielCompatibilityV0 = NotInstallClass( + "Openspiel or Pettingzoo is not installed, run `pip install 'shimmy[openspiel]'`", + e, + ) -try: - from shimmy.dm_lab_compatibility import DmLabCompatibilityV0 -except ImportError: - pass __all__ = [ "DmControlCompatibilityV0", diff --git a/shimmy/atari_env.py b/shimmy/atari_env.py index 5ae8f905..a7ddef51 100644 --- a/shimmy/atari_env.py +++ b/shimmy/atari_env.py @@ -1,4 +1,13 @@ -"""ALE-py interface for atari.""" +"""ALE-py interface for atari. + +This file was originally copied from https://github.com/mgbellemare/Arcade-Learning-Environment/blob/master/src/python/env/gym.py +Under the GNU General Public License v2.0 + +Copyright is held by the authors + +Changes +* Added `self.render_mode` which is identical to `self._render_mode` +""" from __future__ import annotations import sys @@ -29,7 +38,7 @@ class AtariEnvStepMetadata(TypedDict): seeds: NotRequired[Sequence[int]] -class AtariEnv(gymnasium.Env[np.ndarray, int], EzPickle): +class AtariEnv(gymnasium.Env[np.ndarray, np.int64], EzPickle): """(A)rcade (L)earning (Gymnasium) (Env)ironment. A Gymnasium wrapper around the Arcade Learning Environment (ALE). diff --git a/shimmy/dm_control_compatibility.py b/shimmy/dm_control_compatibility.py index 90dd9956..b78af1d6 100644 --- a/shimmy/dm_control_compatibility.py +++ b/shimmy/dm_control_compatibility.py @@ -84,7 +84,7 @@ def reset( obs, reward, terminated, truncated, info = dm_control_step2gym_step(timestep) - return obs, info # pyright: ignore[reportGeneralTypeIssues] + return obs, info def step( self, action: np.ndarray @@ -95,9 +95,9 @@ def step( obs, reward, terminated, truncated, info = dm_control_step2gym_step(timestep) if self.render_mode == "human": - self.viewer.render() + self.viewer.render(self.render_mode) - return ( # pyright: ignore[reportGeneralTypeIssues] + return ( obs, reward, terminated, diff --git a/shimmy/dm_control_multiagent_compatibility.py b/shimmy/dm_control_multiagent_compatibility.py index a8a8cf12..1d3b0f01 100644 --- a/shimmy/dm_control_multiagent_compatibility.py +++ b/shimmy/dm_control_multiagent_compatibility.py @@ -159,7 +159,7 @@ def step(self, actions): ) if self.render_mode == "human": - self.viewer.render() + self.viewer.render(self.render_mode) if any(terminations.values()) or any(truncations.values()): self.agents = [] diff --git a/shimmy/dm_lab_compatibility.py b/shimmy/dm_lab_compatibility.py index 7034aa6d..1997b2d3 100644 --- a/shimmy/dm_lab_compatibility.py +++ b/shimmy/dm_lab_compatibility.py @@ -1,16 +1,16 @@ """Wrapper to convert a dm_lab environment into a gymnasium compatible environment.""" from __future__ import annotations -from typing import Any, TypeVar +from typing import Any, Dict -import gymnasium +import gymnasium as gym import numpy as np from gymnasium.core import ObsType from shimmy.utils.dm_lab import dm_lab_obs2gym_obs_space, dm_lab_spec2gym_space -class DmLabCompatibilityV0(gymnasium.Env[ObsType, np.ndarray]): +class DmLabCompatibilityV0(gym.Env[ObsType, Dict[str, np.ndarray]]): """A compatibility wrapper that converts a dm_lab-control environment into a gymnasium environment.""" metadata = {"render_modes": [], "render_fps": 10} @@ -45,22 +45,22 @@ def reset( return ( self._env.observations(), info, - ) # pyright: ignore[reportGeneralTypeIssues] + ) def step( self, action: dict[str, np.ndarray] ) -> tuple[ObsType, float, bool, bool, dict[str, Any]]: """Steps through the dm-lab environment.""" # there's some funky quantization happening here, dm_lab only accepts ints as actions - action = np.array([a[0] for a in action.values()], dtype=np.intc) - reward = self._env.step(action) + action_array = np.array([a[0] for a in action.values()], dtype=np.intc) + reward = self._env.step(action_array) obs = self._env.observations() terminated = not self._env.is_running() truncated = False info = {} - return ( # pyright: ignore[reportGeneralTypeIssues] + return ( obs, reward, terminated, diff --git a/shimmy/openai_gym_compatibility.py b/shimmy/openai_gym_compatibility.py index faa23d13..6fc03b6c 100644 --- a/shimmy/openai_gym_compatibility.py +++ b/shimmy/openai_gym_compatibility.py @@ -39,8 +39,8 @@ GYM_IMPORT_ERROR = None -class GymV26Compatibility(gymnasium.Env[ObsType, ActType]): - """Converts a gym v26 environment to a gymnasium environment.""" +class GymV26CompatibilityV0(gymnasium.Env[ObsType, ActType]): + """Converts a Gym v26 environment to a Gymnasium environment.""" def __init__( self, @@ -83,6 +83,10 @@ def __init__( self.reward_range = getattr(self.gym_env, "reward_range", None) self.spec = getattr(self.gym_env, "spec", None) + def __getattr__(self, item: str): + """Gets an attribute that only exists in the base environments.""" + return getattr(self.gym_env, item) + def reset( self, seed: int | None = None, options: dict | None = None ) -> tuple[ObsType, dict]: @@ -151,7 +155,7 @@ def seed(self, seed: int | None = None): ... -class GymV22Compatibility(gymnasium.Env[ObsType, ActType]): +class GymV22CompatibilityV0(gymnasium.Env[ObsType, ActType]): r"""A wrapper which can transform an environment from the old API to the new API. Old step API refers to step() method returning (observation, reward, done, info), and reset() only retuning the observation. @@ -201,6 +205,10 @@ def __init__( self.gym_env: LegacyV22Env = gym_env + def __getattr__(self, item: str): + """Gets an attribute that only exists in the base environments.""" + return getattr(self.gym_env, item) + def reset( self, seed: int | None = None, options: dict | None = None ) -> tuple[ObsType, dict]: @@ -236,9 +244,7 @@ def step(self, action: ActType) -> tuple[Any, float, bool, bool, dict]: if self.render_mode == "human": self.render() - return convert_to_terminated_truncated_step_api( - (obs, reward, done, info) - ) # pyright: ignore[reportGeneralTypeIssues] + return convert_to_terminated_truncated_step_api((obs, reward, done, info)) def render(self) -> Any: """Renders the environment. diff --git a/shimmy/registration.py b/shimmy/registration.py index 7def9c73..89e1bfef 100644 --- a/shimmy/registration.py +++ b/shimmy/registration.py @@ -227,14 +227,32 @@ def _register_atari_envs(): ) +def _register_dm_lab(): + try: + import deepmind_lab + except ImportError: + return + + from shimmy.dm_lab_compatibility import DmLabCompatibilityV0 + + def _make_dm_lab_env( + env_id: str, observations, config: dict[str, Any], renderer: str + ): + env = deepmind_lab.Lab(env_id, observations, config=config, renderer=renderer) + return DmLabCompatibilityV0(env) + + register("DmLabCompatibility-v0", _make_dm_lab_env) + + def register_gymnasium_envs(): """This function is called when gymnasium is imported.""" - _register_dm_control_envs() - _register_atari_envs() - register( - "GymV26Environment-v0", "shimmy.openai_gym_compatibility:GymV26Compatibility" + "GymV26Environment-v0", "shimmy.openai_gym_compatibility:GymV26CompatibilityV0" ) register( - "GymV22Environment-v0", "shimmy.openai_gym_compatibility:GymV22Compatibility" + "GymV22Environment-v0", "shimmy.openai_gym_compatibility:GymV22CompatibilityV0" ) + + _register_dm_control_envs() + _register_atari_envs() + _register_dm_lab() diff --git a/shimmy/utils/dm_env.py b/shimmy/utils/dm_env.py index d9ce367d..c8f34878 100644 --- a/shimmy/utils/dm_env.py +++ b/shimmy/utils/dm_env.py @@ -8,7 +8,6 @@ import numpy as np from dm_env.specs import Array, BoundedArray, DiscreteArray from gymnasium import spaces -from gymnasium.core import ObsType def dm_spec2gym_space(spec) -> spaces.Space[Any]: @@ -54,7 +53,7 @@ def dm_obs2gym_obs(obs) -> np.ndarray | dict[str, Any]: def dm_control_step2gym_step( timestep, -) -> tuple[ObsType, float, bool, bool, dict[str, Any]]: +) -> tuple[Any, float, bool, bool, dict[str, Any]]: """Opens up the timestep to return obs, reward, terminated, truncated, info.""" obs = dm_obs2gym_obs(timestep.observation) reward = timestep.reward or 0 @@ -72,7 +71,7 @@ def dm_control_step2gym_step( "timestep.step_type": timestep.step_type, } - return ( # pyright: ignore[reportGeneralTypeIssues] + return ( obs, reward, terminated, diff --git a/tests/test_gym.py b/tests/test_gym.py index b06a469a..5e7ccbc4 100644 --- a/tests/test_gym.py +++ b/tests/test_gym.py @@ -2,12 +2,15 @@ import warnings -import gym +import gym as openai_gym import gymnasium import pytest +from gym.spaces import Box as openai_Box from gymnasium.error import Error from gymnasium.utils.env_checker import check_env +from shimmy import GymV22CompatibilityV0, GymV26CompatibilityV0 + CHECK_ENV_IGNORE_WARNINGS = [ f"\x1b[33mWARN: {message}\x1b[0m" for message in [ @@ -21,7 +24,7 @@ # We do not test Atari environment's here because we check all variants of Pong in test_envs.py (There are too many Atari environments) CLASSIC_CONTROL_ENVS = [ env_id - for env_id, spec in gym.envs.registry.items() # pyright: ignore[reportGeneralTypeIssues] + for env_id, spec in openai_gym.envs.registry.items() # pyright: ignore[reportGeneralTypeIssues] if ("classic_control" in spec.entry_point) ] @@ -51,9 +54,11 @@ def test_gym_conversion_by_id(env_id): ) def test_gym_conversion_instantiated(env_id): """Tests that the gym conversion works with an instantiated gym environment.""" - env = gym.make(env_id) + env = openai_gym.make(env_id) env = gymnasium.make("GymV26Environment-v0", env=env).unwrapped + print("render-mode", env.render_mode) + print("render-modes", env.metadata) with warnings.catch_warnings(record=True) as caught_warnings: check_env(env, skip_render_check=True) @@ -65,3 +70,31 @@ def test_gym_conversion_instantiated(env_id): raise Error(f"Unexpected warning: {warning.message}") env.close() + + +class EnvWithData(openai_gym.Env): + """Environment with data that users might want to access.""" + + def __init__(self): + """Initialises the environment with hidden data.""" + self.observation_space = openai_Box(low=0, high=1) + self.action_space = openai_Box(low=0, high=1) + + self.data = 123 + + def get_env_data(self): + """Gets the environment data.""" + return self.data + + +def test_compatibility_get_attr(): + """Tests that the compatibility environment works with `__getattr__` for those attributes.""" + env = GymV22CompatibilityV0(env=EnvWithData()) + assert env.data == 123 + assert env.get_env_data() == 123 + env.close() + + env = GymV26CompatibilityV0(env=EnvWithData()) + assert env.data == 123 + assert env.get_env_data() == 123 + env.close()