diff --git a/README.md b/README.md index 8ef413b7d..42b76f364 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,6 @@ | [**Docs**](https://instadeepai.github.io/jumanji) --- -
@@ -32,8 +31,6 @@
- - ## Welcome to the Jungle! π΄ Jumanji is a suite of diverse and challenging reinforcement learning (RL) environments written in @@ -70,7 +67,6 @@ JAX-based environments. - ποΈ **Training:** example agents that can be used as inspiration for the agents one may implement in their research. - ## Environments π Jumanji provides a diverse range of environments ranging from simple games to NP-hard combinatorial @@ -88,20 +84,24 @@ problems. | :link: Connector | Routing | `Connector-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/connector/) | [doc](https://instadeepai.github.io/jumanji/environments/connector/) | | π CVRP (Capacitated Vehicle Routing Problem) | Routing | `CVRP-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/cvrp/) | [doc](https://instadeepai.github.io/jumanji/environments/cvrp/) | | :mag: Maze | Routing | `Maze-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/maze/) | [doc](https://instadeepai.github.io/jumanji/environments/maze/) | +| :robot: RobotWarehouse | Routing | `RobotWarehouse-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/robot_warehouse/) | [doc](https://instadeepai.github.io/jumanji/environments/robot_warehouse/) | | π Snake | Routing | `Snake-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/snake/) | [doc](https://instadeepai.github.io/jumanji/environments/snake/) | | π¬ TSP (Travelling Salesman Problem) | Routing | `TSP-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/tsp/) | [doc](https://instadeepai.github.io/jumanji/environments/tsp/) | - ## Installation π¬ You can install the latest release of Jumanji from PyPI: + ```bash pip install jumanji ``` + Alternatively, you can install the latest development version directly from GitHub: + ```bash pip install git+https://github.com/instadeepai/jumanji.git ``` + Jumanji has been tested on Python 3.8 and 3.9. Note that because the installation of JAX differs depending on your hardware accelerator, we advise users to explicitly install the correct JAX version (see the @@ -113,7 +113,6 @@ you will need a GUI backend. For example, on Linux, you can install Tk via: [Matplotlib backends](https://matplotlib.org/stable/users/explain/backends.html) for a list of backends you can use. - ## Quickstart β‘ RL practitioners will find Jumanji's interface familiar as it combines the widely adopted @@ -170,7 +169,6 @@ the version number is incremented by one to prevent potential confusion. For a full list of registered versions of each environment, check out [the documentation](https://instadeepai.github.io/jumanji/environments/tsp/). - ## Training ποΈ To showcase how to train RL agents on Jumanji environments, we provide a random agent and a vanilla @@ -191,7 +189,6 @@ actor-critic networks in For more information on how to use the example agents, see the [training guide](https://instadeepai.github.io/jumanji/guides/training/). - ## Contributing π€ Contributions are welcome! See our issue tracker for @@ -199,10 +196,10 @@ Contributions are welcome! See our issue tracker for our [contributing guidelines](https://github.com/instadeepai/jumanji/blob/main/CONTRIBUTING.md) for details on how to submit pull requests, our Contributor License Agreement, and community guidelines. - ## Citing Jumanji βοΈ If you use Jumanji in your work, please cite the library using: + ``` @software{jumanji2023github, author = {ClΓ©ment Bonnet and Daniel Luo and Donal Byrne and Sasha Abramowitz @@ -216,7 +213,6 @@ If you use Jumanji in your work, please cite the library using: } ``` - ## See Also π Other works have embraced the approach of writing RL environments in JAX. diff --git a/docs/api/environments/rware.md b/docs/api/environments/rware.md new file mode 100644 index 000000000..29737c421 --- /dev/null +++ b/docs/api/environments/rware.md @@ -0,0 +1,8 @@ +::: jumanji.environments.routing.robot_warehouse.env.RobotWarehouse + selection: + members: + - __init__ + - reset + - step + - observation_spec + - action_spec diff --git a/docs/env_anim/rware.gif b/docs/env_anim/rware.gif new file mode 100644 index 000000000..d3b0baf84 Binary files /dev/null and b/docs/env_anim/rware.gif differ diff --git a/docs/env_img/rware.png b/docs/env_img/rware.png new file mode 100644 index 000000000..c1fcc8146 Binary files /dev/null and b/docs/env_img/rware.png differ diff --git a/docs/environments/robot_warehouse.md b/docs/environments/robot_warehouse.md new file mode 100644 index 000000000..799e8d442 --- /dev/null +++ b/docs/environments/robot_warehouse.md @@ -0,0 +1,46 @@ +# RobotWarehouse Environment + ++ +
+ +We provide a JAX jit-able implementation of the [Robotic Warehouse](https://github.com/semitable/robotic-warehouse/tree/master) +environment. + +The Robot Warehouse (RWARE) environment simulates a warehouse with robots moving and delivering requested goods. Real-world applications inspire the simulator, in which robots pick up shelves and deliver them to a workstation. Humans access the content of a shelf, and then robots can return them to empty shelf locations. + +The goal is to successfully deliver as many requested shelves in a given time budget. + +Once a shelf has been delivered, a new shelf is requested at random. Agents start each episode at random locations within the warehouse. + +## Observation + +The **observation** seen by the agent is a `NamedTuple` containing the following: + +- `agents_view`: jax array (int32) of shape `(num_agents, num_obs_features)`, array representing the agent's view of other agents + and shelves. + +- `action_mask`: jax array (bool) of shape `(num_agents, 5)`, array specifying, for each agent, + which action (noop, forward, left, right, toggle_load) is legal. + +- `step_count`: jax array (int32) of shape `()`, number of steps elapsed in the current episode. + +## Action + +The action space is a `MultiDiscreteArray` containing an integer value in `[0, 1, 2, 3, 4]` for each +agent. Each agent can take one of five actions: noop (`0`), forward (`1`), turn left (`2`), turn right (`3`), or toggle_load (`4`). + +The episode terminates under the following conditions: + +- An invalid action is taken, or + +- An agent collides with another agent. + +## Reward + +The reward is global and shared among the agents. It is equal to the number of shelves which were +delivered successfully during the time step (i.e., +1 for each shelf). + +## Registered Versions π + +- `RobotWarehouse-v0`, a warehouse with 4 agents each with a sensor range of 1, a warehouse floor with 2 shelf rows, 3 shelf columns, a column height of 8, and a shelf request queue of 8. diff --git a/jumanji/__init__.py b/jumanji/__init__.py index d717a8664..2c2744a35 100644 --- a/jumanji/__init__.py +++ b/jumanji/__init__.py @@ -77,6 +77,10 @@ # Maze with 10 rows and 10 columns, a time limit of 100 and a random maze generator. register(id="Maze-v0", entry_point="jumanji.environments:Maze") +# RobotWarehouse with a random generator with 2 shelf rows, 3 shelf columns, a column height of 8, +# 4 agents, a sensor range of 1, and a request queue of size 8. +register(id="RobotWarehouse-v0", entry_point="jumanji.environments:RobotWarehouse") + # Snake game on a board of size 12x12 with a time limit of 4000. register(id="Snake-v1", entry_point="jumanji.environments:Snake") diff --git a/jumanji/environments/__init__.py b/jumanji/environments/__init__.py index 68e66834b..d8d8402b8 100644 --- a/jumanji/environments/__init__.py +++ b/jumanji/environments/__init__.py @@ -22,11 +22,20 @@ from jumanji.environments.packing.bin_pack.env import BinPack from jumanji.environments.packing.job_shop.env import JobShop from jumanji.environments.packing.knapsack.env import Knapsack -from jumanji.environments.routing import cleaner, connector, cvrp, maze, snake, tsp +from jumanji.environments.routing import ( + cleaner, + connector, + cvrp, + maze, + robot_warehouse, + snake, + tsp, +) from jumanji.environments.routing.cleaner.env import Cleaner from jumanji.environments.routing.connector.env import Connector from jumanji.environments.routing.cvrp.env import CVRP from jumanji.environments.routing.maze.env import Maze +from jumanji.environments.routing.robot_warehouse.env import RobotWarehouse from jumanji.environments.routing.snake.env import Snake from jumanji.environments.routing.tsp.env import TSP diff --git a/jumanji/environments/routing/robot_warehouse/__init__.py b/jumanji/environments/routing/robot_warehouse/__init__.py new file mode 100644 index 000000000..734a3994e --- /dev/null +++ b/jumanji/environments/routing/robot_warehouse/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jumanji.environments.routing.robot_warehouse.env import RobotWarehouse +from jumanji.environments.routing.robot_warehouse.types import Observation, State diff --git a/jumanji/environments/routing/robot_warehouse/conftest.py b/jumanji/environments/routing/robot_warehouse/conftest.py new file mode 100644 index 000000000..90270450e --- /dev/null +++ b/jumanji/environments/routing/robot_warehouse/conftest.py @@ -0,0 +1,99 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import jax +import jax.numpy as jnp +import pytest + +from jumanji.environments.routing.robot_warehouse import RobotWarehouse +from jumanji.environments.routing.robot_warehouse.generator import RandomGenerator +from jumanji.environments.routing.robot_warehouse.types import ( + Agent, + Position, + Shelf, + State, +) +from jumanji.types import TimeStep + + +@pytest.fixture(scope="module") +def robot_warehouse_env() -> RobotWarehouse: + """Instantiates a default RobotWarehouse environment with 2 agents, 1 shelf row, 3 shelf columns, + a column height of 2, sensor range of 1 and a request queue size of 4.""" + generator = RandomGenerator( + shelf_rows=1, + shelf_columns=3, + column_height=2, + num_agents=2, + sensor_range=1, + request_queue_size=4, + ) + + env = RobotWarehouse( + generator=generator, + time_limit=5, + ) + return env + + +@pytest.fixture +def deterministic_robot_warehouse_env( + robot_warehouse_env: RobotWarehouse, +) -> Tuple[RobotWarehouse, State, TimeStep]: + """Instantiates a RobotWarehouse environment with 2 agents and 8 shelves + with a step limit of 5.""" + state, timestep = robot_warehouse_env.reset(jax.random.PRNGKey(42)) + + # create agents, shelves and grid + def make_agent(x: int, y: int, direction: int, is_carrying: int) -> Agent: + return Agent(Position(x=x, y=y), direction=direction, is_carrying=is_carrying) + + def make_shelf(x: int, y: int, is_requested: int) -> Shelf: + return Shelf(Position(x=x, y=y), is_requested=is_requested) + + # agent information + xs = jnp.array([3, 1]) + ys = jnp.array([4, 7]) + dirs = jnp.array([2, 3]) + carries = jnp.array([0, 0]) + state.agents = jax.vmap(make_agent)(xs, ys, dirs, carries) + + # shelf information + xs = jnp.array([1, 1, 1, 1, 2, 2, 2, 2]) + ys = jnp.array([1, 2, 7, 8, 1, 2, 7, 8]) + requested = jnp.array([0, 1, 1, 0, 0, 0, 1, 1]) + state.shelves = jax.vmap(make_shelf)(xs, ys, requested) + + # create grid + state.grid = jnp.array( + [ + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 0, 0, 0, 0, 3, 4, 0], + [0, 5, 6, 0, 0, 0, 0, 7, 8, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 2, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ) + return robot_warehouse_env, state, timestep diff --git a/jumanji/environments/routing/robot_warehouse/constants.py b/jumanji/environments/routing/robot_warehouse/constants.py new file mode 100644 index 000000000..aae7f3b08 --- /dev/null +++ b/jumanji/environments/routing/robot_warehouse/constants.py @@ -0,0 +1,37 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax.numpy as jnp + +from jumanji.environments.routing.robot_warehouse.types import Direction + +# grid channels +_SHELVES = 0 +_AGENTS = 1 + +# agent directions +_POSSIBLE_DIRECTIONS = jnp.array([d.value for d in Direction]) + +# viewer constants +_FIGURE_SIZE = (5, 5) +_SHELF_PADDING = 2 + +# colors +_GRID_COLOR = (0, 0, 0) # black +_SHELF_COLOR = (72 / 255.0, 61 / 255.0, 139 / 255.0) # dark slate blue +_SHELF_REQ_COLOR = (0, 128 / 255.0, 128 / 255.0) # teal +_AGENT_COLOR = (1, 140 / 255.0, 0) # dark orange +_AGENT_LOADED_COLOR = (1, 0, 0) # red +_AGENT_DIR_COLOR = (0, 0, 0) # black +_GOAL_COLOR = (60 / 255.0, 60 / 255.0, 60 / 255.0) diff --git a/jumanji/environments/routing/robot_warehouse/env.py b/jumanji/environments/routing/robot_warehouse/env.py new file mode 100644 index 000000000..ad9d4ec80 --- /dev/null +++ b/jumanji/environments/routing/robot_warehouse/env.py @@ -0,0 +1,545 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import List, Optional, Sequence, Tuple + +import chex +import jax +import jax.numpy as jnp +import matplotlib +from numpy.typing import NDArray + +from jumanji import specs +from jumanji.env import Environment +from jumanji.environments.routing.robot_warehouse import utils +from jumanji.environments.routing.robot_warehouse.constants import _SHELVES +from jumanji.environments.routing.robot_warehouse.generator import ( + Generator, + RandomGenerator, +) +from jumanji.environments.routing.robot_warehouse.types import ( + Action, + Agent, + Direction, + Observation, + Shelf, + State, +) +from jumanji.environments.routing.robot_warehouse.utils_agent import ( + set_new_direction_after_turn, + set_new_position_after_forward, +) +from jumanji.environments.routing.robot_warehouse.utils_shelf import update_shelf +from jumanji.environments.routing.robot_warehouse.viewer import RobotWarehouseViewer +from jumanji.tree_utils import tree_slice +from jumanji.types import TimeStep, restart, termination, transition +from jumanji.viewer import Viewer + + +class RobotWarehouse(Environment[State]): + """A JAX implementation of the 'Robotic warehouse' environment: + https://github.com/semitable/robotic-warehouse + which is described in the paper [1]. + + Creates a grid world where multiple agents (robots) + are supposed to collect shelves, bring them to a goal + and then return them. + + Below is an example warehouse floor grid: + the grid layout is instantiated using three arguments + + - shelf_rows: number of vertical shelf clusters + - shelf_columns: odd number of horizontal shelf clusters + - column_height: height of each cluster + + A cluster is a set of grouped shelves (two cells wide) represented + below as + + XX + Shelf cluster -> XX (this cluster is of height 3) + XX + + Grid Layout: + + shelf columns (here set to 3, i.e. + v v v shelf_columns=3, must be an odd number) + ---------- + > -XX-XX-XX- ^ + Shelf Row 1 -> -XX-XX-XX- Column Height (here set to 3, i.e. + > -XX-XX-XX- v column_height=3) + ---------- + -XX----XX- < + -XX----XX- <- Shelf Row 2 (here set to 2, i.e. + -XX----XX- < shelf_rows=2) + ---------- + ----GG---- + + - G: is the goal positions where agents are rewarded if + they successfully deliver a requested shelf (i.e toggle the load action + inside the goal position while carrying a requested shelf). + + The final grid size will be + - height: (column_height + 1) * shelf_rows + 2 + - width: (2 + 1) * shelf_columns + 1 + + The bottom-middle column is removed to allow for + agents to queue in front of the goal positions + + - action: jax array (int) of shape (num_agents,) containing the action for each agent. + (0: noop, 1: forward, 2: left, 3: right, 4: toggle_load) + + - reward: jax array (int) of shape (), global reward shared by all agents, +1 + for every successful delivery of a requested shelf to the goal position. + + - episode termination: + - The number of steps is greater than the limit. + - Any agent selects an action which causes two agents to collide. + + - state: State + - grid: an array representing the warehouse floor as a 2D grid with two separate channels + one for the agents, and one for the shelves + - agents: a pytree of Agent type with per agent leaves: [position, direction, is_carrying] + - shelves: a pytree of Shelf type with per shelf leaves: [position, is_requested] + - request_queue: the queue of requested shelves (by ID). + - step_count: an integer representing the current step of the episode. + - action_mask: an array of shape (num_agents, 5) containing the valid actions + for each agent. + - key: a pseudorandom number generator key. + + [1] Papoudakis et al., Benchmarking Multi-Agent Deep Reinforcement Learning Algorithms + in Cooperative Tasks (2021) + + ```python + from jumanji.environments import RobotWarehouse + env = RobotWarehouse() + key = jax.random.PRNGKey(0) + state, timestep = jax.jit(env.reset)(key) + env.render(state) + action = env.action_spec().generate_value() + state, timestep = jax.jit(env.step)(state, action) + env.render(state) + ``` + """ + + def __init__( + self, + generator: Optional[Generator] = None, + time_limit: int = 500, + viewer: Optional[Viewer[State]] = None, + ): + """Instantiates an `RobotWarehouse` environment. + + Args: + generator: callable to instantiate environment instances. + Defaults to `RandomGenerator` with parameters: + `shelf_rows = 2`, + `shelf_columns = 3`, + `column_height = 8`, + `num_agents = 4`, + `sensor_range = 1`, + `request_queue_size = 8`. + time_limit: the maximum step limit allowed within the environment. + Defaults to 500. + viewer: viewer to render the environment. Defaults to `RobotWarehouseViewer`. + """ + + # default generator is: robot_warehouse-tiny-4ag-easy (in original implementation) + self._generator = generator or RandomGenerator( + column_height=8, + shelf_rows=2, + shelf_columns=3, + num_agents=4, + sensor_range=1, + request_queue_size=8, + ) + + self.goals: List[Tuple[int, int]] = [] + self.grid_size = self._generator.grid_size + self.request_queue_size = self._generator.request_queue_size + + self.num_agents = self._generator.num_agents + self.sensor_range = self._generator.sensor_range + self.highways = self._generator.highways + self.shelf_ids = self._generator.shelf_ids + self.not_in_queue_size = self._generator.not_in_queue_size + + self.agent_ids = jnp.arange(self.num_agents) + self.directions = jnp.array([d.value for d in Direction]) + self.num_obs_features = utils.calculate_num_observation_features( + self.sensor_range + ) + self.goals = self._generator.goals + self.time_limit = time_limit + + # create viewer for rendering environment + self._viewer = viewer or RobotWarehouseViewer( + self.grid_size, self.goals, "RobotWarehouse" + ) + + def __repr__(self) -> str: + return ( + f"RobotWarehouse(\n" + f"\tgrid_width={self.grid_size[1]!r},\n" + f"\tgrid_height={self.grid_size[0]!r},\n" + f"\tnum_agents={self.num_agents!r}, \n" + ")" + ) + + def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: + """Resets the environment. + + Args: + key: random key used to reset the environment since it is stochastic. + + Returns: + state: State object corresponding to the new state of the environment. + timestep: TimeStep object corresponding the first timestep returned by the environment. + """ + # create environment state + state = self._generator(key) + + # collect first observations and create timestep + agents_view = self._make_observations(state.grid, state.agents, state.shelves) + observation = Observation( + agents_view=agents_view, + action_mask=state.action_mask, + step_count=state.step_count, + ) + timestep = restart(observation=observation) + return state, timestep + + def step( + self, + state: State, + action: chex.Array, + ) -> Tuple[State, TimeStep[Observation]]: + """Perform an environment step. + + Args: + state: State object containing the dynamics of the environment. + action: Array containing the action to take. + - 0 no op + - 1 move forward + - 2 turn left + - 3 turn right + - 4 toggle load + + Returns: + state: State object corresponding to the next state of the environment. + timestep: TimeStep object corresponding the timestep returned by the environment. + """ + + # unpack state + key = state.key + grid = state.grid + agents = state.agents + shelves = state.shelves + request_queue = state.request_queue + + # check for invalid action -> turn into noops + actions = utils.get_valid_actions(action, state.action_mask) + + # check for agent collisions + collisions = jax.vmap(functools.partial(utils.is_collision, grid))( + agents, actions + ) + collision = jnp.any(collisions) + + # update agents, shelves and grid + def update_state_scan( + carry_info: Tuple[chex.Array, chex.Array, chex.Array, int], action: int + ) -> Tuple[Tuple[chex.Array, chex.Array, chex.Array, int], None]: + grid, agents, shelves, agent_id = carry_info + grid, agents, shelves = self._update_state( + grid, agents, shelves, action, agent_id + ) + return (grid, agents, shelves, agent_id + 1), None + + (grid, agents, shelves, _), _ = jax.lax.scan( + update_state_scan, (grid, agents, shelves, 0), actions + ) + + # compute shared reward for all agents and update request queue + # if a requested shelf has been successfully delivered to the goal + reward = jnp.array(0, dtype=jnp.float32) + + def update_reward_and_request_queue_scan( + carry_info: Tuple[ + chex.PRNGKey, chex.Array, chex.Array, chex.Array, chex.Array + ], + goal: chex.Array, + ) -> Tuple[ + Tuple[chex.PRNGKey, chex.Array, chex.Array, chex.Array, chex.Array], None + ]: + key, reward, request_queue, grid, shelves = carry_info + ( + key, + reward, + request_queue, + shelves, + ) = self._update_reward_and_request_queue( + key, reward, request_queue, grid, shelves, goal + ) + carry_info = (key, reward, request_queue, grid, shelves) + return carry_info, None + + update_info, _ = jax.lax.scan( + update_reward_and_request_queue_scan, + (key, reward, request_queue, grid, shelves), + self.goals, + ) + key, reward, request_queue, grid, shelves = update_info + + # construct timestep and check environment termination + steps = state.step_count + 1 + horizon_reached = steps >= self.time_limit + done = collision | horizon_reached + + # compute next observation + agents_view = self._make_observations(grid, agents, shelves) + action_mask = utils.compute_action_mask(grid, agents) + next_observation = Observation( + agents_view=agents_view, + action_mask=action_mask, + step_count=steps, + ) + + timestep = jax.lax.cond( + done, + termination, + transition, + reward, + next_observation, + ) + next_state = State( + grid=grid, + agents=agents, + shelves=shelves, + request_queue=request_queue, + step_count=steps, + action_mask=action_mask, + key=key, + ) + return next_state, timestep + + def observation_spec(self) -> specs.Spec[Observation]: + """Specification of the observation of the `RobotWarehouse` environment. + Returns: + Spec for the `Observation`, consisting of the fields: + - agents_view: Array (int32) of shape (num_agents, num_obs_features). + - action_mask: BoundedArray (bool) of shape (num_agent, 5). + - step_count: BoundedArray (int32) of shape (). + """ + agents_view = specs.Array( + (self.num_agents, self.num_obs_features), jnp.int32, "agents_view" + ) + action_mask = specs.BoundedArray( + (self.num_agents, 5), bool, False, True, "action_mask" + ) + step_count = specs.BoundedArray((), jnp.int32, 0, self.time_limit, "step_count") + return specs.Spec( + Observation, + "ObservationSpec", + agents_view=agents_view, + action_mask=action_mask, + step_count=step_count, + ) + + def action_spec(self) -> specs.MultiDiscreteArray: + """Returns the action spec. 5 actions: [0,1,2,3,4] -> [No Op, Forward, Left, Right, Toggle_load]. + Since this is a multi-agent environment, the environment expects an array of actions. + This array is of shape (num_agents,). + """ + return specs.MultiDiscreteArray( + num_values=jnp.array([len(Action)] * self.num_agents, jnp.int32), + name="action", + ) + + def _make_observations( + self, + grid: chex.Array, + agents: Agent, + shelves: Shelf, + ) -> chex.Array: + """Create an observation for each agent based on its view of other + agents and shelves + + Args: + grid: the warehouse floor grid array. + agents: a pytree of Agent type containing agents information. + shelves: a pytree of Shelf type containing shelves information. + + Returns: + an array containing agents observations. + """ + return jax.vmap( + functools.partial( + utils.make_agent_observation, + grid, + agents, + shelves, + self.sensor_range, + self.num_obs_features, + self.highways, + ) + )(self.agent_ids) + + def _update_state( + self, + grid: chex.Array, + agents: chex.Array, + shelves: chex.Array, + action: int, + agent_id: int, + ) -> Tuple[chex.Array, chex.Array, chex.Array]: + """Update the state of the environment after an action is performed. + + Args: + grid: the warehouse floor grid array. + agents: a pytree of Agent type containing agents information. + shelves: a pytree of Shelf type containing shelves information. + action: the action performed by the agent. + agent_id: the id of the agent performing the action. + Returns: + the updated warehouse floor grid array, agents and shelves. + """ + agent = tree_slice(agents, agent_id) + is_highway = self.highways[agent.position.x, agent.position.y] + grid, agents, shelves = jax.lax.cond( + jnp.equal(action, Action.FORWARD.value), + set_new_position_after_forward, + set_new_direction_after_turn, + grid, + agents, + shelves, + action, + agent_id, + is_highway, + ) + + return grid, agents, shelves + + def _update_reward_and_request_queue( + self, + key: chex.PRNGKey, + reward: chex.Array, + request_queue: chex.Array, + grid: chex.Array, + shelves: chex.Array, + goal: chex.Array, + ) -> Tuple[chex.PRNGKey, int, chex.Array, chex.Array]: + """Check if a shelf has been delivered successfully to a goal state, + if so reward the agents and update the request queue: removing the ID + of the delivered shelf and replacing it with a new shelf ID. + + Args: + key: a pseudorandom number generator key. + reward: the array of shared reward for each agent. + request_queue: the queue of requested shelves. + grid: the warehouse floor grid array. + shelves: a pytree of Shelf type containing shelves information. + goal: array of goal positions. + Returns: + a random key, updated reward, request queue and shelves. + """ + x, y = goal + shelf_id = grid[_SHELVES, x, y] + + def reward_and_update_request_queue_if_shelf_in_goal( + key: chex.PRNGKey, + reward: jnp.int32, + request_queue: chex.Array, + shelves: chex.Array, + shelf_id: int, + ) -> Tuple[chex.PRNGKey, int, chex.Array, chex.Array]: + "Reward the agents and update the request queue." + + # remove from queue and replace it + key, request_key = jax.random.split(key) + + not_in_queue = jnp.setdiff1d( + self.shelf_ids, + request_queue, + size=self.not_in_queue_size, + ) + new_request_id = jax.random.choice( + request_key, + not_in_queue, + replace=False, + ) + replace_index = jnp.argwhere(jnp.equal(request_queue, shelf_id - 1), size=1) + request_queue = request_queue.at[replace_index].set(new_request_id) + + # also reward the agents + reward += 1.0 + + # update requested shelf + shelves = update_shelf(shelves, shelf_id - 1, "is_requested", 0) + shelves = update_shelf(shelves, new_request_id, "is_requested", 1) + return key, reward, request_queue, shelves + + # check if shelf is at goal position and in request queue + cond = (shelf_id != 0) & jnp.isin(shelf_id, request_queue + 1) + + key, reward, request_queue, shelves = jax.lax.cond( + cond, + reward_and_update_request_queue_if_shelf_in_goal, + lambda k, r, rq, g, _: (k, r, rq, g), + key, + reward, + request_queue, + shelves, + shelf_id, + ) + return key, reward, request_queue, shelves + + def render(self, state: State) -> Optional[NDArray]: + """Renders the current state of the RobotWarehouse environment. + + Args: + state: is the current environment state to be rendered. + save_path: the path where the image should be saved. If it is None, the plot + will not be stored. + """ + return self._viewer.render(state) + + def animate( + self, + states: Sequence[State], + interval: int = 200, + save_path: Optional[str] = None, + ) -> matplotlib.animation.FuncAnimation: + """Creates an animation from a sequence of RobotWarehouse states. + + Args: + states: sequence of `State` corresponding to subsequent timesteps. + interval: delay between frames in milliseconds, default to 200. + save_path: the path where the animation file should be saved. If it is None, + the plot will not be saved. + + Returns: + Animation object that can be saved as a GIF, MP4, or rendered with HTML. + """ + return self._viewer.animate( + states=states, interval=interval, save_path=save_path + ) + + def close(self) -> None: + """Perform any necessary cleanup. + + Environments will automatically :meth:`close()` themselves when + garbage collected or when the program exits. + """ + self._viewer.close() diff --git a/jumanji/environments/routing/robot_warehouse/env_test.py b/jumanji/environments/routing/robot_warehouse/env_test.py new file mode 100644 index 000000000..bf64f5ee7 --- /dev/null +++ b/jumanji/environments/routing/robot_warehouse/env_test.py @@ -0,0 +1,219 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import chex +import jax +import jax.numpy as jnp +from jax import random + +from jumanji.environments.routing.robot_warehouse.env import RobotWarehouse +from jumanji.environments.routing.robot_warehouse.types import State +from jumanji.testing.env_not_smoke import check_env_does_not_smoke +from jumanji.testing.pytrees import assert_is_jax_array_tree +from jumanji.tree_utils import tree_slice +from jumanji.types import TimeStep + + +def test_robot_warehouse__specs(robot_warehouse_env: RobotWarehouse) -> None: + """Validate environment specs conform to the expected shapes and values""" + action_spec = robot_warehouse_env.action_spec() + observation_spec = robot_warehouse_env.observation_spec() + + assert observation_spec.agents_view.shape == (2, 66) # type: ignore + assert action_spec.num_values.shape[0] == robot_warehouse_env.num_agents + assert action_spec.num_values[0] == 5 + + +def test_robot_warehouse__reset(robot_warehouse_env: RobotWarehouse) -> None: + """Validate the jitted reset of the environment.""" + chex.clear_trace_counter() + reset_fn = jax.jit(chex.assert_max_traces(robot_warehouse_env.reset, n=1)) + + key1, key2 = random.PRNGKey(0), random.PRNGKey(1) + state1, timestep1 = reset_fn(key1) + state2, timestep2 = reset_fn(key2) + + assert isinstance(timestep1, TimeStep) + assert isinstance(state1, State) + assert state1.step_count == 0 + assert state1.grid.shape == (2, *robot_warehouse_env.grid_size) + # Check that the state is made of DeviceArrays, this is false for the non-jitted + # reset function since unpacking random.split returns numpy arrays and not device arrays. + assert_is_jax_array_tree(state1) + # Check random initialization + assert not jnp.all(state1.key == state2.key) + assert not jnp.all(state1.grid == state2.grid) + assert state1.step_count == state2.step_count + + +def test_robot_warehouse__agent_observation( + deterministic_robot_warehouse_env: Tuple[RobotWarehouse, State, TimeStep] +) -> None: + """Validate the agent observation function.""" + env, state, timestep = deterministic_robot_warehouse_env + state, timestep = env.step(state, jnp.array([0, 0])) + + # agent 1 obs + agent1_own_view = jnp.array([3, 4, 0, 0, 0, 1, 0, 1]) + agent1_other_agents_view = jnp.array(8 * [0, 0, 0, 0, 0]) + agent1_shelf_view = jnp.array(9 * [0, 0]) + agent1_obs = jnp.hstack( + [agent1_own_view, agent1_other_agents_view, agent1_shelf_view] + ) + + # agent 2 obs + agent2_own_view = jnp.array([1, 7, 0, 0, 0, 0, 1, 0]) + agent2_other_agents_view = jnp.array(8 * [0, 0, 0, 0, 0]) + agent2_shelf_view = jnp.array( + [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1] + ) + agent2_obs = jnp.hstack( + [agent2_own_view, agent2_other_agents_view, agent2_shelf_view] + ) + + assert jnp.all(timestep.observation.agents_view[0] == agent1_obs) + assert jnp.all(timestep.observation.agents_view[1] == agent2_obs) + + +def test_robot_warehouse__step(robot_warehouse_env: RobotWarehouse) -> None: + """Validate the jitted step function of the environment.""" + chex.clear_trace_counter() + + step_fn = chex.assert_max_traces(robot_warehouse_env.step, n=1) + step_fn = jax.jit(step_fn) + + state_key, action_key1, action_key2 = random.split(random.PRNGKey(10), 3) + state, timestep = robot_warehouse_env.reset(state_key) + + # Sample two different actions + action1, action2 = random.choice( + key=action_key1, + a=jnp.arange(5), + shape=(2,), + replace=False, + ) + + action1 = jnp.zeros((robot_warehouse_env.num_agents,), int).at[0].set(action1) + action2 = jnp.zeros((robot_warehouse_env.num_agents,), int).at[0].set(action2) + + new_state1, timestep1 = step_fn(state, action1) + + # Check that rewards have the correct number of dimensions + assert jnp.ndim(timestep1.reward) == 0 + assert jnp.ndim(timestep.reward) == 0 + # Check that discounts have the correct number of dimensions + assert jnp.ndim(timestep1.discount) == 0 + assert jnp.ndim(timestep.discount) == 0 + # Check that the state is made of DeviceArrays, this is false for the non-jitted + # step function since unpacking random.split returns numpy arrays and not device arrays. + assert_is_jax_array_tree(new_state1) + # Check that the state has changed + assert new_state1.step_count != state.step_count + assert not jnp.all(new_state1.grid != state.grid) + # Check that two different actions lead to two different states + new_state2, timestep2 = step_fn(state, action2) + assert not jnp.all(new_state1.grid != new_state2.grid) + + jax.debug.print("grid: {g}", g=state.grid) + + # Check that the state update and timestep creation work as expected + agents = state.agents + agent = tree_slice(agents, 1) + jax.debug.print("agents: {g}", g=agents) + x = agent.position.x + y = agent.position.y + + # turning and moving actions + actions = [2, 2, 3, 3, 1, 3, 1] + + # Note: starting direction is 3 (facing left) + new_locs = [ + (x, y, 2), # turn left -> facing down + (x, y, 1), # turn left -> facing right + (x, y, 2), # turn right -> facing down + (x, y, 3), # turn right -> face left + (x, y - 1, 3), # move forward -> move left + (x, y - 1, 0), # turn right -> face up + (x - 1, y - 1, 0), # move forward -> move up + ] + + for action, new_loc in zip(actions, new_locs): + state, timestep = step_fn(state, jnp.array([action, action])) + agent1_info = tree_slice(state.agents, 1) + agent1_loc = ( + agent1_info.position.x, + agent1_info.position.y, + agent1_info.direction, + ) + assert agent1_loc == new_loc + + +def test_robot_warehouse__does_not_smoke(robot_warehouse_env: RobotWarehouse) -> None: + """Validate that we can run an episode without any errors.""" + check_env_does_not_smoke(robot_warehouse_env) + + +def test_robot_warehouse__time_limit(robot_warehouse_env: RobotWarehouse) -> None: + """Validate the terminal reward.""" + step_fn = jax.jit(robot_warehouse_env.step) + state_key = random.PRNGKey(10) + state, timestep = robot_warehouse_env.reset(state_key) + assert timestep.first() + + for _ in range(robot_warehouse_env.time_limit - 1): + state, timestep = step_fn(state, jnp.array([0, 0])) + + assert timestep.mid() + state, timestep = step_fn(state, jnp.array([0, 0])) + assert timestep.last() + + +def test_robot_warehouse__truncation( + deterministic_robot_warehouse_env: Tuple[RobotWarehouse, State, TimeStep] +) -> None: + """Validate episode truncation based on set time limit.""" + robot_warehouse_env, state, timestep = deterministic_robot_warehouse_env + step_fn = jax.jit(robot_warehouse_env.step) + + # truncation + for _ in range(robot_warehouse_env.time_limit): + state, timestep = step_fn(state, jnp.array([0, 0])) + + assert timestep.last() + # note the line below should be used to test for truncation + # but since we instead use termination inside the env code + # for training capatibility, we check for omit this check + # assert not jnp.all(timestep.discount == 0) + + +def test_robot_warehouse__truncate_upon_collision( + deterministic_robot_warehouse_env: Tuple[RobotWarehouse, State, TimeStep] +) -> None: + """Validate episode terminates upon collision of agents.""" + robot_warehouse_env, state, timestep = deterministic_robot_warehouse_env + step_fn = jax.jit(robot_warehouse_env.step) + + # actions for agent 1 to collide with agent 2 + actions = [3, 1, 1, 3, 1, 1, 1] + + # take actions until collision + for action in actions: + state, timestep = step_fn(state, jnp.array([action, 0])) + + assert timestep.last() + # TODO: uncomment once we have changed termination + # in the env code to truncation (also see above) + # assert not jnp.all(timestep.discount == 0) diff --git a/jumanji/environments/routing/robot_warehouse/generator.py b/jumanji/environments/routing/robot_warehouse/generator.py new file mode 100644 index 000000000..2f0b79ec0 --- /dev/null +++ b/jumanji/environments/routing/robot_warehouse/generator.py @@ -0,0 +1,275 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import Callable + +import chex +import jax +import jax.numpy as jnp + +from jumanji.environments.routing.robot_warehouse.types import State +from jumanji.environments.routing.robot_warehouse.utils import compute_action_mask +from jumanji.environments.routing.robot_warehouse.utils_spawn import ( + place_entities_on_grid, + spawn_random_entities, +) + + +class Generator(abc.ABC): + """Base class for generators for the RobotWarehouse environment.""" + + def __init__( + self, + shelf_rows: int, + shelf_columns: int, + column_height: int, + num_agents: int, + sensor_range: int, + request_queue_size: int, + ) -> None: + """Initializes a robot_warehouse generator, used to generate grids for + the RobotWarehouse environment. + + Args: + shelf_rows: the number of shelf cluster rows, each of height = column_height. + Defaults to 1. + shelf_columns: the number of shelf cluster columns, each of width = 2 cells + (must be an odd number). Defaults to 3. + column_height: the height of each shelf cluster. Defaults to 8. + num_agents: the number of agents (robots) operating on the warehouse floor. + Defaults to 2. + sensor_range: the receptive field around an agent O O O + (e.g. 1 implies a 360 view of 1 cell around the -> O x O + agent's position cell) O O O + Defaults to 1. + request_queue_size: the number of shelves requested at any + given time which remains fixed throughout environment steps. Defaults to 4. + """ + if shelf_columns % 2 != 1: + raise ValueError( + "Environment argument: `shelf_columns`, must be an odd number." + ) + + self._shelf_rows = shelf_rows + self._shelf_columns = shelf_columns + self._column_height = column_height + self._num_agents = num_agents + self._sensor_range = sensor_range + self._request_queue_size = request_queue_size + + self._grid_size = ( + (column_height + 1) * shelf_rows + 2, + (2 + 1) * shelf_columns + 1, + ) + self._agent_ids = jnp.arange(num_agents) + + @property + def shelf_rows(self) -> int: + return self._shelf_rows + + @property + def shelf_columns(self) -> int: + return self._shelf_columns + + @property + def column_height(self) -> int: + return self._column_height + + @property + def grid_size(self) -> chex.Array: + return self._grid_size + + @property + def num_agents(self) -> int: + return self._num_agents + + @property + def sensor_range(self) -> int: + return self._sensor_range + + @property + def request_queue_size(self) -> int: + return self._request_queue_size + + @property + def agent_ids(self) -> chex.Array: + return self._agent_ids + + @property + @abc.abstractmethod + def shelf_ids(self) -> chex.Array: + """shelf ids""" + + @property + @abc.abstractmethod + def not_in_queue_size(self) -> chex.Array: + """number of shelves not in queue""" + + @property + @abc.abstractmethod + def highways(self) -> chex.Array: + """highways positions""" + + @property + @abc.abstractmethod + def goals(self) -> chex.Array: + """goals positions""" + + @abc.abstractmethod + def __call__(self, key: chex.PRNGKey) -> State: + """Generates an `RobotWarehouse` state. + + Returns: + An `RobotWarehouse` state. + """ + + +class GeneratorBase(Generator): + """Base class for `RobotWarehouse` environment state generator.""" + + def __init__( + self, + shelf_rows: int, + shelf_columns: int, + column_height: int, + num_agents: int, + sensor_range: int, + request_queue_size: int, + ) -> None: + """Initializes a robot_warehouse generator.""" + super().__init__( + shelf_rows, + shelf_columns, + column_height, + num_agents, + sensor_range, + request_queue_size, + ) + self._make_warehouse() + + def _make_warehouse(self) -> None: + """Create the layout for the warehouse floor, i.e. the grid + + Args: + shelf_rows: the number of shelf cluster rows + shelf_columns: the number of shelf cluster columns + column_height: the height of each shelf cluster + """ + + # create goal positions + self._goals = jnp.array( + [ + (self._grid_size[1] // 2 - 1, self._grid_size[0] - 1), + (self._grid_size[1] // 2, self._grid_size[0] - 1), + ] + ) + # calculate "highways" (these are open spaces/cells between shelves) + highway_func: Callable[[int, int], bool] = lambda x, y: ( + (y % 3 == 0) # vertical highways + | (x % (self.column_height + 1) == 0) # horizontal highways + | (x == self._grid_size[0] - 1) # delivery row + | ( # remove middle cluster to allow agents to queue in front of goals + (x > self._grid_size[0] - (self.column_height + 3)) + & ((y == self._grid_size[1] // 2 - 1) | (y == self._grid_size[1] // 2)) + ) + ) + grid_indices = jnp.indices(jnp.zeros(self._grid_size, dtype=jnp.int32).shape) + self._highways = jax.vmap(highway_func)(grid_indices[0], grid_indices[1]) + + non_highways = jnp.abs(self.highways - 1) + + # shelves information + n_shelves = jnp.sum(non_highways) + self._shelf_positions = jnp.argwhere(non_highways) + self._shelf_ids = jnp.arange(n_shelves) + self._not_in_queue_size = n_shelves - self.request_queue_size + + @property + def shelf_ids(self) -> chex.Array: + return self._shelf_ids + + @property + def not_in_queue_size(self) -> chex.Array: + return self._not_in_queue_size + + @property + def highways(self) -> chex.Array: + return self._highways + + @property + def goals(self) -> chex.Array: + return self._goals + + +class RandomGenerator(GeneratorBase): + """Randomly generates `RobotWarehouse` environment state. This generator places agents at + starting positions on the grid and selects the requested shelves uniformly at random. + """ + + def __init__( + self, + shelf_rows: int, + shelf_columns: int, + column_height: int, + num_agents: int, + sensor_range: int, + request_queue_size: int, + ) -> None: + """Initialises an robot_warehouse generator, used to generate grids for + the RobotWarehouse environment.""" + super().__init__( + shelf_rows, + shelf_columns, + column_height, + num_agents, + sensor_range, + request_queue_size, + ) + + def __call__(self, key: chex.PRNGKey) -> State: + """Generates a `RobotWarehouse` state that contains the grid and the agents/shelves layout. + + Returns: + A `RobotWarehouse` state. + """ + # empty grid array + grid = jnp.zeros((2, *self._grid_size), dtype=jnp.int32) + + # spawn random agents with random request queue + key, agents, shelves, shelf_request_queue = spawn_random_entities( + key, + self._grid_size, + self._agent_ids, + self._shelf_ids, + self._shelf_positions, + self._request_queue_size, + ) + grid = place_entities_on_grid(grid, agents, shelves) + + # compute action mask + action_mask = compute_action_mask(grid, agents) + + # create environment state + state = State( + grid=grid, + agents=agents, + shelves=shelves, + request_queue=shelf_request_queue, + step_count=jnp.array(0, int), + action_mask=action_mask, + key=key, + ) + + return state diff --git a/jumanji/environments/routing/robot_warehouse/generator_test.py b/jumanji/environments/routing/robot_warehouse/generator_test.py new file mode 100644 index 000000000..925d66b16 --- /dev/null +++ b/jumanji/environments/routing/robot_warehouse/generator_test.py @@ -0,0 +1,53 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import chex +import jax +import pytest + +from jumanji.environments.routing.robot_warehouse.generator import RandomGenerator + + +@pytest.fixture +def random_generator() -> RandomGenerator: + """Creates a generator with 2 agents.""" + return RandomGenerator( + shelf_rows=1, + shelf_columns=3, + column_height=2, + num_agents=2, + sensor_range=1, + request_queue_size=4, + ) + + +def test_random_generator__call(random_generator: RandomGenerator) -> None: + """Test that generator generates valid boards.""" + key = jax.random.PRNGKey(42) + state = random_generator(key) + grid_size = (2, 5, 10) + assert state.grid.shape == grid_size + assert state.agents.direction.shape[0] == 2 + + +def test_random_generator__no_retrace( + random_generator: RandomGenerator, +) -> None: + """Checks that generator only traces the function once and works when jitted.""" + key = jax.random.PRNGKey(42) + keys = jax.random.split(key, 2) + jitted_generator = jax.jit(chex.assert_max_traces((random_generator.__call__), n=1)) + + for key in keys: + jitted_generator(key) diff --git a/jumanji/environments/routing/robot_warehouse/types.py b/jumanji/environments/routing/robot_warehouse/types.py new file mode 100644 index 000000000..bc241e09c --- /dev/null +++ b/jumanji/environments/routing/robot_warehouse/types.py @@ -0,0 +1,133 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, NamedTuple, Union + +if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239 + from dataclasses import dataclass +else: + from chex import dataclass + +from enum import IntEnum + +import chex + + +class Action(IntEnum): + """An enumeration of possible actions + that an agent can take in the warehouse. + + NOOP - represents no operation. + FORWARD - move forward. + LEFT - turn left. + RIGHT - turn right. + TOGGLE_LOAD - toggle loading/offloading a shelf. + """ + + NOOP = 0 + FORWARD = 1 + LEFT = 2 + RIGHT = 3 + TOGGLE_LOAD = 4 + + +class Direction(IntEnum): + """An enumeration of possible directions + that an agent can take in the warehouse. + + UP - move up. + RIGHT - move right. + DOWN - move down. + LEFT - move left. + """ + + UP = 0 + RIGHT = 1 + DOWN = 2 + LEFT = 3 + + +class Position(NamedTuple): + """A class to represent the 2D coordinate position of entities + + x: the x-position of the entity. + y: the y-position of the entity. + """ + + x: chex.Array # () + y: chex.Array # () + + +class Agent(NamedTuple): + """A class to represent an Agent in the warehouse + + position: the (x,y) position of the agent. + direction: the direction the agent is facing. + is_carrying: whether the agent is carrying a shelf or not. + """ + + position: Position # (2,) + direction: chex.Array # () + is_carrying: chex.Array # () + + +class Shelf(NamedTuple): + """A class to represent a Shelf in the warehouse. + + position: the (x,y) position of the shelf. + is_requested: whether the shelf is requested for delivery. + """ + + position: Position # (2,) + is_requested: chex.Array # () + + +Entity = Union[Agent, Shelf] + + +@dataclass +class State: + """A dataclass representing the state of the simulated warehouse. + + grid: an array representing the warehouse floor as a 2D grid with two separate channels + one for the agents, and one for the shelves. + agents: a pytree of Agent type with per agent leaves: [position, direction, is_carrying] + shelves: a pytree of Shelf type with per shelf leaves: [position, is_requested] + request_queue : the queue of requested shelves (by ID). + step_count: an integer representing the current step of the episode. + key: a pseudorandom number generator key. + """ + + grid: chex.Array # (2, grid_width, grid_height) + agents: Agent # (num_agents, ...) + shelves: Shelf # (num_shelves, ...) + request_queue: chex.Array # (num_requested,) + step_count: chex.Array # () + action_mask: chex.Array # (num_agents, 5) + key: chex.PRNGKey # (2,) + + +class Observation(NamedTuple): + """The observation that the agent sees. + agents_view: the agents' view of other agents and shelves within their + sensor range. The number of features in the observation array + depends on the sensor range of the agent. + action_mask: boolean array specifying, for each agent, which action + (up, right, down, left) is legal. + step_count: the number of steps elapsed since the beginning of the episode. + """ + + agents_view: chex.Array # (num_agents, num_obs_features) + action_mask: chex.Array # (num_agents, 5) + step_count: chex.Array # () diff --git a/jumanji/environments/routing/robot_warehouse/utils.py b/jumanji/environments/routing/robot_warehouse/utils.py new file mode 100644 index 000000000..87616d419 --- /dev/null +++ b/jumanji/environments/routing/robot_warehouse/utils.py @@ -0,0 +1,341 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import Tuple + +import chex +import jax +import jax.numpy as jnp + +from jumanji.environments.routing.robot_warehouse.constants import _AGENTS, _SHELVES +from jumanji.environments.routing.robot_warehouse.types import Action, Agent, Entity +from jumanji.environments.routing.robot_warehouse.utils_agent import ( + get_agent_view, + get_new_position_after_forward, +) +from jumanji.tree_utils import tree_slice + + +def get_entity_ids(entities: Entity) -> chex.Array: + """Get ids for agents/shelves. + + Args: + entities: a pytree of Agent or Shelf type. + + Returns: + an array of ids. + """ + return jnp.arange(entities[1].shape[0]) + + +def is_valid_action(grid: chex.Array, agent: Agent, action: chex.Array) -> chex.Array: + """If the agent is carrying a shelf and collides with another + shelf based on its current action, this action is deemed invalid. + + Args: + grid: the warehouse floor grid array. + agent: the agent for which the action is being checked. + action: the action the agent is about to take. + + Returns: + a boolean indicating whether the action is valid or not. + """ + + # get start and target positions + start = agent.position + target = get_new_position_after_forward(grid, start, agent.direction) + + # check if carrying and walking into another shelf + cond = jnp.logical_and(jnp.equal(action, Action.FORWARD), agent.is_carrying) + cond = jnp.logical_and(cond, jnp.logical_not(jnp.array_equal(start, target))) + cond = jnp.logical_and(cond, grid[_SHELVES, target.x, target.y]) + + return ~cond + + +def get_valid_actions(actions: chex.Array, action_mask: chex.Array) -> chex.Array: + """Get the valid action the agent should take given its action mask. + + Args: + actions: the actions the agents are about to take. + action_mask: the mask of valid actions. + + Returns: + the action the agent should take given its current state. + """ + + def get_valid_action(action_mask: chex.Array, action: chex.Array) -> chex.Array: + return jax.lax.cond(action_mask[action], lambda: action, lambda: 0) + + return jax.vmap(get_valid_action)(action_mask, actions) + + +def is_collision(grid: chex.Array, agent: Agent, action: int) -> chex.Array: + """Calculate whether an agent is about to collide with another + entity. If the agent instead collides with another agent, the + episode terminates (this behavior is specific to this JAX version). + + Args: + grid: the warehouse floor grid array. + agent: the agent for which collisions are being checked. + + Returns: + a boolean indicating whether the agent collided with another + agent or not. + """ + + def check_collision() -> chex.Array: + # get start and target positions + start = agent.position + target = get_new_position_after_forward(grid, start, agent.direction) + + agent_id_at_target_pos = grid[_AGENTS, target.x, target.y] + check_forward = ~jnp.array_equal(start, target) + return check_forward & (agent_id_at_target_pos > 0) + + return jax.lax.cond( + jnp.equal(action, Action.FORWARD), check_collision, lambda: False + ) + + +def calculate_num_observation_features(sensor_range: chex.Array) -> chex.Array: + """Calculates the 1-d size of the agent observations array based on the + environment parameters at instantiation + + Below is a receptive field for an agent x with a sensor range of 1: + + O O O + O x O + O O O + + For the sensor on the agent's own position, + we have the following features + 1. the agent's position -> dim 2 + 2. is the agent carrying a shelf? -> binary {0, 1} with dim 1 + 3. the direction of the agent -> one-hot with dim 4 + 4. is the agent on the warehouse "highway" or not? -> binary with dim 1 + Total dim for agent in focus = 2 + 1 + 4 + 1 = 8 + + Then, for each sensor position (other than the agent's own position, 8 in total), + we have the following features based on other agents: + 1. is there an agent? -> binary {0, 1} with dim 1 + 2. if yes, the agent's direction -> one-hot with dim 4, if no, fill all zeros + Therefore, the total number of dimensions for other agent features + (1 + 4) * num_obs_sensors + + Finally, for each sensor position (9 in total) in the agent's receptive field, + we have the following features based on shelves: + 1. is there a shelf? -> binary {0, 1} with dim 1 + 2. if so, has this shelf been requested -> binary {0, 1} with dim 1, if no, zero + Therefore, the total number of dimensions for shelf features is + (1 + 1) * num_obs_sensors + + Args: + sensor_range: the range of the agent's sensors. + + Returns: + agent's 1-d observation array. + """ + num_obs_sensors = (1 + 2 * sensor_range) ** 2 + obs_features = 8 # agent's own features + obs_features += (num_obs_sensors - 1) * 5 # other agent features + obs_features += num_obs_sensors * 2 # shelf features + return jnp.array(obs_features, jnp.int32) + + +def write_to_observation( + observation: chex.Array, idx: chex.Array, data: chex.Array +) -> Tuple[chex.Array, chex.Array]: + """Write data to the given observation vector at a specified index + + Args: + observation: an observation to which data will be written. + idx: an integer representing the index at which the data will be inserted. + data: the data that will be inserted into the observation array. + + Returns: + the updated observation array and the new index + of where to insert the next data. + """ + data_size = len(data) + observation = jax.lax.dynamic_update_slice(observation, data, (idx,)) + return observation, idx + data_size + + +def move_writer_index(idx: chex.Array, bits: chex.Array) -> chex.Array: + """Skip an indicated number of bits in the observation array being written. + + Args: + idx: an integer representing the index at which to skip bits. + bits: the number of bits to skip. + + Returns: + the new index at which to insert data. + """ + return idx + bits + + +def make_agent_observation( + grid: chex.Array, + agents: chex.Array, + shelves: chex.Array, + sensor_range: int, + num_obs_features: int, + highways: chex.Array, + agent_id: int, +) -> chex.Array: + """Create an observation for a single agent based on its view + of other agents and shelves. + + Args: + grid: the warehouse floor grid array. + agents: a pytree of either Agent type containing agent information. + shelves: a pytree of Shelf type containing shelf information. + sensor_range: the range of the agent's sensors. + num_obs_features: the number of features in the observation array. + highways: binary array indicating highway positions. + agent_id: unique ID identifying a specific agent. + + Returns: + a 1-d array containing the agent's observation. + """ + agent = tree_slice(agents, agent_id) + agents_grid, shelves_grid = get_agent_view(grid, agent, sensor_range) + + # write flattened observations + obs = jnp.zeros(num_obs_features, dtype=jnp.int32) + idx = 0 + + # write current agent position and whether carrying a shelf or not + obs, idx = write_to_observation( + obs, + idx, + jnp.array( + [agent.position.x, agent.position.y, agent.is_carrying], + dtype=jnp.int32, + ), + ) + + # write current agent direction + direction = jax.nn.one_hot(agent.direction, 4, dtype=jnp.int32) + obs, idx = write_to_observation(obs, idx, direction) + + # write if agent is on highway or not + obs, idx = write_to_observation( + obs, + idx, + jnp.array( + [jnp.array(highways[agent.position.x, agent.position.y], int)], + dtype=jnp.int32, + ), + ) + + # function for writing receptive field cells + def write_no_agent( + obs: chex.Array, idx: int, _: int, is_self: bool + ) -> Tuple[chex.Array, int]: + "Write information for empty agent cell." + # if there is no agent we set a 0 and all zeros + # for the direction as well, i.e. [0, 0, 0, 0, 0] + idx = jax.lax.cond(is_self, lambda i: i, lambda i: move_writer_index(i, 5), idx) + return obs, idx + + def write_agent( + obs: chex.Array, idx: int, id_agent: int, _: bool + ) -> Tuple[chex.Array, int]: + "Write information for cell containing an agent." + obs, idx = write_to_observation(obs, idx, jnp.array([1], dtype=jnp.int32)) + direction = jax.nn.one_hot( + tree_slice(agents, id_agent - 1).direction, 4, dtype=jnp.int32 + ) + obs, idx = write_to_observation(obs, idx, direction) + return obs, idx + + def write_no_shelf(obs: chex.Array, idx: int, _: int) -> Tuple[chex.Array, int]: + "write information for empty shelf cell." + idx = move_writer_index(idx, 2) + return obs, idx + + def write_shelf(obs: chex.Array, idx: int, shelf_id: int) -> Tuple[chex.Array, int]: + "Write information for cell containing a shelf." + requested = tree_slice(shelves, shelf_id - 1).is_requested + shelf = jnp.array([1, requested], dtype=jnp.int32) + obs, idx = write_to_observation(obs, idx, shelf) + return obs, idx + + def agent_sensor_scan( + obs_idx_and_agent_id: Tuple[chex.Array, chex.Array, chex.Array], + agent_sensor: chex.Array, + ) -> Tuple[Tuple[chex.Array, chex.Array, chex.Array], None]: + """Write agent observation with agent sensor information + of other agents. + """ + obs, idx, agent_id = obs_idx_and_agent_id + sensor_check_for_self = jnp.equal(agent_sensor, agent_id + 1) + sensor_check_for_self_or_no_other = jnp.logical_or( + jnp.equal(agent_sensor, 0), + sensor_check_for_self, + ) + obs, idx = jax.lax.cond( + sensor_check_for_self_or_no_other, + write_no_agent, + write_agent, + obs, + idx, + agent_sensor, + sensor_check_for_self, + ) + return (obs, idx, agent_id), None + + def shelf_sensor_scan( + obs_and_idx: Tuple[chex.Array, chex.Array], shelf_sensor: chex.Array + ) -> Tuple[Tuple[chex.Array, chex.Array], None]: + """Write agent observation with agent sensor information + of other shelves. + """ + obs, idx = obs_and_idx + obs, idx = jax.lax.cond( + jnp.equal(shelf_sensor, 0), + write_no_shelf, + write_shelf, + obs, + idx, + shelf_sensor, + ) + return (obs, idx), None + + (obs, idx, _), _ = jax.lax.scan( + agent_sensor_scan, (obs, idx, agent_id), agents_grid + ) + (obs, _), _ = jax.lax.scan(shelf_sensor_scan, (obs, idx), shelves_grid) + return obs + + +def compute_action_mask(grid: chex.Array, agents: Agent) -> chex.Array: + """Compute the action mask for the environment. + + Args: + grid: the warehouse floor grid array. + agents: a pytree of either Agent type containing agent information. + + Returns: + the action mask for the environment. + """ + # vmap over agents and possible actions + action_mask = jax.vmap( + jax.vmap(functools.partial(is_valid_action, grid), in_axes=(None, 0)), + in_axes=(0, None), + )(agents, jnp.arange(5)) + return action_mask diff --git a/jumanji/environments/routing/robot_warehouse/utils_agent.py b/jumanji/environments/routing/robot_warehouse/utils_agent.py new file mode 100644 index 000000000..1e568f7fb --- /dev/null +++ b/jumanji/environments/routing/robot_warehouse/utils_agent.py @@ -0,0 +1,330 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, Union + +import chex +import jax +import jax.numpy as jnp + +from jumanji.environments.routing.robot_warehouse.constants import _AGENTS, _SHELVES +from jumanji.environments.routing.robot_warehouse.types import Action, Agent, Position +from jumanji.environments.routing.robot_warehouse.utils_shelf import ( + set_new_shelf_position_if_carrying, +) +from jumanji.tree_utils import tree_add_element, tree_slice + + +def update_agent( + agents: Agent, + agent_id: chex.Array, + attr: str, + value: Union[chex.Array, Position], +) -> Agent: + """Update the attribute information of a specific agent. + + Args: + agents: a pytree of either Agent type containing agent information. + agent_id: unique ID identifying a specific agent. + attr: the attribute to update, e.g. `direction`, or `is_requested`. + value: the new value to which the attribute is to be set. + + Returns: + the agent with the specified attribute updated to the given value. + """ + params = {attr: value} + agent = tree_slice(agents, agent_id) + agent = agent._replace(**params) + agents: Agent = tree_add_element(agents, agent_id, agent) + return agents + + +def get_new_direction_after_turn( + action: chex.Array, agent_direction: chex.Array +) -> chex.Array: + """Get the correct direction the agent should face given + the turn action it took. E.g. if the agent is facing LEFT + and turns RIGHT it should now be facing UP, etc. + + Args: + action: the agent's action. + agent_direction: the agent's current direction. + + Returns: + the direction the agent should be facing given the action it took. + """ + change_in_direction = jnp.array([0, 0, -1, 1, 0])[action] + return (agent_direction + change_in_direction) % 4 + + +def get_new_position_after_forward( + grid: chex.Array, agent_position: chex.Array, agent_direction: chex.Array +) -> Position: + """Get the correct position the agent will be in after moving forward + in its current direction. E.g. if the agent is facing LEFT and turns + RIGHT it should stay in the same position. If instead it moves FORWARD + it should move left by one cell. + + Args: + grid: the warehouse floor grid array. + agent_position: the agent's current position. + agent_direction: the agent's current direction. + + Returns: + the position the agent should be in given the action it took. + """ + _, grid_width, grid_height = grid.shape + x, y = agent_position.x, agent_position.y + move_up = lambda x, y: Position(jnp.max(jnp.array([0, x - 1])), y) + move_right = lambda x, y: Position(x, jnp.min(jnp.array([grid_height - 1, y + 1]))) + move_down = lambda x, y: Position(jnp.min(jnp.array([grid_width - 1, x + 1])), y) + move_left = lambda x, y: Position(x, jnp.max(jnp.array([0, y - 1]))) + new_position: Position = jax.lax.switch( + agent_direction, [move_up, move_right, move_down, move_left], x, y + ) + return new_position + + +def get_agent_view( + grid: chex.Array, agent: chex.Array, sensor_range: chex.Array +) -> Tuple[chex.Array, chex.Array]: + """Get an agent's view of other agents and shelves within its + sensor range. + + Below is an example of the agent's view of other agents from + the perspective of agent 1 with a sensor range of 1: + + 0, 0, 0 + 0, 1, 2 + 0, 0, 0 + + It sees agent 2 to its right. Separately, the view of shelves + is shown below: + + 0, 0, 0 + 0, 3, 4 + 0, 7, 8 + + Agent 1 is on top of shelf 3 and has 4, 7 and 8 around it in + the bottom right corner of its view. Before returning these + views they are flattened into a 1-d arrays, i.e. + + View of agents: [0, 0, 0, 0, 1, 2, 0, 0, 0] + View of shelves: [0, 0, 0, 0, 3, 4, 0, 7, 8] + + + Args: + grid: the warehouse floor grid array. + agent: the agent for which the view of their receptive field + is to be calculated. + sensor_range: the range of the agent's sensors. + + Returns: + a view of the agents receptive field separated into two arrays: + one for other agents and one for shelves. + """ + receptive_field = sensor_range * 2 + 1 + padded_agents_layer = jnp.pad(grid[_AGENTS], sensor_range, mode="constant") + padded_shelves_layer = jnp.pad(grid[_SHELVES], sensor_range, mode="constant") + agent_view_of_agents = jax.lax.dynamic_slice( + padded_agents_layer, + (agent.position.x, agent.position.y), + (receptive_field, receptive_field), + ).reshape(-1) + agent_view_of_shelves = jax.lax.dynamic_slice( + padded_shelves_layer, + (agent.position.x, agent.position.y), + (receptive_field, receptive_field), + ).reshape(-1) + return agent_view_of_agents, agent_view_of_shelves + + +def set_agent_carrying_if_at_shelf_position( + grid: chex.Array, agents: chex.Array, agent_id: int, is_highway: chex.Array +) -> chex.Array: + """Set the agent as carrying a shelf if it is at a shelf position. + + Args: + grid: the warehouse floor grid array. + agents: a pytree of either Agent type containing agent information. + agent_id: unique ID identifying a specific agent. + is_highway: binary value indicating highway position. + + Returns: + updated agents pytree. + """ + agent = tree_slice(agents, agent_id) + shelf_id = grid[_SHELVES, agent.position.x, agent.position.y] + + return jax.lax.cond( + shelf_id > 0, + lambda: update_agent(agents, agent_id, "is_carrying", 1), + lambda: agents, + ) + + +def offload_shelf_if_position_is_open( + grid: chex.Array, agents: chex.Array, agent_id: int, is_highway: chex.Array +) -> chex.Array: + """Set the agent as not carrying a shelf if it is at a shelf position. + + Args: + grid: the warehouse floor grid array. + agents: a pytree of either Agent type containing agent information. + agent_id: unique ID identifying a specific agent. + is_highway: binary value indicating highway position. + + Returns: + updated agents pytree. + """ + return jax.lax.cond( + jnp.logical_not(is_highway), + lambda: update_agent(agents, agent_id, "is_carrying", 0), + lambda: agents, + ) + + +def set_carrying_shelf_if_load_toggled_and_not_carrying( + grid: chex.Array, + agents: chex.Array, + action: int, + agent_id: int, + is_highway: chex.Array, +) -> chex.Array: + """Set the agent as carrying a shelf if the load toggle action is + performed and the agent is not carrying a shelf. + + Args: + grid: the warehouse floor grid array. + agents: a pytree of either Agent type containing agent information. + action: the agent's action. + agent_id: unique ID identifying a specific agent. + is_highway: binary value indicating highway position. + + Returns: + updated agents pytree. + """ + agent = tree_slice(agents, agent_id) + + agents = jax.lax.cond( + (action == Action.TOGGLE_LOAD.value) & ~agent.is_carrying, + set_agent_carrying_if_at_shelf_position, + offload_shelf_if_position_is_open, + grid, + agents, + agent_id, + is_highway, + ) + return agents + + +def rotate_agent( + grid: chex.Array, + agents: chex.Array, + action: int, + agent_id: int, + is_highway: chex.Array, +) -> chex.Array: + """Rotate the agent in the direction of the action. + + Args: + grid: the warehouse floor grid array. + agents: a pytree of either Agent type containing agent information. + action: the agent's action. + agent_id: unique ID identifying a specific agent. + is_highway: binary value indicating highway position. + + Returns: + updated agents pytree. + """ + agent = tree_slice(agents, agent_id) + new_direction = get_new_direction_after_turn(action, agent.direction) + return update_agent(agents, agent_id, "direction", new_direction) + + +def set_new_position_after_forward( + grid: chex.Array, + agents: chex.Array, + shelves: chex.Array, + action: int, + agent_id: int, + is_highway: chex.Array, +) -> Tuple[chex.Array, chex.Array, chex.Array]: + """Set the new position of the agent after a forward action. + + Args: + grid: the warehouse floor grid array. + agents: a pytree of either Agent type containing agent information. + shelves: a pytree of Shelf type containing shelf information. + action: the agent's action. + agent_id: unique ID identifying a specific agent. + is_highway: binary value indicating highway position. + + Returns: + updated grid array, agents and shelves pytrees. + """ + # update agent position + agent = tree_slice(agents, agent_id) + current_position = agent.position + new_position = get_new_position_after_forward(grid, agent.position, agent.direction) + agents = update_agent(agents, agent_id, "position", new_position) + + # update agent grid placement + grid = grid.at[_AGENTS, current_position.x, current_position.y].set(0) + grid = grid.at[_AGENTS, new_position.x, new_position.y].set(agent_id + 1) + + grid, shelves = jax.lax.cond( + agent.is_carrying, + set_new_shelf_position_if_carrying, + lambda g, s, p, np: (g, s), + grid, + shelves, + current_position, + new_position, + ) + return grid, agents, shelves + + +def set_new_direction_after_turn( + grid: chex.Array, + agents: chex.Array, + shelves: chex.Array, + action: int, + agent_id: int, + is_highway: chex.Array, +) -> Tuple[chex.Array, chex.Array, chex.Array]: + """Set the new direction of the agent after a turning action. + + Args: + grid: the warehouse floor grid array. + agents: a pytree of either Agent type containing agent information. + shelves: a pytree of Shelf type containing shelf information. + action: the agent's action. + agent_id: unique ID identifying a specific agent. + is_highway: binary value indicating highway position. + + Returns: + updated grid array, agents and shelves pytrees. + """ + agents = jax.lax.cond( + jnp.isin(action, jnp.array([Action.LEFT.value, Action.RIGHT.value])), + rotate_agent, + set_carrying_shelf_if_load_toggled_and_not_carrying, + grid, + agents, + action, + agent_id, + is_highway, + ) + return grid, agents, shelves diff --git a/jumanji/environments/routing/robot_warehouse/utils_shelf.py b/jumanji/environments/routing/robot_warehouse/utils_shelf.py new file mode 100644 index 000000000..91fb59ba2 --- /dev/null +++ b/jumanji/environments/routing/robot_warehouse/utils_shelf.py @@ -0,0 +1,72 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, Union + +import chex + +from jumanji.environments.routing.robot_warehouse.constants import _SHELVES +from jumanji.environments.routing.robot_warehouse.types import Position, Shelf +from jumanji.tree_utils import tree_add_element, tree_slice + + +def update_shelf( + shelves: Shelf, + shelf_id: chex.Array, + attr: str, + value: Union[chex.Array, Position], +) -> Shelf: + """Update the attribute information of a specific shelf. + + Args: + shelves: a pytree of Shelf type containing shelf information. + shelf_id: unique ID identifying a specific shelf. + attr: the attribute to update, e.g. `direction`, or `is_requested`. + value: the new value to which the attribute is to be set. + + Returns: + the shelf with the specified attribute updated to the given value. + """ + params = {attr: value} + shelf = tree_slice(shelves, shelf_id) + shelf = shelf._replace(**params) + shelves: Shelf = tree_add_element(shelves, shelf_id, shelf) + return shelves + + +def set_new_shelf_position_if_carrying( + grid: chex.Array, + shelves: Shelf, + cur_pos: chex.Array, + new_pos: chex.Array, +) -> Tuple[chex.Array, chex.Array]: + """Set the new position of the shelf if the agent is carrying one. + + Args: + grid: the warehouse floor grid array. + shelves: a pytree of Shelf type containing shelf information. + cur_pos: the current position of the shelf. + new_pos: the new position of the shelf. + + Returns: + updated grid array and shelves pytree. + """ + # update shelf position + shelf_id = grid[_SHELVES, cur_pos.x, cur_pos.y] + shelves = update_shelf(shelves, shelf_id - 1, "position", new_pos) + + # update shelf grid placement + grid = grid.at[_SHELVES, cur_pos.x, cur_pos.y].set(0) + grid = grid.at[_SHELVES, new_pos.x, new_pos.y].set(shelf_id) + return grid, shelves diff --git a/jumanji/environments/routing/robot_warehouse/utils_spawn.py b/jumanji/environments/routing/robot_warehouse/utils_spawn.py new file mode 100644 index 000000000..2da8b91c2 --- /dev/null +++ b/jumanji/environments/routing/robot_warehouse/utils_spawn.py @@ -0,0 +1,192 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import chex +import jax +import jax.numpy as jnp + +from jumanji.environments.routing.robot_warehouse.constants import ( + _AGENTS, + _POSSIBLE_DIRECTIONS, + _SHELVES, +) +from jumanji.environments.routing.robot_warehouse.types import ( + Agent, + Entity, + Position, + Shelf, +) +from jumanji.environments.routing.robot_warehouse.utils import get_entity_ids +from jumanji.tree_utils import tree_slice + + +def spawn_agent( + agent_coordinates: chex.Array, + direction: chex.Array, +) -> chex.Array: + """Spawn an agent (robot) at a given position and direction. + + Args: + agent_coordinates: x, y coordinates of the agent. + direction: direction of the agent. + + Returns: + spawned agent. + """ + x, y = agent_coordinates + agent_pos = Position(x=x, y=y) + agent = Agent(position=agent_pos, direction=direction, is_carrying=0) + return agent + + +def spawn_shelf( + shelf_coordinates: chex.Array, + requested: chex.Array, +) -> chex.Array: + """Spawn a shelf at a specific shelf position and label the shelf + as requested or not. + + Args: + shelf_coordinates: x, y coordinates of the shelf. + requested: whether the shelf has been requested or not. + + Returns: + spawned shelf. + """ + x, y = shelf_coordinates + shelf_pos = Position(x=x, y=y) + shelf = Shelf(position=shelf_pos, is_requested=requested) + return shelf + + +def spawn_random_entities( + key: chex.PRNGKey, + grid_size: chex.Array, + agent_ids: chex.Array, + shelf_ids: chex.Array, + shelf_coordinates: chex.Array, + request_queue_size: chex.Array, +) -> Tuple[chex.PRNGKey, Agent, Shelf, chex.Array]: + """Spawn agents and shelves on the warehouse floor grid. + + Args: + key: pseudo random number key. + grid_size: the size of the warehouse floor grid. + agent_ids: array of agent ids. + shelf_ids: array of shelf ids. + shelf_coordinates: x,y coordinates of shelf positions. + request_queue_size: the number of shelves to be delivered. + + Returns: + new key, spawned agents, shelves and the request queue. + """ + + # random agent positions + num_agents = len(agent_ids) + key, position_key = jax.random.split(key) + grid_cells = jnp.array(jnp.arange(grid_size[0] * grid_size[1])) + agent_coords = jax.random.choice( + position_key, + grid_cells, + shape=(num_agents,), + replace=False, + ) + agent_coords = jnp.transpose( + jnp.asarray(jnp.unravel_index(agent_coords, grid_size)) + ) + + # random agent directions + key, direction_key = jax.random.split(key) + + agent_dirs = jax.random.choice( + direction_key, _POSSIBLE_DIRECTIONS, shape=(num_agents,) + ) + + # sample request queue + key, queue_key = jax.random.split(key) + shelf_request_queue = jax.random.choice( + queue_key, + shelf_ids, + shape=(request_queue_size,), + replace=False, + ) + requested_ids = jnp.zeros(shelf_ids.shape) + requested_ids = requested_ids.at[shelf_request_queue].set(1) + + # spawn agents and shelves + agents = jax.vmap(spawn_agent)(agent_coords, agent_dirs) + shelves = jax.vmap(spawn_shelf)(shelf_coordinates, requested_ids) + return key, agents, shelves, shelf_request_queue + + +def place_entity_on_grid( + grid: chex.Array, + channel: chex.Array, + entities: Entity, + entity_id: chex.Array, +) -> chex.Array: + """Places an entity (Agent/Shelf) on the grid based on its + (x, y) position defined once spawned. + + Args: + grid: the warehouse floor grid array. + channel: the grid channel index, either agents or shelves. + entities: a pytree of Agent or Shelf type containing entity information. + entity_id: unique ID identifying a specific entity. + + Returns: + the warehouse grid with the specific entity in its position. + """ + entity = tree_slice(entities, entity_id) + x, y = entity.position.x, entity.position.y + return grid.at[channel, x, y].set(entity_id + 1) + + +def place_entities_on_grid( + grid: chex.Array, agents: Agent, shelves: Shelf +) -> chex.Array: + """Place agents and shelves on the grid. + + Args: + grid: the warehouse floor grid array. + agents: a pytree of Agent type containing agent information. + shelves: a pytree of Shelf type containing shelf information. + + Returns: + the warehouse grid with all agents and shelves placed in their + positions. + """ + agent_ids = get_entity_ids(agents) + shelf_ids = get_entity_ids(shelves) + + # place agents and shelves on warehouse grid + def place_agents_scan( + grid_and_agents: Tuple[chex.Array, chex.Array], agent_id: chex.Array + ) -> Tuple[Tuple[chex.Array, chex.Array], None]: + grid, agents = grid_and_agents + grid = place_entity_on_grid(grid, _AGENTS, agents, agent_id) + return (grid, agents), None + + def place_shelves_scan( + grid_and_shelves: Tuple[chex.Array, chex.Array], shelf_id: chex.Array + ) -> Tuple[Tuple[chex.Array, chex.Array], None]: + grid, shelves = grid_and_shelves + grid = place_entity_on_grid(grid, _SHELVES, shelves, shelf_id) + return (grid, shelves), None + + (grid, _), _ = jax.lax.scan(place_agents_scan, (grid, agents), agent_ids) + (grid, _), _ = jax.lax.scan(place_shelves_scan, (grid, shelves), shelf_ids) + return grid diff --git a/jumanji/environments/routing/robot_warehouse/utils_test.py b/jumanji/environments/routing/robot_warehouse/utils_test.py new file mode 100644 index 000000000..91b04d4b0 --- /dev/null +++ b/jumanji/environments/routing/robot_warehouse/utils_test.py @@ -0,0 +1,436 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import chex +import jax +import jax.numpy as jnp +import pytest + +from jumanji.environments.routing.robot_warehouse.types import ( + Action, + Agent, + Position, + Shelf, + State, +) +from jumanji.environments.routing.robot_warehouse.utils import ( + calculate_num_observation_features, + compute_action_mask, + get_valid_actions, + is_collision, + is_valid_action, + move_writer_index, + write_to_observation, +) +from jumanji.environments.routing.robot_warehouse.utils_agent import ( + get_agent_view, + get_new_direction_after_turn, + get_new_position_after_forward, + update_agent, +) +from jumanji.environments.routing.robot_warehouse.utils_shelf import update_shelf +from jumanji.environments.routing.robot_warehouse.utils_spawn import ( + place_entities_on_grid, +) +from jumanji.tree_utils import tree_slice + + +@pytest.fixture +def fake_robot_warehouse_env_state() -> State: + """Create a fake robot_warehouse environment state.""" + + # create agents, shelves and grid + def make_agent( + x: chex.Array, y: chex.Array, direction: chex.Array, is_carrying: chex.Array + ) -> Agent: + return Agent(Position(x=x, y=y), direction=direction, is_carrying=is_carrying) + + def make_shelf(x: chex.Array, y: chex.Array, is_requested: chex.Array) -> Shelf: + return Shelf(Position(x=x, y=y), is_requested=is_requested) + + # agent information + xs = jnp.array([3, 1]) + ys = jnp.array([4, 7]) + dirs = jnp.array([2, 3]) + carries = jnp.array([0, 0]) + agents = jax.vmap(make_agent)(xs, ys, dirs, carries) + + # shelf information + xs = jnp.array([1, 1, 1, 1, 2, 2, 2, 2]) + ys = jnp.array([1, 2, 7, 8, 1, 2, 7, 8]) + requested = jnp.array([0, 1, 1, 0, 0, 0, 1, 1]) + shelves = jax.vmap(make_shelf)(xs, ys, requested) + + # create grid + grid = jnp.array( + [ + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 0, 0, 0, 0, 3, 4, 0], + [0, 5, 6, 0, 0, 0, 0, 7, 8, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 2, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ) + action_mask = jnp.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]) + state = State( + grid=grid, + agents=agents, + shelves=shelves, + request_queue=jnp.array([1, 2, 6, 7]), + step_count=jnp.array([0], dtype=jnp.int32), + action_mask=action_mask, + key=jax.random.PRNGKey(42), + ) + return state + + +def test_robot_warehouse_utils__entity_placement( + fake_robot_warehouse_env_state: State, +) -> None: + """Test entity placement on the warehouse grid floor.""" + state = fake_robot_warehouse_env_state + agents, shelves = state.agents, state.shelves + empty_grid = jnp.zeros(state.grid.shape, dtype=jnp.int32) + grid_with_agents_and_shelves = place_entities_on_grid(empty_grid, agents, shelves) + + # check that placement is the same as in fake grid + assert jnp.all(grid_with_agents_and_shelves == state.grid) + + +def test_robot_warehouse_utils__entity_update( + fake_robot_warehouse_env_state: State, +) -> None: + """Test entity attribute (e.g. position, direction etc.) updating.""" + state = fake_robot_warehouse_env_state + agents = state.agents + shelves = state.shelves + + # test updating agent position + new_position = Position(x=2, y=4) + agents_with_new_agent_0_position = update_agent(agents, 0, "position", new_position) + agent_0 = tree_slice(agents_with_new_agent_0_position, 0) + assert agent_0.position == new_position + + # test updating agent direction + new_direction = 3 + agents_with_new_agent_0_direction = update_agent( + agents, 0, "direction", new_direction + ) + agent_0 = tree_slice(agents_with_new_agent_0_direction, 0) + assert agent_0.direction == new_direction + + # test updating agent carrying + new_is_carrying = 1 + agents_with_new_agent_0_carrying = update_agent( + agents, 0, "is_carrying", new_is_carrying + ) + agent_0 = tree_slice(agents_with_new_agent_0_carrying, 0) + assert agent_0.is_carrying == new_is_carrying + + # test updating shelf position + new_position = Position(x=1, y=3) + shelves_with_new_shelf_0_position = update_shelf( + shelves, 0, "position", new_position + ) + shelf_0 = tree_slice(shelves_with_new_shelf_0_position, 0) + assert shelf_0.position == new_position + + # test updating shelf requested + new_is_requested = 1 + shelves_with_new_shelf_0_requested = update_shelf( + shelves, 0, "is_requested", new_is_requested + ) + shelf_0 = tree_slice(shelves_with_new_shelf_0_requested, 0) + assert shelf_0.is_requested == new_is_requested + + +def test_robot_warehouse_utils__get_new_direction( + fake_robot_warehouse_env_state: State, +) -> None: + """Test the calculation of the new direction for an agent after turning.""" + state = fake_robot_warehouse_env_state + agents = state.agents + agent = tree_slice(agents, 0) + direction = agent.direction # 2 (facing down) + + # turning: left, left, right, right, right + actions = [2, 2, 3, 3, 3] + expected_directions = [ + 1, # turn left -> facing right + 0, # turn left -> facing up + 1, # turn right -> facing right + 2, # turn right -> face down + 3, # turn right -> face left + ] + + for action, expected_direction in zip(actions, expected_directions): + new_direction = get_new_direction_after_turn(action, direction) + assert new_direction == expected_direction + direction = new_direction + + +def test_robot_warehouse_utils__get_new_position( + fake_robot_warehouse_env_state: State, +) -> None: + """Test the calculation of the new position for an agent after moving + forward in a specific direction.""" + state = fake_robot_warehouse_env_state + grid = state.grid + agents = state.agents + agent = tree_slice(agents, 0) + position = agent.position # [x=3, y=4] + directions = jnp.arange(4) + + # move forward once in each direction + expected_positions = [ + Position(2, 4), # facing up move forward + Position(3, 5), # facing right move forward + Position(4, 4), # facing down move forward + Position(3, 3), # facing left move forward + ] + + for direction, expected_position in zip(directions, expected_positions): + new_position = get_new_position_after_forward(grid, position, direction) + assert new_position == expected_position + + +def test_robot_warehouse_utils__is_collision( + fake_robot_warehouse_env_state: State, +) -> None: + """Test the calculation of collisions between agents and other agents as well as + agents carrying shelves and other shelves.""" + state = fake_robot_warehouse_env_state + grid = state.grid + agents = state.agents + + # check no collision with original grid + agent = tree_slice(agents, 0) + action = 1 # forward + collision = is_collision(grid, agent, action) + assert bool(collision) is False + + # create grid with agents next to each other + grid_prior_to_collision = jnp.array( + [ + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 0, 0, 0, 0, 3, 4, 0], + [0, 5, 6, 0, 0, 0, 0, 7, 8, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 2, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ) + + # update agent zero to face agent 2 + agents = update_agent(agents, 0, "position", Position(1, 6)) + agents = update_agent(agents, 0, "direction", 1) + agent = tree_slice(agents, 0) + + # check collision if moving forward + collision = is_collision(grid_prior_to_collision, agent, action) + assert bool(collision) is True + + +def test_robot_warehouse_utils__is_valid_action( + fake_robot_warehouse_env_state: State, +) -> None: + """Test the calculation of collisions between agents and other agents as well as + agents carrying shelves and other shelves.""" + state = fake_robot_warehouse_env_state + grid = state.grid + agents = state.agents + + # turn agent 2 around, and check no collision with shelf + # i.e. agent is moving underneath shelf rack via highway + agents = update_agent(agents, 1, "direction", 1) + agent = tree_slice(agents, 1) + action = 1 # forward + action = is_valid_action(grid, agent, action) + assert action == Action.FORWARD.value + + # Let agent 2 pick up shelf and move forward + # to test collision with shelf when carrying + # and convert to NOOP action + agents = update_agent(agents, 1, "is_carrying", 1) + agent = tree_slice(agents, 1) + action = is_valid_action(grid, agent, action) + assert action == Action.NOOP.value + + +def test_robot_warehouse_utils__get_agent_view( + fake_robot_warehouse_env_state: State, +) -> None: + """Test extracting the agent's view of other agents and shelves within + its receptive field as set via a given sensor range.""" + state = fake_robot_warehouse_env_state + grid = jnp.array( + [ + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 0, 0, 0, 0, 3, 4, 0], + [0, 5, 6, 0, 0, 0, 0, 7, 8, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 2, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ) + agents = state.agents + agents = update_agent(agents, 0, "position", Position(1, 6)) + agent = tree_slice(agents, 0) + + # get agent view with sensor range of 1 + sensor_range = 1 + agent_view_of_agents, agent_view_of_shelves = get_agent_view( + grid, agent, sensor_range + ) + + # flattened agent view of other agents and shelves + flat_agents = jnp.array([0, 0, 0, 0, 1, 2, 0, 0, 0]) + flat_shelves = jnp.array([0, 0, 0, 0, 0, 3, 0, 0, 7]) + + assert jnp.array_equal(agent_view_of_agents, flat_agents) + assert jnp.array_equal(agent_view_of_shelves, flat_shelves) + + # get agent view with sensor range of 2 + sensor_range = 2 + agent_view_of_agents, agent_view_of_shelves = get_agent_view( + grid, agent, sensor_range + ) + + # flattened agent view of other agents and shelves + flat_agents = jnp.array( + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ) + flat_shelves = jnp.array( + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 4, 0, 0, 0, 7, 8, 0, 0, 0, 0, 0] + ) + + assert jnp.array_equal(agent_view_of_agents, flat_agents) + assert jnp.array_equal(agent_view_of_shelves, flat_shelves) + + +def test_robot_warehouse_utils__calculate_num_observation_features() -> None: + """Test the calculation of the size of the agent's observation + vector based on sensor range.""" + sensor_range = 1 + num_obs_features = calculate_num_observation_features(sensor_range) + assert num_obs_features == 66 + + sensor_range = 2 + num_obs_features = calculate_num_observation_features(sensor_range) + assert num_obs_features == 178 + + +def test_robot_warehouse_utils__observation_writer( + fake_robot_warehouse_env_state: State, +) -> None: + """Test observation writer to write data to 1-d observation vector. + Note that this test does not construct a full observation vector. It + only tests basic functionality by writing the agent's view of itself + and does not include writing agent view data from other agents/shelves.""" + state = fake_robot_warehouse_env_state + agents = state.agents + agent = tree_slice(agents, 0) + + # write flattened observation just for the agent's own view + obs = jnp.zeros(8, dtype=jnp.int32) + idx = 0 + + # write current agent position and whether carrying a shelf or not + obs, idx = write_to_observation( + obs, + idx, + jnp.array( + [agent.position.x, agent.position.y, agent.is_carrying], + dtype=jnp.int32, + ), + ) + + # write current agent direction + direction = jax.nn.one_hot(agent.direction, 4, dtype=jnp.int32) + obs, idx = write_to_observation(obs, idx, direction) + + # move index by one (keeping zero to indicate agent not on highway) + idx = move_writer_index(idx, 1) + + assert jnp.array_equal(obs, jnp.array([3, 4, 0, 0, 0, 1, 0, 0])) + assert idx == 8 + + +def test_robot_warehouse_utils__compute_action_mask( + fake_robot_warehouse_env_state: State, +) -> None: + state = fake_robot_warehouse_env_state + grid = state.grid + agents = state.agents + + action_mask = compute_action_mask(grid, agents) + assert jnp.array_equal(action_mask[1], jnp.array([1, 1, 1, 1, 1])) + + # Let agent 2 turn around, pick up shelf and move forward + # to test collision with shelf when carrying + # which is an illegal action + agents = update_agent(agents, 1, "direction", 1) + agents = update_agent(agents, 1, "is_carrying", 1) + + action_mask = compute_action_mask(grid, agents) + assert jnp.array_equal(action_mask[1], jnp.array([1, 0, 1, 1, 1])) + + +def test_robot_warehouse_utils__get_valid_action( + fake_robot_warehouse_env_state: State, +) -> None: + state = fake_robot_warehouse_env_state + grid = state.grid + agents = state.agents + actions = jnp.array([1, 1]) # forward + + action_mask = compute_action_mask(grid, agents) + actions = get_valid_actions(actions, action_mask) + jax.debug.print("action, {a}", a=actions) + assert jnp.array_equal(actions, jnp.array([1, 1])) + + # Let agent 2 turn around, pick up shelf and move forward + # to test collision with shelf when carrying + # which is an illegal action + agents = update_agent(agents, 1, "direction", 1) + agents = update_agent(agents, 1, "is_carrying", 1) + + action_mask = compute_action_mask(grid, agents) + actions = get_valid_actions(actions, action_mask) + assert jnp.array_equal(actions, jnp.array([1, 0])) # turn into noop action diff --git a/jumanji/environments/routing/robot_warehouse/viewer.py b/jumanji/environments/routing/robot_warehouse/viewer.py new file mode 100644 index 000000000..5ba135748 --- /dev/null +++ b/jumanji/environments/routing/robot_warehouse/viewer.py @@ -0,0 +1,320 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa: CCR001 + +from typing import Callable, Optional, Sequence, Tuple + +import chex +import matplotlib.animation as animation +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.collections import LineCollection +from numpy.typing import NDArray + +import jumanji +import jumanji.environments.routing.robot_warehouse.constants as constants +from jumanji.environments.routing.robot_warehouse.types import Direction, State +from jumanji.tree_utils import tree_slice +from jumanji.viewer import Viewer + + +class RobotWarehouseViewer(Viewer): + def __init__( + self, + grid_size: Tuple[int, int], + goals: chex.Array, + name: str = "RobotWarehouse", + render_mode: str = "human", + ) -> None: + """Viewer for the RobotWarehouse environment. + + Args: + grid_size: the size of the warehouse floor grid (width, height) + goals: x,y coordinates of goal locations (where shelves + should be delivered) + name: custom name for the Viewer. Defaults to `RobotWarehouse`. + """ + self._name = name + self.goals = goals + self.rows, self.cols = grid_size + + self.grid_size = 30 + self.icon_size = 20 + + self.width = 1 + self.cols * (self.grid_size + 1) + self.height = 1 + self.rows * (self.grid_size + 1) + self._display: Callable[[plt.Figure], Optional[NDArray]] + if render_mode == "rgb_array": + self._display = self._display_rgb_array + elif render_mode == "human": + self._display = self._display_human + else: + raise ValueError(f"Invalid render mode: {render_mode}") + + # The animation must be stored in a variable that lives as long as the + # animation should run. Otherwise, the animation will get garbage-collected. + self._animation: Optional[animation.Animation] = None + + def render(self, state: State) -> Optional[NDArray]: + """Render the given state of the `RobotWarehouse` environment. + + Args: + state: the environment state to render. + """ + self._clear_display() + fig, ax = self._get_fig_ax() + ax.clear() + self._prepare_figure(ax) + self._draw_state(ax, state) + return self._display(fig) + + def animate( + self, + states: Sequence[State], + interval: int = 200, + save_path: Optional[str] = None, + ) -> animation.FuncAnimation: + """Create an animation from a sequence of environment states. + + Args: + states: sequence of environment states corresponding to consecutive timesteps. + interval: delay between frames in milliseconds, default to 200. + save_path: the path where the animation file should be saved. If it is None, the plot + will not be saved. + + Returns: + Animation that can be saved as a GIF, MP4, or rendered with HTML. + """ + fig = plt.figure(f"{self._name}Animation", figsize=constants._FIGURE_SIZE) + fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0) + ax = fig.add_subplot(111) + plt.close(fig) + self._prepare_figure(ax) + + def make_frame(state: State) -> None: + ax.clear() + self._prepare_figure(ax) + self._draw_state(ax, state) + + # Create the animation object. + self._animation = animation.FuncAnimation( + fig, + make_frame, + frames=states, + interval=interval, + ) + + # Save the animation as a gif. + if save_path: + self._animation.save(save_path) + + return self._animation + + def close(self) -> None: + plt.close(self._name) + + def _clear_display(self) -> None: + if jumanji.environments.is_colab(): + import IPython.display + + IPython.display.clear_output(True) + + def _get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]: + recreate = not plt.fignum_exists(self._name) + fig = plt.figure(self._name, figsize=constants._FIGURE_SIZE) + fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0) + + if recreate: + fig.tight_layout() + if not plt.isinteractive(): + fig.show() + ax = fig.add_subplot(111) + else: + ax = fig.get_axes()[0] + return fig, ax + + def _prepare_figure(self, ax: plt.Axes) -> None: + ax.set_xlim(0, self.width) + ax.set_ylim(0, self.height) + ax.patch.set_alpha(0.0) + ax.set_axis_off() + + ax.set_aspect("equal", "box") + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_frame_on(False) + + def _draw_state(self, ax: plt.Axes, state: State) -> None: + self.n_agents = state.agents.position.x.shape[0] + self.n_shelves = state.shelves.position.x.shape[0] + self._draw_grid(ax) + self._draw_goals(ax) + self._draw_shelves(ax, state.shelves) + self._draw_agents(ax, state.agents) + + def _draw_grid(self, ax: plt.Axes) -> None: + """Draw grid of warehouse floor.""" + lines = [] + # VERTICAL LINES + for r in range(self.rows + 1): + lines.append( + [ + (0, (self.grid_size + 1) * r + 1), + ((self.grid_size + 1) * self.cols, (self.grid_size + 1) * r + 1), + ] + ) + + # HORIZONTAL LINES + for c in range(self.cols + 1): + lines.append( + [ + ((self.grid_size + 1) * c + 1, 0), + ((self.grid_size + 1) * c + 1, (self.grid_size + 1) * self.rows), + ] + ) + + lc = LineCollection(lines, colors=(constants._GRID_COLOR,)) + ax.add_collection(lc) + + def _draw_goals(self, ax: plt.Axes) -> None: + """Draw goals, i.e. positions where shelves should be delivered.""" + for goal in self.goals: + x, y = goal + y = self.rows - y - 1 # pyglet rendering is reversed + ax.fill( # changed to ax, from plt, check if still works! + [ + x * (self.grid_size + 1) + 1, + (x + 1) * (self.grid_size + 1), + (x + 1) * (self.grid_size + 1), + x * (self.grid_size + 1) + 1, + ], + [ + y * (self.grid_size + 1) + 1, + y * (self.grid_size + 1) + 1, + (y + 1) * (self.grid_size + 1), + (y + 1) * (self.grid_size + 1), + ], + color=constants._GOAL_COLOR, + alpha=1, + ) + + def _draw_shelves(self, ax: plt.Axes, shelves: chex.Array) -> None: + """Draw shelves at their respective positions. + + Args: + shelves: a pytree of Shelf type containing shelves information. + """ + for shelf_id in range(self.n_shelves): + shelf = tree_slice(shelves, shelf_id) + y, x = shelf.position.x, shelf.position.y + y = self.rows - y - 1 # pyglet rendering is reversed + shelf_color = ( + constants._SHELF_REQ_COLOR + if shelf.is_requested + else constants._SHELF_COLOR + ) + shelf_padding = constants._SHELF_PADDING + + x_points = [ + (self.grid_size + 1) * x + shelf_padding + 1, + (self.grid_size + 1) * (x + 1) - shelf_padding, + (self.grid_size + 1) * (x + 1) - shelf_padding, + (self.grid_size + 1) * x + shelf_padding + 1, + ] + + y_points = [ + (self.grid_size + 1) * y + shelf_padding + 1, + (self.grid_size + 1) * y + shelf_padding + 1, + (self.grid_size + 1) * (y + 1) - shelf_padding, + (self.grid_size + 1) * (y + 1) - shelf_padding, + ] + + ax.fill(x_points, y_points, color=shelf_color) + + def _draw_agents(self, ax: plt.Axes, agents: chex.Array) -> None: + """Draw agents at their respective positions. + + Args: + agents: a pytree of Shelf type containing agents information. + """ + radius = self.grid_size / 3 + + resolution = 6 + + for agent_id in range(self.n_agents): + agent = tree_slice(agents, agent_id) + row, col = agent.position.x, agent.position.y + row = self.rows - row - 1 # pyglet rendering is reversed + x_center = (self.grid_size + 1) * col + self.grid_size // 2 + 1 + y_center = (self.grid_size + 1) * row + self.grid_size // 2 + 1 + + # make a circle + verts = [] + for i in range(resolution): + angle = 2 * np.pi * i / resolution + + x_radius = radius * np.cos(angle) + x = x_radius + x_center + 1 + + y_radius = radius * np.sin(angle) + 1 + y = y_radius + y_center + verts += [[x, y]] + facecolor = ( + constants._AGENT_LOADED_COLOR + if agent.is_carrying + else constants._AGENT_COLOR + ) + circle = plt.Polygon( + verts, + edgecolor="none", + facecolor=facecolor, + ) + + ax.add_patch(circle) + + agent_dir = agent.direction + + x_dir = ( + x_center + + (radius if agent_dir == Direction.RIGHT.value else 0) + - (radius if agent_dir == Direction.LEFT.value else 0) + ) + y_dir = ( + y_center + + (radius if agent_dir == Direction.UP.value else 0) + - (radius if agent_dir == Direction.DOWN.value else 0) + ) + + ax.plot( + [x_center, x_dir], + [y_center, y_dir], + color=constants._AGENT_DIR_COLOR, + linewidth=2, + ) + + def _display_human(self, fig: plt.Figure) -> None: + if plt.isinteractive(): + # Required to update render when using Jupyter Notebook. + fig.canvas.draw() + if jumanji.environments.is_colab(): + plt.show(self._name) + else: + # Required to update render when not using Jupyter Notebook. + fig.canvas.draw_idle() + fig.canvas.flush_events() + + def _display_rgb_array(self, fig: plt.Figure) -> NDArray: + fig.canvas.draw() + return np.asarray(fig.canvas.buffer_rgba()) diff --git a/jumanji/environments/routing/robot_warehouse/viewer_test.py b/jumanji/environments/routing/robot_warehouse/viewer_test.py new file mode 100644 index 000000000..e4416db29 --- /dev/null +++ b/jumanji/environments/routing/robot_warehouse/viewer_test.py @@ -0,0 +1,82 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +import jax.random as random +import matplotlib +import matplotlib.pyplot as plt +import numpy as jnp +import py +import pytest + +from jumanji.environments.routing.robot_warehouse import RobotWarehouse +from jumanji.environments.routing.robot_warehouse.viewer import RobotWarehouseViewer + + +def test_robot_warehouse_viewer__render( + robot_warehouse_env: RobotWarehouse, monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.setattr(plt, "show", lambda fig: None) + key = random.PRNGKey(0) + state, _ = robot_warehouse_env.reset(key) + grid_size = robot_warehouse_env._generator.grid_size + goals = robot_warehouse_env._generator.goals + + viewer = RobotWarehouseViewer(grid_size, goals) + viewer.render(state) + viewer.close() + + +def test_robot_warehouse_viewer__animate(robot_warehouse_env: RobotWarehouse) -> None: + key = random.PRNGKey(0) + state, _ = jax.jit(robot_warehouse_env.reset)(key) + grid_size = robot_warehouse_env._generator.grid_size + goals = robot_warehouse_env._generator.goals + + num_steps = 5 + states = [state] + for _ in range(num_steps - 1): + key, subkey = jax.random.split(key) + action = jax.random.choice(subkey, jnp.arange(5), shape=(2,)) + state, _ = jax.jit(robot_warehouse_env.step)(state, action) + states.append(state) + + viewer = RobotWarehouseViewer(grid_size, goals) + viewer.animate(states) + viewer.close() + + +def test_robot_warehouse_viewer__save_animation( + robot_warehouse_env: RobotWarehouse, tmpdir: py.path.local +) -> None: + key = random.PRNGKey(0) + state, _ = jax.jit(robot_warehouse_env.reset)(key) + grid_size = robot_warehouse_env._generator.grid_size + goals = robot_warehouse_env._generator.goals + + num_steps = 5 + states = [state] + for _ in range(num_steps - 1): + key, subkey = jax.random.split(key) + action = jax.random.choice(subkey, jnp.arange(5), shape=(2,)) + state, _ = jax.jit(robot_warehouse_env.step)(state, action) + states.append(state) + + viewer = RobotWarehouseViewer(grid_size, goals) + animation = viewer.animate(states) + assert isinstance(animation, matplotlib.animation.Animation) + + save_path = str(tmpdir.join("/robot_warehouse_animation_test.gif")) + animation.save(save_path) + viewer.close() diff --git a/jumanji/training/configs/config.yaml b/jumanji/training/configs/config.yaml index 34d88409c..8698a2af6 100644 --- a/jumanji/training/configs/config.yaml +++ b/jumanji/training/configs/config.yaml @@ -1,6 +1,6 @@ defaults: - _self_ - - env: snake # [bin_pack, cleaner, connector, cvrp, game_2048, job_shop, knapsack, maze, minesweeper, rubiks_cube, snake, tsp] + - env: snake # [bin_pack, cleaner, connector, cvrp, game_2048, job_shop, knapsack, maze, minesweeper, rubiks_cube, robot_warehouse, snake, tsp] agent: random # [random, a2c] diff --git a/jumanji/training/configs/env/rware.yaml b/jumanji/training/configs/env/rware.yaml new file mode 100644 index 000000000..7d52c8f01 --- /dev/null +++ b/jumanji/training/configs/env/rware.yaml @@ -0,0 +1,27 @@ +name: robot_warehouse +registered_version: RobotWarehouse-v0 + +network: + transformer_num_blocks: 4 + transformer_num_heads: 8 + transformer_key_size: 16 + transformer_mlp_units: [512] + +training: + num_epochs: 500 + num_learner_steps_per_epoch: 100 + n_steps: 20 + total_batch_size: 128 + +evaluation: + eval_total_batch_size: 5000 + greedy_eval_total_batch_size: 5000 + +a2c: + normalize_advantage: False + discount_factor: 0.99 + bootstrapping_factor: 0.95 + l_pg: 1.0 + l_td: 1.0 + l_en: 0.01 + learning_rate: 1e-4 diff --git a/jumanji/training/networks/__init__.py b/jumanji/training/networks/__init__.py index 0b24bb2fe..a8b45b9c3 100644 --- a/jumanji/training/networks/__init__.py +++ b/jumanji/training/networks/__init__.py @@ -48,6 +48,12 @@ make_actor_critic_networks_minesweeper, ) from jumanji.training.networks.minesweeper.random import make_random_policy_minesweeper +from jumanji.training.networks.robot_warehouse.actor_critic import ( + make_actor_critic_networks_robot_warehouse, +) +from jumanji.training.networks.robot_warehouse.random import ( + make_random_policy_robot_warehouse, +) from jumanji.training.networks.rubiks_cube.actor_critic import ( make_actor_critic_networks_rubiks_cube, ) diff --git a/jumanji/training/networks/robot_warehouse/__init__.py b/jumanji/training/networks/robot_warehouse/__init__.py new file mode 100644 index 000000000..21db9ec1c --- /dev/null +++ b/jumanji/training/networks/robot_warehouse/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jumanji/training/networks/robot_warehouse/actor_critic.py b/jumanji/training/networks/robot_warehouse/actor_critic.py new file mode 100644 index 000000000..a1aca10cd --- /dev/null +++ b/jumanji/training/networks/robot_warehouse/actor_critic.py @@ -0,0 +1,168 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Sequence + +import chex +import haiku as hk +import jax.numpy as jnp +import numpy as np + +from jumanji.environments import RobotWarehouse +from jumanji.environments.routing.robot_warehouse.types import Observation +from jumanji.training.networks.actor_critic import ( + ActorCriticNetworks, + FeedForwardNetwork, +) +from jumanji.training.networks.parametric_distribution import ( + MultiCategoricalParametricDistribution, +) +from jumanji.training.networks.transformer_block import TransformerBlock + + +def make_actor_critic_networks_robot_warehouse( + robot_warehouse: RobotWarehouse, + transformer_num_blocks: int, + transformer_num_heads: int, + transformer_key_size: int, + transformer_mlp_units: Sequence[int], +) -> ActorCriticNetworks: + """Make actor-critic networks for the `RobotWarehouse` environment.""" + num_values = np.asarray(robot_warehouse.action_spec().num_values) + parametric_action_distribution = MultiCategoricalParametricDistribution( + num_values=num_values + ) + policy_network = make_actor_network( + time_limit=robot_warehouse.time_limit, + transformer_num_blocks=transformer_num_blocks, + transformer_num_heads=transformer_num_heads, + transformer_key_size=transformer_key_size, + transformer_mlp_units=transformer_mlp_units, + ) + value_network = make_critic_network( + time_limit=robot_warehouse.time_limit, + transformer_num_blocks=transformer_num_blocks, + transformer_num_heads=transformer_num_heads, + transformer_key_size=transformer_key_size, + transformer_mlp_units=transformer_mlp_units, + ) + return ActorCriticNetworks( + policy_network=policy_network, + value_network=value_network, + parametric_action_distribution=parametric_action_distribution, + ) + + +class RobotWarehouseTorso(hk.Module): + def __init__( + self, + transformer_num_blocks: int, + transformer_num_heads: int, + transformer_key_size: int, + transformer_mlp_units: Sequence[int], + env_time_limit: int, + name: Optional[str] = None, + ): + super().__init__(name=name) + self.transformer_num_blocks = transformer_num_blocks + self.transformer_num_heads = transformer_num_heads + self.transformer_key_size = transformer_key_size + self.transformer_mlp_units = transformer_mlp_units + self.model_size = transformer_num_heads * transformer_key_size + self.env_time_limit = env_time_limit + + def __call__(self, observation: Observation) -> chex.Array: + # Shape names: + # B: batch size + # N: number of agents + # O: observation size + # H: hidden/embedding size + # (B, N, O) + _, num_agents, _ = observation.agents_view.shape + + percent_done = observation.step_count / self.env_time_limit + step = jnp.repeat(percent_done[:, None], num_agents, axis=-1)[..., None] + agents_view = observation.agents_view + + # join step count and agent view to embed both at the same time + # (B, N, O + 1) + obs = jnp.concatenate((agents_view, step), axis=-1) + # (B, N, O + 1) -> (B, N, H) + embeddings = hk.Linear(self.model_size)(obs) + + # (B, N, H) -> (B, N, H) + for block_id in range(self.transformer_num_blocks): + transformer_block = TransformerBlock( + num_heads=self.transformer_num_heads, + key_size=self.transformer_key_size, + mlp_units=self.transformer_mlp_units, + w_init_scale=2 / self.transformer_num_blocks, + model_size=self.model_size, + name=f"self_attention_block_{block_id}", + ) + embeddings = transformer_block( + query=embeddings, key=embeddings, value=embeddings + ) + return embeddings # (B, N, H) + + +def make_critic_network( + time_limit: int, + transformer_num_blocks: int, + transformer_num_heads: int, + transformer_key_size: int, + transformer_mlp_units: Sequence[int], +) -> FeedForwardNetwork: + def network_fn(observation: Observation) -> chex.Array: + torso = RobotWarehouseTorso( + transformer_num_blocks, + transformer_num_heads, + transformer_key_size, + transformer_mlp_units, + time_limit, + ) + embeddings = torso(observation) + embeddings = jnp.sum(embeddings, axis=-2) + + head = hk.nets.MLP((*transformer_mlp_units, 1), activate_final=False) + values = head(embeddings) # (B, 1) + return jnp.squeeze(values, axis=-1) # (B,) + + init, apply = hk.without_apply_rng(hk.transform(network_fn)) + return FeedForwardNetwork(init=init, apply=apply) + + +def make_actor_network( + time_limit: int, + transformer_num_blocks: int, + transformer_num_heads: int, + transformer_key_size: int, + transformer_mlp_units: Sequence[int], +) -> FeedForwardNetwork: + def network_fn(observation: Observation) -> chex.Array: + torso = RobotWarehouseTorso( + transformer_num_blocks, + transformer_num_heads, + transformer_key_size, + transformer_mlp_units, + time_limit, + ) + output = torso(observation) + + head = hk.nets.MLP((*transformer_mlp_units, 5), activate_final=False) + logits = head(output) # (B, N, 5) + return jnp.where(observation.action_mask, logits, jnp.finfo(jnp.float32).min) + + init, apply = hk.without_apply_rng(hk.transform(network_fn)) + return FeedForwardNetwork(init=init, apply=apply) diff --git a/jumanji/training/networks/robot_warehouse/random.py b/jumanji/training/networks/robot_warehouse/random.py new file mode 100644 index 000000000..eb0ba826e --- /dev/null +++ b/jumanji/training/networks/robot_warehouse/random.py @@ -0,0 +1,23 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jumanji.training.networks.masked_categorical_random import ( + masked_categorical_random, +) +from jumanji.training.networks.protocols import RandomPolicy + + +def make_random_policy_robot_warehouse() -> RandomPolicy: + """Make random policy for RobotWarehouse.""" + return masked_categorical_random diff --git a/jumanji/training/setup_train.py b/jumanji/training/setup_train.py index f58543fd6..e8f8b1b16 100644 --- a/jumanji/training/setup_train.py +++ b/jumanji/training/setup_train.py @@ -33,6 +33,7 @@ Knapsack, Maze, Minesweeper, + RobotWarehouse, RubiksCube, Snake, ) @@ -164,6 +165,9 @@ def _setup_random_policy( # noqa: CCR001 elif cfg.env.name == "connector": assert isinstance(env.unwrapped, Connector) random_policy = networks.make_random_policy_connector() + elif cfg.env.name == "robot_warehouse": + assert isinstance(env.unwrapped, RobotWarehouse) + random_policy = networks.make_random_policy_robot_warehouse() else: raise ValueError(f"Environment name not found. Got {cfg.env.name}.") return random_policy @@ -246,6 +250,15 @@ def _setup_actor_critic_neworks( # noqa: CCR001 step_count_embed_dim=cfg.env.network.step_count_embed_dim, dense_layer_dims=cfg.env.network.dense_layer_dims, ) + elif cfg.env.name == "robot_warehouse": + assert isinstance(env.unwrapped, RobotWarehouse) + actor_critic_networks = networks.make_actor_critic_networks_robot_warehouse( + robot_warehouse=env.unwrapped, + transformer_num_blocks=cfg.env.network.transformer_num_blocks, + transformer_num_heads=cfg.env.network.transformer_num_heads, + transformer_key_size=cfg.env.network.transformer_key_size, + transformer_mlp_units=cfg.env.network.transformer_mlp_units, + ) elif cfg.env.name == "minesweeper": assert isinstance(env.unwrapped, Minesweeper) actor_critic_networks = networks.make_actor_critic_networks_minesweeper(