Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow custom rendering and instance generation methods for Minesweeper #85

Merged
merged 79 commits into from
Apr 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
ccda776
Changes from patch
Mar 27, 2023
cd85f11
Generator and viewer
Mar 27, 2023
f4d07f3
Changes from patch
Mar 27, 2023
e13bd18
Fix: naming
Mar 27, 2023
af9d865
Merge branch 'main' into tristan/minesweeper
Mar 27, 2023
f5076c2
Some more generic fixes but more work required
Mar 27, 2023
793af1d
Separate reward definitions
Mar 27, 2023
8497f84
Remove action history
Mar 27, 2023
c15e722
Cleanup
Mar 27, 2023
8b85633
Update jumanji/environments/logic/rubiks_cube/env.py
TristanKalloniatis Mar 27, 2023
7bb1c40
Update jumanji/environments/logic/rubiks_cube/env_viewer.py
TristanKalloniatis Mar 27, 2023
470ba66
Update jumanji/environments/logic/rubiks_cube/generator.py
TristanKalloniatis Mar 27, 2023
f3a7be4
Update jumanji/environments/logic/rubiks_cube/utils.py
TristanKalloniatis Mar 27, 2023
ee3def9
Update jumanji/environments/logic/rubiks_cube/generator.py
TristanKalloniatis Mar 27, 2023
8b9356c
Update jumanji/environments/logic/rubiks_cube/utils.py
TristanKalloniatis Mar 27, 2023
afdc1c9
Update jumanji/environments/logic/rubiks_cube/utils.py
TristanKalloniatis Mar 27, 2023
0d0abdb
Update jumanji/environments/logic/rubiks_cube/generator.py
TristanKalloniatis Mar 27, 2023
72846b1
Spacing
Mar 27, 2023
dba2334
Merge branch 'main' into tristan/minesweeper
Mar 27, 2023
14904ff
More generic
Mar 27, 2023
e87ddeb
Minor cleaning
Mar 27, 2023
86f1852
Don't repeat attributes in env and generator
Mar 27, 2023
a2472d5
Merge remote-tracking branch 'origin/main'
Mar 27, 2023
9473883
Lint
Mar 27, 2023
1c44c71
Remove minesweeper changes for separate PR
Mar 27, 2023
8d83792
Simplify
Mar 27, 2023
5c5f8da
Merge remote-tracking branch 'origin/main'
Mar 28, 2023
0bc3f33
To sync
Mar 28, 2023
cebb72f
To sync
Mar 28, 2023
140df27
Merge branch 'main' into tristan/minesweeper
Mar 28, 2023
379e483
To sync
Mar 28, 2023
b8f4c8a
To sync
Mar 28, 2023
77e0b59
Imports
Mar 28, 2023
7a20280
Some review comments
Mar 28, 2023
2816e31
Return
Mar 28, 2023
9e78204
Merge branch 'main' into tristan/minesweeper
Mar 28, 2023
78b3da4
Return
Mar 28, 2023
40c5a53
Merge branch 'instadeepai:main' into main
TristanKalloniatis Mar 29, 2023
454de2a
Merge branch 'main' of https://github.com/TristanKalloniatis/jumanji
Mar 29, 2023
3f08619
Import
Mar 29, 2023
f869670
Merge branch 'main' into tristan/minesweeper
Mar 29, 2023
616832e
Import
Mar 29, 2023
c4d8e7c
Generic types
Mar 29, 2023
f3c6935
Generic types
Mar 29, 2023
74a6d76
Typing
Mar 29, 2023
ba01b17
Merge branch 'main' into tristan/minesweeper
Mar 29, 2023
2e6b5b3
Typing
Mar 29, 2023
9650fa1
Empty
Mar 29, 2023
3a8493c
Empty
Mar 29, 2023
f8ced45
Clement suggestions
Mar 29, 2023
7dcaf77
Merge branch 'main' into tristan/minesweeper
Mar 29, 2023
27c1241
Tidy
Mar 29, 2023
3338299
Rename
Mar 29, 2023
4b4f0c1
Rename more
Mar 29, 2023
40fcc88
Merge branch 'main' into tristan/minesweeper
Mar 29, 2023
2177837
Rename more
Mar 29, 2023
60bbc34
Lint
Mar 29, 2023
cbda5ca
test commit
TristanKalloniatis Mar 29, 2023
e4f906b
Merge branch 'main' into main
TristanKalloniatis Mar 29, 2023
d3636e6
test commit
TristanKalloniatis Mar 30, 2023
7fbf79f
Merge branch 'main' of https://github.com/TristanKalloniatis/jumanji
TristanKalloniatis Mar 30, 2023
33936b7
test commit
TristanKalloniatis Mar 30, 2023
ee4e1a1
Merge branch 'main' of https://github.com/TristanKalloniatis/jumanji …
TristanKalloniatis Mar 30, 2023
30a722d
Merge branch 'main' into main
TristanKalloniatis Mar 30, 2023
c0cc2cb
Merge remote-tracking branch 'origin/main' into tristan/minesweeper
TristanKalloniatis Mar 30, 2023
1f1edd8
Daniel changes
TristanKalloniatis Mar 30, 2023
2618d1b
test commit
TristanKalloniatis Mar 30, 2023
925e4f9
Merge branch 'instadeepai:main' into main
TristanKalloniatis Mar 30, 2023
1b7a0f4
Merge branch 'main' of https://github.com/TristanKalloniatis/jumanji …
TristanKalloniatis Mar 30, 2023
6a0c504
Undo
TristanKalloniatis Mar 30, 2023
fee60f5
Redo
TristanKalloniatis Mar 30, 2023
562b2e8
Formatting
TristanKalloniatis Mar 30, 2023
226acca
Naming
TristanKalloniatis Mar 30, 2023
e4fd5d4
Docstring
TristanKalloniatis Mar 30, 2023
856b220
Lint
TristanKalloniatis Mar 30, 2023
51015da
Clement suggestions
TristanKalloniatis Mar 30, 2023
859ccf4
Docstrings
TristanKalloniatis Mar 31, 2023
8fd5583
Merge branch 'instadeepai:main' into main
TristanKalloniatis Mar 31, 2023
91a1524
Merge branch 'main' of https://github.com/TristanKalloniatis/jumanji …
TristanKalloniatis Mar 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions jumanji/environments/logic/minesweeper/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@

from jumanji.environments.logic.minesweeper.constants import UNEXPLORED_ID
from jumanji.environments.logic.minesweeper.env import Minesweeper
from jumanji.environments.logic.minesweeper.generator import UniformSamplingGenerator
from jumanji.environments.logic.minesweeper.types import State


@pytest.fixture
def minesweeper_env() -> Minesweeper:
"""Fixture for a default minesweeper env"""
return Minesweeper()
"""Fixture for a default minesweeper environment with 10 rows and columns, and 10 mines."""
return Minesweeper(
generator=UniformSamplingGenerator(num_rows=10, num_cols=10, num_mines=10)
)


@pytest.fixture
Expand Down
4 changes: 1 addition & 3 deletions jumanji/environments/logic/minesweeper/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
UNEXPLORED_ID: int = -1
IS_MINE: int = 1
PATCH_SIZE: int = 3
REVEALED_EMPTY_SQUARE_REWARD: float = 1.0
REVEALED_MINE_OR_INVALID_ACTION_REWARD: float = 0.0
COLOUR_MAPPING: list = [
DEFAULT_COLOR_MAPPING: list = [
"orange",
"blue",
"green",
Expand Down
224 changes: 53 additions & 171 deletions jumanji/environments/logic/minesweeper/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, Sequence, Tuple
from typing import Optional, Sequence, Tuple

import chex
import jax
import jax.numpy as jnp
import matplotlib.animation
import matplotlib.pyplot as plt
from numpy.typing import NDArray

import jumanji.environments
from jumanji import specs
from jumanji.env import Environment
from jumanji.environments.logic.minesweeper.constants import (
COLOUR_MAPPING,
PATCH_SIZE,
UNEXPLORED_ID,
)
from jumanji.environments.logic.minesweeper.constants import PATCH_SIZE, UNEXPLORED_ID
from jumanji.environments.logic.minesweeper.done import DefaultDoneFn, DoneFn
from jumanji.environments.logic.minesweeper.generator import (
Generator,
UniformSamplingGenerator,
)
from jumanji.environments.logic.minesweeper.reward import DefaultRewardFn, RewardFn
from jumanji.environments.logic.minesweeper.types import Observation, State
from jumanji.environments.logic.minesweeper.utils import (
count_adjacent_mines,
create_flat_mine_locations,
explored_mine,
)
from jumanji.environments.logic.minesweeper.utils import count_adjacent_mines
from jumanji.environments.logic.minesweeper.viewer import MinesweeperViewer
from jumanji.types import TimeStep, restart, termination, transition
from jumanji.viewer import Viewer


class Minesweeper(Environment[State]):
Expand All @@ -53,7 +50,7 @@ class Minesweeper(Environment[State]):
specifies how many timesteps have elapsed since environment reset.

- action:
multi discrete array containing the square to explore (height and width).
multi discrete array containing the square to explore (row and col).

- reward: jax array (float32):
Configurable function of state and action. By default:
Expand Down Expand Up @@ -92,46 +89,47 @@ class Minesweeper(Environment[State]):

def __init__(
self,
num_rows: int = 10,
num_cols: int = 10,
num_mines: int = 10,
generator: Optional[Generator] = None,
reward_function: Optional[RewardFn] = None,
done_function: Optional[DoneFn] = None,
color_mapping: Optional[List[str]] = None,
viewer: Optional[Viewer[State]] = None,
):
"""Instantiate a `Minesweeper` environment.

Args:
num_rows: number of rows, i.e. height of the board. Defaults to 10.
num_cols: number of columns, i.e. width of the board. Defaults to 10.
num_mines: number of mines on the board. Defaults to 10.
generator: `Generator` to generate problem instances on environment reset.
Implemented options are [`SamplingGenerator`]. Defaults to `SamplingGenerator`.
The generator will have attributes:
- num_rows: number of rows, i.e. height of the board. Defaults to 10.
- num_cols: number of columns, i.e. width of the board. Defaults to 10.
- num_mines: number of mines generated. Defaults to 10.
reward_function: `RewardFn` whose `__call__` method computes the reward of an
environment transition based on the given current state and selected action.
Implemented options are [`DefaultRewardFn`]. Defaults to `DefaultRewardFn`.
Implemented options are [`DefaultRewardFn`]. Defaults to `DefaultRewardFn`, giving
a reward of 1.0 for revealing an empty square, 0.0 for revealing a mine, and
0.0 for an invalid action (selecting an already revealed square).
done_function: `DoneFn` whose `__call__` method computes the done signal given the
current state, action taken, and next state.
Implemented options are [`DefaultDoneFn`]. Defaults to `DefaultDoneFn`.
color_mapping: colour map used for rendering.
Implemented options are [`DefaultDoneFn`]. Defaults to `DefaultDoneFn`, ending the
episode on solving the board, revealing a mine, or picking an invalid action.
viewer: `Viewer` to support rendering and animation methods.
Implemented options are [`MinesweeperViewer`]. Defaults to `MinesweeperViewer`.
"""
clement-bonnet marked this conversation as resolved.
Show resolved Hide resolved
if num_rows <= 1 or num_cols <= 1:
raise ValueError(
f"Should make a board of height and width greater than 1, "
f"got num_rows={num_rows}, num_cols={num_cols}"
)
if num_mines < 0 or num_mines >= num_rows * num_cols:
raise ValueError(
f"Number of mines should be constrained between 0 and the size of the board, "
f"got {num_mines}"
)
self.num_rows = num_rows
self.num_cols = num_cols
self.num_mines = num_mines
self.reward_function = reward_function or DefaultRewardFn()
self.reward_function = reward_function or DefaultRewardFn(
revealed_empty_square_reward=1.0,
revealed_mine_reward=0.0,
invalid_action_reward=0.0,
)
self.done_function = done_function or DefaultDoneFn()

self.cmap = color_mapping if color_mapping else COLOUR_MAPPING
self.figure_name = f"{num_rows}x{num_cols} Minesweeper"
self.figure_size = (6.0, 6.0)
self.generator = generator or UniformSamplingGenerator(
num_rows=10, num_cols=10, num_mines=10
)
self.num_rows = self.generator.num_rows
self.num_cols = self.generator.num_cols
self.num_mines = self.generator.num_mines
self._viewer = viewer or MinesweeperViewer(
num_rows=self.num_rows, num_cols=self.num_cols
)

def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
"""Resets the environment.
Expand All @@ -144,25 +142,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
timestep: `TimeStep` corresponding to the first timestep returned by the
environment.
"""
key, sample_key = jax.random.split(key)
board = jnp.full(
shape=(self.num_rows, self.num_cols),
fill_value=UNEXPLORED_ID,
dtype=jnp.int32,
)
step_count = jnp.array(0, jnp.int32)
flat_mine_locations = create_flat_mine_locations(
key=sample_key,
num_rows=self.num_rows,
num_cols=self.num_cols,
num_mines=self.num_mines,
)
state = State(
board=board,
step_count=step_count,
key=key,
flat_mine_locations=flat_mine_locations,
)
state = self.generator(key)
observation = self._state_to_observation(state=state)
timestep = restart(observation=observation)
return state, timestep
Expand All @@ -180,9 +160,7 @@ def step(
next_state: `State` corresponding to the next state of the environment,
next_timestep: `TimeStep` corresponding to the timestep returned by the environment.
"""
board = state.board
action_height, action_width = action
board = board.at[action_height, action_width].set(
board = state.board.at[tuple(action)].set(
count_adjacent_mines(state=state, action=action)
)
step_count = state.step_count + 1
Expand Down Expand Up @@ -272,134 +250,38 @@ def _state_to_observation(self, state: State) -> Observation:
step_count=state.step_count,
)

def render(self, state: State) -> None:
"""Render the given environment state using matplotlib.
def render(self, state: State) -> Optional[NDArray]:
"""Renders the current state of the board.

Args:
state: environment state to be rendered.

state: the current state to be rendered.
"""
self._clear_display()
fig, ax = self._get_fig_ax()
self._draw(ax, state)
self._update_display(fig)
return self._viewer.render(state=state)

def animate(
self,
states: Sequence[State],
interval: int = 200,
save_path: Optional[str] = None,
) -> matplotlib.animation.FuncAnimation:
"""Create an animation from a sequence of environment states.
"""Creates an animated gif of the board based on the sequence of states.

Args:
states: sequence of environment states corresponding to consecutive timesteps.
interval: delay between frames in milliseconds, default to 200.
states: a list of `State` objects representing the sequence of states.
interval: the delay between frames in milliseconds, default to 200.
save_path: the path where the animation file should be saved. If it is None, the plot
will not be saved.

Returns:
Animation object that can be saved as a GIF, MP4, or rendered with HTML.
animation.FuncAnimation: the animation object that was created.
"""
fig, ax = self._get_fig_ax()
plt.tight_layout()
plt.close(fig)

def make_frame(state_index: int) -> None:
state = states[state_index]
self._draw(ax, state)

# Create the animation object.
self._animation = matplotlib.animation.FuncAnimation(
fig,
make_frame,
frames=len(states),
interval=interval,
return self._viewer.animate(
states=states, interval=interval, save_path=save_path
)

# Save the animation as a GIF.
if save_path:
self._animation.save(save_path)

return self._animation

def close(self) -> None:
"""Perform any necessary cleanup.

Environments will automatically :meth:`close()` themselves when
garbage collected or when the program exits.
"""
plt.close(self.figure_name)

def _get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]:
exists = plt.fignum_exists(self.figure_name)
if exists:
fig = plt.figure(self.figure_name)
ax = fig.get_axes()[0]
else:
fig = plt.figure(self.figure_name, figsize=self.figure_size)
plt.suptitle(self.figure_name)
plt.tight_layout()
if not plt.isinteractive():
fig.show()
ax = fig.add_subplot()
return fig, ax

def _draw(self, ax: plt.Axes, state: State) -> None:
ax.clear()
ax.set_xticks(jnp.arange(-0.5, self.num_cols - 1, 1))
ax.set_yticks(jnp.arange(-0.5, self.num_rows - 1, 1))
ax.tick_params(
top=False,
bottom=False,
left=False,
right=False,
labelleft=False,
labelbottom=False,
labeltop=False,
labelright=False,
)
background = jnp.ones_like(state.board)
for i in range(self.num_rows):
for j in range(self.num_cols):
background = self._render_grid_square(
state=state, ax=ax, i=i, j=j, background=background
)
ax.imshow(background, cmap="gray", vmin=0, vmax=1)
ax.grid(color="black", linestyle="-", linewidth=2)

def _render_grid_square(
self, state: State, ax: plt.Axes, i: int, j: int, background: chex.Array
) -> chex.Array:
board_value = state.board[i, j]
if board_value != UNEXPLORED_ID:
if explored_mine(state=state, action=jnp.array([i, j], dtype=jnp.int32)):
background = background.at[i, j].set(0)
else:
ax.text(
j,
i,
str(board_value),
color=self.cmap[board_value],
ha="center",
va="center",
fontsize="xx-large",
)
return background

def _update_display(self, fig: plt.Figure) -> None:
if plt.isinteractive():
# Required to update render when using Jupyter Notebook.
fig.canvas.draw()
if jumanji.environments.is_colab():
plt.show(self.figure_name)
else:
# Required to update render when not using Jupyter Notebook.
fig.canvas.draw_idle()
fig.canvas.flush_events()

def _clear_display(self) -> None:
if jumanji.environments.is_colab():
import IPython.display

IPython.display.clear_output(True)
self._viewer.close()
Loading