Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(connector): single agent #119

Merged
merged 21 commits into from
May 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ problems.
| 🏭 JobShop (Job Shop Scheduling Problem) | Packing | `JobShop-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/job_shop/) | [doc](https://instadeepai.github.io/jumanji/environments/job_shop/) |
| 🎒 Knapsack | Packing | `Knapsack-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/knapsack/) | [doc](https://instadeepai.github.io/jumanji/environments/knapsack/) |
| 🧹 Cleaner | Routing | `Cleaner-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/cleaner/) | [doc](https://instadeepai.github.io/jumanji/environments/cleaner/) |
| :link: Connector | Routing | `Connector-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/connector/) | [doc](https://instadeepai.github.io/jumanji/environments/connector/) |
| :link: Connector | Routing | `Connector-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/connector/) | [doc](https://instadeepai.github.io/jumanji/environments/connector/) |
| 🚚 CVRP (Capacitated Vehicle Routing Problem) | Routing | `CVRP-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/cvrp/) | [doc](https://instadeepai.github.io/jumanji/environments/cvrp/) |
| :mag: Maze | Routing | `Maze-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/maze/) | [doc](https://instadeepai.github.io/jumanji/environments/maze/) |
| 🐍 Snake | Routing | `Snake-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/snake/) | [doc](https://instadeepai.github.io/jumanji/environments/snake/) |
Expand Down
2 changes: 0 additions & 2 deletions docs/api/environments/connector.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
- init
- observation_spec
- action_spec
- reward_spec
- discount_spec
- reset
- step
- render
46 changes: 10 additions & 36 deletions docs/environments/connector.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,14 @@ to allow each other to connect to their own targets without overlapping.
An episode ends when all agents have connected to their targets or no agents can make any further
moves due to being blocked.

> ⚠️ Warning
>
> This environment is multi-agent, i.e. the observation, action and action mask are batched on the
> agent dimension.
>
> - If used in a multi-agent RL setting, one can direclty vmap the agents' inference functions on
> the observation they receive or unpack the observation and give it to each agent manually, e.g.
> `agents_obs = [jax.tree_util.tree_map(lambda x: x[i] if x.ndim>0 else x, obs) for i in range(len(obs.grid))]`.
>
> - If used in a single-agent RL setting, one can use `jumanji.wrappers.MultiToSingleWrapper` to
> make it a single-agent environment.


## Observation
At each step observation contains 3 items: a grid for each agent, an action mask for each agent and
At each step observation contains 3 items: a grid, an action mask for each agent and
the episode step count.

- `grid`: jax array (int32) of shape `(num_agents, grid_size, grid_size)`, a 2D matrix for each
agent that represents pairs of points that need to be connected from the perspective of each
agent. The **position** of an agent has to connect to its **target**, leaving a **path** behind
- `grid`: jax array (int32) of shape `(grid_size, grid_size)`, a 2D matrix that represents pairs
of points that need to be connected. Each agent has three types of points: **position**,
**target** and **path** which are represented by different numbers on the grid. The
**position** of an agent has to connect to its **target**, leaving a **path** behind
it as it moves across the grid forming its route. Each agent connects to only 1 target.

- `action_mask`: jax array (bool) of shape `(num_agents, 5)`, indicates which actions each agent
Expand All @@ -43,15 +31,14 @@ the episode step count.


### Encoding
Each agent has 3 components represented in the observation space: position, target, and path. Each
Each agent has 3 components represented in the observation space: **position**, **target**, and **path**. Each
agent in the environment will have an integer representing their components.

- Positions are encoded starting from 2 in multiples of 3: 2, 5, 8, …

- Targets are encoded starting from 3 in multiples of 3: 3, 6, 9, …

- Paths appear in the location of the head once it moves, starting from 1 in
multiples of 3: 1, 4, 7, …
- Paths appear in the location of the head once it moves, starting from 1 in multiples of 3: 1, 4, 7, …

Every group of 3 corresponds to 1 agent: (1,2,3), (4,5,6), …

Expand All @@ -62,7 +49,7 @@ Agent2[path=4, position=5, target=6]
Agent3[path=7, position=8, target=9]
```

For example, on a 6x6 grid, the starting observation is shown below.
For example, on a 6x6 grid, a possible observation is shown below.

```
[[ 2 0 3 0 0 0]
Expand All @@ -73,31 +60,18 @@ For example, on a 6x6 grid, the starting observation is shown below.
[ 0 0 6 7 7 7]]
```

### Current Agent (multi-agent)

Given that this is a multi-agent environment, each agent gets its own observation thus we must
have a way to represent the current agent, so that the actor/learner knows which agent its actions
will apply to. The current agent is always encoded as `(1,2,3)` in the observations. However, this
notion of current agent only exists in the observations, in the state agent 0 is always encoded
as `(1,2,3)`.

The implementation shifts all other agents' values to make the `(1,2,3)` values represent the
current agent, so in each agent’s observation it will be represented by `(1,2,3)`.
This means that the agent with the values `(4,5,6)` will always be the next agent to act.


## Action
The action space is a `MultiDiscreteArray` of shape `(num_agents,)` of integer values in the range
of `[0, 4]`. Each value corresponds to an agent moving in 1 of 4 cardinal directions or taking the
no-op action. That is, [0, 1, 2, 3, 4] -> [No Op, Up, Right, Down, Left].


## Reward
The reward is **dense**: +1.0 for each agent that connects at that step and -0.03 for each agent that has not
The reward is **dense**: +1.0 per agent that connects at that step and -0.03 per agent that has not
connected yet.

Rewards are provided in the shape `(num_agents,)` so that each agent can have a reward.


## Registered Versions 📖
- `Connector-v0`, grid size of 10 and 5 agents.
- `Connector-v1`, grid size of 10 and 5 agents.
2 changes: 1 addition & 1 deletion jumanji/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
register(id="Cleaner-v0", entry_point="jumanji.environments:Cleaner")

# Connector with grid size of 10 and 5 agents.
register(id="Connector-v0", entry_point="jumanji.environments:Connector")
register(id="Connector-v1", entry_point="jumanji.environments:Connector")

# CVRP with 20 randomly generated nodes, a maximum capacity of 30,
# a maximum demand for each node of 10, and a dense reward function.
Expand Down
81 changes: 21 additions & 60 deletions jumanji/environments/routing/connector/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,23 @@
is_valid_position,
move_agent,
move_position,
switch_perspective,
)
from jumanji.environments.routing.connector.viewer import ConnectorViewer
from jumanji.types import TimeStep, restart, termination, transition
from jumanji.viewer import Viewer


class Connector(Environment[State]):
"""The `Connector` environment is a multi-agent gridworld problem where each agent must connect a
start to a target. However, when moving through this gridworld the agent leaves an impassable
trail behind it. Therefore, agents must connect to their targets without overlapping the routes
taken by any other agent.
"""The `Connector` environment is a gridworld problem where multiple pairs of points (sets)
must be connected without overlapping the paths taken by any other set. This is achieved
by allowing certain points to move to an adjacent cell at each step. However, each time a
point moves it leaves an impassable trail behind it. The goal is to connect all sets.

- observation - `Observation`
- action mask: jax array (bool) of shape (num_agents, 5).
- step_count: jax array (int32) of shape ()
the current episode step.
- grid: jax array (int32) of shape (num_agents, size, size)
- each 2d array (size, size) along axis 0 is the agent's local observation.
- agents have ids from 0 to (num_agents - 1)
- grid: jax array (int32) of shape (grid_size, grid_size)
- with 2 agents you might have a grid like this:
4 0 1
5 0 1
Expand All @@ -68,24 +65,21 @@ class Connector(Environment[State]):
the bottom right corner and is aiming to get to the middle bottom cell. Agent 2
started in the top left and moved down once towards its target in the bottom left.

This would just be agent 0's view, the numbers would be flipped for agent 1's view.
So the full observation would be of shape (2, 3, 3).

- action: jax array (int32) of shape (num_agents,):
- can take the values [0,1,2,3,4] which correspond to [No Op, Up, Right, Down, Left].
- each value in the array corresponds to an agent's action.

- reward: jax array (float) of shape ():
- dense: each agent is given 1.0 if it connects on that step, otherwise 0.0. Additionally,
each agent that has not connected receives a penalty reward of -0.03.
- dense: reward is 1 for each successful connection on that step. Additionally,
each pair of points that have not connected receives a penalty reward of -0.03.

- episode termination: if an agent can't move, or the time limit is reached, or the agent
connects to its target, it is considered done. Once all agents are done, the episode
terminates. The timestep discounts are of shape (num_agents,).
- episode termination:
- all agents either can't move (no available actions) or have connected to their target.
- the time limit is reached.

- state: State:
- key: jax PRNG key used to randomly spawn agents and targets.
- grid: jax array (int32) of shape (size, size) which corresponds to agent 0's observation.
- grid: jax array (int32) of shape (grid_size, grid_size) giving the observation.
sash-a marked this conversation as resolved.
Show resolved Hide resolved
- step_count: jax array (int32) of shape () number of steps elapsed in the current episode.

```python
Expand Down Expand Up @@ -147,14 +141,12 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
state.agents, state.grid
)
observation = Observation(
grid=self._obs_from_grid(state.grid),
grid=state.grid,
action_mask=action_mask,
step_count=state.step_count,
)
extras = self._get_extras(state)
timestep = restart(
observation=observation, extras=extras, shape=(self.num_agents,)
)
timestep = restart(observation=observation, extras=extras)
return state, timestep

def step(
Expand All @@ -180,31 +172,26 @@ def step(
grid=grid, step_count=state.step_count + 1, agents=agents, key=state.key
)

# Construct timestep: get observations, rewards, discounts
grids = self._obs_from_grid(grid)
# Construct timestep: get reward, legal actions and done
reward = self._reward_fn(state, action, new_state)
action_mask = jax.vmap(self._get_action_mask, (0, None))(agents, grid)
observation = Observation(
grid=grids, action_mask=action_mask, step_count=new_state.step_count
grid=grid, action_mask=action_mask, step_count=new_state.step_count
)

dones = jax.vmap(connected_or_blocked)(agents, action_mask)
discount = jnp.asarray(jnp.logical_not(dones), dtype=float)
done = jnp.all(jax.vmap(connected_or_blocked)(agents, action_mask))
extras = self._get_extras(new_state)
timestep = jax.lax.cond(
dones.all() | (new_state.step_count >= self.time_limit),
done | (new_state.step_count >= self.time_limit),
lambda: termination(
reward=reward,
observation=observation,
extras=extras,
shape=self.num_agents,
),
lambda: transition(
reward=reward,
observation=observation,
discount=discount,
sash-a marked this conversation as resolved.
Show resolved Hide resolved
extras=extras,
shape=self.num_agents,
),
)

Expand Down Expand Up @@ -271,12 +258,6 @@ def _step_agent(

return new_agent, new_grid

def _obs_from_grid(self, grid: chex.Array) -> chex.Array:
"""Gets the observation vector for all agents."""
return jax.vmap(switch_perspective, (None, 0, None))(
grid, self._agent_ids, self.num_agents
)

def _get_action_mask(self, agent: Agent, grid: chex.Array) -> chex.Array:
"""Gets an agent's action mask."""
# Don't check action 0 because no-op is always valid
Expand Down Expand Up @@ -344,12 +325,12 @@ def observation_spec(self) -> specs.Spec[Observation]:

Returns:
Spec for the `Observation` whose fields are:
- grid: BoundedArray (int32) of shape (num_agents, grid_size, grid_size).
- grid: BoundedArray (int32) of shape (grid_size, grid_size).
- action_mask: BoundedArray (bool) of shape (num_agents, 5).
- step_count: BoundedArray (int32) of shape ().
"""
grid = specs.BoundedArray(
shape=(self.num_agents, self.grid_size, self.grid_size),
shape=(self.grid_size, self.grid_size),
dtype=jnp.int32,
name="grid",
minimum=0,
Expand Down Expand Up @@ -380,8 +361,8 @@ def observation_spec(self) -> specs.Spec[Observation]:
def action_spec(self) -> specs.MultiDiscreteArray:
"""Returns the action spec for the Connector environment.

5 actions: [0,1,2,3,4] -> [No Op, Up, Right, Down, Left]. Since this is a multi-agent
environment, the environment expects an array of actions of shape (num_agents,).
5 actions: [0,1,2,3,4] -> [No Op, Up, Right, Down, Left]. Since this is an environment with
a multi-dimensional action space, it expects an array of actions of shape (num_agents,).

Returns:
observation_spec: `MultiDiscreteArray` of shape (num_agents,).
Expand All @@ -391,23 +372,3 @@ def action_spec(self) -> specs.MultiDiscreteArray:
dtype=jnp.int32,
sash-a marked this conversation as resolved.
Show resolved Hide resolved
name="action",
)

def reward_spec(self) -> specs.Array:
"""
Returns:
reward_spec: a `specs.Array` spec of shape (num_agents,). One for each agent.
"""
return specs.Array(shape=(self.num_agents,), dtype=float, name="reward")

def discount_spec(self) -> specs.BoundedArray:
"""
Returns:
discount_spec: a `specs.Array` spec of shape (num_agents,). One for each agent
"""
return specs.BoundedArray(
shape=(self.num_agents,),
dtype=float,
minimum=0.0,
maximum=1.0,
name="discount",
)
53 changes: 6 additions & 47 deletions jumanji/environments/routing/connector/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def test_connector__reset(connector: Connector, key: jax.random.KeyArray) -> Non
assert all(is_head_on_grid(state.agents, state.grid))
assert all(is_target_on_grid(state.agents, state.grid))

assert jnp.array_equal(timestep.discount, jnp.ones(connector.num_agents))
assert jnp.array_equal(timestep.reward, jnp.zeros(connector.num_agents))
assert timestep.discount == 1.0
assert timestep.reward == 0.0
assert timestep.step_type == StepType.FIRST


Expand Down Expand Up @@ -91,7 +91,7 @@ def test_connector__step_connected(
chex.assert_trees_all_equal(real_state2, state2)

assert timestep.step_type == StepType.LAST
assert jnp.array_equal(timestep.discount, jnp.zeros(connector.num_agents))
assert jnp.array_equal(timestep.discount, jnp.asarray(0))
reward = connector._reward_fn(real_state1, action2, real_state2)
assert jnp.array_equal(timestep.reward, reward)

Expand Down Expand Up @@ -143,7 +143,7 @@ def test_connector__step_blocked(

assert jnp.array_equal(state.grid, expected_grid)
assert timestep.step_type == StepType.LAST
assert jnp.array_equal(timestep.discount, jnp.zeros(connector.num_agents))
assert jnp.array_equal(timestep.discount, jnp.asarray(0))

assert all(is_head_on_grid(state.agents, state.grid))
assert all(is_target_on_grid(state.agents, state.grid))
Expand All @@ -162,12 +162,12 @@ def test_connector__step_horizon(connector: Connector, state: State) -> None:
state, timestep = step_fn(state, actions)

assert timestep.step_type != StepType.LAST
assert jnp.array_equal(timestep.discount, jnp.ones(connector.num_agents))
assert jnp.array_equal(timestep.discount, jnp.asarray(1))

# step 5
state, timestep = step_fn(state, actions)
assert timestep.step_type == StepType.LAST
assert jnp.array_equal(timestep.discount, jnp.zeros(connector.num_agents))
assert jnp.array_equal(timestep.discount, jnp.asarray(0))


def test_connector__step_agents_collision(
Expand Down Expand Up @@ -230,47 +230,6 @@ def test_connector__does_not_smoke(connector: Connector) -> None:
check_env_does_not_smoke(connector)


def test_connector__obs_from_grid(
connector: Connector,
grid: chex.Array,
path0: int,
path1: int,
path2: int,
targ0: int,
targ1: int,
targ2: int,
posi0: int,
posi1: int,
posi2: int,
) -> None:
"""Tests that observations are correctly generated given the grid."""
observations = connector._obs_from_grid(grid)

expected_agent_1 = jnp.array(
[
[EMPTY, EMPTY, targ2, EMPTY, EMPTY, EMPTY],
[EMPTY, EMPTY, posi2, path2, path2, EMPTY],
[EMPTY, EMPTY, EMPTY, targ1, posi1, EMPTY],
[targ0, EMPTY, posi0, EMPTY, path1, EMPTY],
[EMPTY, EMPTY, path0, EMPTY, path1, EMPTY],
[EMPTY, EMPTY, path0, EMPTY, EMPTY, EMPTY],
]
)
expected_agent_2 = jnp.array(
[
[EMPTY, EMPTY, targ1, EMPTY, EMPTY, EMPTY],
[EMPTY, EMPTY, posi1, path1, path1, EMPTY],
[EMPTY, EMPTY, EMPTY, targ0, posi0, EMPTY],
[targ2, EMPTY, posi2, EMPTY, path0, EMPTY],
[EMPTY, EMPTY, path2, EMPTY, path0, EMPTY],
[EMPTY, EMPTY, path2, EMPTY, EMPTY, EMPTY],
]
)

expected_obs = jnp.stack([grid, expected_agent_1, expected_agent_2])
assert jnp.array_equal(expected_obs, observations)


def test_connector__get_action_mask(state: State, connector: Connector) -> None:
"""Validates the action masking."""
action_masks = jax.vmap(connector._get_action_mask, (0, None))(
Expand Down
Loading