Skip to content

Commit

Permalink
Add pickle tests (#53)
Browse files Browse the repository at this point in the history
  • Loading branch information
elliottower authored Mar 31, 2023
1 parent 5eef8c6 commit 5dbaa9a
Show file tree
Hide file tree
Showing 16 changed files with 459 additions and 30 deletions.
2 changes: 2 additions & 0 deletions bin/dm_lab.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,5 @@ RUN git clone https://github.com/deepmind/lab.git \
&& rm -rf lab

ENTRYPOINT ["/usr/local/shimmy/bin/docker_entrypoint"]

RUN ls
2 changes: 2 additions & 0 deletions scripts/install_dm_lab.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ fi

pip3 install numpy

# TODO: fix installation issues on MacOS
# Build
if [ ! -d "lab" ]; then
git clone https://github.com/deepmind/lab.git
fi
cd lab
echo 'build --cxxopt=-std=c++17' > .bazelrc
bazel build -c opt //python/pip_package:build_pip_package
Expand Down
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ def get_version():
"dm-control>=1.0.10",
"imageio",
"h5py>=3.7.0",
"pettingzoo>=1.22.4",
"pettingzoo>=1.22.3",
],
"dm-lab": [],
"openspiel": ["open_spiel>=1.2", "pettingzoo>=1.22.4"],
"meltingpot": ["pettingzoo>=1.22.4"],
"dm-lab": ["dm-env>=1.6"],
"openspiel": ["open_spiel>=1.2", "pettingzoo>=1.22.3"],
"meltingpot": ["pettingzoo>=1.22.3"],
"bsuite": ["bsuite>=0.3.5"],
}
extras["all"] = list({lib for libs in extras.values() for lib in libs})
Expand Down
4 changes: 3 additions & 1 deletion shimmy/bsuite_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from bsuite.environments import Environment
from gymnasium.core import ObsType
from gymnasium.error import UnsupportedMode
from gymnasium.utils import EzPickle

from shimmy.utils.dm_env import dm_env_step2gym_step, dm_spec2gym_space

Expand All @@ -17,7 +18,7 @@
np.int = int # pyright: ignore[reportGeneralTypeIssues]


class BSuiteCompatibilityV0(gymnasium.Env[ObsType, np.ndarray]):
class BSuiteCompatibilityV0(gymnasium.Env[ObsType, np.ndarray], EzPickle):
"""A compatibility wrapper that converts a BSuite environment into a gymnasium environment.
Note:
Expand All @@ -33,6 +34,7 @@ def __init__(
render_mode: str | None = None,
):
"""Initialises the environment with a render mode along with render information."""
EzPickle.__init__(self, env, render_mode)
self._env = env

self.observation_space = dm_spec2gym_space(env.observation_spec())
Expand Down
6 changes: 5 additions & 1 deletion shimmy/dm_control_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from dm_control.rl import control
from gymnasium.core import ObsType
from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer
from gymnasium.utils import EzPickle

from shimmy.utils.dm_env import dm_env_step2gym_step, dm_spec2gym_space

Expand All @@ -27,7 +28,7 @@ class EnvType(Enum):
RL_CONTROL = 1


class DmControlCompatibilityV0(gymnasium.Env[ObsType, np.ndarray]):
class DmControlCompatibilityV0(gymnasium.Env[ObsType, np.ndarray], EzPickle):
"""This compatibility wrapper converts a dm-control environment into a gymnasium environment.
Dm-control is DeepMind's software stack for physics-based simulation and Reinforcement Learning environments, using MuJoCo physics.
Expand Down Expand Up @@ -57,6 +58,9 @@ def __init__(
camera_id: int = 0,
):
"""Initialises the environment with a render mode along with render information."""
EzPickle.__init__(
self, env, render_mode, render_height, render_width, camera_id
)
self._env = env
self.env_type = self._find_env_type(env)

Expand Down
6 changes: 4 additions & 2 deletions shimmy/dm_control_multiagent_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import gymnasium
import numpy as np
from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer
from gymnasium.utils import EzPickle
from pettingzoo.utils.env import ActionDict, AgentID, ObsDict, ParallelEnv

from shimmy.utils.dm_env import dm_obs2gym_obs, dm_spec2gym_space
Expand Down Expand Up @@ -62,7 +63,7 @@ def _unravel_ma_timestep(
)


class DmControlMultiAgentCompatibilityV0(ParallelEnv):
class DmControlMultiAgentCompatibilityV0(ParallelEnv, EzPickle):
"""This compatibility wrapper converts multi-agent dm-control environments, primarily soccer, into a Pettingzoo environment.
Dm-control is DeepMind's software stack for physics-based simulation and Reinforcement Learning environments,
Expand All @@ -84,7 +85,8 @@ def __init__(
env (dm_env.Environment): dm control multi-agent environment
render_mode (Optional[str]): render_mode
"""
super().__init__()
EzPickle.__init__(self, env=env, render_mode=render_mode)
ParallelEnv.__init__(self)
self._env = env
self.render_mode = render_mode

Expand Down
4 changes: 4 additions & 0 deletions shimmy/dm_lab_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def reset(
self._env.reset(seed=seed)
info = {}

if seed is not None:
print(
"Warning: DM-lab environments must be seeded in initialization, rather than with reset(seed)."
)
return (
self._env.observations(),
info,
Expand Down
5 changes: 3 additions & 2 deletions shimmy/meltingpot_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@
and modified to modern pettingzoo API
"""
# pyright: reportOptionalSubscript=false

# isort: skip_file
from __future__ import annotations

import functools
from typing import Optional

import gymnasium
import meltingpot.python
import numpy as np
import pygame
from gymnasium.utils.ezpickle import EzPickle
from ml_collections import config_dict
from pettingzoo.utils.env import ActionDict, AgentID, ObsDict, ParallelEnv

import meltingpot.python
import shimmy.utils.meltingpot as utils


Expand Down Expand Up @@ -89,6 +89,7 @@ def __init__(
for index in range(self._num_players)
]
self.agents = [agent for agent in self.possible_agents]
self.num_cycles = 0

# Set up pygame rendering
if self.render_mode == "human":
Expand Down
5 changes: 3 additions & 2 deletions shimmy/openspiel_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
import pettingzoo as pz
import pyspiel
from gymnasium import spaces
from gymnasium.utils import seeding
from gymnasium.utils import EzPickle, seeding
from pettingzoo.utils.env import AgentID, ObsType


class OpenspielCompatibilityV0(pz.AECEnv):
class OpenspielCompatibilityV0(pz.AECEnv, EzPickle):
"""This compatibility wrapper converts an openspiel environment into a pettingzoo environment.
OpenSpiel is a collection of environments and algorithms for research in general reinforcement learning
Expand All @@ -35,6 +35,7 @@ def __init__(
game (pyspiel.Game): game
render_mode (Optional[str]): render_mode
"""
EzPickle.__init__(self, game, render_mode)
super().__init__()
self.game = game
self.possible_agents = [
Expand Down
63 changes: 62 additions & 1 deletion tests/test_atari.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests the ale-py environments are correctly registered."""
import pickle
import warnings

import gymnasium as gym
Expand All @@ -7,7 +8,7 @@
from ale_py.roms import utils as rom_utils
from gymnasium.envs.registration import registry
from gymnasium.error import Error
from gymnasium.utils.env_checker import check_env
from gymnasium.utils.env_checker import check_env, data_equivalence

from shimmy.utils.envs_configs import ALL_ATARI_GAMES

Expand Down Expand Up @@ -47,3 +48,63 @@ def test_atari_envs(env_id):
assert isinstance(warning_message.message, Warning)
if warning_message.message.args[0] not in CHECK_ENV_IGNORE_WARNINGS:
raise Error(f"Unexpected warning: {warning_message.message}")


@pytest.mark.parametrize(
"env_id",
[
env_id
for env_id, env_spec in registry.items()
if "Pong" in env_id and env_spec.entry_point == "shimmy.atari_env:AtariEnv"
],
)
def test_atari_pickle(env_id):
"""Tests the atari envs, as there are 1000 possible environment, we only test the Pong variants."""
env_1 = gym.make(env_id)
env_2 = pickle.loads(pickle.dumps(env_1))

obs_1, info_1 = env_1.reset(seed=42)
obs_2, info_2 = env_2.reset(seed=42)
assert data_equivalence(obs_1, obs_2)
assert data_equivalence(info_1, info_2)
for _ in range(100):
actions = int(env_1.action_space.sample())
obs_1, reward_1, term_1, trunc_1, info_1 = env_1.step(actions)
obs_2, reward_2, term_2, trunc_2, info_2 = env_2.step(actions)
assert data_equivalence(obs_1, obs_2)
assert reward_1 == reward_2
assert term_1 == term_2 and trunc_1 == trunc_2
assert data_equivalence(info_1, info_2)

env_1.close()
env_2.close()


@pytest.mark.parametrize(
"env_id",
[
env_id
for env_id, env_spec in registry.items()
if "Pong" in env_id and env_spec.entry_point == "shimmy.atari_env:AtariEnv"
],
)
def test_atari_seeding(env_id):
"""Tests the seeding of the atari conversion wrapper."""
env_1 = gym.make(env_id)
env_2 = gym.make(env_id)

obs_1, info_1 = env_1.reset(seed=42)
obs_2, info_2 = env_2.reset(seed=42)
assert data_equivalence(obs_1, obs_2)
assert data_equivalence(info_1, info_2)
for _ in range(100):
actions = int(env_1.action_space.sample())
obs_1, reward_1, term_1, trunc_1, info_1 = env_1.step(actions)
obs_2, reward_2, term_2, trunc_2, info_2 = env_2.step(actions)
assert data_equivalence(obs_1, obs_2)
assert reward_1 == reward_2
assert term_1 == term_2 and trunc_1 == trunc_2
assert data_equivalence(info_1, info_2)

env_1.close()
env_2.close()
56 changes: 56 additions & 0 deletions tests/test_bsuite.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests the functionality of the BSuiteCompatibilityV0 on bsuite envs."""
import pickle
import warnings

import bsuite
Expand Down Expand Up @@ -109,3 +110,58 @@ def test_seeding(env_id):

env_1.close()
env_2.close()


# Without EzPickle:_register_bsuite_envs.<locals>._make_bsuite_env cannot be pickled
# With EzPickle: maximum recursion limit reached
FAILING_PICKLE_ENVS = [
"bsuite/bandit_noise-v0",
"bsuite/bandit_scale-v0",
"bsuite/cartpole-v0",
"bsuite/cartpole_noise-v0",
"bsuite/cartpole_scale-v0",
"bsuite/cartpole_swingup-v0",
"bsuite/catch_noise-v0",
"bsuite/catch_scale-v0",
"bsuite/mnist_noise-v0",
"bsuite/mnist_scale-v0",
"bsuite/mountain_car_noise-v0",
"bsuite/mountain_car_scale-v0",
]

PASSING_PICKLE_ENVS = [
"bsuite/mnist-v0",
"bsuite/umbrella_length-v0",
"bsuite/discounting_chain-v0",
"bsuite/deep_sea-v0",
"bsuite/umbrella_distract-v0",
"bsuite/catch-v0",
"bsuite/memory_len-v0",
"bsuite/mountain_car-v0",
"bsuite/memory_size-v0",
"bsuite/deep_sea_stochastic-v0",
"bsuite/bandit-v0",
]


@pytest.mark.parametrize("env_id", PASSING_PICKLE_ENVS)
def test_pickle(env_id):
"""Test that pickling works."""
env_1 = gym.make(env_id, **BSUITE_ENV_SETTINGS[env_id])
env_2 = pickle.loads(pickle.dumps(env_1))

obs_1, info_1 = env_1.reset(seed=42)
obs_2, info_2 = env_2.reset(seed=42)
assert data_equivalence(obs_1, obs_2)
assert data_equivalence(info_1, info_2)
for _ in range(100):
actions = int(env_1.action_space.sample())
obs_1, reward_1, term_1, trunc_1, info_1 = env_1.step(actions)
obs_2, reward_2, term_2, trunc_2, info_2 = env_2.step(actions)
assert data_equivalence(obs_1, obs_2)
assert reward_1 == reward_2
assert term_1 == term_2 and trunc_1 == trunc_2
assert data_equivalence(info_1, info_2)

env_1.close()
env_2.close()
31 changes: 31 additions & 0 deletions tests/test_dm_control.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests the functionality of the DmControlCompatibility Wrapper on dm_control envs."""
import pickle
import warnings
from typing import Callable

Expand Down Expand Up @@ -82,6 +83,36 @@ def test_seeding(env_id):
env_1 = gym.make(env_id)
env_2 = gym.make(env_id)

if "lqr" in env_id or (env_1.spec is not None and env_1.spec.nondeterministic):
# LQR fails this test currently.
return

obs_1, info_1 = env_1.reset(seed=42)
obs_2, info_2 = env_2.reset(seed=42)
assert data_equivalence(obs_1, obs_2)
assert data_equivalence(info_1, info_2)
for _ in range(10):
actions = env_1.action_space.sample()
obs_1, reward_1, term_1, trunc_1, info_1 = env_1.step(actions)
obs_2, reward_2, term_2, trunc_2, info_2 = env_2.step(actions)
assert data_equivalence(obs_1, obs_2)
assert reward_1 == reward_2
assert term_1 == term_2 and trunc_1 == trunc_2
assert data_equivalence(info_1, info_2)

env_1.close()
env_2.close()


@pytest.mark.skip(
reason="Fatal Python error: Segmentation fault (with or without EzPickle)"
)
@pytest.mark.parametrize("env_id", DM_CONTROL_ENV_IDS[0])
def test_pickle(env_id):
"""Test that dm-control seeding works."""
env_1 = gym.make(env_id)
env_2 = pickle.loads(pickle.dumps(env_1))

if "lqr" in env_id or (env_1.spec is not None and env_1.spec.nondeterministic):
# LQR fails this test currently.
return
Expand Down
Loading

0 comments on commit 5dbaa9a

Please sign in to comment.