Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pickle tests #53

Merged
merged 34 commits into from
Mar 31, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
7f90eb8
Add multi-agent dm control dockerfile and workflow
elliottower Mar 27, 2023
b30e99a
Fix typo in dm control multiagent workflow
elliottower Mar 27, 2023
6a0e2f2
Merge remote-tracking branch 'upstream/HEAD' into dm-lab-ci
elliottower Mar 28, 2023
9e3dc4e
Add dm-lab dockerfile and workflow
elliottower Mar 28, 2023
e8f9d87
Fix typo in dm_lab dockerfile
elliottower Mar 28, 2023
ae08a50
Add shimmy[dm-lab] pip installation to match other envs
elliottower Mar 28, 2023
1bc5a35
Add pickling tests for meltingpot, openspiel, bsuite, EzPickle for op…
elliottower Mar 28, 2023
2e97a55
Add initial pickle test to all third party environments (besides gym)
elliottower Mar 29, 2023
9f0760d
Merge branch 'main' into pickle-tests
elliottower Mar 29, 2023
f0afeef
Update PZ version after 1.22.4 yank
elliottower Mar 29, 2023
706db5f
Add importorskip for dm_lab so main tests don't fail
elliottower Mar 29, 2023
dc9df9c
Try old import deepmind_lab inside of test_check_env()
elliottower Mar 29, 2023
5917bd5
Add dm-env requirement to dm-lab dockerfile (fix CI error)
elliottower Mar 30, 2023
55e13a7
Fix typo in multiagent dm control test
elliottower Mar 30, 2023
29a0e4f
Update dm-lab tests to correct action type (from int to dict)
elliottower Mar 30, 2023
86fc8bc
Fix dm control multiagent init error (recursion limit)
elliottower Mar 30, 2023
efac866
Add all dm-lab levels to test, comment out obs test (not matching)
elliottower Mar 30, 2023
45de236
Attempt to fix dm-lab seeding, fix pickling test typo
elliottower Mar 30, 2023
6be6af8
Fix typo in dm lab test
elliottower Mar 30, 2023
e6a1c88
Fix meltingpot isort issues (ignore files, works locally just not in CI)
elliottower Mar 30, 2023
ad564f8
Fix dm control to take 10x less time for seed testing (1+hrs currently)
elliottower Mar 30, 2023
320f498
Fix typo in dm lab test
elliottower Mar 30, 2023
4992a4c
Make seed warning a print statement so execution doesn't stop during …
elliottower Mar 30, 2023
084d9b2
Skip dm lab tests
elliottower Mar 30, 2023
ef0f81a
Switch dm lab tests to do lt_chasm (env used in official examples)
elliottower Mar 30, 2023
246c8e1
Skip dm lab tests again due to erros
elliottower Mar 30, 2023
9413b02
Fix typo in install script
elliottower Mar 30, 2023
60f7482
Change dm_control_multi_agent test skip reason (weakref can't be pick…
elliottower Mar 30, 2023
0240502
Fix typo in dm control test
elliottower Mar 30, 2023
492a364
Fix typo in dm control mutliagent test
elliottower Mar 30, 2023
aa0bb38
Fix typo in dm control multiagent test
elliottower Mar 30, 2023
746a823
Remove isort ignore and don't run local pre-commit hooks
elliottower Mar 31, 2023
ca39529
Remove repeated noqa ignores for file-wide ignores, remove extra imports
elliottower Mar 31, 2023
cce032f
Fix pre-commit in meltingpot test
elliottower Mar 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions scripts/install_dm_lab.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ fi

pip3 install numpy

# TODO: fix installation issues on MacOS
# Build
if [ ! -d "lab" ]; then
git clone https://github.com/deepmind/lab.git
fi
cd lab
echo 'build --cxxopt=-std=c++17' > .bazelrc
bazel build -c opt //python/pip_package:build_pip_package
Expand Down
6 changes: 5 additions & 1 deletion shimmy/dm_control_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from dm_control.rl import control
from gymnasium.core import ObsType
from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer
from gymnasium.utils import EzPickle

from shimmy.utils.dm_env import dm_env_step2gym_step, dm_spec2gym_space

Expand All @@ -27,7 +28,7 @@ class EnvType(Enum):
RL_CONTROL = 1


class DmControlCompatibilityV0(gymnasium.Env[ObsType, np.ndarray]):
class DmControlCompatibilityV0(gymnasium.Env[ObsType, np.ndarray], EzPickle):
"""This compatibility wrapper converts a dm-control environment into a gymnasium environment.

Dm-control is DeepMind's software stack for physics-based simulation and Reinforcement Learning environments, using MuJoCo physics.
Expand Down Expand Up @@ -57,6 +58,9 @@ def __init__(
camera_id: int = 0,
):
"""Initialises the environment with a render mode along with render information."""
EzPickle.__init__(
self, env, render_mode, render_height, render_width, camera_id
)
self._env = env
self.env_type = self._find_env_type(env)

Expand Down
4 changes: 3 additions & 1 deletion shimmy/dm_control_multiagent_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import gymnasium
import numpy as np
from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer
from gymnasium.utils import EzPickle
from pettingzoo.utils.env import ActionDict, AgentID, ObsDict, ParallelEnv

from shimmy.utils.dm_env import dm_obs2gym_obs, dm_spec2gym_space
Expand Down Expand Up @@ -62,7 +63,7 @@ def _unravel_ma_timestep(
)


class DmControlMultiAgentCompatibilityV0(ParallelEnv):
class DmControlMultiAgentCompatibilityV0(ParallelEnv, EzPickle):
"""This compatibility wrapper converts multi-agent dm-control environments, primarily soccer, into a Pettingzoo environment.

Dm-control is DeepMind's software stack for physics-based simulation and Reinforcement Learning environments,
Expand All @@ -84,6 +85,7 @@ def __init__(
env (dm_env.Environment): dm control multi-agent environment
render_mode (Optional[str]): render_mode
"""
EzPickle.__init__(self, env=env, render_mode=render_mode)
super().__init__()
self._env = env
self.render_mode = render_mode
Expand Down
63 changes: 62 additions & 1 deletion tests/test_atari.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests the ale-py environments are correctly registered."""
import pickle
import warnings

import gymnasium as gym
Expand All @@ -7,7 +8,7 @@
from ale_py.roms import utils as rom_utils
from gymnasium.envs.registration import registry
from gymnasium.error import Error
from gymnasium.utils.env_checker import check_env
from gymnasium.utils.env_checker import check_env, data_equivalence

from shimmy.utils.envs_configs import ALL_ATARI_GAMES

Expand Down Expand Up @@ -47,3 +48,63 @@ def test_atari_envs(env_id):
assert isinstance(warning_message.message, Warning)
if warning_message.message.args[0] not in CHECK_ENV_IGNORE_WARNINGS:
raise Error(f"Unexpected warning: {warning_message.message}")


@pytest.mark.parametrize(
"env_id",
[
env_id
for env_id, env_spec in registry.items()
if "Pong" in env_id and env_spec.entry_point == "shimmy.atari_env:AtariEnv"
],
)
def test_atari_pickle(env_id):
"""Tests the atari envs, as there are 1000 possible environment, we only test the Pong variants."""
env_1 = gym.make(env_id)
env_2 = pickle.loads(pickle.dumps(env_1))

obs_1, info_1 = env_1.reset(seed=42)
obs_2, info_2 = env_2.reset(seed=42)
assert data_equivalence(obs_1, obs_2)
assert data_equivalence(info_1, info_2)
for _ in range(100):
actions = int(env_1.action_space.sample())
obs_1, reward_1, term_1, trunc_1, info_1 = env_1.step(actions)
obs_2, reward_2, term_2, trunc_2, info_2 = env_2.step(actions)
assert data_equivalence(obs_1, obs_2)
assert reward_1 == reward_2
assert term_1 == term_2 and trunc_1 == trunc_2
assert data_equivalence(info_1, info_2)

env_1.close()
env_2.close()


@pytest.mark.parametrize(
"env_id",
[
env_id
for env_id, env_spec in registry.items()
if "Pong" in env_id and env_spec.entry_point == "shimmy.atari_env:AtariEnv"
],
)
def test_atari_seeding(env_id):
"""Tests the seeding of the atari conversion wrapper."""
env_1 = gym.make(env_id)
env_2 = gym.make(env_id)

obs_1, info_1 = env_1.reset(seed=42)
obs_2, info_2 = env_2.reset(seed=42)
assert data_equivalence(obs_1, obs_2)
assert data_equivalence(info_1, info_2)
for _ in range(100):
actions = int(env_1.action_space.sample())
obs_1, reward_1, term_1, trunc_1, info_1 = env_1.step(actions)
obs_2, reward_2, term_2, trunc_2, info_2 = env_2.step(actions)
assert data_equivalence(obs_1, obs_2)
assert reward_1 == reward_2
assert term_1 == term_2 and trunc_1 == trunc_2
assert data_equivalence(info_1, info_2)

env_1.close()
env_2.close()
57 changes: 49 additions & 8 deletions tests/test_bsuite.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,56 @@ def test_seeding(env_id):
env_2.close()


@pytest.mark.parametrize("env_id", BSUITE_ENV_IDS)
# Without EzPickle:_register_bsuite_envs.<locals>._make_bsuite_env cannot be pickled
# With EzPickle: maximum recursion limit reached
FAILING_PICKLE_ENVS = [
"bsuite/bandit_noise-v0",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I have found this before. It is generally because one of the classes implemented __getattr__.
Is there a particular parameter that is missing?

I suspect the issue is that the environment only defines a variable on reset or another function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ll go check that individual env and see. But what would we be able to do to fix it, besides submitting a PR to their repo? I guess we could in the compatibility wrapper specifically check if it’s that env or if an env has that specific variable not defined in init, and then do whatever modifications are required?

Copy link
Member

@pseudo-rnd-thoughts pseudo-rnd-thoughts Mar 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In short, we can't fix it. The environments failing use a wrapper that contains

def __getattr__(self, attr):
    return getattr(self._env, attr)

The problem exists when _env doesn't exist in the wrapper, i.e., when a staticmethod is existed (__setstate__) then this causes an infinite loop to occur of __getattr__("static_method") -> __getattr__("_env") -> __getattr__("_env") -> ad infinitum
The second issue is that dm don't seem to be maintaining the project anymore.

The solution is simple

def __getattr__(self, attr):
    if "_env" in self.__dict__:
          return getattr(self._env, attr)
    else:
          return super().__getattribute__(attr)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you put a PR up by any chance? Even if they don't' end up merging it I feel like we might as well try, you seem to understand this stuff better than me though, I'm not sure I'd be able to explain it well or respond to any questions about it.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"bsuite/bandit_scale-v0",
"bsuite/cartpole-v0",
"bsuite/cartpole_noise-v0",
"bsuite/cartpole_scale-v0",
"bsuite/cartpole_swingup-v0",
"bsuite/catch_noise-v0",
"bsuite/catch_scale-v0",
"bsuite/mnist_noise-v0",
"bsuite/mnist_scale-v0",
"bsuite/mountain_car_noise-v0",
"bsuite/mountain_car_scale-v0",
]

PASSING_PICKLE_ENVS = [
"bsuite/mnist-v0",
"bsuite/umbrella_length-v0",
"bsuite/discounting_chain-v0",
"bsuite/deep_sea-v0",
"bsuite/umbrella_distract-v0",
"bsuite/catch-v0",
"bsuite/memory_len-v0",
"bsuite/mountain_car-v0",
"bsuite/memory_size-v0",
"bsuite/deep_sea_stochastic-v0",
"bsuite/bandit-v0",
]


@pytest.mark.parametrize("env_id", PASSING_PICKLE_ENVS)
def test_pickle(env_id):
"""Test that pickling works."""
env = gym.make(env_id, **BSUITE_ENV_SETTINGS[env_id])
env_1 = gym.make(env_id, **BSUITE_ENV_SETTINGS[env_id])
env_2 = pickle.loads(pickle.dumps(env_1))

pickled_env = pickle.loads(pickle.dumps(env))
data_equivalence(env.reset(seed=42), pickled_env.reset(seed=42))
obs_1, info_1 = env_1.reset(seed=42)
obs_2, info_2 = env_2.reset(seed=42)
assert data_equivalence(obs_1, obs_2)
assert data_equivalence(info_1, info_2)
for _ in range(100):
actions = int(env_1.action_space.sample())
obs_1, reward_1, term_1, trunc_1, info_1 = env_1.step(actions)
obs_2, reward_2, term_2, trunc_2, info_2 = env_2.step(actions)
assert data_equivalence(obs_1, obs_2)
assert reward_1 == reward_2
assert term_1 == term_2 and trunc_1 == trunc_2
assert data_equivalence(info_1, info_2)

action = env.action_space.sample()
data_equivalence(env.step(action), pickled_env.step(action))
env.close()
pickled_env.close()
env_1.close()
env_2.close()
31 changes: 31 additions & 0 deletions tests/test_dm_control.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests the functionality of the DmControlCompatibility Wrapper on dm_control envs."""
import pickle
import warnings
from typing import Callable

Expand Down Expand Up @@ -103,6 +104,36 @@ def test_seeding(env_id):
env_2.close()


@pytest.mark.skip(
reason="Fatal Python error: Segmentation fault (with or without EzPickle"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know the source of this? Mujoco is the probable cause which we can probably get solved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it's to do with pluggy?

/Users/elliottower/anaconda3/envs/shimmy/bin/python /Applications/PyCharm.app/Contents/plugins/python/helpers/pycharm/_jb_pytest_runner.py --target test_dm_control.py::test_pickle 
Testing started at 1:17 PM ...
Launching pytest with arguments test_dm_control.py::test_pickle --no-header --no-summary -q in /Users/elliottower/Documents/GitHub/Shimmy/tests

============================= test session starts ==============================
collecting ... collected 85 items

test_dm_control.py::test_pickle[dm_control/acrobot-swingup-v0] Fatal Python error: Segmentation fault

Current thread 0x000000011149c600 (most recent call first):
  File "/Users/elliottower/Documents/GitHub/Shimmy/tests/test_dm_control.py", line 114 in test_pickle
  File "/Users/elliottower/anaconda3/envs/shimmy/lib/python3.9/site-packages/_pytest/python.py", line 192 in pytest_pyfunc_call
  File "/Users/elliottower/anaconda3/envs/shimmy/lib/python3.9/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/Users/elliottower/anaconda3/envs/shimmy/lib/python3.9/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/Users/elliottower/anaconda3/envs/shimmy/lib/python3.9/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/Users/elliottower/anaconda3/envs/shimmy/lib/python3.9/site-packages/_pytest/python.py", line 1761 in runtest
  File "/Users/elliottower/anaconda3/envs/shimmy/lib/python3.9/site-packages/_pytest/runner.py", line 166 in pytest_runtest_call
  File "/Users/elliottower/anaconda3/envs/shimmy/lib/python3.9/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/Users/elliottower/anaconda3/envs/shimmy/lib/python3.9/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/Users/elliottower/anaconda3/envs/shimmy/lib/python3.9/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/Users/elliottower/anaconda3/envs/shimmy/lib/python3.9/site-packages/_pytest/runner.py", line 259 in <lambda>
  File "/Users/elliottower/anaconda3/envs/shimmy/lib/python3.9/site-packages/_pytest/runner.py", line 338 in from_call
  File "/Users/elliottower/anaconda3/envs/shimmy/lib/python3.9/site-packages/_pytest/runner.py", line 258 in call_runtest_hook
  File "/Users/elliottower/anaconda3/envs/shimmy/lib/python3.9/site-packages/_pytest/runner.py", line 219 in call_and_report
  File "/Users/elliottower/anaconda3/envs/shimmy/lib/python3.9/site-packages/_pytest/runner.py", line 130 in runtestprotocol
  File "/Users/elliottower/anaconda3/envs/shimmy/lib/python3.9/site-packages/_pytest/runner.py", line 111 in pytest_runtest_protocol
  File "/Users/elliottower/anaconda3/envs/shimmy/lib/python3.9/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/Users/elliottower/anaconda3/envs/shimmy/lib/python3.9/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/Users/elliottower/anaconda3/envs/shimmy/lib/python3.9/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/Users/elliottower/anaconda3/envs/shimmy/lib/python3.9/site-packages/_pytest/main.py", line 347 in pytest_runtestloop
  File "/Users/elliottower/anaconda3/envs/shimmy/lib/python3.9/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/Users/elliottower/anaconda3/envs/shimmy/lib/python3.9/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/Users/elliottower/anaconda3/envs/shimmy/lib/python3.9/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/Users/elliottower/anaconda3/envs/shimmy/lib/python3.9/site-packages/_pytest/main.py", line 322 in _main
  File "/Users/elliottower/anaconda3/envs/shimmy/lib/python3.9/site-packages/_pytest/main.py", line 268 in wrap_session
  File "/Users/elliottower/anaconda3/envs/shimmy/lib/python3.9/site-packages/_pytest/main.py", line 315 in pytest_cmdline_main
  File "/Users/elliottower/anaconda3/envs/shimmy/lib/python3.9/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/Users/elliottower/anaconda3/envs/shimmy/lib/python3.9/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/Users/elliottower/anaconda3/envs/shimmy/lib/python3.9/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/Users/elliottower/anaconda3/envs/shimmy/lib/python3.9/site-packages/_pytest/config/__init__.py", line 164 in main
  File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pycharm/_jb_pytest_runner.py", line 51 in <module>

Process finished with exit code 139 (interrupted by signal 11: SIGSEGV)

)
@pytest.mark.parametrize("env_id", DM_CONTROL_ENV_IDS)
def test_pickle(env_id):
"""Test that dm-control seeding works."""
env_1 = gym.make(env_id)
env_2 = pickle.loads(pickle.dumps(env_1))

if "lqr" in env_id or (env_1.spec is not None and env_1.spec.nondeterministic):
# LQR fails this test currently.
return

obs_1, info_1 = env_1.reset(seed=42)
obs_2, info_2 = env_2.reset(seed=42)
assert data_equivalence(obs_1, obs_2)
assert data_equivalence(info_1, info_2)
for _ in range(100):
actions = env_1.action_space.sample()
obs_1, reward_1, term_1, trunc_1, info_1 = env_1.step(actions)
obs_2, reward_2, term_2, trunc_2, info_2 = env_2.step(actions)
assert data_equivalence(obs_1, obs_2)
assert reward_1 == reward_2
assert term_1 == term_2 and trunc_1 == trunc_2
assert data_equivalence(info_1, info_2)

env_1.close()
env_2.close()


@pytest.mark.parametrize("camera_id", [0, 1])
def test_rendering_camera_id(camera_id):
"""Test that dm-control rendering works."""
Expand Down
96 changes: 96 additions & 0 deletions tests/test_dm_control_multi_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Tests the multi-agent dm-control soccer environment."""
import pickle

import pytest
from dm_control.locomotion import soccer as dm_soccer
from gym.utils.env_checker import data_equivalence
elliottower marked this conversation as resolved.
Show resolved Hide resolved
from pettingzoo.test import parallel_api_test

from shimmy.dm_control_multiagent_compatibility import (
Expand Down Expand Up @@ -32,3 +34,97 @@ def test_check_env(walker_type):
parallel_api_test(env)

env.close()


@pytest.mark.parametrize("walker_type", WALKER_TYPES)
def test_seeding(walker_type):
"""Tests the seeding of the openspiel conversion wrapper."""
# load envs
env1 = dm_soccer.load(
team_size=2,
time_limit=10.0,
disable_walker_contacts=False,
enable_field_box=True,
terminate_on_goal=False,
walker_type=walker_type,
)
env2 = dm_soccer.load(
team_size=2,
time_limit=10.0,
disable_walker_contacts=False,
enable_field_box=True,
terminate_on_goal=False,
walker_type=walker_type,
)

# convert the environment
env1 = DmControlMultiAgentCompatibilityV0(env1, render_mode=None)
env2 = DmControlMultiAgentCompatibilityV0(env2, render_mode=None)

env1.reset(seed=42)
env2.reset(seed=42)

for agent in env1.possible_agents:
env1.action_space(agent).seed(42)
env2.action_space(agent).seed(42)

while env1.agents:
actions1 = {agent: env1.action_space(agent).sample() for agent in env1.agents}
actions2 = {agent: env2.action_space(agent).sample() for agent in env2.agents}

assert data_equivalence(actions1, actions2), "Incorrect action seeding"

obs1, rewards1, terminations1, truncations1, infos1 = env1.step(actions1)
obs2, rewards2, terminations2, truncations2, infos2 = env2.step(actions2)

assert not data_equivalence(
obs1, obs2
), "Observations are expected to be slightly different (ball position/velocity)"
assert data_equivalence(rewards1, rewards2), "Incorrect values for rewards"
assert data_equivalence(terminations1, terminations2), "Incorrect terminations."
assert data_equivalence(truncations1, truncations2), "Incorrect truncations"
assert data_equivalence(infos1, infos2), "Incorrect infos"
env1.close()
env2.close()


@pytest.mark.skip(
reason="TypeError: __init__() missing 1 required positional argument: 'env'"
elliottower marked this conversation as resolved.
Show resolved Hide resolved
)
@pytest.mark.parametrize("walker_type", WALKER_TYPES)
def test_pickle(walker_type):
"""Tests the seeding of the openspiel conversion wrapper."""
env1 = dm_soccer.load(
team_size=2,
time_limit=10.0,
disable_walker_contacts=False,
enable_field_box=True,
terminate_on_goal=False,
walker_type=walker_type,
)
env1 = DmControlMultiAgentCompatibilityV0(env1, render_mode=None)
env2 = pickle.loads(pickle.dumps(env1))

env1.reset(seed=42)
env2.reset(seed=42)

for agent in env1.possible_agents:
env1.action_space(agent).seed(42)
env2.action_space(agent).seed(42)

while env1.agents:
actions1 = {agent: env1.action_space(agent).sample() for agent in env1.agents}
actions2 = {agent: env2.action_space(agent).sample() for agent in env2.agents}

assert data_equivalence(actions1, actions2), "Incorrect action seeding"

obs1, rewards1, terminations1, truncations1, infos1 = env1.step(actions1)
obs2, rewards2, terminations2, truncations2, infos2 = env2.step(actions2)

assert data_equivalence(obs1, obs2), "Incorrect observations"
assert data_equivalence(rewards1, rewards2), "Incorrect values for rewards"
assert data_equivalence(terminations1, terminations2), "Incorrect terminations."
assert data_equivalence(truncations1, truncations2), "Incorrect truncations"
assert data_equivalence(infos1, infos2), "Incorrect infos"
env1.close()
env2.close()
Loading