Skip to content

Commit

Permalink
AtariPreprocess - Add an option for tuple[int, int] screen-size (#1105)
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudo-rnd-thoughts authored Jul 3, 2024
1 parent c3af58e commit b064b68
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 12 deletions.
28 changes: 16 additions & 12 deletions gymnasium/wrappers/atari_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
env: gym.Env,
noop_max: int = 30,
frame_skip: int = 4,
screen_size: int = 84,
screen_size: int | tuple[int, int] = 84,
terminal_on_life_loss: bool = False,
grayscale_obs: bool = True,
grayscale_newaxis: bool = False,
Expand All @@ -67,7 +67,7 @@ def __init__(
env (Env): The environment to apply the preprocessing
noop_max (int): For No-op reset, the max number no-ops actions are taken at reset, to turn off, set to 0.
frame_skip (int): The number of frames between new observation the agents observations effecting the frequency at which the agent experiences the game.
screen_size (int): resize Atari frame.
screen_size (int | tuple[int, int]): resize Atari frame.
terminal_on_life_loss (bool): `if True`, then :meth:`step()` returns `terminated=True` whenever a
life is lost.
grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation
Expand Down Expand Up @@ -101,7 +101,11 @@ def __init__(
)

assert frame_skip > 0
assert screen_size > 0
assert (isinstance(screen_size, int) and screen_size > 0) or (
isinstance(screen_size, tuple)
and len(screen_size) == 2
and all(isinstance(size, int) and size > 0 for size in screen_size)
), f"Expect the `screen_size` to be positive, actually: {screen_size}"
assert noop_max >= 0
if frame_skip > 1 and getattr(env.unwrapped, "_frameskip", None) != 1:
raise ValueError(
Expand All @@ -111,7 +115,11 @@ def __init__(
assert env.unwrapped.get_action_meanings()[0] == "NOOP"

self.frame_skip = frame_skip
self.screen_size = screen_size
self.screen_size: tuple[int, int] = (
screen_size
if isinstance(screen_size, tuple)
else (screen_size, screen_size)
)
self.terminal_on_life_loss = terminal_on_life_loss
self.grayscale_obs = grayscale_obs
self.grayscale_newaxis = grayscale_newaxis
Expand All @@ -133,15 +141,11 @@ def __init__(
self.lives = 0
self.game_over = False

_low, _high, _obs_dtype = (
(0, 255, np.uint8) if not scale_obs else (0, 1, np.float32)
)
_shape = (screen_size, screen_size, 1 if grayscale_obs else 3)
_low, _high, _dtype = (0, 1, np.float32) if scale_obs else (0, 255, np.uint8)
_shape = self.screen_size + (1 if grayscale_obs else 3,)
if grayscale_obs and not grayscale_newaxis:
_shape = _shape[:-1] # Remove channel axis
self.observation_space = Box(
low=_low, high=_high, shape=_shape, dtype=_obs_dtype
)
self.observation_space = Box(low=_low, high=_high, shape=_shape, dtype=_dtype)

@property
def ale(self):
Expand Down Expand Up @@ -214,7 +218,7 @@ def _get_obs(self):

obs = cv2.resize(
self.obs_buffer[0],
(self.screen_size, self.screen_size),
self.screen_size,
interpolation=cv2.INTER_AREA,
)

Expand Down
23 changes: 23 additions & 0 deletions tests/wrappers/test_atari_preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Test suite for AtariProcessing wrapper."""

import re

import numpy as np
import pytest

Expand Down Expand Up @@ -84,3 +86,24 @@ def test_atari_preprocessing_scale(grayscale, scaled, max_test_steps=10):

step_i += 1
env.close()


def test_screen_size():
env = gym.make("ALE/Pong-v5", frameskip=1)

assert AtariPreprocessing(env).screen_size == (84, 84)
assert AtariPreprocessing(env, screen_size=50).screen_size == (50, 50)
assert AtariPreprocessing(env, screen_size=(100, 120)).screen_size == (100, 120)

with pytest.raises(
AssertionError, match="Expect the `screen_size` to be positive, actually: -1"
):
AtariPreprocessing(env, screen_size=-1)

with pytest.raises(
AssertionError,
match=re.escape("Expect the `screen_size` to be positive, actually: (-1, 10)"),
):
AtariPreprocessing(env, screen_size=(-1, 10))

env.close()

0 comments on commit b064b68

Please sign in to comment.