Skip to content

Commit

Permalink
Add dm-lab (no tests yet) and dm-control-multiagent (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
jjshoots authored Dec 5, 2022
1 parent cda71f0 commit 93641f1
Show file tree
Hide file tree
Showing 12 changed files with 425 additions and 45 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_version():
"gym": ["gym>=0.26"],
"atari": ["ale-py~=0.8.0"],
# "imageio" should be "gymnasium[mujoco]>=0.26" but there are install conflicts
"dm-control": ["dm-control>=1.0.8", "imageio"],
"dm-control": ["dm-control>=1.0.8", "imageio", "h5py>=3.7.0"],
"openspiel": ["open_spiel>=1.2", "pettingzoo>=1.22"],
}
extras["all"] = list({lib for libs in extras.values() for lib in libs})
Expand Down
18 changes: 14 additions & 4 deletions shimmy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
"""API for converting popular non-gymnasium environments to a gymnasium compatible environment."""

__version__ = "0.1.0"
__version__ = "0.2.0"


try:
from shimmy.dm_control_compatibility import DmControlCompatibility
from shimmy.dm_control_compatibility import DmControlCompatibilityV0
except ImportError:
pass

try:
from shimmy.openspiel_wrapper import OpenspielWrapperV0
from shimmy.dm_control_multiagent_compatibility import (
DmControlMultiAgentCompatibilityV0,
)
except ImportError:
pass

from shimmy.openai_gym_compatibility import GymV22Compatibility, GymV26Compatibility
try:
from shimmy.openspiel_compatibility import OpenspielCompatibilityV0
except ImportError:
pass

try:
from shimmy.dm_lab_compatibility import DmLabCompatibilityV0
except ImportError:
pass

__all__ = [
"DmControlCompatibility",
Expand Down
30 changes: 6 additions & 24 deletions shimmy/dm_control_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from gymnasium.core import ObsType
from gymnasium.envs.mujoco.mujoco_rendering import Viewer

from shimmy.utils.dm_env import dm_obs2gym_obs, dm_spec2gym_space
from shimmy.utils.dm_env import dm_control_step2gym_step, dm_spec2gym_space


class EnvType(Enum):
Expand All @@ -27,7 +27,7 @@ class EnvType(Enum):
RL_CONTROL = 1


class DmControlCompatibility(gymnasium.Env[ObsType, np.ndarray]):
class DmControlCompatibilityV0(gymnasium.Env[ObsType, np.ndarray]):
"""A compatibility wrapper that converts a dm-control environment into a gymnasium environment.
Dm-control actually has two Environments classes, `dm_control.composer.Environment` and
Expand Down Expand Up @@ -81,36 +81,18 @@ def reset(
self.np_random = np.random.RandomState(seed=seed)

timestep = self._env.reset()
obs = dm_obs2gym_obs(timestep.observation)
info = {
"timestep.discount": timestep.discount,
"timestep.step_type": timestep.step_type,
}

obs, reward, terminated, truncated, info = dm_control_step2gym_step(timestep)

return obs, info # pyright: ignore[reportGeneralTypeIssues]

def step(
self, action: np.ndarray
) -> tuple[ObsType, float, bool, bool, dict[str, Any]]:
"""Steps through the dm-control environment."""
# Step through the dm-control environment
timestep = self._env.step(action)

# open up the timestep and process reward and observation
obs = dm_obs2gym_obs(timestep.observation)
reward = timestep.reward or 0

# set terminated and truncated
terminated, truncated = False, False
if timestep.last():
if timestep.discount == 0:
truncated = True
else:
terminated = True

info = {
"timestep.discount": timestep.discount,
"timestep.step_type": timestep.step_type,
}
obs, reward, terminated, truncated, info = dm_control_step2gym_step(timestep)

if self.render_mode == "human":
self.viewer.render()
Expand Down
167 changes: 167 additions & 0 deletions shimmy/dm_control_multiagent_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
"""Wrapper to convert a dm_env multiagent environment into a pettingzoo compatible environment."""
from __future__ import annotations

import functools
from itertools import repeat
from typing import Any

import dm_control.composer
import dm_env
import gymnasium
from gymnasium.envs.mujoco.mujoco_rendering import Viewer
from pettingzoo import ParallelEnv

from shimmy.utils.dm_env import dm_obs2gym_obs, dm_spec2gym_space


def _unravel_ma_timestep(
timestep: dm_env.TimeStep, agents: list[str]
) -> tuple[
dict[str, Any],
dict[str, float],
dict[str, bool],
dict[str, bool],
dict[str, Any],
]:
"""Opens up the timestep to return obs, reward, terminated, truncated, info."""
# set terminated and truncated
term, trunc = False, False
if timestep.last():
if timestep.discount == 0:
trunc = True
else:
term = True

# expand the observations
list_observations = [dm_obs2gym_obs(obs) for obs in timestep.observation]
observations: dict[str, Any] = dict(zip(agents, list_observations))

# sometimes deepmind decides not to reward people
rewards: dict[str, float] = dict(zip(agents, repeat(0.0)))
if timestep.reward:
rewards = dict(zip(agents, timestep.reward))

# expand everything else
terminations: dict[str, bool] = dict(zip(agents, repeat(term)))
truncations: dict[str, bool] = dict(zip(agents, repeat(trunc)))

# duplicate infos
info = {
"timestep.discount": timestep.discount,
"timestep.step_type": timestep.step_type,
}
info: dict[str, Any] = dict(zip(agents, repeat(info)))

return (
observations,
rewards,
terminations,
truncations,
info,
)


class DmControlMultiAgentCompatibilityV0(ParallelEnv):
"""Compatibility environment for multi-agent dm-control environments, primarily soccer."""

metadata = {"render_modes": ["human"]}

def __init__(
self,
env: dm_control.composer.Environment,
render_mode: str | None = None,
):
"""Wrapper that converts a dm control multi-agent environment into a pettingzoo environment.
Due to how the underlying environment is setup, this environment is nondeterministic, so seeding doesn't work.
Args:
env (dm_env.Environment): dm control multi-agent environment
render_mode (Optional[str]): render_mode
"""
super().__init__()
self._env = env
self.render_mode = render_mode

# get action and observation specs first
all_obs_spaces = [
dm_spec2gym_space(spec) for spec in self._env.observation_spec()
]
all_act_spaces = [dm_spec2gym_space(spec) for spec in self._env.action_spec()]
num_players = len(all_obs_spaces)

# agent definitions
self.possible_agents = ["player_" + str(r) for r in range(num_players)]
self.agent_id_name_mapping = dict(zip(range(num_players), self.possible_agents))
self.agent_name_id_mapping = dict(zip(self.possible_agents, range(num_players)))

# the official spaces
self.obs_spaces = dict(zip(self.possible_agents, all_obs_spaces))
self.act_spaces = dict(zip(self.possible_agents, all_act_spaces))

if self.render_mode == "human":
self.viewer = Viewer(
self._env.physics.model.ptr, self._env.physics.data.ptr
)

@functools.lru_cache(maxsize=None)
def observation_space(self, agent):
"""The observation space for agent."""
return self.obs_spaces[agent]

@functools.lru_cache(maxsize=None)
def action_space(self, agent):
"""The action space for agent."""
return self.act_spaces[agent]

def render(self):
"""Renders the environment."""
if self.render_mode is None:
gymnasium.logger.warn(
"You are calling render method without specifying any render mode."
)
return

def close(self):
"""Closes the environment."""
self._env.physics.free()
self._env.close()

if hasattr(self, "viewer"):
self.viewer.close()

def reset(self, seed=None, return_info=False, options=None):
"""Resets the dm-control environment."""
self.agents = self.possible_agents[:]
self.num_moves = 0

timestep = self._env.reset()

observations, _, _, _, infos = _unravel_ma_timestep(timestep, self.agents)

if not return_info:
return observations
else:
return observations, infos

def step(self, actions):
"""Steps through all agents with the actions."""
# assert that the actions _must_ have actions for all agents
assert len(actions) == len(
self.agents
), f"Must have actions for all {len(self.agents)} agents, currently only found {len(actions)}."

actions = actions.values()
timestep = self._env.step(actions)

obs, rewards, terminations, truncations, infos = _unravel_ma_timestep(
timestep, self.agents
)

if self.render_mode == "human":
self.viewer.render()

if any(terminations.values()) or any(truncations.values()):
self.agents = []

return obs, rewards, terminations, truncations, infos
81 changes: 81 additions & 0 deletions shimmy/dm_lab_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Wrapper to convert a dm_lab environment into a gymnasium compatible environment."""
from __future__ import annotations

from typing import Any, TypeVar

import gymnasium
import numpy as np
from gymnasium.core import ObsType

from shimmy.utils.dm_lab import dm_lab_obs2gym_obs_space, dm_lab_spec2gym_space


class DmLabCompatibilityV0(gymnasium.Env[ObsType, np.ndarray]):
"""A compatibility wrapper that converts a dm_lab-control environment into a gymnasium environment."""

metadata = {"render_modes": [], "render_fps": 10}

def __init__(
self,
env: Any,
render_mode: None = None,
):
"""Initialises the environment with a render mode along with render information."""
self._env = env

# need to do this to figure out what observation spec the user used
self._env.reset()
self.observation_space = dm_lab_obs2gym_obs_space(self._env.observations())
self.action_space = dm_lab_spec2gym_space(env.action_spec())

assert (
render_mode is None
), "Render mode must be set on dm_lab environment init. Pass `renderer='sdl'` to the config of the base env to enable human rendering."
self.render_mode = render_mode

def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the dm-lab environment."""
super().reset(seed=seed)

self._env.reset(seed=seed)
info = {}

return (
self._env.observations(),
info,
) # pyright: ignore[reportGeneralTypeIssues]

def step(
self, action: dict[str, np.ndarray]
) -> tuple[ObsType, float, bool, bool, dict[str, Any]]:
"""Steps through the dm-lab environment."""
# there's some funky quantization happening here, dm_lab only accepts ints as actions
action = np.array([a[0] for a in action.values()], dtype=np.intc)
reward = self._env.step(action)

obs = self._env.observations()
terminated = not self._env.is_running()
truncated = False
info = {}

return ( # pyright: ignore[reportGeneralTypeIssues]
obs,
reward,
terminated,
truncated,
info,
)

def render(self) -> None:
"""Renders the dm_lab env."""
raise NotImplementedError

def close(self):
"""Closes the environment."""
self._env.close()

def __getattr__(self, item: str):
"""If the attribute is missing, try getting the attribute from dm_lab env."""
return getattr(self._env, item)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pettingzoo.utils.env import AgentID


class OpenspielWrapperV0(pz.AECEnv):
class OpenspielCompatibilityV0(pz.AECEnv):
"""Wrapper that converts an openspiel environment into a pettingzoo environment."""

metadata = {"render_modes": []}
Expand Down
Loading

0 comments on commit 93641f1

Please sign in to comment.