diff --git a/README.md b/README.md index 8365b1fe7..3e6a4a8a9 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,6 @@ | [**Docs**](https://instadeepai.github.io/jumanji) --- -
@@ -28,12 +27,11 @@ +
- - ## Welcome to the Jungle! 🌴 Jumanji is a suite of diverse and challenging reinforcement learning (RL) environments written in @@ -70,7 +68,6 @@ JAX-based environments. - 🏎️ **Training:** example agents that can be used as inspiration for the agents one may implement in their research. - ## Environments 🌍 Jumanji provides a diverse range of environments ranging from simple games to NP-hard combinatorial @@ -79,6 +76,7 @@ problems. | Environment | Category | Registered Version(s) | Source | Description | |------------------------------------------|----------|------------------------------------------------------|--------------------------------------------------------------------------------------------------|------------------------------------------------------------------------| | 🔢 Game2048 | Logic | `Game2048-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/game_2048/) | [doc](https://instadeepai.github.io/jumanji/environments/game_2048/) | +| 🔵🔗🟡🔗🔴 GraphColoring | Logic | `GraphColoring-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/graph_coloring/) | [doc](https://instadeepai.github.io/jumanji/environments/graph_coloring/) | | 💣 Minesweeper | Logic | `Minesweeper-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/minesweeper/) | [doc](https://instadeepai.github.io/jumanji/environments/minesweeper/) | | 🎲 RubiksCube | Logic | `RubiksCube-v0`+ +
+ +We provide here a Jax JIT-able implementation of the Graph Coloring environment. + +Graph coloring is a combinatorial optimization problem where the objective is to assign a color to each vertex of a graph in such a way that no two adjacent vertices share the same color. The problem is usually formulated as minimizing the number of colors used. The `GraphColoring` environment is an episodic, single-agent setting that allows for the exploration of graph coloring algorithms and reinforcement learning methods. + +## Observation + +The observation in the `GraphColoring` environment includes information about the graph, the colors assigned to the vertices, the action mask, and the current node index. + +- `graph`: jax array (bool) of shape `(num_nodes, num_nodes)`, representing the adjacency matrix of the graph. + - For example, a random observation of the graph adjacency matrix: + + ```[[False, True, False, True], + [ True, False, True, False], + [False, True, False, True], + [ True, False, True, False]]``` + +- `colors`: a JAX array (int32) of shape `(num_nodes,)`, representing the current color assignments for the vertices. Initially, all elements are set to -1, indicating that no colors have been assigned yet. + - For example, an initial color assignment: + ```[-1, -1, -1, -1]``` + +- `action_mask`: a JAX array of boolean values, shaped `(num_colors,)`, which indicates the valid actions in the current state of the environment. Each position in the array corresponds to a color. True at a position signifies that the corresponding color can be used to color a node, while False indicates the opposite. + - For example, for 4 number of colors available: + ```[True, False, True, False]``` + +- `current_node_index`: an integer representing the current node being colored. + - For example, an initial current_node_index might be 0. + +## Action + +The action space is a DiscreteArray of integer values in `[0, 1, ..., num_colors - 1]`. Each action corresponds to assigning a color to the current node. + +## Reward + +The reward in the `GraphColoring` environment is given as follows: + +- `sparse reward`: a reward is provided at the end of the episode and equals the negative of the number of unique colors used to color all vertices in the graph. + +The agent's goal is to find a valid coloring using as few colors as possible while avoiding conflicts with adjacent nodes. + +## Episode Termination + +The goal of the agent is to find a valid coloring using as few colors as possible. An episode in the graph coloring environment can terminate under two conditions: + +1. All nodes have been assigned a color: the environment iteratively assigns colors to nodes. When all nodes have a color assigned (i.e., there are no nodes with a color value of -1), the episode ends. This is the natural termination condition and ideally the one we'd like the agent to achieve. + +2. Invalid action is taken: an action is considered invalid if it tries to assign a color to a node that is not within the allowed color set for that node at that time. The allowed color set for each node is updated after every action. If an invalid action is attempted, the episode immediately terminates and the agent receives a large negative reward. This encourages the agent to learn valid actions and discourages it from making invalid actions. + +## Registered Versions 📖 + +- `GraphColoring-v0`: The default settings for the `GraphColoring` problem with a configurable number of nodes and edge_probability. The default number of nodes is 20, and the default edge probability is 0.8. diff --git a/docs/environments/robot_warehouse.md b/docs/environments/robot_warehouse.md new file mode 100644 index 000000000..799e8d442 --- /dev/null +++ b/docs/environments/robot_warehouse.md @@ -0,0 +1,46 @@ +# RobotWarehouse Environment + ++ +
+ +We provide a JAX jit-able implementation of the [Robotic Warehouse](https://github.com/semitable/robotic-warehouse/tree/master) +environment. + +The Robot Warehouse (RWARE) environment simulates a warehouse with robots moving and delivering requested goods. Real-world applications inspire the simulator, in which robots pick up shelves and deliver them to a workstation. Humans access the content of a shelf, and then robots can return them to empty shelf locations. + +The goal is to successfully deliver as many requested shelves in a given time budget. + +Once a shelf has been delivered, a new shelf is requested at random. Agents start each episode at random locations within the warehouse. + +## Observation + +The **observation** seen by the agent is a `NamedTuple` containing the following: + +- `agents_view`: jax array (int32) of shape `(num_agents, num_obs_features)`, array representing the agent's view of other agents + and shelves. + +- `action_mask`: jax array (bool) of shape `(num_agents, 5)`, array specifying, for each agent, + which action (noop, forward, left, right, toggle_load) is legal. + +- `step_count`: jax array (int32) of shape `()`, number of steps elapsed in the current episode. + +## Action + +The action space is a `MultiDiscreteArray` containing an integer value in `[0, 1, 2, 3, 4]` for each +agent. Each agent can take one of five actions: noop (`0`), forward (`1`), turn left (`2`), turn right (`3`), or toggle_load (`4`). + +The episode terminates under the following conditions: + +- An invalid action is taken, or + +- An agent collides with another agent. + +## Reward + +The reward is global and shared among the agents. It is equal to the number of shelves which were +delivered successfully during the time step (i.e., +1 for each shelf). + +## Registered Versions 📖 + +- `RobotWarehouse-v0`, a warehouse with 4 agents each with a sensor range of 1, a warehouse floor with 2 shelf rows, 3 shelf columns, a column height of 8, and a shelf request queue of 8. diff --git a/jumanji/__init__.py b/jumanji/__init__.py index 41651de2b..daba4f7fe 100644 --- a/jumanji/__init__.py +++ b/jumanji/__init__.py @@ -32,6 +32,10 @@ # Game2048 - the game of 2048 with the default board size of 4x4. register(id="Game2048-v1", entry_point="jumanji.environments:Game2048") +# GraphColoring - the graph coloring problem with the default graph of +# 20 number of nodes and 0.8 edge probability. +register(id="GraphColoring-v0", entry_point="jumanji.environments:GraphColoring") + # Minesweeper on a board of size 10x10 with 10 mines. register(id="Minesweeper-v0", entry_point="jumanji.environments:Minesweeper") @@ -104,6 +108,10 @@ # Maze with 10 rows and 10 columns, a time limit of 100 and a random maze generator. register(id="Maze-v0", entry_point="jumanji.environments:Maze") +# RobotWarehouse with a random generator with 2 shelf rows, 3 shelf columns, a column height of 8, +# 4 agents, a sensor range of 1, and a request queue of size 8. +register(id="RobotWarehouse-v0", entry_point="jumanji.environments:RobotWarehouse") + # Snake game on a board of size 12x12 with a time limit of 4000. register(id="Snake-v1", entry_point="jumanji.environments:Snake") diff --git a/jumanji/environments/__init__.py b/jumanji/environments/__init__.py index bf72e6e42..031ad9c8c 100644 --- a/jumanji/environments/__init__.py +++ b/jumanji/environments/__init__.py @@ -16,6 +16,7 @@ from jumanji.environments.logic import game_2048, minesweeper, rubiks_cube from jumanji.environments.logic.game_2048.env import Game2048 +from jumanji.environments.logic.graph_coloring.env import GraphColoring from jumanji.environments.logic.minesweeper import Minesweeper from jumanji.environments.logic.rubiks_cube import RubiksCube from jumanji.environments.logic.sudoku import Sudoku @@ -23,11 +24,20 @@ from jumanji.environments.packing.bin_pack.env import BinPack from jumanji.environments.packing.job_shop.env import JobShop from jumanji.environments.packing.knapsack.env import Knapsack -from jumanji.environments.routing import cleaner, connector, cvrp, maze, snake, tsp +from jumanji.environments.routing import ( + cleaner, + connector, + cvrp, + maze, + robot_warehouse, + snake, + tsp, +) from jumanji.environments.routing.cleaner.env import Cleaner from jumanji.environments.routing.connector.env import Connector from jumanji.environments.routing.cvrp.env import CVRP from jumanji.environments.routing.maze.env import Maze +from jumanji.environments.routing.robot_warehouse.env import RobotWarehouse from jumanji.environments.routing.snake.env import Snake from jumanji.environments.routing.tsp.env import TSP diff --git a/jumanji/environments/logic/graph_coloring/__init__.py b/jumanji/environments/logic/graph_coloring/__init__.py new file mode 100644 index 000000000..5fbf60a42 --- /dev/null +++ b/jumanji/environments/logic/graph_coloring/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jumanji.environments.logic.graph_coloring.env import GraphColoring +from jumanji.environments.logic.graph_coloring.types import Observation, State diff --git a/jumanji/environments/logic/graph_coloring/conftest.py b/jumanji/environments/logic/graph_coloring/conftest.py new file mode 100644 index 000000000..1a669c224 --- /dev/null +++ b/jumanji/environments/logic/graph_coloring/conftest.py @@ -0,0 +1,23 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from jumanji.environments.logic.graph_coloring import GraphColoring + + +@pytest.fixture +def graph_coloring() -> GraphColoring: + """Instantiates a default GraphColoring environment.""" + return GraphColoring() diff --git a/jumanji/environments/logic/graph_coloring/env.py b/jumanji/environments/logic/graph_coloring/env.py new file mode 100644 index 000000000..36970d7da --- /dev/null +++ b/jumanji/environments/logic/graph_coloring/env.py @@ -0,0 +1,316 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Sequence, Tuple + +import chex +import jax +import jax.numpy as jnp +import matplotlib.animation as animation +from jax import lax +from numpy.typing import NDArray + +from jumanji import specs +from jumanji.env import Environment +from jumanji.environments.logic.graph_coloring.generator import ( + Generator, + RandomGenerator, +) +from jumanji.environments.logic.graph_coloring.types import Observation, State +from jumanji.environments.logic.graph_coloring.viewer import GraphColoringViewer +from jumanji.types import TimeStep, restart, termination, transition +from jumanji.viewer import Viewer + + +class GraphColoring(Environment[State]): + """Environment for the GraphColoring problem. + The problem is a combinatorial optimization task where the goal is + to assign a color to each vertex of a graph + in such a way that no two adjacent vertices share the same color. + The problem is usually formulated as minimizing the number of colors used. + + - observation: `Observation` + - adj_matrix: jax array (bool) of shape (num_nodes, num_nodes), + representing the adjacency matrix of the graph. + - colors: jax array (int32) of shape (num_nodes,), + representing the current color assignments for the vertices. + - action_mask: jax array (bool) of shape (num_colors,), + indicating which actions are valid in the current state of the environment. + - current_node_index: integer representing the current node being colored. + + - action: int, the color to be assigned to the current node (0 to num_nodes - 1) + + - reward: float, a sparse reward is provided at the end of the episode. + Equals the negative of the number of unique colors used to color all vertices in the graph. + If an invalid action is taken, the reward is the negative of the total number of colors. + + - episode termination: + - if all nodes have been assigned a color or if an invalid action is taken. + + - state: `State` + - adj_matrix: jax array (bool) of shape (num_nodes, num_nodes), + representing the adjacency matrix of the graph. + - colors: jax array (int32) of shape (num_nodes,), + color assigned to each node, -1 if not assigned. + - current_node_index: jax array (int) with shape (), + index of the current node. + - action_mask: jax array (bool) of shape (num_colors,), + indicating which actions are valid in the current state of the environment. + - key: jax array (uint32) of shape (2,), + random key used to generate random numbers at each step and for auto-reset. + + ```python + from jumanji.environments import GraphColoring + env = GraphColoring() + key = jax.random.key(0) + state, timestep = jax.jit(env.reset)(key) + env.render(state) + action = env.action_spec().generate_value() + state, timestep = jax.jit(env.step)(state, action) + env.render(state) + ``` + """ + + def __init__( + self, + generator: Optional[Generator] = None, + viewer: Optional[Viewer[State]] = None, + ): + """Instantiate a `GraphColoring` environment. + + Args: + generator: callable to instantiate environment instances. + Defaults to `RandomGenerator` which generates graphs with + 20 `num_nodes` and `edge_probability` equal to 0.8. + viewer: environment viewer for rendering. + Defaults to `GraphColoringViewer`. + """ + self.generator = generator or RandomGenerator( + num_nodes=20, edge_probability=0.8 + ) + self.num_nodes = self.generator.num_nodes + + # Create viewer used for rendering + self._env_viewer = viewer or GraphColoringViewer(name="GraphColoring") + + def __repr__(self) -> str: + return repr(self.generator) + + def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: + """Resets the environment to an initial state. + + Returns: + The initial state and timestep. + """ + colors = jnp.full(self.num_nodes, -1, dtype=jnp.int32) + key, subkey = jax.random.split(key) + adj_matrix = self.generator(subkey) + + action_mask = jnp.ones(self.num_nodes, dtype=bool) + current_node_index = jnp.array(0, jnp.int32) + state = State( + adj_matrix=adj_matrix, + colors=colors, + current_node_index=current_node_index, + action_mask=action_mask, + key=key, + ) + obs = Observation( + adj_matrix=adj_matrix, + colors=colors, + action_mask=action_mask, + current_node_index=current_node_index, + ) + timestep = restart(observation=obs) + + return state, timestep + + def step( + self, state: State, action: chex.Array + ) -> Tuple[State, TimeStep[Observation]]: + """Updates the environment state after the agent takes an action. + + Specifically, this function allows the agent to choose + a color for the current node (based on the action taken) + in a graph coloring problem. + It then updates the state of the environment based on + the color chosen and calculates the reward based on + the validity of the action and the completion of the coloring task. + + Args: + state: the current state of the environment. + action: the action taken by the agent. + + Returns: + state: the new state of the environment. + timestep: the next timestep. + """ + # Get the valid actions for the current state. + valid_actions = state.action_mask + + # Check if the chosen action is invalid (not in valid_actions). + invalid_action_taken = jnp.logical_not(valid_actions[action]) + + # Update the colors array with the chosen action. + colors = state.colors.at[state.current_node_index].set(action) + + # Determine if all nodes have been assigned a color + all_nodes_colored = jnp.all(colors >= 0) + + # Calculate the reward + unique_colors_used = jnp.unique(colors, size=self.num_nodes, fill_value=-1) + num_unique_colors = jnp.count_nonzero(unique_colors_used >= 0) + reward = jnp.where(all_nodes_colored, -num_unique_colors, 0.0) + + # Apply the maximum penalty when an invalid action is taken and terminate the episode + reward = jnp.where(invalid_action_taken, -self.num_nodes, reward) + done = jnp.logical_or(all_nodes_colored, invalid_action_taken) + + # Update the current node index + next_node_index = (state.current_node_index + 1) % self.num_nodes + + next_action_mask = self._get_valid_actions( + next_node_index, state.adj_matrix, state.colors + ) + + next_state = State( + adj_matrix=state.adj_matrix, + colors=colors, + current_node_index=next_node_index, + action_mask=next_action_mask, + key=state.key, + ) + obs = Observation( + adj_matrix=state.adj_matrix, + colors=colors, + action_mask=next_state.action_mask, + current_node_index=next_node_index, + ) + timestep = lax.cond( + done, + termination, + transition, + reward, + obs, + ) + return next_state, timestep + + def observation_spec(self) -> specs.Spec[Observation]: + """Returns the observation spec. + + Returns: + Spec for the `Observation` whose fields are: + - adj_matrix: BoundedArray (bool) of shape (num_nodes, num_nodes). + Represents the adjacency matrix of the graph. + - action_mask: BoundedArray (bool) of shape (num_nodes,). + Represents the valid actions in the current state. + - colors: BoundedArray (int32) of shape (num_nodes,). + Represents the colors assigned to each node. + - current_node_index: BoundedArray (int32) of shape (). + Represents the index of the current node. + """ + return specs.Spec( + Observation, + "ObservationSpec", + adj_matrix=specs.BoundedArray( + shape=(self.num_nodes, self.num_nodes), + dtype=bool, + minimum=False, + maximum=True, + name="adj_matrix", + ), + action_mask=specs.BoundedArray( + shape=(self.num_nodes,), + dtype=bool, + minimum=False, + maximum=True, + name="action_mask", + ), + colors=specs.BoundedArray( + shape=(self.num_nodes,), + dtype=jnp.int32, + minimum=-1, + maximum=self.num_nodes - 1, + name="colors", + ), + current_node_index=specs.BoundedArray( + shape=(), + dtype=jnp.int32, + minimum=0, + maximum=self.num_nodes - 1, + name="current_node_index", + ), + ) + + def action_spec(self) -> specs.DiscreteArray: + """Specification of the action for the `GraphColoring` environment. + + Returns: + action_spec: specs.DiscreteArray object + """ + return specs.DiscreteArray( + num_values=self.num_nodes, name="action", dtype=jnp.int32 + ) + + def _get_valid_actions( + self, current_node_index: int, adj_matrix: chex.Array, colors: chex.Array + ) -> chex.Array: + """Returns a boolean array indicating the valid colors for the current node.""" + # Create a boolean array of size (num_nodes + 1) set to True. + # The extra element is to accommodate for the -1 index + # which represents nodes that have not been colored yet. + valid_actions = jnp.ones(self.num_nodes + 1, dtype=bool) + row = adj_matrix[current_node_index, :] + action_mask = jnp.where(row, colors, -1) + valid_actions = valid_actions.at[action_mask].set(False) + + # Exclude the last element (which corresponds to -1 index) + return valid_actions[:-1] + + def render(self, state: State) -> Optional[NDArray]: + """Renders the current state of the `GraphColoring` environment. + + Args: + state: is the current game state to be rendered. + """ + return self._env_viewer.render(state=state) + + def animate( + self, + states: Sequence[State], + interval: int = 200, + save_path: Optional[str] = None, + ) -> animation.FuncAnimation: + """Creates an animated gif of the `GraphColoring` environment based on the sequence of game states. + + Args: + states: is a list of `State` objects representing the sequence of game states. + interval: the delay between frames in milliseconds, default to 200. + save_path: the path where the animation file should be saved. If it is None, the plot + will not be stored. + + Returns: + animation.FuncAnimation: the animation object that was created. + """ + return self._env_viewer.animate( + states=states, interval=interval, save_path=save_path + ) + + def close(self) -> None: + """Perform any necessary cleanup. + + Environments will automatically :meth:`close()` themselves when + garbage collected or when the program exits. + """ + self._env_viewer.close() diff --git a/jumanji/environments/logic/graph_coloring/env_test.py b/jumanji/environments/logic/graph_coloring/env_test.py new file mode 100644 index 000000000..d0418da77 --- /dev/null +++ b/jumanji/environments/logic/graph_coloring/env_test.py @@ -0,0 +1,92 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import chex +import jax +import jax.numpy as jnp + +from jumanji.environments.logic.graph_coloring import GraphColoring +from jumanji.environments.logic.graph_coloring.types import State +from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.pytrees import assert_is_jax_array_tree +from jumanji.types import TimeStep + + +def test_graph_coloring_reset_jit(graph_coloring: GraphColoring) -> None: + """Confirm that the reset method is only compiled once when jitted.""" + chex.clear_trace_counter() + reset_fn = jax.jit(chex.assert_max_traces(graph_coloring.reset, n=1)) + key = jax.random.PRNGKey(0) + state, timestep = reset_fn(key) + + # Verify the data type of the output. + assert isinstance(timestep, TimeStep) + assert isinstance(state, State) + + # Check that the state is made of DeviceArrays, this is false for the non-jitted. + assert_is_jax_array_tree(state.adj_matrix) + assert_is_jax_array_tree(state.colors) + + # Call again to check it does not compile twice. + state, timestep = reset_fn(key) + assert isinstance(timestep, TimeStep) + assert isinstance(state, State) + + +def test_graph_coloring_step_jit(graph_coloring: GraphColoring) -> None: + """Confirm that the step is only compiled once when jitted.""" + key = jax.random.PRNGKey(0) + state, timestep = jax.jit(graph_coloring.reset)(key) + action = jnp.array(0) + + chex.clear_trace_counter() + step_fn = jax.jit(chex.assert_max_traces(graph_coloring.step, n=1)) + + new_state, next_timestep = step_fn(state, action) + + # Check that the state has changed. + assert not jnp.array_equal(new_state.colors, state.colors) + + # Check that the state is made of DeviceArrays, this is false for the non-jitted. + assert_is_jax_array_tree(new_state) + + # New step + state = new_state + new_state, next_timestep = step_fn(state, action) + + # Check that the state has changed + assert not jnp.array_equal(new_state.colors, state.colors) + + +def test_graph_coloring_get_action_mask(graph_coloring: GraphColoring) -> None: + """Verify that the action mask generated by `_get_valid_actions` is correct.""" + key = jax.random.PRNGKey(0) + state, _ = graph_coloring.reset(key) + num_nodes = graph_coloring.generator.num_nodes + get_valid_actions_fn = jax.jit(graph_coloring._get_valid_actions) + action_mask = get_valid_actions_fn( + state.current_node_index, state.adj_matrix, state.colors + ) + + # Check that the action mask is a boolean array with the correct shape. + assert action_mask.dtype == jnp.bool_ + assert action_mask.shape == (num_nodes,) + + # For this specific test case, we don't have any pre-defined expected action_mask, + # as the graph and colors are randomly generated. + + +def test_graph_coloring_does_not_smoke(graph_coloring: GraphColoring) -> None: + """Test that we can run an episode without any errors.""" + check_env_does_not_smoke(graph_coloring) diff --git a/jumanji/environments/logic/graph_coloring/generator.py b/jumanji/environments/logic/graph_coloring/generator.py new file mode 100644 index 000000000..4aab23c10 --- /dev/null +++ b/jumanji/environments/logic/graph_coloring/generator.py @@ -0,0 +1,108 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc + +import chex +import jax +from jax import numpy as jnp + + +class Generator(abc.ABC): + @property + @abc.abstractmethod + def num_nodes(self) -> int: + """Number of nodes of the problem instances generated. + + Returns: + `num_nodes` of the generated instances. + """ + + @abc.abstractmethod + def __call__(self, key: chex.PRNGKey) -> chex.Array: + """Generate a problem instance. + + Args: + key: jax random key for any stochasticity used in the instance generation process. + + Returns: + An `adj_matrix` representing a problem instance. + """ + + +class RandomGenerator(Generator): + """A generator for random graphs in the context of graph coloring problems, + based on the Erdős-Rényi model (G(n, p)). + + The adjacency matrix is generated such that the graph is undirected and loop-less. + The graph is generated with a specified number of nodes and percentage of connectivity, + which is used as a proxy for the edge probability in the Erdős-Rényi model. + """ + + def __init__(self, num_nodes: int, edge_probability: float): + """Initialize the RandomGraphColoringGenerator. + + Args: + num_nodes: The number of nodes in the graph. The number of colors available for + coloring is equal to the number of nodes. This means that the graph is always + colorable with the given colors. + edge_probability: A float between 0 and 1 representing the percentage of connections + in the graph compared to a fully connected graph. + """ + + self._num_nodes = num_nodes + self.edge_probability = edge_probability + assert ( + 0 < self.edge_probability < 1 + ), f"edge_probability={self.edge_probability} must be between 0 and 1." + + @property + def num_nodes(self) -> int: + return self._num_nodes + + def __repr__(self) -> str: + return ( + f"GraphColoring(number of nodes={self.num_nodes}, " + f"percent connected={self.edge_probability * 100}%)" + ) + + def __call__(self, key: chex.PRNGKey) -> chex.Array: + """Generate a random graph adjacency matrix representing + the edges of an undirected graph using the Erdős-Rényi model G(n, p). + + Args: + key: PRNGKey used for stochasticity in the generation process. + + Returns: + adj_matrix: a boolean array of shape (num_nodes, num_nodes) representing + the adjacency matrix of the graph, where adj_matrix[i, j] is True if + there is an edge between nodes i and j, and False otherwise. + """ + key, edge_key = jax.random.split(key) + + # Generate a random adjacency matrix with probabilities of connections. + p_matrix = jax.random.uniform( + key=edge_key, shape=(self.num_nodes, self.num_nodes) + ) + + # Threshold the probabilities to create a boolean adjacency matrix. + adj_matrix = p_matrix < self.edge_probability + + # Make sure the graph is undirected (symmetric) and without self-loops. + adj_matrix = jnp.tril(adj_matrix, k=-1) # Keep only the lower triangular part. + + # Copy the lower triangular part to the upper triangular part. + adj_matrix += adj_matrix.T + + return adj_matrix diff --git a/jumanji/environments/logic/graph_coloring/types.py b/jumanji/environments/logic/graph_coloring/types.py new file mode 100644 index 000000000..90b558de2 --- /dev/null +++ b/jumanji/environments/logic/graph_coloring/types.py @@ -0,0 +1,54 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +import chex + +if TYPE_CHECKING: + from dataclasses import dataclass +else: + from chex import dataclass +from typing import NamedTuple + + +@dataclass +class State: + """ + adj_matrix: adjacency matrix used to represent the graph. + colors: array giving the color index of each node. + current_node_index: current node being colored. + action_mask: binary mask indicating the validity of assigning a color to the current node. + key: random key used for auto-reset. + """ + + adj_matrix: chex.Array # (num_nodes, num_nodes) + colors: chex.Array # (num_nodes,) + current_node_index: chex.Numeric # () + action_mask: chex.Array # (num_colors,) + key: chex.PRNGKey # (2,) + + +class Observation(NamedTuple): + """ + adj_matrix: adjacency matrix used to represent the graph. + colors: array giving the color index of each node. + current_node_index: current node being colored. + action_mask: binary mask indicating the validity of assigning a color to the current node. + """ + + adj_matrix: chex.Array # (num_nodes, num_nodes) + colors: chex.Array # (num_nodes,) + current_node_index: chex.Numeric # () + action_mask: chex.Array # (num_colors,) diff --git a/jumanji/environments/logic/graph_coloring/viewer.py b/jumanji/environments/logic/graph_coloring/viewer.py new file mode 100644 index 000000000..af5d5f179 --- /dev/null +++ b/jumanji/environments/logic/graph_coloring/viewer.py @@ -0,0 +1,244 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Sequence, Tuple + +import chex +import matplotlib.animation as animation +import matplotlib.cm as cm +import matplotlib.pyplot as plt +import numpy as np + +import jumanji.environments +from jumanji.environments.logic.graph_coloring.types import State +from jumanji.viewer import Viewer + + +class GraphColoringViewer(Viewer): + def __init__( + self, + name: str = "GraphColoring", + ) -> None: + self._name = name + self._animation: Optional[animation.Animation] = None + + def render( + self, + state: State, + save_path: Optional[str] = None, + ax: Optional[plt.Axes] = None, + ) -> None: + num_nodes = state.adj_matrix.shape[0] + self.node_scale = self._calculate_node_scale(num_nodes) + self._color_mapping = self._create_color_mapping(num_nodes) + + self._clear_display() + fig, ax = self._get_fig_ax(ax) + pos = self._spring_layout(state.adj_matrix, num_nodes) + self._render_nodes(ax, pos, state.colors) + self._render_edges(ax, pos, state.adj_matrix, num_nodes) + + ax.set_xlim(-0.5, 0.50) + ax.set_ylim(-0.50, 0.50) + ax.set_aspect("equal") + ax.axis("off") + + if save_path: + fig.savefig(save_path, bbox_inches="tight", pad_inches=0.2) + + self._display_human(fig) + + def animate( + self, + states: Sequence[State], + interval: int = 500, + save_path: Optional[str] = None, + ) -> animation.FuncAnimation: + num_nodes = states[0].adj_matrix.shape[0] + self.node_scale = self._calculate_node_scale(num_nodes) + self._color_mapping = self._create_color_mapping(num_nodes) + + fig, ax = self._get_fig_ax(ax=None) + plt.title(f"{self._name}") + + def make_frame(state_index: int) -> None: + state = states[state_index] + self.render(state, ax=ax) + + _animation = animation.FuncAnimation( + fig, make_frame, frames=len(states), interval=interval, blit=False + ) + + if save_path: + _animation.save(save_path) + + return _animation + + def close(self) -> None: + plt.close(self._name) + + def _display_human(self, fig: plt.Figure) -> None: + if plt.isinteractive(): + # Required to update render when using Jupyter Notebook. + fig.canvas.draw() + if jumanji.environments.is_colab(): + plt.show(self._name) + else: + # Required to update render when not using Jupyter Notebook. + fig.canvas.draw_idle() + fig.canvas.flush_events() + + def _clear_display(self) -> None: + if jumanji.environments.is_colab(): + import IPython.display + + IPython.display.clear_output(True) + + def _compute_repulsive_forces( + self, repulsive_forces: np.ndarray, pos: np.ndarray, k: float, num_nodes: int + ) -> np.ndarray: + for i in range(num_nodes): + for j in range(i + 1, num_nodes): + delta = pos[i] - pos[j] + distance = np.linalg.norm(delta) + direction = delta / (distance + 1e-6) + force = k * k / (distance + 1e-6) + repulsive_forces[i] += direction * force + repulsive_forces[j] -= direction * force + + return repulsive_forces + + def _compute_attractive_forces( + self, + graph: chex.Array, + attractive_forces: np.ndarray, + pos: np.ndarray, + k: float, + num_nodes: int, + ) -> np.ndarray: + for i in range(num_nodes): + for j in range(num_nodes): + if graph[i, j]: + delta = pos[i] - pos[j] + distance = np.linalg.norm(delta) + direction = delta / (distance + 1e-6) + force = distance * distance / k + attractive_forces[i] -= direction * force + attractive_forces[j] += direction * force + + return attractive_forces + + def _spring_layout( + self, graph: chex.Array, num_nodes: int, seed: int = 42 + ) -> List[Tuple[float, float]]: + """ + Compute a 2D spring layout for the given graph using + the Fruchterman-Reingold force-directed algorithm. + + The algorithm computes a layout by simulating the graph as a physical system, + where nodes are repelling each other and edges are attracting connected nodes. + The method minimizes the energy of the system over several iterations. + + Args: + graph: A Graph object representing the adjacency matrix of the graph. + seed: An integer used to seed the random number generator for reproducibility. + + Returns: + A list of tuples representing the 2D positions of nodes in the graph. + """ + rng = np.random.default_rng(seed) + pos = rng.random((num_nodes, 2)) * 2 - 1 + + iterations = 100 + k = np.sqrt(5 / num_nodes) + temperature = 2.0 # Added a temperature variable + + for _ in range(iterations): + repulsive_forces = self._compute_repulsive_forces( + np.zeros((num_nodes, 2)), pos, k, num_nodes + ) + attractive_forces = self._compute_attractive_forces( + graph, np.zeros((num_nodes, 2)), pos, k, num_nodes + ) + + pos += (repulsive_forces + attractive_forces) * temperature + # Reduce the temperature (cooling factor) to refine the layout. + temperature *= 0.9 + + pos = np.clip(pos, -1, 1) # Keep positions within the [-1, 1] range + + return [(float(p[0]), float(p[1])) for p in pos] + + def _get_fig_ax(self, ax: Optional[plt.Axes]) -> Tuple[plt.Figure, plt.Axes]: + if ax is None: + fig, ax = plt.subplots(figsize=(self.node_scale, self.node_scale)) + plt.title(f"{self._name}") + else: + fig = ax.figure + ax.clear() + return fig, ax + + def _render_nodes( + self, ax: plt.Axes, pos: List[Tuple[float, float]], colors: chex.Array + ) -> None: + # Set the radius of the nodes as a fraction of the scale, + # so nodes appear smaller when there are more of them. + node_radius = 0.05 * 5 / self.node_scale + + for i, (x, y) in enumerate(pos): + ax.add_artist( + plt.Circle( + (x, y), + node_radius, + color=self._color_mapping[colors[i]], + fill=(colors[i] != -1), + ) + ) + ax.text( + x, y, str(i), color="white", ha="center", va="center", weight="bold" + ) + + def _render_edges( + self, + ax: plt.Axes, + pos: List[Tuple[float, float]], + adj_matrix: chex.Array, + num_nodes: int, + ) -> None: + for i in range(num_nodes): + for j in range(i + 1, num_nodes): + if adj_matrix[i, j]: + ax.plot( + [pos[i][0], pos[j][0]], + [pos[i][1], pos[j][1]], + color=self._color_mapping[-1], + linewidth=0.5, + ) + + def _calculate_node_scale(self, num_nodes: int) -> int: + # Set the scale of the graph based on the number of nodes, + # so the graph grows (at a decelerating rate) with more nodes. + return 5 + int(np.sqrt(num_nodes)) + + def _create_color_mapping( + self, + num_nodes: int, + ) -> List[Tuple[float, float, float, float]]: + colormap_indices = np.arange(0, 1, 1 / num_nodes) + colormap = cm.get_cmap("hsv", num_nodes + 1) + color_mapping = [] + for colormap_idx in colormap_indices: + color_mapping.append(colormap(colormap_idx)) + color_mapping.append((0.0, 0.0, 0.0, 1.0)) # Adding black to the color mapping + return color_mapping diff --git a/jumanji/environments/logic/graph_coloring/viewer_test.py b/jumanji/environments/logic/graph_coloring/viewer_test.py new file mode 100644 index 000000000..657eb5da3 --- /dev/null +++ b/jumanji/environments/logic/graph_coloring/viewer_test.py @@ -0,0 +1,68 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pathlib + +import jax.numpy as jnp +import jax.random as random +import matplotlib.animation +import matplotlib.pyplot as plt +import pytest + +from jumanji.environments.logic.graph_coloring import GraphColoring + + +def test_render(monkeypatch: pytest.MonkeyPatch, graph_coloring: GraphColoring) -> None: + """Check that the render method builds the figure but does not display it.""" + monkeypatch.setattr(plt, "show", lambda fig: None) + key = random.PRNGKey(0) + state, _ = graph_coloring.reset(key) + + graph_coloring.render(state) + graph_coloring.close() + + +def test_animate(graph_coloring: GraphColoring) -> None: + """Check that the animation method creates the animation correctly and can save to a gif.""" + key = random.PRNGKey(0) + state, _ = graph_coloring.reset(key) + + num_steps = 5 + states = [state] + for _ in range(num_steps - 1): + action = jnp.array(0) + new_state, _ = graph_coloring.step(state, action) + states.append(new_state) + state = new_state + + animation = graph_coloring.animate(states, interval=500) + assert isinstance(animation, matplotlib.animation.Animation) + + +def test_save_animation(tmp_path: pathlib.Path, graph_coloring: GraphColoring) -> None: + key = random.PRNGKey(0) + state, _ = graph_coloring.reset(key) + + num_steps = 5 + states = [state] + for _ in range(num_steps - 1): + action = jnp.array(0) + new_state, _ = graph_coloring.step(state, action) + states.append(new_state) + state = new_state + + save_path = tmp_path / "animation_test.gif" + graph_coloring.animate(states, interval=500, save_path=str(save_path)) + + assert save_path.exists() diff --git a/jumanji/environments/routing/robot_warehouse/__init__.py b/jumanji/environments/routing/robot_warehouse/__init__.py new file mode 100644 index 000000000..734a3994e --- /dev/null +++ b/jumanji/environments/routing/robot_warehouse/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jumanji.environments.routing.robot_warehouse.env import RobotWarehouse +from jumanji.environments.routing.robot_warehouse.types import Observation, State diff --git a/jumanji/environments/routing/robot_warehouse/conftest.py b/jumanji/environments/routing/robot_warehouse/conftest.py new file mode 100644 index 000000000..90270450e --- /dev/null +++ b/jumanji/environments/routing/robot_warehouse/conftest.py @@ -0,0 +1,99 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import jax +import jax.numpy as jnp +import pytest + +from jumanji.environments.routing.robot_warehouse import RobotWarehouse +from jumanji.environments.routing.robot_warehouse.generator import RandomGenerator +from jumanji.environments.routing.robot_warehouse.types import ( + Agent, + Position, + Shelf, + State, +) +from jumanji.types import TimeStep + + +@pytest.fixture(scope="module") +def robot_warehouse_env() -> RobotWarehouse: + """Instantiates a default RobotWarehouse environment with 2 agents, 1 shelf row, 3 shelf columns, + a column height of 2, sensor range of 1 and a request queue size of 4.""" + generator = RandomGenerator( + shelf_rows=1, + shelf_columns=3, + column_height=2, + num_agents=2, + sensor_range=1, + request_queue_size=4, + ) + + env = RobotWarehouse( + generator=generator, + time_limit=5, + ) + return env + + +@pytest.fixture +def deterministic_robot_warehouse_env( + robot_warehouse_env: RobotWarehouse, +) -> Tuple[RobotWarehouse, State, TimeStep]: + """Instantiates a RobotWarehouse environment with 2 agents and 8 shelves + with a step limit of 5.""" + state, timestep = robot_warehouse_env.reset(jax.random.PRNGKey(42)) + + # create agents, shelves and grid + def make_agent(x: int, y: int, direction: int, is_carrying: int) -> Agent: + return Agent(Position(x=x, y=y), direction=direction, is_carrying=is_carrying) + + def make_shelf(x: int, y: int, is_requested: int) -> Shelf: + return Shelf(Position(x=x, y=y), is_requested=is_requested) + + # agent information + xs = jnp.array([3, 1]) + ys = jnp.array([4, 7]) + dirs = jnp.array([2, 3]) + carries = jnp.array([0, 0]) + state.agents = jax.vmap(make_agent)(xs, ys, dirs, carries) + + # shelf information + xs = jnp.array([1, 1, 1, 1, 2, 2, 2, 2]) + ys = jnp.array([1, 2, 7, 8, 1, 2, 7, 8]) + requested = jnp.array([0, 1, 1, 0, 0, 0, 1, 1]) + state.shelves = jax.vmap(make_shelf)(xs, ys, requested) + + # create grid + state.grid = jnp.array( + [ + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 0, 0, 0, 0, 3, 4, 0], + [0, 5, 6, 0, 0, 0, 0, 7, 8, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 2, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ) + return robot_warehouse_env, state, timestep diff --git a/jumanji/environments/routing/robot_warehouse/constants.py b/jumanji/environments/routing/robot_warehouse/constants.py new file mode 100644 index 000000000..aae7f3b08 --- /dev/null +++ b/jumanji/environments/routing/robot_warehouse/constants.py @@ -0,0 +1,37 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax.numpy as jnp + +from jumanji.environments.routing.robot_warehouse.types import Direction + +# grid channels +_SHELVES = 0 +_AGENTS = 1 + +# agent directions +_POSSIBLE_DIRECTIONS = jnp.array([d.value for d in Direction]) + +# viewer constants +_FIGURE_SIZE = (5, 5) +_SHELF_PADDING = 2 + +# colors +_GRID_COLOR = (0, 0, 0) # black +_SHELF_COLOR = (72 / 255.0, 61 / 255.0, 139 / 255.0) # dark slate blue +_SHELF_REQ_COLOR = (0, 128 / 255.0, 128 / 255.0) # teal +_AGENT_COLOR = (1, 140 / 255.0, 0) # dark orange +_AGENT_LOADED_COLOR = (1, 0, 0) # red +_AGENT_DIR_COLOR = (0, 0, 0) # black +_GOAL_COLOR = (60 / 255.0, 60 / 255.0, 60 / 255.0) diff --git a/jumanji/environments/routing/robot_warehouse/env.py b/jumanji/environments/routing/robot_warehouse/env.py new file mode 100644 index 000000000..ad9d4ec80 --- /dev/null +++ b/jumanji/environments/routing/robot_warehouse/env.py @@ -0,0 +1,545 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import List, Optional, Sequence, Tuple + +import chex +import jax +import jax.numpy as jnp +import matplotlib +from numpy.typing import NDArray + +from jumanji import specs +from jumanji.env import Environment +from jumanji.environments.routing.robot_warehouse import utils +from jumanji.environments.routing.robot_warehouse.constants import _SHELVES +from jumanji.environments.routing.robot_warehouse.generator import ( + Generator, + RandomGenerator, +) +from jumanji.environments.routing.robot_warehouse.types import ( + Action, + Agent, + Direction, + Observation, + Shelf, + State, +) +from jumanji.environments.routing.robot_warehouse.utils_agent import ( + set_new_direction_after_turn, + set_new_position_after_forward, +) +from jumanji.environments.routing.robot_warehouse.utils_shelf import update_shelf +from jumanji.environments.routing.robot_warehouse.viewer import RobotWarehouseViewer +from jumanji.tree_utils import tree_slice +from jumanji.types import TimeStep, restart, termination, transition +from jumanji.viewer import Viewer + + +class RobotWarehouse(Environment[State]): + """A JAX implementation of the 'Robotic warehouse' environment: + https://github.com/semitable/robotic-warehouse + which is described in the paper [1]. + + Creates a grid world where multiple agents (robots) + are supposed to collect shelves, bring them to a goal + and then return them. + + Below is an example warehouse floor grid: + the grid layout is instantiated using three arguments + + - shelf_rows: number of vertical shelf clusters + - shelf_columns: odd number of horizontal shelf clusters + - column_height: height of each cluster + + A cluster is a set of grouped shelves (two cells wide) represented + below as + + XX + Shelf cluster -> XX (this cluster is of height 3) + XX + + Grid Layout: + + shelf columns (here set to 3, i.e. + v v v shelf_columns=3, must be an odd number) + ---------- + > -XX-XX-XX- ^ + Shelf Row 1 -> -XX-XX-XX- Column Height (here set to 3, i.e. + > -XX-XX-XX- v column_height=3) + ---------- + -XX----XX- < + -XX----XX- <- Shelf Row 2 (here set to 2, i.e. + -XX----XX- < shelf_rows=2) + ---------- + ----GG---- + + - G: is the goal positions where agents are rewarded if + they successfully deliver a requested shelf (i.e toggle the load action + inside the goal position while carrying a requested shelf). + + The final grid size will be + - height: (column_height + 1) * shelf_rows + 2 + - width: (2 + 1) * shelf_columns + 1 + + The bottom-middle column is removed to allow for + agents to queue in front of the goal positions + + - action: jax array (int) of shape (num_agents,) containing the action for each agent. + (0: noop, 1: forward, 2: left, 3: right, 4: toggle_load) + + - reward: jax array (int) of shape (), global reward shared by all agents, +1 + for every successful delivery of a requested shelf to the goal position. + + - episode termination: + - The number of steps is greater than the limit. + - Any agent selects an action which causes two agents to collide. + + - state: State + - grid: an array representing the warehouse floor as a 2D grid with two separate channels + one for the agents, and one for the shelves + - agents: a pytree of Agent type with per agent leaves: [position, direction, is_carrying] + - shelves: a pytree of Shelf type with per shelf leaves: [position, is_requested] + - request_queue: the queue of requested shelves (by ID). + - step_count: an integer representing the current step of the episode. + - action_mask: an array of shape (num_agents, 5) containing the valid actions + for each agent. + - key: a pseudorandom number generator key. + + [1] Papoudakis et al., Benchmarking Multi-Agent Deep Reinforcement Learning Algorithms + in Cooperative Tasks (2021) + + ```python + from jumanji.environments import RobotWarehouse + env = RobotWarehouse() + key = jax.random.PRNGKey(0) + state, timestep = jax.jit(env.reset)(key) + env.render(state) + action = env.action_spec().generate_value() + state, timestep = jax.jit(env.step)(state, action) + env.render(state) + ``` + """ + + def __init__( + self, + generator: Optional[Generator] = None, + time_limit: int = 500, + viewer: Optional[Viewer[State]] = None, + ): + """Instantiates an `RobotWarehouse` environment. + + Args: + generator: callable to instantiate environment instances. + Defaults to `RandomGenerator` with parameters: + `shelf_rows = 2`, + `shelf_columns = 3`, + `column_height = 8`, + `num_agents = 4`, + `sensor_range = 1`, + `request_queue_size = 8`. + time_limit: the maximum step limit allowed within the environment. + Defaults to 500. + viewer: viewer to render the environment. Defaults to `RobotWarehouseViewer`. + """ + + # default generator is: robot_warehouse-tiny-4ag-easy (in original implementation) + self._generator = generator or RandomGenerator( + column_height=8, + shelf_rows=2, + shelf_columns=3, + num_agents=4, + sensor_range=1, + request_queue_size=8, + ) + + self.goals: List[Tuple[int, int]] = [] + self.grid_size = self._generator.grid_size + self.request_queue_size = self._generator.request_queue_size + + self.num_agents = self._generator.num_agents + self.sensor_range = self._generator.sensor_range + self.highways = self._generator.highways + self.shelf_ids = self._generator.shelf_ids + self.not_in_queue_size = self._generator.not_in_queue_size + + self.agent_ids = jnp.arange(self.num_agents) + self.directions = jnp.array([d.value for d in Direction]) + self.num_obs_features = utils.calculate_num_observation_features( + self.sensor_range + ) + self.goals = self._generator.goals + self.time_limit = time_limit + + # create viewer for rendering environment + self._viewer = viewer or RobotWarehouseViewer( + self.grid_size, self.goals, "RobotWarehouse" + ) + + def __repr__(self) -> str: + return ( + f"RobotWarehouse(\n" + f"\tgrid_width={self.grid_size[1]!r},\n" + f"\tgrid_height={self.grid_size[0]!r},\n" + f"\tnum_agents={self.num_agents!r}, \n" + ")" + ) + + def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: + """Resets the environment. + + Args: + key: random key used to reset the environment since it is stochastic. + + Returns: + state: State object corresponding to the new state of the environment. + timestep: TimeStep object corresponding the first timestep returned by the environment. + """ + # create environment state + state = self._generator(key) + + # collect first observations and create timestep + agents_view = self._make_observations(state.grid, state.agents, state.shelves) + observation = Observation( + agents_view=agents_view, + action_mask=state.action_mask, + step_count=state.step_count, + ) + timestep = restart(observation=observation) + return state, timestep + + def step( + self, + state: State, + action: chex.Array, + ) -> Tuple[State, TimeStep[Observation]]: + """Perform an environment step. + + Args: + state: State object containing the dynamics of the environment. + action: Array containing the action to take. + - 0 no op + - 1 move forward + - 2 turn left + - 3 turn right + - 4 toggle load + + Returns: + state: State object corresponding to the next state of the environment. + timestep: TimeStep object corresponding the timestep returned by the environment. + """ + + # unpack state + key = state.key + grid = state.grid + agents = state.agents + shelves = state.shelves + request_queue = state.request_queue + + # check for invalid action -> turn into noops + actions = utils.get_valid_actions(action, state.action_mask) + + # check for agent collisions + collisions = jax.vmap(functools.partial(utils.is_collision, grid))( + agents, actions + ) + collision = jnp.any(collisions) + + # update agents, shelves and grid + def update_state_scan( + carry_info: Tuple[chex.Array, chex.Array, chex.Array, int], action: int + ) -> Tuple[Tuple[chex.Array, chex.Array, chex.Array, int], None]: + grid, agents, shelves, agent_id = carry_info + grid, agents, shelves = self._update_state( + grid, agents, shelves, action, agent_id + ) + return (grid, agents, shelves, agent_id + 1), None + + (grid, agents, shelves, _), _ = jax.lax.scan( + update_state_scan, (grid, agents, shelves, 0), actions + ) + + # compute shared reward for all agents and update request queue + # if a requested shelf has been successfully delivered to the goal + reward = jnp.array(0, dtype=jnp.float32) + + def update_reward_and_request_queue_scan( + carry_info: Tuple[ + chex.PRNGKey, chex.Array, chex.Array, chex.Array, chex.Array + ], + goal: chex.Array, + ) -> Tuple[ + Tuple[chex.PRNGKey, chex.Array, chex.Array, chex.Array, chex.Array], None + ]: + key, reward, request_queue, grid, shelves = carry_info + ( + key, + reward, + request_queue, + shelves, + ) = self._update_reward_and_request_queue( + key, reward, request_queue, grid, shelves, goal + ) + carry_info = (key, reward, request_queue, grid, shelves) + return carry_info, None + + update_info, _ = jax.lax.scan( + update_reward_and_request_queue_scan, + (key, reward, request_queue, grid, shelves), + self.goals, + ) + key, reward, request_queue, grid, shelves = update_info + + # construct timestep and check environment termination + steps = state.step_count + 1 + horizon_reached = steps >= self.time_limit + done = collision | horizon_reached + + # compute next observation + agents_view = self._make_observations(grid, agents, shelves) + action_mask = utils.compute_action_mask(grid, agents) + next_observation = Observation( + agents_view=agents_view, + action_mask=action_mask, + step_count=steps, + ) + + timestep = jax.lax.cond( + done, + termination, + transition, + reward, + next_observation, + ) + next_state = State( + grid=grid, + agents=agents, + shelves=shelves, + request_queue=request_queue, + step_count=steps, + action_mask=action_mask, + key=key, + ) + return next_state, timestep + + def observation_spec(self) -> specs.Spec[Observation]: + """Specification of the observation of the `RobotWarehouse` environment. + Returns: + Spec for the `Observation`, consisting of the fields: + - agents_view: Array (int32) of shape (num_agents, num_obs_features). + - action_mask: BoundedArray (bool) of shape (num_agent, 5). + - step_count: BoundedArray (int32) of shape (). + """ + agents_view = specs.Array( + (self.num_agents, self.num_obs_features), jnp.int32, "agents_view" + ) + action_mask = specs.BoundedArray( + (self.num_agents, 5), bool, False, True, "action_mask" + ) + step_count = specs.BoundedArray((), jnp.int32, 0, self.time_limit, "step_count") + return specs.Spec( + Observation, + "ObservationSpec", + agents_view=agents_view, + action_mask=action_mask, + step_count=step_count, + ) + + def action_spec(self) -> specs.MultiDiscreteArray: + """Returns the action spec. 5 actions: [0,1,2,3,4] -> [No Op, Forward, Left, Right, Toggle_load]. + Since this is a multi-agent environment, the environment expects an array of actions. + This array is of shape (num_agents,). + """ + return specs.MultiDiscreteArray( + num_values=jnp.array([len(Action)] * self.num_agents, jnp.int32), + name="action", + ) + + def _make_observations( + self, + grid: chex.Array, + agents: Agent, + shelves: Shelf, + ) -> chex.Array: + """Create an observation for each agent based on its view of other + agents and shelves + + Args: + grid: the warehouse floor grid array. + agents: a pytree of Agent type containing agents information. + shelves: a pytree of Shelf type containing shelves information. + + Returns: + an array containing agents observations. + """ + return jax.vmap( + functools.partial( + utils.make_agent_observation, + grid, + agents, + shelves, + self.sensor_range, + self.num_obs_features, + self.highways, + ) + )(self.agent_ids) + + def _update_state( + self, + grid: chex.Array, + agents: chex.Array, + shelves: chex.Array, + action: int, + agent_id: int, + ) -> Tuple[chex.Array, chex.Array, chex.Array]: + """Update the state of the environment after an action is performed. + + Args: + grid: the warehouse floor grid array. + agents: a pytree of Agent type containing agents information. + shelves: a pytree of Shelf type containing shelves information. + action: the action performed by the agent. + agent_id: the id of the agent performing the action. + Returns: + the updated warehouse floor grid array, agents and shelves. + """ + agent = tree_slice(agents, agent_id) + is_highway = self.highways[agent.position.x, agent.position.y] + grid, agents, shelves = jax.lax.cond( + jnp.equal(action, Action.FORWARD.value), + set_new_position_after_forward, + set_new_direction_after_turn, + grid, + agents, + shelves, + action, + agent_id, + is_highway, + ) + + return grid, agents, shelves + + def _update_reward_and_request_queue( + self, + key: chex.PRNGKey, + reward: chex.Array, + request_queue: chex.Array, + grid: chex.Array, + shelves: chex.Array, + goal: chex.Array, + ) -> Tuple[chex.PRNGKey, int, chex.Array, chex.Array]: + """Check if a shelf has been delivered successfully to a goal state, + if so reward the agents and update the request queue: removing the ID + of the delivered shelf and replacing it with a new shelf ID. + + Args: + key: a pseudorandom number generator key. + reward: the array of shared reward for each agent. + request_queue: the queue of requested shelves. + grid: the warehouse floor grid array. + shelves: a pytree of Shelf type containing shelves information. + goal: array of goal positions. + Returns: + a random key, updated reward, request queue and shelves. + """ + x, y = goal + shelf_id = grid[_SHELVES, x, y] + + def reward_and_update_request_queue_if_shelf_in_goal( + key: chex.PRNGKey, + reward: jnp.int32, + request_queue: chex.Array, + shelves: chex.Array, + shelf_id: int, + ) -> Tuple[chex.PRNGKey, int, chex.Array, chex.Array]: + "Reward the agents and update the request queue." + + # remove from queue and replace it + key, request_key = jax.random.split(key) + + not_in_queue = jnp.setdiff1d( + self.shelf_ids, + request_queue, + size=self.not_in_queue_size, + ) + new_request_id = jax.random.choice( + request_key, + not_in_queue, + replace=False, + ) + replace_index = jnp.argwhere(jnp.equal(request_queue, shelf_id - 1), size=1) + request_queue = request_queue.at[replace_index].set(new_request_id) + + # also reward the agents + reward += 1.0 + + # update requested shelf + shelves = update_shelf(shelves, shelf_id - 1, "is_requested", 0) + shelves = update_shelf(shelves, new_request_id, "is_requested", 1) + return key, reward, request_queue, shelves + + # check if shelf is at goal position and in request queue + cond = (shelf_id != 0) & jnp.isin(shelf_id, request_queue + 1) + + key, reward, request_queue, shelves = jax.lax.cond( + cond, + reward_and_update_request_queue_if_shelf_in_goal, + lambda k, r, rq, g, _: (k, r, rq, g), + key, + reward, + request_queue, + shelves, + shelf_id, + ) + return key, reward, request_queue, shelves + + def render(self, state: State) -> Optional[NDArray]: + """Renders the current state of the RobotWarehouse environment. + + Args: + state: is the current environment state to be rendered. + save_path: the path where the image should be saved. If it is None, the plot + will not be stored. + """ + return self._viewer.render(state) + + def animate( + self, + states: Sequence[State], + interval: int = 200, + save_path: Optional[str] = None, + ) -> matplotlib.animation.FuncAnimation: + """Creates an animation from a sequence of RobotWarehouse states. + + Args: + states: sequence of `State` corresponding to subsequent timesteps. + interval: delay between frames in milliseconds, default to 200. + save_path: the path where the animation file should be saved. If it is None, + the plot will not be saved. + + Returns: + Animation object that can be saved as a GIF, MP4, or rendered with HTML. + """ + return self._viewer.animate( + states=states, interval=interval, save_path=save_path + ) + + def close(self) -> None: + """Perform any necessary cleanup. + + Environments will automatically :meth:`close()` themselves when + garbage collected or when the program exits. + """ + self._viewer.close() diff --git a/jumanji/environments/routing/robot_warehouse/env_test.py b/jumanji/environments/routing/robot_warehouse/env_test.py new file mode 100644 index 000000000..bf64f5ee7 --- /dev/null +++ b/jumanji/environments/routing/robot_warehouse/env_test.py @@ -0,0 +1,219 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import chex +import jax +import jax.numpy as jnp +from jax import random + +from jumanji.environments.routing.robot_warehouse.env import RobotWarehouse +from jumanji.environments.routing.robot_warehouse.types import State +from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.pytrees import assert_is_jax_array_tree +from jumanji.tree_utils import tree_slice +from jumanji.types import TimeStep + + +def test_robot_warehouse__specs(robot_warehouse_env: RobotWarehouse) -> None: + """Validate environment specs conform to the expected shapes and values""" + action_spec = robot_warehouse_env.action_spec() + observation_spec = robot_warehouse_env.observation_spec() + + assert observation_spec.agents_view.shape == (2, 66) # type: ignore + assert action_spec.num_values.shape[0] == robot_warehouse_env.num_agents + assert action_spec.num_values[0] == 5 + + +def test_robot_warehouse__reset(robot_warehouse_env: RobotWarehouse) -> None: + """Validate the jitted reset of the environment.""" + chex.clear_trace_counter() + reset_fn = jax.jit(chex.assert_max_traces(robot_warehouse_env.reset, n=1)) + + key1, key2 = random.PRNGKey(0), random.PRNGKey(1) + state1, timestep1 = reset_fn(key1) + state2, timestep2 = reset_fn(key2) + + assert isinstance(timestep1, TimeStep) + assert isinstance(state1, State) + assert state1.step_count == 0 + assert state1.grid.shape == (2, *robot_warehouse_env.grid_size) + # Check that the state is made of DeviceArrays, this is false for the non-jitted + # reset function since unpacking random.split returns numpy arrays and not device arrays. + assert_is_jax_array_tree(state1) + # Check random initialization + assert not jnp.all(state1.key == state2.key) + assert not jnp.all(state1.grid == state2.grid) + assert state1.step_count == state2.step_count + + +def test_robot_warehouse__agent_observation( + deterministic_robot_warehouse_env: Tuple[RobotWarehouse, State, TimeStep] +) -> None: + """Validate the agent observation function.""" + env, state, timestep = deterministic_robot_warehouse_env + state, timestep = env.step(state, jnp.array([0, 0])) + + # agent 1 obs + agent1_own_view = jnp.array([3, 4, 0, 0, 0, 1, 0, 1]) + agent1_other_agents_view = jnp.array(8 * [0, 0, 0, 0, 0]) + agent1_shelf_view = jnp.array(9 * [0, 0]) + agent1_obs = jnp.hstack( + [agent1_own_view, agent1_other_agents_view, agent1_shelf_view] + ) + + # agent 2 obs + agent2_own_view = jnp.array([1, 7, 0, 0, 0, 0, 1, 0]) + agent2_other_agents_view = jnp.array(8 * [0, 0, 0, 0, 0]) + agent2_shelf_view = jnp.array( + [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1] + ) + agent2_obs = jnp.hstack( + [agent2_own_view, agent2_other_agents_view, agent2_shelf_view] + ) + + assert jnp.all(timestep.observation.agents_view[0] == agent1_obs) + assert jnp.all(timestep.observation.agents_view[1] == agent2_obs) + + +def test_robot_warehouse__step(robot_warehouse_env: RobotWarehouse) -> None: + """Validate the jitted step function of the environment.""" + chex.clear_trace_counter() + + step_fn = chex.assert_max_traces(robot_warehouse_env.step, n=1) + step_fn = jax.jit(step_fn) + + state_key, action_key1, action_key2 = random.split(random.PRNGKey(10), 3) + state, timestep = robot_warehouse_env.reset(state_key) + + # Sample two different actions + action1, action2 = random.choice( + key=action_key1, + a=jnp.arange(5), + shape=(2,), + replace=False, + ) + + action1 = jnp.zeros((robot_warehouse_env.num_agents,), int).at[0].set(action1) + action2 = jnp.zeros((robot_warehouse_env.num_agents,), int).at[0].set(action2) + + new_state1, timestep1 = step_fn(state, action1) + + # Check that rewards have the correct number of dimensions + assert jnp.ndim(timestep1.reward) == 0 + assert jnp.ndim(timestep.reward) == 0 + # Check that discounts have the correct number of dimensions + assert jnp.ndim(timestep1.discount) == 0 + assert jnp.ndim(timestep.discount) == 0 + # Check that the state is made of DeviceArrays, this is false for the non-jitted + # step function since unpacking random.split returns numpy arrays and not device arrays. + assert_is_jax_array_tree(new_state1) + # Check that the state has changed + assert new_state1.step_count != state.step_count + assert not jnp.all(new_state1.grid != state.grid) + # Check that two different actions lead to two different states + new_state2, timestep2 = step_fn(state, action2) + assert not jnp.all(new_state1.grid != new_state2.grid) + + jax.debug.print("grid: {g}", g=state.grid) + + # Check that the state update and timestep creation work as expected + agents = state.agents + agent = tree_slice(agents, 1) + jax.debug.print("agents: {g}", g=agents) + x = agent.position.x + y = agent.position.y + + # turning and moving actions + actions = [2, 2, 3, 3, 1, 3, 1] + + # Note: starting direction is 3 (facing left) + new_locs = [ + (x, y, 2), # turn left -> facing down + (x, y, 1), # turn left -> facing right + (x, y, 2), # turn right -> facing down + (x, y, 3), # turn right -> face left + (x, y - 1, 3), # move forward -> move left + (x, y - 1, 0), # turn right -> face up + (x - 1, y - 1, 0), # move forward -> move up + ] + + for action, new_loc in zip(actions, new_locs): + state, timestep = step_fn(state, jnp.array([action, action])) + agent1_info = tree_slice(state.agents, 1) + agent1_loc = ( + agent1_info.position.x, + agent1_info.position.y, + agent1_info.direction, + ) + assert agent1_loc == new_loc + + +def test_robot_warehouse__does_not_smoke(robot_warehouse_env: RobotWarehouse) -> None: + """Validate that we can run an episode without any errors.""" + check_env_does_not_smoke(robot_warehouse_env) + + +def test_robot_warehouse__time_limit(robot_warehouse_env: RobotWarehouse) -> None: + """Validate the terminal reward.""" + step_fn = jax.jit(robot_warehouse_env.step) + state_key = random.PRNGKey(10) + state, timestep = robot_warehouse_env.reset(state_key) + assert timestep.first() + + for _ in range(robot_warehouse_env.time_limit - 1): + state, timestep = step_fn(state, jnp.array([0, 0])) + + assert timestep.mid() + state, timestep = step_fn(state, jnp.array([0, 0])) + assert timestep.last() + + +def test_robot_warehouse__truncation( + deterministic_robot_warehouse_env: Tuple[RobotWarehouse, State, TimeStep] +) -> None: + """Validate episode truncation based on set time limit.""" + robot_warehouse_env, state, timestep = deterministic_robot_warehouse_env + step_fn = jax.jit(robot_warehouse_env.step) + + # truncation + for _ in range(robot_warehouse_env.time_limit): + state, timestep = step_fn(state, jnp.array([0, 0])) + + assert timestep.last() + # note the line below should be used to test for truncation + # but since we instead use termination inside the env code + # for training capatibility, we check for omit this check + # assert not jnp.all(timestep.discount == 0) + + +def test_robot_warehouse__truncate_upon_collision( + deterministic_robot_warehouse_env: Tuple[RobotWarehouse, State, TimeStep] +) -> None: + """Validate episode terminates upon collision of agents.""" + robot_warehouse_env, state, timestep = deterministic_robot_warehouse_env + step_fn = jax.jit(robot_warehouse_env.step) + + # actions for agent 1 to collide with agent 2 + actions = [3, 1, 1, 3, 1, 1, 1] + + # take actions until collision + for action in actions: + state, timestep = step_fn(state, jnp.array([action, 0])) + + assert timestep.last() + # TODO: uncomment once we have changed termination + # in the env code to truncation (also see above) + # assert not jnp.all(timestep.discount == 0) diff --git a/jumanji/environments/routing/robot_warehouse/generator.py b/jumanji/environments/routing/robot_warehouse/generator.py new file mode 100644 index 000000000..2f0b79ec0 --- /dev/null +++ b/jumanji/environments/routing/robot_warehouse/generator.py @@ -0,0 +1,275 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import Callable + +import chex +import jax +import jax.numpy as jnp + +from jumanji.environments.routing.robot_warehouse.types import State +from jumanji.environments.routing.robot_warehouse.utils import compute_action_mask +from jumanji.environments.routing.robot_warehouse.utils_spawn import ( + place_entities_on_grid, + spawn_random_entities, +) + + +class Generator(abc.ABC): + """Base class for generators for the RobotWarehouse environment.""" + + def __init__( + self, + shelf_rows: int, + shelf_columns: int, + column_height: int, + num_agents: int, + sensor_range: int, + request_queue_size: int, + ) -> None: + """Initializes a robot_warehouse generator, used to generate grids for + the RobotWarehouse environment. + + Args: + shelf_rows: the number of shelf cluster rows, each of height = column_height. + Defaults to 1. + shelf_columns: the number of shelf cluster columns, each of width = 2 cells + (must be an odd number). Defaults to 3. + column_height: the height of each shelf cluster. Defaults to 8. + num_agents: the number of agents (robots) operating on the warehouse floor. + Defaults to 2. + sensor_range: the receptive field around an agent O O O + (e.g. 1 implies a 360 view of 1 cell around the -> O x O + agent's position cell) O O O + Defaults to 1. + request_queue_size: the number of shelves requested at any + given time which remains fixed throughout environment steps. Defaults to 4. + """ + if shelf_columns % 2 != 1: + raise ValueError( + "Environment argument: `shelf_columns`, must be an odd number." + ) + + self._shelf_rows = shelf_rows + self._shelf_columns = shelf_columns + self._column_height = column_height + self._num_agents = num_agents + self._sensor_range = sensor_range + self._request_queue_size = request_queue_size + + self._grid_size = ( + (column_height + 1) * shelf_rows + 2, + (2 + 1) * shelf_columns + 1, + ) + self._agent_ids = jnp.arange(num_agents) + + @property + def shelf_rows(self) -> int: + return self._shelf_rows + + @property + def shelf_columns(self) -> int: + return self._shelf_columns + + @property + def column_height(self) -> int: + return self._column_height + + @property + def grid_size(self) -> chex.Array: + return self._grid_size + + @property + def num_agents(self) -> int: + return self._num_agents + + @property + def sensor_range(self) -> int: + return self._sensor_range + + @property + def request_queue_size(self) -> int: + return self._request_queue_size + + @property + def agent_ids(self) -> chex.Array: + return self._agent_ids + + @property + @abc.abstractmethod + def shelf_ids(self) -> chex.Array: + """shelf ids""" + + @property + @abc.abstractmethod + def not_in_queue_size(self) -> chex.Array: + """number of shelves not in queue""" + + @property + @abc.abstractmethod + def highways(self) -> chex.Array: + """highways positions""" + + @property + @abc.abstractmethod + def goals(self) -> chex.Array: + """goals positions""" + + @abc.abstractmethod + def __call__(self, key: chex.PRNGKey) -> State: + """Generates an `RobotWarehouse` state. + + Returns: + An `RobotWarehouse` state. + """ + + +class GeneratorBase(Generator): + """Base class for `RobotWarehouse` environment state generator.""" + + def __init__( + self, + shelf_rows: int, + shelf_columns: int, + column_height: int, + num_agents: int, + sensor_range: int, + request_queue_size: int, + ) -> None: + """Initializes a robot_warehouse generator.""" + super().__init__( + shelf_rows, + shelf_columns, + column_height, + num_agents, + sensor_range, + request_queue_size, + ) + self._make_warehouse() + + def _make_warehouse(self) -> None: + """Create the layout for the warehouse floor, i.e. the grid + + Args: + shelf_rows: the number of shelf cluster rows + shelf_columns: the number of shelf cluster columns + column_height: the height of each shelf cluster + """ + + # create goal positions + self._goals = jnp.array( + [ + (self._grid_size[1] // 2 - 1, self._grid_size[0] - 1), + (self._grid_size[1] // 2, self._grid_size[0] - 1), + ] + ) + # calculate "highways" (these are open spaces/cells between shelves) + highway_func: Callable[[int, int], bool] = lambda x, y: ( + (y % 3 == 0) # vertical highways + | (x % (self.column_height + 1) == 0) # horizontal highways + | (x == self._grid_size[0] - 1) # delivery row + | ( # remove middle cluster to allow agents to queue in front of goals + (x > self._grid_size[0] - (self.column_height + 3)) + & ((y == self._grid_size[1] // 2 - 1) | (y == self._grid_size[1] // 2)) + ) + ) + grid_indices = jnp.indices(jnp.zeros(self._grid_size, dtype=jnp.int32).shape) + self._highways = jax.vmap(highway_func)(grid_indices[0], grid_indices[1]) + + non_highways = jnp.abs(self.highways - 1) + + # shelves information + n_shelves = jnp.sum(non_highways) + self._shelf_positions = jnp.argwhere(non_highways) + self._shelf_ids = jnp.arange(n_shelves) + self._not_in_queue_size = n_shelves - self.request_queue_size + + @property + def shelf_ids(self) -> chex.Array: + return self._shelf_ids + + @property + def not_in_queue_size(self) -> chex.Array: + return self._not_in_queue_size + + @property + def highways(self) -> chex.Array: + return self._highways + + @property + def goals(self) -> chex.Array: + return self._goals + + +class RandomGenerator(GeneratorBase): + """Randomly generates `RobotWarehouse` environment state. This generator places agents at + starting positions on the grid and selects the requested shelves uniformly at random. + """ + + def __init__( + self, + shelf_rows: int, + shelf_columns: int, + column_height: int, + num_agents: int, + sensor_range: int, + request_queue_size: int, + ) -> None: + """Initialises an robot_warehouse generator, used to generate grids for + the RobotWarehouse environment.""" + super().__init__( + shelf_rows, + shelf_columns, + column_height, + num_agents, + sensor_range, + request_queue_size, + ) + + def __call__(self, key: chex.PRNGKey) -> State: + """Generates a `RobotWarehouse` state that contains the grid and the agents/shelves layout. + + Returns: + A `RobotWarehouse` state. + """ + # empty grid array + grid = jnp.zeros((2, *self._grid_size), dtype=jnp.int32) + + # spawn random agents with random request queue + key, agents, shelves, shelf_request_queue = spawn_random_entities( + key, + self._grid_size, + self._agent_ids, + self._shelf_ids, + self._shelf_positions, + self._request_queue_size, + ) + grid = place_entities_on_grid(grid, agents, shelves) + + # compute action mask + action_mask = compute_action_mask(grid, agents) + + # create environment state + state = State( + grid=grid, + agents=agents, + shelves=shelves, + request_queue=shelf_request_queue, + step_count=jnp.array(0, int), + action_mask=action_mask, + key=key, + ) + + return state diff --git a/jumanji/environments/routing/robot_warehouse/generator_test.py b/jumanji/environments/routing/robot_warehouse/generator_test.py new file mode 100644 index 000000000..925d66b16 --- /dev/null +++ b/jumanji/environments/routing/robot_warehouse/generator_test.py @@ -0,0 +1,53 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import chex +import jax +import pytest + +from jumanji.environments.routing.robot_warehouse.generator import RandomGenerator + + +@pytest.fixture +def random_generator() -> RandomGenerator: + """Creates a generator with 2 agents.""" + return RandomGenerator( + shelf_rows=1, + shelf_columns=3, + column_height=2, + num_agents=2, + sensor_range=1, + request_queue_size=4, + ) + + +def test_random_generator__call(random_generator: RandomGenerator) -> None: + """Test that generator generates valid boards.""" + key = jax.random.PRNGKey(42) + state = random_generator(key) + grid_size = (2, 5, 10) + assert state.grid.shape == grid_size + assert state.agents.direction.shape[0] == 2 + + +def test_random_generator__no_retrace( + random_generator: RandomGenerator, +) -> None: + """Checks that generator only traces the function once and works when jitted.""" + key = jax.random.PRNGKey(42) + keys = jax.random.split(key, 2) + jitted_generator = jax.jit(chex.assert_max_traces((random_generator.__call__), n=1)) + + for key in keys: + jitted_generator(key) diff --git a/jumanji/environments/routing/robot_warehouse/types.py b/jumanji/environments/routing/robot_warehouse/types.py new file mode 100644 index 000000000..bc241e09c --- /dev/null +++ b/jumanji/environments/routing/robot_warehouse/types.py @@ -0,0 +1,133 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, NamedTuple, Union + +if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239 + from dataclasses import dataclass +else: + from chex import dataclass + +from enum import IntEnum + +import chex + + +class Action(IntEnum): + """An enumeration of possible actions + that an agent can take in the warehouse. + + NOOP - represents no operation. + FORWARD - move forward. + LEFT - turn left. + RIGHT - turn right. + TOGGLE_LOAD - toggle loading/offloading a shelf. + """ + + NOOP = 0 + FORWARD = 1 + LEFT = 2 + RIGHT = 3 + TOGGLE_LOAD = 4 + + +class Direction(IntEnum): + """An enumeration of possible directions + that an agent can take in the warehouse. + + UP - move up. + RIGHT - move right. + DOWN - move down. + LEFT - move left. + """ + + UP = 0 + RIGHT = 1 + DOWN = 2 + LEFT = 3 + + +class Position(NamedTuple): + """A class to represent the 2D coordinate position of entities + + x: the x-position of the entity. + y: the y-position of the entity. + """ + + x: chex.Array # () + y: chex.Array # () + + +class Agent(NamedTuple): + """A class to represent an Agent in the warehouse + + position: the (x,y) position of the agent. + direction: the direction the agent is facing. + is_carrying: whether the agent is carrying a shelf or not. + """ + + position: Position # (2,) + direction: chex.Array # () + is_carrying: chex.Array # () + + +class Shelf(NamedTuple): + """A class to represent a Shelf in the warehouse. + + position: the (x,y) position of the shelf. + is_requested: whether the shelf is requested for delivery. + """ + + position: Position # (2,) + is_requested: chex.Array # () + + +Entity = Union[Agent, Shelf] + + +@dataclass +class State: + """A dataclass representing the state of the simulated warehouse. + + grid: an array representing the warehouse floor as a 2D grid with two separate channels + one for the agents, and one for the shelves. + agents: a pytree of Agent type with per agent leaves: [position, direction, is_carrying] + shelves: a pytree of Shelf type with per shelf leaves: [position, is_requested] + request_queue : the queue of requested shelves (by ID). + step_count: an integer representing the current step of the episode. + key: a pseudorandom number generator key. + """ + + grid: chex.Array # (2, grid_width, grid_height) + agents: Agent # (num_agents, ...) + shelves: Shelf # (num_shelves, ...) + request_queue: chex.Array # (num_requested,) + step_count: chex.Array # () + action_mask: chex.Array # (num_agents, 5) + key: chex.PRNGKey # (2,) + + +class Observation(NamedTuple): + """The observation that the agent sees. + agents_view: the agents' view of other agents and shelves within their + sensor range. The number of features in the observation array + depends on the sensor range of the agent. + action_mask: boolean array specifying, for each agent, which action + (up, right, down, left) is legal. + step_count: the number of steps elapsed since the beginning of the episode. + """ + + agents_view: chex.Array # (num_agents, num_obs_features) + action_mask: chex.Array # (num_agents, 5) + step_count: chex.Array # () diff --git a/jumanji/environments/routing/robot_warehouse/utils.py b/jumanji/environments/routing/robot_warehouse/utils.py new file mode 100644 index 000000000..87616d419 --- /dev/null +++ b/jumanji/environments/routing/robot_warehouse/utils.py @@ -0,0 +1,341 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import Tuple + +import chex +import jax +import jax.numpy as jnp + +from jumanji.environments.routing.robot_warehouse.constants import _AGENTS, _SHELVES +from jumanji.environments.routing.robot_warehouse.types import Action, Agent, Entity +from jumanji.environments.routing.robot_warehouse.utils_agent import ( + get_agent_view, + get_new_position_after_forward, +) +from jumanji.tree_utils import tree_slice + + +def get_entity_ids(entities: Entity) -> chex.Array: + """Get ids for agents/shelves. + + Args: + entities: a pytree of Agent or Shelf type. + + Returns: + an array of ids. + """ + return jnp.arange(entities[1].shape[0]) + + +def is_valid_action(grid: chex.Array, agent: Agent, action: chex.Array) -> chex.Array: + """If the agent is carrying a shelf and collides with another + shelf based on its current action, this action is deemed invalid. + + Args: + grid: the warehouse floor grid array. + agent: the agent for which the action is being checked. + action: the action the agent is about to take. + + Returns: + a boolean indicating whether the action is valid or not. + """ + + # get start and target positions + start = agent.position + target = get_new_position_after_forward(grid, start, agent.direction) + + # check if carrying and walking into another shelf + cond = jnp.logical_and(jnp.equal(action, Action.FORWARD), agent.is_carrying) + cond = jnp.logical_and(cond, jnp.logical_not(jnp.array_equal(start, target))) + cond = jnp.logical_and(cond, grid[_SHELVES, target.x, target.y]) + + return ~cond + + +def get_valid_actions(actions: chex.Array, action_mask: chex.Array) -> chex.Array: + """Get the valid action the agent should take given its action mask. + + Args: + actions: the actions the agents are about to take. + action_mask: the mask of valid actions. + + Returns: + the action the agent should take given its current state. + """ + + def get_valid_action(action_mask: chex.Array, action: chex.Array) -> chex.Array: + return jax.lax.cond(action_mask[action], lambda: action, lambda: 0) + + return jax.vmap(get_valid_action)(action_mask, actions) + + +def is_collision(grid: chex.Array, agent: Agent, action: int) -> chex.Array: + """Calculate whether an agent is about to collide with another + entity. If the agent instead collides with another agent, the + episode terminates (this behavior is specific to this JAX version). + + Args: + grid: the warehouse floor grid array. + agent: the agent for which collisions are being checked. + + Returns: + a boolean indicating whether the agent collided with another + agent or not. + """ + + def check_collision() -> chex.Array: + # get start and target positions + start = agent.position + target = get_new_position_after_forward(grid, start, agent.direction) + + agent_id_at_target_pos = grid[_AGENTS, target.x, target.y] + check_forward = ~jnp.array_equal(start, target) + return check_forward & (agent_id_at_target_pos > 0) + + return jax.lax.cond( + jnp.equal(action, Action.FORWARD), check_collision, lambda: False + ) + + +def calculate_num_observation_features(sensor_range: chex.Array) -> chex.Array: + """Calculates the 1-d size of the agent observations array based on the + environment parameters at instantiation + + Below is a receptive field for an agent x with a sensor range of 1: + + O O O + O x O + O O O + + For the sensor on the agent's own position, + we have the following features + 1. the agent's position -> dim 2 + 2. is the agent carrying a shelf? -> binary {0, 1} with dim 1 + 3. the direction of the agent -> one-hot with dim 4 + 4. is the agent on the warehouse "highway" or not? -> binary with dim 1 + Total dim for agent in focus = 2 + 1 + 4 + 1 = 8 + + Then, for each sensor position (other than the agent's own position, 8 in total), + we have the following features based on other agents: + 1. is there an agent? -> binary {0, 1} with dim 1 + 2. if yes, the agent's direction -> one-hot with dim 4, if no, fill all zeros + Therefore, the total number of dimensions for other agent features + (1 + 4) * num_obs_sensors + + Finally, for each sensor position (9 in total) in the agent's receptive field, + we have the following features based on shelves: + 1. is there a shelf? -> binary {0, 1} with dim 1 + 2. if so, has this shelf been requested -> binary {0, 1} with dim 1, if no, zero + Therefore, the total number of dimensions for shelf features is + (1 + 1) * num_obs_sensors + + Args: + sensor_range: the range of the agent's sensors. + + Returns: + agent's 1-d observation array. + """ + num_obs_sensors = (1 + 2 * sensor_range) ** 2 + obs_features = 8 # agent's own features + obs_features += (num_obs_sensors - 1) * 5 # other agent features + obs_features += num_obs_sensors * 2 # shelf features + return jnp.array(obs_features, jnp.int32) + + +def write_to_observation( + observation: chex.Array, idx: chex.Array, data: chex.Array +) -> Tuple[chex.Array, chex.Array]: + """Write data to the given observation vector at a specified index + + Args: + observation: an observation to which data will be written. + idx: an integer representing the index at which the data will be inserted. + data: the data that will be inserted into the observation array. + + Returns: + the updated observation array and the new index + of where to insert the next data. + """ + data_size = len(data) + observation = jax.lax.dynamic_update_slice(observation, data, (idx,)) + return observation, idx + data_size + + +def move_writer_index(idx: chex.Array, bits: chex.Array) -> chex.Array: + """Skip an indicated number of bits in the observation array being written. + + Args: + idx: an integer representing the index at which to skip bits. + bits: the number of bits to skip. + + Returns: + the new index at which to insert data. + """ + return idx + bits + + +def make_agent_observation( + grid: chex.Array, + agents: chex.Array, + shelves: chex.Array, + sensor_range: int, + num_obs_features: int, + highways: chex.Array, + agent_id: int, +) -> chex.Array: + """Create an observation for a single agent based on its view + of other agents and shelves. + + Args: + grid: the warehouse floor grid array. + agents: a pytree of either Agent type containing agent information. + shelves: a pytree of Shelf type containing shelf information. + sensor_range: the range of the agent's sensors. + num_obs_features: the number of features in the observation array. + highways: binary array indicating highway positions. + agent_id: unique ID identifying a specific agent. + + Returns: + a 1-d array containing the agent's observation. + """ + agent = tree_slice(agents, agent_id) + agents_grid, shelves_grid = get_agent_view(grid, agent, sensor_range) + + # write flattened observations + obs = jnp.zeros(num_obs_features, dtype=jnp.int32) + idx = 0 + + # write current agent position and whether carrying a shelf or not + obs, idx = write_to_observation( + obs, + idx, + jnp.array( + [agent.position.x, agent.position.y, agent.is_carrying], + dtype=jnp.int32, + ), + ) + + # write current agent direction + direction = jax.nn.one_hot(agent.direction, 4, dtype=jnp.int32) + obs, idx = write_to_observation(obs, idx, direction) + + # write if agent is on highway or not + obs, idx = write_to_observation( + obs, + idx, + jnp.array( + [jnp.array(highways[agent.position.x, agent.position.y], int)], + dtype=jnp.int32, + ), + ) + + # function for writing receptive field cells + def write_no_agent( + obs: chex.Array, idx: int, _: int, is_self: bool + ) -> Tuple[chex.Array, int]: + "Write information for empty agent cell." + # if there is no agent we set a 0 and all zeros + # for the direction as well, i.e. [0, 0, 0, 0, 0] + idx = jax.lax.cond(is_self, lambda i: i, lambda i: move_writer_index(i, 5), idx) + return obs, idx + + def write_agent( + obs: chex.Array, idx: int, id_agent: int, _: bool + ) -> Tuple[chex.Array, int]: + "Write information for cell containing an agent." + obs, idx = write_to_observation(obs, idx, jnp.array([1], dtype=jnp.int32)) + direction = jax.nn.one_hot( + tree_slice(agents, id_agent - 1).direction, 4, dtype=jnp.int32 + ) + obs, idx = write_to_observation(obs, idx, direction) + return obs, idx + + def write_no_shelf(obs: chex.Array, idx: int, _: int) -> Tuple[chex.Array, int]: + "write information for empty shelf cell." + idx = move_writer_index(idx, 2) + return obs, idx + + def write_shelf(obs: chex.Array, idx: int, shelf_id: int) -> Tuple[chex.Array, int]: + "Write information for cell containing a shelf." + requested = tree_slice(shelves, shelf_id - 1).is_requested + shelf = jnp.array([1, requested], dtype=jnp.int32) + obs, idx = write_to_observation(obs, idx, shelf) + return obs, idx + + def agent_sensor_scan( + obs_idx_and_agent_id: Tuple[chex.Array, chex.Array, chex.Array], + agent_sensor: chex.Array, + ) -> Tuple[Tuple[chex.Array, chex.Array, chex.Array], None]: + """Write agent observation with agent sensor information + of other agents. + """ + obs, idx, agent_id = obs_idx_and_agent_id + sensor_check_for_self = jnp.equal(agent_sensor, agent_id + 1) + sensor_check_for_self_or_no_other = jnp.logical_or( + jnp.equal(agent_sensor, 0), + sensor_check_for_self, + ) + obs, idx = jax.lax.cond( + sensor_check_for_self_or_no_other, + write_no_agent, + write_agent, + obs, + idx, + agent_sensor, + sensor_check_for_self, + ) + return (obs, idx, agent_id), None + + def shelf_sensor_scan( + obs_and_idx: Tuple[chex.Array, chex.Array], shelf_sensor: chex.Array + ) -> Tuple[Tuple[chex.Array, chex.Array], None]: + """Write agent observation with agent sensor information + of other shelves. + """ + obs, idx = obs_and_idx + obs, idx = jax.lax.cond( + jnp.equal(shelf_sensor, 0), + write_no_shelf, + write_shelf, + obs, + idx, + shelf_sensor, + ) + return (obs, idx), None + + (obs, idx, _), _ = jax.lax.scan( + agent_sensor_scan, (obs, idx, agent_id), agents_grid + ) + (obs, _), _ = jax.lax.scan(shelf_sensor_scan, (obs, idx), shelves_grid) + return obs + + +def compute_action_mask(grid: chex.Array, agents: Agent) -> chex.Array: + """Compute the action mask for the environment. + + Args: + grid: the warehouse floor grid array. + agents: a pytree of either Agent type containing agent information. + + Returns: + the action mask for the environment. + """ + # vmap over agents and possible actions + action_mask = jax.vmap( + jax.vmap(functools.partial(is_valid_action, grid), in_axes=(None, 0)), + in_axes=(0, None), + )(agents, jnp.arange(5)) + return action_mask diff --git a/jumanji/environments/routing/robot_warehouse/utils_agent.py b/jumanji/environments/routing/robot_warehouse/utils_agent.py new file mode 100644 index 000000000..1e568f7fb --- /dev/null +++ b/jumanji/environments/routing/robot_warehouse/utils_agent.py @@ -0,0 +1,330 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, Union + +import chex +import jax +import jax.numpy as jnp + +from jumanji.environments.routing.robot_warehouse.constants import _AGENTS, _SHELVES +from jumanji.environments.routing.robot_warehouse.types import Action, Agent, Position +from jumanji.environments.routing.robot_warehouse.utils_shelf import ( + set_new_shelf_position_if_carrying, +) +from jumanji.tree_utils import tree_add_element, tree_slice + + +def update_agent( + agents: Agent, + agent_id: chex.Array, + attr: str, + value: Union[chex.Array, Position], +) -> Agent: + """Update the attribute information of a specific agent. + + Args: + agents: a pytree of either Agent type containing agent information. + agent_id: unique ID identifying a specific agent. + attr: the attribute to update, e.g. `direction`, or `is_requested`. + value: the new value to which the attribute is to be set. + + Returns: + the agent with the specified attribute updated to the given value. + """ + params = {attr: value} + agent = tree_slice(agents, agent_id) + agent = agent._replace(**params) + agents: Agent = tree_add_element(agents, agent_id, agent) + return agents + + +def get_new_direction_after_turn( + action: chex.Array, agent_direction: chex.Array +) -> chex.Array: + """Get the correct direction the agent should face given + the turn action it took. E.g. if the agent is facing LEFT + and turns RIGHT it should now be facing UP, etc. + + Args: + action: the agent's action. + agent_direction: the agent's current direction. + + Returns: + the direction the agent should be facing given the action it took. + """ + change_in_direction = jnp.array([0, 0, -1, 1, 0])[action] + return (agent_direction + change_in_direction) % 4 + + +def get_new_position_after_forward( + grid: chex.Array, agent_position: chex.Array, agent_direction: chex.Array +) -> Position: + """Get the correct position the agent will be in after moving forward + in its current direction. E.g. if the agent is facing LEFT and turns + RIGHT it should stay in the same position. If instead it moves FORWARD + it should move left by one cell. + + Args: + grid: the warehouse floor grid array. + agent_position: the agent's current position. + agent_direction: the agent's current direction. + + Returns: + the position the agent should be in given the action it took. + """ + _, grid_width, grid_height = grid.shape + x, y = agent_position.x, agent_position.y + move_up = lambda x, y: Position(jnp.max(jnp.array([0, x - 1])), y) + move_right = lambda x, y: Position(x, jnp.min(jnp.array([grid_height - 1, y + 1]))) + move_down = lambda x, y: Position(jnp.min(jnp.array([grid_width - 1, x + 1])), y) + move_left = lambda x, y: Position(x, jnp.max(jnp.array([0, y - 1]))) + new_position: Position = jax.lax.switch( + agent_direction, [move_up, move_right, move_down, move_left], x, y + ) + return new_position + + +def get_agent_view( + grid: chex.Array, agent: chex.Array, sensor_range: chex.Array +) -> Tuple[chex.Array, chex.Array]: + """Get an agent's view of other agents and shelves within its + sensor range. + + Below is an example of the agent's view of other agents from + the perspective of agent 1 with a sensor range of 1: + + 0, 0, 0 + 0, 1, 2 + 0, 0, 0 + + It sees agent 2 to its right. Separately, the view of shelves + is shown below: + + 0, 0, 0 + 0, 3, 4 + 0, 7, 8 + + Agent 1 is on top of shelf 3 and has 4, 7 and 8 around it in + the bottom right corner of its view. Before returning these + views they are flattened into a 1-d arrays, i.e. + + View of agents: [0, 0, 0, 0, 1, 2, 0, 0, 0] + View of shelves: [0, 0, 0, 0, 3, 4, 0, 7, 8] + + + Args: + grid: the warehouse floor grid array. + agent: the agent for which the view of their receptive field + is to be calculated. + sensor_range: the range of the agent's sensors. + + Returns: + a view of the agents receptive field separated into two arrays: + one for other agents and one for shelves. + """ + receptive_field = sensor_range * 2 + 1 + padded_agents_layer = jnp.pad(grid[_AGENTS], sensor_range, mode="constant") + padded_shelves_layer = jnp.pad(grid[_SHELVES], sensor_range, mode="constant") + agent_view_of_agents = jax.lax.dynamic_slice( + padded_agents_layer, + (agent.position.x, agent.position.y), + (receptive_field, receptive_field), + ).reshape(-1) + agent_view_of_shelves = jax.lax.dynamic_slice( + padded_shelves_layer, + (agent.position.x, agent.position.y), + (receptive_field, receptive_field), + ).reshape(-1) + return agent_view_of_agents, agent_view_of_shelves + + +def set_agent_carrying_if_at_shelf_position( + grid: chex.Array, agents: chex.Array, agent_id: int, is_highway: chex.Array +) -> chex.Array: + """Set the agent as carrying a shelf if it is at a shelf position. + + Args: + grid: the warehouse floor grid array. + agents: a pytree of either Agent type containing agent information. + agent_id: unique ID identifying a specific agent. + is_highway: binary value indicating highway position. + + Returns: + updated agents pytree. + """ + agent = tree_slice(agents, agent_id) + shelf_id = grid[_SHELVES, agent.position.x, agent.position.y] + + return jax.lax.cond( + shelf_id > 0, + lambda: update_agent(agents, agent_id, "is_carrying", 1), + lambda: agents, + ) + + +def offload_shelf_if_position_is_open( + grid: chex.Array, agents: chex.Array, agent_id: int, is_highway: chex.Array +) -> chex.Array: + """Set the agent as not carrying a shelf if it is at a shelf position. + + Args: + grid: the warehouse floor grid array. + agents: a pytree of either Agent type containing agent information. + agent_id: unique ID identifying a specific agent. + is_highway: binary value indicating highway position. + + Returns: + updated agents pytree. + """ + return jax.lax.cond( + jnp.logical_not(is_highway), + lambda: update_agent(agents, agent_id, "is_carrying", 0), + lambda: agents, + ) + + +def set_carrying_shelf_if_load_toggled_and_not_carrying( + grid: chex.Array, + agents: chex.Array, + action: int, + agent_id: int, + is_highway: chex.Array, +) -> chex.Array: + """Set the agent as carrying a shelf if the load toggle action is + performed and the agent is not carrying a shelf. + + Args: + grid: the warehouse floor grid array. + agents: a pytree of either Agent type containing agent information. + action: the agent's action. + agent_id: unique ID identifying a specific agent. + is_highway: binary value indicating highway position. + + Returns: + updated agents pytree. + """ + agent = tree_slice(agents, agent_id) + + agents = jax.lax.cond( + (action == Action.TOGGLE_LOAD.value) & ~agent.is_carrying, + set_agent_carrying_if_at_shelf_position, + offload_shelf_if_position_is_open, + grid, + agents, + agent_id, + is_highway, + ) + return agents + + +def rotate_agent( + grid: chex.Array, + agents: chex.Array, + action: int, + agent_id: int, + is_highway: chex.Array, +) -> chex.Array: + """Rotate the agent in the direction of the action. + + Args: + grid: the warehouse floor grid array. + agents: a pytree of either Agent type containing agent information. + action: the agent's action. + agent_id: unique ID identifying a specific agent. + is_highway: binary value indicating highway position. + + Returns: + updated agents pytree. + """ + agent = tree_slice(agents, agent_id) + new_direction = get_new_direction_after_turn(action, agent.direction) + return update_agent(agents, agent_id, "direction", new_direction) + + +def set_new_position_after_forward( + grid: chex.Array, + agents: chex.Array, + shelves: chex.Array, + action: int, + agent_id: int, + is_highway: chex.Array, +) -> Tuple[chex.Array, chex.Array, chex.Array]: + """Set the new position of the agent after a forward action. + + Args: + grid: the warehouse floor grid array. + agents: a pytree of either Agent type containing agent information. + shelves: a pytree of Shelf type containing shelf information. + action: the agent's action. + agent_id: unique ID identifying a specific agent. + is_highway: binary value indicating highway position. + + Returns: + updated grid array, agents and shelves pytrees. + """ + # update agent position + agent = tree_slice(agents, agent_id) + current_position = agent.position + new_position = get_new_position_after_forward(grid, agent.position, agent.direction) + agents = update_agent(agents, agent_id, "position", new_position) + + # update agent grid placement + grid = grid.at[_AGENTS, current_position.x, current_position.y].set(0) + grid = grid.at[_AGENTS, new_position.x, new_position.y].set(agent_id + 1) + + grid, shelves = jax.lax.cond( + agent.is_carrying, + set_new_shelf_position_if_carrying, + lambda g, s, p, np: (g, s), + grid, + shelves, + current_position, + new_position, + ) + return grid, agents, shelves + + +def set_new_direction_after_turn( + grid: chex.Array, + agents: chex.Array, + shelves: chex.Array, + action: int, + agent_id: int, + is_highway: chex.Array, +) -> Tuple[chex.Array, chex.Array, chex.Array]: + """Set the new direction of the agent after a turning action. + + Args: + grid: the warehouse floor grid array. + agents: a pytree of either Agent type containing agent information. + shelves: a pytree of Shelf type containing shelf information. + action: the agent's action. + agent_id: unique ID identifying a specific agent. + is_highway: binary value indicating highway position. + + Returns: + updated grid array, agents and shelves pytrees. + """ + agents = jax.lax.cond( + jnp.isin(action, jnp.array([Action.LEFT.value, Action.RIGHT.value])), + rotate_agent, + set_carrying_shelf_if_load_toggled_and_not_carrying, + grid, + agents, + action, + agent_id, + is_highway, + ) + return grid, agents, shelves diff --git a/jumanji/environments/routing/robot_warehouse/utils_shelf.py b/jumanji/environments/routing/robot_warehouse/utils_shelf.py new file mode 100644 index 000000000..91fb59ba2 --- /dev/null +++ b/jumanji/environments/routing/robot_warehouse/utils_shelf.py @@ -0,0 +1,72 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, Union + +import chex + +from jumanji.environments.routing.robot_warehouse.constants import _SHELVES +from jumanji.environments.routing.robot_warehouse.types import Position, Shelf +from jumanji.tree_utils import tree_add_element, tree_slice + + +def update_shelf( + shelves: Shelf, + shelf_id: chex.Array, + attr: str, + value: Union[chex.Array, Position], +) -> Shelf: + """Update the attribute information of a specific shelf. + + Args: + shelves: a pytree of Shelf type containing shelf information. + shelf_id: unique ID identifying a specific shelf. + attr: the attribute to update, e.g. `direction`, or `is_requested`. + value: the new value to which the attribute is to be set. + + Returns: + the shelf with the specified attribute updated to the given value. + """ + params = {attr: value} + shelf = tree_slice(shelves, shelf_id) + shelf = shelf._replace(**params) + shelves: Shelf = tree_add_element(shelves, shelf_id, shelf) + return shelves + + +def set_new_shelf_position_if_carrying( + grid: chex.Array, + shelves: Shelf, + cur_pos: chex.Array, + new_pos: chex.Array, +) -> Tuple[chex.Array, chex.Array]: + """Set the new position of the shelf if the agent is carrying one. + + Args: + grid: the warehouse floor grid array. + shelves: a pytree of Shelf type containing shelf information. + cur_pos: the current position of the shelf. + new_pos: the new position of the shelf. + + Returns: + updated grid array and shelves pytree. + """ + # update shelf position + shelf_id = grid[_SHELVES, cur_pos.x, cur_pos.y] + shelves = update_shelf(shelves, shelf_id - 1, "position", new_pos) + + # update shelf grid placement + grid = grid.at[_SHELVES, cur_pos.x, cur_pos.y].set(0) + grid = grid.at[_SHELVES, new_pos.x, new_pos.y].set(shelf_id) + return grid, shelves diff --git a/jumanji/environments/routing/robot_warehouse/utils_spawn.py b/jumanji/environments/routing/robot_warehouse/utils_spawn.py new file mode 100644 index 000000000..2da8b91c2 --- /dev/null +++ b/jumanji/environments/routing/robot_warehouse/utils_spawn.py @@ -0,0 +1,192 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import chex +import jax +import jax.numpy as jnp + +from jumanji.environments.routing.robot_warehouse.constants import ( + _AGENTS, + _POSSIBLE_DIRECTIONS, + _SHELVES, +) +from jumanji.environments.routing.robot_warehouse.types import ( + Agent, + Entity, + Position, + Shelf, +) +from jumanji.environments.routing.robot_warehouse.utils import get_entity_ids +from jumanji.tree_utils import tree_slice + + +def spawn_agent( + agent_coordinates: chex.Array, + direction: chex.Array, +) -> chex.Array: + """Spawn an agent (robot) at a given position and direction. + + Args: + agent_coordinates: x, y coordinates of the agent. + direction: direction of the agent. + + Returns: + spawned agent. + """ + x, y = agent_coordinates + agent_pos = Position(x=x, y=y) + agent = Agent(position=agent_pos, direction=direction, is_carrying=0) + return agent + + +def spawn_shelf( + shelf_coordinates: chex.Array, + requested: chex.Array, +) -> chex.Array: + """Spawn a shelf at a specific shelf position and label the shelf + as requested or not. + + Args: + shelf_coordinates: x, y coordinates of the shelf. + requested: whether the shelf has been requested or not. + + Returns: + spawned shelf. + """ + x, y = shelf_coordinates + shelf_pos = Position(x=x, y=y) + shelf = Shelf(position=shelf_pos, is_requested=requested) + return shelf + + +def spawn_random_entities( + key: chex.PRNGKey, + grid_size: chex.Array, + agent_ids: chex.Array, + shelf_ids: chex.Array, + shelf_coordinates: chex.Array, + request_queue_size: chex.Array, +) -> Tuple[chex.PRNGKey, Agent, Shelf, chex.Array]: + """Spawn agents and shelves on the warehouse floor grid. + + Args: + key: pseudo random number key. + grid_size: the size of the warehouse floor grid. + agent_ids: array of agent ids. + shelf_ids: array of shelf ids. + shelf_coordinates: x,y coordinates of shelf positions. + request_queue_size: the number of shelves to be delivered. + + Returns: + new key, spawned agents, shelves and the request queue. + """ + + # random agent positions + num_agents = len(agent_ids) + key, position_key = jax.random.split(key) + grid_cells = jnp.array(jnp.arange(grid_size[0] * grid_size[1])) + agent_coords = jax.random.choice( + position_key, + grid_cells, + shape=(num_agents,), + replace=False, + ) + agent_coords = jnp.transpose( + jnp.asarray(jnp.unravel_index(agent_coords, grid_size)) + ) + + # random agent directions + key, direction_key = jax.random.split(key) + + agent_dirs = jax.random.choice( + direction_key, _POSSIBLE_DIRECTIONS, shape=(num_agents,) + ) + + # sample request queue + key, queue_key = jax.random.split(key) + shelf_request_queue = jax.random.choice( + queue_key, + shelf_ids, + shape=(request_queue_size,), + replace=False, + ) + requested_ids = jnp.zeros(shelf_ids.shape) + requested_ids = requested_ids.at[shelf_request_queue].set(1) + + # spawn agents and shelves + agents = jax.vmap(spawn_agent)(agent_coords, agent_dirs) + shelves = jax.vmap(spawn_shelf)(shelf_coordinates, requested_ids) + return key, agents, shelves, shelf_request_queue + + +def place_entity_on_grid( + grid: chex.Array, + channel: chex.Array, + entities: Entity, + entity_id: chex.Array, +) -> chex.Array: + """Places an entity (Agent/Shelf) on the grid based on its + (x, y) position defined once spawned. + + Args: + grid: the warehouse floor grid array. + channel: the grid channel index, either agents or shelves. + entities: a pytree of Agent or Shelf type containing entity information. + entity_id: unique ID identifying a specific entity. + + Returns: + the warehouse grid with the specific entity in its position. + """ + entity = tree_slice(entities, entity_id) + x, y = entity.position.x, entity.position.y + return grid.at[channel, x, y].set(entity_id + 1) + + +def place_entities_on_grid( + grid: chex.Array, agents: Agent, shelves: Shelf +) -> chex.Array: + """Place agents and shelves on the grid. + + Args: + grid: the warehouse floor grid array. + agents: a pytree of Agent type containing agent information. + shelves: a pytree of Shelf type containing shelf information. + + Returns: + the warehouse grid with all agents and shelves placed in their + positions. + """ + agent_ids = get_entity_ids(agents) + shelf_ids = get_entity_ids(shelves) + + # place agents and shelves on warehouse grid + def place_agents_scan( + grid_and_agents: Tuple[chex.Array, chex.Array], agent_id: chex.Array + ) -> Tuple[Tuple[chex.Array, chex.Array], None]: + grid, agents = grid_and_agents + grid = place_entity_on_grid(grid, _AGENTS, agents, agent_id) + return (grid, agents), None + + def place_shelves_scan( + grid_and_shelves: Tuple[chex.Array, chex.Array], shelf_id: chex.Array + ) -> Tuple[Tuple[chex.Array, chex.Array], None]: + grid, shelves = grid_and_shelves + grid = place_entity_on_grid(grid, _SHELVES, shelves, shelf_id) + return (grid, shelves), None + + (grid, _), _ = jax.lax.scan(place_agents_scan, (grid, agents), agent_ids) + (grid, _), _ = jax.lax.scan(place_shelves_scan, (grid, shelves), shelf_ids) + return grid diff --git a/jumanji/environments/routing/robot_warehouse/utils_test.py b/jumanji/environments/routing/robot_warehouse/utils_test.py new file mode 100644 index 000000000..91b04d4b0 --- /dev/null +++ b/jumanji/environments/routing/robot_warehouse/utils_test.py @@ -0,0 +1,436 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import chex +import jax +import jax.numpy as jnp +import pytest + +from jumanji.environments.routing.robot_warehouse.types import ( + Action, + Agent, + Position, + Shelf, + State, +) +from jumanji.environments.routing.robot_warehouse.utils import ( + calculate_num_observation_features, + compute_action_mask, + get_valid_actions, + is_collision, + is_valid_action, + move_writer_index, + write_to_observation, +) +from jumanji.environments.routing.robot_warehouse.utils_agent import ( + get_agent_view, + get_new_direction_after_turn, + get_new_position_after_forward, + update_agent, +) +from jumanji.environments.routing.robot_warehouse.utils_shelf import update_shelf +from jumanji.environments.routing.robot_warehouse.utils_spawn import ( + place_entities_on_grid, +) +from jumanji.tree_utils import tree_slice + + +@pytest.fixture +def fake_robot_warehouse_env_state() -> State: + """Create a fake robot_warehouse environment state.""" + + # create agents, shelves and grid + def make_agent( + x: chex.Array, y: chex.Array, direction: chex.Array, is_carrying: chex.Array + ) -> Agent: + return Agent(Position(x=x, y=y), direction=direction, is_carrying=is_carrying) + + def make_shelf(x: chex.Array, y: chex.Array, is_requested: chex.Array) -> Shelf: + return Shelf(Position(x=x, y=y), is_requested=is_requested) + + # agent information + xs = jnp.array([3, 1]) + ys = jnp.array([4, 7]) + dirs = jnp.array([2, 3]) + carries = jnp.array([0, 0]) + agents = jax.vmap(make_agent)(xs, ys, dirs, carries) + + # shelf information + xs = jnp.array([1, 1, 1, 1, 2, 2, 2, 2]) + ys = jnp.array([1, 2, 7, 8, 1, 2, 7, 8]) + requested = jnp.array([0, 1, 1, 0, 0, 0, 1, 1]) + shelves = jax.vmap(make_shelf)(xs, ys, requested) + + # create grid + grid = jnp.array( + [ + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 0, 0, 0, 0, 3, 4, 0], + [0, 5, 6, 0, 0, 0, 0, 7, 8, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 2, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ) + action_mask = jnp.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]) + state = State( + grid=grid, + agents=agents, + shelves=shelves, + request_queue=jnp.array([1, 2, 6, 7]), + step_count=jnp.array([0], dtype=jnp.int32), + action_mask=action_mask, + key=jax.random.PRNGKey(42), + ) + return state + + +def test_robot_warehouse_utils__entity_placement( + fake_robot_warehouse_env_state: State, +) -> None: + """Test entity placement on the warehouse grid floor.""" + state = fake_robot_warehouse_env_state + agents, shelves = state.agents, state.shelves + empty_grid = jnp.zeros(state.grid.shape, dtype=jnp.int32) + grid_with_agents_and_shelves = place_entities_on_grid(empty_grid, agents, shelves) + + # check that placement is the same as in fake grid + assert jnp.all(grid_with_agents_and_shelves == state.grid) + + +def test_robot_warehouse_utils__entity_update( + fake_robot_warehouse_env_state: State, +) -> None: + """Test entity attribute (e.g. position, direction etc.) updating.""" + state = fake_robot_warehouse_env_state + agents = state.agents + shelves = state.shelves + + # test updating agent position + new_position = Position(x=2, y=4) + agents_with_new_agent_0_position = update_agent(agents, 0, "position", new_position) + agent_0 = tree_slice(agents_with_new_agent_0_position, 0) + assert agent_0.position == new_position + + # test updating agent direction + new_direction = 3 + agents_with_new_agent_0_direction = update_agent( + agents, 0, "direction", new_direction + ) + agent_0 = tree_slice(agents_with_new_agent_0_direction, 0) + assert agent_0.direction == new_direction + + # test updating agent carrying + new_is_carrying = 1 + agents_with_new_agent_0_carrying = update_agent( + agents, 0, "is_carrying", new_is_carrying + ) + agent_0 = tree_slice(agents_with_new_agent_0_carrying, 0) + assert agent_0.is_carrying == new_is_carrying + + # test updating shelf position + new_position = Position(x=1, y=3) + shelves_with_new_shelf_0_position = update_shelf( + shelves, 0, "position", new_position + ) + shelf_0 = tree_slice(shelves_with_new_shelf_0_position, 0) + assert shelf_0.position == new_position + + # test updating shelf requested + new_is_requested = 1 + shelves_with_new_shelf_0_requested = update_shelf( + shelves, 0, "is_requested", new_is_requested + ) + shelf_0 = tree_slice(shelves_with_new_shelf_0_requested, 0) + assert shelf_0.is_requested == new_is_requested + + +def test_robot_warehouse_utils__get_new_direction( + fake_robot_warehouse_env_state: State, +) -> None: + """Test the calculation of the new direction for an agent after turning.""" + state = fake_robot_warehouse_env_state + agents = state.agents + agent = tree_slice(agents, 0) + direction = agent.direction # 2 (facing down) + + # turning: left, left, right, right, right + actions = [2, 2, 3, 3, 3] + expected_directions = [ + 1, # turn left -> facing right + 0, # turn left -> facing up + 1, # turn right -> facing right + 2, # turn right -> face down + 3, # turn right -> face left + ] + + for action, expected_direction in zip(actions, expected_directions): + new_direction = get_new_direction_after_turn(action, direction) + assert new_direction == expected_direction + direction = new_direction + + +def test_robot_warehouse_utils__get_new_position( + fake_robot_warehouse_env_state: State, +) -> None: + """Test the calculation of the new position for an agent after moving + forward in a specific direction.""" + state = fake_robot_warehouse_env_state + grid = state.grid + agents = state.agents + agent = tree_slice(agents, 0) + position = agent.position # [x=3, y=4] + directions = jnp.arange(4) + + # move forward once in each direction + expected_positions = [ + Position(2, 4), # facing up move forward + Position(3, 5), # facing right move forward + Position(4, 4), # facing down move forward + Position(3, 3), # facing left move forward + ] + + for direction, expected_position in zip(directions, expected_positions): + new_position = get_new_position_after_forward(grid, position, direction) + assert new_position == expected_position + + +def test_robot_warehouse_utils__is_collision( + fake_robot_warehouse_env_state: State, +) -> None: + """Test the calculation of collisions between agents and other agents as well as + agents carrying shelves and other shelves.""" + state = fake_robot_warehouse_env_state + grid = state.grid + agents = state.agents + + # check no collision with original grid + agent = tree_slice(agents, 0) + action = 1 # forward + collision = is_collision(grid, agent, action) + assert bool(collision) is False + + # create grid with agents next to each other + grid_prior_to_collision = jnp.array( + [ + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 0, 0, 0, 0, 3, 4, 0], + [0, 5, 6, 0, 0, 0, 0, 7, 8, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 2, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ) + + # update agent zero to face agent 2 + agents = update_agent(agents, 0, "position", Position(1, 6)) + agents = update_agent(agents, 0, "direction", 1) + agent = tree_slice(agents, 0) + + # check collision if moving forward + collision = is_collision(grid_prior_to_collision, agent, action) + assert bool(collision) is True + + +def test_robot_warehouse_utils__is_valid_action( + fake_robot_warehouse_env_state: State, +) -> None: + """Test the calculation of collisions between agents and other agents as well as + agents carrying shelves and other shelves.""" + state = fake_robot_warehouse_env_state + grid = state.grid + agents = state.agents + + # turn agent 2 around, and check no collision with shelf + # i.e. agent is moving underneath shelf rack via highway + agents = update_agent(agents, 1, "direction", 1) + agent = tree_slice(agents, 1) + action = 1 # forward + action = is_valid_action(grid, agent, action) + assert action == Action.FORWARD.value + + # Let agent 2 pick up shelf and move forward + # to test collision with shelf when carrying + # and convert to NOOP action + agents = update_agent(agents, 1, "is_carrying", 1) + agent = tree_slice(agents, 1) + action = is_valid_action(grid, agent, action) + assert action == Action.NOOP.value + + +def test_robot_warehouse_utils__get_agent_view( + fake_robot_warehouse_env_state: State, +) -> None: + """Test extracting the agent's view of other agents and shelves within + its receptive field as set via a given sensor range.""" + state = fake_robot_warehouse_env_state + grid = jnp.array( + [ + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 0, 0, 0, 0, 3, 4, 0], + [0, 5, 6, 0, 0, 0, 0, 7, 8, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 2, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ) + agents = state.agents + agents = update_agent(agents, 0, "position", Position(1, 6)) + agent = tree_slice(agents, 0) + + # get agent view with sensor range of 1 + sensor_range = 1 + agent_view_of_agents, agent_view_of_shelves = get_agent_view( + grid, agent, sensor_range + ) + + # flattened agent view of other agents and shelves + flat_agents = jnp.array([0, 0, 0, 0, 1, 2, 0, 0, 0]) + flat_shelves = jnp.array([0, 0, 0, 0, 0, 3, 0, 0, 7]) + + assert jnp.array_equal(agent_view_of_agents, flat_agents) + assert jnp.array_equal(agent_view_of_shelves, flat_shelves) + + # get agent view with sensor range of 2 + sensor_range = 2 + agent_view_of_agents, agent_view_of_shelves = get_agent_view( + grid, agent, sensor_range + ) + + # flattened agent view of other agents and shelves + flat_agents = jnp.array( + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ) + flat_shelves = jnp.array( + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 4, 0, 0, 0, 7, 8, 0, 0, 0, 0, 0] + ) + + assert jnp.array_equal(agent_view_of_agents, flat_agents) + assert jnp.array_equal(agent_view_of_shelves, flat_shelves) + + +def test_robot_warehouse_utils__calculate_num_observation_features() -> None: + """Test the calculation of the size of the agent's observation + vector based on sensor range.""" + sensor_range = 1 + num_obs_features = calculate_num_observation_features(sensor_range) + assert num_obs_features == 66 + + sensor_range = 2 + num_obs_features = calculate_num_observation_features(sensor_range) + assert num_obs_features == 178 + + +def test_robot_warehouse_utils__observation_writer( + fake_robot_warehouse_env_state: State, +) -> None: + """Test observation writer to write data to 1-d observation vector. + Note that this test does not construct a full observation vector. It + only tests basic functionality by writing the agent's view of itself + and does not include writing agent view data from other agents/shelves.""" + state = fake_robot_warehouse_env_state + agents = state.agents + agent = tree_slice(agents, 0) + + # write flattened observation just for the agent's own view + obs = jnp.zeros(8, dtype=jnp.int32) + idx = 0 + + # write current agent position and whether carrying a shelf or not + obs, idx = write_to_observation( + obs, + idx, + jnp.array( + [agent.position.x, agent.position.y, agent.is_carrying], + dtype=jnp.int32, + ), + ) + + # write current agent direction + direction = jax.nn.one_hot(agent.direction, 4, dtype=jnp.int32) + obs, idx = write_to_observation(obs, idx, direction) + + # move index by one (keeping zero to indicate agent not on highway) + idx = move_writer_index(idx, 1) + + assert jnp.array_equal(obs, jnp.array([3, 4, 0, 0, 0, 1, 0, 0])) + assert idx == 8 + + +def test_robot_warehouse_utils__compute_action_mask( + fake_robot_warehouse_env_state: State, +) -> None: + state = fake_robot_warehouse_env_state + grid = state.grid + agents = state.agents + + action_mask = compute_action_mask(grid, agents) + assert jnp.array_equal(action_mask[1], jnp.array([1, 1, 1, 1, 1])) + + # Let agent 2 turn around, pick up shelf and move forward + # to test collision with shelf when carrying + # which is an illegal action + agents = update_agent(agents, 1, "direction", 1) + agents = update_agent(agents, 1, "is_carrying", 1) + + action_mask = compute_action_mask(grid, agents) + assert jnp.array_equal(action_mask[1], jnp.array([1, 0, 1, 1, 1])) + + +def test_robot_warehouse_utils__get_valid_action( + fake_robot_warehouse_env_state: State, +) -> None: + state = fake_robot_warehouse_env_state + grid = state.grid + agents = state.agents + actions = jnp.array([1, 1]) # forward + + action_mask = compute_action_mask(grid, agents) + actions = get_valid_actions(actions, action_mask) + jax.debug.print("action, {a}", a=actions) + assert jnp.array_equal(actions, jnp.array([1, 1])) + + # Let agent 2 turn around, pick up shelf and move forward + # to test collision with shelf when carrying + # which is an illegal action + agents = update_agent(agents, 1, "direction", 1) + agents = update_agent(agents, 1, "is_carrying", 1) + + action_mask = compute_action_mask(grid, agents) + actions = get_valid_actions(actions, action_mask) + assert jnp.array_equal(actions, jnp.array([1, 0])) # turn into noop action diff --git a/jumanji/environments/routing/robot_warehouse/viewer.py b/jumanji/environments/routing/robot_warehouse/viewer.py new file mode 100644 index 000000000..5ba135748 --- /dev/null +++ b/jumanji/environments/routing/robot_warehouse/viewer.py @@ -0,0 +1,320 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa: CCR001 + +from typing import Callable, Optional, Sequence, Tuple + +import chex +import matplotlib.animation as animation +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.collections import LineCollection +from numpy.typing import NDArray + +import jumanji +import jumanji.environments.routing.robot_warehouse.constants as constants +from jumanji.environments.routing.robot_warehouse.types import Direction, State +from jumanji.tree_utils import tree_slice +from jumanji.viewer import Viewer + + +class RobotWarehouseViewer(Viewer): + def __init__( + self, + grid_size: Tuple[int, int], + goals: chex.Array, + name: str = "RobotWarehouse", + render_mode: str = "human", + ) -> None: + """Viewer for the RobotWarehouse environment. + + Args: + grid_size: the size of the warehouse floor grid (width, height) + goals: x,y coordinates of goal locations (where shelves + should be delivered) + name: custom name for the Viewer. Defaults to `RobotWarehouse`. + """ + self._name = name + self.goals = goals + self.rows, self.cols = grid_size + + self.grid_size = 30 + self.icon_size = 20 + + self.width = 1 + self.cols * (self.grid_size + 1) + self.height = 1 + self.rows * (self.grid_size + 1) + self._display: Callable[[plt.Figure], Optional[NDArray]] + if render_mode == "rgb_array": + self._display = self._display_rgb_array + elif render_mode == "human": + self._display = self._display_human + else: + raise ValueError(f"Invalid render mode: {render_mode}") + + # The animation must be stored in a variable that lives as long as the + # animation should run. Otherwise, the animation will get garbage-collected. + self._animation: Optional[animation.Animation] = None + + def render(self, state: State) -> Optional[NDArray]: + """Render the given state of the `RobotWarehouse` environment. + + Args: + state: the environment state to render. + """ + self._clear_display() + fig, ax = self._get_fig_ax() + ax.clear() + self._prepare_figure(ax) + self._draw_state(ax, state) + return self._display(fig) + + def animate( + self, + states: Sequence[State], + interval: int = 200, + save_path: Optional[str] = None, + ) -> animation.FuncAnimation: + """Create an animation from a sequence of environment states. + + Args: + states: sequence of environment states corresponding to consecutive timesteps. + interval: delay between frames in milliseconds, default to 200. + save_path: the path where the animation file should be saved. If it is None, the plot + will not be saved. + + Returns: + Animation that can be saved as a GIF, MP4, or rendered with HTML. + """ + fig = plt.figure(f"{self._name}Animation", figsize=constants._FIGURE_SIZE) + fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0) + ax = fig.add_subplot(111) + plt.close(fig) + self._prepare_figure(ax) + + def make_frame(state: State) -> None: + ax.clear() + self._prepare_figure(ax) + self._draw_state(ax, state) + + # Create the animation object. + self._animation = animation.FuncAnimation( + fig, + make_frame, + frames=states, + interval=interval, + ) + + # Save the animation as a gif. + if save_path: + self._animation.save(save_path) + + return self._animation + + def close(self) -> None: + plt.close(self._name) + + def _clear_display(self) -> None: + if jumanji.environments.is_colab(): + import IPython.display + + IPython.display.clear_output(True) + + def _get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]: + recreate = not plt.fignum_exists(self._name) + fig = plt.figure(self._name, figsize=constants._FIGURE_SIZE) + fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0) + + if recreate: + fig.tight_layout() + if not plt.isinteractive(): + fig.show() + ax = fig.add_subplot(111) + else: + ax = fig.get_axes()[0] + return fig, ax + + def _prepare_figure(self, ax: plt.Axes) -> None: + ax.set_xlim(0, self.width) + ax.set_ylim(0, self.height) + ax.patch.set_alpha(0.0) + ax.set_axis_off() + + ax.set_aspect("equal", "box") + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_frame_on(False) + + def _draw_state(self, ax: plt.Axes, state: State) -> None: + self.n_agents = state.agents.position.x.shape[0] + self.n_shelves = state.shelves.position.x.shape[0] + self._draw_grid(ax) + self._draw_goals(ax) + self._draw_shelves(ax, state.shelves) + self._draw_agents(ax, state.agents) + + def _draw_grid(self, ax: plt.Axes) -> None: + """Draw grid of warehouse floor.""" + lines = [] + # VERTICAL LINES + for r in range(self.rows + 1): + lines.append( + [ + (0, (self.grid_size + 1) * r + 1), + ((self.grid_size + 1) * self.cols, (self.grid_size + 1) * r + 1), + ] + ) + + # HORIZONTAL LINES + for c in range(self.cols + 1): + lines.append( + [ + ((self.grid_size + 1) * c + 1, 0), + ((self.grid_size + 1) * c + 1, (self.grid_size + 1) * self.rows), + ] + ) + + lc = LineCollection(lines, colors=(constants._GRID_COLOR,)) + ax.add_collection(lc) + + def _draw_goals(self, ax: plt.Axes) -> None: + """Draw goals, i.e. positions where shelves should be delivered.""" + for goal in self.goals: + x, y = goal + y = self.rows - y - 1 # pyglet rendering is reversed + ax.fill( # changed to ax, from plt, check if still works! + [ + x * (self.grid_size + 1) + 1, + (x + 1) * (self.grid_size + 1), + (x + 1) * (self.grid_size + 1), + x * (self.grid_size + 1) + 1, + ], + [ + y * (self.grid_size + 1) + 1, + y * (self.grid_size + 1) + 1, + (y + 1) * (self.grid_size + 1), + (y + 1) * (self.grid_size + 1), + ], + color=constants._GOAL_COLOR, + alpha=1, + ) + + def _draw_shelves(self, ax: plt.Axes, shelves: chex.Array) -> None: + """Draw shelves at their respective positions. + + Args: + shelves: a pytree of Shelf type containing shelves information. + """ + for shelf_id in range(self.n_shelves): + shelf = tree_slice(shelves, shelf_id) + y, x = shelf.position.x, shelf.position.y + y = self.rows - y - 1 # pyglet rendering is reversed + shelf_color = ( + constants._SHELF_REQ_COLOR + if shelf.is_requested + else constants._SHELF_COLOR + ) + shelf_padding = constants._SHELF_PADDING + + x_points = [ + (self.grid_size + 1) * x + shelf_padding + 1, + (self.grid_size + 1) * (x + 1) - shelf_padding, + (self.grid_size + 1) * (x + 1) - shelf_padding, + (self.grid_size + 1) * x + shelf_padding + 1, + ] + + y_points = [ + (self.grid_size + 1) * y + shelf_padding + 1, + (self.grid_size + 1) * y + shelf_padding + 1, + (self.grid_size + 1) * (y + 1) - shelf_padding, + (self.grid_size + 1) * (y + 1) - shelf_padding, + ] + + ax.fill(x_points, y_points, color=shelf_color) + + def _draw_agents(self, ax: plt.Axes, agents: chex.Array) -> None: + """Draw agents at their respective positions. + + Args: + agents: a pytree of Shelf type containing agents information. + """ + radius = self.grid_size / 3 + + resolution = 6 + + for agent_id in range(self.n_agents): + agent = tree_slice(agents, agent_id) + row, col = agent.position.x, agent.position.y + row = self.rows - row - 1 # pyglet rendering is reversed + x_center = (self.grid_size + 1) * col + self.grid_size // 2 + 1 + y_center = (self.grid_size + 1) * row + self.grid_size // 2 + 1 + + # make a circle + verts = [] + for i in range(resolution): + angle = 2 * np.pi * i / resolution + + x_radius = radius * np.cos(angle) + x = x_radius + x_center + 1 + + y_radius = radius * np.sin(angle) + 1 + y = y_radius + y_center + verts += [[x, y]] + facecolor = ( + constants._AGENT_LOADED_COLOR + if agent.is_carrying + else constants._AGENT_COLOR + ) + circle = plt.Polygon( + verts, + edgecolor="none", + facecolor=facecolor, + ) + + ax.add_patch(circle) + + agent_dir = agent.direction + + x_dir = ( + x_center + + (radius if agent_dir == Direction.RIGHT.value else 0) + - (radius if agent_dir == Direction.LEFT.value else 0) + ) + y_dir = ( + y_center + + (radius if agent_dir == Direction.UP.value else 0) + - (radius if agent_dir == Direction.DOWN.value else 0) + ) + + ax.plot( + [x_center, x_dir], + [y_center, y_dir], + color=constants._AGENT_DIR_COLOR, + linewidth=2, + ) + + def _display_human(self, fig: plt.Figure) -> None: + if plt.isinteractive(): + # Required to update render when using Jupyter Notebook. + fig.canvas.draw() + if jumanji.environments.is_colab(): + plt.show(self._name) + else: + # Required to update render when not using Jupyter Notebook. + fig.canvas.draw_idle() + fig.canvas.flush_events() + + def _display_rgb_array(self, fig: plt.Figure) -> NDArray: + fig.canvas.draw() + return np.asarray(fig.canvas.buffer_rgba()) diff --git a/jumanji/environments/routing/robot_warehouse/viewer_test.py b/jumanji/environments/routing/robot_warehouse/viewer_test.py new file mode 100644 index 000000000..e4416db29 --- /dev/null +++ b/jumanji/environments/routing/robot_warehouse/viewer_test.py @@ -0,0 +1,82 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +import jax.random as random +import matplotlib +import matplotlib.pyplot as plt +import numpy as jnp +import py +import pytest + +from jumanji.environments.routing.robot_warehouse import RobotWarehouse +from jumanji.environments.routing.robot_warehouse.viewer import RobotWarehouseViewer + + +def test_robot_warehouse_viewer__render( + robot_warehouse_env: RobotWarehouse, monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.setattr(plt, "show", lambda fig: None) + key = random.PRNGKey(0) + state, _ = robot_warehouse_env.reset(key) + grid_size = robot_warehouse_env._generator.grid_size + goals = robot_warehouse_env._generator.goals + + viewer = RobotWarehouseViewer(grid_size, goals) + viewer.render(state) + viewer.close() + + +def test_robot_warehouse_viewer__animate(robot_warehouse_env: RobotWarehouse) -> None: + key = random.PRNGKey(0) + state, _ = jax.jit(robot_warehouse_env.reset)(key) + grid_size = robot_warehouse_env._generator.grid_size + goals = robot_warehouse_env._generator.goals + + num_steps = 5 + states = [state] + for _ in range(num_steps - 1): + key, subkey = jax.random.split(key) + action = jax.random.choice(subkey, jnp.arange(5), shape=(2,)) + state, _ = jax.jit(robot_warehouse_env.step)(state, action) + states.append(state) + + viewer = RobotWarehouseViewer(grid_size, goals) + viewer.animate(states) + viewer.close() + + +def test_robot_warehouse_viewer__save_animation( + robot_warehouse_env: RobotWarehouse, tmpdir: py.path.local +) -> None: + key = random.PRNGKey(0) + state, _ = jax.jit(robot_warehouse_env.reset)(key) + grid_size = robot_warehouse_env._generator.grid_size + goals = robot_warehouse_env._generator.goals + + num_steps = 5 + states = [state] + for _ in range(num_steps - 1): + key, subkey = jax.random.split(key) + action = jax.random.choice(subkey, jnp.arange(5), shape=(2,)) + state, _ = jax.jit(robot_warehouse_env.step)(state, action) + states.append(state) + + viewer = RobotWarehouseViewer(grid_size, goals) + animation = viewer.animate(states) + assert isinstance(animation, matplotlib.animation.Animation) + + save_path = str(tmpdir.join("/robot_warehouse_animation_test.gif")) + animation.save(save_path) + viewer.close() diff --git a/jumanji/training/configs/config.yaml b/jumanji/training/configs/config.yaml index 79ee4e830..47a02c704 100644 --- a/jumanji/training/configs/config.yaml +++ b/jumanji/training/configs/config.yaml @@ -1,6 +1,6 @@ defaults: - _self_ - - env: snake # [bin_pack, cleaner, connector, cvrp, game_2048, job_shop, knapsack, maze, minesweeper, rubiks_cube, snake, sudoku, tsp] + - env: snake # [bin_pack, cleaner, connector, cvrp, game_2048, job_shop, knapsack, maze, minesweeper, rubiks_cube, robot_warehouse, snake, sudoku, tsp] agent: random # [random, a2c] diff --git a/jumanji/training/configs/env/graph_coloring.yaml b/jumanji/training/configs/env/graph_coloring.yaml new file mode 100644 index 000000000..d6e73fdec --- /dev/null +++ b/jumanji/training/configs/env/graph_coloring.yaml @@ -0,0 +1,27 @@ +name: graph_coloring +registered_version: GraphColoring-v0 + +network: + num_transformer_layers: 2 + transformer_num_heads: 8 + transformer_key_size: 16 + transformer_mlp_units: [512] + +training: + num_epochs: 500 + num_learner_steps_per_epoch: 100 + n_steps: 20 + total_batch_size: 64 + +evaluation: + eval_total_batch_size: 50000 + greedy_eval_total_batch_size: 50000 + +a2c: + normalize_advantage: False + discount_factor: 1.0 + bootstrapping_factor: 0.95 + l_pg: 1.0 + l_td: 1.0 + l_en: 1e-2 + learning_rate: 1e-4 diff --git a/jumanji/training/configs/env/rware.yaml b/jumanji/training/configs/env/rware.yaml new file mode 100644 index 000000000..7d52c8f01 --- /dev/null +++ b/jumanji/training/configs/env/rware.yaml @@ -0,0 +1,27 @@ +name: robot_warehouse +registered_version: RobotWarehouse-v0 + +network: + transformer_num_blocks: 4 + transformer_num_heads: 8 + transformer_key_size: 16 + transformer_mlp_units: [512] + +training: + num_epochs: 500 + num_learner_steps_per_epoch: 100 + n_steps: 20 + total_batch_size: 128 + +evaluation: + eval_total_batch_size: 5000 + greedy_eval_total_batch_size: 5000 + +a2c: + normalize_advantage: False + discount_factor: 0.99 + bootstrapping_factor: 0.95 + l_pg: 1.0 + l_td: 1.0 + l_en: 0.01 + learning_rate: 1e-4 diff --git a/jumanji/training/networks/__init__.py b/jumanji/training/networks/__init__.py index c13c99562..1ed0cd476 100644 --- a/jumanji/training/networks/__init__.py +++ b/jumanji/training/networks/__init__.py @@ -34,6 +34,12 @@ make_actor_critic_networks_game_2048, ) from jumanji.training.networks.game_2048.random import make_random_policy_game_2048 +from jumanji.training.networks.graph_coloring.actor_critic import ( + make_actor_critic_networks_graph_coloring, +) +from jumanji.training.networks.graph_coloring.random import ( + make_random_policy_graph_coloring, +) from jumanji.training.networks.job_shop.actor_critic import ( make_actor_critic_networks_job_shop, ) @@ -48,6 +54,12 @@ make_actor_critic_networks_minesweeper, ) from jumanji.training.networks.minesweeper.random import make_random_policy_minesweeper +from jumanji.training.networks.robot_warehouse.actor_critic import ( + make_actor_critic_networks_robot_warehouse, +) +from jumanji.training.networks.robot_warehouse.random import ( + make_random_policy_robot_warehouse, +) from jumanji.training.networks.rubiks_cube.actor_critic import ( make_actor_critic_networks_rubiks_cube, ) diff --git a/jumanji/training/networks/graph_coloring/__init__.py b/jumanji/training/networks/graph_coloring/__init__.py new file mode 100644 index 000000000..21db9ec1c --- /dev/null +++ b/jumanji/training/networks/graph_coloring/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jumanji/training/networks/graph_coloring/actor_critic.py b/jumanji/training/networks/graph_coloring/actor_critic.py new file mode 100644 index 000000000..2833061c0 --- /dev/null +++ b/jumanji/training/networks/graph_coloring/actor_critic.py @@ -0,0 +1,277 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional, Sequence + +import chex +import haiku as hk +import jax.numpy as jnp + +from jumanji.environments.logic.graph_coloring import GraphColoring, Observation +from jumanji.training.networks.actor_critic import ( + ActorCriticNetworks, + FeedForwardNetwork, +) +from jumanji.training.networks.parametric_distribution import ( + CategoricalParametricDistribution, +) +from jumanji.training.networks.transformer_block import TransformerBlock + + +def make_actor_critic_networks_graph_coloring( + graph_coloring: GraphColoring, + num_transformer_layers: int, + transformer_num_heads: int, + transformer_key_size: int, + transformer_mlp_units: Sequence[int], +) -> ActorCriticNetworks: + """Make actor-critic networks for the `GraphColoring` environment.""" + num_actions = graph_coloring.action_spec().num_values + parametric_action_distribution = CategoricalParametricDistribution( + num_actions=num_actions + ) + policy_network = make_actor_network_graph_coloring( + num_actions=num_actions, + num_transformer_layers=num_transformer_layers, + transformer_num_heads=transformer_num_heads, + transformer_key_size=transformer_key_size, + transformer_mlp_units=transformer_mlp_units, + ) + value_network = make_critic_network_graph_coloring( + num_actions=num_actions, + num_transformer_layers=num_transformer_layers, + transformer_num_heads=transformer_num_heads, + transformer_key_size=transformer_key_size, + transformer_mlp_units=transformer_mlp_units, + ) + return ActorCriticNetworks( + policy_network=policy_network, + value_network=value_network, + parametric_action_distribution=parametric_action_distribution, + ) + + +class GraphColoringTorso(hk.Module): + def __init__( + self, + num_transformer_layers: int, + transformer_num_heads: int, + transformer_key_size: int, + transformer_mlp_units: Sequence[int], + name: Optional[str] = None, + ): + super().__init__(name=name) + self.num_transformer_layers = num_transformer_layers + self.transformer_num_heads = transformer_num_heads + self.transformer_key_size = transformer_key_size + self.transformer_mlp_units = transformer_mlp_units + self.model_size = transformer_num_heads * transformer_key_size + + def _make_self_attention_mask(self, adj_matrix: chex.Array) -> chex.Array: + # Expand on the head dimension. + mask = jnp.expand_dims(adj_matrix, axis=-3) + return mask + + def __call__(self, observation: Observation) -> chex.Array: + """Transforms the observation using a series of transformations. + + The observation is composed of the following components: + - observation.action_mask: Represents the colors that are available for the current node. + - observation.colors: Represents the colors assigned to each node. + Nodes without an assigned color are marked with -1. + - observation.current_node_index: Represents the node currently being considered. + + The function first determines which colors are used and which nodes are colored. + Then it embeds the colors and nodes, and creates a mask for the adjacency matrix. + It further creates two masks to track the relation between colors and nodes. + Then the function applies a series of transformer blocks to compute: + self-attention on nodes and colors and the cross-attention between nodes and colors. + Finally, it extracts the embedding for the current node and computes a new embedding. + + Args: + observation: the observation to be transformed. + + Returns: + new_embedding: the transformed observation. + """ + + batch_size, num_nodes = observation.colors.shape + colors_used = jnp.isin(observation.colors, jnp.arange(num_nodes)) + color_embeddings = hk.Linear(self.model_size)( + colors_used[..., None].astype(float) + ) # Shape (batch_size, num_colors, 128) + + nodes_colored = observation.colors >= 0 + + color_embeddings = hk.Linear(self.model_size)( + colors_used[..., None].astype(float) + ) # Shape (batch_size, num_colors, 128) + + node_embeddings = hk.Linear(self.model_size)( + nodes_colored[..., None].astype(float) + ) # Shape (batch_size, num_colors, 128) + + mask = self._make_self_attention_mask( + observation.adj_matrix + ) # Shape (batch_size, 1, num_nodes, num_nodes) + + colors_array = jnp.expand_dims( + observation.colors, axis=1 + ) # Shape (batch_size, 1, num_nodes) + color_indices = jnp.arange(observation.colors.shape[1]) # Shape (num_colors,) + color_indices = color_indices[None, :, None] # Shape (1, num_colors, 1) + + colors_cross_nodes_mask = ( + colors_array == color_indices + ) # Shape (batch_size, num_colors, num_nodes) + + # Expand along the transformer_num_heads axis + colors_cross_nodes_mask = jnp.expand_dims( + colors_cross_nodes_mask, axis=1 + ) # Shape (batch_size, 1, num_colors, num_nodes) + + colors_array = jnp.expand_dims( + observation.colors, axis=-1 + ) # Shape (batch_size, num_nodes, 1) + color_indices = jnp.arange(observation.colors.shape[1]) # Shape (num_colors,) + color_indices = color_indices[None, None] # Shape (1, 1, num_colors) + + nodes_cross_colors_mask = ( + colors_array == color_indices + ) # Shape (batch_size, num_nodes, num_colors) + + # Expand along the transformer_num_heads axis + nodes_cross_colors_mask = jnp.expand_dims( + nodes_cross_colors_mask, axis=1 + ) # Shape (batch_size, 1, num_nodes, num_colors) + + for block_id in range(self.num_transformer_layers): + # Self-attention on nodes. + node_embeddings = TransformerBlock( + num_heads=self.transformer_num_heads, + key_size=self.transformer_key_size, + mlp_units=self.transformer_mlp_units, + w_init_scale=2 / self.num_transformer_layers, + model_size=self.model_size, + name=f"self_attention_nodes_block_{block_id}", + )(node_embeddings, node_embeddings, node_embeddings, mask) + + # Self-attention on colors. + color_embeddings = TransformerBlock( + num_heads=self.transformer_num_heads, + key_size=self.transformer_key_size, + mlp_units=self.transformer_mlp_units, + w_init_scale=2 / self.num_transformer_layers, + model_size=self.model_size, + name=f"self_attention_colors_block_{block_id}", + )(color_embeddings, color_embeddings, color_embeddings) + + # Cross-attention between nodes and colors. + new_node_embeddings = TransformerBlock( + num_heads=self.transformer_num_heads, + key_size=self.transformer_key_size, + mlp_units=self.transformer_mlp_units, + w_init_scale=2 / self.num_transformer_layers, + model_size=self.model_size, + name=f"cross_attention_node_color_block_{block_id}", + )( + node_embeddings, + color_embeddings, + color_embeddings, + nodes_cross_colors_mask, + ) + + # Cross-attention between colors and nodes. + color_embeddings = TransformerBlock( + num_heads=self.transformer_num_heads, + key_size=self.transformer_key_size, + mlp_units=self.transformer_mlp_units, + w_init_scale=2 / self.num_transformer_layers, + model_size=self.model_size, + name=f"cross_attention_color_node_block_{block_id}", + )( + color_embeddings, + node_embeddings, + node_embeddings, + colors_cross_nodes_mask, + ) + + node_embeddings = new_node_embeddings + + current_node_embeddings = jnp.take( + node_embeddings, observation.current_node_index, axis=1 + ) + new_embedding = TransformerBlock( + num_heads=self.transformer_num_heads, + key_size=self.transformer_key_size, + mlp_units=self.transformer_mlp_units, + w_init_scale=2 / self.num_transformer_layers, + model_size=self.model_size, + name=f"cross_attention_color_node_block_{block_id+1}", + )(color_embeddings, current_node_embeddings, current_node_embeddings) + + return new_embedding + + +def make_actor_network_graph_coloring( + num_actions: int, + num_transformer_layers: int, + transformer_num_heads: int, + transformer_key_size: int, + transformer_mlp_units: Sequence[int], +) -> FeedForwardNetwork: + def network_fn(observation: Observation) -> chex.Array: + torso = GraphColoringTorso( + num_transformer_layers=num_transformer_layers, + transformer_num_heads=transformer_num_heads, + transformer_key_size=transformer_key_size, + transformer_mlp_units=transformer_mlp_units, + name="policy_torso", + ) + embeddings = torso(observation) # (B, N, H) + logits = hk.nets.MLP((torso.model_size, 1), name="policy_head")( + embeddings + ) # (B, N, 1) + logits = jnp.squeeze(logits, axis=-1) # (B, N) + logits = jnp.where(observation.action_mask, logits, jnp.finfo(jnp.float32).min) + return logits + + init, apply = hk.without_apply_rng(hk.transform(network_fn)) + return FeedForwardNetwork(init=init, apply=apply) + + +def make_critic_network_graph_coloring( + num_actions: int, + num_transformer_layers: int, + transformer_num_heads: int, + transformer_key_size: int, + transformer_mlp_units: Sequence[int], +) -> FeedForwardNetwork: + def network_fn(observation: Observation) -> chex.Array: + torso = GraphColoringTorso( + num_transformer_layers=num_transformer_layers, + transformer_num_heads=transformer_num_heads, + transformer_key_size=transformer_key_size, + transformer_mlp_units=transformer_mlp_units, + name="critic_torso", + ) + embeddings = torso(observation) + + embedding = jnp.mean(embeddings, axis=-2) + value = hk.nets.MLP((torso.model_size, 1), name="critic_head")(embedding) + return jnp.squeeze(value, axis=-1) + + init, apply = hk.without_apply_rng(hk.transform(network_fn)) + return FeedForwardNetwork(init=init, apply=apply) diff --git a/jumanji/training/networks/graph_coloring/random.py b/jumanji/training/networks/graph_coloring/random.py new file mode 100644 index 000000000..17e9ce89f --- /dev/null +++ b/jumanji/training/networks/graph_coloring/random.py @@ -0,0 +1,23 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jumanji.training.networks.masked_categorical_random import ( + masked_categorical_random, +) +from jumanji.training.networks.protocols import RandomPolicy + + +def make_random_policy_graph_coloring() -> RandomPolicy: + """Make random policy for GraphColoring.""" + return masked_categorical_random diff --git a/jumanji/training/networks/robot_warehouse/__init__.py b/jumanji/training/networks/robot_warehouse/__init__.py new file mode 100644 index 000000000..21db9ec1c --- /dev/null +++ b/jumanji/training/networks/robot_warehouse/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jumanji/training/networks/robot_warehouse/actor_critic.py b/jumanji/training/networks/robot_warehouse/actor_critic.py new file mode 100644 index 000000000..a1aca10cd --- /dev/null +++ b/jumanji/training/networks/robot_warehouse/actor_critic.py @@ -0,0 +1,168 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Sequence + +import chex +import haiku as hk +import jax.numpy as jnp +import numpy as np + +from jumanji.environments import RobotWarehouse +from jumanji.environments.routing.robot_warehouse.types import Observation +from jumanji.training.networks.actor_critic import ( + ActorCriticNetworks, + FeedForwardNetwork, +) +from jumanji.training.networks.parametric_distribution import ( + MultiCategoricalParametricDistribution, +) +from jumanji.training.networks.transformer_block import TransformerBlock + + +def make_actor_critic_networks_robot_warehouse( + robot_warehouse: RobotWarehouse, + transformer_num_blocks: int, + transformer_num_heads: int, + transformer_key_size: int, + transformer_mlp_units: Sequence[int], +) -> ActorCriticNetworks: + """Make actor-critic networks for the `RobotWarehouse` environment.""" + num_values = np.asarray(robot_warehouse.action_spec().num_values) + parametric_action_distribution = MultiCategoricalParametricDistribution( + num_values=num_values + ) + policy_network = make_actor_network( + time_limit=robot_warehouse.time_limit, + transformer_num_blocks=transformer_num_blocks, + transformer_num_heads=transformer_num_heads, + transformer_key_size=transformer_key_size, + transformer_mlp_units=transformer_mlp_units, + ) + value_network = make_critic_network( + time_limit=robot_warehouse.time_limit, + transformer_num_blocks=transformer_num_blocks, + transformer_num_heads=transformer_num_heads, + transformer_key_size=transformer_key_size, + transformer_mlp_units=transformer_mlp_units, + ) + return ActorCriticNetworks( + policy_network=policy_network, + value_network=value_network, + parametric_action_distribution=parametric_action_distribution, + ) + + +class RobotWarehouseTorso(hk.Module): + def __init__( + self, + transformer_num_blocks: int, + transformer_num_heads: int, + transformer_key_size: int, + transformer_mlp_units: Sequence[int], + env_time_limit: int, + name: Optional[str] = None, + ): + super().__init__(name=name) + self.transformer_num_blocks = transformer_num_blocks + self.transformer_num_heads = transformer_num_heads + self.transformer_key_size = transformer_key_size + self.transformer_mlp_units = transformer_mlp_units + self.model_size = transformer_num_heads * transformer_key_size + self.env_time_limit = env_time_limit + + def __call__(self, observation: Observation) -> chex.Array: + # Shape names: + # B: batch size + # N: number of agents + # O: observation size + # H: hidden/embedding size + # (B, N, O) + _, num_agents, _ = observation.agents_view.shape + + percent_done = observation.step_count / self.env_time_limit + step = jnp.repeat(percent_done[:, None], num_agents, axis=-1)[..., None] + agents_view = observation.agents_view + + # join step count and agent view to embed both at the same time + # (B, N, O + 1) + obs = jnp.concatenate((agents_view, step), axis=-1) + # (B, N, O + 1) -> (B, N, H) + embeddings = hk.Linear(self.model_size)(obs) + + # (B, N, H) -> (B, N, H) + for block_id in range(self.transformer_num_blocks): + transformer_block = TransformerBlock( + num_heads=self.transformer_num_heads, + key_size=self.transformer_key_size, + mlp_units=self.transformer_mlp_units, + w_init_scale=2 / self.transformer_num_blocks, + model_size=self.model_size, + name=f"self_attention_block_{block_id}", + ) + embeddings = transformer_block( + query=embeddings, key=embeddings, value=embeddings + ) + return embeddings # (B, N, H) + + +def make_critic_network( + time_limit: int, + transformer_num_blocks: int, + transformer_num_heads: int, + transformer_key_size: int, + transformer_mlp_units: Sequence[int], +) -> FeedForwardNetwork: + def network_fn(observation: Observation) -> chex.Array: + torso = RobotWarehouseTorso( + transformer_num_blocks, + transformer_num_heads, + transformer_key_size, + transformer_mlp_units, + time_limit, + ) + embeddings = torso(observation) + embeddings = jnp.sum(embeddings, axis=-2) + + head = hk.nets.MLP((*transformer_mlp_units, 1), activate_final=False) + values = head(embeddings) # (B, 1) + return jnp.squeeze(values, axis=-1) # (B,) + + init, apply = hk.without_apply_rng(hk.transform(network_fn)) + return FeedForwardNetwork(init=init, apply=apply) + + +def make_actor_network( + time_limit: int, + transformer_num_blocks: int, + transformer_num_heads: int, + transformer_key_size: int, + transformer_mlp_units: Sequence[int], +) -> FeedForwardNetwork: + def network_fn(observation: Observation) -> chex.Array: + torso = RobotWarehouseTorso( + transformer_num_blocks, + transformer_num_heads, + transformer_key_size, + transformer_mlp_units, + time_limit, + ) + output = torso(observation) + + head = hk.nets.MLP((*transformer_mlp_units, 5), activate_final=False) + logits = head(output) # (B, N, 5) + return jnp.where(observation.action_mask, logits, jnp.finfo(jnp.float32).min) + + init, apply = hk.without_apply_rng(hk.transform(network_fn)) + return FeedForwardNetwork(init=init, apply=apply) diff --git a/jumanji/training/networks/robot_warehouse/random.py b/jumanji/training/networks/robot_warehouse/random.py new file mode 100644 index 000000000..eb0ba826e --- /dev/null +++ b/jumanji/training/networks/robot_warehouse/random.py @@ -0,0 +1,23 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jumanji.training.networks.masked_categorical_random import ( + masked_categorical_random, +) +from jumanji.training.networks.protocols import RandomPolicy + + +def make_random_policy_robot_warehouse() -> RandomPolicy: + """Make random policy for RobotWarehouse.""" + return masked_categorical_random diff --git a/jumanji/training/setup_train.py b/jumanji/training/setup_train.py index b64f4f8cd..1f1f9a458 100644 --- a/jumanji/training/setup_train.py +++ b/jumanji/training/setup_train.py @@ -29,10 +29,12 @@ Cleaner, Connector, Game2048, + GraphColoring, JobShop, Knapsack, Maze, Minesweeper, + RobotWarehouse, RubiksCube, Snake, Sudoku, @@ -168,6 +170,12 @@ def _setup_random_policy( # noqa: CCR001 elif cfg.env.name == "connector": assert isinstance(env.unwrapped, Connector) random_policy = networks.make_random_policy_connector() + elif cfg.env.name == "robot_warehouse": + assert isinstance(env.unwrapped, RobotWarehouse) + random_policy = networks.make_random_policy_robot_warehouse() + elif cfg.env.name == "graph_coloring": + assert isinstance(env.unwrapped, GraphColoring) + random_policy = networks.make_random_policy_graph_coloring() else: raise ValueError(f"Environment name not found. Got {cfg.env.name}.") return random_policy @@ -258,6 +266,14 @@ def _setup_actor_critic_neworks( # noqa: CCR001 key_size=cfg.env.network.key_size, policy_layers=cfg.env.network.policy_layers, value_layers=cfg.env.network.value_layers, + elif cfg.env.name == "robot_warehouse": + assert isinstance(env.unwrapped, RobotWarehouse) + actor_critic_networks = networks.make_actor_critic_networks_robot_warehouse( + robot_warehouse=env.unwrapped, + transformer_num_blocks=cfg.env.network.transformer_num_blocks, + transformer_num_heads=cfg.env.network.transformer_num_heads, + transformer_key_size=cfg.env.network.transformer_key_size, + transformer_mlp_units=cfg.env.network.transformer_mlp_units, ) elif cfg.env.name == "minesweeper": assert isinstance(env.unwrapped, Minesweeper) @@ -295,6 +311,15 @@ def _setup_actor_critic_neworks( # noqa: CCR001 transformer_mlp_units=cfg.env.network.transformer_mlp_units, conv_n_channels=cfg.env.network.conv_n_channels, ) + elif cfg.env.name == "graph_coloring": + assert isinstance(env.unwrapped, GraphColoring) + actor_critic_networks = networks.make_actor_critic_networks_graph_coloring( + graph_coloring=env.unwrapped, + num_transformer_layers=cfg.env.network.num_transformer_layers, + transformer_num_heads=cfg.env.network.transformer_num_heads, + transformer_key_size=cfg.env.network.transformer_key_size, + transformer_mlp_units=cfg.env.network.transformer_mlp_units, + ) else: raise ValueError(f"Environment name not found. Got {cfg.env.name}.") return actor_critic_networks