Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dm-lab (no tests yet) and dm-control-multiagent #7

Merged
merged 27 commits into from
Dec 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_version():
"gym": ["gym>=0.26"],
"atari": ["ale-py~=0.8.0"],
# "imageio" should be "gymnasium[mujoco]>=0.26" but there are install conflicts
"dm-control": ["dm-control>=1.0.8", "imageio"],
"dm-control": ["dm-control>=1.0.8", "imageio", "h5py>=3.7.0"],
"openspiel": ["open_spiel>=1.2", "pettingzoo>=1.22"],
}
extras["all"] = list({lib for libs in extras.values() for lib in libs})
Expand Down
18 changes: 14 additions & 4 deletions shimmy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
"""API for converting popular non-gymnasium environments to a gymnasium compatible environment."""

__version__ = "0.1.0"
__version__ = "0.2.0"


try:
from shimmy.dm_control_compatibility import DmControlCompatibility
from shimmy.dm_control_compatibility import DmControlCompatibilityV0
except ImportError:
pass

try:
from shimmy.openspiel_wrapper import OpenspielWrapperV0
from shimmy.dm_control_multiagent_compatibility import (
DmControlMultiAgentCompatibilityV0,
)
except ImportError:
pass

from shimmy.openai_gym_compatibility import GymV22Compatibility, GymV26Compatibility
try:
from shimmy.openspiel_compatibility import OpenspielCompatibilityV0
except ImportError:
pass

try:
from shimmy.dm_lab_compatibility import DmLabCompatibilityV0
except ImportError:
pass

jjshoots marked this conversation as resolved.
Show resolved Hide resolved
__all__ = [
"DmControlCompatibility",
Expand Down
30 changes: 6 additions & 24 deletions shimmy/dm_control_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from gymnasium.core import ObsType
from gymnasium.envs.mujoco.mujoco_rendering import Viewer

from shimmy.utils.dm_env import dm_obs2gym_obs, dm_spec2gym_space
from shimmy.utils.dm_env import dm_control_step2gym_step, dm_spec2gym_space


class EnvType(Enum):
Expand All @@ -27,7 +27,7 @@ class EnvType(Enum):
RL_CONTROL = 1


class DmControlCompatibility(gymnasium.Env[ObsType, np.ndarray]):
class DmControlCompatibilityV0(gymnasium.Env[ObsType, np.ndarray]):
"""A compatibility wrapper that converts a dm-control environment into a gymnasium environment.

Dm-control actually has two Environments classes, `dm_control.composer.Environment` and
Expand Down Expand Up @@ -81,36 +81,18 @@ def reset(
self.np_random = np.random.RandomState(seed=seed)

timestep = self._env.reset()
obs = dm_obs2gym_obs(timestep.observation)
info = {
"timestep.discount": timestep.discount,
"timestep.step_type": timestep.step_type,
}

obs, reward, terminated, truncated, info = dm_control_step2gym_step(timestep)

return obs, info # pyright: ignore[reportGeneralTypeIssues]

def step(
self, action: np.ndarray
) -> tuple[ObsType, float, bool, bool, dict[str, Any]]:
"""Steps through the dm-control environment."""
# Step through the dm-control environment
timestep = self._env.step(action)

# open up the timestep and process reward and observation
obs = dm_obs2gym_obs(timestep.observation)
reward = timestep.reward or 0

# set terminated and truncated
terminated, truncated = False, False
if timestep.last():
if timestep.discount == 0:
truncated = True
else:
terminated = True

info = {
"timestep.discount": timestep.discount,
"timestep.step_type": timestep.step_type,
}
obs, reward, terminated, truncated, info = dm_control_step2gym_step(timestep)

if self.render_mode == "human":
self.viewer.render()
Expand Down
167 changes: 167 additions & 0 deletions shimmy/dm_control_multiagent_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
"""Wrapper to convert a dm_env multiagent environment into a pettingzoo compatible environment."""
from __future__ import annotations

import functools
from itertools import repeat
from typing import Any

import dm_control.composer
import dm_env
import gymnasium
from gymnasium.envs.mujoco.mujoco_rendering import Viewer
from pettingzoo import ParallelEnv

from shimmy.utils.dm_env import dm_obs2gym_obs, dm_spec2gym_space


def _unravel_ma_timestep(
timestep: dm_env.TimeStep, agents: list[str]
) -> tuple[
dict[str, Any],
dict[str, float],
dict[str, bool],
dict[str, bool],
dict[str, Any],
]:
"""Opens up the timestep to return obs, reward, terminated, truncated, info."""
# set terminated and truncated
term, trunc = False, False
if timestep.last():
if timestep.discount == 0:
trunc = True
else:
term = True

# expand the observations
list_observations = [dm_obs2gym_obs(obs) for obs in timestep.observation]
observations: dict[str, Any] = dict(zip(agents, list_observations))

# sometimes deepmind decides not to reward people
rewards: dict[str, float] = dict(zip(agents, repeat(0.0)))
if timestep.reward:
rewards = dict(zip(agents, timestep.reward))

# expand everything else
terminations: dict[str, bool] = dict(zip(agents, repeat(term)))
truncations: dict[str, bool] = dict(zip(agents, repeat(trunc)))

# duplicate infos
info = {
"timestep.discount": timestep.discount,
"timestep.step_type": timestep.step_type,
}
info: dict[str, Any] = dict(zip(agents, repeat(info)))

return (
observations,
rewards,
terminations,
truncations,
info,
)


class DmControlMultiAgentCompatibilityV0(ParallelEnv):
"""Compatibility environment for multi-agent dm-control environments, primarily soccer."""

metadata = {"render_modes": ["human"]}

def __init__(
self,
env: dm_control.composer.Environment,
render_mode: str | None = None,
):
"""Wrapper that converts a dm control multi-agent environment into a pettingzoo environment.

Due to how the underlying environment is setup, this environment is nondeterministic, so seeding doesn't work.

Args:
env (dm_env.Environment): dm control multi-agent environment
render_mode (Optional[str]): render_mode
"""
super().__init__()
self._env = env
self.render_mode = render_mode

# get action and observation specs first
all_obs_spaces = [
dm_spec2gym_space(spec) for spec in self._env.observation_spec()
]
all_act_spaces = [dm_spec2gym_space(spec) for spec in self._env.action_spec()]
num_players = len(all_obs_spaces)

# agent definitions
self.possible_agents = ["player_" + str(r) for r in range(num_players)]
self.agent_id_name_mapping = dict(zip(range(num_players), self.possible_agents))
self.agent_name_id_mapping = dict(zip(self.possible_agents, range(num_players)))

# the official spaces
self.obs_spaces = dict(zip(self.possible_agents, all_obs_spaces))
self.act_spaces = dict(zip(self.possible_agents, all_act_spaces))

if self.render_mode == "human":
self.viewer = Viewer(
self._env.physics.model.ptr, self._env.physics.data.ptr
)

@functools.lru_cache(maxsize=None)
def observation_space(self, agent):
"""The observation space for agent."""
return self.obs_spaces[agent]

@functools.lru_cache(maxsize=None)
def action_space(self, agent):
"""The action space for agent."""
return self.act_spaces[agent]

def render(self):
"""Renders the environment."""
if self.render_mode is None:
gymnasium.logger.warn(
"You are calling render method without specifying any render mode."
)
return

def close(self):
"""Closes the environment."""
self._env.physics.free()
self._env.close()

if hasattr(self, "viewer"):
self.viewer.close()

def reset(self, seed=None, return_info=False, options=None):
"""Resets the dm-control environment."""
self.agents = self.possible_agents[:]
self.num_moves = 0

timestep = self._env.reset()

observations, _, _, _, infos = _unravel_ma_timestep(timestep, self.agents)

if not return_info:
return observations
else:
return observations, infos

def step(self, actions):
"""Steps through all agents with the actions."""
# assert that the actions _must_ have actions for all agents
assert len(actions) == len(
self.agents
), f"Must have actions for all {len(self.agents)} agents, currently only found {len(actions)}."

actions = actions.values()
timestep = self._env.step(actions)

obs, rewards, terminations, truncations, infos = _unravel_ma_timestep(
timestep, self.agents
)

if self.render_mode == "human":
self.viewer.render()

if any(terminations.values()) or any(truncations.values()):
self.agents = []

return obs, rewards, terminations, truncations, infos
81 changes: 81 additions & 0 deletions shimmy/dm_lab_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Wrapper to convert a dm_lab environment into a gymnasium compatible environment."""
from __future__ import annotations

from typing import Any, TypeVar

import gymnasium
import numpy as np
from gymnasium.core import ObsType

from shimmy.utils.dm_lab import dm_lab_obs2gym_obs_space, dm_lab_spec2gym_space


class DmLabCompatibilityV0(gymnasium.Env[ObsType, np.ndarray]):
"""A compatibility wrapper that converts a dm_lab-control environment into a gymnasium environment."""

metadata = {"render_modes": [], "render_fps": 10}

def __init__(
self,
env: Any,
render_mode: None = None,
):
"""Initialises the environment with a render mode along with render information."""
self._env = env

# need to do this to figure out what observation spec the user used
self._env.reset()
self.observation_space = dm_lab_obs2gym_obs_space(self._env.observations())
self.action_space = dm_lab_spec2gym_space(env.action_spec())

assert (
render_mode is None
), "Render mode must be set on dm_lab environment init. Pass `renderer='sdl'` to the config of the base env to enable human rendering."
self.render_mode = render_mode

def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the dm-lab environment."""
super().reset(seed=seed)

self._env.reset(seed=seed)
info = {}

return (
self._env.observations(),
info,
) # pyright: ignore[reportGeneralTypeIssues]

def step(
self, action: dict[str, np.ndarray]
) -> tuple[ObsType, float, bool, bool, dict[str, Any]]:
"""Steps through the dm-lab environment."""
# there's some funky quantization happening here, dm_lab only accepts ints as actions
action = np.array([a[0] for a in action.values()], dtype=np.intc)
reward = self._env.step(action)

obs = self._env.observations()
terminated = not self._env.is_running()
truncated = False
info = {}

return ( # pyright: ignore[reportGeneralTypeIssues]
obs,
reward,
terminated,
truncated,
info,
)

def render(self) -> None:
"""Renders the dm_lab env."""
raise NotImplementedError

def close(self):
"""Closes the environment."""
self._env.close()

def __getattr__(self, item: str):
"""If the attribute is missing, try getting the attribute from dm_lab env."""
return getattr(self._env, item)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pettingzoo.utils.env import AgentID


class OpenspielWrapperV0(pz.AECEnv):
class OpenspielCompatibilityV0(pz.AECEnv):
"""Wrapper that converts an openspiel environment into a pettingzoo environment."""

metadata = {"render_modes": []}
Expand Down
Loading