Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dm-lab (no tests yet) and dm-control-multiagent #7

Merged
merged 27 commits into from
Dec 5, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions shimmy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from shimmy.dm_control_compatibility import (
DmControlCompatibility as DmControlCompatibilityV0,
)
from shimmy.dm_lab_compatibility import DmLabCompatibility as DmLabCompatibilityV0
from shimmy.openspiel_wrapper import OpenspielWrapper as OpenspielWrapperV0

__version__ = "0.0.1a"
28 changes: 5 additions & 23 deletions shimmy/dm_control_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from gymnasium.core import ObsType
from numpy.random import RandomState

from shimmy.utils import dm_obs2gym_obs, dm_spec2gym_space
from shimmy.utils.dm_env_utils import dm_spec2gym_space, expose_timestep


class DmControlCompatibility(gymnasium.Env[ObsType, np.ndarray]):
Expand Down Expand Up @@ -57,36 +57,18 @@ def reset(
self.np_random = RandomState(seed=seed)

timestep = self._env.reset()
obs = dm_obs2gym_obs(timestep.observation)
info = {
"timestep.discount": timestep.discount,
"timestep.step_type": timestep.step_type,
}

obs, reward, terminated, truncated, info = expose_timestep(timestep)
pseudo-rnd-thoughts marked this conversation as resolved.
Show resolved Hide resolved

return obs, info # pyright: ignore[reportGeneralTypeIssues]

def step(
self, action: np.ndarray
) -> tuple[ObsType, float, bool, bool, dict[str, Any]]:
"""Steps through the dm-control environment."""
# Step through the dm-control environment
timestep = self._env.step(action)

# open up the timestep and process reward and observation
obs = dm_obs2gym_obs(timestep.observation)
reward = timestep.reward or 0

# set terminated and truncated
terminated, truncated = False, False
if timestep.last():
if timestep.discount == 0:
truncated = True
else:
terminated = True

info = {
"timestep.discount": timestep.discount,
"timestep.step_type": timestep.step_type,
}
obs, reward, terminated, truncated, info = expose_timestep(timestep)

if self.render_mode == "human":
self.viewer.render()
Expand Down
97 changes: 97 additions & 0 deletions shimmy/dm_lab_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""Wrapper to convert a dm_lab environment into a gymnasium compatible environment.

Taken from
https://github.com/ikostrikov/dmcgym/blob/main/dmcgym/env.py
and modified to modern gymnasium API
"""
from __future__ import annotations

from typing import Any

import gymnasium
import numpy as np
from dm_env import Environment
from gymnasium.core import ObsType

from shimmy.utils.dm_lab_utils import dm_lab_spec2gym_space


class DmLabCompatibility(gymnasium.Env[ObsType, np.ndarray]):
"""A compatibility wrapper that converts a dm_lab-control environment into a gymnasium environment."""

metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 10}

def __init__(
self,
env: Environment,
render_mode: str | None = None,
jjshoots marked this conversation as resolved.
Show resolved Hide resolved
render_height: int = 84,
render_width: int = 84,
camera_id: int = 0,
):
"""Initialises the environment with a render mode along with render information."""
self._env = env

self.observation_space = dm_lab_spec2gym_space(env.observation_spec())
self.action_space = dm_lab_spec2gym_space(env.action_spec())

assert render_mode is None or render_mode in self.metadata["render_modes"]
self.render_mode = render_mode
self.render_height, self.render_width = render_height, render_width
self.camera_id = camera_id

if self.render_mode == "human":
from gymnasium.envs.mujoco.mujoco_rendering import Viewer

self.viewer = Viewer(
self._env.physics.model.ptr, self._env.physics.data.ptr
)

def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the dm-lab environment."""
super().reset(seed=seed)

timestep = self._env.reset()

obs, reward, terminated, truncated, info = expose_timestep(timestep)

obs = None
pseudo-rnd-thoughts marked this conversation as resolved.
Show resolved Hide resolved
info = {}

return obs, info # pyright: ignore[reportGeneralTypeIssues]

def step(
self, action: np.ndarray
jjshoots marked this conversation as resolved.
Show resolved Hide resolved
) -> tuple[ObsType, float, bool, bool, dict[str, Any]]:
"""Steps through the dm-lab environment."""
# there's some funky quantization happening here, dm_lab only accepts ints as actions
action = np.array([a[0] for a in action.values()], dtype=np.intc)
timestep = self._env.step(action)
print(timestep)

obs, reward, terminated, truncated, info = expose_timestep(timestep)

if self.render_mode == "human":
self.viewer.render()

return ( # pyright: ignore[reportGeneralTypeIssues]
obs,
reward,
terminated,
truncated,
info,
)

def render(self) -> np.ndarray | None:
jjshoots marked this conversation as resolved.
Show resolved Hide resolved
"""Renders the dm_lab env."""
raise NotImplementedError

def close(self):
"""Closes the environment."""
self._env.close()

def __getattr__(self, item: str):
"""If the attribute is missing, try getting the attribute from dm_lab env."""
return getattr(self._env, item)
31 changes: 31 additions & 0 deletions shimmy/utils.py → shimmy/utils/dm_env_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
from dm_env.specs import Array, BoundedArray, DiscreteArray
from gymnasium import spaces
from gymnasium.core import ObsType


def dm_spec2gym_space(spec) -> spaces.Space[Any]:
Expand Down Expand Up @@ -46,3 +47,33 @@ def dm_obs2gym_obs(obs) -> np.ndarray | dict[str, Any]:
return {key: dm_obs2gym_obs(value) for key, value in copy.copy(obs).items()}
else:
return np.asarray(obs)


def expose_timestep(timestep) -> tuple[ObsType, float, bool, bool, dict[str, Any]]:
"""Opens up the timestep to return obs, reward, terminated, truncated, info."""
if isinstance(timestep, bool):
raise AssertionError("wtf")

obs = dm_obs2gym_obs(timestep.observation)
reward = timestep.reward or 0

# set terminated and truncated
terminated, truncated = False, False
if timestep.last():
if timestep.discount == 0:
truncated = True
else:
terminated = True

info = {
"timestep.discount": timestep.discount,
"timestep.step_type": timestep.step_type,
}

return ( # pyright: ignore[reportGeneralTypeIssues]
obs,
reward,
terminated,
truncated,
info,
)
65 changes: 65 additions & 0 deletions shimmy/utils/dm_lab_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Utility functions for the compatibility wrappers."""
from __future__ import annotations

from collections import OrderedDict
from typing import Any

import numpy as np
from gymnasium import spaces


def dm_lab_spec2gym_space(spec) -> spaces.Space[Any]:
"""Converts a dm_lab spec to a gymnasium space."""
if isinstance(spec, list):
expanded = {}
for desc in spec:
assert (
"name" in desc
), f"Can't find name for the description: {desc} in spec."

# some observation spaces have a string description, we ignore those for now
if "dtype" in desc:
if desc["dtype"] == str:
continue

expanded[desc["name"]] = dm_lab_spec2gym_space(desc)

return spaces.Dict(expanded)
if isinstance(spec, (OrderedDict, dict)):
# this is an observation space
if "shape" in spec and "dtype" in spec:
assert "dtype" in spec, f"Can't find dtype for spec: {spec}."
assert "shape" in spec, f"Can't find shape for spec: {spec}."

low = 0
high = 0
if np.issubdtype(spec["dtype"], np.integer):
low = np.iinfo(spec["dtype"]).min
high = np.iinfo(spec["dtype"]).max
elif np.issubdtype(spec["dtype"], np.inexact):
low = float("-inf")
high = float("inf")

return spaces.Box(
low=low, high=high, shape=spec["shape"], dtype=spec["dtype"]
)

# this is an action space
elif "min" in spec and "max" in spec:
return spaces.Box(low=spec["min"], high=spec["max"], dtype=np.float64)

# we dk wtf it is here
else:
raise NotImplementedError(
f"Unknown spec definition: {spec}, please report."
)

else:
raise NotImplementedError(
f"Cannot convert dm_spec to gymnasium space, unknown spec: {spec}, please report."
)


def dm_lab_obs2gym_obs(obs) -> np.ndarray | dict[str, Any]:
"""Converts a dm_lab observation to a gymnasium observation."""
print(obs)
20 changes: 20 additions & 0 deletions src/dmlab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# import deepmind_lab
from shimmy import DmLabCompatibilityV0

# observations = ['RGBD']
# env = deepmind_lab.Lab('lt_chasm', observations,
# config={'width': '640', # screen size, in pixels
# 'height': '480', # screen size, in pixels
# 'botCount': '2'}, # lt_chasm option.
# renderer='hardware') # select renderer.

# env.reset()
# env = DmLabCompatibilityV0(env=env, render_mode=None)

# # reset and begin test
# env.reset()
# term, trunc = False, False

# # run until termination
# while not term and not trunc:
# obs, rew, term, trunc, info = env.step(env.action_space.sample())
11 changes: 7 additions & 4 deletions tests/test_openspiel.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,16 @@ def test_passing_games(game):
@pytest.mark.parametrize("game", _FAILING_GAMES)
def test_failing_games(game):
"""Ensures that failing games are still failing."""
with pytest.raises((pyspiel.SpielError, NotImplementedError)):
with pytest.raises(pyspiel.SpielError):
test_passing_games(game)


def test_seeding():
@pytest.mark.parametrize("game", _PASSING_GAMES)
def test_seeding(game):
"""Tests the seeding of the openspiel conversion wrapper."""
# load envs
env1 = pyspiel.load_game("2048")
env2 = pyspiel.load_game("2048")
env1 = pyspiel.load_game(game)
env2 = pyspiel.load_game(game)

# convert the environment
env1 = OpenspielWrapperV0(env1, render_mode=None)
Expand Down Expand Up @@ -173,3 +174,5 @@ def test_seeding():
assert stuff1 == stuff2, "Incorrect returns on iteration."
elif isinstance(stuff1, np.ndarray):
assert (stuff1 == stuff2).all(), "Incorrect returns on iteration."
elif isinstance(stuff1, str):
assert stuff1 == stuff2, "Incorrect returns on iteration."