From ccda7768c46b9fa5d34079d0be2d727ca1e278fa Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Mon, 27 Mar 2023 10:44:11 +0100 Subject: [PATCH 01/58] Changes from patch --- jumanji/environments/logic/rubiks_cube/env.py | 248 ++++-------------- .../logic/rubiks_cube/env_test.py | 70 +---- .../environments/logic/rubiks_cube/utils.py | 184 ++++++++++++- .../logic/rubiks_cube/utils_test.py | 97 ++++++- 4 files changed, 318 insertions(+), 281 deletions(-) diff --git a/jumanji/environments/logic/rubiks_cube/env.py b/jumanji/environments/logic/rubiks_cube/env.py index 7a80c81e4..fbea3f2d0 100644 --- a/jumanji/environments/logic/rubiks_cube/env.py +++ b/jumanji/environments/logic/rubiks_cube/env.py @@ -12,29 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Sequence, Tuple +from typing import Optional, Sequence, Tuple import chex import jax import jax.numpy as jnp -import matplotlib import matplotlib.animation -import matplotlib.pyplot as plt -import jumanji.environments from jumanji import specs from jumanji.env import Environment -from jumanji.environments.logic.rubiks_cube.constants import ( - DEFAULT_STICKER_COLORS, - CubeMovementAmount, - Face, +from jumanji.environments.logic.rubiks_cube.constants import Face +from jumanji.environments.logic.rubiks_cube.env_viewer import ( + DefaultRubiksCubeViewer, + RubiksCubeViewer, +) +from jumanji.environments.logic.rubiks_cube.generator import ( + Generator, + ScramblingGenerator, ) from jumanji.environments.logic.rubiks_cube.reward import RewardFn, SparseRewardFn -from jumanji.environments.logic.rubiks_cube.types import Cube, Observation, State +from jumanji.environments.logic.rubiks_cube.types import Observation, State from jumanji.environments.logic.rubiks_cube.utils import ( - generate_all_moves, + flatten_action, is_solved, - make_solved_cube, + rotate_cube, ) from jumanji.types import TimeStep, restart, termination, transition @@ -92,7 +93,8 @@ def __init__( time_limit: int = 200, num_scrambles_on_reset: int = 100, reward_fn: Optional[RewardFn] = None, - sticker_colors: Optional[list] = None, + env_viewer: Optional[RubiksCubeViewer] = None, + generator: Optional[Generator] = None, ): """Instantiate a `RubiksCube` environment. @@ -102,10 +104,15 @@ def __init__( num_scrambles_on_reset: the number of scrambles done from a solved Rubik's Cube in the generation of a random instance. The lower, the closer to a solved cube the reset state is. Defaults to 100. + Note that this argument will be ignored if a custom generator is passed. reward_fn: `RewardFn` whose `__call__` method computes the reward given the new state. Implemented options are [`SparseRewardFn`]. Defaults to `SparseRewardFn`. - sticker_colors: colors used in rendering the faces of the rubiks cube. - Defaults to `DEFAULT_STICKER_COLORS`. + env_viewer: RubiksCubeViewer to support rendering and animation methods. + Implemented options are [`DefaultRubiksCubeViewer`]. + Defaults to `DefaultRubiksCubeViewer`. + generator: Generator to generate problem instances on environment reset. + Implemented options are [`ScramblingGenerator`]. + Defaults to `ScramblingGenerator`. """ if cube_size < 2: raise ValueError( @@ -125,13 +132,12 @@ def __init__( self.time_limit = time_limit self.num_scrambles_on_reset = num_scrambles_on_reset self.reward_function = reward_fn or SparseRewardFn() - sticker_colors = sticker_colors or DEFAULT_STICKER_COLORS - self.sticker_colors_cmap = matplotlib.colors.ListedColormap(sticker_colors) - self.num_actions = len(Face) * (cube_size // 2) * len(CubeMovementAmount) - self.all_moves = generate_all_moves(cube_size=cube_size) - - self.figure_name = f"{cube_size}x{cube_size}x{cube_size} Rubik's Cube" - self.figure_size = (6.0, 6.0) + self._env_viewer = env_viewer or DefaultRubiksCubeViewer(cube_size=cube_size) + self._generator = generator or ScramblingGenerator( + cube_size=cube_size, + num_scrambles_on_reset=num_scrambles_on_reset, + time_limit=self.time_limit, + ) def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: """Resets the environment. @@ -144,30 +150,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: timestep: `TimeStep` corresponding to the first timestep returned by the environment. """ - key, scramble_key = jax.random.split(key) - flat_actions_in_scramble = jax.random.randint( - scramble_key, - minval=0, - maxval=self.num_actions, - shape=(self.num_scrambles_on_reset,), - dtype=jnp.int32, - ) - cube = self._scramble_solved_cube( - flat_actions_in_scramble=flat_actions_in_scramble - ) - action_history = jnp.zeros( - shape=(self.num_scrambles_on_reset + self.time_limit, 3), dtype=jnp.int32 - ) - action_history = action_history.at[: self.num_scrambles_on_reset].set( - self._unflatten_action(flat_actions_in_scramble).transpose() - ) - step_count = jnp.array(0, jnp.int32) - state = State( - cube=cube, - step_count=step_count, - key=key, - action_history=action_history, - ) + state = self._generator(key) observation = self._state_to_observation(state=state) timestep = restart(observation=observation) return state, timestep @@ -186,8 +169,13 @@ def step( next_state: `State` corresponding to the next state of the environment. next_timestep: `TimeStep` corresponding to the timestep returned by the environment. """ - flat_action = self._flatten_action(action) - cube = self._rotate_cube(cube=state.cube, flat_action=flat_action) + flattened_action = flatten_action( + unflattened_action=action, cube_size=self.cube_size + ) + cube = rotate_cube( + cube=state.cube, + flattened_action=flattened_action, + ) action_history = state.action_history.at[ self.num_scrambles_on_reset + state.step_count ].set(action) @@ -258,76 +246,18 @@ def action_spec(self) -> specs.MultiDiscreteArray: dtype=jnp.int32, ) - def _unflatten_action(self, action: chex.Array) -> chex.Array: - """Turn a flat action (index into the sequence of all moves) into a tuple: - - face (0-5). This indicates the face on which the layer will turn. - - depth (0-cube_size//2). This indicates how many layers down from the face - the turn will take place. - - amount (0-2). This indicates the amount of turning (see below). - - Convention: - - 0 = up face - - 1 = front face - - 2 = right face - - 3 = back face - - 4 = left face - - 5 = down face - All read in reading order when looking directly at a face. - - To look directly at the faces: - - UP: LEFT face on the left and BACK face pointing up - - FRONT: LEFT face on the left and UP face pointing up - - RIGHT: FRONT face on the left and UP face pointing up - - BACK: RIGHT face on the left and UP face pointing up - - LEFT: BACK face on the left and UP face pointing up - - DOWN: LEFT face on the left and FRONT face pointing up - - Turning amounts are when looking directly at a face: - - 0 = clockwise turn - - 1 = anticlockwise turn - - 2 = half turn - """ - face_and_depth, amount = jnp.divmod(action, len(CubeMovementAmount)) - face, depth = jnp.divmod(face_and_depth, self.cube_size // 2) - return jnp.stack([face, depth, amount], axis=0) - - def _flatten_action(self, action: chex.Array) -> chex.Array: - """Inverse of the `_flatten_action` method.""" - face, depth, amount = action - return ( - face * len(CubeMovementAmount) * (self.cube_size // 2) - + depth * len(CubeMovementAmount) - + amount - ) - - def _rotate_cube(self, cube: Cube, flat_action: chex.Array) -> Cube: - """Apply a flattened action (index into the sequence of all moves) to a cube.""" - moved_cube = jax.lax.switch(flat_action, self.all_moves, cube) - return moved_cube - - def _scramble_solved_cube(self, flat_actions_in_scramble: chex.Array) -> Cube: - """Return a scrambled cube according to a given sequence of flat actions.""" - cube = make_solved_cube(cube_size=self.cube_size) - cube, _ = jax.lax.scan( - lambda *args: (self._rotate_cube(*args), None), - cube, - flat_actions_in_scramble, - ) - return cube - def _state_to_observation(self, state: State) -> Observation: return Observation(cube=state.cube, step_count=state.step_count) - def render(self, state: State) -> None: - """Render frames of the environment for a given state using matplotlib. + def render(self, state: State, save_path: Optional[str] = None) -> None: + """Renders the current state of the game board. Args: - state: `State` object corresponding to the new state of the environment. + state: is the current game state to be rendered. + save_path: the path where the image should be saved. If it is None, the plot + will not be stored. """ - self._clear_display() - fig, ax = self._get_fig_ax() - self._draw(ax, state) - self._update_display(fig) + self._env_viewer.render(state=state) def animate( self, @@ -335,105 +265,25 @@ def animate( interval: int = 200, save_path: Optional[str] = None, ) -> matplotlib.animation.FuncAnimation: - """Create an animation from a sequence of environment states. + """Creates an animated gif of the 2048 game board based on the sequence of game states. Args: - states: sequence of environment states corresponding to consecutive timesteps. - interval: delay between frames in milliseconds, default to 200. + states: is a list of `State` objects representing the sequence of game states. + interval: the delay between frames in milliseconds, default to 200. save_path: the path where the animation file should be saved. If it is None, the plot - will not be saved. + will not be stored. Returns: - Animation that can be saved as a GIF, MP4, or rendered with HTML. + animation.FuncAnimation: the animation object that was created. """ - fig, ax = plt.subplots(nrows=3, ncols=2, figsize=self.figure_size) - fig.suptitle(self.figure_name) - plt.tight_layout() - ax = ax.flatten() - plt.close(fig) - - def make_frame(state_index: int) -> None: - state = states[state_index] - self._draw(ax, state) - - # Create the animation object. - self._animation = matplotlib.animation.FuncAnimation( - fig, - make_frame, - frames=len(states), - interval=interval, + return self._env_viewer.animate( + states=states, interval=interval, save_path=save_path ) - # Save the animation as a gif. - if save_path: - self._animation.save(save_path) - - return self._animation - def close(self) -> None: """Perform any necessary cleanup. Environments will automatically :meth:`close()` themselves when garbage collected or when the program exits. """ - plt.close(self.figure_name) - - def _get_fig_ax(self) -> Tuple[plt.Figure, List[plt.Axes]]: - exists = plt.fignum_exists(self.figure_name) - if exists: - fig = plt.figure(self.figure_name) - ax = fig.get_axes() - else: - fig, ax = plt.subplots( - nrows=3, ncols=2, figsize=self.figure_size, num=self.figure_name - ) - fig.suptitle(self.figure_name) - ax = ax.flatten() - plt.tight_layout() - plt.axis("off") - if not plt.isinteractive(): - fig.show() - return fig, ax - - def _draw(self, ax: List[plt.Axes], state: State) -> None: - i = 0 - for face in Face: - ax[i].clear() - ax[i].set_title(label=f"{face}") - ax[i].set_xticks(jnp.arange(-0.5, self.cube_size - 1, 1)) - ax[i].set_yticks(jnp.arange(-0.5, self.cube_size - 1, 1)) - ax[i].tick_params( - top=False, - bottom=False, - left=False, - right=False, - labelleft=False, - labelbottom=False, - labeltop=False, - labelright=False, - ) - ax[i].imshow( - state.cube[i], - cmap=self.sticker_colors_cmap, - vmin=0, - vmax=len(Face) - 1, - ) - ax[i].grid(color="black", linestyle="-", linewidth=2) - i += 1 - - def _update_display(self, fig: plt.Figure) -> None: - if plt.isinteractive(): - # Required to update render when using Jupyter Notebook. - fig.canvas.draw() - if jumanji.environments.is_colab(): - plt.show(self.figure_name) - else: - # Required to update render when not using Jupyter Notebook. - fig.canvas.draw_idle() - fig.canvas.flush_events() - - def _clear_display(self) -> None: - if jumanji.environments.is_colab(): - import IPython.display - - IPython.display.clear_output(True) + self._env_viewer.close() diff --git a/jumanji/environments/logic/rubiks_cube/env_test.py b/jumanji/environments/logic/rubiks_cube/env_test.py index 1abfa7e43..7b8aa08d5 100644 --- a/jumanji/environments/logic/rubiks_cube/env_test.py +++ b/jumanji/environments/logic/rubiks_cube/env_test.py @@ -28,72 +28,6 @@ from jumanji.types import TimeStep -@pytest.mark.parametrize("cube_size", [2, 3, 4, 5]) -def test_flatten_action(cube_size: int) -> None: - """Test that flattening and unflattening actions are inverse to each other.""" - env = RubiksCube(cube_size=cube_size) - flat_actions = jnp.arange( - len(Face) * (cube_size // 2) * len(CubeMovementAmount), dtype=jnp.int32 - ) - faces = jnp.arange(len(Face), dtype=jnp.int32) - depths = jnp.arange(cube_size // 2, dtype=jnp.int32) - amounts = jnp.arange(len(CubeMovementAmount), dtype=jnp.int32) - unflat_actions = jnp.stack( - [ - jnp.repeat(faces, len(CubeMovementAmount) * (cube_size // 2)), - jnp.concatenate( - [jnp.repeat(depths, len(CubeMovementAmount)) for _ in Face] - ), - jnp.concatenate([amounts for _ in range(len(Face) * (cube_size // 2))]), - ] - ) - assert jnp.array_equal(unflat_actions, env._unflatten_action(flat_actions)) - assert jnp.array_equal(flat_actions, env._flatten_action(unflat_actions)) - - -def test_scramble_on_reset( - rubiks_cube: RubiksCube, expected_scramble_result: chex.Array -) -> None: - """Test that the environment reset is performing correctly when given a particular scramble - (chosen manually). - """ - amount_to_index = { - CubeMovementAmount.CLOCKWISE: 0, - CubeMovementAmount.ANTI_CLOCKWISE: 1, - CubeMovementAmount.HALF_TURN: 2, - } - unflattened_sequence = jnp.array( - [ - [Face.UP.value, 0, amount_to_index[CubeMovementAmount.CLOCKWISE]], - [Face.LEFT.value, 0, amount_to_index[CubeMovementAmount.HALF_TURN]], - [Face.DOWN.value, 0, amount_to_index[CubeMovementAmount.ANTI_CLOCKWISE]], - [Face.UP.value, 0, amount_to_index[CubeMovementAmount.HALF_TURN]], - [Face.BACK.value, 0, amount_to_index[CubeMovementAmount.ANTI_CLOCKWISE]], - [Face.RIGHT.value, 0, amount_to_index[CubeMovementAmount.CLOCKWISE]], - [Face.FRONT.value, 0, amount_to_index[CubeMovementAmount.CLOCKWISE]], - [Face.RIGHT.value, 0, amount_to_index[CubeMovementAmount.ANTI_CLOCKWISE]], - [Face.LEFT.value, 0, amount_to_index[CubeMovementAmount.ANTI_CLOCKWISE]], - [Face.BACK.value, 0, amount_to_index[CubeMovementAmount.HALF_TURN]], - [Face.FRONT.value, 0, amount_to_index[CubeMovementAmount.ANTI_CLOCKWISE]], - [Face.UP.value, 0, amount_to_index[CubeMovementAmount.CLOCKWISE]], - [Face.DOWN.value, 0, amount_to_index[CubeMovementAmount.CLOCKWISE]], - ], - dtype=jnp.int32, - ) - flat_sequence = jnp.array( - [0, 14, 16, 2, 10, 6, 3, 7, 13, 11, 4, 0, 15], dtype=jnp.int32 - ) - assert jnp.array_equal( - unflattened_sequence.transpose(), - rubiks_cube._unflatten_action(action=flat_sequence), - ) - assert jnp.array_equal( - flat_sequence, jax.vmap(rubiks_cube._flatten_action)(unflattened_sequence) - ) - cube = rubiks_cube._scramble_solved_cube(flat_actions_in_scramble=flat_sequence) - assert jnp.array_equal(expected_scramble_result, cube) - - def test_rubiks_cube__reset(rubiks_cube: RubiksCube) -> None: """Validates the jitted reset of the environment.""" chex.clear_trace_counter() @@ -165,7 +99,7 @@ def test_rubiks_cube__does_not_smoke(cube_size: int) -> None: def test_rubiks_cube__render( monkeypatch: pytest.MonkeyPatch, rubiks_cube: RubiksCube ) -> None: - """Check that the render method builds the figure but does not display it.""" + """Test that the render method builds the figure (but does not display it).""" monkeypatch.setattr(plt, "show", lambda fig: None) state, timestep = rubiks_cube.reset(jax.random.PRNGKey(0)) rubiks_cube.render(state) @@ -196,7 +130,7 @@ def test_rubiks_cube__done(time_limit: int) -> None: def test_rubiks_cube__animate( rubiks_cube: RubiksCube, mocker: pytest_mock.MockerFixture ) -> None: - """Check that the `animate` method creates the animation correctly.""" + """Test that the `animate` method creates the animation correctly (but does not display it).""" states = mocker.MagicMock() animation = rubiks_cube.animate(states) assert isinstance(animation, matplotlib.animation.Animation) diff --git a/jumanji/environments/logic/rubiks_cube/utils.py b/jumanji/environments/logic/rubiks_cube/utils.py index e53defbeb..e8a16bbce 100644 --- a/jumanji/environments/logic/rubiks_cube/utils.py +++ b/jumanji/environments/logic/rubiks_cube/utils.py @@ -15,6 +15,7 @@ from typing import Callable, List import chex +import jax from jax import numpy as jnp from jumanji.environments.logic.rubiks_cube.constants import CubeMovementAmount, Face @@ -39,19 +40,32 @@ # Turn amounts (eg clockwise) are when looking directly at the face -def make_solved_cube(cube_size: int) -> chex.Array: +def make_solved_cube(cube_size: int) -> Cube: + """Make a solved cube of a given size. + Args: + cube_size: the size of the cube to generate. + Returns: + A solved cube, ie with all faces a uniform id (sticker color). + """ return jnp.stack( [face.value * jnp.ones((cube_size, cube_size), dtype=jnp.int8) for face in Face] ) def is_solved(cube: Cube) -> chex.Array: + """Check if a cube is solved + Args: + cube: the cube to check. + Returns: + Whether or not the cube is solved (all faces have a unique id). + """ max_sticker_by_side = jnp.max(cube, axis=(-1, -2)) min_sticker_by_side = jnp.min(cube, axis=(-1, -2)) return jnp.array_equal(max_sticker_by_side, min_sticker_by_side) def sparse_reward_function(state: State) -> chex.Array: + """A sparse reward function: +1 if the cube is solved, otherwise 0""" solved = is_solved(state.cube) return jnp.array(solved, float) @@ -98,6 +112,14 @@ def do_rotation( def generate_up_move(amount: CubeMovementAmount, depth: int) -> Callable[[Cube], Cube]: + """Generate the move corresponding to turning the up face. + Args: + amount: how much to turn the face by. + depth: the number of layers into the cube where the move is performed. + Returns: + A callable that performs the specified up move. + """ + def up_move_function(cube: Cube) -> Cube: cube_size = cube.shape[-1] adjacent_faces = jnp.array( @@ -135,6 +157,14 @@ def up_move_function(cube: Cube) -> Cube: def generate_front_move( amount: CubeMovementAmount, depth: int ) -> Callable[[Cube], Cube]: + """Generate the move corresponding to turning the front face. + Args: + amount: how much to turn the face by. + depth: the number of layers into the cube where the move is performed. + Returns: + A callable that performs the specified front move. + """ + def front_move_function(cube: Cube) -> Cube: cube_size = cube.shape[-1] adjacent_faces = jnp.array( @@ -172,6 +202,14 @@ def front_move_function(cube: Cube) -> Cube: def generate_right_move( amount: CubeMovementAmount, depth: int ) -> Callable[[Cube], Cube]: + """Generate the move corresponding to turning the right face. + Args: + amount: how much to turn the face by. + depth: the number of layers into the cube where the move is performed. + Returns: + A callable that performs the specified right move. + """ + def right_move_function(cube: Cube) -> Cube: cube_size = cube.shape[-1] adjacent_faces = jnp.array( @@ -209,6 +247,14 @@ def right_move_function(cube: Cube) -> Cube: def generate_back_move( amount: CubeMovementAmount, depth: int ) -> Callable[[Cube], Cube]: + """Generate the move corresponding to turning the back face. + Args: + amount: how much to turn the face by. + depth: the number of layers into the cube where the move is performed. + Returns: + A callable that performs the specified back move. + """ + def back_move_function(cube: Cube) -> Cube: cube_size = cube.shape[-1] adjacent_faces = jnp.array( @@ -246,6 +292,14 @@ def back_move_function(cube: Cube) -> Cube: def generate_left_move( amount: CubeMovementAmount, depth: int ) -> Callable[[Cube], Cube]: + """Generate the move corresponding to turning the left face. + Args: + amount: how much to turn the face by. + depth: the number of layers into the cube where the move is performed. + Returns: + A callable that performs the specified left move. + """ + def left_move_function(cube: Cube) -> Cube: cube_size = cube.shape[-1] adjacent_faces = jnp.array( @@ -283,6 +337,14 @@ def left_move_function(cube: Cube) -> Cube: def generate_down_move( amount: CubeMovementAmount, depth: int ) -> Callable[[Cube], Cube]: + """Generate the move corresponding to turning the down face. + Args: + amount: how much to turn the face by. + depth: the number of layers into the cube where the move is performed. + Returns: + A callable that performs the specified down move. + """ + def down_move_function(cube: Cube) -> Cube: cube_size = cube.shape[-1] adjacent_faces = jnp.array( @@ -318,6 +380,7 @@ def down_move_function(cube: Cube) -> Cube: def generate_all_moves(cube_size: int) -> List[Callable[[Cube], Cube]]: + """Generate a list of all moves for the given cube size.""" return [ f(amount, depth) for f in [ @@ -331,3 +394,122 @@ def generate_all_moves(cube_size: int) -> List[Callable[[Cube], Cube]]: for depth in range(cube_size // 2) for amount in CubeMovementAmount ] + + +def unflatten_action(flattened_action: chex.Array, cube_size: int) -> chex.Array: + """Translate from the flat action representation to the unflattened representation. + Args: + flattened_action: index into the sequence of all moves. + cube_size: the size of the cube in question. + Returns: + Unflattened action, ie a tuple: + - face (0-5). This indicates the face on which the layer will turn. + - depth (0-cube_size//2). This indicates how many layers down from the face + the turn will take place. + - amount (0-2). This indicates the amount of turning (see below). + + Convention: + - 0 = up face + - 1 = front face + - 2 = right face + - 3 = back face + - 4 = left face + - 5 = down face + All read in reading order when looking directly at a face. + + To look directly at the faces: + - UP: LEFT face on the left and BACK face pointing up + - FRONT: LEFT face on the left and UP face pointing up + - RIGHT: FRONT face on the left and UP face pointing up + - BACK: RIGHT face on the left and UP face pointing up + - LEFT: BACK face on the left and UP face pointing up + - DOWN: LEFT face on the left and FRONT face pointing up + + Turning amounts are when looking directly at a face: + - 0 = clockwise turn + - 1 = anticlockwise turn + - 2 = half turn + """ + face_and_depth, amount = jnp.divmod(flattened_action, len(CubeMovementAmount)) + face, depth = jnp.divmod(face_and_depth, cube_size // 2) + return jnp.stack([face, depth, amount], axis=0) + + +def flatten_action(unflattened_action: chex.Array, cube_size: int) -> chex.Array: + """Inverse of the `unflatten_action` method. + Args: + unflattened_action: flattened action representation, a tuple: + - face (0-5). This indicates the face on which the layer will turn. + - depth (0-cube_size//2). This indicates how many layers down from the face + the turn will take place. + - amount (0-2). This indicates the amount of turning. + cube_size: the size of the cube in question. + Returns: + The flattened action representation, ie an index into the sequence of all moves. + + Convention: + - 0 = up face + - 1 = front face + - 2 = right face + - 3 = back face + - 4 = left face + - 5 = down face + All read in reading order when looking directly at a face. + + To look directly at the faces: + - UP: LEFT face on the left and BACK face pointing up + - FRONT: LEFT face on the left and UP face pointing up + - RIGHT: FRONT face on the left and UP face pointing up + - BACK: RIGHT face on the left and UP face pointing up + - LEFT: BACK face on the left and UP face pointing up + - DOWN: LEFT face on the left and FRONT face pointing up + + Turning amounts are when looking directly at a face: + - 0 = clockwise turn + - 1 = anticlockwise turn + - 2 = half turn + """ + face, depth, amount = unflattened_action + return ( + face * len(CubeMovementAmount) * (cube_size // 2) + + depth * len(CubeMovementAmount) + + amount + ) + + +def rotate_cube(cube: Cube, flattened_action: chex.Array) -> Cube: + """Apply a flattened action (index into the sequence of all moves) to a cube. + Args: + cube: the cube on which to perform the move. + flattened_action: the action to perform, in the flattened representation. + Returns: + The rotated cube. + """ + all_moves = generate_all_moves(cube_size=cube.shape[-1]) + moved_cube = jax.lax.switch(flattened_action, all_moves, cube) + return moved_cube + + +def scramble_solved_cube( + flattened_actions_in_scramble: chex.Array, + cube_size: int, +) -> Cube: + """Return a scrambled cube according to a given sequence of flat actions. + Args: + flattened_actions_in_scramble: the sequence of moves to perform, + in their flat representation. + cube_size: the size of the cube to return. + Returns: + The scrambled cube. + """ + + def rotate_cube_fn(cube: Cube, flattened_action: chex.Array) -> Cube: + return rotate_cube(cube=cube, flattened_action=flattened_action) + + cube = make_solved_cube(cube_size=cube_size) + cube, _ = jax.lax.scan( + lambda *args: (rotate_cube_fn(*args), None), + cube, + flattened_actions_in_scramble, + ) + return cube diff --git a/jumanji/environments/logic/rubiks_cube/utils_test.py b/jumanji/environments/logic/rubiks_cube/utils_test.py index bbd6e9517..bcb4f0c16 100644 --- a/jumanji/environments/logic/rubiks_cube/utils_test.py +++ b/jumanji/environments/logic/rubiks_cube/utils_test.py @@ -24,6 +24,7 @@ from jumanji.environments.logic.rubiks_cube.types import Cube, State from jumanji.environments.logic.rubiks_cube.utils import ( CubeMovementAmount, + flatten_action, generate_all_moves, generate_back_move, generate_down_move, @@ -32,6 +33,8 @@ generate_right_move, generate_up_move, make_solved_cube, + scramble_solved_cube, + unflatten_action, ) # 3x3x3 moves, for testing purposes @@ -56,7 +59,7 @@ def is_face_turn(cube_size: int) -> List[bool]: - """Says whether each action from generate_all_moves is a face turn (ie with depth 0)""" + """Says whether each action from generate_all_moves is a face turn (ie with depth 0).""" per_face_result_true = [True] * len(CubeMovementAmount) per_face_result_false = [False] * (len(CubeMovementAmount) * ((cube_size // 2) - 1)) per_face_result = per_face_result_true + per_face_result_false @@ -64,6 +67,34 @@ def is_face_turn(cube_size: int) -> List[bool]: return [item for r in result for item in r] +@pytest.mark.parametrize("cube_size", [2, 3, 4, 5]) +def test_flatten_and_unflatten_action(cube_size: int) -> None: + """Test that flattening and unflattening actions are inverse to each other.""" + flattened_actions = jnp.arange( + len(Face) * (cube_size // 2) * len(CubeMovementAmount), dtype=jnp.int32 + ) + faces = jnp.arange(len(Face), dtype=jnp.int32) + depths = jnp.arange(cube_size // 2, dtype=jnp.int32) + amounts = jnp.arange(len(CubeMovementAmount), dtype=jnp.int32) + unflattened_actions = jnp.stack( + [ + jnp.repeat(faces, len(CubeMovementAmount) * (cube_size // 2)), + jnp.concatenate( + [jnp.repeat(depths, len(CubeMovementAmount)) for _ in Face] + ), + jnp.concatenate([amounts for _ in range(len(Face) * (cube_size // 2))]), + ] + ) + assert jnp.array_equal( + unflattened_actions, + unflatten_action(flattened_action=flattened_actions, cube_size=cube_size), + ) + assert jnp.array_equal( + flattened_actions, + flatten_action(unflattened_action=unflattened_actions, cube_size=cube_size), + ) + + @pytest.mark.parametrize( "move, inverse_move", [ @@ -80,7 +111,7 @@ def test_inverses( move: Callable[[Cube], Cube], inverse_move: Callable[[Cube], Cube], ) -> None: - """Test that applying a move followed by its inverse leads back to the original""" + """Test that applying a move followed by its inverse leads back to the original.""" cube = move(differently_stickered_cube) cube = inverse_move(cube) assert jnp.array_equal(cube, differently_stickered_cube) @@ -102,7 +133,7 @@ def test_half_turns( move: Callable[[Cube], Cube], half_turn_move: Callable[[Cube], Cube], ) -> None: - """Test that 2 applications of a move followed by its half turn leads back to the original""" + """Test that 2 applications of a move followed by its half turn leads back to the original.""" cube = move(differently_stickered_cube) cube = move(cube) cube = half_turn_move(cube) @@ -112,17 +143,17 @@ def test_half_turns( def test_solved_reward( solved_cube: chex.Array, differently_stickered_cube: chex.Array ) -> None: - """Test that the cube fixtures have the expected rewards""" + """Test that the cube fixtures have the expected rewards.""" solved_state = State( cube=solved_cube, step_count=jnp.array(0, jnp.int32), - action_history=None, + action_history=jnp.array(0, jnp.int32), key=jax.random.PRNGKey(0), ) differently_stickered_state = State( cube=differently_stickered_cube, step_count=jnp.array(0, jnp.int32), - action_history=None, + action_history=jnp.array(0, jnp.int32), key=jax.random.PRNGKey(0), ) assert jnp.equal(SparseRewardFn()(solved_state), 1.0) @@ -142,12 +173,12 @@ def test_moves_nontrivial( move: Callable[[Cube], Cube], move_is_face_turn: bool, ) -> None: - """Test that all moves leave the cube in a non-solved state""" + """Test that all moves leave the cube in a non-solved state.""" move_solved_cube = move(solved_cube) move_solved_state = State( cube=move_solved_cube, step_count=jnp.array(0, jnp.int32), - action_history=None, + action_history=jnp.array(0, jnp.int32), key=jax.random.PRNGKey(0), ) assert jnp.equal(SparseRewardFn()(move_solved_state), 0.0) @@ -196,8 +227,8 @@ def test_commuting_moves( first_move: Callable[[Cube], Cube], second_move: Callable[[Cube], Cube], ) -> None: - """Check that moves that should commute, do in fact commute - (on a differently stickered cube)""" + """Test that moves that should commute, do in fact commute + (on a differently stickered cube).""" first_then_second = second_move(first_move(differently_stickered_cube)) second_then_first = first_move(second_move(differently_stickered_cube)) assert jnp.array_equal(first_then_second, second_then_first) @@ -225,7 +256,7 @@ def test_non_commuting_moves( first_move: Callable[[Cube], Cube], second_move: Callable[[Cube], Cube], ) -> None: - """Check that moves that should not commute, do not (on a solved cube)""" + """Test that moves that should not commute, do not (on a solved cube).""" first_then_second = second_move(first_move(solved_cube)) second_then_first = first_move(second_move(solved_cube)) assert ~jnp.array_equal(first_then_second, second_then_first) @@ -240,7 +271,7 @@ def test_non_commuting_moves( ], ) def test_checkerboard(cube_size: int, indices: List[int]) -> None: - """Check that the checkerboard scramble gives the expected result""" + """Test that the checkerboard scramble gives the expected result.""" cube = make_solved_cube(cube_size=cube_size) all_moves = generate_all_moves(cube_size=cube_size) for index in indices: @@ -261,7 +292,7 @@ def test_manual_scramble( solved_cube: chex.Array, expected_scramble_result: chex.Array ) -> None: """Testing a particular scramble manually. - Scramble chosen to have all faces touched at least once""" + Scramble chosen to have all faces touched at least once.""" scramble = [ up_move, left_move_half_turn, @@ -277,6 +308,46 @@ def test_manual_scramble( up_move, down_move, ] + amount_to_index = { + CubeMovementAmount.CLOCKWISE: 0, + CubeMovementAmount.ANTI_CLOCKWISE: 1, + CubeMovementAmount.HALF_TURN: 2, + } + unflattened_sequence = jnp.array( + [ + [Face.UP.value, 0, amount_to_index[CubeMovementAmount.CLOCKWISE]], + [Face.LEFT.value, 0, amount_to_index[CubeMovementAmount.HALF_TURN]], + [Face.DOWN.value, 0, amount_to_index[CubeMovementAmount.ANTI_CLOCKWISE]], + [Face.UP.value, 0, amount_to_index[CubeMovementAmount.HALF_TURN]], + [Face.BACK.value, 0, amount_to_index[CubeMovementAmount.ANTI_CLOCKWISE]], + [Face.RIGHT.value, 0, amount_to_index[CubeMovementAmount.CLOCKWISE]], + [Face.FRONT.value, 0, amount_to_index[CubeMovementAmount.CLOCKWISE]], + [Face.RIGHT.value, 0, amount_to_index[CubeMovementAmount.ANTI_CLOCKWISE]], + [Face.LEFT.value, 0, amount_to_index[CubeMovementAmount.ANTI_CLOCKWISE]], + [Face.BACK.value, 0, amount_to_index[CubeMovementAmount.HALF_TURN]], + [Face.FRONT.value, 0, amount_to_index[CubeMovementAmount.ANTI_CLOCKWISE]], + [Face.UP.value, 0, amount_to_index[CubeMovementAmount.CLOCKWISE]], + [Face.DOWN.value, 0, amount_to_index[CubeMovementAmount.CLOCKWISE]], + ], + dtype=jnp.int32, + ) + flattened_sequence = jnp.array( + [0, 14, 16, 2, 10, 6, 3, 7, 13, 11, 4, 0, 15], dtype=jnp.int32 + ) + assert jnp.array_equal( + unflattened_sequence.transpose(), + unflatten_action(flattened_action=flattened_sequence, cube_size=3), + ) + flatten_fn = lambda x: flatten_action(x, 3) + assert jnp.array_equal( + flattened_sequence, jax.vmap(flatten_fn)(unflattened_sequence) + ) + cube = scramble_solved_cube( + flattened_actions_in_scramble=flattened_sequence, + cube_size=3, + ) + assert jnp.array_equal(expected_scramble_result, cube) + cube = solved_cube for move in scramble: cube = move(cube) From cd85f1147078c430887d3326495efb103261a735 Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Mon, 27 Mar 2023 10:45:31 +0100 Subject: [PATCH 02/58] Generator and viewer --- .../logic/rubiks_cube/env_viewer.py | 200 ++++++++++++++++++ .../logic/rubiks_cube/generator.py | 94 ++++++++ 2 files changed, 294 insertions(+) create mode 100644 jumanji/environments/logic/rubiks_cube/env_viewer.py create mode 100644 jumanji/environments/logic/rubiks_cube/generator.py diff --git a/jumanji/environments/logic/rubiks_cube/env_viewer.py b/jumanji/environments/logic/rubiks_cube/env_viewer.py new file mode 100644 index 000000000..f40ed196d --- /dev/null +++ b/jumanji/environments/logic/rubiks_cube/env_viewer.py @@ -0,0 +1,200 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Sequence, Tuple + +import jax.numpy as jnp +import matplotlib +from matplotlib import pyplot as plt + +import jumanji.environments +from jumanji.environments.logic.rubiks_cube.constants import ( + DEFAULT_STICKER_COLORS, + Face, +) +from jumanji.environments.logic.rubiks_cube.types import State + + +class RubiksCubeViewer: + """Abstract viewer class to support rendering and animation""" + + def render(self, state: State) -> None: + """Render frames of the environment for a given state using matplotlib. + + Args: + state: `State` object corresponding to the new state of the environment. + """ + raise NotImplementedError + + def animate( + self, + states: Sequence[State], + interval: int = 200, + save_path: Optional[str] = None, + ) -> matplotlib.animation.FuncAnimation: + """Create an animation from a sequence of environment states. + + Args: + states: sequence of environment states corresponding to consecutive timesteps. + interval: delay between frames in milliseconds, default to 200. + save_path: the path where the animation file should be saved. If it is None, the plot + will not be saved. + + Returns: + Animation that can be saved as a GIF, MP4, or rendered with HTML. + """ + raise NotImplementedError + + def close(self) -> None: + """Perform any necessary cleanup. + + Environments will automatically :meth:`close()` themselves when + garbage collected or when the program exits. + """ + raise NotImplementedError + + +class DefaultRubiksCubeViewer(RubiksCubeViewer): + def __init__(self, sticker_colors: Optional[list] = None, cube_size: int = 3): + """ + Args: + sticker_colors: colors used in rendering the faces of the Rubik's cube. + Defaults to `DEFAULT_STICKER_COLORS`. + cube_size: size of cube to view + """ + self.cube_size = cube_size + sticker_colors = sticker_colors or DEFAULT_STICKER_COLORS + self.sticker_colors_cmap = matplotlib.colors.ListedColormap(sticker_colors) + self.figure_name = f"{cube_size}x{cube_size}x{cube_size} Rubik's Cube" + self.figure_size = (6.0, 6.0) + + def render(self, state: State) -> None: + """Render frames of the environment for a given state using matplotlib. + + Args: + state: `State` object corresponding to the new state of the environment. + """ + self._clear_display() + fig, ax = self._get_fig_ax() + self._draw(ax, state) + self._update_display(fig) + + def animate( + self, + states: Sequence[State], + interval: int = 200, + save_path: Optional[str] = None, + ) -> matplotlib.animation.FuncAnimation: + """Create an animation from a sequence of environment states. + + Args: + states: sequence of environment states corresponding to consecutive timesteps. + interval: delay between frames in milliseconds, default to 200. + save_path: the path where the animation file should be saved. If it is None, the plot + will not be saved. + + Returns: + Animation that can be saved as a GIF, MP4, or rendered with HTML. + """ + fig, ax = plt.subplots(nrows=3, ncols=2, figsize=self.figure_size) + fig.suptitle(self.figure_name) + plt.tight_layout() + ax = ax.flatten() + plt.close(fig) + + def make_frame(state_index: int) -> None: + state = states[state_index] + self._draw(ax, state) + + # Create the animation object. + self._animation = matplotlib.animation.FuncAnimation( + fig, + make_frame, + frames=len(states), + interval=interval, + ) + + # Save the animation as a gif. + if save_path: + self._animation.save(save_path) + + return self._animation + + def _get_fig_ax(self) -> Tuple[plt.Figure, List[plt.Axes]]: + exists = plt.fignum_exists(self.figure_name) + if exists: + fig = plt.figure(self.figure_name) + ax = fig.get_axes() + else: + fig, ax = plt.subplots( + nrows=3, ncols=2, figsize=self.figure_size, num=self.figure_name + ) + fig.suptitle(self.figure_name) + ax = ax.flatten() + plt.tight_layout() + plt.axis("off") + if not plt.isinteractive(): + fig.show() + return fig, ax + + def _draw(self, ax: List[plt.Axes], state: State) -> None: + i = 0 + for face in Face: + ax[i].clear() + ax[i].set_title(label=f"{face}") + ax[i].set_xticks(jnp.arange(-0.5, self.cube_size - 1, 1)) + ax[i].set_yticks(jnp.arange(-0.5, self.cube_size - 1, 1)) + ax[i].tick_params( + top=False, + bottom=False, + left=False, + right=False, + labelleft=False, + labelbottom=False, + labeltop=False, + labelright=False, + ) + ax[i].imshow( + state.cube[i], + cmap=self.sticker_colors_cmap, + vmin=0, + vmax=len(Face) - 1, + ) + ax[i].grid(color="black", linestyle="-", linewidth=2) + i += 1 + + def _update_display(self, fig: plt.Figure) -> None: + if plt.isinteractive(): + # Required to update render when using Jupyter Notebook. + fig.canvas.draw() + if jumanji.environments.is_colab(): + plt.show(self.figure_name) + else: + # Required to update render when not using Jupyter Notebook. + fig.canvas.draw_idle() + fig.canvas.flush_events() + + def _clear_display(self) -> None: + if jumanji.environments.is_colab(): + import IPython.display + + IPython.display.clear_output(True) + + def close(self) -> None: + """Perform any necessary cleanup. + + Environments will automatically :meth:`close()` themselves when + garbage collected or when the program exits. + """ + plt.close(self.figure_name) diff --git a/jumanji/environments/logic/rubiks_cube/generator.py b/jumanji/environments/logic/rubiks_cube/generator.py new file mode 100644 index 000000000..bd801ab52 --- /dev/null +++ b/jumanji/environments/logic/rubiks_cube/generator.py @@ -0,0 +1,94 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc + +import chex +import jax +import jax.numpy as jnp + +from jumanji.environments.logic.rubiks_cube.constants import CubeMovementAmount, Face +from jumanji.environments.logic.rubiks_cube.types import State +from jumanji.environments.logic.rubiks_cube.utils import ( + scramble_solved_cube, + unflatten_action, +) + + +class Generator(abc.ABC): + """Base class for generators for the RubiksCube environment.""" + + def __init__(self, cube_size: int): + """Initialises a RubiksCube generator for resetting the environment. + + Args: + cube_size: the size of the cube to generate instances for. + """ + self.cube_size = cube_size + + @abc.abstractmethod + def __call__(self, key: chex.PRNGKey) -> State: + """Generates a `RubiksCube` state. + + Returns: + A `RubiksCube` state. + """ + + +class ScramblingGenerator(Generator): + """Generates instances by applying a given number of scrambles to a solved cube""" + + def __init__( + self, + cube_size: int, + num_scrambles_on_reset: int, + time_limit: int, + ): + self.num_scrambles_on_reset = num_scrambles_on_reset + self.time_limit = time_limit + super().__init__(cube_size=cube_size) + + def __call__(self, key: chex.PRNGKey) -> State: + """Generates a `RubiksCube` state by scrambling a solved cube a fixed number of times. + + Returns: + A `RubiksCube` state. + """ + key, scramble_key = jax.random.split(key) + flattened_actions_in_scramble = jax.random.randint( + scramble_key, + minval=0, + maxval=len(Face) * (self.cube_size // 2) * len(CubeMovementAmount), + shape=(self.num_scrambles_on_reset,), + dtype=jnp.int32, + ) + cube = scramble_solved_cube( + flattened_actions_in_scramble=flattened_actions_in_scramble, + cube_size=self.cube_size, + ) + action_history = jnp.zeros( + shape=(self.num_scrambles_on_reset + self.time_limit, 3), dtype=jnp.int32 + ) + action_history = action_history.at[: self.num_scrambles_on_reset].set( + unflatten_action( + flattened_action=flattened_actions_in_scramble, cube_size=self.cube_size + ).transpose() + ) + step_count = jnp.array(0, jnp.int32) + return State( + cube=cube, + step_count=step_count, + key=key, + action_history=action_history, + ) From f4d07f342ef612bf105724e7b86529c9f0409bb2 Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Mon, 27 Mar 2023 11:32:52 +0100 Subject: [PATCH 03/58] Changes from patch --- .../logic/minesweeper/constants.py | 2 +- jumanji/environments/logic/minesweeper/env.py | 193 ++++------------ .../logic/minesweeper/env_viewer.py | 216 ++++++++++++++++++ .../logic/minesweeper/generator.py | 82 +++++++ 4 files changed, 342 insertions(+), 151 deletions(-) create mode 100644 jumanji/environments/logic/minesweeper/env_viewer.py create mode 100644 jumanji/environments/logic/minesweeper/generator.py diff --git a/jumanji/environments/logic/minesweeper/constants.py b/jumanji/environments/logic/minesweeper/constants.py index 21f4b4824..63f4580d7 100644 --- a/jumanji/environments/logic/minesweeper/constants.py +++ b/jumanji/environments/logic/minesweeper/constants.py @@ -17,7 +17,7 @@ PATCH_SIZE: int = 3 REVEALED_EMPTY_SQUARE_REWARD: float = 1.0 REVEALED_MINE_OR_INVALID_ACTION_REWARD: float = 0.0 -COLOUR_MAPPING: list = [ +DEFAULT_COLOR_MAPPING: list = [ "orange", "blue", "green", diff --git a/jumanji/environments/logic/minesweeper/env.py b/jumanji/environments/logic/minesweeper/env.py index b9ec6a0d5..d0325aa39 100644 --- a/jumanji/environments/logic/minesweeper/env.py +++ b/jumanji/environments/logic/minesweeper/env.py @@ -12,30 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Sequence, Tuple +from typing import Optional, Sequence, Tuple import chex import jax import jax.numpy as jnp import matplotlib.animation -import matplotlib.pyplot as plt -import jumanji.environments from jumanji import specs from jumanji.env import Environment -from jumanji.environments.logic.minesweeper.constants import ( - COLOUR_MAPPING, - PATCH_SIZE, - UNEXPLORED_ID, -) +from jumanji.environments.logic.minesweeper.constants import PATCH_SIZE, UNEXPLORED_ID from jumanji.environments.logic.minesweeper.done import DefaultDoneFn, DoneFn +from jumanji.environments.logic.minesweeper.env_viewer import ( + DefaultMinesweeperViewer, + MinesweeperViewer, +) +from jumanji.environments.logic.minesweeper.generator import ( + Generator, + SamplingGenerator, +) from jumanji.environments.logic.minesweeper.reward import DefaultRewardFn, RewardFn from jumanji.environments.logic.minesweeper.types import Observation, State -from jumanji.environments.logic.minesweeper.utils import ( - count_adjacent_mines, - create_flat_mine_locations, - explored_mine, -) +from jumanji.environments.logic.minesweeper.utils import count_adjacent_mines from jumanji.types import TimeStep, restart, termination, transition @@ -97,7 +95,8 @@ def __init__( num_mines: int = 10, reward_function: Optional[RewardFn] = None, done_function: Optional[DoneFn] = None, - color_mapping: Optional[List[str]] = None, + env_viewer: Optional[MinesweeperViewer] = None, + generator: Optional[Generator] = None, ): """Instantiate a `Minesweeper` environment. @@ -105,13 +104,19 @@ def __init__( num_rows: number of rows, i.e. height of the board. Defaults to 10. num_cols: number of columns, i.e. width of the board. Defaults to 10. num_mines: number of mines on the board. Defaults to 10. + Note that this argument will be ignored if a custom generator is passed. reward_function: `RewardFn` whose `__call__` method computes the reward of an environment transition based on the given current state and selected action. Implemented options are [`DefaultRewardFn`]. Defaults to `DefaultRewardFn`. done_function: `DoneFn` whose `__call__` method computes the done signal given the current state, action taken, and next state. Implemented options are [`DefaultDoneFn`]. Defaults to `DefaultDoneFn`. - color_mapping: colour map used for rendering. + env_viewer: MinesweeperViewer to support rendering and animation methods. + Implemented options are [`DefaultMinesweeperViewer`]. + Defaults to `DefaultMinesweeperViewer`. + generator: Generator to generate problem instances on environment reset. + Implemented options are [`SamplingGenerator`]. + Defaults to `SamplingGenerator`. """ if num_rows <= 1 or num_cols <= 1: raise ValueError( @@ -129,9 +134,12 @@ def __init__( self.reward_function = reward_function or DefaultRewardFn() self.done_function = done_function or DefaultDoneFn() - self.cmap = color_mapping if color_mapping else COLOUR_MAPPING - self.figure_name = f"{num_rows}x{num_cols} Minesweeper" - self.figure_size = (6.0, 6.0) + self._env_viewer = env_viewer or DefaultMinesweeperViewer( + num_rows=num_rows, num_cols=num_cols + ) + self._generator = generator or SamplingGenerator( + num_rows=num_rows, num_cols=num_cols, num_mines=num_mines + ) def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: """Resets the environment. @@ -144,25 +152,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: timestep: `TimeStep` corresponding to the first timestep returned by the environment. """ - key, sample_key = jax.random.split(key) - board = jnp.full( - shape=(self.num_rows, self.num_cols), - fill_value=UNEXPLORED_ID, - dtype=jnp.int32, - ) - step_count = jnp.array(0, jnp.int32) - flat_mine_locations = create_flat_mine_locations( - key=sample_key, - num_rows=self.num_rows, - num_cols=self.num_cols, - num_mines=self.num_mines, - ) - state = State( - board=board, - step_count=step_count, - key=key, - flat_mine_locations=flat_mine_locations, - ) + state = self._generator(key) observation = self._state_to_observation(state=state) timestep = restart(observation=observation) return state, timestep @@ -272,17 +262,14 @@ def _state_to_observation(self, state: State) -> Observation: step_count=state.step_count, ) - def render(self, state: State) -> None: - """Render the given environment state using matplotlib. - + def render(self, state: State, save_path: Optional[str] = None) -> None: + """Renders the current state of the board. Args: - state: environment state to be rendered. - + state: the current state to be rendered. + save_path: the path where the image should be saved. If it is None, the plot + will not be stored. """ - self._clear_display() - fig, ax = self._get_fig_ax() - self._draw(ax, state) - self._update_display(fig) + self._env_viewer.render(state=state) def animate( self, @@ -290,116 +277,22 @@ def animate( interval: int = 200, save_path: Optional[str] = None, ) -> matplotlib.animation.FuncAnimation: - """Create an animation from a sequence of environment states. - - Args: - states: sequence of environment states corresponding to consecutive timesteps. - interval: delay between frames in milliseconds, default to 200. - save_path: the path where the animation file should be saved. If it is None, the plot - will not be saved. - + """Creates an animated gif of the board based on the sequence of states. + Args: + states: a list of `State` objects representing the sequence of states. + interval: the delay between frames in milliseconds, default to 200. + save_path: the path where the animation file should be saved. If it is None, the plot + will not be stored. Returns: - Animation object that can be saved as a GIF, MP4, or rendered with HTML. + animation.FuncAnimation: the animation object that was created. """ - fig, ax = self._get_fig_ax() - plt.tight_layout() - plt.close(fig) - - def make_frame(state_index: int) -> None: - state = states[state_index] - self._draw(ax, state) - - # Create the animation object. - self._animation = matplotlib.animation.FuncAnimation( - fig, - make_frame, - frames=len(states), - interval=interval, + return self._env_viewer.animate( + states=states, interval=interval, save_path=save_path ) - # Save the animation as a GIF. - if save_path: - self._animation.save(save_path) - - return self._animation - def close(self) -> None: """Perform any necessary cleanup. - Environments will automatically :meth:`close()` themselves when garbage collected or when the program exits. """ - plt.close(self.figure_name) - - def _get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]: - exists = plt.fignum_exists(self.figure_name) - if exists: - fig = plt.figure(self.figure_name) - ax = fig.get_axes()[0] - else: - fig = plt.figure(self.figure_name, figsize=self.figure_size) - plt.suptitle(self.figure_name) - plt.tight_layout() - if not plt.isinteractive(): - fig.show() - ax = fig.add_subplot() - return fig, ax - - def _draw(self, ax: plt.Axes, state: State) -> None: - ax.clear() - ax.set_xticks(jnp.arange(-0.5, self.num_cols - 1, 1)) - ax.set_yticks(jnp.arange(-0.5, self.num_rows - 1, 1)) - ax.tick_params( - top=False, - bottom=False, - left=False, - right=False, - labelleft=False, - labelbottom=False, - labeltop=False, - labelright=False, - ) - background = jnp.ones_like(state.board) - for i in range(self.num_rows): - for j in range(self.num_cols): - background = self._render_grid_square( - state=state, ax=ax, i=i, j=j, background=background - ) - ax.imshow(background, cmap="gray", vmin=0, vmax=1) - ax.grid(color="black", linestyle="-", linewidth=2) - - def _render_grid_square( - self, state: State, ax: plt.Axes, i: int, j: int, background: chex.Array - ) -> chex.Array: - board_value = state.board[i, j] - if board_value != UNEXPLORED_ID: - if explored_mine(state=state, action=jnp.array([i, j], dtype=jnp.int32)): - background = background.at[i, j].set(0) - else: - ax.text( - j, - i, - str(board_value), - color=self.cmap[board_value], - ha="center", - va="center", - fontsize="xx-large", - ) - return background - - def _update_display(self, fig: plt.Figure) -> None: - if plt.isinteractive(): - # Required to update render when using Jupyter Notebook. - fig.canvas.draw() - if jumanji.environments.is_colab(): - plt.show(self.figure_name) - else: - # Required to update render when not using Jupyter Notebook. - fig.canvas.draw_idle() - fig.canvas.flush_events() - - def _clear_display(self) -> None: - if jumanji.environments.is_colab(): - import IPython.display - - IPython.display.clear_output(True) + self._env_viewer.close() diff --git a/jumanji/environments/logic/minesweeper/env_viewer.py b/jumanji/environments/logic/minesweeper/env_viewer.py new file mode 100644 index 000000000..3115f491c --- /dev/null +++ b/jumanji/environments/logic/minesweeper/env_viewer.py @@ -0,0 +1,216 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Sequence, Tuple + +import chex +import jax.numpy as jnp +import matplotlib +from matplotlib import pyplot as plt + +import jumanji.environments +from jumanji.environments.logic.minesweeper.constants import ( + DEFAULT_COLOR_MAPPING, + UNEXPLORED_ID, +) +from jumanji.environments.logic.minesweeper.types import State +from jumanji.environments.logic.minesweeper.utils import explored_mine + + +class MinesweeperViewer: + """Abstract viewer class to support rendering and animation""" + + def render(self, state: State) -> None: + """Render frames of the environment for a given state using matplotlib. + Args: + state: `State` object corresponding to the new state of the environment. + """ + raise NotImplementedError + + def animate( + self, + states: Sequence[State], + interval: int = 200, + save_path: Optional[str] = None, + ) -> matplotlib.animation.FuncAnimation: + """Create an animation from a sequence of environment states. + Args: + states: sequence of environment states corresponding to consecutive timesteps. + interval: delay between frames in milliseconds, default to 200. + save_path: the path where the animation file should be saved. If it is None, the plot + will not be saved. + Returns: + Animation that can be saved as a GIF, MP4, or rendered with HTML. + """ + raise NotImplementedError + + def close(self) -> None: + """Perform any necessary cleanup. + Environments will automatically :meth:`close()` themselves when + garbage collected or when the program exits. + """ + raise NotImplementedError + + +class DefaultMinesweeperViewer(MinesweeperViewer): + def __init__( + self, + color_mapping: Optional[List[str]] = None, + num_rows: int = 10, + num_cols: int = 10, + ): + """ + Args: + color_mapping: colors used in rendering the cells in Minesweeper. + Defaults to `DEFAULT_COLOR_MAPPING`. + num_rows: number of rows, i.e. height of the board. Defaults to 10. + num_cols: number of columns, i.e. width of the board. Defaults to 10. + """ + self.cmap = color_mapping if color_mapping else DEFAULT_COLOR_MAPPING + self.num_rows = num_rows + self.num_cols = num_cols + self.figure_name = f"{num_rows}x{num_cols} Minesweeper" + self.figure_size = (6.0, 6.0) + + def render(self, state: State) -> None: + """Render the given environment state using matplotlib. + + Args: + state: environment state to be rendered. + + """ + self._clear_display() + fig, ax = self._get_fig_ax() + self._draw(ax, state) + self._update_display(fig) + + def animate( + self, + states: Sequence[State], + interval: int = 200, + save_path: Optional[str] = None, + ) -> matplotlib.animation.FuncAnimation: + """Create an animation from a sequence of environment states. + + Args: + states: sequence of environment states corresponding to consecutive timesteps. + interval: delay between frames in milliseconds, default to 200. + save_path: the path where the animation file should be saved. If it is None, the plot + will not be saved. + + Returns: + Animation object that can be saved as a GIF, MP4, or rendered with HTML. + """ + fig, ax = self._get_fig_ax() + plt.tight_layout() + plt.close(fig) + + def make_frame(state_index: int) -> None: + state = states[state_index] + self._draw(ax, state) + + # Create the animation object. + self._animation = matplotlib.animation.FuncAnimation( + fig, + make_frame, + frames=len(states), + interval=interval, + ) + + # Save the animation as a GIF. + if save_path: + self._animation.save(save_path) + + return self._animation + + def close(self) -> None: + """Perform any necessary cleanup. + + Environments will automatically :meth:`close()` themselves when + garbage collected or when the program exits. + """ + plt.close(self.figure_name) + + def _get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]: + exists = plt.fignum_exists(self.figure_name) + if exists: + fig = plt.figure(self.figure_name) + ax = fig.get_axes()[0] + else: + fig = plt.figure(self.figure_name, figsize=self.figure_size) + plt.suptitle(self.figure_name) + plt.tight_layout() + if not plt.isinteractive(): + fig.show() + ax = fig.add_subplot() + return fig, ax + + def _draw(self, ax: plt.Axes, state: State) -> None: + ax.clear() + ax.set_xticks(jnp.arange(-0.5, self.num_cols - 1, 1)) + ax.set_yticks(jnp.arange(-0.5, self.num_rows - 1, 1)) + ax.tick_params( + top=False, + bottom=False, + left=False, + right=False, + labelleft=False, + labelbottom=False, + labeltop=False, + labelright=False, + ) + background = jnp.ones_like(state.board) + for i in range(self.num_rows): + for j in range(self.num_cols): + background = self._render_grid_square( + state=state, ax=ax, i=i, j=j, background=background + ) + ax.imshow(background, cmap="gray", vmin=0, vmax=1) + ax.grid(color="black", linestyle="-", linewidth=2) + + def _render_grid_square( + self, state: State, ax: plt.Axes, i: int, j: int, background: chex.Array + ) -> chex.Array: + board_value = state.board[i, j] + if board_value != UNEXPLORED_ID: + if explored_mine(state=state, action=jnp.array([i, j], dtype=jnp.int32)): + background = background.at[i, j].set(0) + else: + ax.text( + j, + i, + str(board_value), + color=self.cmap[board_value], + ha="center", + va="center", + fontsize="xx-large", + ) + return background + + def _update_display(self, fig: plt.Figure) -> None: + if plt.isinteractive(): + # Required to update render when using Jupyter Notebook. + fig.canvas.draw() + if jumanji.environments.is_colab(): + plt.show(self.figure_name) + else: + # Required to update render when not using Jupyter Notebook. + fig.canvas.draw_idle() + fig.canvas.flush_events() + + def _clear_display(self) -> None: + if jumanji.environments.is_colab(): + import IPython.display + + IPython.display.clear_output(True) diff --git a/jumanji/environments/logic/minesweeper/generator.py b/jumanji/environments/logic/minesweeper/generator.py new file mode 100644 index 000000000..7db1f9657 --- /dev/null +++ b/jumanji/environments/logic/minesweeper/generator.py @@ -0,0 +1,82 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc + +import chex +import jax +import jax.numpy as jnp + +from jumanji.environments.logic.minesweeper.constants import UNEXPLORED_ID +from jumanji.environments.logic.minesweeper.types import State +from jumanji.environments.logic.minesweeper.utils import create_flat_mine_locations + + +class Generator(abc.ABC): + """Base class for generators for the Minesweeper environment.""" + + def __init__(self, num_rows: int, num_cols: int): + """Initialises a Minesweeper generator for resetting the environment. + Args: + num_rows: number of rows, i.e. height of the board. + num_cols: number of columns, i.e. width of the board. + """ + self.num_rows = num_rows + self.num_cols = num_cols + + @abc.abstractmethod + def __call__(self, key: chex.PRNGKey) -> State: + """Generates a `Minesweeper` state. + Returns: + A `Minesweeper` state. + """ + + +class SamplingGenerator(Generator): + """Generates instances by sampling a given number of mines (without replacement).""" + + def __init__( + self, + num_rows: int, + num_cols: int, + num_mines: int, + ): + self.num_mines = num_mines + super().__init__(num_rows=num_rows, num_cols=num_cols) + + def __call__(self, key: chex.PRNGKey) -> State: + """Generates a `Minesweeper` state by placing a fixed number of mines on the board. + Returns: + A `Minesweeper` state. + """ + key, sample_key = jax.random.split(key) + board = jnp.full( + shape=(self.num_rows, self.num_cols), + fill_value=UNEXPLORED_ID, + dtype=jnp.int32, + ) + step_count = jnp.array(0, jnp.int32) + flat_mine_locations = create_flat_mine_locations( + key=sample_key, + num_rows=self.num_rows, + num_cols=self.num_cols, + num_mines=self.num_mines, + ) + state = State( + board=board, + step_count=step_count, + key=key, + flat_mine_locations=flat_mine_locations, + ) + return state From e13bd18c821fd78dc9038d198616a6772bea1a77 Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Mon, 27 Mar 2023 11:35:22 +0100 Subject: [PATCH 04/58] Fix: naming --- jumanji/environments/logic/rubiks_cube/env.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/jumanji/environments/logic/rubiks_cube/env.py b/jumanji/environments/logic/rubiks_cube/env.py index fbea3f2d0..e1855be61 100644 --- a/jumanji/environments/logic/rubiks_cube/env.py +++ b/jumanji/environments/logic/rubiks_cube/env.py @@ -250,10 +250,10 @@ def _state_to_observation(self, state: State) -> Observation: return Observation(cube=state.cube, step_count=state.step_count) def render(self, state: State, save_path: Optional[str] = None) -> None: - """Renders the current state of the game board. + """Renders the current state of the cube. Args: - state: is the current game state to be rendered. + state: the current state to be rendered. save_path: the path where the image should be saved. If it is None, the plot will not be stored. """ @@ -265,10 +265,10 @@ def animate( interval: int = 200, save_path: Optional[str] = None, ) -> matplotlib.animation.FuncAnimation: - """Creates an animated gif of the 2048 game board based on the sequence of game states. + """Creates an animated gif of the cube based on the sequence of states. Args: - states: is a list of `State` objects representing the sequence of game states. + states: a list of `State` objects representing the sequence of game states. interval: the delay between frames in milliseconds, default to 200. save_path: the path where the animation file should be saved. If it is None, the plot will not be stored. From f5076c2d105a7126b6a15e6402e5cadd951a2e25 Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Mon, 27 Mar 2023 13:02:19 +0100 Subject: [PATCH 05/58] Some more generic fixes but more work required --- .../logic/minesweeper/constants.py | 3 +- jumanji/environments/logic/minesweeper/env.py | 29 ++++++------ .../logic/minesweeper/generator.py | 47 ++++++++++--------- .../environments/logic/minesweeper/reward.py | 26 ++++++---- 4 files changed, 60 insertions(+), 45 deletions(-) diff --git a/jumanji/environments/logic/minesweeper/constants.py b/jumanji/environments/logic/minesweeper/constants.py index 63f4580d7..e7f0a9244 100644 --- a/jumanji/environments/logic/minesweeper/constants.py +++ b/jumanji/environments/logic/minesweeper/constants.py @@ -16,7 +16,8 @@ IS_MINE: int = 1 PATCH_SIZE: int = 3 REVEALED_EMPTY_SQUARE_REWARD: float = 1.0 -REVEALED_MINE_OR_INVALID_ACTION_REWARD: float = 0.0 +REVEALED_MINE_REWARD: float = 0.0 +INVALID_ACTION_REWARD: float = 0.0 DEFAULT_COLOR_MAPPING: list = [ "orange", "blue", diff --git a/jumanji/environments/logic/minesweeper/env.py b/jumanji/environments/logic/minesweeper/env.py index d0325aa39..632750e61 100644 --- a/jumanji/environments/logic/minesweeper/env.py +++ b/jumanji/environments/logic/minesweeper/env.py @@ -21,7 +21,13 @@ from jumanji import specs from jumanji.env import Environment -from jumanji.environments.logic.minesweeper.constants import PATCH_SIZE, UNEXPLORED_ID +from jumanji.environments.logic.minesweeper.constants import ( + INVALID_ACTION_REWARD, + PATCH_SIZE, + REVEALED_EMPTY_SQUARE_REWARD, + REVEALED_MINE_REWARD, + UNEXPLORED_ID, +) from jumanji.environments.logic.minesweeper.done import DefaultDoneFn, DoneFn from jumanji.environments.logic.minesweeper.env_viewer import ( DefaultMinesweeperViewer, @@ -92,7 +98,6 @@ def __init__( self, num_rows: int = 10, num_cols: int = 10, - num_mines: int = 10, reward_function: Optional[RewardFn] = None, done_function: Optional[DoneFn] = None, env_viewer: Optional[MinesweeperViewer] = None, @@ -103,8 +108,6 @@ def __init__( Args: num_rows: number of rows, i.e. height of the board. Defaults to 10. num_cols: number of columns, i.e. width of the board. Defaults to 10. - num_mines: number of mines on the board. Defaults to 10. - Note that this argument will be ignored if a custom generator is passed. reward_function: `RewardFn` whose `__call__` method computes the reward of an environment transition based on the given current state and selected action. Implemented options are [`DefaultRewardFn`]. Defaults to `DefaultRewardFn`. @@ -123,22 +126,20 @@ def __init__( f"Should make a board of height and width greater than 1, " f"got num_rows={num_rows}, num_cols={num_cols}" ) - if num_mines < 0 or num_mines >= num_rows * num_cols: - raise ValueError( - f"Number of mines should be constrained between 0 and the size of the board, " - f"got {num_mines}" - ) self.num_rows = num_rows self.num_cols = num_cols - self.num_mines = num_mines - self.reward_function = reward_function or DefaultRewardFn() + self.reward_function = reward_function or DefaultRewardFn( + revealed_empty_square_reward=REVEALED_EMPTY_SQUARE_REWARD, + revealed_mine_reward=REVEALED_MINE_REWARD, + invalid_action_reward=INVALID_ACTION_REWARD, + ) self.done_function = done_function or DefaultDoneFn() self._env_viewer = env_viewer or DefaultMinesweeperViewer( num_rows=num_rows, num_cols=num_cols ) self._generator = generator or SamplingGenerator( - num_rows=num_rows, num_cols=num_cols, num_mines=num_mines + num_rows=num_rows, num_cols=num_cols, num_mines=10 ) def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: @@ -229,7 +230,7 @@ def observation_spec(self) -> specs.Spec[Observation]: shape=(), dtype=jnp.int32, minimum=0, - maximum=self.num_rows * self.num_cols - self.num_mines, + maximum=self.num_rows * self.num_cols, name="step_count", ) return specs.Spec( @@ -258,7 +259,7 @@ def _state_to_observation(self, state: State) -> Observation: return Observation( board=state.board, action_mask=jnp.equal(state.board, UNEXPLORED_ID), - num_mines=jnp.array(self.num_mines, jnp.int32), + num_mines=jnp.array(10, jnp.int32), # todo: make this more generic step_count=state.step_count, ) diff --git a/jumanji/environments/logic/minesweeper/generator.py b/jumanji/environments/logic/minesweeper/generator.py index 7db1f9657..3d39e77f5 100644 --- a/jumanji/environments/logic/minesweeper/generator.py +++ b/jumanji/environments/logic/minesweeper/generator.py @@ -36,11 +36,29 @@ def __init__(self, num_rows: int, num_cols: int): self.num_cols = num_cols @abc.abstractmethod + def generate_flat_mine_locations(self, key: chex.PRNGKey) -> chex.Array: + """Generates positions (in flattened coordinates) of the mines in the board""" + def __call__(self, key: chex.PRNGKey) -> State: """Generates a `Minesweeper` state. Returns: A `Minesweeper` state. """ + key, sample_key = jax.random.split(key) + board = jnp.full( + shape=(self.num_rows, self.num_cols), + fill_value=UNEXPLORED_ID, + dtype=jnp.int32, + ) + step_count = jnp.array(0, jnp.int32) + flat_mine_locations = self.generate_flat_mine_locations(key=sample_key) + state = State( + board=board, + step_count=step_count, + key=key, + flat_mine_locations=flat_mine_locations, + ) + return state class SamplingGenerator(Generator): @@ -52,31 +70,18 @@ def __init__( num_cols: int, num_mines: int, ): + if num_mines < 0 or num_mines >= num_rows * num_cols: + raise ValueError( + f"Number of mines should be constrained between 0 and the size of the board, " + f"got {num_mines}" + ) self.num_mines = num_mines super().__init__(num_rows=num_rows, num_cols=num_cols) - def __call__(self, key: chex.PRNGKey) -> State: - """Generates a `Minesweeper` state by placing a fixed number of mines on the board. - Returns: - A `Minesweeper` state. - """ - key, sample_key = jax.random.split(key) - board = jnp.full( - shape=(self.num_rows, self.num_cols), - fill_value=UNEXPLORED_ID, - dtype=jnp.int32, - ) - step_count = jnp.array(0, jnp.int32) - flat_mine_locations = create_flat_mine_locations( - key=sample_key, + def generate_flat_mine_locations(self, key: chex.PRNGKey) -> chex.Array: + return create_flat_mine_locations( + key=key, num_rows=self.num_rows, num_cols=self.num_cols, num_mines=self.num_mines, ) - state = State( - board=board, - step_count=step_count, - key=key, - flat_mine_locations=flat_mine_locations, - ) - return state diff --git a/jumanji/environments/logic/minesweeper/reward.py b/jumanji/environments/logic/minesweeper/reward.py index 336d0325c..45f6a6141 100644 --- a/jumanji/environments/logic/minesweeper/reward.py +++ b/jumanji/environments/logic/minesweeper/reward.py @@ -17,10 +17,6 @@ import chex import jax.numpy as jnp -from jumanji.environments.logic.minesweeper.constants import ( - REVEALED_EMPTY_SQUARE_REWARD, - REVEALED_MINE_OR_INVALID_ACTION_REWARD, -) from jumanji.environments.logic.minesweeper.types import State from jumanji.environments.logic.minesweeper.utils import explored_mine, is_valid_action @@ -32,17 +28,29 @@ def __call__(self, state: State, action: chex.Array) -> chex.Array: class DefaultRewardFn(RewardFn): - """A dense reward function: 1 for every timestep on which a mine is not explored - (or a small penalty if action is invalid), otherwise 0. + """A dense reward function corresponding to the 3 possible events: + - Revealing an empty square + - Revealing a mine + - Choosing an invalid action (an already revealed square) """ + def __init__( + self, + revealed_empty_square_reward: float, + revealed_mine_reward: float, + invalid_action_reward: float, + ): + self.revelead_empty_square_reward = revealed_empty_square_reward + self.revelead_mine_reward = revealed_mine_reward + self.invalid_action_reward = invalid_action_reward + def __call__(self, state: State, action: chex.Array) -> chex.Array: return jnp.where( is_valid_action(state=state, action=action), jnp.where( explored_mine(state=state, action=action), - jnp.array(REVEALED_MINE_OR_INVALID_ACTION_REWARD, float), - jnp.array(REVEALED_EMPTY_SQUARE_REWARD, float), + jnp.array(self.revelead_mine_reward, float), + jnp.array(self.revelead_empty_square_reward, float), ), - jnp.array(REVEALED_MINE_OR_INVALID_ACTION_REWARD, float), + jnp.array(self.invalid_action_reward, float), ) From 793af1d7c75649a8449a5d1fb9aaaaff2d6aaa26 Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Mon, 27 Mar 2023 14:00:51 +0100 Subject: [PATCH 06/58] Separate reward definitions --- jumanji/environments/logic/minesweeper/env_test.py | 7 ++++--- jumanji/environments/logic/minesweeper/reward.py | 6 +++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/jumanji/environments/logic/minesweeper/env_test.py b/jumanji/environments/logic/minesweeper/env_test.py index 129d81583..0c720d60d 100644 --- a/jumanji/environments/logic/minesweeper/env_test.py +++ b/jumanji/environments/logic/minesweeper/env_test.py @@ -24,8 +24,9 @@ from jax import random from jumanji.environments.logic.minesweeper.constants import ( + INVALID_ACTION_REWARD, REVEALED_EMPTY_SQUARE_REWARD, - REVEALED_MINE_OR_INVALID_ACTION_REWARD, + REVEALED_MINE_REWARD, ) from jumanji.environments.logic.minesweeper.env import Minesweeper from jumanji.environments.logic.minesweeper.types import State @@ -69,12 +70,12 @@ def play_and_get_episode_stats( ), ( [[0, 3], [0, 2]], - [REVEALED_EMPTY_SQUARE_REWARD, REVEALED_MINE_OR_INVALID_ACTION_REWARD], + [REVEALED_EMPTY_SQUARE_REWARD, REVEALED_MINE_REWARD], [StepType.MID, StepType.LAST], ), ( [[0, 3], [0, 3]], - [REVEALED_EMPTY_SQUARE_REWARD, REVEALED_MINE_OR_INVALID_ACTION_REWARD], + [REVEALED_EMPTY_SQUARE_REWARD, INVALID_ACTION_REWARD], [StepType.MID, StepType.LAST], ), ], diff --git a/jumanji/environments/logic/minesweeper/reward.py b/jumanji/environments/logic/minesweeper/reward.py index 45f6a6141..7e8e6f61c 100644 --- a/jumanji/environments/logic/minesweeper/reward.py +++ b/jumanji/environments/logic/minesweeper/reward.py @@ -29,9 +29,9 @@ def __call__(self, state: State, action: chex.Array) -> chex.Array: class DefaultRewardFn(RewardFn): """A dense reward function corresponding to the 3 possible events: - - Revealing an empty square - - Revealing a mine - - Choosing an invalid action (an already revealed square) + - Revealing an empty square + - Revealing a mine + - Choosing an invalid action (an already revealed square) """ def __init__( From 8497f847b956f82780e83fbfa847cff538290dd0 Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Mon, 27 Mar 2023 14:19:38 +0100 Subject: [PATCH 07/58] Remove action history --- jumanji/environments/logic/rubiks_cube/env.py | 17 ------- .../logic/rubiks_cube/env_test.py | 16 ------ .../logic/rubiks_cube/generator.py | 50 ++++++++----------- .../environments/logic/rubiks_cube/types.py | 4 -- .../logic/rubiks_cube/utils_test.py | 3 -- 5 files changed, 20 insertions(+), 70 deletions(-) diff --git a/jumanji/environments/logic/rubiks_cube/env.py b/jumanji/environments/logic/rubiks_cube/env.py index e1855be61..57220dbd4 100644 --- a/jumanji/environments/logic/rubiks_cube/env.py +++ b/jumanji/environments/logic/rubiks_cube/env.py @@ -66,14 +66,6 @@ class RubiksCube(Environment[State]): specifies how many timesteps have elapsed since environment reset. - key: jax array (uint) of shape (2,) used for seeding the sampling for scrambling on reset. - - action_history: jax array (int32) of shape (num_scrambles_on_reset + time_limit, 3): - indicates the entire history of applied moves (including those taken on scrambling the - cube in the environment reset). This is useful for debugging purposes, providing a - method to solve the cube from any position without relying on the agent, by just - inverting the action history. The first axis indexes over the length of the sequence - The second axis indexes over the component of the action (face, depth, amount). The - number of scrambles applied for each state is given by - `env.num_scrambles_on_reset + state.step_count`. ```python from jumanji.environments import RubiksCube @@ -176,15 +168,11 @@ def step( cube=state.cube, flattened_action=flattened_action, ) - action_history = state.action_history.at[ - self.num_scrambles_on_reset + state.step_count - ].set(action) step_count = state.step_count + 1 next_state = State( cube=cube, step_count=step_count, key=state.key, - action_history=action_history, ) reward = self.reward_function(state=next_state) solved = is_solved(cube) @@ -199,11 +187,6 @@ def step( ) return next_state, next_timestep - def get_action_history(self, state: State) -> chex.Array: - """Slice and return the action history from the state.""" - action_history_index = self.num_scrambles_on_reset + state.step_count - return state.action_history[:action_history_index] - def observation_spec(self) -> specs.Spec[Observation]: """Specifications of the observation of the `RubiksCube` environment. diff --git a/jumanji/environments/logic/rubiks_cube/env_test.py b/jumanji/environments/logic/rubiks_cube/env_test.py index 7b8aa08d5..a70f04972 100644 --- a/jumanji/environments/logic/rubiks_cube/env_test.py +++ b/jumanji/environments/logic/rubiks_cube/env_test.py @@ -20,7 +20,6 @@ import pytest import pytest_mock -from jumanji.environments.logic.rubiks_cube.constants import CubeMovementAmount, Face from jumanji.environments.logic.rubiks_cube.env import RubiksCube from jumanji.environments.logic.rubiks_cube.types import State from jumanji.testing.env_not_smoke import check_env_does_not_smoke @@ -38,14 +37,6 @@ def test_rubiks_cube__reset(rubiks_cube: RubiksCube) -> None: assert isinstance(timestep, TimeStep) assert isinstance(state, State) assert state.step_count == 0 - expected_shape = (rubiks_cube.num_scrambles_on_reset + rubiks_cube.time_limit, 3) - assert state.action_history.shape == expected_shape - action_history_index = rubiks_cube.num_scrambles_on_reset - assert jnp.all(jnp.equal(state.action_history[action_history_index:], 0)) - assert state.action_history.min() >= 0 - assert state.action_history[:, 0].max() < len(Face) - assert state.action_history[:, 1].max() < rubiks_cube.cube_size // 2 - assert state.action_history[:, 2].max() < len(CubeMovementAmount) assert jnp.array_equal(state.cube, timestep.observation.cube) assert timestep.observation.step_count == 0 # Check that the state is made of DeviceArrays, this is false for the non-jitted @@ -67,10 +58,6 @@ def test_rubiks_cube__step(rubiks_cube: RubiksCube) -> None: assert next_state.step_count == 1 assert next_timestep.observation.step_count == 1 assert jnp.array_equal(next_state.cube, next_timestep.observation.cube) - expected_shape = (rubiks_cube.num_scrambles_on_reset + rubiks_cube.time_limit, 3) - assert next_state.action_history.shape == expected_shape - action_history_index = rubiks_cube.num_scrambles_on_reset + 1 - assert jnp.all(next_state.action_history[action_history_index:] == 0) # Check that the state is made of DeviceArrays, this is false for the non-jitted # step function since unpacking random.split returns numpy arrays and not device arrays. @@ -84,9 +71,6 @@ def test_rubiks_cube__step(rubiks_cube: RubiksCube) -> None: assert next_next_state.step_count == 2 assert next_next_timestep.observation.step_count == 2 assert jnp.array_equal(next_next_state.cube, next_next_timestep.observation.cube) - assert next_next_state.action_history.shape == expected_shape - action_history_index = rubiks_cube.num_scrambles_on_reset + 2 - assert jnp.all(next_next_state.action_history[action_history_index:] == 0) @pytest.mark.parametrize("cube_size", [3, 4, 5]) diff --git a/jumanji/environments/logic/rubiks_cube/generator.py b/jumanji/environments/logic/rubiks_cube/generator.py index bd801ab52..b5c23dcff 100644 --- a/jumanji/environments/logic/rubiks_cube/generator.py +++ b/jumanji/environments/logic/rubiks_cube/generator.py @@ -19,11 +19,8 @@ import jax.numpy as jnp from jumanji.environments.logic.rubiks_cube.constants import CubeMovementAmount, Face -from jumanji.environments.logic.rubiks_cube.types import State -from jumanji.environments.logic.rubiks_cube.utils import ( - scramble_solved_cube, - unflatten_action, -) +from jumanji.environments.logic.rubiks_cube.types import Cube, State +from jumanji.environments.logic.rubiks_cube.utils import scramble_solved_cube class Generator(abc.ABC): @@ -38,12 +35,23 @@ def __init__(self, cube_size: int): self.cube_size = cube_size @abc.abstractmethod + def generate_cube(self, key: chex.PRNGKey) -> Cube: + """Generate a cube for this instance""" + def __call__(self, key: chex.PRNGKey) -> State: """Generates a `RubiksCube` state. Returns: A `RubiksCube` state. """ + key, scramble_key = jax.random.split(key) + cube = self.generate_cube(key=scramble_key) + step_count = jnp.array(0, jnp.int32) + return State( + cube=cube, + step_count=step_count, + key=key, + ) class ScramblingGenerator(Generator): @@ -59,36 +67,18 @@ def __init__( self.time_limit = time_limit super().__init__(cube_size=cube_size) - def __call__(self, key: chex.PRNGKey) -> State: - """Generates a `RubiksCube` state by scrambling a solved cube a fixed number of times. - - Returns: - A `RubiksCube` state. - """ - key, scramble_key = jax.random.split(key) - flattened_actions_in_scramble = jax.random.randint( - scramble_key, + def generate_actions_for_scramble(self, key: chex.PRNGKey) -> chex.Array: + return jax.random.randint( + key=key, minval=0, maxval=len(Face) * (self.cube_size // 2) * len(CubeMovementAmount), shape=(self.num_scrambles_on_reset,), dtype=jnp.int32, ) - cube = scramble_solved_cube( + + def generate_cube(self, key: chex.PRNGKey) -> Cube: + flattened_actions_in_scramble = self.generate_actions_for_scramble(key=key) + return scramble_solved_cube( flattened_actions_in_scramble=flattened_actions_in_scramble, cube_size=self.cube_size, ) - action_history = jnp.zeros( - shape=(self.num_scrambles_on_reset + self.time_limit, 3), dtype=jnp.int32 - ) - action_history = action_history.at[: self.num_scrambles_on_reset].set( - unflatten_action( - flattened_action=flattened_actions_in_scramble, cube_size=self.cube_size - ).transpose() - ) - step_count = jnp.array(0, jnp.int32) - return State( - cube=cube, - step_count=step_count, - key=key, - action_history=action_history, - ) diff --git a/jumanji/environments/logic/rubiks_cube/types.py b/jumanji/environments/logic/rubiks_cube/types.py index cb2975f43..507b610a7 100644 --- a/jumanji/environments/logic/rubiks_cube/types.py +++ b/jumanji/environments/logic/rubiks_cube/types.py @@ -15,7 +15,6 @@ from typing import TYPE_CHECKING, NamedTuple import chex -from chex import Array from typing_extensions import TypeAlias if TYPE_CHECKING: @@ -32,14 +31,11 @@ class State: cube: 3D array whose cells contain the index of the corresponding colour of the sticker in the scramble. step_count: specifies how many timesteps have elapsed since environment reset. - action_history: array that indicates the entire history of applied moves (including those taken - on scrambling the cube in the environment reset). key: random key used for auto-reset. """ cube: Cube # (6, cube_size, cube_size) step_count: chex.Numeric # () - action_history: Array # (num_scrambles_on_reset + time_limit, 3) key: chex.PRNGKey # (2,) diff --git a/jumanji/environments/logic/rubiks_cube/utils_test.py b/jumanji/environments/logic/rubiks_cube/utils_test.py index bcb4f0c16..c28154002 100644 --- a/jumanji/environments/logic/rubiks_cube/utils_test.py +++ b/jumanji/environments/logic/rubiks_cube/utils_test.py @@ -147,13 +147,11 @@ def test_solved_reward( solved_state = State( cube=solved_cube, step_count=jnp.array(0, jnp.int32), - action_history=jnp.array(0, jnp.int32), key=jax.random.PRNGKey(0), ) differently_stickered_state = State( cube=differently_stickered_cube, step_count=jnp.array(0, jnp.int32), - action_history=jnp.array(0, jnp.int32), key=jax.random.PRNGKey(0), ) assert jnp.equal(SparseRewardFn()(solved_state), 1.0) @@ -178,7 +176,6 @@ def test_moves_nontrivial( move_solved_state = State( cube=move_solved_cube, step_count=jnp.array(0, jnp.int32), - action_history=jnp.array(0, jnp.int32), key=jax.random.PRNGKey(0), ) assert jnp.equal(SparseRewardFn()(move_solved_state), 0.0) From c15e7227015150737809b8f0ed445a47ea8f5748 Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Mon, 27 Mar 2023 14:33:03 +0100 Subject: [PATCH 08/58] Cleanup --- .../logic/rubiks_cube/conftest.py | 5 +++- jumanji/environments/logic/rubiks_cube/env.py | 24 +++++++------------ .../logic/rubiks_cube/env_test.py | 7 +++++- .../logic/rubiks_cube/env_viewer.py | 22 ++++++++--------- .../logic/rubiks_cube/generator.py | 13 ++++++++-- 5 files changed, 40 insertions(+), 31 deletions(-) diff --git a/jumanji/environments/logic/rubiks_cube/conftest.py b/jumanji/environments/logic/rubiks_cube/conftest.py index 047c46958..c1e7ff671 100644 --- a/jumanji/environments/logic/rubiks_cube/conftest.py +++ b/jumanji/environments/logic/rubiks_cube/conftest.py @@ -18,6 +18,7 @@ from jumanji.environments.logic.rubiks_cube.constants import Face from jumanji.environments.logic.rubiks_cube.env import RubiksCube +from jumanji.environments.logic.rubiks_cube.generator import ScramblingGenerator from jumanji.environments.logic.rubiks_cube.utils import make_solved_cube @@ -76,4 +77,6 @@ def expected_scramble_result() -> chex.Array: @pytest.fixture def rubiks_cube() -> RubiksCube: """Instantiates a `RubiksCube` environment with 10 scrambles on reset.""" - return RubiksCube(num_scrambles_on_reset=10) + return RubiksCube( + generator=ScramblingGenerator(cube_size=3, num_scrambles_on_reset=10) + ) diff --git a/jumanji/environments/logic/rubiks_cube/env.py b/jumanji/environments/logic/rubiks_cube/env.py index 57220dbd4..de7077b0e 100644 --- a/jumanji/environments/logic/rubiks_cube/env.py +++ b/jumanji/environments/logic/rubiks_cube/env.py @@ -21,7 +21,10 @@ from jumanji import specs from jumanji.env import Environment -from jumanji.environments.logic.rubiks_cube.constants import Face +from jumanji.environments.logic.rubiks_cube.constants import ( + DEFAULT_STICKER_COLORS, + Face, +) from jumanji.environments.logic.rubiks_cube.env_viewer import ( DefaultRubiksCubeViewer, RubiksCubeViewer, @@ -83,7 +86,6 @@ def __init__( self, cube_size: int = 3, time_limit: int = 200, - num_scrambles_on_reset: int = 100, reward_fn: Optional[RewardFn] = None, env_viewer: Optional[RubiksCubeViewer] = None, generator: Optional[Generator] = None, @@ -93,10 +95,6 @@ def __init__( Args: cube_size: the size of the cube, i.e. length of an edge. Defaults to 3. time_limit: the number of steps allowed before an episode terminates. Defaults to 200. - num_scrambles_on_reset: the number of scrambles done from a solved Rubik's Cube in the - generation of a random instance. The lower, the closer to a solved cube the reset - state is. Defaults to 100. - Note that this argument will be ignored if a custom generator is passed. reward_fn: `RewardFn` whose `__call__` method computes the reward given the new state. Implemented options are [`SparseRewardFn`]. Defaults to `SparseRewardFn`. env_viewer: RubiksCubeViewer to support rendering and animation methods. @@ -115,20 +113,16 @@ def __init__( raise ValueError( f"The time_limit must be positive, but received time_limit={time_limit}" ) - if num_scrambles_on_reset < 0: - raise ValueError( - f"The num_scrambles_on_reset must be non-negative, " - f"but received num_scrambles_on_reset={num_scrambles_on_reset}" - ) + self.cube_size = cube_size self.time_limit = time_limit - self.num_scrambles_on_reset = num_scrambles_on_reset self.reward_function = reward_fn or SparseRewardFn() - self._env_viewer = env_viewer or DefaultRubiksCubeViewer(cube_size=cube_size) + self._env_viewer = env_viewer or DefaultRubiksCubeViewer( + sticker_colors=DEFAULT_STICKER_COLORS, cube_size=cube_size + ) self._generator = generator or ScramblingGenerator( cube_size=cube_size, - num_scrambles_on_reset=num_scrambles_on_reset, - time_limit=self.time_limit, + num_scrambles_on_reset=100, ) def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: diff --git a/jumanji/environments/logic/rubiks_cube/env_test.py b/jumanji/environments/logic/rubiks_cube/env_test.py index a70f04972..075308716 100644 --- a/jumanji/environments/logic/rubiks_cube/env_test.py +++ b/jumanji/environments/logic/rubiks_cube/env_test.py @@ -21,6 +21,7 @@ import pytest_mock from jumanji.environments.logic.rubiks_cube.env import RubiksCube +from jumanji.environments.logic.rubiks_cube.generator import ScramblingGenerator from jumanji.environments.logic.rubiks_cube.types import State from jumanji.testing.env_not_smoke import check_env_does_not_smoke from jumanji.testing.pytrees import assert_is_jax_array_tree @@ -76,7 +77,11 @@ def test_rubiks_cube__step(rubiks_cube: RubiksCube) -> None: @pytest.mark.parametrize("cube_size", [3, 4, 5]) def test_rubiks_cube__does_not_smoke(cube_size: int) -> None: """Test that we can run an episode without any errors.""" - env = RubiksCube(cube_size=cube_size, time_limit=10, num_scrambles_on_reset=5) + env = RubiksCube( + cube_size=cube_size, + time_limit=10, + generator=ScramblingGenerator(cube_size=cube_size, num_scrambles_on_reset=5), + ) check_env_does_not_smoke(env) diff --git a/jumanji/environments/logic/rubiks_cube/env_viewer.py b/jumanji/environments/logic/rubiks_cube/env_viewer.py index f40ed196d..bc227e3ca 100644 --- a/jumanji/environments/logic/rubiks_cube/env_viewer.py +++ b/jumanji/environments/logic/rubiks_cube/env_viewer.py @@ -19,16 +19,16 @@ from matplotlib import pyplot as plt import jumanji.environments -from jumanji.environments.logic.rubiks_cube.constants import ( - DEFAULT_STICKER_COLORS, - Face, -) +from jumanji.environments.logic.rubiks_cube.constants import Face from jumanji.environments.logic.rubiks_cube.types import State class RubiksCubeViewer: """Abstract viewer class to support rendering and animation""" + def __init__(self, cube_size: int): + self.cube_size = cube_size + def render(self, state: State) -> None: """Render frames of the environment for a given state using matplotlib. @@ -40,8 +40,8 @@ def render(self, state: State) -> None: def animate( self, states: Sequence[State], - interval: int = 200, - save_path: Optional[str] = None, + interval: int, + save_path: Optional[str], ) -> matplotlib.animation.FuncAnimation: """Create an animation from a sequence of environment states. @@ -66,18 +66,16 @@ def close(self) -> None: class DefaultRubiksCubeViewer(RubiksCubeViewer): - def __init__(self, sticker_colors: Optional[list] = None, cube_size: int = 3): + def __init__(self, sticker_colors: Optional[list], cube_size: int): """ Args: sticker_colors: colors used in rendering the faces of the Rubik's cube. - Defaults to `DEFAULT_STICKER_COLORS`. cube_size: size of cube to view """ - self.cube_size = cube_size - sticker_colors = sticker_colors or DEFAULT_STICKER_COLORS self.sticker_colors_cmap = matplotlib.colors.ListedColormap(sticker_colors) self.figure_name = f"{cube_size}x{cube_size}x{cube_size} Rubik's Cube" self.figure_size = (6.0, 6.0) + super().__init__(cube_size=cube_size) def render(self, state: State) -> None: """Render frames of the environment for a given state using matplotlib. @@ -93,8 +91,8 @@ def render(self, state: State) -> None: def animate( self, states: Sequence[State], - interval: int = 200, - save_path: Optional[str] = None, + interval: int, + save_path: Optional[str], ) -> matplotlib.animation.FuncAnimation: """Create an animation from a sequence of environment states. diff --git a/jumanji/environments/logic/rubiks_cube/generator.py b/jumanji/environments/logic/rubiks_cube/generator.py index b5c23dcff..646003754 100644 --- a/jumanji/environments/logic/rubiks_cube/generator.py +++ b/jumanji/environments/logic/rubiks_cube/generator.py @@ -61,10 +61,19 @@ def __init__( self, cube_size: int, num_scrambles_on_reset: int, - time_limit: int, ): + """ + Args: + num_scrambles_on_reset: the number of scrambles done from a solved Rubik's Cube in the + generation of a random instance. The lower, the closer to a solved cube the reset + state is. + """ + if num_scrambles_on_reset < 0: + raise ValueError( + f"The num_scrambles_on_reset must be non-negative, " + f"but received num_scrambles_on_reset={num_scrambles_on_reset}" + ) self.num_scrambles_on_reset = num_scrambles_on_reset - self.time_limit = time_limit super().__init__(cube_size=cube_size) def generate_actions_for_scramble(self, key: chex.PRNGKey) -> chex.Array: From 8b85633e76c6324cdc0c81c2f4799b9e00da541f Mon Sep 17 00:00:00 2001 From: Tristan Kalloniatis Date: Mon, 27 Mar 2023 14:37:14 +0100 Subject: [PATCH 09/58] Update jumanji/environments/logic/rubiks_cube/env.py Co-authored-by: Daniel <57721552+dluo96@users.noreply.github.com> --- jumanji/environments/logic/rubiks_cube/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jumanji/environments/logic/rubiks_cube/env.py b/jumanji/environments/logic/rubiks_cube/env.py index de7077b0e..417c2498b 100644 --- a/jumanji/environments/logic/rubiks_cube/env.py +++ b/jumanji/environments/logic/rubiks_cube/env.py @@ -100,7 +100,7 @@ def __init__( env_viewer: RubiksCubeViewer to support rendering and animation methods. Implemented options are [`DefaultRubiksCubeViewer`]. Defaults to `DefaultRubiksCubeViewer`. - generator: Generator to generate problem instances on environment reset. + generator: `Generator` used to generate problem instances on environment reset. Implemented options are [`ScramblingGenerator`]. Defaults to `ScramblingGenerator`. """ From 7bb1c40183fbef611f81f2f735f732da0954cc28 Mon Sep 17 00:00:00 2001 From: Tristan Kalloniatis Date: Mon, 27 Mar 2023 14:37:39 +0100 Subject: [PATCH 10/58] Update jumanji/environments/logic/rubiks_cube/env_viewer.py Co-authored-by: Daniel <57721552+dluo96@users.noreply.github.com> --- jumanji/environments/logic/rubiks_cube/env_viewer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jumanji/environments/logic/rubiks_cube/env_viewer.py b/jumanji/environments/logic/rubiks_cube/env_viewer.py index bc227e3ca..57d15f079 100644 --- a/jumanji/environments/logic/rubiks_cube/env_viewer.py +++ b/jumanji/environments/logic/rubiks_cube/env_viewer.py @@ -70,7 +70,7 @@ def __init__(self, sticker_colors: Optional[list], cube_size: int): """ Args: sticker_colors: colors used in rendering the faces of the Rubik's cube. - cube_size: size of cube to view + cube_size: size of cube to view. """ self.sticker_colors_cmap = matplotlib.colors.ListedColormap(sticker_colors) self.figure_name = f"{cube_size}x{cube_size}x{cube_size} Rubik's Cube" From 470ba662d9012d0274e73711909b6759a8a66e03 Mon Sep 17 00:00:00 2001 From: Tristan Kalloniatis Date: Mon, 27 Mar 2023 14:37:51 +0100 Subject: [PATCH 11/58] Update jumanji/environments/logic/rubiks_cube/generator.py Co-authored-by: Daniel <57721552+dluo96@users.noreply.github.com> --- jumanji/environments/logic/rubiks_cube/generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jumanji/environments/logic/rubiks_cube/generator.py b/jumanji/environments/logic/rubiks_cube/generator.py index 646003754..13151142a 100644 --- a/jumanji/environments/logic/rubiks_cube/generator.py +++ b/jumanji/environments/logic/rubiks_cube/generator.py @@ -55,7 +55,7 @@ def __call__(self, key: chex.PRNGKey) -> State: class ScramblingGenerator(Generator): - """Generates instances by applying a given number of scrambles to a solved cube""" + """Generates instances by applying a given number of scrambles to a solved cube.""" def __init__( self, From f3a7be420a340bbc548efb80cad04c513410c8bd Mon Sep 17 00:00:00 2001 From: Tristan Kalloniatis Date: Mon, 27 Mar 2023 14:38:00 +0100 Subject: [PATCH 12/58] Update jumanji/environments/logic/rubiks_cube/utils.py Co-authored-by: Daniel <57721552+dluo96@users.noreply.github.com> --- jumanji/environments/logic/rubiks_cube/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/jumanji/environments/logic/rubiks_cube/utils.py b/jumanji/environments/logic/rubiks_cube/utils.py index e8a16bbce..ed7b4a0e1 100644 --- a/jumanji/environments/logic/rubiks_cube/utils.py +++ b/jumanji/environments/logic/rubiks_cube/utils.py @@ -53,9 +53,11 @@ def make_solved_cube(cube_size: int) -> Cube: def is_solved(cube: Cube) -> chex.Array: - """Check if a cube is solved + """Check if a cube is solved. + Args: cube: the cube to check. + Returns: Whether or not the cube is solved (all faces have a unique id). """ From ee3def9b2d3757fb73c26ce0b322a2463abbbd1c Mon Sep 17 00:00:00 2001 From: Tristan Kalloniatis Date: Mon, 27 Mar 2023 14:38:25 +0100 Subject: [PATCH 13/58] Update jumanji/environments/logic/rubiks_cube/generator.py Co-authored-by: Daniel <57721552+dluo96@users.noreply.github.com> --- jumanji/environments/logic/rubiks_cube/generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jumanji/environments/logic/rubiks_cube/generator.py b/jumanji/environments/logic/rubiks_cube/generator.py index 13151142a..6673139dd 100644 --- a/jumanji/environments/logic/rubiks_cube/generator.py +++ b/jumanji/environments/logic/rubiks_cube/generator.py @@ -24,7 +24,7 @@ class Generator(abc.ABC): - """Base class for generators for the RubiksCube environment.""" + """Base class for generators for the `RubiksCube` environment.""" def __init__(self, cube_size: int): """Initialises a RubiksCube generator for resetting the environment. From 8b9356c2554f46298335a54d16ff527912149a66 Mon Sep 17 00:00:00 2001 From: Tristan Kalloniatis Date: Mon, 27 Mar 2023 14:38:34 +0100 Subject: [PATCH 14/58] Update jumanji/environments/logic/rubiks_cube/utils.py Co-authored-by: Daniel <57721552+dluo96@users.noreply.github.com> --- jumanji/environments/logic/rubiks_cube/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jumanji/environments/logic/rubiks_cube/utils.py b/jumanji/environments/logic/rubiks_cube/utils.py index ed7b4a0e1..ab2b1c341 100644 --- a/jumanji/environments/logic/rubiks_cube/utils.py +++ b/jumanji/environments/logic/rubiks_cube/utils.py @@ -67,7 +67,7 @@ def is_solved(cube: Cube) -> chex.Array: def sparse_reward_function(state: State) -> chex.Array: - """A sparse reward function: +1 if the cube is solved, otherwise 0""" + """A sparse reward function: +1 if the cube is solved, otherwise 0.""" solved = is_solved(state.cube) return jnp.array(solved, float) From afdc1c9acedd4f113a02198d4ccedcf732ec25bd Mon Sep 17 00:00:00 2001 From: Tristan Kalloniatis Date: Mon, 27 Mar 2023 14:38:45 +0100 Subject: [PATCH 15/58] Update jumanji/environments/logic/rubiks_cube/utils.py Co-authored-by: Daniel <57721552+dluo96@users.noreply.github.com> --- jumanji/environments/logic/rubiks_cube/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/jumanji/environments/logic/rubiks_cube/utils.py b/jumanji/environments/logic/rubiks_cube/utils.py index ab2b1c341..4b31fb9eb 100644 --- a/jumanji/environments/logic/rubiks_cube/utils.py +++ b/jumanji/environments/logic/rubiks_cube/utils.py @@ -42,10 +42,12 @@ def make_solved_cube(cube_size: int) -> Cube: """Make a solved cube of a given size. + Args: cube_size: the size of the cube to generate. + Returns: - A solved cube, ie with all faces a uniform id (sticker color). + A solved cube, i.e. with all faces a uniform id (sticker color). """ return jnp.stack( [face.value * jnp.ones((cube_size, cube_size), dtype=jnp.int8) for face in Face] From 0d0abdb4cad2329a5d75fa27757fe80571959852 Mon Sep 17 00:00:00 2001 From: Tristan Kalloniatis Date: Mon, 27 Mar 2023 14:38:58 +0100 Subject: [PATCH 16/58] Update jumanji/environments/logic/rubiks_cube/generator.py Co-authored-by: Daniel <57721552+dluo96@users.noreply.github.com> --- jumanji/environments/logic/rubiks_cube/generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jumanji/environments/logic/rubiks_cube/generator.py b/jumanji/environments/logic/rubiks_cube/generator.py index 6673139dd..74256d28f 100644 --- a/jumanji/environments/logic/rubiks_cube/generator.py +++ b/jumanji/environments/logic/rubiks_cube/generator.py @@ -27,7 +27,7 @@ class Generator(abc.ABC): """Base class for generators for the `RubiksCube` environment.""" def __init__(self, cube_size: int): - """Initialises a RubiksCube generator for resetting the environment. + """Initialises a `RubiksCube` generator for resetting the environment. Args: cube_size: the size of the cube to generate instances for. From 72846b1d9b01d1a89df6a96a5799e54ec275208e Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Mon, 27 Mar 2023 16:57:29 +0100 Subject: [PATCH 17/58] Spacing --- jumanji/environments/logic/rubiks_cube/env.py | 6 ++-- .../environments/logic/rubiks_cube/utils.py | 28 ++++++++++++++++--- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/jumanji/environments/logic/rubiks_cube/env.py b/jumanji/environments/logic/rubiks_cube/env.py index 417c2498b..0baa422d4 100644 --- a/jumanji/environments/logic/rubiks_cube/env.py +++ b/jumanji/environments/logic/rubiks_cube/env.py @@ -226,13 +226,11 @@ def action_spec(self) -> specs.MultiDiscreteArray: def _state_to_observation(self, state: State) -> Observation: return Observation(cube=state.cube, step_count=state.step_count) - def render(self, state: State, save_path: Optional[str] = None) -> None: + def render(self, state: State) -> None: """Renders the current state of the cube. Args: state: the current state to be rendered. - save_path: the path where the image should be saved. If it is None, the plot - will not be stored. """ self._env_viewer.render(state=state) @@ -248,7 +246,7 @@ def animate( states: a list of `State` objects representing the sequence of game states. interval: the delay between frames in milliseconds, default to 200. save_path: the path where the animation file should be saved. If it is None, the plot - will not be stored. + will not be saved. Returns: animation.FuncAnimation: the animation object that was created. diff --git a/jumanji/environments/logic/rubiks_cube/utils.py b/jumanji/environments/logic/rubiks_cube/utils.py index 4b31fb9eb..0ce1b7186 100644 --- a/jumanji/environments/logic/rubiks_cube/utils.py +++ b/jumanji/environments/logic/rubiks_cube/utils.py @@ -42,10 +42,10 @@ def make_solved_cube(cube_size: int) -> Cube: """Make a solved cube of a given size. - + Args: cube_size: the size of the cube to generate. - + Returns: A solved cube, i.e. with all faces a uniform id (sticker color). """ @@ -56,10 +56,10 @@ def make_solved_cube(cube_size: int) -> Cube: def is_solved(cube: Cube) -> chex.Array: """Check if a cube is solved. - + Args: cube: the cube to check. - + Returns: Whether or not the cube is solved (all faces have a unique id). """ @@ -117,9 +117,11 @@ def do_rotation( def generate_up_move(amount: CubeMovementAmount, depth: int) -> Callable[[Cube], Cube]: """Generate the move corresponding to turning the up face. + Args: amount: how much to turn the face by. depth: the number of layers into the cube where the move is performed. + Returns: A callable that performs the specified up move. """ @@ -162,9 +164,11 @@ def generate_front_move( amount: CubeMovementAmount, depth: int ) -> Callable[[Cube], Cube]: """Generate the move corresponding to turning the front face. + Args: amount: how much to turn the face by. depth: the number of layers into the cube where the move is performed. + Returns: A callable that performs the specified front move. """ @@ -207,9 +211,11 @@ def generate_right_move( amount: CubeMovementAmount, depth: int ) -> Callable[[Cube], Cube]: """Generate the move corresponding to turning the right face. + Args: amount: how much to turn the face by. depth: the number of layers into the cube where the move is performed. + Returns: A callable that performs the specified right move. """ @@ -252,9 +258,11 @@ def generate_back_move( amount: CubeMovementAmount, depth: int ) -> Callable[[Cube], Cube]: """Generate the move corresponding to turning the back face. + Args: amount: how much to turn the face by. depth: the number of layers into the cube where the move is performed. + Returns: A callable that performs the specified back move. """ @@ -297,9 +305,11 @@ def generate_left_move( amount: CubeMovementAmount, depth: int ) -> Callable[[Cube], Cube]: """Generate the move corresponding to turning the left face. + Args: amount: how much to turn the face by. depth: the number of layers into the cube where the move is performed. + Returns: A callable that performs the specified left move. """ @@ -342,9 +352,11 @@ def generate_down_move( amount: CubeMovementAmount, depth: int ) -> Callable[[Cube], Cube]: """Generate the move corresponding to turning the down face. + Args: amount: how much to turn the face by. depth: the number of layers into the cube where the move is performed. + Returns: A callable that performs the specified down move. """ @@ -402,9 +414,11 @@ def generate_all_moves(cube_size: int) -> List[Callable[[Cube], Cube]]: def unflatten_action(flattened_action: chex.Array, cube_size: int) -> chex.Array: """Translate from the flat action representation to the unflattened representation. + Args: flattened_action: index into the sequence of all moves. cube_size: the size of the cube in question. + Returns: Unflattened action, ie a tuple: - face (0-5). This indicates the face on which the layer will turn. @@ -441,6 +455,7 @@ def unflatten_action(flattened_action: chex.Array, cube_size: int) -> chex.Array def flatten_action(unflattened_action: chex.Array, cube_size: int) -> chex.Array: """Inverse of the `unflatten_action` method. + Args: unflattened_action: flattened action representation, a tuple: - face (0-5). This indicates the face on which the layer will turn. @@ -448,6 +463,7 @@ def flatten_action(unflattened_action: chex.Array, cube_size: int) -> chex.Array the turn will take place. - amount (0-2). This indicates the amount of turning. cube_size: the size of the cube in question. + Returns: The flattened action representation, ie an index into the sequence of all moves. @@ -483,9 +499,11 @@ def flatten_action(unflattened_action: chex.Array, cube_size: int) -> chex.Array def rotate_cube(cube: Cube, flattened_action: chex.Array) -> Cube: """Apply a flattened action (index into the sequence of all moves) to a cube. + Args: cube: the cube on which to perform the move. flattened_action: the action to perform, in the flattened representation. + Returns: The rotated cube. """ @@ -499,10 +517,12 @@ def scramble_solved_cube( cube_size: int, ) -> Cube: """Return a scrambled cube according to a given sequence of flat actions. + Args: flattened_actions_in_scramble: the sequence of moves to perform, in their flat representation. cube_size: the size of the cube to return. + Returns: The scrambled cube. """ From 14904ff535649ae47c3e857be8ac044d4b36800b Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Mon, 27 Mar 2023 17:20:54 +0100 Subject: [PATCH 18/58] More generic --- .../logic/minesweeper/conftest.py | 5 ++- jumanji/environments/logic/minesweeper/env.py | 42 +++++++++---------- .../logic/minesweeper/env_test.py | 17 ++++---- .../logic/minesweeper/generator.py | 30 +++++++------ .../environments/logic/minesweeper/reward.py | 4 +- .../networks/minesweeper/actor_critic.py | 4 +- 6 files changed, 51 insertions(+), 51 deletions(-) diff --git a/jumanji/environments/logic/minesweeper/conftest.py b/jumanji/environments/logic/minesweeper/conftest.py index 4893cd765..b032dd70f 100644 --- a/jumanji/environments/logic/minesweeper/conftest.py +++ b/jumanji/environments/logic/minesweeper/conftest.py @@ -18,13 +18,16 @@ from jumanji.environments.logic.minesweeper.constants import UNEXPLORED_ID from jumanji.environments.logic.minesweeper.env import Minesweeper +from jumanji.environments.logic.minesweeper.generator import UniformSamplingGenerator from jumanji.environments.logic.minesweeper.types import State @pytest.fixture def minesweeper_env() -> Minesweeper: """Fixture for a default minesweeper env""" - return Minesweeper() + return Minesweeper( + generator=UniformSamplingGenerator(num_rows=10, num_cols=10, num_mines=10) + ) @pytest.fixture diff --git a/jumanji/environments/logic/minesweeper/env.py b/jumanji/environments/logic/minesweeper/env.py index 632750e61..cb3ac335e 100644 --- a/jumanji/environments/logic/minesweeper/env.py +++ b/jumanji/environments/logic/minesweeper/env.py @@ -35,7 +35,7 @@ ) from jumanji.environments.logic.minesweeper.generator import ( Generator, - SamplingGenerator, + UniformSamplingGenerator, ) from jumanji.environments.logic.minesweeper.reward import DefaultRewardFn, RewardFn from jumanji.environments.logic.minesweeper.types import Observation, State @@ -96,8 +96,6 @@ class Minesweeper(Environment[State]): def __init__( self, - num_rows: int = 10, - num_cols: int = 10, reward_function: Optional[RewardFn] = None, done_function: Optional[DoneFn] = None, env_viewer: Optional[MinesweeperViewer] = None, @@ -106,8 +104,7 @@ def __init__( """Instantiate a `Minesweeper` environment. Args: - num_rows: number of rows, i.e. height of the board. Defaults to 10. - num_cols: number of columns, i.e. width of the board. Defaults to 10. + reward_function: `RewardFn` whose `__call__` method computes the reward of an environment transition based on the given current state and selected action. Implemented options are [`DefaultRewardFn`]. Defaults to `DefaultRewardFn`. @@ -120,14 +117,11 @@ def __init__( generator: Generator to generate problem instances on environment reset. Implemented options are [`SamplingGenerator`]. Defaults to `SamplingGenerator`. + The generator will have attributes: + - num_rows: number of rows, i.e. height of the board. Defaults to 10. + - num_cols: number of columns, i.e. width of the board. Defaults to 10. + - num_mines: number of mines generated. Defaults to 10. """ - if num_rows <= 1 or num_cols <= 1: - raise ValueError( - f"Should make a board of height and width greater than 1, " - f"got num_rows={num_rows}, num_cols={num_cols}" - ) - self.num_rows = num_rows - self.num_cols = num_cols self.reward_function = reward_function or DefaultRewardFn( revealed_empty_square_reward=REVEALED_EMPTY_SQUARE_REWARD, revealed_mine_reward=REVEALED_MINE_REWARD, @@ -135,11 +129,11 @@ def __init__( ) self.done_function = done_function or DefaultDoneFn() - self._env_viewer = env_viewer or DefaultMinesweeperViewer( - num_rows=num_rows, num_cols=num_cols + self._generator = generator or UniformSamplingGenerator( + num_rows=10, num_cols=10, num_mines=10 ) - self._generator = generator or SamplingGenerator( - num_rows=num_rows, num_cols=num_cols, num_mines=10 + self._env_viewer = env_viewer or DefaultMinesweeperViewer( + num_rows=self._generator.num_rows, num_cols=self._generator.num_cols ) def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: @@ -206,14 +200,14 @@ def observation_spec(self) -> specs.Spec[Observation]: - step_count: BoundedArray (int32) of shape (). """ board = specs.BoundedArray( - shape=(self.num_rows, self.num_cols), + shape=(self._generator.num_rows, self._generator.num_cols), dtype=jnp.int32, minimum=-1, maximum=PATCH_SIZE * PATCH_SIZE - 1, name="board", ) action_mask = specs.BoundedArray( - shape=(self.num_rows, self.num_cols), + shape=(self._generator.num_rows, self._generator.num_cols), dtype=bool, minimum=False, maximum=True, @@ -223,14 +217,14 @@ def observation_spec(self) -> specs.Spec[Observation]: shape=(), dtype=jnp.int32, minimum=0, - maximum=self.num_rows * self.num_cols - 1, + maximum=self._generator.num_rows * self._generator.num_cols - 1, name="num_mines", ) step_count = specs.BoundedArray( shape=(), dtype=jnp.int32, minimum=0, - maximum=self.num_rows * self.num_cols, + maximum=self._generator.num_rows * self._generator.num_cols, name="step_count", ) return specs.Spec( @@ -250,7 +244,9 @@ def action_spec(self) -> specs.MultiDiscreteArray: action_spec: `specs.MultiDiscreteArray` object. """ return specs.MultiDiscreteArray( - num_values=jnp.array([self.num_rows, self.num_cols], jnp.int32), + num_values=jnp.array( + [self._generator.num_rows, self._generator.num_cols], jnp.int32 + ), name="action", dtype=jnp.int32, ) @@ -259,7 +255,9 @@ def _state_to_observation(self, state: State) -> Observation: return Observation( board=state.board, action_mask=jnp.equal(state.board, UNEXPLORED_ID), - num_mines=jnp.array(10, jnp.int32), # todo: make this more generic + num_mines=jnp.array( + self._generator.num_mines, jnp.int32 + ), step_count=state.step_count, ) diff --git a/jumanji/environments/logic/minesweeper/env_test.py b/jumanji/environments/logic/minesweeper/env_test.py index 0c720d60d..fb75d968f 100644 --- a/jumanji/environments/logic/minesweeper/env_test.py +++ b/jumanji/environments/logic/minesweeper/env_test.py @@ -108,11 +108,11 @@ def test_minesweeper_env_reset(minesweeper_env: Minesweeper) -> None: assert isinstance(timestep, TimeStep) assert isinstance(state, State) assert state.step_count == 0 - assert state.flat_mine_locations.shape == (minesweeper_env.num_mines,) - assert timestep.observation.num_mines == minesweeper_env.num_mines + assert state.flat_mine_locations.shape == (minesweeper_env._generator.num_mines,) + assert timestep.observation.num_mines == minesweeper_env._generator.num_mines assert state.board.shape == ( - minesweeper_env.num_rows, - minesweeper_env.num_cols, + minesweeper_env._generator.num_rows, + minesweeper_env._generator.num_cols, ) assert jnp.array_equal(state.board, timestep.observation.board) assert timestep.observation.step_count == 0 @@ -190,9 +190,9 @@ def test_minesweeper_env_solved(minesweeper_env: Minesweeper) -> None: step_fn = jit(minesweeper_env.step) collected_rewards = [] collected_step_types = [] - for i in range(minesweeper_env.num_rows): - for j in range(minesweeper_env.num_cols): - flat_location = i * minesweeper_env.num_cols + j + for i in range(minesweeper_env._generator.num_rows): + for j in range(minesweeper_env._generator.num_cols): + flat_location = i * minesweeper_env._generator.num_cols + j if flat_location in state.flat_mine_locations: continue action = jnp.array([i, j], dtype=jnp.int32) @@ -200,7 +200,8 @@ def test_minesweeper_env_solved(minesweeper_env: Minesweeper) -> None: collected_rewards.append(timestep.reward) collected_step_types.append(timestep.step_type) expected_episode_length = ( - minesweeper_env.num_rows * minesweeper_env.num_cols - minesweeper_env.num_mines + minesweeper_env._generator.num_rows * minesweeper_env._generator.num_cols + - minesweeper_env._generator.num_mines ) assert collected_rewards == [REVEALED_EMPTY_SQUARE_REWARD] * expected_episode_length assert collected_step_types == [StepType.MID] * (expected_episode_length - 1) + [ diff --git a/jumanji/environments/logic/minesweeper/generator.py b/jumanji/environments/logic/minesweeper/generator.py index 3d39e77f5..b975b2971 100644 --- a/jumanji/environments/logic/minesweeper/generator.py +++ b/jumanji/environments/logic/minesweeper/generator.py @@ -26,14 +26,26 @@ class Generator(abc.ABC): """Base class for generators for the Minesweeper environment.""" - def __init__(self, num_rows: int, num_cols: int): + def __init__(self, num_rows: int, num_cols: int, num_mines: int): """Initialises a Minesweeper generator for resetting the environment. Args: num_rows: number of rows, i.e. height of the board. num_cols: number of columns, i.e. width of the board. + num_mines: number of mines to place on the board. """ + if num_rows <= 1 or num_cols <= 1: + raise ValueError( + f"Should make a board of height and width greater than 1, " + f"got num_rows={num_rows}, num_cols={num_cols}" + ) + if num_mines < 0 or num_mines >= num_rows * num_cols: + raise ValueError( + f"Number of mines should be constrained between 0 and the size of the board, " + f"got {num_mines}" + ) self.num_rows = num_rows self.num_cols = num_cols + self.num_mines = num_mines @abc.abstractmethod def generate_flat_mine_locations(self, key: chex.PRNGKey) -> chex.Array: @@ -61,23 +73,9 @@ def __call__(self, key: chex.PRNGKey) -> State: return state -class SamplingGenerator(Generator): +class UniformSamplingGenerator(Generator): """Generates instances by sampling a given number of mines (without replacement).""" - def __init__( - self, - num_rows: int, - num_cols: int, - num_mines: int, - ): - if num_mines < 0 or num_mines >= num_rows * num_cols: - raise ValueError( - f"Number of mines should be constrained between 0 and the size of the board, " - f"got {num_mines}" - ) - self.num_mines = num_mines - super().__init__(num_rows=num_rows, num_cols=num_cols) - def generate_flat_mine_locations(self, key: chex.PRNGKey) -> chex.Array: return create_flat_mine_locations( key=key, diff --git a/jumanji/environments/logic/minesweeper/reward.py b/jumanji/environments/logic/minesweeper/reward.py index 7e8e6f61c..28ee5db89 100644 --- a/jumanji/environments/logic/minesweeper/reward.py +++ b/jumanji/environments/logic/minesweeper/reward.py @@ -40,7 +40,7 @@ def __init__( revealed_mine_reward: float, invalid_action_reward: float, ): - self.revelead_empty_square_reward = revealed_empty_square_reward + self.revealed_empty_square_reward = revealed_empty_square_reward self.revelead_mine_reward = revealed_mine_reward self.invalid_action_reward = invalid_action_reward @@ -50,7 +50,7 @@ def __call__(self, state: State, action: chex.Array) -> chex.Array: jnp.where( explored_mine(state=state, action=action), jnp.array(self.revelead_mine_reward, float), - jnp.array(self.revelead_empty_square_reward, float), + jnp.array(self.revealed_empty_square_reward, float), ), jnp.array(self.invalid_action_reward, float), ) diff --git a/jumanji/training/networks/minesweeper/actor_critic.py b/jumanji/training/networks/minesweeper/actor_critic.py index a14e04db6..95fcbe6f2 100644 --- a/jumanji/training/networks/minesweeper/actor_critic.py +++ b/jumanji/training/networks/minesweeper/actor_critic.py @@ -40,8 +40,8 @@ def make_actor_critic_networks_minesweeper( final_layer_dims: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `Minesweeper` environment.""" - board_height = minesweeper.num_rows - board_width = minesweeper.num_cols + board_height = minesweeper._generator.num_rows + board_width = minesweeper._generator.num_cols vocab_size = 1 + PATCH_SIZE**2 # unexplored, or 0, 1, ..., 8 parametric_action_distribution = FactorisedActionSpaceParametricDistribution( From e87ddeb96121e463fb4d54fd48787070842ff400 Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Mon, 27 Mar 2023 17:23:12 +0100 Subject: [PATCH 19/58] Minor cleaning --- jumanji/environments/logic/minesweeper/env.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/jumanji/environments/logic/minesweeper/env.py b/jumanji/environments/logic/minesweeper/env.py index cb3ac335e..89f7b7ea9 100644 --- a/jumanji/environments/logic/minesweeper/env.py +++ b/jumanji/environments/logic/minesweeper/env.py @@ -128,7 +128,6 @@ def __init__( invalid_action_reward=INVALID_ACTION_REWARD, ) self.done_function = done_function or DefaultDoneFn() - self._generator = generator or UniformSamplingGenerator( num_rows=10, num_cols=10, num_mines=10 ) @@ -224,7 +223,7 @@ def observation_spec(self) -> specs.Spec[Observation]: shape=(), dtype=jnp.int32, minimum=0, - maximum=self._generator.num_rows * self._generator.num_cols, + maximum=self._generator.num_rows * self._generator.num_cols - self._generator.num_mines, name="step_count", ) return specs.Spec( @@ -266,7 +265,7 @@ def render(self, state: State, save_path: Optional[str] = None) -> None: Args: state: the current state to be rendered. save_path: the path where the image should be saved. If it is None, the plot - will not be stored. + will not be saved. """ self._env_viewer.render(state=state) @@ -281,7 +280,7 @@ def animate( states: a list of `State` objects representing the sequence of states. interval: the delay between frames in milliseconds, default to 200. save_path: the path where the animation file should be saved. If it is None, the plot - will not be stored. + will not be saved. Returns: animation.FuncAnimation: the animation object that was created. """ From 86f1852a6c8dd2608afab1c0ba712c07d615dd59 Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Mon, 27 Mar 2023 17:43:17 +0100 Subject: [PATCH 20/58] Don't repeat attributes in env and generator --- jumanji/__init__.py | 6 ++++- jumanji/environments/logic/rubiks_cube/env.py | 27 ++++++++----------- .../logic/rubiks_cube/env_test.py | 1 - .../logic/rubiks_cube/generator.py | 5 ++++ 4 files changed, 21 insertions(+), 18 deletions(-) diff --git a/jumanji/__init__.py b/jumanji/__init__.py index 643b3fdb0..a2e3a1cfd 100644 --- a/jumanji/__init__.py +++ b/jumanji/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from jumanji.env import Environment +from jumanji.environments.logic.rubiks_cube.generator import ScramblingGenerator from jumanji.registration import make, register, registered_environments from jumanji.version import __version__ @@ -32,10 +33,13 @@ register(id="RubiksCube-v0", entry_point="jumanji.environments:RubiksCube") # RubiksCube - an easier version of the standard Rubik's Cube puzzle with faces of size 3x3 yet only # 7 scrambles at reset time, making it technically maximum 7 actions away from the solution. +partly_scrambled_rubiks_cube_generator = ScramblingGenerator( + cube_size=3, num_scrambles_on_reset=7 +) register( id="RubiksCube-partly-scrambled-v0", entry_point="jumanji.environments:RubiksCube", - kwargs={"cube_size": 3, "time_limit": 20, "num_scrambles_on_reset": 7}, + kwargs={"time_limit": 20, "generator": partly_scrambled_rubiks_cube_generator}, ) diff --git a/jumanji/environments/logic/rubiks_cube/env.py b/jumanji/environments/logic/rubiks_cube/env.py index 0baa422d4..2349865cb 100644 --- a/jumanji/environments/logic/rubiks_cube/env.py +++ b/jumanji/environments/logic/rubiks_cube/env.py @@ -84,7 +84,6 @@ class RubiksCube(Environment[State]): def __init__( self, - cube_size: int = 3, time_limit: int = 200, reward_fn: Optional[RewardFn] = None, env_viewer: Optional[RubiksCubeViewer] = None, @@ -93,7 +92,6 @@ def __init__( """Instantiate a `RubiksCube` environment. Args: - cube_size: the size of the cube, i.e. length of an edge. Defaults to 3. time_limit: the number of steps allowed before an episode terminates. Defaults to 200. reward_fn: `RewardFn` whose `__call__` method computes the reward given the new state. Implemented options are [`SparseRewardFn`]. Defaults to `SparseRewardFn`. @@ -103,27 +101,22 @@ def __init__( generator: `Generator` used to generate problem instances on environment reset. Implemented options are [`ScramblingGenerator`]. Defaults to `ScramblingGenerator`. + The generator will contain an attribute `cube_size`, corresponding to the number of + cubies to an edge, and defaulting to 3. """ - if cube_size < 2: - raise ValueError( - f"Cannot meaningfully construct a cube smaller than 2x2x2, " - f"but received cube_size={cube_size}" - ) if time_limit <= 0: raise ValueError( f"The time_limit must be positive, but received time_limit={time_limit}" ) - - self.cube_size = cube_size self.time_limit = time_limit self.reward_function = reward_fn or SparseRewardFn() - self._env_viewer = env_viewer or DefaultRubiksCubeViewer( - sticker_colors=DEFAULT_STICKER_COLORS, cube_size=cube_size - ) self._generator = generator or ScramblingGenerator( - cube_size=cube_size, + cube_size=3, num_scrambles_on_reset=100, ) + self._env_viewer = env_viewer or DefaultRubiksCubeViewer( + sticker_colors=DEFAULT_STICKER_COLORS, cube_size=self._generator.cube_size + ) def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: """Resets the environment. @@ -156,7 +149,7 @@ def step( next_timestep: `TimeStep` corresponding to the timestep returned by the environment. """ flattened_action = flatten_action( - unflattened_action=action, cube_size=self.cube_size + unflattened_action=action, cube_size=self._generator.cube_size ) cube = rotate_cube( cube=state.cube, @@ -190,7 +183,7 @@ def observation_spec(self) -> specs.Spec[Observation]: - step_count: BoundedArray (jnp.int32) of shape (). """ cube = specs.BoundedArray( - shape=(len(Face), self.cube_size, self.cube_size), + shape=(len(Face), self._generator.cube_size, self._generator.cube_size), dtype=jnp.int8, minimum=0, maximum=len(Face) - 1, @@ -218,7 +211,9 @@ def action_spec(self) -> specs.MultiDiscreteArray: action_spec: `MultiDiscreteArray` object. """ return specs.MultiDiscreteArray( - num_values=jnp.array([len(Face), self.cube_size // 2, 3], jnp.int32), + num_values=jnp.array( + [len(Face), self._generator.cube_size // 2, 3], jnp.int32 + ), name="action", dtype=jnp.int32, ) diff --git a/jumanji/environments/logic/rubiks_cube/env_test.py b/jumanji/environments/logic/rubiks_cube/env_test.py index 075308716..3cbf7ac55 100644 --- a/jumanji/environments/logic/rubiks_cube/env_test.py +++ b/jumanji/environments/logic/rubiks_cube/env_test.py @@ -78,7 +78,6 @@ def test_rubiks_cube__step(rubiks_cube: RubiksCube) -> None: def test_rubiks_cube__does_not_smoke(cube_size: int) -> None: """Test that we can run an episode without any errors.""" env = RubiksCube( - cube_size=cube_size, time_limit=10, generator=ScramblingGenerator(cube_size=cube_size, num_scrambles_on_reset=5), ) diff --git a/jumanji/environments/logic/rubiks_cube/generator.py b/jumanji/environments/logic/rubiks_cube/generator.py index 74256d28f..29f484d17 100644 --- a/jumanji/environments/logic/rubiks_cube/generator.py +++ b/jumanji/environments/logic/rubiks_cube/generator.py @@ -32,6 +32,11 @@ def __init__(self, cube_size: int): Args: cube_size: the size of the cube to generate instances for. """ + if cube_size < 2: + raise ValueError( + f"Cannot meaningfully construct a cube smaller than 2x2x2, " + f"but received cube_size={cube_size}" + ) self.cube_size = cube_size @abc.abstractmethod From 94738832de51393e96762dfea0ed0d6155c2f215 Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Mon, 27 Mar 2023 17:52:57 +0100 Subject: [PATCH 21/58] Lint --- jumanji/environments/logic/minesweeper/env.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/jumanji/environments/logic/minesweeper/env.py b/jumanji/environments/logic/minesweeper/env.py index 89f7b7ea9..19f8f200e 100644 --- a/jumanji/environments/logic/minesweeper/env.py +++ b/jumanji/environments/logic/minesweeper/env.py @@ -223,7 +223,8 @@ def observation_spec(self) -> specs.Spec[Observation]: shape=(), dtype=jnp.int32, minimum=0, - maximum=self._generator.num_rows * self._generator.num_cols - self._generator.num_mines, + maximum=self._generator.num_rows * self._generator.num_cols + - self._generator.num_mines, name="step_count", ) return specs.Spec( @@ -254,9 +255,7 @@ def _state_to_observation(self, state: State) -> Observation: return Observation( board=state.board, action_mask=jnp.equal(state.board, UNEXPLORED_ID), - num_mines=jnp.array( - self._generator.num_mines, jnp.int32 - ), + num_mines=jnp.array(self._generator.num_mines, jnp.int32), step_count=state.step_count, ) From 1c44c7130ba92a881addf0116feae1862dd486ab Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Mon, 27 Mar 2023 17:56:27 +0100 Subject: [PATCH 22/58] Remove minesweeper changes for separate PR --- .../logic/minesweeper/conftest.py | 5 +- .../logic/minesweeper/constants.py | 5 +- jumanji/environments/logic/minesweeper/env.py | 237 +++++++++++++----- .../logic/minesweeper/env_test.py | 24 +- .../logic/minesweeper/env_viewer.py | 216 ---------------- .../logic/minesweeper/generator.py | 85 ------- .../environments/logic/minesweeper/reward.py | 26 +- .../networks/minesweeper/actor_critic.py | 4 +- 8 files changed, 198 insertions(+), 404 deletions(-) delete mode 100644 jumanji/environments/logic/minesweeper/env_viewer.py delete mode 100644 jumanji/environments/logic/minesweeper/generator.py diff --git a/jumanji/environments/logic/minesweeper/conftest.py b/jumanji/environments/logic/minesweeper/conftest.py index b032dd70f..4893cd765 100644 --- a/jumanji/environments/logic/minesweeper/conftest.py +++ b/jumanji/environments/logic/minesweeper/conftest.py @@ -18,16 +18,13 @@ from jumanji.environments.logic.minesweeper.constants import UNEXPLORED_ID from jumanji.environments.logic.minesweeper.env import Minesweeper -from jumanji.environments.logic.minesweeper.generator import UniformSamplingGenerator from jumanji.environments.logic.minesweeper.types import State @pytest.fixture def minesweeper_env() -> Minesweeper: """Fixture for a default minesweeper env""" - return Minesweeper( - generator=UniformSamplingGenerator(num_rows=10, num_cols=10, num_mines=10) - ) + return Minesweeper() @pytest.fixture diff --git a/jumanji/environments/logic/minesweeper/constants.py b/jumanji/environments/logic/minesweeper/constants.py index e7f0a9244..21f4b4824 100644 --- a/jumanji/environments/logic/minesweeper/constants.py +++ b/jumanji/environments/logic/minesweeper/constants.py @@ -16,9 +16,8 @@ IS_MINE: int = 1 PATCH_SIZE: int = 3 REVEALED_EMPTY_SQUARE_REWARD: float = 1.0 -REVEALED_MINE_REWARD: float = 0.0 -INVALID_ACTION_REWARD: float = 0.0 -DEFAULT_COLOR_MAPPING: list = [ +REVEALED_MINE_OR_INVALID_ACTION_REWARD: float = 0.0 +COLOUR_MAPPING: list = [ "orange", "blue", "green", diff --git a/jumanji/environments/logic/minesweeper/env.py b/jumanji/environments/logic/minesweeper/env.py index 89f7b7ea9..b9ec6a0d5 100644 --- a/jumanji/environments/logic/minesweeper/env.py +++ b/jumanji/environments/logic/minesweeper/env.py @@ -12,34 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Tuple +from typing import List, Optional, Sequence, Tuple import chex import jax import jax.numpy as jnp import matplotlib.animation +import matplotlib.pyplot as plt +import jumanji.environments from jumanji import specs from jumanji.env import Environment from jumanji.environments.logic.minesweeper.constants import ( - INVALID_ACTION_REWARD, + COLOUR_MAPPING, PATCH_SIZE, - REVEALED_EMPTY_SQUARE_REWARD, - REVEALED_MINE_REWARD, UNEXPLORED_ID, ) from jumanji.environments.logic.minesweeper.done import DefaultDoneFn, DoneFn -from jumanji.environments.logic.minesweeper.env_viewer import ( - DefaultMinesweeperViewer, - MinesweeperViewer, -) -from jumanji.environments.logic.minesweeper.generator import ( - Generator, - UniformSamplingGenerator, -) from jumanji.environments.logic.minesweeper.reward import DefaultRewardFn, RewardFn from jumanji.environments.logic.minesweeper.types import Observation, State -from jumanji.environments.logic.minesweeper.utils import count_adjacent_mines +from jumanji.environments.logic.minesweeper.utils import ( + count_adjacent_mines, + create_flat_mine_locations, + explored_mine, +) from jumanji.types import TimeStep, restart, termination, transition @@ -96,44 +92,46 @@ class Minesweeper(Environment[State]): def __init__( self, + num_rows: int = 10, + num_cols: int = 10, + num_mines: int = 10, reward_function: Optional[RewardFn] = None, done_function: Optional[DoneFn] = None, - env_viewer: Optional[MinesweeperViewer] = None, - generator: Optional[Generator] = None, + color_mapping: Optional[List[str]] = None, ): """Instantiate a `Minesweeper` environment. Args: - + num_rows: number of rows, i.e. height of the board. Defaults to 10. + num_cols: number of columns, i.e. width of the board. Defaults to 10. + num_mines: number of mines on the board. Defaults to 10. reward_function: `RewardFn` whose `__call__` method computes the reward of an environment transition based on the given current state and selected action. Implemented options are [`DefaultRewardFn`]. Defaults to `DefaultRewardFn`. done_function: `DoneFn` whose `__call__` method computes the done signal given the current state, action taken, and next state. Implemented options are [`DefaultDoneFn`]. Defaults to `DefaultDoneFn`. - env_viewer: MinesweeperViewer to support rendering and animation methods. - Implemented options are [`DefaultMinesweeperViewer`]. - Defaults to `DefaultMinesweeperViewer`. - generator: Generator to generate problem instances on environment reset. - Implemented options are [`SamplingGenerator`]. - Defaults to `SamplingGenerator`. - The generator will have attributes: - - num_rows: number of rows, i.e. height of the board. Defaults to 10. - - num_cols: number of columns, i.e. width of the board. Defaults to 10. - - num_mines: number of mines generated. Defaults to 10. + color_mapping: colour map used for rendering. """ - self.reward_function = reward_function or DefaultRewardFn( - revealed_empty_square_reward=REVEALED_EMPTY_SQUARE_REWARD, - revealed_mine_reward=REVEALED_MINE_REWARD, - invalid_action_reward=INVALID_ACTION_REWARD, - ) + if num_rows <= 1 or num_cols <= 1: + raise ValueError( + f"Should make a board of height and width greater than 1, " + f"got num_rows={num_rows}, num_cols={num_cols}" + ) + if num_mines < 0 or num_mines >= num_rows * num_cols: + raise ValueError( + f"Number of mines should be constrained between 0 and the size of the board, " + f"got {num_mines}" + ) + self.num_rows = num_rows + self.num_cols = num_cols + self.num_mines = num_mines + self.reward_function = reward_function or DefaultRewardFn() self.done_function = done_function or DefaultDoneFn() - self._generator = generator or UniformSamplingGenerator( - num_rows=10, num_cols=10, num_mines=10 - ) - self._env_viewer = env_viewer or DefaultMinesweeperViewer( - num_rows=self._generator.num_rows, num_cols=self._generator.num_cols - ) + + self.cmap = color_mapping if color_mapping else COLOUR_MAPPING + self.figure_name = f"{num_rows}x{num_cols} Minesweeper" + self.figure_size = (6.0, 6.0) def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: """Resets the environment. @@ -146,7 +144,25 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: timestep: `TimeStep` corresponding to the first timestep returned by the environment. """ - state = self._generator(key) + key, sample_key = jax.random.split(key) + board = jnp.full( + shape=(self.num_rows, self.num_cols), + fill_value=UNEXPLORED_ID, + dtype=jnp.int32, + ) + step_count = jnp.array(0, jnp.int32) + flat_mine_locations = create_flat_mine_locations( + key=sample_key, + num_rows=self.num_rows, + num_cols=self.num_cols, + num_mines=self.num_mines, + ) + state = State( + board=board, + step_count=step_count, + key=key, + flat_mine_locations=flat_mine_locations, + ) observation = self._state_to_observation(state=state) timestep = restart(observation=observation) return state, timestep @@ -199,14 +215,14 @@ def observation_spec(self) -> specs.Spec[Observation]: - step_count: BoundedArray (int32) of shape (). """ board = specs.BoundedArray( - shape=(self._generator.num_rows, self._generator.num_cols), + shape=(self.num_rows, self.num_cols), dtype=jnp.int32, minimum=-1, maximum=PATCH_SIZE * PATCH_SIZE - 1, name="board", ) action_mask = specs.BoundedArray( - shape=(self._generator.num_rows, self._generator.num_cols), + shape=(self.num_rows, self.num_cols), dtype=bool, minimum=False, maximum=True, @@ -216,14 +232,14 @@ def observation_spec(self) -> specs.Spec[Observation]: shape=(), dtype=jnp.int32, minimum=0, - maximum=self._generator.num_rows * self._generator.num_cols - 1, + maximum=self.num_rows * self.num_cols - 1, name="num_mines", ) step_count = specs.BoundedArray( shape=(), dtype=jnp.int32, minimum=0, - maximum=self._generator.num_rows * self._generator.num_cols - self._generator.num_mines, + maximum=self.num_rows * self.num_cols - self.num_mines, name="step_count", ) return specs.Spec( @@ -243,9 +259,7 @@ def action_spec(self) -> specs.MultiDiscreteArray: action_spec: `specs.MultiDiscreteArray` object. """ return specs.MultiDiscreteArray( - num_values=jnp.array( - [self._generator.num_rows, self._generator.num_cols], jnp.int32 - ), + num_values=jnp.array([self.num_rows, self.num_cols], jnp.int32), name="action", dtype=jnp.int32, ) @@ -254,20 +268,21 @@ def _state_to_observation(self, state: State) -> Observation: return Observation( board=state.board, action_mask=jnp.equal(state.board, UNEXPLORED_ID), - num_mines=jnp.array( - self._generator.num_mines, jnp.int32 - ), + num_mines=jnp.array(self.num_mines, jnp.int32), step_count=state.step_count, ) - def render(self, state: State, save_path: Optional[str] = None) -> None: - """Renders the current state of the board. + def render(self, state: State) -> None: + """Render the given environment state using matplotlib. + Args: - state: the current state to be rendered. - save_path: the path where the image should be saved. If it is None, the plot - will not be saved. + state: environment state to be rendered. + """ - self._env_viewer.render(state=state) + self._clear_display() + fig, ax = self._get_fig_ax() + self._draw(ax, state) + self._update_display(fig) def animate( self, @@ -275,22 +290,116 @@ def animate( interval: int = 200, save_path: Optional[str] = None, ) -> matplotlib.animation.FuncAnimation: - """Creates an animated gif of the board based on the sequence of states. - Args: - states: a list of `State` objects representing the sequence of states. - interval: the delay between frames in milliseconds, default to 200. - save_path: the path where the animation file should be saved. If it is None, the plot - will not be saved. + """Create an animation from a sequence of environment states. + + Args: + states: sequence of environment states corresponding to consecutive timesteps. + interval: delay between frames in milliseconds, default to 200. + save_path: the path where the animation file should be saved. If it is None, the plot + will not be saved. + Returns: - animation.FuncAnimation: the animation object that was created. + Animation object that can be saved as a GIF, MP4, or rendered with HTML. """ - return self._env_viewer.animate( - states=states, interval=interval, save_path=save_path + fig, ax = self._get_fig_ax() + plt.tight_layout() + plt.close(fig) + + def make_frame(state_index: int) -> None: + state = states[state_index] + self._draw(ax, state) + + # Create the animation object. + self._animation = matplotlib.animation.FuncAnimation( + fig, + make_frame, + frames=len(states), + interval=interval, ) + # Save the animation as a GIF. + if save_path: + self._animation.save(save_path) + + return self._animation + def close(self) -> None: """Perform any necessary cleanup. + Environments will automatically :meth:`close()` themselves when garbage collected or when the program exits. """ - self._env_viewer.close() + plt.close(self.figure_name) + + def _get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]: + exists = plt.fignum_exists(self.figure_name) + if exists: + fig = plt.figure(self.figure_name) + ax = fig.get_axes()[0] + else: + fig = plt.figure(self.figure_name, figsize=self.figure_size) + plt.suptitle(self.figure_name) + plt.tight_layout() + if not plt.isinteractive(): + fig.show() + ax = fig.add_subplot() + return fig, ax + + def _draw(self, ax: plt.Axes, state: State) -> None: + ax.clear() + ax.set_xticks(jnp.arange(-0.5, self.num_cols - 1, 1)) + ax.set_yticks(jnp.arange(-0.5, self.num_rows - 1, 1)) + ax.tick_params( + top=False, + bottom=False, + left=False, + right=False, + labelleft=False, + labelbottom=False, + labeltop=False, + labelright=False, + ) + background = jnp.ones_like(state.board) + for i in range(self.num_rows): + for j in range(self.num_cols): + background = self._render_grid_square( + state=state, ax=ax, i=i, j=j, background=background + ) + ax.imshow(background, cmap="gray", vmin=0, vmax=1) + ax.grid(color="black", linestyle="-", linewidth=2) + + def _render_grid_square( + self, state: State, ax: plt.Axes, i: int, j: int, background: chex.Array + ) -> chex.Array: + board_value = state.board[i, j] + if board_value != UNEXPLORED_ID: + if explored_mine(state=state, action=jnp.array([i, j], dtype=jnp.int32)): + background = background.at[i, j].set(0) + else: + ax.text( + j, + i, + str(board_value), + color=self.cmap[board_value], + ha="center", + va="center", + fontsize="xx-large", + ) + return background + + def _update_display(self, fig: plt.Figure) -> None: + if plt.isinteractive(): + # Required to update render when using Jupyter Notebook. + fig.canvas.draw() + if jumanji.environments.is_colab(): + plt.show(self.figure_name) + else: + # Required to update render when not using Jupyter Notebook. + fig.canvas.draw_idle() + fig.canvas.flush_events() + + def _clear_display(self) -> None: + if jumanji.environments.is_colab(): + import IPython.display + + IPython.display.clear_output(True) diff --git a/jumanji/environments/logic/minesweeper/env_test.py b/jumanji/environments/logic/minesweeper/env_test.py index fb75d968f..129d81583 100644 --- a/jumanji/environments/logic/minesweeper/env_test.py +++ b/jumanji/environments/logic/minesweeper/env_test.py @@ -24,9 +24,8 @@ from jax import random from jumanji.environments.logic.minesweeper.constants import ( - INVALID_ACTION_REWARD, REVEALED_EMPTY_SQUARE_REWARD, - REVEALED_MINE_REWARD, + REVEALED_MINE_OR_INVALID_ACTION_REWARD, ) from jumanji.environments.logic.minesweeper.env import Minesweeper from jumanji.environments.logic.minesweeper.types import State @@ -70,12 +69,12 @@ def play_and_get_episode_stats( ), ( [[0, 3], [0, 2]], - [REVEALED_EMPTY_SQUARE_REWARD, REVEALED_MINE_REWARD], + [REVEALED_EMPTY_SQUARE_REWARD, REVEALED_MINE_OR_INVALID_ACTION_REWARD], [StepType.MID, StepType.LAST], ), ( [[0, 3], [0, 3]], - [REVEALED_EMPTY_SQUARE_REWARD, INVALID_ACTION_REWARD], + [REVEALED_EMPTY_SQUARE_REWARD, REVEALED_MINE_OR_INVALID_ACTION_REWARD], [StepType.MID, StepType.LAST], ), ], @@ -108,11 +107,11 @@ def test_minesweeper_env_reset(minesweeper_env: Minesweeper) -> None: assert isinstance(timestep, TimeStep) assert isinstance(state, State) assert state.step_count == 0 - assert state.flat_mine_locations.shape == (minesweeper_env._generator.num_mines,) - assert timestep.observation.num_mines == minesweeper_env._generator.num_mines + assert state.flat_mine_locations.shape == (minesweeper_env.num_mines,) + assert timestep.observation.num_mines == minesweeper_env.num_mines assert state.board.shape == ( - minesweeper_env._generator.num_rows, - minesweeper_env._generator.num_cols, + minesweeper_env.num_rows, + minesweeper_env.num_cols, ) assert jnp.array_equal(state.board, timestep.observation.board) assert timestep.observation.step_count == 0 @@ -190,9 +189,9 @@ def test_minesweeper_env_solved(minesweeper_env: Minesweeper) -> None: step_fn = jit(minesweeper_env.step) collected_rewards = [] collected_step_types = [] - for i in range(minesweeper_env._generator.num_rows): - for j in range(minesweeper_env._generator.num_cols): - flat_location = i * minesweeper_env._generator.num_cols + j + for i in range(minesweeper_env.num_rows): + for j in range(minesweeper_env.num_cols): + flat_location = i * minesweeper_env.num_cols + j if flat_location in state.flat_mine_locations: continue action = jnp.array([i, j], dtype=jnp.int32) @@ -200,8 +199,7 @@ def test_minesweeper_env_solved(minesweeper_env: Minesweeper) -> None: collected_rewards.append(timestep.reward) collected_step_types.append(timestep.step_type) expected_episode_length = ( - minesweeper_env._generator.num_rows * minesweeper_env._generator.num_cols - - minesweeper_env._generator.num_mines + minesweeper_env.num_rows * minesweeper_env.num_cols - minesweeper_env.num_mines ) assert collected_rewards == [REVEALED_EMPTY_SQUARE_REWARD] * expected_episode_length assert collected_step_types == [StepType.MID] * (expected_episode_length - 1) + [ diff --git a/jumanji/environments/logic/minesweeper/env_viewer.py b/jumanji/environments/logic/minesweeper/env_viewer.py deleted file mode 100644 index 3115f491c..000000000 --- a/jumanji/environments/logic/minesweeper/env_viewer.py +++ /dev/null @@ -1,216 +0,0 @@ -# Copyright 2022 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Optional, Sequence, Tuple - -import chex -import jax.numpy as jnp -import matplotlib -from matplotlib import pyplot as plt - -import jumanji.environments -from jumanji.environments.logic.minesweeper.constants import ( - DEFAULT_COLOR_MAPPING, - UNEXPLORED_ID, -) -from jumanji.environments.logic.minesweeper.types import State -from jumanji.environments.logic.minesweeper.utils import explored_mine - - -class MinesweeperViewer: - """Abstract viewer class to support rendering and animation""" - - def render(self, state: State) -> None: - """Render frames of the environment for a given state using matplotlib. - Args: - state: `State` object corresponding to the new state of the environment. - """ - raise NotImplementedError - - def animate( - self, - states: Sequence[State], - interval: int = 200, - save_path: Optional[str] = None, - ) -> matplotlib.animation.FuncAnimation: - """Create an animation from a sequence of environment states. - Args: - states: sequence of environment states corresponding to consecutive timesteps. - interval: delay between frames in milliseconds, default to 200. - save_path: the path where the animation file should be saved. If it is None, the plot - will not be saved. - Returns: - Animation that can be saved as a GIF, MP4, or rendered with HTML. - """ - raise NotImplementedError - - def close(self) -> None: - """Perform any necessary cleanup. - Environments will automatically :meth:`close()` themselves when - garbage collected or when the program exits. - """ - raise NotImplementedError - - -class DefaultMinesweeperViewer(MinesweeperViewer): - def __init__( - self, - color_mapping: Optional[List[str]] = None, - num_rows: int = 10, - num_cols: int = 10, - ): - """ - Args: - color_mapping: colors used in rendering the cells in Minesweeper. - Defaults to `DEFAULT_COLOR_MAPPING`. - num_rows: number of rows, i.e. height of the board. Defaults to 10. - num_cols: number of columns, i.e. width of the board. Defaults to 10. - """ - self.cmap = color_mapping if color_mapping else DEFAULT_COLOR_MAPPING - self.num_rows = num_rows - self.num_cols = num_cols - self.figure_name = f"{num_rows}x{num_cols} Minesweeper" - self.figure_size = (6.0, 6.0) - - def render(self, state: State) -> None: - """Render the given environment state using matplotlib. - - Args: - state: environment state to be rendered. - - """ - self._clear_display() - fig, ax = self._get_fig_ax() - self._draw(ax, state) - self._update_display(fig) - - def animate( - self, - states: Sequence[State], - interval: int = 200, - save_path: Optional[str] = None, - ) -> matplotlib.animation.FuncAnimation: - """Create an animation from a sequence of environment states. - - Args: - states: sequence of environment states corresponding to consecutive timesteps. - interval: delay between frames in milliseconds, default to 200. - save_path: the path where the animation file should be saved. If it is None, the plot - will not be saved. - - Returns: - Animation object that can be saved as a GIF, MP4, or rendered with HTML. - """ - fig, ax = self._get_fig_ax() - plt.tight_layout() - plt.close(fig) - - def make_frame(state_index: int) -> None: - state = states[state_index] - self._draw(ax, state) - - # Create the animation object. - self._animation = matplotlib.animation.FuncAnimation( - fig, - make_frame, - frames=len(states), - interval=interval, - ) - - # Save the animation as a GIF. - if save_path: - self._animation.save(save_path) - - return self._animation - - def close(self) -> None: - """Perform any necessary cleanup. - - Environments will automatically :meth:`close()` themselves when - garbage collected or when the program exits. - """ - plt.close(self.figure_name) - - def _get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]: - exists = plt.fignum_exists(self.figure_name) - if exists: - fig = plt.figure(self.figure_name) - ax = fig.get_axes()[0] - else: - fig = plt.figure(self.figure_name, figsize=self.figure_size) - plt.suptitle(self.figure_name) - plt.tight_layout() - if not plt.isinteractive(): - fig.show() - ax = fig.add_subplot() - return fig, ax - - def _draw(self, ax: plt.Axes, state: State) -> None: - ax.clear() - ax.set_xticks(jnp.arange(-0.5, self.num_cols - 1, 1)) - ax.set_yticks(jnp.arange(-0.5, self.num_rows - 1, 1)) - ax.tick_params( - top=False, - bottom=False, - left=False, - right=False, - labelleft=False, - labelbottom=False, - labeltop=False, - labelright=False, - ) - background = jnp.ones_like(state.board) - for i in range(self.num_rows): - for j in range(self.num_cols): - background = self._render_grid_square( - state=state, ax=ax, i=i, j=j, background=background - ) - ax.imshow(background, cmap="gray", vmin=0, vmax=1) - ax.grid(color="black", linestyle="-", linewidth=2) - - def _render_grid_square( - self, state: State, ax: plt.Axes, i: int, j: int, background: chex.Array - ) -> chex.Array: - board_value = state.board[i, j] - if board_value != UNEXPLORED_ID: - if explored_mine(state=state, action=jnp.array([i, j], dtype=jnp.int32)): - background = background.at[i, j].set(0) - else: - ax.text( - j, - i, - str(board_value), - color=self.cmap[board_value], - ha="center", - va="center", - fontsize="xx-large", - ) - return background - - def _update_display(self, fig: plt.Figure) -> None: - if plt.isinteractive(): - # Required to update render when using Jupyter Notebook. - fig.canvas.draw() - if jumanji.environments.is_colab(): - plt.show(self.figure_name) - else: - # Required to update render when not using Jupyter Notebook. - fig.canvas.draw_idle() - fig.canvas.flush_events() - - def _clear_display(self) -> None: - if jumanji.environments.is_colab(): - import IPython.display - - IPython.display.clear_output(True) diff --git a/jumanji/environments/logic/minesweeper/generator.py b/jumanji/environments/logic/minesweeper/generator.py deleted file mode 100644 index b975b2971..000000000 --- a/jumanji/environments/logic/minesweeper/generator.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2022 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import abc - -import chex -import jax -import jax.numpy as jnp - -from jumanji.environments.logic.minesweeper.constants import UNEXPLORED_ID -from jumanji.environments.logic.minesweeper.types import State -from jumanji.environments.logic.minesweeper.utils import create_flat_mine_locations - - -class Generator(abc.ABC): - """Base class for generators for the Minesweeper environment.""" - - def __init__(self, num_rows: int, num_cols: int, num_mines: int): - """Initialises a Minesweeper generator for resetting the environment. - Args: - num_rows: number of rows, i.e. height of the board. - num_cols: number of columns, i.e. width of the board. - num_mines: number of mines to place on the board. - """ - if num_rows <= 1 or num_cols <= 1: - raise ValueError( - f"Should make a board of height and width greater than 1, " - f"got num_rows={num_rows}, num_cols={num_cols}" - ) - if num_mines < 0 or num_mines >= num_rows * num_cols: - raise ValueError( - f"Number of mines should be constrained between 0 and the size of the board, " - f"got {num_mines}" - ) - self.num_rows = num_rows - self.num_cols = num_cols - self.num_mines = num_mines - - @abc.abstractmethod - def generate_flat_mine_locations(self, key: chex.PRNGKey) -> chex.Array: - """Generates positions (in flattened coordinates) of the mines in the board""" - - def __call__(self, key: chex.PRNGKey) -> State: - """Generates a `Minesweeper` state. - Returns: - A `Minesweeper` state. - """ - key, sample_key = jax.random.split(key) - board = jnp.full( - shape=(self.num_rows, self.num_cols), - fill_value=UNEXPLORED_ID, - dtype=jnp.int32, - ) - step_count = jnp.array(0, jnp.int32) - flat_mine_locations = self.generate_flat_mine_locations(key=sample_key) - state = State( - board=board, - step_count=step_count, - key=key, - flat_mine_locations=flat_mine_locations, - ) - return state - - -class UniformSamplingGenerator(Generator): - """Generates instances by sampling a given number of mines (without replacement).""" - - def generate_flat_mine_locations(self, key: chex.PRNGKey) -> chex.Array: - return create_flat_mine_locations( - key=key, - num_rows=self.num_rows, - num_cols=self.num_cols, - num_mines=self.num_mines, - ) diff --git a/jumanji/environments/logic/minesweeper/reward.py b/jumanji/environments/logic/minesweeper/reward.py index 28ee5db89..336d0325c 100644 --- a/jumanji/environments/logic/minesweeper/reward.py +++ b/jumanji/environments/logic/minesweeper/reward.py @@ -17,6 +17,10 @@ import chex import jax.numpy as jnp +from jumanji.environments.logic.minesweeper.constants import ( + REVEALED_EMPTY_SQUARE_REWARD, + REVEALED_MINE_OR_INVALID_ACTION_REWARD, +) from jumanji.environments.logic.minesweeper.types import State from jumanji.environments.logic.minesweeper.utils import explored_mine, is_valid_action @@ -28,29 +32,17 @@ def __call__(self, state: State, action: chex.Array) -> chex.Array: class DefaultRewardFn(RewardFn): - """A dense reward function corresponding to the 3 possible events: - - Revealing an empty square - - Revealing a mine - - Choosing an invalid action (an already revealed square) + """A dense reward function: 1 for every timestep on which a mine is not explored + (or a small penalty if action is invalid), otherwise 0. """ - def __init__( - self, - revealed_empty_square_reward: float, - revealed_mine_reward: float, - invalid_action_reward: float, - ): - self.revealed_empty_square_reward = revealed_empty_square_reward - self.revelead_mine_reward = revealed_mine_reward - self.invalid_action_reward = invalid_action_reward - def __call__(self, state: State, action: chex.Array) -> chex.Array: return jnp.where( is_valid_action(state=state, action=action), jnp.where( explored_mine(state=state, action=action), - jnp.array(self.revelead_mine_reward, float), - jnp.array(self.revealed_empty_square_reward, float), + jnp.array(REVEALED_MINE_OR_INVALID_ACTION_REWARD, float), + jnp.array(REVEALED_EMPTY_SQUARE_REWARD, float), ), - jnp.array(self.invalid_action_reward, float), + jnp.array(REVEALED_MINE_OR_INVALID_ACTION_REWARD, float), ) diff --git a/jumanji/training/networks/minesweeper/actor_critic.py b/jumanji/training/networks/minesweeper/actor_critic.py index 95fcbe6f2..a14e04db6 100644 --- a/jumanji/training/networks/minesweeper/actor_critic.py +++ b/jumanji/training/networks/minesweeper/actor_critic.py @@ -40,8 +40,8 @@ def make_actor_critic_networks_minesweeper( final_layer_dims: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `Minesweeper` environment.""" - board_height = minesweeper._generator.num_rows - board_width = minesweeper._generator.num_cols + board_height = minesweeper.num_rows + board_width = minesweeper.num_cols vocab_size = 1 + PATCH_SIZE**2 # unexplored, or 0, 1, ..., 8 parametric_action_distribution = FactorisedActionSpaceParametricDistribution( From 8d83792b795b7f4178eda13dc9f08aa20420f8fa Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Mon, 27 Mar 2023 18:04:32 +0100 Subject: [PATCH 23/58] Simplify --- jumanji/environments/logic/rubiks_cube/utils.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/jumanji/environments/logic/rubiks_cube/utils.py b/jumanji/environments/logic/rubiks_cube/utils.py index 0ce1b7186..040aabe45 100644 --- a/jumanji/environments/logic/rubiks_cube/utils.py +++ b/jumanji/environments/logic/rubiks_cube/utils.py @@ -526,13 +526,9 @@ def scramble_solved_cube( Returns: The scrambled cube. """ - - def rotate_cube_fn(cube: Cube, flattened_action: chex.Array) -> Cube: - return rotate_cube(cube=cube, flattened_action=flattened_action) - cube = make_solved_cube(cube_size=cube_size) cube, _ = jax.lax.scan( - lambda *args: (rotate_cube_fn(*args), None), + lambda *args: (rotate_cube(*args), None), cube, flattened_actions_in_scramble, ) From 0bc3f330729052c559332a29203ecf2505546a54 Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Tue, 28 Mar 2023 12:17:41 +0100 Subject: [PATCH 24/58] To sync --- .../logic/rubiks_cube/env_viewer.py | 37 +------------------ 1 file changed, 1 insertion(+), 36 deletions(-) diff --git a/jumanji/environments/logic/rubiks_cube/env_viewer.py b/jumanji/environments/logic/rubiks_cube/env_viewer.py index 57d15f079..faf9bcb6c 100644 --- a/jumanji/environments/logic/rubiks_cube/env_viewer.py +++ b/jumanji/environments/logic/rubiks_cube/env_viewer.py @@ -23,47 +23,12 @@ from jumanji.environments.logic.rubiks_cube.types import State -class RubiksCubeViewer: +class RubiksCubeViewer(Viewer): """Abstract viewer class to support rendering and animation""" def __init__(self, cube_size: int): self.cube_size = cube_size - def render(self, state: State) -> None: - """Render frames of the environment for a given state using matplotlib. - - Args: - state: `State` object corresponding to the new state of the environment. - """ - raise NotImplementedError - - def animate( - self, - states: Sequence[State], - interval: int, - save_path: Optional[str], - ) -> matplotlib.animation.FuncAnimation: - """Create an animation from a sequence of environment states. - - Args: - states: sequence of environment states corresponding to consecutive timesteps. - interval: delay between frames in milliseconds, default to 200. - save_path: the path where the animation file should be saved. If it is None, the plot - will not be saved. - - Returns: - Animation that can be saved as a GIF, MP4, or rendered with HTML. - """ - raise NotImplementedError - - def close(self) -> None: - """Perform any necessary cleanup. - - Environments will automatically :meth:`close()` themselves when - garbage collected or when the program exits. - """ - raise NotImplementedError - class DefaultRubiksCubeViewer(RubiksCubeViewer): def __init__(self, sticker_colors: Optional[list], cube_size: int): From cebb72f33cf65e2a4977589f3b626738487699a2 Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Tue, 28 Mar 2023 12:42:23 +0100 Subject: [PATCH 25/58] To sync --- jumanji/environments/logic/minesweeper/env_viewer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/jumanji/environments/logic/minesweeper/env_viewer.py b/jumanji/environments/logic/minesweeper/env_viewer.py index 3115f491c..e4aa2a331 100644 --- a/jumanji/environments/logic/minesweeper/env_viewer.py +++ b/jumanji/environments/logic/minesweeper/env_viewer.py @@ -28,8 +28,11 @@ from jumanji.environments.logic.minesweeper.utils import explored_mine -class MinesweeperViewer: +class MinesweeperViewer(Viewer): """Abstract viewer class to support rendering and animation""" + def __init__(self, num_rows: int, num_cols: int): + self.num_rows = num_rows + self.num_cols = num_cols def render(self, state: State) -> None: """Render frames of the environment for a given state using matplotlib. @@ -78,10 +81,9 @@ def __init__( num_cols: number of columns, i.e. width of the board. Defaults to 10. """ self.cmap = color_mapping if color_mapping else DEFAULT_COLOR_MAPPING - self.num_rows = num_rows - self.num_cols = num_cols self.figure_name = f"{num_rows}x{num_cols} Minesweeper" self.figure_size = (6.0, 6.0) + super().__init__(num_rows=num_rows, num_cols=num_cols) def render(self, state: State) -> None: """Render the given environment state using matplotlib. From 379e483894f44a43f0474f406c353b0fd3ae844e Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Tue, 28 Mar 2023 12:47:03 +0100 Subject: [PATCH 26/58] To sync --- .../logic/minesweeper/conftest.py | 5 +- .../logic/minesweeper/constants.py | 5 +- jumanji/environments/logic/minesweeper/env.py | 236 +++++------------- .../logic/minesweeper/env_test.py | 24 +- .../logic/minesweeper/generator.py | 85 +++++++ .../environments/logic/minesweeper/reward.py | 26 +- .../networks/minesweeper/actor_critic.py | 4 +- 7 files changed, 187 insertions(+), 198 deletions(-) create mode 100644 jumanji/environments/logic/minesweeper/generator.py diff --git a/jumanji/environments/logic/minesweeper/conftest.py b/jumanji/environments/logic/minesweeper/conftest.py index 4893cd765..b032dd70f 100644 --- a/jumanji/environments/logic/minesweeper/conftest.py +++ b/jumanji/environments/logic/minesweeper/conftest.py @@ -18,13 +18,16 @@ from jumanji.environments.logic.minesweeper.constants import UNEXPLORED_ID from jumanji.environments.logic.minesweeper.env import Minesweeper +from jumanji.environments.logic.minesweeper.generator import UniformSamplingGenerator from jumanji.environments.logic.minesweeper.types import State @pytest.fixture def minesweeper_env() -> Minesweeper: """Fixture for a default minesweeper env""" - return Minesweeper() + return Minesweeper( + generator=UniformSamplingGenerator(num_rows=10, num_cols=10, num_mines=10) + ) @pytest.fixture diff --git a/jumanji/environments/logic/minesweeper/constants.py b/jumanji/environments/logic/minesweeper/constants.py index 21f4b4824..e7f0a9244 100644 --- a/jumanji/environments/logic/minesweeper/constants.py +++ b/jumanji/environments/logic/minesweeper/constants.py @@ -16,8 +16,9 @@ IS_MINE: int = 1 PATCH_SIZE: int = 3 REVEALED_EMPTY_SQUARE_REWARD: float = 1.0 -REVEALED_MINE_OR_INVALID_ACTION_REWARD: float = 0.0 -COLOUR_MAPPING: list = [ +REVEALED_MINE_REWARD: float = 0.0 +INVALID_ACTION_REWARD: float = 0.0 +DEFAULT_COLOR_MAPPING: list = [ "orange", "blue", "green", diff --git a/jumanji/environments/logic/minesweeper/env.py b/jumanji/environments/logic/minesweeper/env.py index b9ec6a0d5..19f8f200e 100644 --- a/jumanji/environments/logic/minesweeper/env.py +++ b/jumanji/environments/logic/minesweeper/env.py @@ -12,30 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Sequence, Tuple +from typing import Optional, Sequence, Tuple import chex import jax import jax.numpy as jnp import matplotlib.animation -import matplotlib.pyplot as plt -import jumanji.environments from jumanji import specs from jumanji.env import Environment from jumanji.environments.logic.minesweeper.constants import ( - COLOUR_MAPPING, + INVALID_ACTION_REWARD, PATCH_SIZE, + REVEALED_EMPTY_SQUARE_REWARD, + REVEALED_MINE_REWARD, UNEXPLORED_ID, ) from jumanji.environments.logic.minesweeper.done import DefaultDoneFn, DoneFn +from jumanji.environments.logic.minesweeper.env_viewer import ( + DefaultMinesweeperViewer, + MinesweeperViewer, +) +from jumanji.environments.logic.minesweeper.generator import ( + Generator, + UniformSamplingGenerator, +) from jumanji.environments.logic.minesweeper.reward import DefaultRewardFn, RewardFn from jumanji.environments.logic.minesweeper.types import Observation, State -from jumanji.environments.logic.minesweeper.utils import ( - count_adjacent_mines, - create_flat_mine_locations, - explored_mine, -) +from jumanji.environments.logic.minesweeper.utils import count_adjacent_mines from jumanji.types import TimeStep, restart, termination, transition @@ -92,46 +96,44 @@ class Minesweeper(Environment[State]): def __init__( self, - num_rows: int = 10, - num_cols: int = 10, - num_mines: int = 10, reward_function: Optional[RewardFn] = None, done_function: Optional[DoneFn] = None, - color_mapping: Optional[List[str]] = None, + env_viewer: Optional[MinesweeperViewer] = None, + generator: Optional[Generator] = None, ): """Instantiate a `Minesweeper` environment. Args: - num_rows: number of rows, i.e. height of the board. Defaults to 10. - num_cols: number of columns, i.e. width of the board. Defaults to 10. - num_mines: number of mines on the board. Defaults to 10. + reward_function: `RewardFn` whose `__call__` method computes the reward of an environment transition based on the given current state and selected action. Implemented options are [`DefaultRewardFn`]. Defaults to `DefaultRewardFn`. done_function: `DoneFn` whose `__call__` method computes the done signal given the current state, action taken, and next state. Implemented options are [`DefaultDoneFn`]. Defaults to `DefaultDoneFn`. - color_mapping: colour map used for rendering. + env_viewer: MinesweeperViewer to support rendering and animation methods. + Implemented options are [`DefaultMinesweeperViewer`]. + Defaults to `DefaultMinesweeperViewer`. + generator: Generator to generate problem instances on environment reset. + Implemented options are [`SamplingGenerator`]. + Defaults to `SamplingGenerator`. + The generator will have attributes: + - num_rows: number of rows, i.e. height of the board. Defaults to 10. + - num_cols: number of columns, i.e. width of the board. Defaults to 10. + - num_mines: number of mines generated. Defaults to 10. """ - if num_rows <= 1 or num_cols <= 1: - raise ValueError( - f"Should make a board of height and width greater than 1, " - f"got num_rows={num_rows}, num_cols={num_cols}" - ) - if num_mines < 0 or num_mines >= num_rows * num_cols: - raise ValueError( - f"Number of mines should be constrained between 0 and the size of the board, " - f"got {num_mines}" - ) - self.num_rows = num_rows - self.num_cols = num_cols - self.num_mines = num_mines - self.reward_function = reward_function or DefaultRewardFn() + self.reward_function = reward_function or DefaultRewardFn( + revealed_empty_square_reward=REVEALED_EMPTY_SQUARE_REWARD, + revealed_mine_reward=REVEALED_MINE_REWARD, + invalid_action_reward=INVALID_ACTION_REWARD, + ) self.done_function = done_function or DefaultDoneFn() - - self.cmap = color_mapping if color_mapping else COLOUR_MAPPING - self.figure_name = f"{num_rows}x{num_cols} Minesweeper" - self.figure_size = (6.0, 6.0) + self._generator = generator or UniformSamplingGenerator( + num_rows=10, num_cols=10, num_mines=10 + ) + self._env_viewer = env_viewer or DefaultMinesweeperViewer( + num_rows=self._generator.num_rows, num_cols=self._generator.num_cols + ) def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: """Resets the environment. @@ -144,25 +146,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: timestep: `TimeStep` corresponding to the first timestep returned by the environment. """ - key, sample_key = jax.random.split(key) - board = jnp.full( - shape=(self.num_rows, self.num_cols), - fill_value=UNEXPLORED_ID, - dtype=jnp.int32, - ) - step_count = jnp.array(0, jnp.int32) - flat_mine_locations = create_flat_mine_locations( - key=sample_key, - num_rows=self.num_rows, - num_cols=self.num_cols, - num_mines=self.num_mines, - ) - state = State( - board=board, - step_count=step_count, - key=key, - flat_mine_locations=flat_mine_locations, - ) + state = self._generator(key) observation = self._state_to_observation(state=state) timestep = restart(observation=observation) return state, timestep @@ -215,14 +199,14 @@ def observation_spec(self) -> specs.Spec[Observation]: - step_count: BoundedArray (int32) of shape (). """ board = specs.BoundedArray( - shape=(self.num_rows, self.num_cols), + shape=(self._generator.num_rows, self._generator.num_cols), dtype=jnp.int32, minimum=-1, maximum=PATCH_SIZE * PATCH_SIZE - 1, name="board", ) action_mask = specs.BoundedArray( - shape=(self.num_rows, self.num_cols), + shape=(self._generator.num_rows, self._generator.num_cols), dtype=bool, minimum=False, maximum=True, @@ -232,14 +216,15 @@ def observation_spec(self) -> specs.Spec[Observation]: shape=(), dtype=jnp.int32, minimum=0, - maximum=self.num_rows * self.num_cols - 1, + maximum=self._generator.num_rows * self._generator.num_cols - 1, name="num_mines", ) step_count = specs.BoundedArray( shape=(), dtype=jnp.int32, minimum=0, - maximum=self.num_rows * self.num_cols - self.num_mines, + maximum=self._generator.num_rows * self._generator.num_cols + - self._generator.num_mines, name="step_count", ) return specs.Spec( @@ -259,7 +244,9 @@ def action_spec(self) -> specs.MultiDiscreteArray: action_spec: `specs.MultiDiscreteArray` object. """ return specs.MultiDiscreteArray( - num_values=jnp.array([self.num_rows, self.num_cols], jnp.int32), + num_values=jnp.array( + [self._generator.num_rows, self._generator.num_cols], jnp.int32 + ), name="action", dtype=jnp.int32, ) @@ -268,21 +255,18 @@ def _state_to_observation(self, state: State) -> Observation: return Observation( board=state.board, action_mask=jnp.equal(state.board, UNEXPLORED_ID), - num_mines=jnp.array(self.num_mines, jnp.int32), + num_mines=jnp.array(self._generator.num_mines, jnp.int32), step_count=state.step_count, ) - def render(self, state: State) -> None: - """Render the given environment state using matplotlib. - + def render(self, state: State, save_path: Optional[str] = None) -> None: + """Renders the current state of the board. Args: - state: environment state to be rendered. - + state: the current state to be rendered. + save_path: the path where the image should be saved. If it is None, the plot + will not be saved. """ - self._clear_display() - fig, ax = self._get_fig_ax() - self._draw(ax, state) - self._update_display(fig) + self._env_viewer.render(state=state) def animate( self, @@ -290,116 +274,22 @@ def animate( interval: int = 200, save_path: Optional[str] = None, ) -> matplotlib.animation.FuncAnimation: - """Create an animation from a sequence of environment states. - - Args: - states: sequence of environment states corresponding to consecutive timesteps. - interval: delay between frames in milliseconds, default to 200. - save_path: the path where the animation file should be saved. If it is None, the plot - will not be saved. - + """Creates an animated gif of the board based on the sequence of states. + Args: + states: a list of `State` objects representing the sequence of states. + interval: the delay between frames in milliseconds, default to 200. + save_path: the path where the animation file should be saved. If it is None, the plot + will not be saved. Returns: - Animation object that can be saved as a GIF, MP4, or rendered with HTML. + animation.FuncAnimation: the animation object that was created. """ - fig, ax = self._get_fig_ax() - plt.tight_layout() - plt.close(fig) - - def make_frame(state_index: int) -> None: - state = states[state_index] - self._draw(ax, state) - - # Create the animation object. - self._animation = matplotlib.animation.FuncAnimation( - fig, - make_frame, - frames=len(states), - interval=interval, + return self._env_viewer.animate( + states=states, interval=interval, save_path=save_path ) - # Save the animation as a GIF. - if save_path: - self._animation.save(save_path) - - return self._animation - def close(self) -> None: """Perform any necessary cleanup. - Environments will automatically :meth:`close()` themselves when garbage collected or when the program exits. """ - plt.close(self.figure_name) - - def _get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]: - exists = plt.fignum_exists(self.figure_name) - if exists: - fig = plt.figure(self.figure_name) - ax = fig.get_axes()[0] - else: - fig = plt.figure(self.figure_name, figsize=self.figure_size) - plt.suptitle(self.figure_name) - plt.tight_layout() - if not plt.isinteractive(): - fig.show() - ax = fig.add_subplot() - return fig, ax - - def _draw(self, ax: plt.Axes, state: State) -> None: - ax.clear() - ax.set_xticks(jnp.arange(-0.5, self.num_cols - 1, 1)) - ax.set_yticks(jnp.arange(-0.5, self.num_rows - 1, 1)) - ax.tick_params( - top=False, - bottom=False, - left=False, - right=False, - labelleft=False, - labelbottom=False, - labeltop=False, - labelright=False, - ) - background = jnp.ones_like(state.board) - for i in range(self.num_rows): - for j in range(self.num_cols): - background = self._render_grid_square( - state=state, ax=ax, i=i, j=j, background=background - ) - ax.imshow(background, cmap="gray", vmin=0, vmax=1) - ax.grid(color="black", linestyle="-", linewidth=2) - - def _render_grid_square( - self, state: State, ax: plt.Axes, i: int, j: int, background: chex.Array - ) -> chex.Array: - board_value = state.board[i, j] - if board_value != UNEXPLORED_ID: - if explored_mine(state=state, action=jnp.array([i, j], dtype=jnp.int32)): - background = background.at[i, j].set(0) - else: - ax.text( - j, - i, - str(board_value), - color=self.cmap[board_value], - ha="center", - va="center", - fontsize="xx-large", - ) - return background - - def _update_display(self, fig: plt.Figure) -> None: - if plt.isinteractive(): - # Required to update render when using Jupyter Notebook. - fig.canvas.draw() - if jumanji.environments.is_colab(): - plt.show(self.figure_name) - else: - # Required to update render when not using Jupyter Notebook. - fig.canvas.draw_idle() - fig.canvas.flush_events() - - def _clear_display(self) -> None: - if jumanji.environments.is_colab(): - import IPython.display - - IPython.display.clear_output(True) + self._env_viewer.close() diff --git a/jumanji/environments/logic/minesweeper/env_test.py b/jumanji/environments/logic/minesweeper/env_test.py index 129d81583..fb75d968f 100644 --- a/jumanji/environments/logic/minesweeper/env_test.py +++ b/jumanji/environments/logic/minesweeper/env_test.py @@ -24,8 +24,9 @@ from jax import random from jumanji.environments.logic.minesweeper.constants import ( + INVALID_ACTION_REWARD, REVEALED_EMPTY_SQUARE_REWARD, - REVEALED_MINE_OR_INVALID_ACTION_REWARD, + REVEALED_MINE_REWARD, ) from jumanji.environments.logic.minesweeper.env import Minesweeper from jumanji.environments.logic.minesweeper.types import State @@ -69,12 +70,12 @@ def play_and_get_episode_stats( ), ( [[0, 3], [0, 2]], - [REVEALED_EMPTY_SQUARE_REWARD, REVEALED_MINE_OR_INVALID_ACTION_REWARD], + [REVEALED_EMPTY_SQUARE_REWARD, REVEALED_MINE_REWARD], [StepType.MID, StepType.LAST], ), ( [[0, 3], [0, 3]], - [REVEALED_EMPTY_SQUARE_REWARD, REVEALED_MINE_OR_INVALID_ACTION_REWARD], + [REVEALED_EMPTY_SQUARE_REWARD, INVALID_ACTION_REWARD], [StepType.MID, StepType.LAST], ), ], @@ -107,11 +108,11 @@ def test_minesweeper_env_reset(minesweeper_env: Minesweeper) -> None: assert isinstance(timestep, TimeStep) assert isinstance(state, State) assert state.step_count == 0 - assert state.flat_mine_locations.shape == (minesweeper_env.num_mines,) - assert timestep.observation.num_mines == minesweeper_env.num_mines + assert state.flat_mine_locations.shape == (minesweeper_env._generator.num_mines,) + assert timestep.observation.num_mines == minesweeper_env._generator.num_mines assert state.board.shape == ( - minesweeper_env.num_rows, - minesweeper_env.num_cols, + minesweeper_env._generator.num_rows, + minesweeper_env._generator.num_cols, ) assert jnp.array_equal(state.board, timestep.observation.board) assert timestep.observation.step_count == 0 @@ -189,9 +190,9 @@ def test_minesweeper_env_solved(minesweeper_env: Minesweeper) -> None: step_fn = jit(minesweeper_env.step) collected_rewards = [] collected_step_types = [] - for i in range(minesweeper_env.num_rows): - for j in range(minesweeper_env.num_cols): - flat_location = i * minesweeper_env.num_cols + j + for i in range(minesweeper_env._generator.num_rows): + for j in range(minesweeper_env._generator.num_cols): + flat_location = i * minesweeper_env._generator.num_cols + j if flat_location in state.flat_mine_locations: continue action = jnp.array([i, j], dtype=jnp.int32) @@ -199,7 +200,8 @@ def test_minesweeper_env_solved(minesweeper_env: Minesweeper) -> None: collected_rewards.append(timestep.reward) collected_step_types.append(timestep.step_type) expected_episode_length = ( - minesweeper_env.num_rows * minesweeper_env.num_cols - minesweeper_env.num_mines + minesweeper_env._generator.num_rows * minesweeper_env._generator.num_cols + - minesweeper_env._generator.num_mines ) assert collected_rewards == [REVEALED_EMPTY_SQUARE_REWARD] * expected_episode_length assert collected_step_types == [StepType.MID] * (expected_episode_length - 1) + [ diff --git a/jumanji/environments/logic/minesweeper/generator.py b/jumanji/environments/logic/minesweeper/generator.py new file mode 100644 index 000000000..b975b2971 --- /dev/null +++ b/jumanji/environments/logic/minesweeper/generator.py @@ -0,0 +1,85 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc + +import chex +import jax +import jax.numpy as jnp + +from jumanji.environments.logic.minesweeper.constants import UNEXPLORED_ID +from jumanji.environments.logic.minesweeper.types import State +from jumanji.environments.logic.minesweeper.utils import create_flat_mine_locations + + +class Generator(abc.ABC): + """Base class for generators for the Minesweeper environment.""" + + def __init__(self, num_rows: int, num_cols: int, num_mines: int): + """Initialises a Minesweeper generator for resetting the environment. + Args: + num_rows: number of rows, i.e. height of the board. + num_cols: number of columns, i.e. width of the board. + num_mines: number of mines to place on the board. + """ + if num_rows <= 1 or num_cols <= 1: + raise ValueError( + f"Should make a board of height and width greater than 1, " + f"got num_rows={num_rows}, num_cols={num_cols}" + ) + if num_mines < 0 or num_mines >= num_rows * num_cols: + raise ValueError( + f"Number of mines should be constrained between 0 and the size of the board, " + f"got {num_mines}" + ) + self.num_rows = num_rows + self.num_cols = num_cols + self.num_mines = num_mines + + @abc.abstractmethod + def generate_flat_mine_locations(self, key: chex.PRNGKey) -> chex.Array: + """Generates positions (in flattened coordinates) of the mines in the board""" + + def __call__(self, key: chex.PRNGKey) -> State: + """Generates a `Minesweeper` state. + Returns: + A `Minesweeper` state. + """ + key, sample_key = jax.random.split(key) + board = jnp.full( + shape=(self.num_rows, self.num_cols), + fill_value=UNEXPLORED_ID, + dtype=jnp.int32, + ) + step_count = jnp.array(0, jnp.int32) + flat_mine_locations = self.generate_flat_mine_locations(key=sample_key) + state = State( + board=board, + step_count=step_count, + key=key, + flat_mine_locations=flat_mine_locations, + ) + return state + + +class UniformSamplingGenerator(Generator): + """Generates instances by sampling a given number of mines (without replacement).""" + + def generate_flat_mine_locations(self, key: chex.PRNGKey) -> chex.Array: + return create_flat_mine_locations( + key=key, + num_rows=self.num_rows, + num_cols=self.num_cols, + num_mines=self.num_mines, + ) diff --git a/jumanji/environments/logic/minesweeper/reward.py b/jumanji/environments/logic/minesweeper/reward.py index 336d0325c..28ee5db89 100644 --- a/jumanji/environments/logic/minesweeper/reward.py +++ b/jumanji/environments/logic/minesweeper/reward.py @@ -17,10 +17,6 @@ import chex import jax.numpy as jnp -from jumanji.environments.logic.minesweeper.constants import ( - REVEALED_EMPTY_SQUARE_REWARD, - REVEALED_MINE_OR_INVALID_ACTION_REWARD, -) from jumanji.environments.logic.minesweeper.types import State from jumanji.environments.logic.minesweeper.utils import explored_mine, is_valid_action @@ -32,17 +28,29 @@ def __call__(self, state: State, action: chex.Array) -> chex.Array: class DefaultRewardFn(RewardFn): - """A dense reward function: 1 for every timestep on which a mine is not explored - (or a small penalty if action is invalid), otherwise 0. + """A dense reward function corresponding to the 3 possible events: + - Revealing an empty square + - Revealing a mine + - Choosing an invalid action (an already revealed square) """ + def __init__( + self, + revealed_empty_square_reward: float, + revealed_mine_reward: float, + invalid_action_reward: float, + ): + self.revealed_empty_square_reward = revealed_empty_square_reward + self.revelead_mine_reward = revealed_mine_reward + self.invalid_action_reward = invalid_action_reward + def __call__(self, state: State, action: chex.Array) -> chex.Array: return jnp.where( is_valid_action(state=state, action=action), jnp.where( explored_mine(state=state, action=action), - jnp.array(REVEALED_MINE_OR_INVALID_ACTION_REWARD, float), - jnp.array(REVEALED_EMPTY_SQUARE_REWARD, float), + jnp.array(self.revelead_mine_reward, float), + jnp.array(self.revealed_empty_square_reward, float), ), - jnp.array(REVEALED_MINE_OR_INVALID_ACTION_REWARD, float), + jnp.array(self.invalid_action_reward, float), ) diff --git a/jumanji/training/networks/minesweeper/actor_critic.py b/jumanji/training/networks/minesweeper/actor_critic.py index a14e04db6..95fcbe6f2 100644 --- a/jumanji/training/networks/minesweeper/actor_critic.py +++ b/jumanji/training/networks/minesweeper/actor_critic.py @@ -40,8 +40,8 @@ def make_actor_critic_networks_minesweeper( final_layer_dims: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `Minesweeper` environment.""" - board_height = minesweeper.num_rows - board_width = minesweeper.num_cols + board_height = minesweeper._generator.num_rows + board_width = minesweeper._generator.num_cols vocab_size = 1 + PATCH_SIZE**2 # unexplored, or 0, 1, ..., 8 parametric_action_distribution = FactorisedActionSpaceParametricDistribution( From b8f4c8acf92b642456d3a598217a4b650c858a5a Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Tue, 28 Mar 2023 12:47:54 +0100 Subject: [PATCH 27/58] To sync --- .../logic/minesweeper/env_viewer.py | 32 +------------------ 1 file changed, 1 insertion(+), 31 deletions(-) diff --git a/jumanji/environments/logic/minesweeper/env_viewer.py b/jumanji/environments/logic/minesweeper/env_viewer.py index e4aa2a331..887d0c423 100644 --- a/jumanji/environments/logic/minesweeper/env_viewer.py +++ b/jumanji/environments/logic/minesweeper/env_viewer.py @@ -30,41 +30,11 @@ class MinesweeperViewer(Viewer): """Abstract viewer class to support rendering and animation""" + def __init__(self, num_rows: int, num_cols: int): self.num_rows = num_rows self.num_cols = num_cols - def render(self, state: State) -> None: - """Render frames of the environment for a given state using matplotlib. - Args: - state: `State` object corresponding to the new state of the environment. - """ - raise NotImplementedError - - def animate( - self, - states: Sequence[State], - interval: int = 200, - save_path: Optional[str] = None, - ) -> matplotlib.animation.FuncAnimation: - """Create an animation from a sequence of environment states. - Args: - states: sequence of environment states corresponding to consecutive timesteps. - interval: delay between frames in milliseconds, default to 200. - save_path: the path where the animation file should be saved. If it is None, the plot - will not be saved. - Returns: - Animation that can be saved as a GIF, MP4, or rendered with HTML. - """ - raise NotImplementedError - - def close(self) -> None: - """Perform any necessary cleanup. - Environments will automatically :meth:`close()` themselves when - garbage collected or when the program exits. - """ - raise NotImplementedError - class DefaultMinesweeperViewer(MinesweeperViewer): def __init__( From 77e0b5949ab775eb455195c4637eb6501f397adf Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Tue, 28 Mar 2023 14:08:52 +0100 Subject: [PATCH 28/58] Imports --- .../logic/minesweeper/env_test.py | 25 +++++++++---------- .../environments/logic/minesweeper/utils.py | 9 +++---- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/jumanji/environments/logic/minesweeper/env_test.py b/jumanji/environments/logic/minesweeper/env_test.py index fb75d968f..e7cf0f61f 100644 --- a/jumanji/environments/logic/minesweeper/env_test.py +++ b/jumanji/environments/logic/minesweeper/env_test.py @@ -19,9 +19,8 @@ import matplotlib.pyplot as plt import pytest import pytest_mock -from jax import jit +import jax from jax import numpy as jnp -from jax import random from jumanji.environments.logic.minesweeper.constants import ( INVALID_ACTION_REWARD, @@ -41,11 +40,11 @@ def play_and_get_episode_stats( time_limit: int, force_start_state: Optional[State] = None, ) -> Tuple[List[float], List[StepType], int]: - state, timestep = jit(env.reset)(random.PRNGKey(0)) + state, timestep = jax.jit(env.reset)(jax.random.PRNGKey(0)) if force_start_state: state = force_start_state episode_length = 0 - step_fn = jit(env.step) + step_fn = jax.jit(env.step) collected_rewards = [] collected_step_types = [] while not timestep.last(): @@ -102,8 +101,8 @@ def test_default_reward_and_done_signals( def test_minesweeper_env_reset(minesweeper_env: Minesweeper) -> None: """Validates the jitted reset of the environment.""" - reset_fn = jit(minesweeper_env.reset) - key = random.PRNGKey(0) + reset_fn = jax.jit(minesweeper_env.reset) + key = jax.random.PRNGKey(0) state, timestep = reset_fn(key) assert isinstance(timestep, TimeStep) assert isinstance(state, State) @@ -125,9 +124,9 @@ def test_minesweeper_env_step(minesweeper_env: Minesweeper) -> None: """Validates the jitted step of the environment.""" chex.clear_trace_counter() step_fn = chex.assert_max_traces(minesweeper_env.step, n=2) - step_fn = jit(step_fn) - key = random.PRNGKey(0) - state, timestep = jit(minesweeper_env.reset)(key) + step_fn = jax.jit(step_fn) + key = jax.random.PRNGKey(0) + state, timestep = jax.jit(minesweeper_env.reset)(key) # For this board, this action will be a non-mined square action = minesweeper_env.action_spec().generate_value() next_state, next_timestep = step_fn(state, action) @@ -165,11 +164,11 @@ def test_minesweeper_env_render( ) -> None: """Check that the render method builds the figure but does not display it.""" monkeypatch.setattr(plt, "show", lambda fig: None) - state, timestep = jit(minesweeper_env.reset)(random.PRNGKey(0)) + state, timestep = jax.jit(minesweeper_env.reset)(jax.random.PRNGKey(0)) minesweeper_env.render(state) minesweeper_env.close() action = minesweeper_env.action_spec().generate_value() - state, timestep = jit(minesweeper_env.step)(state, action) + state, timestep = jax.jit(minesweeper_env.step)(state, action) minesweeper_env.render(state) minesweeper_env.close() @@ -186,8 +185,8 @@ def test_minesweeper_env_done_invalid_action(minesweeper_env: Minesweeper) -> No def test_minesweeper_env_solved(minesweeper_env: Minesweeper) -> None: """Solve the game and verify that things are as expected""" - state, timestep = jit(minesweeper_env.reset)(random.PRNGKey(0)) - step_fn = jit(minesweeper_env.step) + state, timestep = jax.jit(minesweeper_env.reset)(jax.random.PRNGKey(0)) + step_fn = jax.jit(minesweeper_env.step) collected_rewards = [] collected_step_types = [] for i in range(minesweeper_env._generator.num_rows): diff --git a/jumanji/environments/logic/minesweeper/utils.py b/jumanji/environments/logic/minesweeper/utils.py index 61bb213a1..4655ad92e 100644 --- a/jumanji/environments/logic/minesweeper/utils.py +++ b/jumanji/environments/logic/minesweeper/utils.py @@ -14,8 +14,7 @@ import chex import jax.numpy as jnp -from jax import random -from jax.lax import dynamic_slice_in_dim +import jax from jumanji.environments.logic.minesweeper.constants import ( IS_MINE, @@ -34,7 +33,7 @@ def create_flat_mine_locations( """Create locations of mines on a board with a specified height, width, and number of mines. The locations are in flattened coordinates. """ - return random.choice( + return jax.random.choice( key, num_rows * num_cols, shape=(num_mines,), @@ -81,11 +80,11 @@ def count_adjacent_mines(state: State, action: chex.Array) -> chex.Array: state.board.shape[-2], state.board.shape[-1] ) pad_board = jnp.pad(mined_board, pad_width=PATCH_SIZE - 1) - selected_rows = dynamic_slice_in_dim( + selected_rows = jax.lax.dynamic_slice_in_dim( pad_board, start_index=action_height + 1, slice_size=PATCH_SIZE, axis=-2 ) return ( - dynamic_slice_in_dim( + jax.lax.dynamic_slice_in_dim( selected_rows, start_index=action_width + 1, slice_size=PATCH_SIZE, From 7a202801a51d13bbcef31b9c5f5e87e3fee10f22 Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Tue, 28 Mar 2023 17:47:59 +0100 Subject: [PATCH 29/58] Some review comments --- jumanji/__init__.py | 4 ++-- jumanji/environments/logic/rubiks_cube/env.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/jumanji/__init__.py b/jumanji/__init__.py index a2e3a1cfd..53bd1438d 100644 --- a/jumanji/__init__.py +++ b/jumanji/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. from jumanji.env import Environment -from jumanji.environments.logic.rubiks_cube.generator import ScramblingGenerator +from jumanji.environments.logic.rubiks_cube import generator as rubik_generator from jumanji.registration import make, register, registered_environments from jumanji.version import __version__ @@ -33,7 +33,7 @@ register(id="RubiksCube-v0", entry_point="jumanji.environments:RubiksCube") # RubiksCube - an easier version of the standard Rubik's Cube puzzle with faces of size 3x3 yet only # 7 scrambles at reset time, making it technically maximum 7 actions away from the solution. -partly_scrambled_rubiks_cube_generator = ScramblingGenerator( +partly_scrambled_rubiks_cube_generator = rubik_generator.ScramblingGenerator( cube_size=3, num_scrambles_on_reset=7 ) register( diff --git a/jumanji/environments/logic/rubiks_cube/env.py b/jumanji/environments/logic/rubiks_cube/env.py index 2349865cb..20e0cfb3c 100644 --- a/jumanji/environments/logic/rubiks_cube/env.py +++ b/jumanji/environments/logic/rubiks_cube/env.py @@ -241,7 +241,7 @@ def animate( states: a list of `State` objects representing the sequence of game states. interval: the delay between frames in milliseconds, default to 200. save_path: the path where the animation file should be saved. If it is None, the plot - will not be saved. + will not be saved. Returns: animation.FuncAnimation: the animation object that was created. From 2816e31a828ae579e5097657beb330c08a5ca922 Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Tue, 28 Mar 2023 17:52:57 +0100 Subject: [PATCH 30/58] Return --- jumanji/environments/logic/rubiks_cube/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jumanji/environments/logic/rubiks_cube/env.py b/jumanji/environments/logic/rubiks_cube/env.py index 20e0cfb3c..649572a06 100644 --- a/jumanji/environments/logic/rubiks_cube/env.py +++ b/jumanji/environments/logic/rubiks_cube/env.py @@ -227,7 +227,7 @@ def render(self, state: State) -> None: Args: state: the current state to be rendered. """ - self._env_viewer.render(state=state) + return self._env_viewer.render(state=state) def animate( self, From 78b3da44a64b663de7bb307ea45f641e8c8b5993 Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Tue, 28 Mar 2023 18:09:13 +0100 Subject: [PATCH 31/58] Return --- jumanji/environments/logic/minesweeper/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jumanji/environments/logic/minesweeper/env.py b/jumanji/environments/logic/minesweeper/env.py index 19f8f200e..404c4d662 100644 --- a/jumanji/environments/logic/minesweeper/env.py +++ b/jumanji/environments/logic/minesweeper/env.py @@ -266,7 +266,7 @@ def render(self, state: State, save_path: Optional[str] = None) -> None: save_path: the path where the image should be saved. If it is None, the plot will not be saved. """ - self._env_viewer.render(state=state) + return self._env_viewer.render(state=state) def animate( self, From 3f0861991b29c2951781a7e1b6d8107a8c7e5c2e Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Wed, 29 Mar 2023 17:12:06 +0100 Subject: [PATCH 32/58] Import --- jumanji/environments/logic/rubiks_cube/env_viewer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jumanji/environments/logic/rubiks_cube/env_viewer.py b/jumanji/environments/logic/rubiks_cube/env_viewer.py index faf9bcb6c..e95107b4b 100644 --- a/jumanji/environments/logic/rubiks_cube/env_viewer.py +++ b/jumanji/environments/logic/rubiks_cube/env_viewer.py @@ -21,6 +21,7 @@ import jumanji.environments from jumanji.environments.logic.rubiks_cube.constants import Face from jumanji.environments.logic.rubiks_cube.types import State +from jumanji.viewer import Viewer class RubiksCubeViewer(Viewer): From 616832e5a3c37bc898229fd286091705190cb274 Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Wed, 29 Mar 2023 17:12:57 +0100 Subject: [PATCH 33/58] Import --- jumanji/environments/logic/minesweeper/env_test.py | 2 +- jumanji/environments/logic/minesweeper/env_viewer.py | 1 + jumanji/environments/logic/minesweeper/utils.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/jumanji/environments/logic/minesweeper/env_test.py b/jumanji/environments/logic/minesweeper/env_test.py index e7cf0f61f..5eef6bdfe 100644 --- a/jumanji/environments/logic/minesweeper/env_test.py +++ b/jumanji/environments/logic/minesweeper/env_test.py @@ -15,11 +15,11 @@ from typing import List, Optional, Tuple import chex +import jax import matplotlib.animation import matplotlib.pyplot as plt import pytest import pytest_mock -import jax from jax import numpy as jnp from jumanji.environments.logic.minesweeper.constants import ( diff --git a/jumanji/environments/logic/minesweeper/env_viewer.py b/jumanji/environments/logic/minesweeper/env_viewer.py index 887d0c423..ca63fbbef 100644 --- a/jumanji/environments/logic/minesweeper/env_viewer.py +++ b/jumanji/environments/logic/minesweeper/env_viewer.py @@ -26,6 +26,7 @@ ) from jumanji.environments.logic.minesweeper.types import State from jumanji.environments.logic.minesweeper.utils import explored_mine +from jumanji.viewer import Viewer class MinesweeperViewer(Viewer): diff --git a/jumanji/environments/logic/minesweeper/utils.py b/jumanji/environments/logic/minesweeper/utils.py index 4655ad92e..f24c55b32 100644 --- a/jumanji/environments/logic/minesweeper/utils.py +++ b/jumanji/environments/logic/minesweeper/utils.py @@ -13,8 +13,8 @@ # limitations under the License. import chex -import jax.numpy as jnp import jax +import jax.numpy as jnp from jumanji.environments.logic.minesweeper.constants import ( IS_MINE, From c4d8e7c3bed581b04c946d25c0f530187457ef2b Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Wed, 29 Mar 2023 17:22:42 +0100 Subject: [PATCH 34/58] Generic types --- jumanji/environments/logic/rubiks_cube/env.py | 10 ++++------ jumanji/environments/logic/rubiks_cube/env_viewer.py | 11 ++--------- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/jumanji/environments/logic/rubiks_cube/env.py b/jumanji/environments/logic/rubiks_cube/env.py index 649572a06..b4cb14304 100644 --- a/jumanji/environments/logic/rubiks_cube/env.py +++ b/jumanji/environments/logic/rubiks_cube/env.py @@ -25,10 +25,7 @@ DEFAULT_STICKER_COLORS, Face, ) -from jumanji.environments.logic.rubiks_cube.env_viewer import ( - DefaultRubiksCubeViewer, - RubiksCubeViewer, -) +from jumanji.environments.logic.rubiks_cube.env_viewer import DefaultRubiksCubeViewer from jumanji.environments.logic.rubiks_cube.generator import ( Generator, ScramblingGenerator, @@ -41,6 +38,7 @@ rotate_cube, ) from jumanji.types import TimeStep, restart, termination, transition +from jumanji.viewer import Viewer class RubiksCube(Environment[State]): @@ -86,7 +84,7 @@ def __init__( self, time_limit: int = 200, reward_fn: Optional[RewardFn] = None, - env_viewer: Optional[RubiksCubeViewer] = None, + env_viewer: Optional[Viewer[State]] = None, generator: Optional[Generator] = None, ): """Instantiate a `RubiksCube` environment. @@ -95,7 +93,7 @@ def __init__( time_limit: the number of steps allowed before an episode terminates. Defaults to 200. reward_fn: `RewardFn` whose `__call__` method computes the reward given the new state. Implemented options are [`SparseRewardFn`]. Defaults to `SparseRewardFn`. - env_viewer: RubiksCubeViewer to support rendering and animation methods. + env_viewer: Viewer to support rendering and animation methods. Implemented options are [`DefaultRubiksCubeViewer`]. Defaults to `DefaultRubiksCubeViewer`. generator: `Generator` used to generate problem instances on environment reset. diff --git a/jumanji/environments/logic/rubiks_cube/env_viewer.py b/jumanji/environments/logic/rubiks_cube/env_viewer.py index e95107b4b..0040094d3 100644 --- a/jumanji/environments/logic/rubiks_cube/env_viewer.py +++ b/jumanji/environments/logic/rubiks_cube/env_viewer.py @@ -24,24 +24,17 @@ from jumanji.viewer import Viewer -class RubiksCubeViewer(Viewer): - """Abstract viewer class to support rendering and animation""" - - def __init__(self, cube_size: int): - self.cube_size = cube_size - - -class DefaultRubiksCubeViewer(RubiksCubeViewer): +class DefaultRubiksCubeViewer(Viewer[State]): def __init__(self, sticker_colors: Optional[list], cube_size: int): """ Args: sticker_colors: colors used in rendering the faces of the Rubik's cube. cube_size: size of cube to view. """ + self.cube_size = cube_size self.sticker_colors_cmap = matplotlib.colors.ListedColormap(sticker_colors) self.figure_name = f"{cube_size}x{cube_size}x{cube_size} Rubik's Cube" self.figure_size = (6.0, 6.0) - super().__init__(cube_size=cube_size) def render(self, state: State) -> None: """Render frames of the environment for a given state using matplotlib. From f3c693599141ba4b6f12019c41ced727e18a340e Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Wed, 29 Mar 2023 17:25:23 +0100 Subject: [PATCH 35/58] Generic types --- jumanji/environments/logic/minesweeper/env.py | 10 ++++------ .../environments/logic/minesweeper/env_viewer.py | 13 +++---------- 2 files changed, 7 insertions(+), 16 deletions(-) diff --git a/jumanji/environments/logic/minesweeper/env.py b/jumanji/environments/logic/minesweeper/env.py index 404c4d662..0a97781c0 100644 --- a/jumanji/environments/logic/minesweeper/env.py +++ b/jumanji/environments/logic/minesweeper/env.py @@ -29,10 +29,7 @@ UNEXPLORED_ID, ) from jumanji.environments.logic.minesweeper.done import DefaultDoneFn, DoneFn -from jumanji.environments.logic.minesweeper.env_viewer import ( - DefaultMinesweeperViewer, - MinesweeperViewer, -) +from jumanji.environments.logic.minesweeper.env_viewer import DefaultMinesweeperViewer from jumanji.environments.logic.minesweeper.generator import ( Generator, UniformSamplingGenerator, @@ -41,6 +38,7 @@ from jumanji.environments.logic.minesweeper.types import Observation, State from jumanji.environments.logic.minesweeper.utils import count_adjacent_mines from jumanji.types import TimeStep, restart, termination, transition +from jumanji.viewer import Viewer class Minesweeper(Environment[State]): @@ -98,7 +96,7 @@ def __init__( self, reward_function: Optional[RewardFn] = None, done_function: Optional[DoneFn] = None, - env_viewer: Optional[MinesweeperViewer] = None, + env_viewer: Optional[Viewer[State]] = None, generator: Optional[Generator] = None, ): """Instantiate a `Minesweeper` environment. @@ -111,7 +109,7 @@ def __init__( done_function: `DoneFn` whose `__call__` method computes the done signal given the current state, action taken, and next state. Implemented options are [`DefaultDoneFn`]. Defaults to `DefaultDoneFn`. - env_viewer: MinesweeperViewer to support rendering and animation methods. + env_viewer: Viewer to support rendering and animation methods. Implemented options are [`DefaultMinesweeperViewer`]. Defaults to `DefaultMinesweeperViewer`. generator: Generator to generate problem instances on environment reset. diff --git a/jumanji/environments/logic/minesweeper/env_viewer.py b/jumanji/environments/logic/minesweeper/env_viewer.py index ca63fbbef..e3f2c44a3 100644 --- a/jumanji/environments/logic/minesweeper/env_viewer.py +++ b/jumanji/environments/logic/minesweeper/env_viewer.py @@ -29,15 +29,7 @@ from jumanji.viewer import Viewer -class MinesweeperViewer(Viewer): - """Abstract viewer class to support rendering and animation""" - - def __init__(self, num_rows: int, num_cols: int): - self.num_rows = num_rows - self.num_cols = num_cols - - -class DefaultMinesweeperViewer(MinesweeperViewer): +class DefaultMinesweeperViewer(Viewer[State]): def __init__( self, color_mapping: Optional[List[str]] = None, @@ -51,10 +43,11 @@ def __init__( num_rows: number of rows, i.e. height of the board. Defaults to 10. num_cols: number of columns, i.e. width of the board. Defaults to 10. """ + self.num_rows = num_rows + self.num_cols = num_cols self.cmap = color_mapping if color_mapping else DEFAULT_COLOR_MAPPING self.figure_name = f"{num_rows}x{num_cols} Minesweeper" self.figure_size = (6.0, 6.0) - super().__init__(num_rows=num_rows, num_cols=num_cols) def render(self, state: State) -> None: """Render the given environment state using matplotlib. From 74a6d765a0fe3468b968883765d1e0f07a05e57d Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Wed, 29 Mar 2023 17:33:56 +0100 Subject: [PATCH 36/58] Typing --- jumanji/environments/logic/rubiks_cube/env.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jumanji/environments/logic/rubiks_cube/env.py b/jumanji/environments/logic/rubiks_cube/env.py index b4cb14304..12ae2611e 100644 --- a/jumanji/environments/logic/rubiks_cube/env.py +++ b/jumanji/environments/logic/rubiks_cube/env.py @@ -18,6 +18,7 @@ import jax import jax.numpy as jnp import matplotlib.animation +from numpy.typing import NDArray from jumanji import specs from jumanji.env import Environment @@ -219,7 +220,7 @@ def action_spec(self) -> specs.MultiDiscreteArray: def _state_to_observation(self, state: State) -> Observation: return Observation(cube=state.cube, step_count=state.step_count) - def render(self, state: State) -> None: + def render(self, state: State) -> Optional[NDArray]: """Renders the current state of the cube. Args: From 2e6b5b37f2bb3ade01fa3bd8045c0d035c08fad9 Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Wed, 29 Mar 2023 17:41:34 +0100 Subject: [PATCH 37/58] Typing --- jumanji/environments/logic/minesweeper/env.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/jumanji/environments/logic/minesweeper/env.py b/jumanji/environments/logic/minesweeper/env.py index 0a97781c0..74045b3d1 100644 --- a/jumanji/environments/logic/minesweeper/env.py +++ b/jumanji/environments/logic/minesweeper/env.py @@ -18,6 +18,7 @@ import jax import jax.numpy as jnp import matplotlib.animation +from numpy.typing import NDArray from jumanji import specs from jumanji.env import Environment @@ -257,12 +258,10 @@ def _state_to_observation(self, state: State) -> Observation: step_count=state.step_count, ) - def render(self, state: State, save_path: Optional[str] = None) -> None: + def render(self, state: State) -> Optional[NDArray]: """Renders the current state of the board. Args: state: the current state to be rendered. - save_path: the path where the image should be saved. If it is None, the plot - will not be saved. """ return self._env_viewer.render(state=state) From 9650fa153b20e7872c49b05fe775b3c558216b23 Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Wed, 29 Mar 2023 17:43:59 +0100 Subject: [PATCH 38/58] Empty --- jumanji/environments/logic/rubiks_cube/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jumanji/environments/logic/rubiks_cube/env.py b/jumanji/environments/logic/rubiks_cube/env.py index 12ae2611e..8f06edbe9 100644 --- a/jumanji/environments/logic/rubiks_cube/env.py +++ b/jumanji/environments/logic/rubiks_cube/env.py @@ -182,7 +182,7 @@ def observation_spec(self) -> specs.Spec[Observation]: - step_count: BoundedArray (jnp.int32) of shape (). """ cube = specs.BoundedArray( - shape=(len(Face), self._generator.cube_size, self._generator.cube_size), + shape=(len(Face), self._generator.cube_size, self._generator.cube_size), dtype=jnp.int8, minimum=0, maximum=len(Face) - 1, From 3a8493cc2e2e7d02de00ee489774018843e412ce Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Wed, 29 Mar 2023 17:44:05 +0100 Subject: [PATCH 39/58] Empty --- jumanji/environments/logic/rubiks_cube/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jumanji/environments/logic/rubiks_cube/env.py b/jumanji/environments/logic/rubiks_cube/env.py index 8f06edbe9..12ae2611e 100644 --- a/jumanji/environments/logic/rubiks_cube/env.py +++ b/jumanji/environments/logic/rubiks_cube/env.py @@ -182,7 +182,7 @@ def observation_spec(self) -> specs.Spec[Observation]: - step_count: BoundedArray (jnp.int32) of shape (). """ cube = specs.BoundedArray( - shape=(len(Face), self._generator.cube_size, self._generator.cube_size), + shape=(len(Face), self._generator.cube_size, self._generator.cube_size), dtype=jnp.int8, minimum=0, maximum=len(Face) - 1, From f8ced459d7af2c0194f60a1974edbb5672c8f437 Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Wed, 29 Mar 2023 17:48:44 +0100 Subject: [PATCH 40/58] Clement suggestions --- jumanji/environments/logic/rubiks_cube/env.py | 32 +++++++++---------- .../logic/rubiks_cube/env_viewer.py | 2 +- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/jumanji/environments/logic/rubiks_cube/env.py b/jumanji/environments/logic/rubiks_cube/env.py index 12ae2611e..0860519d0 100644 --- a/jumanji/environments/logic/rubiks_cube/env.py +++ b/jumanji/environments/logic/rubiks_cube/env.py @@ -26,7 +26,7 @@ DEFAULT_STICKER_COLORS, Face, ) -from jumanji.environments.logic.rubiks_cube.env_viewer import DefaultRubiksCubeViewer +from jumanji.environments.logic.rubiks_cube.env_viewer import RubiksCubeViewer from jumanji.environments.logic.rubiks_cube.generator import ( Generator, ScramblingGenerator, @@ -83,25 +83,25 @@ class RubiksCube(Environment[State]): def __init__( self, + generator: Optional[Generator] = None, time_limit: int = 200, reward_fn: Optional[RewardFn] = None, - env_viewer: Optional[Viewer[State]] = None, - generator: Optional[Generator] = None, + viewer: Optional[Viewer[State]] = None, ): """Instantiate a `RubiksCube` environment. Args: - time_limit: the number of steps allowed before an episode terminates. Defaults to 200. - reward_fn: `RewardFn` whose `__call__` method computes the reward given the new state. - Implemented options are [`SparseRewardFn`]. Defaults to `SparseRewardFn`. - env_viewer: Viewer to support rendering and animation methods. - Implemented options are [`DefaultRubiksCubeViewer`]. - Defaults to `DefaultRubiksCubeViewer`. generator: `Generator` used to generate problem instances on environment reset. Implemented options are [`ScramblingGenerator`]. Defaults to `ScramblingGenerator`. The generator will contain an attribute `cube_size`, corresponding to the number of cubies to an edge, and defaulting to 3. + time_limit: the number of steps allowed before an episode terminates. Defaults to 200. + reward_fn: `RewardFn` whose `__call__` method computes the reward given the new state. + Implemented options are [`SparseRewardFn`]. Defaults to `SparseRewardFn`. + viewer: Viewer to support rendering and animation methods. + Implemented options are [`RubiksCubeViewer`]. + Defaults to `RubiksCubeViewer`. """ if time_limit <= 0: raise ValueError( @@ -109,12 +109,12 @@ def __init__( ) self.time_limit = time_limit self.reward_function = reward_fn or SparseRewardFn() - self._generator = generator or ScramblingGenerator( + self.generator = generator or ScramblingGenerator( cube_size=3, num_scrambles_on_reset=100, ) - self._env_viewer = env_viewer or DefaultRubiksCubeViewer( - sticker_colors=DEFAULT_STICKER_COLORS, cube_size=self._generator.cube_size + self._env_viewer = viewer or RubiksCubeViewer( + sticker_colors=DEFAULT_STICKER_COLORS, cube_size=self.generator.cube_size ) def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: @@ -128,7 +128,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: timestep: `TimeStep` corresponding to the first timestep returned by the environment. """ - state = self._generator(key) + state = self.generator(key) observation = self._state_to_observation(state=state) timestep = restart(observation=observation) return state, timestep @@ -148,7 +148,7 @@ def step( next_timestep: `TimeStep` corresponding to the timestep returned by the environment. """ flattened_action = flatten_action( - unflattened_action=action, cube_size=self._generator.cube_size + unflattened_action=action, cube_size=self.generator.cube_size ) cube = rotate_cube( cube=state.cube, @@ -182,7 +182,7 @@ def observation_spec(self) -> specs.Spec[Observation]: - step_count: BoundedArray (jnp.int32) of shape (). """ cube = specs.BoundedArray( - shape=(len(Face), self._generator.cube_size, self._generator.cube_size), + shape=(len(Face), self.generator.cube_size, self.generator.cube_size), dtype=jnp.int8, minimum=0, maximum=len(Face) - 1, @@ -211,7 +211,7 @@ def action_spec(self) -> specs.MultiDiscreteArray: """ return specs.MultiDiscreteArray( num_values=jnp.array( - [len(Face), self._generator.cube_size // 2, 3], jnp.int32 + [len(Face), self.generator.cube_size // 2, 3], jnp.int32 ), name="action", dtype=jnp.int32, diff --git a/jumanji/environments/logic/rubiks_cube/env_viewer.py b/jumanji/environments/logic/rubiks_cube/env_viewer.py index 0040094d3..77d493a78 100644 --- a/jumanji/environments/logic/rubiks_cube/env_viewer.py +++ b/jumanji/environments/logic/rubiks_cube/env_viewer.py @@ -24,7 +24,7 @@ from jumanji.viewer import Viewer -class DefaultRubiksCubeViewer(Viewer[State]): +class RubiksCubeViewer(Viewer[State]): def __init__(self, sticker_colors: Optional[list], cube_size: int): """ Args: From 27c1241cfd158dd751a16739d1ab8ed15342e7c2 Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Wed, 29 Mar 2023 17:53:55 +0100 Subject: [PATCH 41/58] Tidy --- jumanji/environments/logic/minesweeper/env.py | 46 +++++++++---------- .../logic/minesweeper/env_test.py | 18 ++++---- .../logic/minesweeper/env_viewer.py | 2 +- .../networks/minesweeper/actor_critic.py | 4 +- 4 files changed, 35 insertions(+), 35 deletions(-) diff --git a/jumanji/environments/logic/minesweeper/env.py b/jumanji/environments/logic/minesweeper/env.py index 74045b3d1..3155dea1e 100644 --- a/jumanji/environments/logic/minesweeper/env.py +++ b/jumanji/environments/logic/minesweeper/env.py @@ -30,7 +30,7 @@ UNEXPLORED_ID, ) from jumanji.environments.logic.minesweeper.done import DefaultDoneFn, DoneFn -from jumanji.environments.logic.minesweeper.env_viewer import DefaultMinesweeperViewer +from jumanji.environments.logic.minesweeper.env_viewer import MinesweeperViewer from jumanji.environments.logic.minesweeper.generator import ( Generator, UniformSamplingGenerator, @@ -95,24 +95,15 @@ class Minesweeper(Environment[State]): def __init__( self, + generator: Optional[Generator] = None, reward_function: Optional[RewardFn] = None, done_function: Optional[DoneFn] = None, - env_viewer: Optional[Viewer[State]] = None, - generator: Optional[Generator] = None, + viewer: Optional[Viewer[State]] = None, ): """Instantiate a `Minesweeper` environment. Args: - reward_function: `RewardFn` whose `__call__` method computes the reward of an - environment transition based on the given current state and selected action. - Implemented options are [`DefaultRewardFn`]. Defaults to `DefaultRewardFn`. - done_function: `DoneFn` whose `__call__` method computes the done signal given the - current state, action taken, and next state. - Implemented options are [`DefaultDoneFn`]. Defaults to `DefaultDoneFn`. - env_viewer: Viewer to support rendering and animation methods. - Implemented options are [`DefaultMinesweeperViewer`]. - Defaults to `DefaultMinesweeperViewer`. generator: Generator to generate problem instances on environment reset. Implemented options are [`SamplingGenerator`]. Defaults to `SamplingGenerator`. @@ -120,6 +111,15 @@ def __init__( - num_rows: number of rows, i.e. height of the board. Defaults to 10. - num_cols: number of columns, i.e. width of the board. Defaults to 10. - num_mines: number of mines generated. Defaults to 10. + reward_function: `RewardFn` whose `__call__` method computes the reward of an + environment transition based on the given current state and selected action. + Implemented options are [`DefaultRewardFn`]. Defaults to `DefaultRewardFn`. + done_function: `DoneFn` whose `__call__` method computes the done signal given the + current state, action taken, and next state. + Implemented options are [`DefaultDoneFn`]. Defaults to `DefaultDoneFn`. + viewer: Viewer to support rendering and animation methods. + Implemented options are [`MinesweeperViewer`]. + Defaults to `MinesweeperViewer`. """ self.reward_function = reward_function or DefaultRewardFn( revealed_empty_square_reward=REVEALED_EMPTY_SQUARE_REWARD, @@ -127,11 +127,11 @@ def __init__( invalid_action_reward=INVALID_ACTION_REWARD, ) self.done_function = done_function or DefaultDoneFn() - self._generator = generator or UniformSamplingGenerator( + self.generator = generator or UniformSamplingGenerator( num_rows=10, num_cols=10, num_mines=10 ) - self._env_viewer = env_viewer or DefaultMinesweeperViewer( - num_rows=self._generator.num_rows, num_cols=self._generator.num_cols + self._env_viewer = viewer or MinesweeperViewer( + num_rows=self.generator.num_rows, num_cols=self.generator.num_cols ) def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: @@ -145,7 +145,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: timestep: `TimeStep` corresponding to the first timestep returned by the environment. """ - state = self._generator(key) + state = self.generator(key) observation = self._state_to_observation(state=state) timestep = restart(observation=observation) return state, timestep @@ -198,14 +198,14 @@ def observation_spec(self) -> specs.Spec[Observation]: - step_count: BoundedArray (int32) of shape (). """ board = specs.BoundedArray( - shape=(self._generator.num_rows, self._generator.num_cols), + shape=(self.generator.num_rows, self.generator.num_cols), dtype=jnp.int32, minimum=-1, maximum=PATCH_SIZE * PATCH_SIZE - 1, name="board", ) action_mask = specs.BoundedArray( - shape=(self._generator.num_rows, self._generator.num_cols), + shape=(self.generator.num_rows, self.generator.num_cols), dtype=bool, minimum=False, maximum=True, @@ -215,15 +215,15 @@ def observation_spec(self) -> specs.Spec[Observation]: shape=(), dtype=jnp.int32, minimum=0, - maximum=self._generator.num_rows * self._generator.num_cols - 1, + maximum=self.generator.num_rows * self.generator.num_cols - 1, name="num_mines", ) step_count = specs.BoundedArray( shape=(), dtype=jnp.int32, minimum=0, - maximum=self._generator.num_rows * self._generator.num_cols - - self._generator.num_mines, + maximum=self.generator.num_rows * self.generator.num_cols + - self.generator.num_mines, name="step_count", ) return specs.Spec( @@ -244,7 +244,7 @@ def action_spec(self) -> specs.MultiDiscreteArray: """ return specs.MultiDiscreteArray( num_values=jnp.array( - [self._generator.num_rows, self._generator.num_cols], jnp.int32 + [self.generator.num_rows, self.generator.num_cols], jnp.int32 ), name="action", dtype=jnp.int32, @@ -254,7 +254,7 @@ def _state_to_observation(self, state: State) -> Observation: return Observation( board=state.board, action_mask=jnp.equal(state.board, UNEXPLORED_ID), - num_mines=jnp.array(self._generator.num_mines, jnp.int32), + num_mines=jnp.array(self.generator.num_mines, jnp.int32), step_count=state.step_count, ) diff --git a/jumanji/environments/logic/minesweeper/env_test.py b/jumanji/environments/logic/minesweeper/env_test.py index 5eef6bdfe..a05815f9b 100644 --- a/jumanji/environments/logic/minesweeper/env_test.py +++ b/jumanji/environments/logic/minesweeper/env_test.py @@ -107,11 +107,11 @@ def test_minesweeper_env_reset(minesweeper_env: Minesweeper) -> None: assert isinstance(timestep, TimeStep) assert isinstance(state, State) assert state.step_count == 0 - assert state.flat_mine_locations.shape == (minesweeper_env._generator.num_mines,) - assert timestep.observation.num_mines == minesweeper_env._generator.num_mines + assert state.flat_mine_locations.shape == (minesweeper_env.generator.num_mines,) + assert timestep.observation.num_mines == minesweeper_env.generator.num_mines assert state.board.shape == ( - minesweeper_env._generator.num_rows, - minesweeper_env._generator.num_cols, + minesweeper_env.generator.num_rows, + minesweeper_env.generator.num_cols, ) assert jnp.array_equal(state.board, timestep.observation.board) assert timestep.observation.step_count == 0 @@ -189,9 +189,9 @@ def test_minesweeper_env_solved(minesweeper_env: Minesweeper) -> None: step_fn = jax.jit(minesweeper_env.step) collected_rewards = [] collected_step_types = [] - for i in range(minesweeper_env._generator.num_rows): - for j in range(minesweeper_env._generator.num_cols): - flat_location = i * minesweeper_env._generator.num_cols + j + for i in range(minesweeper_env.generator.num_rows): + for j in range(minesweeper_env.generator.num_cols): + flat_location = i * minesweeper_env.generator.num_cols + j if flat_location in state.flat_mine_locations: continue action = jnp.array([i, j], dtype=jnp.int32) @@ -199,8 +199,8 @@ def test_minesweeper_env_solved(minesweeper_env: Minesweeper) -> None: collected_rewards.append(timestep.reward) collected_step_types.append(timestep.step_type) expected_episode_length = ( - minesweeper_env._generator.num_rows * minesweeper_env._generator.num_cols - - minesweeper_env._generator.num_mines + minesweeper_env.generator.num_rows * minesweeper_env.generator.num_cols + - minesweeper_env.generator.num_mines ) assert collected_rewards == [REVEALED_EMPTY_SQUARE_REWARD] * expected_episode_length assert collected_step_types == [StepType.MID] * (expected_episode_length - 1) + [ diff --git a/jumanji/environments/logic/minesweeper/env_viewer.py b/jumanji/environments/logic/minesweeper/env_viewer.py index e3f2c44a3..40e2e66f1 100644 --- a/jumanji/environments/logic/minesweeper/env_viewer.py +++ b/jumanji/environments/logic/minesweeper/env_viewer.py @@ -29,7 +29,7 @@ from jumanji.viewer import Viewer -class DefaultMinesweeperViewer(Viewer[State]): +class MinesweeperViewer(Viewer[State]): def __init__( self, color_mapping: Optional[List[str]] = None, diff --git a/jumanji/training/networks/minesweeper/actor_critic.py b/jumanji/training/networks/minesweeper/actor_critic.py index 95fcbe6f2..b68dd5bc7 100644 --- a/jumanji/training/networks/minesweeper/actor_critic.py +++ b/jumanji/training/networks/minesweeper/actor_critic.py @@ -40,8 +40,8 @@ def make_actor_critic_networks_minesweeper( final_layer_dims: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `Minesweeper` environment.""" - board_height = minesweeper._generator.num_rows - board_width = minesweeper._generator.num_cols + board_height = minesweeper.generator.num_rows + board_width = minesweeper.generator.num_cols vocab_size = 1 + PATCH_SIZE**2 # unexplored, or 0, 1, ..., 8 parametric_action_distribution = FactorisedActionSpaceParametricDistribution( From 3338299bd9e3ad0a4b96150bcd931bb03e0bb2af Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Wed, 29 Mar 2023 17:56:03 +0100 Subject: [PATCH 42/58] Rename --- jumanji/environments/logic/minesweeper/env.py | 2 +- .../environments/logic/minesweeper/{env_viewer.py => viewer.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename jumanji/environments/logic/minesweeper/{env_viewer.py => viewer.py} (100%) diff --git a/jumanji/environments/logic/minesweeper/env.py b/jumanji/environments/logic/minesweeper/env.py index 3155dea1e..b6772296d 100644 --- a/jumanji/environments/logic/minesweeper/env.py +++ b/jumanji/environments/logic/minesweeper/env.py @@ -30,7 +30,7 @@ UNEXPLORED_ID, ) from jumanji.environments.logic.minesweeper.done import DefaultDoneFn, DoneFn -from jumanji.environments.logic.minesweeper.env_viewer import MinesweeperViewer +from jumanji.environments.logic.minesweeper.viewer import MinesweeperViewer from jumanji.environments.logic.minesweeper.generator import ( Generator, UniformSamplingGenerator, diff --git a/jumanji/environments/logic/minesweeper/env_viewer.py b/jumanji/environments/logic/minesweeper/viewer.py similarity index 100% rename from jumanji/environments/logic/minesweeper/env_viewer.py rename to jumanji/environments/logic/minesweeper/viewer.py From 4b4f0c1ded80248151f016263ff03f095e2ec0f9 Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Wed, 29 Mar 2023 17:57:39 +0100 Subject: [PATCH 43/58] Rename more --- jumanji/environments/logic/rubiks_cube/env.py | 10 +++++----- .../logic/rubiks_cube/{env_viewer.py => viewer.py} | 0 2 files changed, 5 insertions(+), 5 deletions(-) rename jumanji/environments/logic/rubiks_cube/{env_viewer.py => viewer.py} (100%) diff --git a/jumanji/environments/logic/rubiks_cube/env.py b/jumanji/environments/logic/rubiks_cube/env.py index 0860519d0..da39befd8 100644 --- a/jumanji/environments/logic/rubiks_cube/env.py +++ b/jumanji/environments/logic/rubiks_cube/env.py @@ -26,7 +26,6 @@ DEFAULT_STICKER_COLORS, Face, ) -from jumanji.environments.logic.rubiks_cube.env_viewer import RubiksCubeViewer from jumanji.environments.logic.rubiks_cube.generator import ( Generator, ScramblingGenerator, @@ -38,6 +37,7 @@ is_solved, rotate_cube, ) +from jumanji.environments.logic.rubiks_cube.viewer import RubiksCubeViewer from jumanji.types import TimeStep, restart, termination, transition from jumanji.viewer import Viewer @@ -113,7 +113,7 @@ def __init__( cube_size=3, num_scrambles_on_reset=100, ) - self._env_viewer = viewer or RubiksCubeViewer( + self._viewer = viewer or RubiksCubeViewer( sticker_colors=DEFAULT_STICKER_COLORS, cube_size=self.generator.cube_size ) @@ -226,7 +226,7 @@ def render(self, state: State) -> Optional[NDArray]: Args: state: the current state to be rendered. """ - return self._env_viewer.render(state=state) + return self._viewer.render(state=state) def animate( self, @@ -245,7 +245,7 @@ def animate( Returns: animation.FuncAnimation: the animation object that was created. """ - return self._env_viewer.animate( + return self._viewer.animate( states=states, interval=interval, save_path=save_path ) @@ -255,4 +255,4 @@ def close(self) -> None: Environments will automatically :meth:`close()` themselves when garbage collected or when the program exits. """ - self._env_viewer.close() + self._viewer.close() diff --git a/jumanji/environments/logic/rubiks_cube/env_viewer.py b/jumanji/environments/logic/rubiks_cube/viewer.py similarity index 100% rename from jumanji/environments/logic/rubiks_cube/env_viewer.py rename to jumanji/environments/logic/rubiks_cube/viewer.py From 2177837303237edf5cdacea61fe7ff49ac32be24 Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Wed, 29 Mar 2023 17:58:15 +0100 Subject: [PATCH 44/58] Rename more --- jumanji/environments/logic/minesweeper/env.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/jumanji/environments/logic/minesweeper/env.py b/jumanji/environments/logic/minesweeper/env.py index b6772296d..7518872cd 100644 --- a/jumanji/environments/logic/minesweeper/env.py +++ b/jumanji/environments/logic/minesweeper/env.py @@ -130,7 +130,7 @@ def __init__( self.generator = generator or UniformSamplingGenerator( num_rows=10, num_cols=10, num_mines=10 ) - self._env_viewer = viewer or MinesweeperViewer( + self._viewer = viewer or MinesweeperViewer( num_rows=self.generator.num_rows, num_cols=self.generator.num_cols ) @@ -263,7 +263,7 @@ def render(self, state: State) -> Optional[NDArray]: Args: state: the current state to be rendered. """ - return self._env_viewer.render(state=state) + return self._viewer.render(state=state) def animate( self, @@ -280,7 +280,7 @@ def animate( Returns: animation.FuncAnimation: the animation object that was created. """ - return self._env_viewer.animate( + return self._viewer.animate( states=states, interval=interval, save_path=save_path ) @@ -289,4 +289,4 @@ def close(self) -> None: Environments will automatically :meth:`close()` themselves when garbage collected or when the program exits. """ - self._env_viewer.close() + self._viewer.close() From 60bbc340071ef812782cd5fd467178ac91eeda80 Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Wed, 29 Mar 2023 17:59:02 +0100 Subject: [PATCH 45/58] Lint --- jumanji/environments/logic/minesweeper/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jumanji/environments/logic/minesweeper/env.py b/jumanji/environments/logic/minesweeper/env.py index 7518872cd..13e57f0bf 100644 --- a/jumanji/environments/logic/minesweeper/env.py +++ b/jumanji/environments/logic/minesweeper/env.py @@ -30,7 +30,6 @@ UNEXPLORED_ID, ) from jumanji.environments.logic.minesweeper.done import DefaultDoneFn, DoneFn -from jumanji.environments.logic.minesweeper.viewer import MinesweeperViewer from jumanji.environments.logic.minesweeper.generator import ( Generator, UniformSamplingGenerator, @@ -38,6 +37,7 @@ from jumanji.environments.logic.minesweeper.reward import DefaultRewardFn, RewardFn from jumanji.environments.logic.minesweeper.types import Observation, State from jumanji.environments.logic.minesweeper.utils import count_adjacent_mines +from jumanji.environments.logic.minesweeper.viewer import MinesweeperViewer from jumanji.types import TimeStep, restart, termination, transition from jumanji.viewer import Viewer From cbda5ca5ca51ef508f63610c57ae09733703d0cf Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Wed, 29 Mar 2023 18:25:43 +0100 Subject: [PATCH 46/58] test commit From d3636e65d1565399e734b317d75d2ea9229ba0ff Mon Sep 17 00:00:00 2001 From: TristanKalloniatis Date: Thu, 30 Mar 2023 10:09:27 +0100 Subject: [PATCH 47/58] test commit From 33936b726394fe4546b6abf279f8b58d0fbd84dc Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Thu, 30 Mar 2023 10:12:16 +0100 Subject: [PATCH 48/58] test commit From 1f1edd8aae32f6b27bd87a55dbc97589f193c71e Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Thu, 30 Mar 2023 12:17:36 +0100 Subject: [PATCH 49/58] Daniel changes --- jumanji/environments/logic/minesweeper/env.py | 34 +++++++++---------- .../logic/minesweeper/env_test.py | 17 +++++----- .../logic/minesweeper/generator.py | 6 ++-- .../environments/logic/minesweeper/viewer.py | 9 +++-- jumanji/environments/logic/rubiks_cube/env.py | 8 ++--- .../networks/minesweeper/actor_critic.py | 4 +-- 6 files changed, 38 insertions(+), 40 deletions(-) diff --git a/jumanji/environments/logic/minesweeper/env.py b/jumanji/environments/logic/minesweeper/env.py index 13e57f0bf..b35fa60f0 100644 --- a/jumanji/environments/logic/minesweeper/env.py +++ b/jumanji/environments/logic/minesweeper/env.py @@ -103,10 +103,8 @@ def __init__( """Instantiate a `Minesweeper` environment. Args: - - generator: Generator to generate problem instances on environment reset. - Implemented options are [`SamplingGenerator`]. - Defaults to `SamplingGenerator`. + generator: `Generator` to generate problem instances on environment reset. + Implemented options are [`SamplingGenerator`]. Defaults to `SamplingGenerator`. The generator will have attributes: - num_rows: number of rows, i.e. height of the board. Defaults to 10. - num_cols: number of columns, i.e. width of the board. Defaults to 10. @@ -117,9 +115,8 @@ def __init__( done_function: `DoneFn` whose `__call__` method computes the done signal given the current state, action taken, and next state. Implemented options are [`DefaultDoneFn`]. Defaults to `DefaultDoneFn`. - viewer: Viewer to support rendering and animation methods. - Implemented options are [`MinesweeperViewer`]. - Defaults to `MinesweeperViewer`. + viewer: `Viewer` to support rendering and animation methods. + Implemented options are [`MinesweeperViewer`]. Defaults to `MinesweeperViewer`. """ self.reward_function = reward_function or DefaultRewardFn( revealed_empty_square_reward=REVEALED_EMPTY_SQUARE_REWARD, @@ -130,8 +127,11 @@ def __init__( self.generator = generator or UniformSamplingGenerator( num_rows=10, num_cols=10, num_mines=10 ) + self.num_rows = self.generator.num_rows + self.num_cols = self.generator.num_cols + self.num_mines = self.generator.num_mines self._viewer = viewer or MinesweeperViewer( - num_rows=self.generator.num_rows, num_cols=self.generator.num_cols + num_rows=self.num_rows, num_cols=self.num_cols ) def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: @@ -198,14 +198,14 @@ def observation_spec(self) -> specs.Spec[Observation]: - step_count: BoundedArray (int32) of shape (). """ board = specs.BoundedArray( - shape=(self.generator.num_rows, self.generator.num_cols), + shape=(self.num_rows, self.num_cols), dtype=jnp.int32, minimum=-1, maximum=PATCH_SIZE * PATCH_SIZE - 1, name="board", ) action_mask = specs.BoundedArray( - shape=(self.generator.num_rows, self.generator.num_cols), + shape=(self.num_rows, self.num_cols), dtype=bool, minimum=False, maximum=True, @@ -215,15 +215,14 @@ def observation_spec(self) -> specs.Spec[Observation]: shape=(), dtype=jnp.int32, minimum=0, - maximum=self.generator.num_rows * self.generator.num_cols - 1, + maximum=self.num_rows * self.num_cols - 1, name="num_mines", ) step_count = specs.BoundedArray( shape=(), dtype=jnp.int32, minimum=0, - maximum=self.generator.num_rows * self.generator.num_cols - - self.generator.num_mines, + maximum=self.num_rows * self.num_cols - self.num_mines, name="step_count", ) return specs.Spec( @@ -243,9 +242,7 @@ def action_spec(self) -> specs.MultiDiscreteArray: action_spec: `specs.MultiDiscreteArray` object. """ return specs.MultiDiscreteArray( - num_values=jnp.array( - [self.generator.num_rows, self.generator.num_cols], jnp.int32 - ), + num_values=jnp.array([self.num_rows, self.num_cols], jnp.int32), name="action", dtype=jnp.int32, ) @@ -254,12 +251,13 @@ def _state_to_observation(self, state: State) -> Observation: return Observation( board=state.board, action_mask=jnp.equal(state.board, UNEXPLORED_ID), - num_mines=jnp.array(self.generator.num_mines, jnp.int32), + num_mines=jnp.array(self.num_mines, jnp.int32), step_count=state.step_count, ) def render(self, state: State) -> Optional[NDArray]: """Renders the current state of the board. + Args: state: the current state to be rendered. """ @@ -272,11 +270,13 @@ def animate( save_path: Optional[str] = None, ) -> matplotlib.animation.FuncAnimation: """Creates an animated gif of the board based on the sequence of states. + Args: states: a list of `State` objects representing the sequence of states. interval: the delay between frames in milliseconds, default to 200. save_path: the path where the animation file should be saved. If it is None, the plot will not be saved. + Returns: animation.FuncAnimation: the animation object that was created. """ diff --git a/jumanji/environments/logic/minesweeper/env_test.py b/jumanji/environments/logic/minesweeper/env_test.py index a05815f9b..74774dae8 100644 --- a/jumanji/environments/logic/minesweeper/env_test.py +++ b/jumanji/environments/logic/minesweeper/env_test.py @@ -107,11 +107,11 @@ def test_minesweeper_env_reset(minesweeper_env: Minesweeper) -> None: assert isinstance(timestep, TimeStep) assert isinstance(state, State) assert state.step_count == 0 - assert state.flat_mine_locations.shape == (minesweeper_env.generator.num_mines,) - assert timestep.observation.num_mines == minesweeper_env.generator.num_mines + assert state.flat_mine_locations.shape == (minesweeper_env.num_mines,) + assert timestep.observation.num_mines == minesweeper_env.num_mines assert state.board.shape == ( - minesweeper_env.generator.num_rows, - minesweeper_env.generator.num_cols, + minesweeper_env.num_rows, + minesweeper_env.num_cols, ) assert jnp.array_equal(state.board, timestep.observation.board) assert timestep.observation.step_count == 0 @@ -189,9 +189,9 @@ def test_minesweeper_env_solved(minesweeper_env: Minesweeper) -> None: step_fn = jax.jit(minesweeper_env.step) collected_rewards = [] collected_step_types = [] - for i in range(minesweeper_env.generator.num_rows): - for j in range(minesweeper_env.generator.num_cols): - flat_location = i * minesweeper_env.generator.num_cols + j + for i in range(minesweeper_env.num_rows): + for j in range(minesweeper_env.num_cols): + flat_location = i * minesweeper_env.num_cols + j if flat_location in state.flat_mine_locations: continue action = jnp.array([i, j], dtype=jnp.int32) @@ -199,8 +199,7 @@ def test_minesweeper_env_solved(minesweeper_env: Minesweeper) -> None: collected_rewards.append(timestep.reward) collected_step_types.append(timestep.step_type) expected_episode_length = ( - minesweeper_env.generator.num_rows * minesweeper_env.generator.num_cols - - minesweeper_env.generator.num_mines + minesweeper_env.num_rows * minesweeper_env.num_cols - minesweeper_env.num_mines ) assert collected_rewards == [REVEALED_EMPTY_SQUARE_REWARD] * expected_episode_length assert collected_step_types == [StepType.MID] * (expected_episode_length - 1) + [ diff --git a/jumanji/environments/logic/minesweeper/generator.py b/jumanji/environments/logic/minesweeper/generator.py index b975b2971..7e5743e30 100644 --- a/jumanji/environments/logic/minesweeper/generator.py +++ b/jumanji/environments/logic/minesweeper/generator.py @@ -24,10 +24,11 @@ class Generator(abc.ABC): - """Base class for generators for the Minesweeper environment.""" + """Base class for generators for the `Minesweeper` environment.""" def __init__(self, num_rows: int, num_cols: int, num_mines: int): """Initialises a Minesweeper generator for resetting the environment. + Args: num_rows: number of rows, i.e. height of the board. num_cols: number of columns, i.e. width of the board. @@ -49,10 +50,11 @@ def __init__(self, num_rows: int, num_cols: int, num_mines: int): @abc.abstractmethod def generate_flat_mine_locations(self, key: chex.PRNGKey) -> chex.Array: - """Generates positions (in flattened coordinates) of the mines in the board""" + """Generates positions (in flattened coordinates) of the mines in the board.""" def __call__(self, key: chex.PRNGKey) -> State: """Generates a `Minesweeper` state. + Returns: A `Minesweeper` state. """ diff --git a/jumanji/environments/logic/minesweeper/viewer.py b/jumanji/environments/logic/minesweeper/viewer.py index 40e2e66f1..fd90f528a 100644 --- a/jumanji/environments/logic/minesweeper/viewer.py +++ b/jumanji/environments/logic/minesweeper/viewer.py @@ -32,20 +32,20 @@ class MinesweeperViewer(Viewer[State]): def __init__( self, - color_mapping: Optional[List[str]] = None, num_rows: int = 10, num_cols: int = 10, + color_mapping: Optional[List[str]] = None, ): """ Args: - color_mapping: colors used in rendering the cells in Minesweeper. - Defaults to `DEFAULT_COLOR_MAPPING`. num_rows: number of rows, i.e. height of the board. Defaults to 10. num_cols: number of columns, i.e. width of the board. Defaults to 10. + color_mapping: colors used in rendering the cells in `Minesweeper`. + Defaults to `DEFAULT_COLOR_MAPPING`. """ self.num_rows = num_rows self.num_cols = num_cols - self.cmap = color_mapping if color_mapping else DEFAULT_COLOR_MAPPING + self.cmap = color_mapping or DEFAULT_COLOR_MAPPING self.figure_name = f"{num_rows}x{num_cols} Minesweeper" self.figure_size = (6.0, 6.0) @@ -54,7 +54,6 @@ def render(self, state: State) -> None: Args: state: environment state to be rendered. - """ self._clear_display() fig, ax = self._get_fig_ax() diff --git a/jumanji/environments/logic/rubiks_cube/env.py b/jumanji/environments/logic/rubiks_cube/env.py index da39befd8..dc25897b2 100644 --- a/jumanji/environments/logic/rubiks_cube/env.py +++ b/jumanji/environments/logic/rubiks_cube/env.py @@ -92,16 +92,14 @@ def __init__( Args: generator: `Generator` used to generate problem instances on environment reset. - Implemented options are [`ScramblingGenerator`]. - Defaults to `ScramblingGenerator`. + Implemented options are [`ScramblingGenerator`]. Defaults to `ScramblingGenerator`. The generator will contain an attribute `cube_size`, corresponding to the number of cubies to an edge, and defaulting to 3. time_limit: the number of steps allowed before an episode terminates. Defaults to 200. reward_fn: `RewardFn` whose `__call__` method computes the reward given the new state. Implemented options are [`SparseRewardFn`]. Defaults to `SparseRewardFn`. - viewer: Viewer to support rendering and animation methods. - Implemented options are [`RubiksCubeViewer`]. - Defaults to `RubiksCubeViewer`. + viewer: `Viewer` to support rendering and animation methods. + Implemented options are [`RubiksCubeViewer`]. Defaults to `RubiksCubeViewer`. """ if time_limit <= 0: raise ValueError( diff --git a/jumanji/training/networks/minesweeper/actor_critic.py b/jumanji/training/networks/minesweeper/actor_critic.py index b68dd5bc7..a14e04db6 100644 --- a/jumanji/training/networks/minesweeper/actor_critic.py +++ b/jumanji/training/networks/minesweeper/actor_critic.py @@ -40,8 +40,8 @@ def make_actor_critic_networks_minesweeper( final_layer_dims: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `Minesweeper` environment.""" - board_height = minesweeper.generator.num_rows - board_width = minesweeper.generator.num_cols + board_height = minesweeper.num_rows + board_width = minesweeper.num_cols vocab_size = 1 + PATCH_SIZE**2 # unexplored, or 0, 1, ..., 8 parametric_action_distribution = FactorisedActionSpaceParametricDistribution( From 2618d1baea94d446012d85481dc3978a32e3bbc2 Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Thu, 30 Mar 2023 12:21:49 +0100 Subject: [PATCH 50/58] test commit From 6a0c504a3e1c7db3a0d1792bd5d1bcd338350f51 Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Thu, 30 Mar 2023 13:12:30 +0100 Subject: [PATCH 51/58] Undo --- jumanji/environments/logic/minesweeper/env_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jumanji/environments/logic/minesweeper/env_test.py b/jumanji/environments/logic/minesweeper/env_test.py index 105a62b6f..35bf6f845 100644 --- a/jumanji/environments/logic/minesweeper/env_test.py +++ b/jumanji/environments/logic/minesweeper/env_test.py @@ -121,7 +121,7 @@ def test_minesweeper__reset(minesweeper_env: Minesweeper) -> None: def test_minesweeper__step(minesweeper_env: Minesweeper) -> None: - """Validates the jitted step of the environment.""" + """Validates the jitted step of the environment."" chex.clear_trace_counter() step_fn = chex.assert_max_traces(minesweeper_env.step, n=2) step_fn = jax.jit(step_fn) From fee60f50d78d2700f0a840619a55b9237eae0495 Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Thu, 30 Mar 2023 13:12:42 +0100 Subject: [PATCH 52/58] Redo --- jumanji/environments/logic/minesweeper/env_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jumanji/environments/logic/minesweeper/env_test.py b/jumanji/environments/logic/minesweeper/env_test.py index 35bf6f845..105a62b6f 100644 --- a/jumanji/environments/logic/minesweeper/env_test.py +++ b/jumanji/environments/logic/minesweeper/env_test.py @@ -121,7 +121,7 @@ def test_minesweeper__reset(minesweeper_env: Minesweeper) -> None: def test_minesweeper__step(minesweeper_env: Minesweeper) -> None: - """Validates the jitted step of the environment."" + """Validates the jitted step of the environment.""" chex.clear_trace_counter() step_fn = chex.assert_max_traces(minesweeper_env.step, n=2) step_fn = jax.jit(step_fn) From 562b2e8d964d10caf0cbedb14d76a000d3e184c7 Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Thu, 30 Mar 2023 15:32:37 +0100 Subject: [PATCH 53/58] Formatting --- jumanji/environments/logic/minesweeper/env.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/jumanji/environments/logic/minesweeper/env.py b/jumanji/environments/logic/minesweeper/env.py index b35fa60f0..19fbae5cc 100644 --- a/jumanji/environments/logic/minesweeper/env.py +++ b/jumanji/environments/logic/minesweeper/env.py @@ -271,11 +271,11 @@ def animate( ) -> matplotlib.animation.FuncAnimation: """Creates an animated gif of the board based on the sequence of states. - Args: - states: a list of `State` objects representing the sequence of states. - interval: the delay between frames in milliseconds, default to 200. - save_path: the path where the animation file should be saved. If it is None, the plot - will not be saved. + Args: + states: a list of `State` objects representing the sequence of states. + interval: the delay between frames in milliseconds, default to 200. + save_path: the path where the animation file should be saved. If it is None, the plot + will not be saved. Returns: animation.FuncAnimation: the animation object that was created. From 226acca83508fe3a3030e0f733711042509dc818 Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Thu, 30 Mar 2023 15:42:24 +0100 Subject: [PATCH 54/58] Naming --- jumanji/environments/logic/minesweeper/env.py | 6 +++--- .../environments/logic/minesweeper/utils.py | 18 ++++++++--------- .../logic/minesweeper/utils_test.py | 8 ++++---- .../networks/minesweeper/actor_critic.py | 20 +++++++++---------- 4 files changed, 26 insertions(+), 26 deletions(-) diff --git a/jumanji/environments/logic/minesweeper/env.py b/jumanji/environments/logic/minesweeper/env.py index 19fbae5cc..e9db2413f 100644 --- a/jumanji/environments/logic/minesweeper/env.py +++ b/jumanji/environments/logic/minesweeper/env.py @@ -56,7 +56,7 @@ class Minesweeper(Environment[State]): specifies how many timesteps have elapsed since environment reset. - action: - multi discrete array containing the square to explore (height and width). + multi discrete array containing the square to explore (row and col). - reward: jax array (float32): Configurable function of state and action. By default: @@ -164,8 +164,8 @@ def step( next_timestep: `TimeStep` corresponding to the timestep returned by the environment. """ board = state.board - action_height, action_width = action - board = board.at[action_height, action_width].set( + action_row, action_col = action + board = board.at[action_row, action_col].set( count_adjacent_mines(state=state, action=action) ) step_count = state.step_count + 1 diff --git a/jumanji/environments/logic/minesweeper/utils.py b/jumanji/environments/logic/minesweeper/utils.py index f24c55b32..271944484 100644 --- a/jumanji/environments/logic/minesweeper/utils.py +++ b/jumanji/environments/logic/minesweeper/utils.py @@ -30,7 +30,7 @@ def create_flat_mine_locations( num_cols: int, num_mines: int, ) -> Board: - """Create locations of mines on a board with a specified height, width, and number + """Create locations of mines on a board with a specified row, column, and number of mines. The locations are in flattened coordinates. """ return jax.random.choice( @@ -52,8 +52,8 @@ def is_solved(state: State) -> chex.Array: def is_valid_action(state: State, action: chex.Array) -> chex.Array: """Check if an action is exploring a square that has not already been explored.""" - action_height, action_width = action - return state.board[action_height, action_width] == UNEXPLORED_ID + action_row, action_col = action + return state.board[action_row, action_col] == UNEXPLORED_ID def get_mined_board(state: State) -> chex.Array: @@ -67,28 +67,28 @@ def get_mined_board(state: State) -> chex.Array: def explored_mine(state: State, action: chex.Array) -> chex.Array: """Check if an action is exploring a square containing a mine.""" - height, width = action - index = width + height * state.board.shape[-1] + row, col = action + index = col + row * state.board.shape[-1] mined_board = get_mined_board(state=state) return mined_board[index] == IS_MINE def count_adjacent_mines(state: State, action: chex.Array) -> chex.Array: """Count the number of mines in a 3x3 patch surrounding the selected action.""" - action_height, action_width = action + action_row, action_col = action mined_board = get_mined_board(state=state).reshape( state.board.shape[-2], state.board.shape[-1] ) pad_board = jnp.pad(mined_board, pad_width=PATCH_SIZE - 1) selected_rows = jax.lax.dynamic_slice_in_dim( - pad_board, start_index=action_height + 1, slice_size=PATCH_SIZE, axis=-2 + pad_board, start_index=action_row + 1, slice_size=PATCH_SIZE, axis=-2 ) return ( jax.lax.dynamic_slice_in_dim( selected_rows, - start_index=action_width + 1, + start_index=action_col + 1, slice_size=PATCH_SIZE, axis=-1, ).sum() - - mined_board[action_height, action_width] + - mined_board[action_row, action_col] ) diff --git a/jumanji/environments/logic/minesweeper/utils_test.py b/jumanji/environments/logic/minesweeper/utils_test.py index f2d413fa9..5ccde7174 100644 --- a/jumanji/environments/logic/minesweeper/utils_test.py +++ b/jumanji/environments/logic/minesweeper/utils_test.py @@ -55,11 +55,11 @@ def test_explored_mine( expected_explored_mine_result: bool, ) -> None: """Test whether mines are being explored""" - action_height, action_width = action + action_row, action_col = action assert ( explored_mine( manual_start_state, - jnp.array([action_height, action_width], dtype=jnp.int32), + jnp.array([action_row, action_col], dtype=jnp.int32), ) == expected_explored_mine_result ) @@ -75,11 +75,11 @@ def test_count_adjacent_mines( expected_count_adjacent_mines_result: int, ) -> None: """Test whether the mine counting function is working as expected""" - action_height, action_width = action + action_row, action_col = action assert ( count_adjacent_mines( manual_start_state, - jnp.array([action_height, action_width], dtype=jnp.int32), + jnp.array([action_row, action_col], dtype=jnp.int32), ) == expected_count_adjacent_mines_result ) diff --git a/jumanji/training/networks/minesweeper/actor_critic.py b/jumanji/training/networks/minesweeper/actor_critic.py index a14e04db6..3789be5f3 100644 --- a/jumanji/training/networks/minesweeper/actor_critic.py +++ b/jumanji/training/networks/minesweeper/actor_critic.py @@ -40,8 +40,8 @@ def make_actor_critic_networks_minesweeper( final_layer_dims: Sequence[int], ) -> ActorCriticNetworks: """Make actor-critic networks for the `Minesweeper` environment.""" - board_height = minesweeper.num_rows - board_width = minesweeper.num_cols + board_num_rows = minesweeper.num_rows + board_num_cols = minesweeper.num_cols vocab_size = 1 + PATCH_SIZE**2 # unexplored, or 0, 1, ..., 8 parametric_action_distribution = FactorisedActionSpaceParametricDistribution( @@ -49,8 +49,8 @@ def make_actor_critic_networks_minesweeper( ) policy_network = make_network_cnn( vocab_size=vocab_size, - board_height=board_height, - board_width=board_width, + board_num_rows=board_num_rows, + board_num_cols=board_num_cols, board_embed_dim=board_embed_dim, board_conv_channels=board_conv_channels, board_kernel_shape=board_kernel_shape, @@ -60,8 +60,8 @@ def make_actor_critic_networks_minesweeper( ) value_network = make_network_cnn( vocab_size=vocab_size, - board_height=board_height, - board_width=board_width, + board_num_rows=board_num_rows, + board_num_cols=board_num_cols, board_embed_dim=board_embed_dim, board_conv_channels=board_conv_channels, board_kernel_shape=board_kernel_shape, @@ -78,8 +78,8 @@ def make_actor_critic_networks_minesweeper( def make_network_cnn( vocab_size: int, - board_height: int, - board_width: int, + board_num_rows: int, + board_num_cols: int, board_embed_dim: int, board_conv_channels: Sequence[int], board_kernel_shape: int, @@ -106,9 +106,9 @@ def network_fn(observation: Observation) -> chex.Array: x = board_embedder(observation.board + 1) num_mines_embedder = hk.Linear(num_mines_embed_dim) y = num_mines_embedder( - observation.num_mines[:, None] / (board_height * board_width) + observation.num_mines[:, None] / (board_num_rows * board_num_cols) )[:, None, None, :] - y = jnp.tile(y, [1, board_height, board_width, 1]) + y = jnp.tile(y, [1, board_num_rows, board_num_cols, 1]) output = jnp.concatenate([x, y], axis=-1) final_layers = hk.nets.MLP((*final_layer_dims, 1)) output = jnp.squeeze(final_layers(output), axis=-1) From e4fd5d4931f6761485c88a47e4383ef5e3dc219a Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Thu, 30 Mar 2023 15:47:29 +0100 Subject: [PATCH 55/58] Docstring --- jumanji/environments/logic/minesweeper/utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/jumanji/environments/logic/minesweeper/utils.py b/jumanji/environments/logic/minesweeper/utils.py index 271944484..fa266ce7f 100644 --- a/jumanji/environments/logic/minesweeper/utils.py +++ b/jumanji/environments/logic/minesweeper/utils.py @@ -29,9 +29,15 @@ def create_flat_mine_locations( num_rows: int, num_cols: int, num_mines: int, -) -> Board: - """Create locations of mines on a board with a specified row, column, and number +) -> chex.Array: + """Create locations of mines on a board with a specified height, width, and number of mines. The locations are in flattened coordinates. + + Args: + key: used for sampling mine positions. + num_rows: the height of the board. + num_cols: the width of the board. + num_mines: how many mines to place. """ return jax.random.choice( key, From 856b220c48cddae9660c3d28e004c1a7c3c58616 Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Thu, 30 Mar 2023 15:53:23 +0100 Subject: [PATCH 56/58] Lint --- jumanji/environments/logic/minesweeper/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jumanji/environments/logic/minesweeper/utils.py b/jumanji/environments/logic/minesweeper/utils.py index fa266ce7f..07411f884 100644 --- a/jumanji/environments/logic/minesweeper/utils.py +++ b/jumanji/environments/logic/minesweeper/utils.py @@ -21,7 +21,7 @@ PATCH_SIZE, UNEXPLORED_ID, ) -from jumanji.environments.logic.minesweeper.types import Board, State +from jumanji.environments.logic.minesweeper.types import State def create_flat_mine_locations( From 51015da4a8864302fd7beea467df543bb25fc531 Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Thu, 30 Mar 2023 17:54:01 +0100 Subject: [PATCH 57/58] Clement suggestions --- .../environments/logic/minesweeper/conftest.py | 2 +- .../logic/minesweeper/constants.py | 3 --- jumanji/environments/logic/minesweeper/env.py | 18 +++++------------- .../environments/logic/minesweeper/env_test.py | 13 ++++--------- .../environments/logic/minesweeper/utils.py | 9 +++------ .../logic/minesweeper/utils_test.py | 6 ++---- .../environments/logic/minesweeper/viewer.py | 4 ++-- 7 files changed, 17 insertions(+), 38 deletions(-) diff --git a/jumanji/environments/logic/minesweeper/conftest.py b/jumanji/environments/logic/minesweeper/conftest.py index b032dd70f..4f7cc321c 100644 --- a/jumanji/environments/logic/minesweeper/conftest.py +++ b/jumanji/environments/logic/minesweeper/conftest.py @@ -24,7 +24,7 @@ @pytest.fixture def minesweeper_env() -> Minesweeper: - """Fixture for a default minesweeper env""" + """Fixture for a default minesweeper environment with 10 rows and columns, and 10 mines.""" return Minesweeper( generator=UniformSamplingGenerator(num_rows=10, num_cols=10, num_mines=10) ) diff --git a/jumanji/environments/logic/minesweeper/constants.py b/jumanji/environments/logic/minesweeper/constants.py index e7f0a9244..28744ecdb 100644 --- a/jumanji/environments/logic/minesweeper/constants.py +++ b/jumanji/environments/logic/minesweeper/constants.py @@ -15,9 +15,6 @@ UNEXPLORED_ID: int = -1 IS_MINE: int = 1 PATCH_SIZE: int = 3 -REVEALED_EMPTY_SQUARE_REWARD: float = 1.0 -REVEALED_MINE_REWARD: float = 0.0 -INVALID_ACTION_REWARD: float = 0.0 DEFAULT_COLOR_MAPPING: list = [ "orange", "blue", diff --git a/jumanji/environments/logic/minesweeper/env.py b/jumanji/environments/logic/minesweeper/env.py index e9db2413f..3e9864a10 100644 --- a/jumanji/environments/logic/minesweeper/env.py +++ b/jumanji/environments/logic/minesweeper/env.py @@ -22,13 +22,7 @@ from jumanji import specs from jumanji.env import Environment -from jumanji.environments.logic.minesweeper.constants import ( - INVALID_ACTION_REWARD, - PATCH_SIZE, - REVEALED_EMPTY_SQUARE_REWARD, - REVEALED_MINE_REWARD, - UNEXPLORED_ID, -) +from jumanji.environments.logic.minesweeper.constants import PATCH_SIZE, UNEXPLORED_ID from jumanji.environments.logic.minesweeper.done import DefaultDoneFn, DoneFn from jumanji.environments.logic.minesweeper.generator import ( Generator, @@ -119,9 +113,9 @@ def __init__( Implemented options are [`MinesweeperViewer`]. Defaults to `MinesweeperViewer`. """ self.reward_function = reward_function or DefaultRewardFn( - revealed_empty_square_reward=REVEALED_EMPTY_SQUARE_REWARD, - revealed_mine_reward=REVEALED_MINE_REWARD, - invalid_action_reward=INVALID_ACTION_REWARD, + revealed_empty_square_reward=1.0, + revealed_mine_reward=0.0, + invalid_action_reward=0.0, ) self.done_function = done_function or DefaultDoneFn() self.generator = generator or UniformSamplingGenerator( @@ -163,9 +157,7 @@ def step( next_state: `State` corresponding to the next state of the environment, next_timestep: `TimeStep` corresponding to the timestep returned by the environment. """ - board = state.board - action_row, action_col = action - board = board.at[action_row, action_col].set( + board = state.board.at[tuple(action)].set( count_adjacent_mines(state=state, action=action) ) step_count = state.step_count + 1 diff --git a/jumanji/environments/logic/minesweeper/env_test.py b/jumanji/environments/logic/minesweeper/env_test.py index 105a62b6f..197675f0e 100644 --- a/jumanji/environments/logic/minesweeper/env_test.py +++ b/jumanji/environments/logic/minesweeper/env_test.py @@ -22,11 +22,6 @@ import pytest_mock from jax import numpy as jnp -from jumanji.environments.logic.minesweeper.constants import ( - INVALID_ACTION_REWARD, - REVEALED_EMPTY_SQUARE_REWARD, - REVEALED_MINE_REWARD, -) from jumanji.environments.logic.minesweeper.env import Minesweeper from jumanji.environments.logic.minesweeper.types import State from jumanji.testing.env_not_smoke import check_env_does_not_smoke @@ -64,17 +59,17 @@ def play_and_get_episode_stats( [ ( [[0, 3], [1, 1], [1, 3], [2, 3], [3, 0], [3, 1], [3, 2], [3, 3]], - [REVEALED_EMPTY_SQUARE_REWARD] * 8, + [1.0] * 8, [StepType.MID] * 7 + [StepType.LAST], ), ( [[0, 3], [0, 2]], - [REVEALED_EMPTY_SQUARE_REWARD, REVEALED_MINE_REWARD], + [1.0, 0.0], [StepType.MID, StepType.LAST], ), ( [[0, 3], [0, 3]], - [REVEALED_EMPTY_SQUARE_REWARD, INVALID_ACTION_REWARD], + [1.0, 0.0], [StepType.MID, StepType.LAST], ), ], @@ -201,7 +196,7 @@ def test_minesweeper__solved(minesweeper_env: Minesweeper) -> None: expected_episode_length = ( minesweeper_env.num_rows * minesweeper_env.num_cols - minesweeper_env.num_mines ) - assert collected_rewards == [REVEALED_EMPTY_SQUARE_REWARD] * expected_episode_length + assert collected_rewards == [1.0] * expected_episode_length assert collected_step_types == [StepType.MID] * (expected_episode_length - 1) + [ StepType.LAST ] diff --git a/jumanji/environments/logic/minesweeper/utils.py b/jumanji/environments/logic/minesweeper/utils.py index 07411f884..09af0fbd5 100644 --- a/jumanji/environments/logic/minesweeper/utils.py +++ b/jumanji/environments/logic/minesweeper/utils.py @@ -58,8 +58,7 @@ def is_solved(state: State) -> chex.Array: def is_valid_action(state: State, action: chex.Array) -> chex.Array: """Check if an action is exploring a square that has not already been explored.""" - action_row, action_col = action - return state.board[action_row, action_col] == UNEXPLORED_ID + return state.board[tuple(action)] == UNEXPLORED_ID def get_mined_board(state: State) -> chex.Array: @@ -82,9 +81,7 @@ def explored_mine(state: State, action: chex.Array) -> chex.Array: def count_adjacent_mines(state: State, action: chex.Array) -> chex.Array: """Count the number of mines in a 3x3 patch surrounding the selected action.""" action_row, action_col = action - mined_board = get_mined_board(state=state).reshape( - state.board.shape[-2], state.board.shape[-1] - ) + mined_board = get_mined_board(state=state).reshape(*state.board.shape) pad_board = jnp.pad(mined_board, pad_width=PATCH_SIZE - 1) selected_rows = jax.lax.dynamic_slice_in_dim( pad_board, start_index=action_row + 1, slice_size=PATCH_SIZE, axis=-2 @@ -96,5 +93,5 @@ def count_adjacent_mines(state: State, action: chex.Array) -> chex.Array: slice_size=PATCH_SIZE, axis=-1, ).sum() - - mined_board[action_row, action_col] + - mined_board[tuple(action)] ) diff --git a/jumanji/environments/logic/minesweeper/utils_test.py b/jumanji/environments/logic/minesweeper/utils_test.py index 5ccde7174..55b2d2b99 100644 --- a/jumanji/environments/logic/minesweeper/utils_test.py +++ b/jumanji/environments/logic/minesweeper/utils_test.py @@ -55,11 +55,10 @@ def test_explored_mine( expected_explored_mine_result: bool, ) -> None: """Test whether mines are being explored""" - action_row, action_col = action assert ( explored_mine( manual_start_state, - jnp.array([action_row, action_col], dtype=jnp.int32), + jnp.array(action, dtype=jnp.int32), ) == expected_explored_mine_result ) @@ -75,11 +74,10 @@ def test_count_adjacent_mines( expected_count_adjacent_mines_result: int, ) -> None: """Test whether the mine counting function is working as expected""" - action_row, action_col = action assert ( count_adjacent_mines( manual_start_state, - jnp.array([action_row, action_col], dtype=jnp.int32), + jnp.array(action, dtype=jnp.int32), ) == expected_count_adjacent_mines_result ) diff --git a/jumanji/environments/logic/minesweeper/viewer.py b/jumanji/environments/logic/minesweeper/viewer.py index fd90f528a..cc681a0fb 100644 --- a/jumanji/environments/logic/minesweeper/viewer.py +++ b/jumanji/environments/logic/minesweeper/viewer.py @@ -32,8 +32,8 @@ class MinesweeperViewer(Viewer[State]): def __init__( self, - num_rows: int = 10, - num_cols: int = 10, + num_rows: int, + num_cols: int, color_mapping: Optional[List[str]] = None, ): """ From 859ccf4d5d87f2e24d62d3508140248e59fca3b2 Mon Sep 17 00:00:00 2001 From: tristankalloniatis Date: Fri, 31 Mar 2023 11:57:00 +0100 Subject: [PATCH 58/58] Docstrings --- jumanji/environments/logic/minesweeper/env.py | 7 +++++-- jumanji/environments/logic/minesweeper/viewer.py | 4 ++-- jumanji/environments/logic/rubiks_cube/env.py | 6 ++++-- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/jumanji/environments/logic/minesweeper/env.py b/jumanji/environments/logic/minesweeper/env.py index 3e9864a10..a5d9e5f01 100644 --- a/jumanji/environments/logic/minesweeper/env.py +++ b/jumanji/environments/logic/minesweeper/env.py @@ -105,10 +105,13 @@ def __init__( - num_mines: number of mines generated. Defaults to 10. reward_function: `RewardFn` whose `__call__` method computes the reward of an environment transition based on the given current state and selected action. - Implemented options are [`DefaultRewardFn`]. Defaults to `DefaultRewardFn`. + Implemented options are [`DefaultRewardFn`]. Defaults to `DefaultRewardFn`, giving + a reward of 1.0 for revealing an empty square, 0.0 for revealing a mine, and + 0.0 for an invalid action (selecting an already revealed square). done_function: `DoneFn` whose `__call__` method computes the done signal given the current state, action taken, and next state. - Implemented options are [`DefaultDoneFn`]. Defaults to `DefaultDoneFn`. + Implemented options are [`DefaultDoneFn`]. Defaults to `DefaultDoneFn`, ending the + episode on solving the board, revealing a mine, or picking an invalid action. viewer: `Viewer` to support rendering and animation methods. Implemented options are [`MinesweeperViewer`]. Defaults to `MinesweeperViewer`. """ diff --git a/jumanji/environments/logic/minesweeper/viewer.py b/jumanji/environments/logic/minesweeper/viewer.py index cc681a0fb..de8a67fa0 100644 --- a/jumanji/environments/logic/minesweeper/viewer.py +++ b/jumanji/environments/logic/minesweeper/viewer.py @@ -38,8 +38,8 @@ def __init__( ): """ Args: - num_rows: number of rows, i.e. height of the board. Defaults to 10. - num_cols: number of columns, i.e. width of the board. Defaults to 10. + num_rows: number of rows, i.e. height of the board. + num_cols: number of columns, i.e. width of the board. color_mapping: colors used in rendering the cells in `Minesweeper`. Defaults to `DEFAULT_COLOR_MAPPING`. """ diff --git a/jumanji/environments/logic/rubiks_cube/env.py b/jumanji/environments/logic/rubiks_cube/env.py index dc25897b2..bd01f9809 100644 --- a/jumanji/environments/logic/rubiks_cube/env.py +++ b/jumanji/environments/logic/rubiks_cube/env.py @@ -92,12 +92,14 @@ def __init__( Args: generator: `Generator` used to generate problem instances on environment reset. - Implemented options are [`ScramblingGenerator`]. Defaults to `ScramblingGenerator`. + Implemented options are [`ScramblingGenerator`]. Defaults to `ScramblingGenerator`, + with 100 scrambles on reset. The generator will contain an attribute `cube_size`, corresponding to the number of cubies to an edge, and defaulting to 3. time_limit: the number of steps allowed before an episode terminates. Defaults to 200. reward_fn: `RewardFn` whose `__call__` method computes the reward given the new state. - Implemented options are [`SparseRewardFn`]. Defaults to `SparseRewardFn`. + Implemented options are [`SparseRewardFn`]. Defaults to `SparseRewardFn`, giving a + reward of 1.0 if the cube is solved or otherwise 0.0. viewer: `Viewer` to support rendering and animation methods. Implemented options are [`RubiksCubeViewer`]. Defaults to `RubiksCubeViewer`. """