From b16cf5dcde88a73b0c8f56f93b97813105cb99ec Mon Sep 17 00:00:00 2001 From: aar65537 <115365716+aar65537@users.noreply.github.com> Date: Fri, 26 May 2023 12:47:09 -0500 Subject: [PATCH] fix(2048): incorrect action mask (#144) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Clément Bonnet <56230714+clement-bonnet@users.noreply.github.com> --- README.md | 2 +- docs/environments/game_2048.md | 2 +- jumanji/__init__.py | 2 +- jumanji/environments/logic/game_2048/env.py | 18 ++++----- .../environments/logic/game_2048/env_test.py | 37 ++++++++++++++++++- 5 files changed, 48 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index d7e1ff424..8ef413b7d 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ problems. | Environment | Category | Registered Version(s) | Source | Description | |------------------------------------------|----------|------------------------------------------------------|--------------------------------------------------------------------------------------------------|------------------------------------------------------------------------| -| 🔢 Game2048 | Logic | `Game2048-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/game_2048/) | [doc](https://instadeepai.github.io/jumanji/environments/game_2048/) | +| 🔢 Game2048 | Logic | `Game2048-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/game_2048/) | [doc](https://instadeepai.github.io/jumanji/environments/game_2048/) | | 💣 Minesweeper | Logic | `Minesweeper-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/minesweeper/) | [doc](https://instadeepai.github.io/jumanji/environments/minesweeper/) | | 🎲 RubiksCube | Logic | `RubiksCube-v0`
`RubiksCube-partly-scrambled-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/rubiks_cube/) | [doc](https://instadeepai.github.io/jumanji/environments/rubiks_cube/) | | 📦 BinPack (3D BinPacking Problem) | Packing | `BinPack-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/bin_pack/) | [doc](https://instadeepai.github.io/jumanji/environments/bin_pack/) | diff --git a/docs/environments/game_2048.md b/docs/environments/game_2048.md index 7dba35f4f..1b16c8868 100644 --- a/docs/environments/game_2048.md +++ b/docs/environments/game_2048.md @@ -60,4 +60,4 @@ tile, the total reward from these actions is 1536 (i.e., 1024 + 512). ## Registered Versions 📖 -- `Game2048-v0`, the default settings for 2048 with a board of size 4x4. +- `Game2048-v1`, the default settings for 2048 with a board of size 4x4. diff --git a/jumanji/__init__.py b/jumanji/__init__.py index 751b21a78..d717a8664 100644 --- a/jumanji/__init__.py +++ b/jumanji/__init__.py @@ -24,7 +24,7 @@ ### # Game2048 - the game of 2048 with the default board size of 4x4. -register(id="Game2048-v0", entry_point="jumanji.environments:Game2048") +register(id="Game2048-v1", entry_point="jumanji.environments:Game2048") # Minesweeper on a board of size 10x10 with 10 mines. register(id="Minesweeper-v0", entry_point="jumanji.environments:Minesweeper") diff --git a/jumanji/environments/logic/game_2048/env.py b/jumanji/environments/logic/game_2048/env.py index 185565cbb..323995bdd 100644 --- a/jumanji/environments/logic/game_2048/env.py +++ b/jumanji/environments/logic/game_2048/env.py @@ -187,25 +187,22 @@ def step( state.board, ) - # Generate action mask to keep in the state for the next step and - # to provide to the agent in the observation. - action_mask = self._get_action_mask(board=updated_board) - - # Check if the episode terminates (i.e. there are no legal actions). - done = ~jnp.any(action_mask) - # Generate new key. random_cell_key, new_state_key = jax.random.split(state.key) # Update the state of the board by adding a new random cell. updated_board = jax.lax.cond( - done, - lambda board, key: board, + state.action_mask[action], self._add_random_cell, + lambda board, key: board, updated_board, random_cell_key, ) + # Generate action mask to keep in the state for the next step and + # to provide to the agent in the observation. + action_mask = self._get_action_mask(board=updated_board) + # Build the state. state = State( board=updated_board, @@ -221,6 +218,9 @@ def step( action_mask=action_mask, ) + # Check if the episode terminates (i.e. there are no legal actions). + done = ~jnp.any(action_mask) + # Return either a MID or a LAST timestep depending on done. highest_tile = 2 ** jnp.max(updated_board) extras = {"highest_tile": highest_tile} diff --git a/jumanji/environments/logic/game_2048/env_test.py b/jumanji/environments/logic/game_2048/env_test.py index 59d1a1eab..7985e0278 100644 --- a/jumanji/environments/logic/game_2048/env_test.py +++ b/jumanji/environments/logic/game_2048/env_test.py @@ -61,7 +61,7 @@ def test_game_2048__step_jit(game_2048: Game2048) -> None: """Confirm that the step is only compiled once when jitted.""" key = jax.random.PRNGKey(0) state, timestep = game_2048.reset(key) - action = jnp.array(0) + action = jnp.argmax(state.action_mask) chex.clear_trace_counter() step_fn = jax.jit(chex.assert_max_traces(game_2048.step, n=1)) @@ -75,12 +75,47 @@ def test_game_2048__step_jit(game_2048: Game2048) -> None: # New step state = new_state + action = jnp.argmax(state.action_mask) new_state, next_timestep = step_fn(state, action) # Check that the state has changed assert not jnp.array_equal(new_state.board, state.board) +def test_game_2048__step_invalid(game_2048: Game2048) -> None: + """Confirm that performing step on an invalid action does nothing.""" + state = State( + board=jnp.array([[1, 1, 1, 1], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]), + step_count=jnp.array(0), + action_mask=jnp.array([False, True, True, True]), + score=jnp.array(0), + key=jax.random.PRNGKey(0), + ) + action = jnp.array(0) + step_fn = jax.jit(game_2048.step) + new_state, next_timestep = step_fn(state, action) + assert jnp.array_equal(state.board, new_state.board) + assert jnp.array_equal(state.step_count + 1, new_state.step_count) + assert jnp.array_equal(state.action_mask, new_state.action_mask) + assert jnp.array_equal(state.score, new_state.score) + + +def test_game_2048__step_action_mask(game_2048: Game2048) -> None: + """Verify that the action mask returned from `step` is correct.""" + state = State( + board=jnp.array([[0, 1, 2, 3], [3, 1, 2, 3], [1, 2, 3, 4], [4, 3, 2, 1]]), + step_count=jnp.array(0), + action_mask=jnp.array([True, False, True, True]), + score=jnp.array(0), + key=jax.random.PRNGKey(0), + ) + action = jnp.array(3) + step_fn = jax.jit(game_2048.step) + new_state, next_timestep = step_fn(state, action) + expected_action_mask = jnp.array([False, False, False, False]) + assert jnp.array_equal(new_state.action_mask, expected_action_mask) + + def test_game_2048__generate_board(game_2048: Game2048) -> None: """Confirm that `generate_board` method creates an initial board that follows the rules of the 2048 game."""