Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(connector): single agent #119

Merged
merged 21 commits into from
May 12, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ problems.
| 🏭 JobShop (Job Shop Scheduling Problem) | Packing | `JobShop-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/job_shop/) | [doc](https://instadeepai.github.io/jumanji/environments/job_shop/) |
| 🎒 Knapsack | Packing | `Knapsack-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/knapsack/) | [doc](https://instadeepai.github.io/jumanji/environments/knapsack/) |
| 🧹 Cleaner | Routing | `Cleaner-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/cleaner/) | [doc](https://instadeepai.github.io/jumanji/environments/cleaner/) |
| :link: Connector | Routing | `Connector-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/connector/) | [doc](https://instadeepai.github.io/jumanji/environments/connector/) |
| :link: Connector | Routing | `Connector-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/connector/) | [doc](https://instadeepai.github.io/jumanji/environments/connector/) |
| 🚚 CVRP (Capacitated Vehicle Routing Problem) | Routing | `CVRP-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/cvrp/) | [doc](https://instadeepai.github.io/jumanji/environments/cvrp/) |
| :mag: Maze | Routing | `Maze-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/maze/) | [doc](https://instadeepai.github.io/jumanji/environments/maze/) |
| 🐍 Snake | Routing | `Snake-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/snake/) | [doc](https://instadeepai.github.io/jumanji/environments/snake/) |
Expand Down
10 changes: 5 additions & 5 deletions docs/environments/connector.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ moves due to being blocked.
At each step observation contains 3 items: a grid, an action mask for each agent and
the episode step count.

- `grid`: jax array (int32) of shape `(grid_size, grid_size)`, a 2D matrix for each
agent that represents pairs of points that need to be connected. Each agent has three types of
points: **position**, **target** and **path** which are represented by different numbers on the
grid. The **position** of an agent has to connect to its **target**, leaving a **path** behind
- `grid`: jax array (int32) of shape `(grid_size, grid_size)`, a 2D matrix that represents pairs
of points that need to be connected. Each agent has three types of points: **position**,
**target** and **path** which are represented by different numbers on the grid. The
**position** of an agent has to connect to its **target**, leaving a **path** behind
it as it moves across the grid forming its route. Each agent connects to only 1 target.

- `action_mask`: jax array (bool) of shape `(num_agents, 5)`, indicates which actions each agent
Expand Down Expand Up @@ -74,4 +74,4 @@ Rewards are provided in the shape `(num_agents,)` so that each agent can have a


## Registered Versions 📖
- `Connector-v0`, grid size of 10 and 5 agents.
- `Connector-v1`, grid size of 10 and 5 agents.
2 changes: 1 addition & 1 deletion jumanji/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
register(id="Cleaner-v0", entry_point="jumanji.environments:Cleaner")

# Connector with grid size of 10 and 5 agents.
register(id="Connector-v0", entry_point="jumanji.environments:Connector")
register(id="Connector-v1", entry_point="jumanji.environments:Connector")

# CVRP with 20 randomly generated nodes, a maximum capacity of 30,
# a maximum demand for each node of 10, and a dense reward function.
Expand Down
13 changes: 6 additions & 7 deletions jumanji/environments/routing/connector/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,16 @@ class Connector(Environment[State]):
- each value in the array corresponds to an agent's action.

- reward: jax array (float) of shape ():
- dense: reward is increased by 1 for each successful connection on that step. Additionally,
- dense: reward is 1 for each successful connection on that step. Additionally,
each pair of points that have not connected receives a penalty reward of -0.03.

- episode termination: if an agent can't move, or the time limit is reached, or the agent
connects to its target, it is considered done. Once all agents are done, the episode
terminates. The timestep discounts are of shape (1,) and only set to `0` when all agents
are done.
- episode termination:
* all agents either can't move (no available actions) or have connected to their target.
* the time limit is reached.

- state: State:
- key: jax PRNG key used to randomly spawn agents and targets.
- grid: jax array (int32) of shape (grid_size, grid_size) the observation.
- grid: jax array (int32) of shape (grid_size, grid_size) giving the observation.
- step_count: jax array (int32) of shape () number of steps elapsed in the current episode.

```python
Expand Down Expand Up @@ -173,7 +172,7 @@ def step(
grid=grid, step_count=state.step_count + 1, agents=agents, key=state.key
)

# Construct timestep: get rewards, discounts
# Construct timestep: get reward, legal actions and done
reward = self._reward_fn(state, action, new_state)
action_mask = jax.vmap(self._get_action_mask, (0, None))(agents, grid)
observation = Observation(
Expand Down
4 changes: 2 additions & 2 deletions jumanji/environments/routing/connector/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def test_connector__reset(connector: Connector, key: jax.random.KeyArray) -> Non
assert all(is_head_on_grid(state.agents, state.grid))
assert all(is_target_on_grid(state.agents, state.grid))

assert jnp.array_equal(timestep.discount, jnp.asarray(1.0))
assert jnp.array_equal(timestep.reward, jnp.asarray(0.0))
assert timestep.discount == 1.0
assert timestep.reward == 0.0
assert timestep.step_type == StepType.FIRST


Expand Down
3 changes: 2 additions & 1 deletion jumanji/environments/routing/connector/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def __call__(

class DenseRewardFn(RewardFn):
"""Returns: reward of 1.0 for each agent that connects on that step and adds a penalty of
-0.03, per agent, at every timestep where they have yet to connect."""
-0.03, per agent, at every timestep where they have yet to connect.
"""

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion jumanji/training/configs/env/connector.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: connector
registered_version: Connector-v0
registered_version: Connector-v1

network:
transformer_num_blocks: 4
Expand Down
57 changes: 33 additions & 24 deletions jumanji/training/networks/connector/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
import numpy as np

from jumanji.environments.routing.connector import Connector, Observation
from jumanji.environments.routing.connector.constants import (
AGENT_INITIAL_VALUE,
PATH,
POSITION,
TARGET,
from jumanji.environments.routing.connector.constants import AGENT_INITIAL_VALUE
from jumanji.environments.routing.connector.utils import (
get_path,
get_position,
get_target,
)
from jumanji.training.networks.actor_critic import (
ActorCriticNetworks,
Expand Down Expand Up @@ -60,6 +60,7 @@ def make_actor_critic_networks_connector(
env_time_limit=connector.time_limit,
)
value_network = make_critic_network_connector(
num_agents=num_values[0],
transformer_num_blocks=transformer_num_blocks,
transformer_num_heads=transformer_num_heads,
transformer_key_size=transformer_key_size,
Expand All @@ -74,33 +75,35 @@ def make_actor_critic_networks_connector(
)


def process_grid(grid: chex.Array) -> chex.Array:
def channels_for_one_agent(agent_grid: chex.Array) -> chex.Array:
def process_grid(grid: chex.Array, num_agents: jnp.int32) -> chex.Array:
def channel_per_agent(agent_grid: chex.Array, agent_id: jnp.int32) -> chex.Array:
"""Concatenates two feature maps: the info of the agent and the info about all other agents
in an indiscernible way (to keep permutation equivariance).
"""
agent_path = get_path(agent_id)
agent_target = get_target(agent_id)
agent_pos = get_position(agent_id)
agent_grid = jnp.expand_dims(agent_grid, -1)
agent_mask = (
(agent_grid == PATH) | (agent_grid == TARGET) | (agent_grid == POSITION)
)
agent_channel = jnp.where(
agent_mask,
agent_grid,
0,
(agent_grid == agent_path)
| (agent_grid == agent_target)
| (agent_grid == agent_pos)
)
# only current agent's info as values: 1, 2 or 3
# [G, G, 1]
agent_channel = jnp.where(agent_mask, agent_grid - 3 * agent_id, 0)

# collapse all other agent values into just 1, 2 or 3
offset = AGENT_INITIAL_VALUE
others_channel = offset + (agent_grid - offset) % 3
others_channel = jnp.where(
agent_mask | (agent_grid == 0),
0,
others_channel,
)
channels = jnp.concatenate(
[agent_channel, others_channel], axis=-1
) # [G, G, 2]
# [G, G, 1]
others_channel = jnp.where(agent_mask | (agent_grid == 0), 0, others_channel)
# [G, G, 2]
channels = jnp.concatenate([agent_channel, others_channel], axis=-1)
return channels

channels = jax.vmap(channels_for_one_agent)(grid) # (N, G, G, 2)
# (N, G, G, 2)
channels = jax.vmap(channel_per_agent, (None, 0))(grid, jnp.arange(num_agents))
return channels.astype(float)


Expand All @@ -113,6 +116,7 @@ def __init__(
transformer_mlp_units: Sequence[int],
conv_n_channels: int,
env_time_limit: int,
num_agents: int,
name: Optional[str] = None,
):
super().__init__(name=name)
Expand All @@ -121,6 +125,7 @@ def __init__(
self.transformer_key_size = transformer_key_size
self.transformer_mlp_units = transformer_mlp_units
self.model_size = transformer_num_heads * transformer_key_size
self.num_agents = num_agents
self.cnn_block = hk.Sequential(
[
hk.Conv2D(conv_n_channels, (3, 3), 1, padding="VALID"),
Expand All @@ -138,7 +143,8 @@ def __init__(
self.env_time_limit = env_time_limit

def __call__(self, observation: Observation) -> chex.Array:
grid = jax.vmap(process_grid)(observation.grid) # (B, N, G, G, 2)
# (B, N, G, G, 2)
grid = jax.vmap(process_grid, (0, None))(observation.grid, self.num_agents)
embeddings = jax.vmap(self.cnn_block)(grid) # (B, N, H)
embeddings = self._augment_with_step_count(embeddings, observation.step_count)

Expand Down Expand Up @@ -179,7 +185,7 @@ def _augment_with_step_count(
self, embeddings: chex.Array, step_count: chex.Array
) -> chex.Array:
step_count = jnp.asarray(step_count / self.env_time_limit, float)
num_agents = embeddings.shape[-2]
num_agents = self.num_agents # embeddings.shape[-2]
step_count = jnp.repeat(step_count[:, None], num_agents, axis=-1)[..., None]
embeddings = jnp.concatenate([embeddings, step_count], axis=-1)
return embeddings
Expand All @@ -203,6 +209,7 @@ def network_fn(observation: Observation) -> chex.Array:
conv_n_channels=conv_n_channels,
env_time_limit=env_time_limit,
name="policy_torso",
num_agents=num_actions,
)
embeddings = torso(observation)
logits = hk.nets.MLP((*transformer_mlp_units, num_actions), name="policy_head")(
Expand All @@ -216,6 +223,7 @@ def network_fn(observation: Observation) -> chex.Array:


def make_critic_network_connector(
num_agents: int,
transformer_num_blocks: int,
transformer_num_heads: int,
transformer_key_size: int,
Expand All @@ -232,6 +240,7 @@ def network_fn(observation: Observation) -> chex.Array:
conv_n_channels=conv_n_channels,
env_time_limit=env_time_limit,
name="critic_torso",
num_agents=num_agents,
)
embeddings = torso(observation)
# Sum embeddings over the sequence length (num_agents).
Expand Down
7 changes: 2 additions & 5 deletions jumanji/training/setup_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from jumanji.training.networks.actor_critic import ActorCriticNetworks
from jumanji.training.networks.protocols import RandomPolicy
from jumanji.training.types import ActingState, TrainingState
from jumanji.wrappers import MultiToSingleWrapper, VmapAutoResetWrapper
from jumanji.wrappers import VmapAutoResetWrapper


def setup_logger(cfg: DictConfig) -> Logger:
Expand Down Expand Up @@ -78,10 +78,7 @@ def setup_logger(cfg: DictConfig) -> Logger:


def _make_raw_env(cfg: DictConfig) -> Environment:
env: Environment = jumanji.make(cfg.env.registered_version)
if isinstance(env, Connector):
env = MultiToSingleWrapper(env)
return env
return jumanji.make(cfg.env.registered_version)


def setup_env(cfg: DictConfig) -> Environment:
Expand Down