Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jigsaw): Implement the Jigsaw env #147

Closed
wants to merge 28 commits into from
Closed
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
960ead0
feat: initial jigsaw commit.
RuanJohn May 21, 2023
0e6c994
feat: added puzzle numbers to env viewer
RuanJohn May 22, 2023
19f21f1
feat: initial code for random agent network.
RuanJohn May 22, 2023
3988f4d
feat: remove board action mask.
RuanJohn May 22, 2023
14a53ed
feat: add jigsaw random agent.
RuanJohn May 22, 2023
62625c0
chore: change board_dim to num_rows and num_cols
RuanJohn May 22, 2023
6828cfa
feat: register environment and add random networks
RuanJohn May 23, 2023
b81dda2
feat: full action mask working.
RuanJohn May 25, 2023
f67315c
feat: cleaner action mask generation.
RuanJohn May 25, 2023
941fee1
feat: added jigsaw documentation
RuanJohn May 28, 2023
2839da0
chore: typo fix.
RuanJohn May 28, 2023
3474efa
feat: added class doctring to env.
RuanJohn May 28, 2023
6378477
feat: import jigsaw actor critic network.
RuanJohn May 28, 2023
478b504
wip: work on actor critic networks.
RuanJohn May 28, 2023
316e733
chore: better variable naming
RuanJohn May 28, 2023
4e35b86
chore: variable renaming in jigsaw networks.
RuanJohn May 29, 2023
b681f83
chore: variable renaming in jigsaw networks.
RuanJohn May 29, 2023
ee87ec3
feat: jigsaw networks implemented.
RuanJohn May 29, 2023
8416969
fix: fix action spec off by one.
RuanJohn May 29, 2023
5868835
feat: added jigsaw training config.
RuanJohn May 29, 2023
4a8caad
chore: minor fixes.
RuanJohn May 29, 2023
d0aa02c
chore: fix action space in docs.
RuanJohn May 29, 2023
e876908
chore: docs action mask fix.
RuanJohn May 29, 2023
57b6a81
chore: action mask fix in docs.
RuanJohn May 29, 2023
0267299
chore: action mask fix in docs.
RuanJohn May 29, 2023
6adfaf3
chore: indent docstrings.
RuanJohn May 29, 2023
6206819
Merge branch 'main' into 143-implement-jigsaw-env
RuanJohn May 29, 2023
b344add
fix: action mask indexing bugfix.
RuanJohn May 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ problems.
| 💣 Minesweeper | Logic | `Minesweeper-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/minesweeper/) | [doc](https://instadeepai.github.io/jumanji/environments/minesweeper/) |
| 🎲 RubiksCube | Logic | `RubiksCube-v0`<br/>`RubiksCube-partly-scrambled-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/rubiks_cube/) | [doc](https://instadeepai.github.io/jumanji/environments/rubiks_cube/) |
| 📦 BinPack (3D BinPacking Problem) | Packing | `BinPack-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/bin_pack/) | [doc](https://instadeepai.github.io/jumanji/environments/bin_pack/) |
| 🧩 Jigsaw (Jigsaw Puzzle Solving) | Packing | `Jigsaw-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/jigsaw/) | [doc](https://instadeepai.github.io/jumanji/environments/jigsaw/) |
| 🏭 JobShop (Job Shop Scheduling Problem) | Packing | `JobShop-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/job_shop/) | [doc](https://instadeepai.github.io/jumanji/environments/job_shop/) |
| 🎒 Knapsack | Packing | `Knapsack-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/knapsack/) | [doc](https://instadeepai.github.io/jumanji/environments/knapsack/) |
| 🧹 Cleaner | Routing | `Cleaner-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/cleaner/) | [doc](https://instadeepai.github.io/jumanji/environments/cleaner/) |
Expand Down
8 changes: 8 additions & 0 deletions docs/api/environments/jigsaw.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
::: jumanji.environments.packing.jigsaw.env.Jigsaw
selection:
members:
- __init__
- reset
- step
- observation_spec
- action_spec
Binary file added docs/env_anim/jigsaw.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/env_img/jigsaw.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
55 changes: 55 additions & 0 deletions docs/environments/jigsaw.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Jigsaw Environment

<p align="center">
<img src="../env_anim/jigsaw.gif" width="500"/>
</p>

We provide here a Jax JIT-able implementation of a simple _jigsaw_ puzzle. The goal of the agent is to place
all the jigsaw pieces in the correct locations on an empty 2D puzzle board. Each time an episode resets a
new puzzle and set of piece is created. Pieces are randomly shuffled and rotated.

## Observation
The observation given to the agent gives a view of the current state of the puzzle as well as
all pieces that can be placed.

- `current_board`: jax array (float32) of shape `(num_rows, num_cols)` with values in the range
`[1, num_pieces]` (corresponding to the number of each piece). This board will have zeros
RuanJohn marked this conversation as resolved.
Show resolved Hide resolved
where no pieces have been placed and numbers corresponding to each piece where that particular
pieces has been paced.
RuanJohn marked this conversation as resolved.
Show resolved Hide resolved

- `pieces`: jax array (float32) of shape `(num_pieces, 3, 3)` of all possible pieces in the
current puzzle. These pieces are shuffled and rotated. Pieces will always have shape `(3, 3)`.

- `action_mask`: jax array (bool) of shape `(num_pieces, 4, num_rows-3, num_cols-3)`, representing
RuanJohn marked this conversation as resolved.
Show resolved Hide resolved
which actions are possible given the current state of the board. The first index indicates the
number of pieces in a given puzzle. The second index indicates the number of times a piece may be rotated.
The third and fourth indices indicate the row and column coordinate of where a piece may be placed respectively.
These values will always be `num_rows-3` and `num_cols-3` respectively to make it impossible for an agent to
RuanJohn marked this conversation as resolved.
Show resolved Hide resolved
place a piece outside the current board.


## Action
The action space is a `MultiDiscreteArray`, specifically a tuple of an index between 0 and `num_pieces`,
an index between 0 and 4 (since there are 4 possible rotations), an index between 0 and `num_rows-3`
(the possible row coordinates for placing a piece) and an index between 0 and `num_cols-3`
RuanJohn marked this conversation as resolved.
Show resolved Hide resolved
(the possible column coordinates for placing a piece). An action thus consists of four pieces of
information:

- Piece to place,

- Number of rotations to make to a chosen piece ({0, 90, 180, 270} degrees),
RuanJohn marked this conversation as resolved.
Show resolved Hide resolved

- Row coordinate for placing the rotated piece,

- Column coordinate for placed the rotated piece.


## Reward
The reward function is configurable, but by default is a fully dense reward giving `+1` for
each cell of a placed piece that overlaps with its correct position on the solved board. The episode
terminates if either the puzzle is solved or `num_pieces` steps have been taken by an agent.


## Registered Versions 📖
- `Jigsaw-v0`, a jigsaw puzzle with 7 rows and 7 columns containing 3 row pieces and 3 column pieces
for a total of 9 pieces in the puzzle. This version has a dense reward.
16 changes: 16 additions & 0 deletions jumanji/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from jumanji.env import Environment
from jumanji.environments.logic.rubiks_cube import generator as rubik_generator
from jumanji.environments.packing.jigsaw import generator as jigsaw_generator
from jumanji.registration import make, register, registered_environments
from jumanji.version import __version__

Expand Down Expand Up @@ -51,6 +52,21 @@
# largest ones are given in the observation.
register(id="BinPack-v1", entry_point="jumanji.environments:BinPack")

# Jigsaw puzzle with 9 pieces, a 7x7 grid and a random puzzle generator.
# The puzzle must be completed in `num_pieces` steps.
register(id="Jigsaw-v0", entry_point="jumanji.environments:Jigsaw")

# Simplified jigsaw puzzle with a 5x5 grid, 4 pieces and a deterministic
# puzzle generator.
deterministic_jigsaw_generator_with_rotation = (
jigsaw_generator.ToyJigsawGeneratorWithRotation()
)
register(
id="Jigsaw-deterministic-rotation-v0",
entry_point="jumanji.environments:Jigsaw",
kwargs={"generator": deterministic_jigsaw_generator_with_rotation},
)

# Job-shop scheduling problem with 20 jobs, 10 machines, at most
# 8 operations per job, and a max operation duration of 6 timesteps.
register(id="JobShop-v0", entry_point="jumanji.environments:JobShop")
Expand Down
1 change: 1 addition & 0 deletions jumanji/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from jumanji.environments.logic.rubiks_cube import RubiksCube
from jumanji.environments.packing import bin_pack, job_shop, knapsack
from jumanji.environments.packing.bin_pack.env import BinPack
from jumanji.environments.packing.jigsaw.env import Jigsaw
from jumanji.environments.packing.job_shop.env import JobShop
from jumanji.environments.packing.knapsack.env import Knapsack
from jumanji.environments.routing import cleaner, connector, cvrp, maze, snake, tsp
Expand Down
16 changes: 16 additions & 0 deletions jumanji/environments/packing/jigsaw/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from jumanji.environments.packing.jigsaw.env import Jigsaw
from jumanji.environments.packing.jigsaw.types import Observation, State
85 changes: 85 additions & 0 deletions jumanji/environments/packing/jigsaw/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import chex
import jax
import jax.numpy as jnp
import pytest


@pytest.fixture
def key() -> chex.PRNGKey:
"""A determinstic key."""
return jax.random.PRNGKey(0)


@pytest.fixture
def piece() -> chex.Array:
return jnp.array(
[
[0.0, 1.0, 1.0],
[0.0, 1.0, 1.0],
[0.0, 0.0, 1.0],
]
)


@pytest.fixture
def solved_board() -> chex.Array:
"""A mock solved puzzle board for testing."""

return jnp.array(
[
[1.0, 1.0, 1.0, 2.0, 2.0],
[1.0, 1.0, 2.0, 2.0, 2.0],
[3.0, 1.0, 4.0, 4.0, 2.0],
[3.0, 3.0, 4.0, 4.0, 4.0],
[3.0, 3.0, 3.0, 4.0, 4.0],
],
)


@pytest.fixture
def board_with_piece_one_placed() -> chex.Array:
"""A board with only piece one placed."""

return jnp.array(
[
[1.0, 1.0, 1.0, 0.0, 0.0],
[1.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
],
)


@pytest.fixture()
def piece_one_correctly_placed(board_with_piece_one_placed: chex.Array) -> chex.Array:
"""A 2D array of zeros where piece one has been placed correctly."""

return board_with_piece_one_placed


@pytest.fixture()
def piece_one_partially_placed(board_with_piece_one_placed: chex.Array) -> chex.Array:
"""A 2D array of zeros where piece one has been placed partially correctly.
That is to say that there is overlap between where the piece has been placed and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the correct tabs throughout the codebase, please :)

where it should be placed to solve the puzzle."""

# Shift all elements in the array one down and one to the right
partially_placed_piece = jnp.roll(board_with_piece_one_placed, shift=1, axis=0)
partially_placed_piece = jnp.roll(partially_placed_piece, shift=1, axis=1)

return partially_placed_piece
Loading