Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dm-lab (no tests yet) and dm-control-multiagent #7

Merged
merged 27 commits into from
Dec 5, 2022
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion shimmy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,4 @@
except ImportError:
pass


jjshoots marked this conversation as resolved.
Show resolved Hide resolved
__version__ = "0.0.1a"
26 changes: 4 additions & 22 deletions shimmy/dm_control_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,36 +81,18 @@ def reset(
self.np_random = np.random.RandomState(seed=seed)

timestep = self._env.reset()
obs = dm_obs2gym_obs(timestep.observation)
info = {
"timestep.discount": timestep.discount,
"timestep.step_type": timestep.step_type,
}

obs, reward, terminated, truncated, info = dm_control_step2gym_step(timestep)

return obs, info # pyright: ignore[reportGeneralTypeIssues]

def step(
self, action: np.ndarray
) -> tuple[ObsType, float, bool, bool, dict[str, Any]]:
"""Steps through the dm-control environment."""
# Step through the dm-control environment
timestep = self._env.step(action)

# open up the timestep and process reward and observation
obs = dm_obs2gym_obs(timestep.observation)
reward = timestep.reward or 0

# set terminated and truncated
terminated, truncated = False, False
if timestep.last():
if timestep.discount == 0:
truncated = True
else:
terminated = True

info = {
"timestep.discount": timestep.discount,
"timestep.step_type": timestep.step_type,
}
obs, reward, terminated, truncated, info = dm_control_step2gym_step(timestep)

if self.render_mode == "human":
self.viewer.render()
Expand Down
81 changes: 81 additions & 0 deletions shimmy/dm_lab_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Wrapper to convert a dm_lab environment into a gymnasium compatible environment."""
from __future__ import annotations

from typing import Any, TypeVar

import gymnasium
import numpy as np
from gymnasium.core import ObsType

from shimmy.utils.dm_lab import dm_lab_obs2gym_obs_space, dm_lab_spec2gym_space


class DmLabCompatibility(gymnasium.Env[ObsType, np.ndarray]):
"""A compatibility wrapper that converts a dm_lab-control environment into a gymnasium environment."""

metadata = {"render_modes": None, "render_fps": 10}
jjshoots marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
env: Any,
render_mode: str | None = None,
jjshoots marked this conversation as resolved.
Show resolved Hide resolved
):
"""Initialises the environment with a render mode along with render information."""
self._env = env

# need to do this to figure out what observation spec the user used
self._env.reset()
self.observation_space = dm_lab_obs2gym_obs_space(self._env.observations())
self.action_space = dm_lab_spec2gym_space(env.action_spec())

assert (
render_mode is None
), "Render mode must be set on dm_lab environment init. Pass `renderer='sdl'` to the config of the base env to enable human rendering."
self.render_mode = render_mode

def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the dm-lab environment."""
super().reset(seed=seed)

self._env.reset(seed=seed)
info = {}

return (
self._env.observations(),
info,
) # pyright: ignore[reportGeneralTypeIssues]

def step(
self, action: np.ndarray
jjshoots marked this conversation as resolved.
Show resolved Hide resolved
) -> tuple[ObsType, float, bool, bool, dict[str, Any]]:
"""Steps through the dm-lab environment."""
# there's some funky quantization happening here, dm_lab only accepts ints as actions
action = np.array([a[0] for a in action.values()], dtype=np.intc)
reward = self._env.step(action)

obs = self._env.observations()
terminated = not self._env.is_running()
truncated = False
info = {}

return ( # pyright: ignore[reportGeneralTypeIssues]
obs,
reward,
terminated,
truncated,
info,
)

def render(self) -> np.ndarray | None:
jjshoots marked this conversation as resolved.
Show resolved Hide resolved
"""Renders the dm_lab env."""
raise NotImplementedError

def close(self):
"""Closes the environment."""
self._env.close()

def __getattr__(self, item: str):
"""If the attribute is missing, try getting the attribute from dm_lab env."""
return getattr(self._env, item)
33 changes: 33 additions & 0 deletions shimmy/utils/dm_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
from dm_env.specs import Array, BoundedArray, DiscreteArray
from gymnasium import spaces
from gymnasium.core import ObsType


def dm_spec2gym_space(spec) -> spaces.Space[Any]:
Expand Down Expand Up @@ -46,3 +47,35 @@ def dm_obs2gym_obs(obs) -> np.ndarray | dict[str, Any]:
return {key: dm_obs2gym_obs(value) for key, value in copy.copy(obs).items()}
else:
return np.asarray(obs)


def dm_control_step2gym_step(
timestep,
) -> tuple[ObsType, float, bool, bool, dict[str, Any]]:
"""Opens up the timestep to return obs, reward, terminated, truncated, info."""
if isinstance(timestep, bool):
raise AssertionError("wtf")
jjshoots marked this conversation as resolved.
Show resolved Hide resolved

obs = dm_obs2gym_obs(timestep.observation)
reward = timestep.reward or 0

# set terminated and truncated
terminated, truncated = False, False
if timestep.last():
if timestep.discount == 0:
truncated = True
else:
terminated = True

info = {
"timestep.discount": timestep.discount,
"timestep.step_type": timestep.step_type,
}

return ( # pyright: ignore[reportGeneralTypeIssues]
obs,
reward,
terminated,
truncated,
info,
)
68 changes: 68 additions & 0 deletions shimmy/utils/dm_lab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Utility functions for the compatibility wrappers."""
from __future__ import annotations

from collections import OrderedDict
from typing import Any

import numpy as np
from gymnasium import spaces


def dm_lab_obs2gym_obs_space(observation: dict) -> spaces.Space[Any]:
"""Gets the observation spec from a single observation."""
assert isinstance(
observation, (OrderedDict, dict)
), f"Observation must be a dict, got {observation}"

all_spaces = dict()
for key, value in observation.items():
dtype = value.dtype

low = None
high = None
if np.issubdtype(dtype, np.integer):
low = np.iinfo(dtype).min
high = np.iinfo(dtype).max
elif np.issubdtype(dtype, np.inexact):
low = float("-inf")
high = float("inf")
else:
raise ValueError(f"Unknown dtype {dtype}.")

all_spaces[key] = spaces.Box(low=low, high=high, shape=value.shape, dtype=dtype)

return spaces.Dict(all_spaces)


def dm_lab_spec2gym_space(spec) -> spaces.Space[Any]:
"""Converts a dm_lab spec to a gymnasium space."""
if isinstance(spec, list):
expanded = {}
for desc in spec:
assert (
"name" in desc
), f"Can't find name for the description: {desc} in spec."

# some observation spaces have a string description, we ignore those for now
if "dtype" in desc:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Single if statement?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not possible, we can't do if desc["dtype"] == str if the first statement fails, throwing an error.

if desc["dtype"] == str:
continue

expanded[desc["name"]] = dm_lab_spec2gym_space(desc)

return spaces.Dict(expanded)
if isinstance(spec, (OrderedDict, dict)):
# this is an action space
if "min" in spec and "max" in spec:
return spaces.Box(low=spec["min"], high=spec["max"], dtype=np.float64)

# we dk wtf it is here
else:
raise NotImplementedError(
f"Unknown spec definition: {spec}, please report."
)

else:
raise NotImplementedError(
f"Cannot convert dm_spec to gymnasium space, unknown spec: {spec}, please report."
)
11 changes: 7 additions & 4 deletions tests/test_openspiel.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,16 @@ def test_passing_games(game):
@pytest.mark.parametrize("game", _FAILING_GAMES)
def test_failing_games(game):
"""Ensures that failing games are still failing."""
with pytest.raises((pyspiel.SpielError, NotImplementedError)):
with pytest.raises(pyspiel.SpielError):
test_passing_games(game)


def test_seeding():
@pytest.mark.parametrize("game", _PASSING_GAMES)
def test_seeding(game):
"""Tests the seeding of the openspiel conversion wrapper."""
# load envs
env1 = pyspiel.load_game("2048")
env2 = pyspiel.load_game("2048")
env1 = pyspiel.load_game(game)
env2 = pyspiel.load_game(game)

# convert the environment
env1 = OpenspielWrapperV0(env1, render_mode=None)
Expand Down Expand Up @@ -173,3 +174,5 @@ def test_seeding():
assert stuff1 == stuff2, "Incorrect returns on iteration."
elif isinstance(stuff1, np.ndarray):
assert (stuff1 == stuff2).all(), "Incorrect returns on iteration."
elif isinstance(stuff1, str):
assert stuff1 == stuff2, "Incorrect returns on iteration."