-
Notifications
You must be signed in to change notification settings - Fork 85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(jigsaw): Implement the Jigsaw env #147
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just had a look at the docs, will check the rest tomorrow 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @RuanJohn 🙌 This looks great 🙂 I will do a more detailed review tomorrow.
Co-authored-by: Sasha <reallysasha@gmail.com>
Co-authored-by: Sasha <reallysasha@gmail.com>
Co-authored-by: Sasha <reallysasha@gmail.com>
Co-authored-by: Sasha <reallysasha@gmail.com>
@pytest.fixture() | ||
def piece_one_partially_placed(board_with_piece_one_placed: chex.Array) -> chex.Array: | ||
"""A 2D array of zeros where piece one has been placed partially correctly. | ||
That is to say that there is overlap between where the piece has been placed and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add the correct tabs throughout the codebase, please :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Types and reward ✔️
- env: snake # [bin_pack, cleaner, connector, cvrp, game_2048, job_shop, knapsack, maze, minesweeper, rubiks_cube, snake, tsp] | ||
- env: jigsaw # [bin_pack, cleaner, connector, cvrp, game_2048, jigsaw, job_shop, knapsack, maze, minesweeper, rubiks_cube, snake, tsp] | ||
|
||
agent: random # [random, a2c] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change this back please
num_pieces: chex.Numeric # () | ||
solved_board: chex.Array # (num_rows, num_cols) | ||
pieces: chex.Array # (num_pieces, 3, 3) | ||
action_mask: chex.Array # (num_pieces, num_rotations, num_rows-3, num_cols-3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you calculating the action mask from scratch each step? If so you don't need it in the state
from jumanji.environments.packing.jigsaw.types import State | ||
|
||
|
||
class RewardFn(abc.ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a protocol would make more sense here than a base class as it is just the callable? @clement-bonnet
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks Ruan! Really minor style points from my side.
Only thing I haven't looked at is the networks because I assume those are still being tweaked
chosen_piece = rotate_piece(chosen_piece, rotation) | ||
|
||
grid_piece = self._expand_piece_to_board(chosen_piece, row_idx, col_idx) | ||
grid_mask_piece = self._get_ones_like_expanded_piece(grid_piece=grid_piece) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this works and gets rid of the extra method
grid_mask_piece = self._get_ones_like_expanded_piece(grid_piece=grid_piece) | |
grid_mask_piece = grid_piece == piece_idx |
grids = batch_expand_piece_to_board(rotated_pieces, rows, cols) | ||
|
||
batch_get_ones_like_expanded_piece = jax.vmap( | ||
self._get_ones_like_expanded_piece, in_axes=(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you delete _get_ones_like_expanded_piece
then:
self._get_ones_like_expanded_piece, in_axes=(0) | |
lambda x: x != 0, in_axes=(0) |
There is a problem with the current formulation of the problem. In order to make the problem solvable in its current configuration would require the agent to have access to the solved board which makes the problem non-combinatorial. The way forward will be to rework Jigsaw as a new environment called FlatPack. This will be a 2D, discrete and flattened version of the BinPack problem with potential positive transfer to the Tetris environment since placed blocks will still interlock with each other. |
Implements the full Jigsaw environment with actor critic networks.
closes #143