diff --git a/README.md b/README.md
index 4b2fb4cdf..28fbd5b4c 100644
--- a/README.md
+++ b/README.md
@@ -98,8 +98,9 @@ problems.
| 🎨 GraphColoring | Logic | `GraphColoring-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/graph_coloring/) | [doc](https://instadeepai.github.io/jumanji/environments/graph_coloring/) |
| 💣 Minesweeper | Logic | `Minesweeper-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/minesweeper/) | [doc](https://instadeepai.github.io/jumanji/environments/minesweeper/) |
| 🎲 RubiksCube | Logic | `RubiksCube-v0`
`RubiksCube-partly-scrambled-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/rubiks_cube/) | [doc](https://instadeepai.github.io/jumanji/environments/rubiks_cube/) |
-| ✏️ Sudoku | Logic | `Sudoku-v0`
`Sudoku-very-easy-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/sudoku/) | [doc](https://instadeepai.github.io/jumanji/environments/sudoku/) |
-| 📦 BinPack (3D BinPacking Problem) | Packing | `BinPack-v2` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/bin_pack/) | [doc](https://instadeepai.github.io/jumanji/environments/bin_pack/) |
+| ✏️ Sudoku | Logic | `Sudoku-v0`
`Sudoku-very-easy-v0`| [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/sudoku/) | [doc](https://instadeepai.github.io/jumanji/environments/sudoku/) |
+| 📦 BinPack (3D BinPacking Problem) | Packing | `BinPack-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/bin_pack/) | [doc](https://instadeepai.github.io/jumanji/environments/bin_pack/) |
+| 🧩 FlatPack (2D Grid Filling Problem) | Packing | `FlatPack-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/flat_pack/) | [doc](https://instadeepai.github.io/jumanji/environments/flat_pack/) |
| 🏭 JobShop (Job Shop Scheduling Problem) | Packing | `JobShop-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/job_shop/) | [doc](https://instadeepai.github.io/jumanji/environments/job_shop/) |
| 🎒 Knapsack | Packing | `Knapsack-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/knapsack/) | [doc](https://instadeepai.github.io/jumanji/environments/knapsack/) |
| ▒ Tetris | Packing | `Tetris-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/tetris/) | [doc](https://instadeepai.github.io/jumanji/environments/tetris/) |
diff --git a/docs/api/environments/flat_pack.md b/docs/api/environments/flat_pack.md
new file mode 100644
index 000000000..1a372e7fb
--- /dev/null
+++ b/docs/api/environments/flat_pack.md
@@ -0,0 +1,8 @@
+::: jumanji.environments.packing.flat_pack.env.FlatPack
+ selection:
+ members:
+ - __init__
+ - reset
+ - step
+ - observation_spec
+ - action_spec
diff --git a/docs/env_anim/flat_pack.gif b/docs/env_anim/flat_pack.gif
new file mode 100644
index 000000000..561414ad7
Binary files /dev/null and b/docs/env_anim/flat_pack.gif differ
diff --git a/docs/env_img/flat_pack.png b/docs/env_img/flat_pack.png
new file mode 100644
index 000000000..ef1906825
Binary files /dev/null and b/docs/env_img/flat_pack.png differ
diff --git a/docs/environments/flat_pack.md b/docs/environments/flat_pack.md
new file mode 100644
index 000000000..67af383c8
--- /dev/null
+++ b/docs/environments/flat_pack.md
@@ -0,0 +1,57 @@
+# FlatPack Environment
+
+
+ +
+ +We provide here a Jax JIT-able implementation of a packing environment named _flat pack_. The goal of +the agent is to place all the available blocks on an empty 2D grid. +Each time an episode resets a new set of blocks is created and the grid is emptied. Blocks are randomly +shuffled and rotated and all have shape (3, 3). + +## Observation +The observation given to the agent gives a view of the current state of the grid as well as +all blocks that can be placed. + +- `current_grid`: jax array (float32) of shape `(num_rows, num_cols)` with values in the range + `[0, num_blocks]` (corresponding to the number of each block). This grid will have zeros + where no blocks have been placed and numbers corresponding to each block where that particular + block has been placed. + +- `blocks`: jax array (float32) of shape `(num_blocks, 3, 3)` of all possible blocks in + that can fit in the current grid. These blocks are shuffled, rotated and will always have shape `(3, 3)`. + +- `action_mask`: jax array (bool) of shape `(num_blocks, 4, num_rows-2, num_cols-2)`, representing + which actions are possible given the current state of the grid. The first index indicates the + number of blocks associated with a given grid. The second index indicates the number of times a block may be rotated. + The third and fourth indices indicate the row and column coordinate of where a blocks top left-most corner may be placed + respectively. Blocks are placed by an agent by specifying the row and column coordinate on the grid where the top left corner + of the selected block should be placed. These values will always be `num_rows-2` and `num_cols-2` + respectively to make it impossible for an agent to place a block outside the current grid. + + +## Action +The action space is a `MultiDiscreteArray`, specifically a tuple of an index between 0 and `num_blocks - 1`, +an index between 0 and 4 (since there are 4 possible rotations), an index between 0 and `num_rows-2` +(the possible row coordinates for placing a block) and an index between 0 and `num_cols-2` +(the possible column coordinates for placing a block). An action thus consists of four pieces of +information: + +- Block to place, + +- Number of 90 degree rotations to make to a chosen block ({0, 90, 180, 270} degrees), + +- Row coordinate for placing the rotated block's top left corner, + +- Column coordinate for placing the rotated block's top left corner. + + +## Reward +The reward function is configurable, but by default is a fully dense reward giving the sum of the number of non-zero +cells in a placed block normalised by the total number of cells in the grid at each timestep. The episode +terminates if either the grid is filled or `num_blocks` steps have been taken by an agent. + + +## Registered Versions 📖 +- `FlatPack-v0`, a flat pack environment grid with 11 rows and 11 columns containing 5 row blocks and 5 column blocks + for a total of 25 blocks that can be placed on the grid. This version has a dense reward. diff --git a/examples/load_checkpoints.ipynb b/examples/load_checkpoints.ipynb index 1f3c82226..f99cea250 100644 --- a/examples/load_checkpoints.ipynb +++ b/examples/load_checkpoints.ipynb @@ -111,8 +111,11 @@ ] }, { + "attachments": {}, "cell_type": "markdown", - "metadata": {}, + "metadata": { + "collapsed": false + }, "source": [ "## Load configs" ] @@ -194,6 +197,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -243,6 +247,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -279,6 +284,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ diff --git a/jumanji/__init__.py b/jumanji/__init__.py index 49b3fdcfd..cfa526965 100644 --- a/jumanji/__init__.py +++ b/jumanji/__init__.py @@ -81,6 +81,10 @@ # given in the observation. register(id="BinPack-v2", entry_point="jumanji.environments:BinPack") +# 2D grid filling problem with 25 blocks, an 11x11 grid and a random grid generator. +# The grid must be filled in `num_blocks` steps. +register(id="FlatPack-v0", entry_point="jumanji.environments:FlatPack") + # Job-shop scheduling problem with 20 jobs, 10 machines, at most # 8 operations per job, and a max operation duration of 6 timesteps. register(id="JobShop-v0", entry_point="jumanji.environments:JobShop") diff --git a/jumanji/environments/__init__.py b/jumanji/environments/__init__.py index 239ef8f51..4e69e2c2a 100644 --- a/jumanji/environments/__init__.py +++ b/jumanji/environments/__init__.py @@ -20,8 +20,9 @@ from jumanji.environments.logic.minesweeper import Minesweeper from jumanji.environments.logic.rubiks_cube import RubiksCube from jumanji.environments.logic.sudoku import Sudoku -from jumanji.environments.packing import bin_pack, job_shop, knapsack, tetris +from jumanji.environments.packing import bin_pack, flat_pack, job_shop, knapsack, tetris from jumanji.environments.packing.bin_pack.env import BinPack +from jumanji.environments.packing.flat_pack.env import FlatPack from jumanji.environments.packing.job_shop.env import JobShop from jumanji.environments.packing.knapsack.env import Knapsack from jumanji.environments.packing.tetris.env import Tetris diff --git a/jumanji/environments/packing/flat_pack/__init__.py b/jumanji/environments/packing/flat_pack/__init__.py new file mode 100644 index 000000000..252accca3 --- /dev/null +++ b/jumanji/environments/packing/flat_pack/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jumanji.environments.packing.flat_pack.env import FlatPack +from jumanji.environments.packing.flat_pack.types import Observation, State diff --git a/jumanji/environments/packing/flat_pack/conftest.py b/jumanji/environments/packing/flat_pack/conftest.py new file mode 100644 index 000000000..9749488c3 --- /dev/null +++ b/jumanji/environments/packing/flat_pack/conftest.py @@ -0,0 +1,162 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import chex +import jax +import jax.numpy as jnp +import pytest + + +@pytest.fixture +def key() -> chex.PRNGKey: + """A determinstic key.""" + + return jax.random.PRNGKey(0) + + +@pytest.fixture +def block() -> chex.Array: + """A mock block for testing.""" + + return jnp.array( + [ + [0, 1, 1], + [0, 1, 1], + [0, 0, 1], + ] + ) + + +@pytest.fixture +def solved_grid() -> chex.Array: + """A mock solved grid for testing.""" + + return jnp.array( + [ + [1, 1, 1, 2, 2], + [1, 1, 2, 2, 2], + [3, 1, 4, 4, 2], + [3, 3, 4, 4, 4], + [3, 3, 3, 4, 4], + ], + ) + + +@pytest.fixture +def grid_with_block_one_placed() -> chex.Array: + """A grid with only block one placed.""" + + return jnp.array( + [ + [1, 1, 1, 0, 0], + [1, 1, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + ], + ) + + +@pytest.fixture() +def block_one_placed_at_0_0(grid_with_block_one_placed: chex.Array) -> chex.Array: + """A 2D array of zeros where block one has been placed with it left top-most + corner at position (0, 0). + """ + + return grid_with_block_one_placed + + +@pytest.fixture() +def block_one_placed_at_1_1(grid_with_block_one_placed: chex.Array) -> chex.Array: + """A 2D array of zeros where block one has been placed with it left top-most + corner at position (1, 1). + """ + + # Shift all elements in the array one down and one to the right + partially_placed_block = jnp.roll(grid_with_block_one_placed, shift=1, axis=0) + partially_placed_block = jnp.roll(partially_placed_block, shift=1, axis=1) + + return partially_placed_block + + +@pytest.fixture() +def action_mask_with_block_1_placed() -> chex.Array: + """Action mask for a 4 piece grid where only block 1 has been placed with its + left top-most corner at (1, 1). + """ + + return jnp.array( + [ + [ + [[False, False, False], [False, False, False], [False, False, False]], + [[False, False, False], [False, False, False], [False, False, False]], + [[False, False, False], [False, False, False], [False, False, False]], + [[False, False, False], [False, False, False], [False, False, False]], + ], + [ + [[False, False, True], [False, False, True], [False, True, True]], + [[False, False, True], [False, True, True], [False, True, True]], + [[False, False, False], [False, False, True], [True, False, True]], + [[False, False, False], [False, False, True], [False, False, True]], + ], + [ + [[False, False, False], [False, False, True], [True, False, True]], + [[False, False, False], [False, False, True], [False, False, True]], + [[False, False, False], [False, False, True], [False, False, True]], + [[False, False, True], [False, True, True], [True, True, True]], + ], + [ + [[False, False, False], [False, False, True], [False, False, True]], + [[False, False, True], [False, False, True], [False, True, True]], + [[False, False, False], [False, False, True], [False, False, True]], + [[False, False, True], [False, False, True], [False, True, True]], + ], + ] + ) + + +@pytest.fixture() +def action_mask_without_only_block_1_placed() -> chex.Array: + """Action mask for a 4 piece grid where only block 1 can be placed with its + left top-most corner at (1, 1). + """ + + return jnp.array( + [ + [ + [[True, False, False], [False, False, False], [False, False, False]], + [[False, False, False], [False, False, False], [False, False, False]], + [[False, False, False], [False, False, False], [False, False, False]], + [[False, False, False], [False, False, False], [False, False, False]], + ], + [ + [[False, False, False], [False, False, False], [False, False, False]], + [[False, False, False], [False, False, False], [False, False, False]], + [[False, False, False], [False, False, False], [False, False, False]], + [[False, False, False], [False, False, False], [False, False, False]], + ], + [ + [[False, False, False], [False, False, False], [False, False, False]], + [[False, False, False], [False, False, False], [False, False, False]], + [[False, False, False], [False, False, False], [False, False, False]], + [[False, False, False], [False, False, False], [False, False, False]], + ], + [ + [[False, False, False], [False, False, False], [False, False, False]], + [[False, False, False], [False, False, False], [False, False, False]], + [[False, False, False], [False, False, False], [False, False, False]], + [[False, False, False], [False, False, False], [False, False, False]], + ], + ] + ) diff --git a/jumanji/environments/packing/flat_pack/env.py b/jumanji/environments/packing/flat_pack/env.py new file mode 100644 index 000000000..573486a73 --- /dev/null +++ b/jumanji/environments/packing/flat_pack/env.py @@ -0,0 +1,518 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Sequence, Tuple + +import chex +import jax +import jax.numpy as jnp +import matplotlib +from numpy.typing import NDArray + +from jumanji import specs +from jumanji.env import Environment +from jumanji.environments.packing.flat_pack.generator import ( + InstanceGenerator, + RandomFlatPackGenerator, +) +from jumanji.environments.packing.flat_pack.reward import CellDenseReward, RewardFn +from jumanji.environments.packing.flat_pack.types import Observation, State +from jumanji.environments.packing.flat_pack.utils import compute_grid_dim, rotate_block +from jumanji.environments.packing.flat_pack.viewer import FlatPackViewer +from jumanji.types import TimeStep, restart, termination, transition +from jumanji.viewer import Viewer + + +class FlatPack(Environment[State]): + + """The FlatPack environment with a configurable number of row and column blocks. + Here the goal of an agent is to completely fill an empty grid by placing all + available blocks. It can be thought of as a discrete 2D version of the `BinPack` + environment. + + - observation: `Observation` + - grid: jax array (int) of shape (num_rows, num_cols) with the + current state of the grid. + - blocks: jax array (int) of shape (num_blocks, 3, 3) with the blocks to + be placed on the grid. Here each block is a 2D array with shape (3, 3). + - action_mask: jax array (bool) showing where which blocks can be placed on the grid. + this mask includes all possible rotations and possible placement locations + for each block on the grid. + + - action: jax array (int32) of shape (4,) + multi discrete array containing the move to perform + (block to place, number of rotations, row coordinate, column coordinate). + + - reward: jax array (float) of shape (), could be either: + - cell dense: the number of non-zero cells in a placed block normalised by the + total number of cells in a grid. this will be a value in the range [0, 1]. + that is to say that the agent will optimise for the maximum area to fill on + the grid. + - block dense: each placed block will receive a reward of 1./num_blocks. this will + be a value in the range [0, 1]. that is to say that the agent will optimise + for the maximum number of blocks placed on the grid. + - sparse: 1 if the grid is completely filled, otherwise 0 at each timestep. + + - episode termination: + - if all blocks have been placed on the board. + - if the agent has taken `num_blocks` steps in the environment. + + - state: `State` + - num_blocks: jax array (int32) of shape () with the + number of blocks in the environment. + - blocks: jax array (int32) of shape (num_blocks, 3, 3) with the blocks to + be placed on the grid. Here each block is a 2D array with shape (3, 3). + - action_mask: jax array (bool) showing where which blocks can be placed on the grid. + this mask includes all possible rotations and possible placement locations + for each block on the grid. + - placed_blocks: jax array (bool) of shape (num_blocks,) showing which blocks + have been placed on the grid. + - grid: jax array (int32) of shape (num_rows, num_cols) with the + current state of the grid. + - step_count: jax array (int32) of shape () with the number of steps taken + in the environment. + - key: jax array of shape (2,) with the random key used for board + generation. + + ```python + from jumanji.environments import FlatPack + env = FlatPack() + key = jax.random.PRNGKey(0) + state, timestep = jax.jit(env.reset)(key) + env.render(state) + action = env.action_spec().generate_value() + state, timestep = jax.jit(env.step)(state, action) + env.render(state) + ``` + """ + + def __init__( + self, + generator: Optional[InstanceGenerator] = None, + reward_fn: Optional[RewardFn] = None, + viewer: Optional[Viewer[State]] = None, + ): + """Initializes the FlatPack environment. + + Args: + generator: Instance generator for the environment, default to `RandomFlatPackGenerator` + with a grid of 5 blocks per row and column. + reward_fn: Reward function for the environment, default to `CellDenseReward`. + viewer: Viewer for rendering the environment. + """ + + default_generator = RandomFlatPackGenerator( + num_row_blocks=5, + num_col_blocks=5, + ) + + self.generator = generator or default_generator + self.num_row_blocks = self.generator.num_row_blocks + self.num_col_blocks = self.generator.num_col_blocks + self.num_blocks = self.num_row_blocks * self.num_col_blocks + self.num_rows, self.num_cols = ( + compute_grid_dim(self.num_row_blocks), + compute_grid_dim(self.num_col_blocks), + ) + self.reward_fn = reward_fn or CellDenseReward() + self.viewer = viewer or FlatPackViewer( + "FlatPack", self.num_blocks, render_mode="human" + ) + + def __repr__(self) -> str: + return ( + f"FlatPack environment with a grid size of ({self.num_rows}x{self.num_cols}) " + f"with {self.num_row_blocks} row blocks, {self.num_col_blocks} column " + f"blocks. Each block has dimension (3x3)." + ) + + def reset( + self, + key: chex.PRNGKey, + ) -> Tuple[State, TimeStep[Observation]]: + + """Resets the environment. + + Args: + key: PRNG key for generating a new instance. + + Returns: + a tuple of the initial environment state and a time step. + """ + + grid_state = self.generator(key) + + obs = self._observation_from_state(grid_state) + timestep = restart(observation=obs) + + return grid_state, timestep + + def step( + self, state: State, action: chex.Array + ) -> Tuple[State, TimeStep[Observation]]: + """Steps the environment. + + Args: + state: current state of the environment. + action: action to take. + + Returns: + a tuple of the next environment state and a time step. + """ + + # Unpack and use actions + block_idx, rotation, row_idx, col_idx = action + + chosen_block = state.blocks[block_idx] + + # Rotate chosen block + chosen_block = rotate_block(chosen_block, rotation) + + grid_block = self._expand_block_to_grid(chosen_block, row_idx, col_idx) + + action_is_legal = state.action_mask[block_idx, rotation, row_idx, col_idx] + + # If the action is legal create a new grid and update the placed blocks array + new_grid = jax.lax.cond( + action_is_legal, + lambda: state.grid + grid_block, + lambda: state.grid, + ) + placed_blocks = jax.lax.cond( + action_is_legal, + lambda: state.placed_blocks.at[block_idx].set(True), + lambda: state.placed_blocks, + ) + + new_action_mask = self._make_action_mask(new_grid, state.blocks, placed_blocks) + + next_state = State( + grid=new_grid, + blocks=state.blocks, + action_mask=new_action_mask, + num_blocks=state.num_blocks, + key=state.key, + step_count=state.step_count + 1, + placed_blocks=placed_blocks, + ) + + done = self._is_done(next_state) + next_obs = self._observation_from_state(next_state) + reward = self.reward_fn(state, grid_block, next_state, action_is_legal, done) + + timestep = jax.lax.cond( + done, + termination, + transition, + reward, + next_obs, + ) + + return next_state, timestep + + def render(self, state: State) -> Optional[NDArray]: + """Render a given state of the environment. + + Args: + state: `State` object containing the current environment state. + """ + + return self.viewer.render(state) + + def animate( + self, + states: Sequence[State], + interval: int = 200, + save_path: Optional[str] = None, + ) -> matplotlib.animation.FuncAnimation: + """Create an animation from a sequence of states. + + Args: + states: sequence of `State` corresponding to subsequent timesteps. + interval: delay between frames in milliseconds, default to 200. + save_path: the path where the animation file should be saved. If it is None, the plot + will not be saved. + + Returns: + animation that can export to gif, mp4, or render with HTML. + """ + + return self.viewer.animate(states, interval, save_path) + + def close(self) -> None: + """Perform any necessary cleanup. + + Environments will automatically `close()` themselves when + garbage collected or when the program exits. + """ + + self.viewer.close() + + def observation_spec(self) -> specs.Spec[Observation]: + """Returns the observation spec of the environment. + + Returns: + Spec for each filed in the observation: + - grid: BoundedArray (int) of shape (num_rows, num_cols). + - blocks: BoundedArray (int) of shape (num_blocks, 3, 3). + - action_mask: BoundedArray (bool) of shape + (num_blocks, 4, num_rows-2, num_cols-2). + """ + + grid = specs.BoundedArray( + shape=(self.num_rows, self.num_cols), + minimum=0, + maximum=self.num_blocks, + dtype=jnp.int32, + name="grid", + ) + + blocks = specs.BoundedArray( + shape=(self.num_blocks, 3, 3), + minimum=0, + maximum=self.num_blocks, + dtype=jnp.int32, + name="blocks", + ) + + action_mask = specs.BoundedArray( + shape=( + self.num_blocks, + 4, + self.num_rows - 2, + self.num_cols - 2, + ), + minimum=False, + maximum=True, + dtype=bool, + name="action_mask", + ) + + return specs.Spec( + Observation, + "ObservationSpec", + grid=grid, + blocks=blocks, + action_mask=action_mask, + ) + + def action_spec(self) -> specs.MultiDiscreteArray: + """Specifications of the action expected by the `FlatPack` environment. + + Returns: + MultiDiscreteArray (int32) of shape (num_blocks, num_rotations, + num_rows-2, num_cols-2). + - num_blocks: int between 0 and num_blocks - 1 (inclusive). + - num_rotations: int between 0 and 3 (inclusive). + - max_row_position: int between 0 and num_rows - 3 (inclusive). + - max_col_position: int between 0 and num_cols - 3 (inclusive). + """ + + max_row_position = self.num_rows - 2 + max_col_position = self.num_cols - 2 + + return specs.MultiDiscreteArray( + num_values=jnp.array( + [self.num_blocks, 4, max_row_position, max_col_position] + ), + name="action", + ) + + def _is_done(self, state: State) -> bool: + """Checks if the environment is done by checking whether the number of + steps is equal to the number of blocks. + + Args: + state: current state of the environment. + + Returns: + True if the environment is done, False otherwise. + """ + + done: bool = state.step_count >= state.num_blocks + + return done + + def _is_legal_action( + self, + action: chex.Numeric, + grid: chex.Array, + placed_blocks: chex.Array, + grid_mask_block: chex.Array, + ) -> bool: + """Checks if the action is legal by considering the action mask and the + current grid. An action is legal if the action mask is True for that action + and the there is no overlap with blocks already placed. + + Args: + action: action taken. + grid: current state of the grid. + placed_blocks: array indicating which blocks have been placed. + grid_mask_block: grid with ones where current block should be placed. + + Returns: + True if the action is legal, False otherwise. + """ + + block_idx, _, _, _ = action + + placed_mask = (grid > 0.0) + grid_mask_block + + legal: bool = (~placed_blocks[block_idx]) & (jnp.max(placed_mask) <= 1) + + return legal + + def _get_ones_like_expanded_block(self, grid_block: chex.Array) -> chex.Array: + """Makes a grid of zeroes with ones where the block is placed.""" + + return (grid_block != 0).astype(jnp.int32) + + def _expand_block_to_grid( + self, + block: chex.Array, + row_coord: chex.Numeric, + col_coord: chex.Numeric, + ) -> chex.Array: + """Places a block on a grid of zeroes with the same size as the grid. + + Args: + block: block to place on the grid. + row_coord: row coordinate on the grid where the top left corner + of the block will be placed. + col_coord: column coordinate on the grid where the top left corner + of the block will be placed. + + Returns: + Grid of zeroes with values where the block is placed. + """ + + # Make an empty grid for placing the block on. + grid_with_block = jnp.zeros((self.num_rows, self.num_cols), dtype=jnp.int32) + place_location = (row_coord, col_coord) + + grid_with_block = jax.lax.dynamic_update_slice( + grid_with_block, block, place_location + ) + + return grid_with_block + + def _observation_from_state(self, state: State) -> Observation: + """Creates an observation from a state. + + Args: + state: State to create an observation from. + + Returns: + An observation. + """ + + return Observation( + grid=state.grid, + action_mask=state.action_mask, + blocks=state.blocks, + ) + + def _expand_all_blocks_to_grids( + self, + blocks: chex.Array, + block_idxs: chex.Array, + rotations: chex.Array, + row_coords: chex.Array, + col_coords: chex.Array, + ) -> chex.Array: + """Takes multiple blocks and their corresponding rotations and positions, + and generates a grid for each block. + + Args: + blocks: array of possible blocks. + block_idxs: array of indices of the blocks to place. + rotations: array of all possible rotations for each block. + row_coords: array of row coordinates. + col_coords: array of column coordinates. + """ + + batch_expand_block_to_board = jax.vmap(self._expand_block_to_grid) + + all_possible_blocks = blocks[block_idxs] + rotated_blocks = jax.vmap(rotate_block)(all_possible_blocks, rotations) + grids = batch_expand_block_to_board(rotated_blocks, row_coords, col_coords) + + batch_get_ones_like_expanded_block = jax.vmap( + self._get_ones_like_expanded_block, in_axes=(0) + ) + grids = batch_get_ones_like_expanded_block(grids) + return grids + + def _make_action_mask( + self, grid: chex.Array, blocks: chex.Array, placed_blocks: chex.Array + ) -> chex.Array: + """Create a mask of possible actions based on the current state of the grid. + + Args: + grid: current state of the grid. + blocks: array of all blocks. + placed_blocks: array of blocks that have already been placed. + """ + + num_blocks, num_rotations, num_placement_rows, num_placement_cols = ( + self.num_blocks, + 4, + self.num_rows - 2, + self.num_cols - 2, + ) + + blocks_grid, rotations_grid, rows_grid, cols_grid = jnp.meshgrid( + jnp.arange(num_blocks), + jnp.arange(num_rotations), + jnp.arange(num_placement_rows), + jnp.arange(num_placement_cols), + indexing="ij", + ) + + grid_mask_pieces = self._expand_all_blocks_to_grids( + blocks, + blocks_grid.flatten(), + rotations_grid.flatten(), + rows_grid.flatten(), + cols_grid.flatten(), + ) + + batch_is_legal_action = jax.vmap( + self._is_legal_action, in_axes=(0, None, None, 0) + ) + + all_actions = jnp.stack( + (blocks_grid, rotations_grid, rows_grid, cols_grid), axis=-1 + ).reshape(-1, 4) + + legal_actions = batch_is_legal_action( + all_actions, + grid, + placed_blocks, + grid_mask_pieces, + ) + + legal_actions = legal_actions.reshape( + num_blocks, num_rotations, num_placement_rows, num_placement_cols + ) + + # Now set all current placed blocks to false in the mask. + placed_blocks_array = placed_blocks.reshape((self.num_blocks, 1, 1, 1)) + placed_blocks_mask = jnp.tile( + placed_blocks_array, + (1, num_rotations, num_placement_rows, num_placement_cols), + ) + legal_actions = jnp.where(placed_blocks_mask, False, legal_actions) + + return legal_actions diff --git a/jumanji/environments/packing/flat_pack/env_test.py b/jumanji/environments/packing/flat_pack/env_test.py new file mode 100644 index 000000000..923306349 --- /dev/null +++ b/jumanji/environments/packing/flat_pack/env_test.py @@ -0,0 +1,340 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import chex +import jax +import jax.numpy as jnp +import pytest + +from jumanji.environments.packing.flat_pack.env import FlatPack +from jumanji.environments.packing.flat_pack.generator import ( + RandomFlatPackGenerator, + ToyFlatPackGeneratorNoRotation, + ToyFlatPackGeneratorWithRotation, +) +from jumanji.environments.packing.flat_pack.reward import ( + BlockDenseReward, + CellDenseReward, +) +from jumanji.environments.packing.flat_pack.types import State +from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.pytrees import assert_is_jax_array_tree +from jumanji.types import StepType, TimeStep + + +@pytest.fixture(scope="module") +def flat_pack() -> FlatPack: + """Creates a simple FlatPack environment for testing.""" + return FlatPack( + generator=RandomFlatPackGenerator( + num_col_blocks=3, + num_row_blocks=3, + ), + ) + + +@pytest.fixture +def simple_env_grid_state_1() -> chex.Array: + """The state of the grid in the simplified example after 1 correct action.""" + # fmt: off + return jnp.array( + [ + [1, 1, 1, 0, 0], + [1, 1, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + ] + ) + # fmt: on + + +@pytest.fixture +def simple_env_grid_state_2() -> chex.Array: + """The state of the grid in the simplified example after 2 correct actions.""" + # fmt: off + return jnp.array( + [ + [1, 1, 1, 2, 2], + [1, 1, 2, 2, 2], + [0, 1, 0, 0, 2], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + ] + ) + # fmt: on + + +@pytest.fixture +def simple_env_grid_state_3() -> chex.Array: + """The state of the grid in the simplified example after 3 correct actions.""" + # fmt: off + return jnp.array( + [ + [1, 1, 1, 2, 2], + [1, 1, 2, 2, 2], + [3, 1, 0, 0, 2], + [3, 3, 0, 0, 0], + [3, 3, 3, 0, 0], + ] + ) + # fmt: on + + +@pytest.fixture +def simple_env_grid_state_4() -> chex.Array: + """The state of the grdi in the simplified example after 4 correct actions.""" + # fmt: off + return jnp.array( + [ + [1, 1, 1, 2, 2], + [1, 1, 2, 2, 2], + [3, 1, 4, 4, 2], + [3, 3, 4, 4, 4], + [3, 3, 3, 4, 4], + ] + ) + # fmt: on + + +@pytest.fixture +def simple_env_placed_blocks_1() -> chex.Array: + """Placed blocks array in the simplified env after 1 block has been placed.""" + return jnp.array([True, False, False, False]) + + +@pytest.fixture +def simple_env_placed_blocks_2() -> chex.Array: + """Placed blocks array in the simplified env after 2 blocks have been placed.""" + return jnp.array([True, True, False, False]) + + +@pytest.fixture +def simple_env_placed_blocks_3() -> chex.Array: + """Placed blocks array in the simplified env after 3 blocks have been placed.""" + return jnp.array([True, True, True, False]) + + +def test_flat_pack__reset_jit(flat_pack: FlatPack, key: chex.PRNGKey) -> None: + """Test that the environment reset only compiles once.""" + chex.clear_trace_counter() + reset_fn = jax.jit(chex.assert_max_traces(flat_pack.reset, n=1)) + state, timestep = reset_fn(key) + + # Check the types of the outputs + assert isinstance(state, State) + assert isinstance(timestep, TimeStep) + + # Check that the state contains DeviceArrays to verify that it is jitted. + assert_is_jax_array_tree(state) + + # Call the reset method again to ensure it is not compiling twice. + key, new_key = jax.random.split(key) + state, timestep = reset_fn(new_key) + assert isinstance(state, State) + assert isinstance(timestep, TimeStep) + + +def test_flat_pack__step_jit(flat_pack: FlatPack, key: chex.PRNGKey) -> None: + """Test that the step function is only compiled once.""" + state_0, timestep_0 = flat_pack.reset(key) + action_0 = jnp.array([0, 0, 0, 0]) + + chex.clear_trace_counter() + step_fn = jax.jit(chex.assert_max_traces(flat_pack.step, n=1)) + + state_1, timestep_1 = step_fn(state_0, action_0) + + # Check that the state has changed and that the step has incremented. + assert not jnp.array_equal(state_1.grid, state_0.grid) + assert state_1.step_count == state_0.step_count + 1 + assert isinstance(timestep_1, TimeStep) + + # Check that the state contains DeviceArrays to verify that it is jitted. + assert_is_jax_array_tree(state_1) + + # Call the step method again to ensure it is not compiling twice. + action_1 = jnp.array([1, 0, 3, 3]) + state_2, timestep_2 = step_fn(state_1, action_1) + + # Check that the state contains DeviceArrays to verify that it is jitted. + assert_is_jax_array_tree(state_2) + + # Check that the state has changed and that the step has incremented. + assert not jnp.array_equal(state_2.grid, state_1.grid) + assert state_2.step_count == state_1.step_count + 1 + assert isinstance(timestep_2, TimeStep) + + +def test_flat_pack__does_not_smoke(flat_pack: FlatPack) -> None: + """Test that we can run an episode without any errors.""" + check_env_does_not_smoke(flat_pack) + + +def test_flat_pack__is_done(flat_pack: FlatPack, key: chex.PRNGKey) -> None: + """Test that the is_done method works as expected.""" + + state, _ = flat_pack.reset(key) + assert not flat_pack._is_done(state) + + # Manually set step count equal to the number of blocks. + state.step_count = 9 + assert flat_pack._is_done(state) + + +def test_flat_pack__expand_block_to_grid( + flat_pack: FlatPack, key: chex.PRNGKey, block: chex.Array +) -> None: + """Test that a block is correctly set on a grid of zeros.""" + _, _ = flat_pack.reset(key) + expanded_grid_with_block = flat_pack._expand_block_to_grid(block, 2, 1) + # fmt: off + expected_expanded_grid = jnp.array( + [ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 0, 0, 0], + [0, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + ] + ) + # fmt: on + assert jnp.array_equal(expanded_grid_with_block, expected_expanded_grid) + + +def test_flat_pack__completed_episode_with_cell_dense_reward( + key: chex.PRNGKey, + simple_env_grid_state_1: chex.Array, + simple_env_grid_state_2: chex.Array, + simple_env_grid_state_3: chex.Array, + simple_env_grid_state_4: chex.Array, + simple_env_placed_blocks_1: chex.Array, + simple_env_placed_blocks_2: chex.Array, + simple_env_placed_blocks_3: chex.Array, +) -> None: + """This test will step a simplified version of the FlatPack environment + with a cell dense reward until completion. It will check that the reward is + correctly computed and that the environment transitions as expected until + done. + """ + + simple_env = FlatPack( + generator=ToyFlatPackGeneratorNoRotation(), + reward_fn=CellDenseReward(), + ) + chex.clear_trace_counter() + step_fn = jax.jit(chex.assert_max_traces(simple_env.step, n=1)) + + # Intialize the environment + state, timestep = simple_env.reset(key) + assert isinstance(state, State) + assert isinstance(timestep, TimeStep) + assert timestep.step_type == StepType.FIRST + + # Check that the reset board contains only zeros + assert jnp.all(state.grid == 0) + assert jnp.all(state.action_mask) + + # Step the environment + state, timestep = step_fn(state, jnp.array([0, 0, 0, 0])) + assert timestep.step_type == StepType.MID + assert jnp.all(state.grid == simple_env_grid_state_1) + assert timestep.reward == 6 / 25 + assert jnp.all(state.placed_blocks == simple_env_placed_blocks_1) + + # Step the environment + state, timestep = step_fn(state, jnp.array([1, 0, 0, 2])) + assert timestep.step_type == StepType.MID + assert jnp.all(state.grid == simple_env_grid_state_2) + assert timestep.reward == 6 / 25 + assert jnp.all(state.placed_blocks == simple_env_placed_blocks_2) + + # Step the environment + state, timestep = step_fn(state, jnp.array([2, 0, 2, 0])) + assert timestep.step_type == StepType.MID + assert jnp.all(state.grid == simple_env_grid_state_3) + assert timestep.reward == 6 / 25 + assert jnp.all(state.placed_blocks == simple_env_placed_blocks_3) + + # Step the environment + state, timestep = step_fn(state, jnp.array([3, 0, 2, 2])) + assert timestep.step_type == StepType.LAST + assert jnp.all(state.grid == simple_env_grid_state_4) + assert timestep.reward == 7 / 25 + assert jnp.all(~state.action_mask) + + +def test_flat_pack__completed_episode_with_sparse_block_dense_reward( + key: chex.PRNGKey, + simple_env_grid_state_1: chex.Array, + simple_env_grid_state_2: chex.Array, + simple_env_grid_state_3: chex.Array, + simple_env_grid_state_4: chex.Array, + simple_env_placed_blocks_1: chex.Array, + simple_env_placed_blocks_2: chex.Array, + simple_env_placed_blocks_3: chex.Array, +) -> None: + """This test will step a simplified version of the FlatPack environment + with a block dense reward until completion. It will check that the reward is + correctly computed and that the environment transitions as expected until + done. + """ + + simple_env = FlatPack( + generator=ToyFlatPackGeneratorWithRotation(), + reward_fn=BlockDenseReward(), + ) + chex.clear_trace_counter() + step_fn = jax.jit(chex.assert_max_traces(simple_env.step, n=1)) + + # Intialize the environment + state, timestep = simple_env.reset(key) + assert isinstance(state, State) + assert isinstance(timestep, TimeStep) + assert timestep.step_type == StepType.FIRST + + # Check that the reset board contains only zeros + assert jnp.all(state.grid == 0) + assert jnp.all(state.action_mask) + + # Step the environment + state, timestep = step_fn(state, jnp.array([0, 2, 0, 0])) + assert timestep.step_type == StepType.MID + assert jnp.all(state.grid == simple_env_grid_state_1) + assert timestep.reward == 1 / 4 + assert jnp.all(state.placed_blocks == simple_env_placed_blocks_1) + + # Step the environment + state, timestep = step_fn(state, jnp.array([1, 2, 0, 2])) + assert timestep.step_type == StepType.MID + + assert jnp.all(state.grid == simple_env_grid_state_2) + assert timestep.reward == 1 / 4 + assert jnp.all(state.placed_blocks == simple_env_placed_blocks_2) + + # Step the environment + state, timestep = step_fn(state, jnp.array([2, 1, 2, 0])) + assert timestep.step_type == StepType.MID + assert jnp.all(state.grid == simple_env_grid_state_3) + assert timestep.reward == 1 / 4 + assert jnp.all(state.placed_blocks == simple_env_placed_blocks_3) + + # Step the environment + state, timestep = step_fn(state, jnp.array([3, 0, 2, 2])) + assert timestep.step_type == StepType.LAST + assert jnp.all(state.grid == simple_env_grid_state_4) + assert timestep.reward == 1 / 4 + assert jnp.all(~state.action_mask) diff --git a/jumanji/environments/packing/flat_pack/generator.py b/jumanji/environments/packing/flat_pack/generator.py new file mode 100644 index 000000000..7ea8495d5 --- /dev/null +++ b/jumanji/environments/packing/flat_pack/generator.py @@ -0,0 +1,386 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import Tuple + +import chex +import jax +import jax.numpy as jnp + +from jumanji.environments.packing.flat_pack.types import State +from jumanji.environments.packing.flat_pack.utils import ( + compute_grid_dim, + get_significant_idxs, + rotate_block, +) + + +class InstanceGenerator(abc.ABC): + """Base class for generators for the flat_pack environment. An `InstanceGenerator` is responsible + for generating a problem instance when the environment is reset. + """ + + def __init__( + self, + num_row_blocks: int, + num_col_blocks: int, + ) -> None: + """Initialises a flat_pack generator, used to generate grids for the FlatPack environment. + + Args: + num_row_blocks: Number of row blocks in flat_pack environment. + num_col_blocks: Number of column blocks in flat_pack environment. + """ + + self.num_row_blocks = num_row_blocks + self.num_col_blocks = num_col_blocks + + @abc.abstractmethod + def __call__(self, key: chex.PRNGKey) -> State: + """Call method responsible for generating a new state. + + Args: + key: jax random key in case stochasticity is used in the instance generation process. + + Returns: + A `FlatPack` environment state. + """ + + +class RandomFlatPackGenerator(InstanceGenerator): + """Random flat_pack generator. This generator will generate a random flat_pack grid.""" + + def _fill_grid_columns( + self, carry: Tuple[chex.Array, int], arr_value: int + ) -> Tuple[Tuple[chex.Array, int], int]: + """Fills the grid columns with a value. + This function will fill the grid columns with a value that + is incremented by 1 each time it is called. + """ + + grid = carry[0] + grid_x, _ = grid.shape + fill_value = carry[1] + + fill_value += 1 + + edit_grid = jax.lax.dynamic_slice(grid, (0, arr_value), (grid_x, 3)) + edit_grid = jnp.ones_like(edit_grid) + edit_grid *= fill_value + + grid = jax.lax.dynamic_update_slice(grid, edit_grid, (0, arr_value)) + + return (grid, fill_value), arr_value + + def _fill_grid_rows( + self, carry: Tuple[chex.Array, int, int], arr_value: int + ) -> Tuple[Tuple[chex.Array, int, int], int]: + """Fills the grid rows with a value. + This function will fill the grid rows with a value that + is incremented by `num_col_blocks` each time it is called. + """ + + grid = carry[0] + _, grid_y = grid.shape + sum_value = carry[1] + num_col_blocks = carry[2] + + edit_grid = jax.lax.dynamic_slice(grid, (arr_value, 0), (3, grid_y)) + edit_grid += sum_value + + sum_value += num_col_blocks + + grid = jax.lax.dynamic_update_slice(grid, edit_grid, (arr_value, 0)) + + return (grid, sum_value, num_col_blocks), arr_value + + def _select_sides(self, array: chex.Array, key: chex.PRNGKey) -> chex.Array: + """Randomly selects a value to replace the center value of an array + containing three values. + """ + + selector = jax.random.uniform(key, shape=()) + + center_val = jax.lax.cond( + selector > 0.5, + lambda: array[0], + lambda: array[2], + ) + + array = array.at[1].set(center_val) + + return array + + def _select_col_interlocks( + self, carry: Tuple[chex.Array, chex.PRNGKey], col: int + ) -> Tuple[Tuple[chex.Array, chex.PRNGKey], int]: + """Creates interlocks in adjacent blocks along columns by randomly + selecting a value from the left and right side of the column. + """ + + grid = carry[0] + key = carry[1] + rows = grid.shape[0] + + grid_slice = jax.lax.dynamic_slice(grid, (0, col - 1), (rows, 3)) + all_keys = jax.random.split(key, rows + 1) + key = all_keys[0] + select_keys = all_keys[1:] + filled_grid_slice = jax.vmap(self._select_sides)(grid_slice, select_keys) + + grid = jax.lax.dynamic_update_slice(grid, filled_grid_slice, (0, col - 1)) + + return (grid, key), col + + def _select_row_interlocks( + self, carry: Tuple[chex.Array, chex.PRNGKey], row: int + ) -> Tuple[Tuple[chex.Array, chex.PRNGKey], int]: + """Creates interlocks in adjacent blocks along rows by randomly + selecting a value from the block above and below the current + block. + """ + + grid = carry[0] + key = carry[1] + cols = grid.shape[1] + + grid_slice = jax.lax.dynamic_slice(grid, (row - 1, 0), (3, cols)) + + grid_slice = grid_slice.T + + all_keys = jax.random.split(key, cols + 1) + key = all_keys[0] + select_keys = all_keys[1:] + + filled_grid_slice = jax.vmap(self._select_sides)(grid_slice, select_keys) + filled_grid_slice = filled_grid_slice.T + + grid = jax.lax.dynamic_update_slice(grid, filled_grid_slice, (row - 1, 0)) + + return (grid, key), row + + def _first_nonzero( + self, arr: chex.Array, axis: int, invalid_val: int = 1000 + ) -> chex.Numeric: + """Returns the index of the first non-zero value in an array.""" + + mask = arr != 0 + return jnp.min( + jnp.where(mask.any(axis=axis), mask.argmax(axis=axis), invalid_val) + ) + + def _crop_nonzero(self, arr_: chex.Array) -> chex.Array: + """Crops a block to be of shape (3, 3).""" + + row_roll, col_roll = self._first_nonzero(arr_, axis=0), self._first_nonzero( + arr_, axis=1 + ) + + arr_ = jnp.roll(arr_, -row_roll, axis=0) + arr_ = jnp.roll(arr_, -col_roll, axis=1) + + cropped_arr = jnp.zeros((3, 3), dtype=jnp.int32) + + cropped_arr = cropped_arr.at[:, :].set(arr_[:3, :3]) + + return cropped_arr + + def _extract_block( + self, carry: Tuple[chex.Array, chex.PRNGKey], block_num: int + ) -> Tuple[Tuple[chex.Array, chex.PRNGKey], chex.Array]: + """Extracts a block from a solved grid according to its block number + and rotates it by a random amount of degrees. + """ + + grid, key = carry + + # create a boolean mask for the current block number + mask = grid == block_num + # use the mask to extract the block from the grid + block = jnp.where(mask, grid, 0) + + # Crop block + block = self._crop_nonzero(block) + + # Rotate block by random amount of degrees {0, 90, 180, 270} + key, rot_key = jax.random.split(key) + rotation_value = jax.random.randint(key=rot_key, shape=(), minval=0, maxval=4) + rotated_block = rotate_block(block, rotation_value) + + return (grid, key), rotated_block + + def __call__(self, key: chex.PRNGKey) -> State: + """Generates a random flat_pack grid. + + Args: + key: jax random key in case stochasticity is used in the instance generation process. + + Returns: + A `FlatPack` environment state. + """ + + num_blocks = self.num_row_blocks * self.num_col_blocks + + # Compute the size of the grid. + grid_row_dim = compute_grid_dim(self.num_row_blocks) + grid_col_dim = compute_grid_dim(self.num_col_blocks) + + # Get indices of grid where interlocks will be. + row_interlock_idxs = get_significant_idxs(grid_row_dim) + col_interlock_idxs = get_significant_idxs(grid_col_dim) + + # Create an empty grid. + grid = jnp.ones((grid_row_dim, grid_col_dim), dtype=jnp.int32) + + # Fill grid columns with block numbers + (grid, _), _ = jax.lax.scan( + f=self._fill_grid_columns, + init=(grid, 1), + xs=col_interlock_idxs, + ) + + # Fill grid rows with block numbers + (grid, _, _), _ = jax.lax.scan( + f=self._fill_grid_rows, + init=( + grid, + self.num_col_blocks, + self.num_col_blocks, + ), + xs=row_interlock_idxs, + ) + + # Create block interlocks at relevant rows and columns. + (grid, key), _ = jax.lax.scan( + f=self._select_col_interlocks, init=(grid, key), xs=col_interlock_idxs + ) + + (solved_grid, key), _ = jax.lax.scan( + f=self._select_row_interlocks, init=(grid, key), xs=row_interlock_idxs + ) + + # Extract blocks from the filled grid + _, blocks = jax.lax.scan( + f=self._extract_block, + init=(solved_grid, key), + xs=jnp.arange(1, num_blocks + 1), + ) + + # Finally shuffle the blocks along the leading dimension to + # untangle a block's number from its position in the blocks array. + key, shuffle_blocks_key = jax.random.split(key) + blocks = jax.random.permutation( + key=shuffle_blocks_key, x=blocks, axis=0, independent=False + ) + + return State( + blocks=blocks, + num_blocks=num_blocks, + action_mask=jnp.ones( + (num_blocks, 4, grid_row_dim - 2, grid_col_dim - 2), dtype=bool + ), + grid=jnp.zeros_like(solved_grid), + step_count=0, + key=key, + placed_blocks=jnp.zeros(num_blocks, dtype=bool), + ) + + +class ToyFlatPackGeneratorWithRotation(InstanceGenerator): + """Generates a deterministic toy FlatPack environment with 4 blocks. The blocks + are rotated by a random amount of degrees {0, 90, 180, 270} but not shuffled. + """ + + def __init__(self) -> None: + super().__init__(num_row_blocks=2, num_col_blocks=2) + + def __call__(self, key: chex.PRNGKey) -> State: + + del key + + solved_grid = jnp.array( + [ + [1, 1, 1, 2, 2], + [1, 1, 2, 2, 2], + [3, 1, 4, 4, 2], + [3, 3, 4, 4, 4], + [3, 3, 3, 4, 4], + ], + dtype=jnp.int32, + ) + + blocks = jnp.array( + [ + [[0, 1, 0], [0, 1, 1], [1, 1, 1]], + [[2, 0, 0], [2, 2, 2], [2, 2, 0]], + [[0, 0, 3], [0, 3, 3], [3, 3, 3]], + [[4, 4, 0], [4, 4, 4], [0, 4, 4]], + ], + dtype=jnp.int32, + ) + + return State( + blocks=blocks, + grid=jnp.zeros_like(solved_grid), + action_mask=jnp.ones((4, 4, 3, 3), dtype=bool), + num_blocks=jnp.int32(4), + key=jax.random.PRNGKey(0), + step_count=0, + placed_blocks=jnp.zeros(4, dtype=bool), + ) + + +class ToyFlatPackGeneratorNoRotation(InstanceGenerator): + """Generates a deterministic toy FlatPack environment with 4 blocks. The + blocks are not rotated and not shuffled. + """ + + def __init__(self) -> None: + super().__init__(num_row_blocks=2, num_col_blocks=2) + + def __call__(self, key: chex.PRNGKey) -> State: + + del key + + solved_grid = jnp.array( + [ + [1, 1, 1, 2, 2], + [1, 1, 2, 2, 2], + [3, 1, 4, 4, 2], + [3, 3, 4, 4, 4], + [3, 3, 3, 4, 4], + ], + dtype=jnp.int32, + ) + + blocks = jnp.array( + [ + [[1, 1, 1], [1, 1, 0], [0, 1, 0]], + [[0, 2, 2], [2, 2, 2], [0, 0, 2]], + [[3, 0, 0], [3, 3, 0], [3, 3, 3]], + [[4, 4, 0], [4, 4, 4], [0, 4, 4]], + ], + dtype=jnp.int32, + ) + + return State( + blocks=blocks, + num_blocks=jnp.int32(4), + key=jax.random.PRNGKey(0), + action_mask=jnp.ones((4, 4, 3, 3), dtype=bool), + grid=jnp.zeros_like(solved_grid), + step_count=0, + placed_blocks=jnp.zeros(4, dtype=bool), + ) diff --git a/jumanji/environments/packing/flat_pack/generator_test.py b/jumanji/environments/packing/flat_pack/generator_test.py new file mode 100644 index 000000000..5047c0eae --- /dev/null +++ b/jumanji/environments/packing/flat_pack/generator_test.py @@ -0,0 +1,251 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import chex +import jax +import jax.numpy as jnp +import pytest + +from jumanji.environments.packing.flat_pack.generator import RandomFlatPackGenerator + + +@pytest.fixture +def random_flat_pack_generator() -> RandomFlatPackGenerator: + """Creates a generator with two row blocks and two column blocks.""" + return RandomFlatPackGenerator( + num_col_blocks=2, + num_row_blocks=2, + ) + + +@pytest.fixture +def grid_only_ones() -> chex.Array: + """A grid with only ones.""" + return jnp.ones((5, 5)) + + +@pytest.fixture +def grid_columns_partially_filled() -> chex.Array: + """A grid after one iteration of _fill_grid_columns.""" + # fmt: off + return jnp.array( + [ + [1.0, 1.0, 2.0, 2.0, 2.0], + [1.0, 1.0, 2.0, 2.0, 2.0], + [1.0, 1.0, 2.0, 2.0, 2.0], + [1.0, 1.0, 2.0, 2.0, 2.0], + [1.0, 1.0, 2.0, 2.0, 2.0], + ] + ) + # fmt: on + + +@pytest.fixture +def grid_rows_partially_filled() -> chex.Array: + """A grid after one iteration of _fill_grid_rows.""" + # fmt: off + return jnp.array( + [ + [1.0, 1.0, 2.0, 2.0, 2.0], + [1.0, 1.0, 2.0, 2.0, 2.0], + [3.0, 3.0, 4.0, 4.0, 4.0], + [3.0, 3.0, 4.0, 4.0, 4.0], + [3.0, 3.0, 4.0, 4.0, 4.0], + ] + ) + # fmt: on + + +def test_random_flat_pack_generator__call( + random_flat_pack_generator: RandomFlatPackGenerator, key: chex.PRNGKey +) -> None: + """Test that generator generates a valid state.""" + state = random_flat_pack_generator(key) + assert state.num_blocks == 4 + assert state.blocks.shape == (4, 3, 3) + assert all(state.blocks[i].shape == (3, 3) for i in range(4)) + assert state.action_mask.shape == (4, 4, 3, 3) + assert state.step_count == 0 + + +def test_random_flat_pack_generator__no_retrace( + random_flat_pack_generator: RandomFlatPackGenerator, key: chex.PRNGKey +) -> None: + """Checks that generator call method is only traced once when jitted.""" + keys = jax.random.split(key, 2) + jitted_generator = jax.jit( + chex.assert_max_traces((random_flat_pack_generator.__call__), n=1) + ) + + for key in keys: + jitted_generator(key) + + +def test_random_flat_pack_generator__fill_grid_columns( + random_flat_pack_generator: RandomFlatPackGenerator, + grid_only_ones: chex.Array, + grid_columns_partially_filled: chex.Array, +) -> None: + """Checks that _fill_grid_columns method does a single + step correctly. + """ + + (grid, fill_value), arr_value = random_flat_pack_generator._fill_grid_columns( + (grid_only_ones, 1), 2 + ) + + assert grid.shape == (5, 5) + assert jnp.array_equal(grid, grid_columns_partially_filled) + assert fill_value == 2 + assert arr_value == 2 + + +def test_random_flat_pack_generator__fill_grid_rows( + random_flat_pack_generator: RandomFlatPackGenerator, + grid_columns_partially_filled: chex.Array, + grid_rows_partially_filled: chex.Array, +) -> None: + """Checks that _fill_grid_columns method does a single + step correctly. + """ + + ( + grid, + sum_value, + num_col_blocks, + ), arr_value = random_flat_pack_generator._fill_grid_rows( + (grid_columns_partially_filled, 2, 2), 2 + ) + + assert grid.shape == (5, 5) + assert jnp.array_equal(grid, grid_rows_partially_filled) + assert sum_value == 4 + assert num_col_blocks == 2 + assert arr_value == 2 + + +def test_random_flat_pack_generator__select_sides( + random_flat_pack_generator: RandomFlatPackGenerator, key: chex.PRNGKey +) -> None: + """Checks that _select_sides method correctly assigns the + middle value in an array with shape (3,) to either the value + at index 0 or 2. + """ + + side_chosen_array = random_flat_pack_generator._select_sides( + jnp.array([1.0, 2.0, 3.0]), key + ) + + assert side_chosen_array.shape == (3,) + # check that the output is different from the input + assert jnp.not_equal(jnp.array([1.0, 2.0, 3.0]), side_chosen_array).any() + + +def test_random_flat_pack_generator__select_col_interlocks( + random_flat_pack_generator: RandomFlatPackGenerator, + grid_rows_partially_filled: chex.Array, + key: chex.PRNGKey, +) -> None: + """Checks that interlocks are created along a given column of the grid.""" + + ( + grid_with_interlocks_selected, + new_key, + ), column = random_flat_pack_generator._select_col_interlocks( + (grid_rows_partially_filled, key), 2 + ) + + assert grid_with_interlocks_selected.shape == (5, 5) + assert jnp.not_equal(key, new_key).all() + assert column == 2 + + selected_col_interlocks = grid_with_interlocks_selected[:, 2] + before_selected_interlocks_col = grid_rows_partially_filled[:, 2] + + # check that the interlocks are different from the column before + assert jnp.not_equal(selected_col_interlocks, before_selected_interlocks_col).any() + + +def test_random_flat_pack_generator__select_row_interlocks( + random_flat_pack_generator: RandomFlatPackGenerator, + grid_rows_partially_filled: chex.Array, + key: chex.PRNGKey, +) -> None: + """Checks that interlocks are created along a given row of the grid.""" + + ( + grid_with_interlocks_selected, + new_key, + ), row = random_flat_pack_generator._select_row_interlocks( + (grid_rows_partially_filled, key), 2 + ) + + assert grid_with_interlocks_selected.shape == (5, 5) + assert jnp.not_equal(key, new_key).all() + assert row == 2 + + selected_row_interlocks = grid_with_interlocks_selected[2, :] + before_selected_interlocks_row = grid_rows_partially_filled[2, :] + + # check that the interlocks are different from the row before + assert jnp.not_equal(selected_row_interlocks, before_selected_interlocks_row).any() + + +def test_random_flat_pack_generator__first_nonzero( + random_flat_pack_generator: RandomFlatPackGenerator, + block_one_placed_at_1_1: chex.Array, +) -> None: + """Checks that the indices of the first non-zero value in a grid is found correctly.""" + + first_nonzero_row = random_flat_pack_generator._first_nonzero( + block_one_placed_at_1_1, 0 + ) + first_nonzero_col = random_flat_pack_generator._first_nonzero( + block_one_placed_at_1_1, 1 + ) + + assert first_nonzero_row == 1 + assert first_nonzero_col == 1 + + +def test_random_flat_pack_generator__crop_nonzero( + random_flat_pack_generator: RandomFlatPackGenerator, + block_one_placed_at_1_1: chex.Array, +) -> None: + """Checks a block is correctly extracted from a grid of zeros.""" + + cropped_block = random_flat_pack_generator._crop_nonzero(block_one_placed_at_1_1) + + assert cropped_block.shape == (3, 3) + assert jnp.array_equal( + cropped_block, jnp.array([[1.0, 1.0, 1.0], [1.0, 1.0, 0.0], [0.0, 1.0, 0.0]]) + ) + + +def test_random_flat_pack_generator__extract_block( + random_flat_pack_generator: RandomFlatPackGenerator, + solved_grid: chex.Array, + key: chex.PRNGKey, +) -> None: + """Checks that a block is correctly extracted from a solved grid.""" + + # extract block number 3 + (_, new_key), block = random_flat_pack_generator._extract_block( + (solved_grid, key), 3 + ) + + assert block.shape == (3, 3) + assert jnp.not_equal(key, new_key).all() + # check that the block only contains 3s or 0s + assert jnp.isin(block, jnp.array([0.0, 3.0])).all() diff --git a/jumanji/environments/packing/flat_pack/reward.py b/jumanji/environments/packing/flat_pack/reward.py new file mode 100644 index 000000000..74ac1166e --- /dev/null +++ b/jumanji/environments/packing/flat_pack/reward.py @@ -0,0 +1,111 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc + +import chex +import jax +import jax.numpy as jnp + +from jumanji.environments.packing.flat_pack.types import State + + +class RewardFn(abc.ABC): + @abc.abstractmethod + def __call__( + self, + state: State, + placed_block: chex.Numeric, + next_state: State, + is_valid: bool, + is_done: bool, + ) -> chex.Numeric: + """Compute the reward based on the current state, the chosen action, + whether the action is valid and whether the episode is terminated. + """ + + +class CellDenseReward(RewardFn): + """Reward function for the dense reward setting. + + This reward returns the number of non-zero cells in a placed block normalised + by the total number of cells in the grid. This means that the maximum possible + episode return is 1. That is to say that, in the case of this reward, an agent + will optimise for maximal area coverage in the the grid. + """ + + def __call__( + self, + state: State, + placed_block: chex.Numeric, + next_state: State, + is_valid: bool, + is_done: bool, + ) -> chex.Numeric: + """Compute the reward based on the current state, the chosen action, + whether the action is valid and whether the episode is terminated. + """ + + del is_done + del next_state + del state + + num_rows, num_cols = placed_block.shape + + reward = jax.lax.cond( + is_valid, + lambda: jnp.sum(placed_block != 0.0, dtype=jnp.float32) + / (num_rows * num_cols), + lambda: jnp.float32(0.0), + ) + + return reward + + +class BlockDenseReward(RewardFn): + """Reward function for the dense reward setting. + + This reward will give a normalised reward for each block placed on the grid + with each block being equally weighted. This implies that each placed block + will have a reward of `1 / num_blocks` and the maximum possible episode return + is 1. That is to say that, in the case of this reward, an agent will optimise + for placing as many blocks as possible on the grid. + """ + + def __call__( + self, + state: State, + placed_block: chex.Numeric, + next_state: State, + is_valid: bool, + is_done: bool, + ) -> chex.Numeric: + """Compute the reward based on the current state, the chosen action, + whether the action is valid and whether the episode is terminated. + """ + + del is_done + del next_state + del placed_block + + num_blocks = state.num_blocks + del state + + reward = jax.lax.cond( + is_valid, + lambda: jnp.float32(1.0 / num_blocks), + lambda: jnp.float32(0.0), + ) + + return reward diff --git a/jumanji/environments/packing/flat_pack/reward_test.py b/jumanji/environments/packing/flat_pack/reward_test.py new file mode 100644 index 000000000..dc846a8b7 --- /dev/null +++ b/jumanji/environments/packing/flat_pack/reward_test.py @@ -0,0 +1,259 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import chex +import jax +import jax.numpy as jnp +import pytest + +from jumanji.environments.packing.flat_pack.reward import ( + BlockDenseReward, + CellDenseReward, +) +from jumanji.environments.packing.flat_pack.types import State + + +@pytest.fixture +def blocks() -> chex.Array: + """An array containing 4 blocks.""" + + return jnp.array( + [ + [[1.0, 1.0, 1.0], [1.0, 1.0, 0.0], [0.0, 1.0, 0.0]], + [[0.0, 2.0, 2.0], [2.0, 2.0, 2.0], [0.0, 0.0, 2.0]], + [[3.0, 0.0, 0.0], [3.0, 3.0, 0.0], [3.0, 3.0, 3.0]], + [[4.0, 4.0, 0.0], [4.0, 4.0, 4.0], [0.0, 4.0, 4.0]], + ], + dtype=jnp.float32, + ) + + +@pytest.fixture() +def state_with_no_blocks_placed( + solved_grid: chex.Array, key: chex.PRNGKey, blocks: chex.Array +) -> State: + """A grid state with no blocks placed.""" + + return State( + num_blocks=4, + blocks=blocks, + action_mask=jnp.ones((4, 4, 2, 2), dtype=bool), + placed_blocks=jnp.zeros(4, dtype=bool), + grid=jnp.zeros_like(solved_grid), + step_count=0, + key=key, + ) + + +@pytest.fixture() +def state_with_block_one_placed( + action_mask_with_block_1_placed: chex.Array, + grid_with_block_one_placed: chex.Array, + blocks: chex.Array, + key: chex.PRNGKey, +) -> State: + """A grid state with block one placed.""" + + key, new_key = jax.random.split(key) + return State( + num_blocks=4, + action_mask=action_mask_with_block_1_placed, + placed_blocks=jnp.array( + [ + True, + False, + False, + False, + ] + ), + grid=grid_with_block_one_placed, + step_count=0, + key=new_key, + blocks=blocks, + ) + + +@pytest.fixture() +def state_needing_only_block_one( + action_mask_without_only_block_1_placed: chex.Array, + solved_grid: chex.Array, + grid_with_block_one_placed: chex.Array, + blocks: chex.Array, + key: chex.PRNGKey, +) -> State: + """A grid state that one needs block one to be fully completed.""" + + key, new_key = jax.random.split(key) + + grid = solved_grid - grid_with_block_one_placed + + return State( + num_blocks=4, + action_mask=action_mask_without_only_block_1_placed, + placed_blocks=jnp.array( + [ + True, + False, + False, + False, + ] + ), + grid=grid, + step_count=3, + blocks=blocks, + key=new_key, + ) + + +@pytest.fixture() +def solved_state( + solved_grid: chex.Array, + blocks: chex.Array, + key: chex.PRNGKey, +) -> State: + """A solved grid state.""" + + key, new_key = jax.random.split(key) + + return State( + num_blocks=4, + action_mask=jnp.zeros((4, 4, 2, 2), dtype=bool), + placed_blocks=jnp.array( + [ + True, + True, + True, + True, + ] + ), + grid=solved_grid, + step_count=4, + blocks=blocks, + key=new_key, + ) + + +@pytest.fixture() +def block_one_placed_at_2_2(grid_with_block_one_placed: chex.Array) -> chex.Array: + """A 2D array of zeros where block one has been placed with it left top-most + corner at position (2, 2). + """ + + # Shift all elements in the array two down and two to the right + placed_block = jnp.roll(grid_with_block_one_placed, shift=2, axis=0) + placed_block = jnp.roll(placed_block, shift=2, axis=1) + + return placed_block + + +def test_cell_dense_reward( + state_with_no_blocks_placed: State, + state_with_block_one_placed: State, + block_one_placed_at_0_0: chex.Array, + block_one_placed_at_1_1: chex.Array, + block_one_placed_at_2_2: chex.Array, +) -> None: + + dense_reward = jax.jit(CellDenseReward()) + + # Test placing block one completely correctly + reward = dense_reward( + state=state_with_no_blocks_placed, + placed_block=block_one_placed_at_0_0, + is_valid=True, + is_done=False, + next_state=state_with_block_one_placed, + ) + assert reward == 6.0 / 25.0 + + # Test placing block one partially correct + reward = dense_reward( + state=state_with_no_blocks_placed, + placed_block=block_one_placed_at_1_1, + is_valid=True, + is_done=False, + next_state=state_with_block_one_placed, + ) + assert reward == 6.0 / 25.0 + + # Test placing a completely incorrect block + reward = dense_reward( + state=state_with_no_blocks_placed, + placed_block=block_one_placed_at_2_2, + is_valid=True, + is_done=False, + next_state=state_with_block_one_placed, + ) + assert reward == 6.0 / 25.0 + + # Test invalid action returns 0 reward. + reward = dense_reward( + state=state_with_no_blocks_placed, + placed_block=block_one_placed_at_0_0, + is_valid=False, + is_done=False, + next_state=state_with_block_one_placed, + ) + assert reward == 0.0 + + +def test_block_dense_reward( + state_with_no_blocks_placed: State, + state_with_block_one_placed: State, + block_one_placed_at_0_0: chex.Array, + block_one_placed_at_1_1: chex.Array, + block_one_placed_at_2_2: chex.Array, +) -> None: + + dense_reward = jax.jit(BlockDenseReward()) + + # Test placing block one completely correctly + reward = dense_reward( + state=state_with_no_blocks_placed, + placed_block=block_one_placed_at_0_0, + is_valid=True, + is_done=False, + next_state=state_with_block_one_placed, + ) + assert reward == 1.0 / 4.0 + + # Test placing block one partially correct + reward = dense_reward( + state=state_with_no_blocks_placed, + placed_block=block_one_placed_at_1_1, + is_valid=True, + is_done=False, + next_state=state_with_block_one_placed, + ) + assert reward == 1.0 / 4.0 + + # Test placing a completely incorrect block + reward = dense_reward( + state=state_with_no_blocks_placed, + placed_block=block_one_placed_at_2_2, + is_valid=True, + is_done=False, + next_state=state_with_block_one_placed, + ) + assert reward == 1.0 / 4.0 + + # Test invalid action returns 0 reward. + reward = dense_reward( + state=state_with_no_blocks_placed, + placed_block=block_one_placed_at_0_0, + is_valid=False, + is_done=False, + next_state=state_with_block_one_placed, + ) + assert reward == 0.0 diff --git a/jumanji/environments/packing/flat_pack/types.py b/jumanji/environments/packing/flat_pack/types.py new file mode 100644 index 000000000..e41a1b0a8 --- /dev/null +++ b/jumanji/environments/packing/flat_pack/types.py @@ -0,0 +1,61 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, NamedTuple + +import chex + +if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239 + from dataclasses import dataclass +else: + from chex import dataclass + + +class Observation(NamedTuple): + """ + grid: 2D array with the current state of grid. + blocks: 3D array with the blocks to be placed on the board. Here each block is a + 2D array with shape (3, 3). + action_mask: 4D array showing where blocks can be placed on the grid. + this mask includes all possible rotations and possible placement locations + for each block on the grid. + """ + + grid: chex.Array # (num_rows, num_cols) + blocks: chex.Array # (num_blocks, 3, 3) + action_mask: chex.Array # (num_blocks, num_rotations, num_rows-3, num_cols-3) + + +@dataclass +class State: + """ + grid: 2D array with the current state of grid. + num_blocks: number of blocks in the full grid. + blocks: 3D array with the blocks to be placed on the board. Here each block is a + 2D array with shape (3, 3). + action_mask: 4D array showing where blocks can be placed on the grid. + this mask includes all possible rotations and possible placement locations + for each block on the grid. + placed_blocks: 1D boolean array showing which blocks have been placed on the board. + step_count: number of steps taken in the environment. + key: random key used for board generation. + """ + + grid: chex.Array # (num_rows, num_cols) + num_blocks: chex.Numeric # () + blocks: chex.Array # (num_blocks, 3, 3) + action_mask: chex.Array # (num_blocks, num_rotations, num_rows-3, num_cols-3) + placed_blocks: chex.Array # (num_blocks,) + step_count: chex.Numeric # () + key: chex.PRNGKey # (2,) diff --git a/jumanji/environments/packing/flat_pack/utils.py b/jumanji/environments/packing/flat_pack/utils.py new file mode 100644 index 000000000..1611aadf8 --- /dev/null +++ b/jumanji/environments/packing/flat_pack/utils.py @@ -0,0 +1,58 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A general utils file for the flat_pack environment.""" + +import chex +import jax +import jax.numpy as jnp + + +def compute_grid_dim(num_blocks: int) -> int: + """Computes the grid dimension given the number of blocks. + + Args: + num_blocks: The number of blocks. + """ + return 3 * num_blocks - (num_blocks - 1) + + +def get_significant_idxs(grid_dim: int) -> chex.Array: + """Returns the indices of the grid that are significant. These will be used + to create interlocks between adjacent blocks. + + Args: + grid_dim: The dimension of the grid. + """ + return jnp.arange(grid_dim)[:: 3 - 1][1:-1] + + +def rotate_block(block: chex.Array, rotation_value: int) -> chex.Array: + """Rotates a block by {0, 90, 180, 270} degrees. + + Args: + block: The block to rotate. + rotation: The number of rotations to rotate the block by. + """ + rotated_block = jax.lax.switch( + index=rotation_value, + branches=( + lambda arr: arr, + lambda arr: jnp.flip(jnp.transpose(arr), axis=1), + lambda arr: jnp.flip(jnp.flip(arr, axis=0), axis=1), + lambda arr: jnp.flip(jnp.transpose(arr), axis=0), + ), + operand=block, + ) + + return rotated_block diff --git a/jumanji/environments/packing/flat_pack/utils_test.py b/jumanji/environments/packing/flat_pack/utils_test.py new file mode 100644 index 000000000..1981a3010 --- /dev/null +++ b/jumanji/environments/packing/flat_pack/utils_test.py @@ -0,0 +1,92 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import chex +import jax.numpy as jnp +import pytest + +from jumanji.environments.packing.flat_pack.utils import ( + compute_grid_dim, + get_significant_idxs, + rotate_block, +) + + +@pytest.mark.parametrize( + "num_blocks, expected_grid_dim", + [ + (1, 3), + (2, 5), + (3, 7), + (4, 9), + (5, 11), + ], +) +def test_compute_grid_dim(num_blocks: int, expected_grid_dim: int) -> None: + """Test that grid dimension is correctly computed given a number of blocks.""" + assert compute_grid_dim(num_blocks) == expected_grid_dim + + +@pytest.mark.parametrize( + "grid_dim, expected_idxs", + [ + (5, jnp.array([2])), + (7, jnp.array([2, 4])), + (9, jnp.array([2, 4, 6])), + (11, jnp.array([2, 4, 6, 8])), + ], +) +def test_get_significant_idxs(grid_dim: int, expected_idxs: chex.Array) -> None: + """Test that significant indices are correctly computed given a grid dimension.""" + assert jnp.all(get_significant_idxs(grid_dim) == expected_idxs) + + +def test_rotate_block(block: chex.Array) -> None: + + # Test with no rotation. + rotated_block = rotate_block(block, 0) + assert jnp.array_equal(rotated_block, block) + + # Test 90 degree rotation. + expected_rotated_block = jnp.array( + [ + [0.0, 0.0, 0.0], + [0.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + ] + ) + rotated_block = rotate_block(block, 1) + assert jnp.array_equal(rotated_block, expected_rotated_block) + + # Test 180 degree rotation. + expected_rotated_block = jnp.array( + [ + [1.0, 0.0, 0.0], + [1.0, 1.0, 0.0], + [1.0, 1.0, 0.0], + ] + ) + rotated_block = rotate_block(block, 2) + assert jnp.array_equal(rotated_block, expected_rotated_block) + + # Test 270 degree rotation. + expected_rotated_block = jnp.array( + [ + [1.0, 1.0, 1.0], + [1.0, 1.0, 0.0], + [0.0, 0.0, 0.0], + ] + ) + rotated_block = rotate_block(block, 3) + assert jnp.array_equal(rotated_block, expected_rotated_block) diff --git a/jumanji/environments/packing/flat_pack/viewer.py b/jumanji/environments/packing/flat_pack/viewer.py new file mode 100644 index 000000000..639435dd1 --- /dev/null +++ b/jumanji/environments/packing/flat_pack/viewer.py @@ -0,0 +1,190 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Dict, Optional, Sequence, Tuple + +import chex +import matplotlib.animation +import matplotlib.cm +import matplotlib.pyplot as plt +import numpy as np +from numpy.typing import NDArray + +import jumanji.environments +from jumanji.environments.packing.flat_pack.types import State +from jumanji.viewer import Viewer + + +class FlatPackViewer(Viewer): + FIGURE_SIZE = (10, 10) + + def __init__(self, name: str, num_blocks: int, render_mode: str = "human") -> None: + """Viewer for a `FlatPack` environment. + + Args: + name: the window name to be used when initialising the window. + num_blocks: number of blocks in the environment. + render_mode: return a numpy array frame representing the environment. + """ + self._name = name + + # Pick display method + self._display: Callable[[plt.Figure], Optional[NDArray]] + if render_mode == "rgb_array": + self._display = self._display_rgb_array + elif render_mode == "human": + self._display = self._display_human + else: + raise ValueError(f"Invalid render mode: {render_mode}") + + # Create a color for each block. + colormap_indices = np.arange(0, 1, 1 / num_blocks) + colormap = matplotlib.cm.get_cmap("hsv", num_blocks + 1) + + self.colors = [(1.0, 1.0, 1.0, 1.0)] # Empty grid colour should be white. + for colormap_idx in colormap_indices: + # Give the blocks an alpha of 0.7. + r, g, b, _ = colormap(colormap_idx) + self.colors.append((r, g, b, 0.7)) + + # The animation must be stored in a variable that lives as long as the + # animation should run. Otherwise, the animation will get garbage-collected. + self._animation: Optional[matplotlib.animation.Animation] = None + + def render(self, state: State) -> Optional[NDArray]: + """Render a FlatPack environment state. + + Args: + state: the flat_pack environment state to be rendered. + + Returns: + RGB array if the render_mode is RenderMode.RGB_ARRAY. + """ + self._clear_display() + fig, ax = self._get_fig_ax() + ax.clear() + self._add_grid_image(state.grid, ax) + return self._display(fig) + + def animate( + self, + states: Sequence[State], + interval: int = 200, + save_path: Optional[str] = None, + ) -> matplotlib.animation.FuncAnimation: + """Create an animation from a sequence of FlatPack states. + + Args: + states: sequence of FlatPack states corresponding to consecutive timesteps. + interval: delay between frames in milliseconds, default to 200. + save_path: the path where the animation file should be saved. If it is None, the plot + will not be saved. + + Returns: + Animation that can be saved as a GIF, MP4, or rendered with HTML. + """ + fig, ax = plt.subplots( + num=f"{self._name}Animation", figsize=FlatPackViewer.FIGURE_SIZE + ) + plt.close(fig) + + def make_frame(state_index: int) -> None: + ax.clear() + state = states[state_index] + self._add_grid_image(state.grid, ax) + + # Create the animation object. + self._animation = matplotlib.animation.FuncAnimation( + fig, + make_frame, + frames=len(states), + interval=interval, + ) + + # Save the animation as a gif. + if save_path: + self._animation.save(save_path) + + return self._animation + + def close(self) -> None: + plt.close(self._name) + + def _display_human(self, fig: plt.Figure) -> None: + if plt.isinteractive(): + # Required to update render when using Jupyter Notebook. + fig.canvas.draw() + if jumanji.environments.is_notebook(): + plt.show(self._name) + else: + # Required to update render when not using Jupyter Notebook. + fig.canvas.draw_idle() + fig.canvas.flush_events() + + def _display_rgb_array(self, fig: plt.Figure) -> NDArray: + fig.canvas.draw() + return np.asarray(fig.canvas.buffer_rgba()) + + def _clear_display(self) -> None: + if jumanji.environments.is_notebook(): + import IPython.display + + IPython.display.clear_output(True) + + def _get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]: + recreate = not plt.fignum_exists(self._name) + fig = plt.figure(self._name, FlatPackViewer.FIGURE_SIZE) + if recreate: + if not plt.isinteractive(): + fig.show() + ax = fig.add_subplot() + else: + ax = fig.get_axes()[0] + return fig, ax + + def _add_grid_image(self, grid: chex.Array, ax: plt.Axes) -> None: + self._draw_grid(grid, ax) + ax.set_axis_off() + ax.set_aspect(1) + ax.relim() + ax.autoscale_view() + + def _draw_grid(self, grid: chex.Array, ax: plt.Axes) -> None: + # Flip the grid upside down to match the coordinate system of matplotlib. + grid = np.flipud(grid) + rows, cols = grid.shape + + for row in range(rows): + for col in range(cols): + self._draw_grid_cell(grid[row, col], row, col, ax) + + def _draw_grid_cell( + self, cell_value: int, row: int, col: int, ax: plt.Axes + ) -> None: + cell = plt.Rectangle((col, row), 1, 1, **self._get_cell_attributes(cell_value)) + ax.add_patch(cell) + if cell_value != 0: + ax.text( + col + 0.5, + row + 0.5, + str(int(cell_value)), + color="#606060", + ha="center", + va="center", + fontsize="xx-large", + ) + + def _get_cell_attributes(self, cell_value: int) -> Dict[str, Any]: + color = self.colors[int(cell_value)] + return {"facecolor": color, "edgecolor": "black", "linewidth": 1} diff --git a/jumanji/training/configs/config.yaml b/jumanji/training/configs/config.yaml index 458d60b07..6ad8f62fc 100644 --- a/jumanji/training/configs/config.yaml +++ b/jumanji/training/configs/config.yaml @@ -1,6 +1,6 @@ defaults: - _self_ - - env: snake # [bin_pack, cleaner, connector, cvrp, game_2048, graph_coloring, job_shop, knapsack, maze, minesweeper, mmst, multi_cvrp, pac_man, robot_warehouse, rubiks_cube, snake, sokoban, sudoku, tetris, tsp] + - env: snake # [bin_pack, cleaner, connector, cvrp, flat_pack, game_2048, graph_coloring, job_shop, knapsack, maze, minesweeper, mmst, multi_cvrp, pac_man, robot_warehouse, rubiks_cube, snake, sokoban, sudoku, tetris, tsp] agent: random # [random, a2c] diff --git a/jumanji/training/configs/env/flat_pack.yaml b/jumanji/training/configs/env/flat_pack.yaml new file mode 100644 index 000000000..ca3a61519 --- /dev/null +++ b/jumanji/training/configs/env/flat_pack.yaml @@ -0,0 +1,28 @@ +name: flat_pack +registered_version: FlatPack-v0 + +network: + num_transformer_layers: 2 + transformer_num_heads: 8 + transformer_key_size: 16 + transformer_mlp_units: [512] + hidden_size: 8 + +training: + num_epochs: 1000 + num_learner_steps_per_epoch: 100 + n_steps: 20 + total_batch_size: 64 + +evaluation: + eval_total_batch_size: 5000 + greedy_eval_total_batch_size: 5000 + +a2c: + normalize_advantage: False + discount_factor: 0.99 + bootstrapping_factor: 0.95 + l_pg: 1.0 + l_td: 1.0 + l_en: 0.01 + learning_rate: 2e-4 diff --git a/jumanji/training/networks/__init__.py b/jumanji/training/networks/__init__.py index 82ad0ae65..956d8dbb3 100644 --- a/jumanji/training/networks/__init__.py +++ b/jumanji/training/networks/__init__.py @@ -30,6 +30,10 @@ from jumanji.training.networks.connector.random import make_random_policy_connector from jumanji.training.networks.cvrp.actor_critic import make_actor_critic_networks_cvrp from jumanji.training.networks.cvrp.random import make_random_policy_cvrp +from jumanji.training.networks.flat_pack.actor_critic import ( + make_actor_critic_networks_flat_pack, +) +from jumanji.training.networks.flat_pack.random import make_random_policy_flat_pack from jumanji.training.networks.game_2048.actor_critic import ( make_actor_critic_networks_game_2048, ) diff --git a/jumanji/training/networks/flat_pack/__init__.py b/jumanji/training/networks/flat_pack/__init__.py new file mode 100644 index 000000000..21db9ec1c --- /dev/null +++ b/jumanji/training/networks/flat_pack/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jumanji/training/networks/flat_pack/actor_critic.py b/jumanji/training/networks/flat_pack/actor_critic.py new file mode 100644 index 000000000..5c6923b4d --- /dev/null +++ b/jumanji/training/networks/flat_pack/actor_critic.py @@ -0,0 +1,325 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Sequence, Tuple + +import chex +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np + +from jumanji.environments.packing.flat_pack import FlatPack, Observation +from jumanji.training.networks.actor_critic import ( + ActorCriticNetworks, + FeedForwardNetwork, +) +from jumanji.training.networks.parametric_distribution import ( + FactorisedActionSpaceParametricDistribution, +) +from jumanji.training.networks.transformer_block import TransformerBlock + + +def make_actor_critic_networks_flat_pack( + flat_pack: FlatPack, + num_transformer_layers: int, + transformer_num_heads: int, + transformer_key_size: int, + transformer_mlp_units: Sequence[int], + hidden_size: int, +) -> ActorCriticNetworks: + """Make actor-critic networks for the `FlatPack` environment.""" + num_values = np.asarray(flat_pack.action_spec().num_values) + parametric_action_distribution = FactorisedActionSpaceParametricDistribution( + action_spec_num_values=num_values + ) + num_blocks = flat_pack.num_blocks + policy_network = make_actor_network_flat_pack( + num_transformer_layers=num_transformer_layers, + transformer_num_heads=transformer_num_heads, + transformer_key_size=transformer_key_size, + transformer_mlp_units=transformer_mlp_units, + num_blocks=num_blocks, + hidden_size=hidden_size, + ) + value_network = make_critic_network_flat_pack( + num_transformer_layers=num_transformer_layers, + transformer_num_heads=transformer_num_heads, + transformer_key_size=transformer_key_size, + transformer_mlp_units=transformer_mlp_units, + num_blocks=num_blocks, + hidden_size=hidden_size, + ) + return ActorCriticNetworks( + policy_network=policy_network, + value_network=value_network, + parametric_action_distribution=parametric_action_distribution, + ) + + +class UNet(hk.Module): + """A simple module based on the UNet architecture. + Please note that all shapes assume an 11x11 grid observation to match the + default grid size of the FlatPack environment. + """ + + def __init__( + self, + model_size: int, + name: Optional[str] = None, + hidden_size: int = 8, + ) -> None: + super().__init__(name=name) + self.hidden_size = hidden_size + self.model_size = model_size + + def __call__(self, grid_observation: chex.Array) -> chex.Array: + # Grid observation is of shape (B, num_rows, num_cols) + + # Add a channel dimension + grid_observation = grid_observation[..., jnp.newaxis] + + # Down colvolve with strided convolutions + down_1 = hk.Conv2D(32, kernel_shape=3, stride=2, padding="SAME")( + grid_observation + ) + down_1 = jax.nn.relu(down_1) # (B, 6, 6, 32) + down_2 = hk.Conv2D(32, kernel_shape=3, stride=2, padding="SAME")(down_1) + down_2 = jax.nn.relu(down_2) # (B, 3, 3, 32) + + # Up convolve + up_1 = hk.Conv2DTranspose(32, kernel_shape=3, stride=2, padding="SAME")(down_2) + up_1 = jax.nn.relu(up_1) # (B, 6, 6, 32) + up_1 = jnp.concatenate([up_1, down_1], axis=-1) + up_2 = hk.Conv2DTranspose(32, kernel_shape=3, stride=2, padding="SAME")(up_1) + up_2 = jax.nn.relu(up_2) # (B, 12, 12, 32) + up_2 = up_2[:, :-1, :-1] + up_2 = jnp.concatenate( + [up_2, grid_observation], axis=-1 + ) # (B, num_rows, num_cols, 33) + + output = hk.Conv2D(self.hidden_size, kernel_shape=1, stride=1, padding="SAME")( + up_2 + ) + + # Crop the upconvolved output to be the same size as the action mask. + output = output[:, 1:-1, 1:-1] # (B, num_rows-2, num_cols-2, hidden_size) + + # Flatten down_2 to be (B, ...) + grid_conv_encoding = jnp.reshape( + down_2, + (down_2.shape[0], -1), + ) + + # Linear mapping to transformer model size. + grid_conv_encoding = hk.Linear(self.model_size)( + grid_conv_encoding + ) # (B, model_size) + + return grid_conv_encoding, output + + +class FlatPackTorso(hk.Module): + def __init__( + self, + num_transformer_layers: int, + transformer_num_heads: int, + transformer_key_size: int, + transformer_mlp_units: Sequence[int], + num_blocks: int, + hidden_size: int, + name: Optional[str] = None, + ) -> None: + super().__init__(name=name) + self.num_transformer_layers = num_transformer_layers + self.transformer_num_heads = transformer_num_heads + self.transformer_key_size = transformer_key_size + self.transformer_mlp_units = transformer_mlp_units + self.model_size = transformer_num_heads * transformer_key_size + self.num_blocks = num_blocks + self.hidden_size = hidden_size + + def __call__(self, observation: Observation) -> Tuple[chex.Array, chex.Array]: + # observation.blocks (B, num_blocks, 3, 3) + # observation.grid (B, num_rows, num_cols) + + # Flatten the blocks + flattened_blocks = jnp.reshape( + observation.blocks, (-1, self.num_blocks, 9) + ) # (B, num_blocks, 9) + + # Encode the blocks with an MLP + block_encoder = hk.nets.MLP(output_sizes=[self.model_size]) + blocks_embedding = jax.vmap(block_encoder)( + flattened_blocks + ) # (B, num_blocks, model_size) + + unet = UNet(hidden_size=self.hidden_size, model_size=self.model_size) + grid_conv_encoding, grid_encoding = unet( + observation.grid + ) # (B, model_size), (B, num_rows-2, num_cols-2, hidden_size) + + for block_id in range(self.num_transformer_layers): + + ( + self_attention_mask, # (B, 1, num_blocks, num_blocks) + cross_attention_mask, # (B, 1, num_blocks, 1) + ) = make_flatpack_masks(observation) + + self_attention = TransformerBlock( + num_heads=self.transformer_num_heads, + key_size=self.transformer_key_size, + mlp_units=self.transformer_mlp_units, + model_size=self.model_size, + w_init_scale=2 / self.num_transformer_layers, + name=f"self_attention_block_{block_id}", + ) + blocks_embedding = self_attention( + query=blocks_embedding, + key=blocks_embedding, + value=blocks_embedding, + mask=self_attention_mask, + ) + + cross_attention = TransformerBlock( + num_heads=self.transformer_num_heads, + key_size=self.transformer_key_size, + mlp_units=self.transformer_mlp_units, + model_size=self.model_size, + w_init_scale=2 / self.num_transformer_layers, + name=f"cross_attention_block_{block_id}", + ) + blocks_embedding = cross_attention( + query=blocks_embedding, + key=grid_conv_encoding, + value=grid_conv_encoding, + mask=cross_attention_mask, + ) + + # Map blocks embedding from (num_blocks, 128) to (num_blocks, num_rotations, hidden_size) + blocks_head = hk.nets.MLP(output_sizes=[4 * self.hidden_size]) + blocks_embedding = jax.vmap(blocks_head)(blocks_embedding) + blocks_embedding = jnp.reshape( + blocks_embedding, (-1, self.num_blocks, 4, self.hidden_size) + ) + + return blocks_embedding, grid_encoding + + +def make_actor_network_flat_pack( + num_transformer_layers: int, + transformer_num_heads: int, + transformer_key_size: int, + transformer_mlp_units: Sequence[int], + num_blocks: int, + hidden_size: int, +) -> FeedForwardNetwork: + def network_fn(observation: Observation) -> chex.Array: + torso = FlatPackTorso( + num_transformer_layers=num_transformer_layers, + transformer_num_heads=transformer_num_heads, + transformer_key_size=transformer_key_size, + transformer_mlp_units=transformer_mlp_units, + num_blocks=num_blocks, + hidden_size=hidden_size, + name="policy_torso", + ) + blocks_embedding, grid_embedding = torso(observation) + outer_product = jnp.einsum( + "...ijh,...klh->...ijkl", blocks_embedding, grid_embedding + ) + + logits = jnp.where( + observation.action_mask, outer_product, jnp.finfo(jnp.float32).min + ) + + logits = logits.reshape(*logits.shape[:-4], -1) + return logits + + init, apply = hk.without_apply_rng(hk.transform(network_fn)) + return FeedForwardNetwork(init=init, apply=apply) + + +def make_critic_network_flat_pack( + num_transformer_layers: int, + transformer_num_heads: int, + transformer_key_size: int, + transformer_mlp_units: Sequence[int], + num_blocks: int, + hidden_size: int, +) -> FeedForwardNetwork: + def network_fn(observation: Observation) -> chex.Array: + torso = FlatPackTorso( + num_transformer_layers=num_transformer_layers, + transformer_num_heads=transformer_num_heads, + transformer_key_size=transformer_key_size, + transformer_mlp_units=transformer_mlp_units, + num_blocks=num_blocks, + hidden_size=hidden_size, + name="critic_torso", + ) + + ( + blocks_embedding, # (B, num_blocks, 4, hidden_size) + grid_embedding, # (B, num_rows-2, num_cols-2, hidden_size) + ) = torso(observation) + + # Flatten the blocks embedding + blocks_embedding = jnp.reshape( + blocks_embedding, + (*blocks_embedding.shape[0:2], -1), + ) + + # Sum over blocks for permutation invariance + blocks_embedding = jnp.sum(blocks_embedding, axis=1) # (B, 4*hidden_size) + + # Flatten grid embedding while keeping batch dimension + grid_embedding = jnp.reshape( # (B, hidden_size * num_rows-2 * num_cols-2) + grid_embedding, + (grid_embedding.shape[0], -1), + ) + + grid_embedding = hk.Linear(blocks_embedding.shape[-1])(grid_embedding) + grid_embedding = jax.nn.relu(grid_embedding) + + # Concatenate along the second dimension + torso_output = jnp.concatenate([blocks_embedding, grid_embedding], axis=-1) + + value = hk.Linear(1)(torso_output) + + return jnp.squeeze(value, axis=-1) + + init, apply = hk.without_apply_rng(hk.transform(network_fn)) + return FeedForwardNetwork(init=init, apply=apply) + + +def make_flatpack_masks(observation: Observation) -> Tuple[chex.Array, chex.Array]: + """Return: + - self_attention_mask: mask of non-placed blocks. + - cross_attention_mask: action mask, i.e. blocks that can be placed. + """ + + mask = jnp.any(observation.action_mask, axis=(2, 3, 4)) + + # Replicate the mask on the query and key dimensions. + self_attention_mask = jnp.einsum("...i,...j->...ij", mask, mask) + # Expand on the head dimension. + self_attention_mask = jnp.expand_dims(self_attention_mask, axis=-3) + + # Expand on the query dimension. + cross_attention_mask = jnp.expand_dims(mask, axis=-2) + # Expand on the head dimension. + cross_attention_mask = jnp.expand_dims(cross_attention_mask, axis=-1) + + return self_attention_mask, cross_attention_mask diff --git a/jumanji/training/networks/flat_pack/random.py b/jumanji/training/networks/flat_pack/random.py new file mode 100644 index 000000000..a81ba43f0 --- /dev/null +++ b/jumanji/training/networks/flat_pack/random.py @@ -0,0 +1,28 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jumanji.environments.packing.flat_pack.env import FlatPack +from jumanji.training.networks.masked_categorical_random import ( + make_masked_categorical_random_ndim, +) +from jumanji.training.networks.protocols import RandomPolicy + + +def make_random_policy_flat_pack(flat_pack: FlatPack) -> RandomPolicy: + """Make random policy for FlatPack.""" + action_spec_num_values = flat_pack.action_spec().num_values + + return make_masked_categorical_random_ndim( + action_spec_num_values=action_spec_num_values + ) diff --git a/jumanji/training/setup_train.py b/jumanji/training/setup_train.py index e2d2b9890..a8b3ed6f1 100644 --- a/jumanji/training/setup_train.py +++ b/jumanji/training/setup_train.py @@ -29,6 +29,7 @@ BinPack, Cleaner, Connector, + FlatPack, Game2048, GraphColoring, JobShop, @@ -197,6 +198,11 @@ def _setup_random_policy( # noqa: CCR001 elif cfg.env.name == "graph_coloring": assert isinstance(env.unwrapped, GraphColoring) random_policy = networks.make_random_policy_graph_coloring() + elif cfg.env.name == "flat_pack": + assert isinstance(env.unwrapped, FlatPack) + random_policy = networks.make_random_policy_flat_pack( + flat_pack=env.unwrapped, + ) elif cfg.env.name == "pac_man": assert isinstance(env.unwrapped, PacMan) random_policy = networks.make_random_policy_pacman() @@ -245,6 +251,16 @@ def _setup_actor_critic_neworks( # noqa: CCR001 transformer_key_size=cfg.env.network.transformer_key_size, transformer_mlp_units=cfg.env.network.transformer_mlp_units, ) + elif cfg.env.name == "flat_pack": + assert isinstance(env.unwrapped, FlatPack) + actor_critic_networks = networks.make_actor_critic_networks_flat_pack( + flat_pack=env.unwrapped, + num_transformer_layers=cfg.env.network.num_transformer_layers, + transformer_num_heads=cfg.env.network.transformer_num_heads, + transformer_key_size=cfg.env.network.transformer_key_size, + transformer_mlp_units=cfg.env.network.transformer_mlp_units, + hidden_size=cfg.env.network.hidden_size, + ) elif cfg.env.name == "job_shop": assert isinstance(env.unwrapped, JobShop) actor_critic_networks = networks.make_actor_critic_networks_job_shop( diff --git a/mkdocs.yml b/mkdocs.yml index 39dbca1cd..fe048d0c4 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -25,6 +25,7 @@ nav: - Sudoku: environments/sudoku.md - Packing: - BinPack: environments/bin_pack.md + - FlatPack: environments/flat_pack.md - JobShop: environments/job_shop.md - Knapsack: environments/knapsack.md - Tetris: environments/tetris.md @@ -56,6 +57,7 @@ nav: - Sudoku: api/environments/sudoku.md - Packing: - BinPack: api/environments/bin_pack.md + - FlatPack: api/environments/flat_pack.md - JobShop: api/environments/job_shop.md - Knapsack: api/environments/knapsack.md - Tetris: api/environments/tetris.md