-
Notifications
You must be signed in to change notification settings - Fork 432
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
- Loading branch information
1 parent
4c99c70
commit 09a6714
Showing
1 changed file
with
371 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,371 @@ | ||
from __future__ import annotations | ||
|
||
import sys | ||
from typing import Any, Literal, Optional, Sequence, Union | ||
|
||
import ale_py | ||
import gymnasium | ||
import gymnasium.logger as logger | ||
import numpy as np | ||
from ale_py import roms | ||
from gymnasium import error, spaces, utils | ||
from gymnasium.utils import seeding | ||
|
||
if sys.version_info < (3, 11): | ||
from typing_extensions import NotRequired, TypedDict | ||
else: | ||
from typing import NotRequired, TypedDict | ||
|
||
|
||
class AtariEnvStepMetadata(TypedDict): | ||
lives: int | ||
episode_frame_number: int | ||
frame_number: int | ||
seeds: NotRequired[Sequence[int]] | ||
|
||
|
||
class AtariEnv(gymnasium.Env, utils.EzPickle): | ||
""" | ||
(A)rcade (L)earning (Gym) (Env)ironment. | ||
A Gym wrapper around the Arcade Learning Environment (ALE). | ||
""" | ||
|
||
# No render modes | ||
metadata = {"render_modes": ["human", "rgb_array"]} | ||
|
||
def __init__( | ||
self, | ||
game: str = "pong", | ||
mode: Optional[int] = None, | ||
difficulty: Optional[int] = None, | ||
obs_type: Literal["rgb", "grayscale", "ram"] = "rgb", | ||
frameskip: Union[tuple[int, int], int] = 4, | ||
repeat_action_probability: float = 0.25, | ||
full_action_space: bool = False, | ||
max_num_frames_per_episode: Optional[int] = None, | ||
render_mode: Optional[Literal["human", "rgb_array"]] = None, | ||
) -> None: | ||
""" | ||
Initialize the ALE for Gymnasium. | ||
Default parameters are taken from Machado et al., 2018. | ||
Args: | ||
game: str => Game to initialize env with. | ||
mode: Optional[int] => Game mode, see Machado et al., 2018 | ||
difficulty: Optional[int] => Game difficulty,see Machado et al., 2018 | ||
obs_type: str => Observation type in { 'rgb', 'grayscale', 'ram' } | ||
frameskip: Union[tuple[int, int], int] => | ||
Stochastic frameskip as tuple or fixed. | ||
repeat_action_probability: int => | ||
Probability to repeat actions, see Machado et al., 2018 | ||
full_action_space: bool => Use full action space? | ||
max_num_frames_per_episode: int => Max number of frame per epsiode. | ||
Once `max_num_frames_per_episode` is reached the episode is | ||
truncated. | ||
render_mode: str => One of { 'human', 'rgb_array' }. | ||
If `human` we'll interactively display the screen and enable | ||
game sounds. This will lock emulation to the ROMs specified FPS | ||
If `rgb_array` we'll return the `rgb` key in step metadata with | ||
the current environment RGB frame. | ||
Note: | ||
- The game must be installed, see ale-import-roms, or ale-py-roms. | ||
- Frameskip values of (low, high) will enable stochastic frame skip | ||
which will sample a random frameskip uniformly each action. | ||
- It is recommended to enable full action space. | ||
See Machado et al., 2018 for more details. | ||
References: | ||
`Revisiting the Arcade Learning Environment: Evaluation Protocols | ||
and Open Problems for General Agents`, Machado et al., 2018, JAIR | ||
URL: https://jair.org/index.php/jair/article/view/11182 | ||
""" | ||
if obs_type not in {"rgb", "grayscale", "ram"}: | ||
raise error.Error( | ||
f"Invalid observation type: {obs_type}. Expecting: rgb, grayscale, ram." | ||
) | ||
|
||
if type(frameskip) not in (int, tuple): | ||
raise error.Error(f"Invalid frameskip type: {type(frameskip)}.") | ||
if isinstance(frameskip, int) and frameskip <= 0: | ||
raise error.Error( | ||
f"Invalid frameskip of {frameskip}, frameskip must be positive." | ||
) | ||
elif isinstance(frameskip, tuple) and len(frameskip) != 2: | ||
raise error.Error( | ||
f"Invalid stochastic frameskip length of {len(frameskip)}, expected length 2." | ||
) | ||
elif isinstance(frameskip, tuple) and frameskip[0] > frameskip[1]: | ||
raise error.Error( | ||
f"Invalid stochastic frameskip, lower bound is greater than upper bound." | ||
) | ||
elif isinstance(frameskip, tuple) and frameskip[0] <= 0: | ||
raise error.Error( | ||
f"Invalid stochastic frameskip lower bound is greater than upper bound." | ||
) | ||
|
||
if render_mode is not None and render_mode not in {"rgb_array", "human"}: | ||
raise error.Error( | ||
f"Render mode {render_mode} not supported (rgb_array, human)." | ||
) | ||
|
||
utils.EzPickle.__init__( | ||
self, | ||
game, | ||
mode, | ||
difficulty, | ||
obs_type, | ||
frameskip, | ||
repeat_action_probability, | ||
full_action_space, | ||
max_num_frames_per_episode, | ||
render_mode, | ||
) | ||
|
||
# Initialize ALE | ||
self.ale = ale_py.ALEInterface() | ||
|
||
self._game = game | ||
self._game_mode = mode | ||
self._game_difficulty = difficulty | ||
|
||
self._frameskip = frameskip | ||
self._obs_type = obs_type | ||
self.render_mode = render_mode | ||
|
||
# Set logger mode to error only | ||
self.ale.setLoggerMode(ale_py.LoggerMode.Error) | ||
# Config sticky action prob. | ||
self.ale.setFloat("repeat_action_probability", repeat_action_probability) | ||
|
||
if max_num_frames_per_episode is not None: | ||
self.ale.setInt("max_num_frames_per_episode", max_num_frames_per_episode) | ||
|
||
# If render mode is human we can display screen and sound | ||
if render_mode == "human": | ||
self.ale.setBool("display_screen", True) | ||
self.ale.setBool("sound", True) | ||
|
||
# seed + load | ||
self.seed_game() | ||
self.load_game() | ||
|
||
# initialize action space | ||
self._action_set = ( | ||
self.ale.getLegalActionSet() | ||
if full_action_space | ||
else self.ale.getMinimalActionSet() | ||
) | ||
self.action_space = spaces.Discrete(len(self._action_set)) | ||
|
||
# initialize observation space | ||
if self._obs_type == "ram": | ||
self.observation_space = spaces.Box( | ||
low=0, high=255, dtype=np.uint8, shape=(self.ale.getRAMSize(),) | ||
) | ||
elif self._obs_type == "rgb" or self._obs_type == "grayscale": | ||
(screen_height, screen_width) = self.ale.getScreenDims() | ||
image_shape = ( | ||
screen_height, | ||
screen_width, | ||
) | ||
if self._obs_type == "rgb": | ||
image_shape += (3,) | ||
self.observation_space = spaces.Box( | ||
low=0, high=255, dtype=np.uint8, shape=image_shape | ||
) | ||
else: | ||
raise error.Error(f"Unrecognized observation type: {self._obs_type}") | ||
|
||
def seed_game(self, seed: Optional[int] = None) -> tuple[int, int]: | ||
"""Seeds the internal and ALE RNG.""" | ||
ss = np.random.SeedSequence(seed) | ||
np_seed, ale_seed = ss.generate_state(n_words=2) | ||
self._np_random, seed = seeding.np_random(int(np_seed)) | ||
self.ale.setInt("random_seed", np.int32(ale_seed)) | ||
return np_seed, ale_seed | ||
|
||
def load_game(self) -> None: | ||
"""This function initializes the ROM and sets the corresponding mode and difficulty.""" | ||
self.ale.loadROM(getattr(roms, self._game)) | ||
|
||
if self._game_mode is not None: | ||
self.ale.setMode(self._game_mode) | ||
if self._game_difficulty is not None: | ||
self.ale.setDifficulty(self._game_difficulty) | ||
|
||
def step( # pyright: ignore[reportIncompatibleMethodOverride] | ||
self, | ||
action: int, | ||
) -> tuple[np.ndarray, float, bool, bool, AtariEnvStepMetadata]: | ||
""" | ||
Perform one agent step, i.e., repeats `action` frameskip # of steps. | ||
Args: | ||
action_ind: int => Action index to execute | ||
Returns: | ||
tuple[np.ndarray, float, bool, bool, Dict[str, Any]] => | ||
observation, reward, terminal, truncation, metadata | ||
Note: `metadata` contains the keys "lives" and "rgb" if | ||
render_mode == 'rgb_array'. | ||
""" | ||
# If frameskip is a length 2 tuple then it's stochastic | ||
# frameskip between [frameskip[0], frameskip[1]] uniformly. | ||
if isinstance(self._frameskip, int): | ||
frameskip = self._frameskip | ||
elif isinstance(self._frameskip, tuple): | ||
frameskip = self.np_random.integers(*self._frameskip) | ||
else: | ||
raise error.Error(f"Invalid frameskip type: {self._frameskip}") | ||
|
||
# Frameskip | ||
reward = 0.0 | ||
for _ in range(frameskip): | ||
reward += self.ale.act(self._action_set[action]) | ||
is_terminal = self.ale.game_over(with_truncation=False) | ||
is_truncated = self.ale.game_truncated() | ||
|
||
return self._get_obs(), reward, is_terminal, is_truncated, self._get_info() | ||
|
||
def reset( # pyright: ignore[reportIncompatibleMethodOverride] | ||
self, | ||
*, | ||
seed: Optional[int] = None, | ||
options: Optional[dict[str, Any]] = None, | ||
) -> tuple[np.ndarray, AtariEnvStepMetadata]: | ||
"""Resets environment and returns initial observation.""" | ||
# sets the seeds if it's specified for both ALE and frameskip np | ||
# we only want to do this when commanded to so we don't reset all previous states, statistics, etc. | ||
seeded_with = None | ||
if seed is not None: | ||
seeded_with = self.seed_game(seed) | ||
self.load_game() | ||
|
||
self.ale.reset_game() | ||
|
||
obs = self._get_obs() | ||
info = self._get_info() | ||
if seeded_with is not None: | ||
info["seeds"] = seeded_with | ||
|
||
return obs, info | ||
|
||
def render(self) -> Optional[np.ndarray]: | ||
""" | ||
Render is not supported by ALE. We use a paradigm similar to | ||
Gym3 which allows you to specify `render_mode` during construction. | ||
For example, | ||
gym.make("ale-py:Pong-v0", render_mode="human") | ||
will display the ALE and maintain the proper interval to match the | ||
FPS target set by the ROM. | ||
""" | ||
if self.render_mode == "rgb_array": | ||
return self.ale.getScreenRGB() | ||
elif self.render_mode == "human": | ||
return | ||
else: | ||
raise error.Error( | ||
f"Invalid render mode `{self.render_mode}`. " | ||
"Supported modes: `human`, `rgb_array`." | ||
) | ||
|
||
def _get_obs(self) -> np.ndarray: | ||
""" | ||
Retreives the current observation. | ||
This is dependent on `self._obs_type`. | ||
""" | ||
if self._obs_type == "ram": | ||
return self.ale.getRAM() | ||
elif self._obs_type == "rgb": | ||
return self.ale.getScreenRGB() | ||
elif self._obs_type == "grayscale": | ||
return self.ale.getScreenGrayscale() | ||
else: | ||
raise error.Error(f"Unrecognized observation type: {self._obs_type}") | ||
|
||
def _get_info(self) -> AtariEnvStepMetadata: | ||
return { | ||
"lives": self.ale.lives(), | ||
"episode_frame_number": self.ale.getEpisodeFrameNumber(), | ||
"frame_number": self.ale.getFrameNumber(), | ||
} | ||
|
||
def get_keys_to_action(self) -> dict[tuple[int, ...], ale_py.Action]: | ||
""" | ||
Return keymapping -> actions for human play. | ||
""" | ||
UP = ord("w") | ||
LEFT = ord("a") | ||
RIGHT = ord("d") | ||
DOWN = ord("s") | ||
FIRE = ord(" ") | ||
|
||
mapping = { | ||
ale_py.Action.NOOP: (None,), | ||
ale_py.Action.UP: (UP,), | ||
ale_py.Action.FIRE: (FIRE,), | ||
ale_py.Action.DOWN: (DOWN,), | ||
ale_py.Action.LEFT: (LEFT,), | ||
ale_py.Action.RIGHT: (RIGHT,), | ||
ale_py.Action.UPFIRE: (UP, FIRE), | ||
ale_py.Action.DOWNFIRE: (DOWN, FIRE), | ||
ale_py.Action.LEFTFIRE: (LEFT, FIRE), | ||
ale_py.Action.RIGHTFIRE: (RIGHT, FIRE), | ||
ale_py.Action.UPLEFT: (UP, LEFT), | ||
ale_py.Action.UPRIGHT: (UP, RIGHT), | ||
ale_py.Action.DOWNLEFT: (DOWN, LEFT), | ||
ale_py.Action.DOWNRIGHT: (DOWN, RIGHT), | ||
ale_py.Action.UPLEFTFIRE: (UP, LEFT, FIRE), | ||
ale_py.Action.UPRIGHTFIRE: (UP, RIGHT, FIRE), | ||
ale_py.Action.DOWNLEFTFIRE: (DOWN, LEFT, FIRE), | ||
ale_py.Action.DOWNRIGHTFIRE: (DOWN, RIGHT, FIRE), | ||
} | ||
|
||
# Map | ||
# (key, key, ...) -> action_idx | ||
# where action_idx is the integer value of the action enum | ||
# | ||
return dict( | ||
zip( | ||
map(lambda action: tuple(sorted(mapping[action])), self._action_set), | ||
range(len(self._action_set)), | ||
) | ||
) | ||
|
||
def get_action_meanings(self) -> list[str]: | ||
""" | ||
Return the meaning of each integer action. | ||
""" | ||
keys = ale_py.Action.__members__.values() | ||
values = ale_py.Action.__members__.keys() | ||
mapping = dict(zip(keys, values)) | ||
return [mapping[action] for action in self._action_set] | ||
|
||
def clone_state(self, include_rng=False) -> ale_py.ALEState: | ||
"""Clone emulator state w/o system state. Restoring this state will | ||
*not* give an identical environment. For complete cloning and restoring | ||
of the full state, see `{clone,restore}_full_state()`.""" | ||
return self.ale.cloneState(include_rng=include_rng) | ||
|
||
def restore_state(self, state: ale_py.ALEState) -> None: | ||
"""Restore emulator state w/o system state.""" | ||
self.ale.restoreState(state) | ||
|
||
def clone_full_state(self) -> ale_py.ALEState: | ||
"""Deprecated method which would clone the emulator and system state.""" | ||
logger.warn( | ||
"`clone_full_state()` is deprecated and will be removed in a future release of `ale-py`. " | ||
"Please use `clone_state(include_rng=True)` which is equivalent to `clone_full_state`. " | ||
) | ||
return self.ale.cloneSystemState() | ||
|
||
def restore_full_state(self, state: ale_py.ALEState) -> None: | ||
"""Restore emulator state w/ system state including pseudorandomness.""" | ||
logger.warn( | ||
"restore_full_state() is deprecated and will be removed in a future release of `ale-py`. " | ||
"Please use `restore_state(state)` which will restore the state regardless of being a full or partial state. " | ||
) | ||
self.ale.restoreSystemState(state) |