Skip to content

Commit

Permalink
fix(2048): incorrect action mask (#144)
Browse files Browse the repository at this point in the history
Co-authored-by: Clément Bonnet <56230714+clement-bonnet@users.noreply.github.com>
  • Loading branch information
aar65537 and clement-bonnet authored May 26, 2023
1 parent e1a80fa commit b16cf5d
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 13 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ problems.

| Environment | Category | Registered Version(s) | Source | Description |
|------------------------------------------|----------|------------------------------------------------------|--------------------------------------------------------------------------------------------------|------------------------------------------------------------------------|
| 🔢 Game2048 | Logic | `Game2048-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/game_2048/) | [doc](https://instadeepai.github.io/jumanji/environments/game_2048/) |
| 🔢 Game2048 | Logic | `Game2048-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/game_2048/) | [doc](https://instadeepai.github.io/jumanji/environments/game_2048/) |
| 💣 Minesweeper | Logic | `Minesweeper-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/minesweeper/) | [doc](https://instadeepai.github.io/jumanji/environments/minesweeper/) |
| 🎲 RubiksCube | Logic | `RubiksCube-v0`<br/>`RubiksCube-partly-scrambled-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/rubiks_cube/) | [doc](https://instadeepai.github.io/jumanji/environments/rubiks_cube/) |
| 📦 BinPack (3D BinPacking Problem) | Packing | `BinPack-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/bin_pack/) | [doc](https://instadeepai.github.io/jumanji/environments/bin_pack/) |
Expand Down
2 changes: 1 addition & 1 deletion docs/environments/game_2048.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,4 @@ tile, the total reward from these actions is 1536 (i.e., 1024 + 512).
## Registered Versions 📖
- `Game2048-v0`, the default settings for 2048 with a board of size 4x4.
- `Game2048-v1`, the default settings for 2048 with a board of size 4x4.
2 changes: 1 addition & 1 deletion jumanji/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
###

# Game2048 - the game of 2048 with the default board size of 4x4.
register(id="Game2048-v0", entry_point="jumanji.environments:Game2048")
register(id="Game2048-v1", entry_point="jumanji.environments:Game2048")

# Minesweeper on a board of size 10x10 with 10 mines.
register(id="Minesweeper-v0", entry_point="jumanji.environments:Minesweeper")
Expand Down
18 changes: 9 additions & 9 deletions jumanji/environments/logic/game_2048/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,25 +187,22 @@ def step(
state.board,
)

# Generate action mask to keep in the state for the next step and
# to provide to the agent in the observation.
action_mask = self._get_action_mask(board=updated_board)

# Check if the episode terminates (i.e. there are no legal actions).
done = ~jnp.any(action_mask)

# Generate new key.
random_cell_key, new_state_key = jax.random.split(state.key)

# Update the state of the board by adding a new random cell.
updated_board = jax.lax.cond(
done,
lambda board, key: board,
state.action_mask[action],
self._add_random_cell,
lambda board, key: board,
updated_board,
random_cell_key,
)

# Generate action mask to keep in the state for the next step and
# to provide to the agent in the observation.
action_mask = self._get_action_mask(board=updated_board)

# Build the state.
state = State(
board=updated_board,
Expand All @@ -221,6 +218,9 @@ def step(
action_mask=action_mask,
)

# Check if the episode terminates (i.e. there are no legal actions).
done = ~jnp.any(action_mask)

# Return either a MID or a LAST timestep depending on done.
highest_tile = 2 ** jnp.max(updated_board)
extras = {"highest_tile": highest_tile}
Expand Down
37 changes: 36 additions & 1 deletion jumanji/environments/logic/game_2048/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_game_2048__step_jit(game_2048: Game2048) -> None:
"""Confirm that the step is only compiled once when jitted."""
key = jax.random.PRNGKey(0)
state, timestep = game_2048.reset(key)
action = jnp.array(0)
action = jnp.argmax(state.action_mask)

chex.clear_trace_counter()
step_fn = jax.jit(chex.assert_max_traces(game_2048.step, n=1))
Expand All @@ -75,12 +75,47 @@ def test_game_2048__step_jit(game_2048: Game2048) -> None:

# New step
state = new_state
action = jnp.argmax(state.action_mask)
new_state, next_timestep = step_fn(state, action)

# Check that the state has changed
assert not jnp.array_equal(new_state.board, state.board)


def test_game_2048__step_invalid(game_2048: Game2048) -> None:
"""Confirm that performing step on an invalid action does nothing."""
state = State(
board=jnp.array([[1, 1, 1, 1], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]),
step_count=jnp.array(0),
action_mask=jnp.array([False, True, True, True]),
score=jnp.array(0),
key=jax.random.PRNGKey(0),
)
action = jnp.array(0)
step_fn = jax.jit(game_2048.step)
new_state, next_timestep = step_fn(state, action)
assert jnp.array_equal(state.board, new_state.board)
assert jnp.array_equal(state.step_count + 1, new_state.step_count)
assert jnp.array_equal(state.action_mask, new_state.action_mask)
assert jnp.array_equal(state.score, new_state.score)


def test_game_2048__step_action_mask(game_2048: Game2048) -> None:
"""Verify that the action mask returned from `step` is correct."""
state = State(
board=jnp.array([[0, 1, 2, 3], [3, 1, 2, 3], [1, 2, 3, 4], [4, 3, 2, 1]]),
step_count=jnp.array(0),
action_mask=jnp.array([True, False, True, True]),
score=jnp.array(0),
key=jax.random.PRNGKey(0),
)
action = jnp.array(3)
step_fn = jax.jit(game_2048.step)
new_state, next_timestep = step_fn(state, action)
expected_action_mask = jnp.array([False, False, False, False])
assert jnp.array_equal(new_state.action_mask, expected_action_mask)


def test_game_2048__generate_board(game_2048: Game2048) -> None:
"""Confirm that `generate_board` method creates an initial board that
follows the rules of the 2048 game."""
Expand Down

0 comments on commit b16cf5d

Please sign in to comment.