Skip to content

Commit

Permalink
refactor: observation & action for clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
anordin95 committed Dec 24, 2024
1 parent 70f61ef commit 09e7a10
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 152 deletions.
12 changes: 5 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,7 @@ You can control your replays to your liking! Currently, we support these control
## 🌍 Environment
### 🔭 Observation
An observation for one agent is a dictionary `{"observation": observation, "action_mask": action_mask}`.

The `observation` is a `Dict`. Values are either `numpy` matrices with shape `(N,M)`, or simple `int` constants:
An agents observation contains a broad swath of information about their position in the game. Values are either `numpy` matrices with shape `(N,M)`, or `int` constants:
| Key | Shape | Description |
| -------------------- | --------- | ---------------------------------------------------------------------------- |
| `armies` | `(N,M)` | Number of units in a visible cell regardless of the owner |
Expand All @@ -179,9 +177,6 @@ The `observation` is a `Dict`. Values are either `numpy` matrices with shape `(N
| `timestep` || Current timestep of the game |
| `priority` || `1` if your move is evaluted first, `0` otherwise |

The `action_mask` is a 3D array with shape `(N, M, 4)`, where each element corresponds to whether a move is valid from cell
`[i, j]` in one of four directions: `0 (up)`, `1 (down)`, `2 (left)`, or `3 (right)`.

### ⚡ Action
Actions are lists of 5 values `[pass, cell_i, cell_j, direction, split]`, where
- `pass` indicates whether you want to `1 (pass)` or `0 (play)`.
Expand All @@ -190,6 +185,9 @@ Actions are lists of 5 values `[pass, cell_i, cell_j, direction, split]`, where
- `direction` indicates whether you want to move `0 (up)`, `1 (down)`, `2 (left)`, or `3 (right)`
- `split` indicates whether you want to `1 (split)` units and send only half, or `0 (no split)` where you send all units to the next cell

A convenience function `compute_valid_action_mask` is also provided for detailing the set of legal moves an agent can make based on its `observation`. The `valid_action_mask` is a 3D array with shape `(N, M, 4)`, where each element corresponds to whether a move is valid from cell
`[i, j]` in one of four directions: `0 (up)`, `1 (down)`, `2 (left)`, or `3 (right)`.

> [!TIP]
> You can see how actions and observations look like by printing a sample form the environment:
> ```python
Expand All @@ -203,7 +201,7 @@ and gives `1` for winner and `-1` for loser, otherwise `0`.
```python
def custom_reward_fn(observation, action, done, info):
# Give agent a reward based on the number of cells they own
return observation["observation"]["owned_land_count"]
return observation["owned_land_count"]
env = gym.make(..., reward_fn=custom_reward_fn)
observations, info = env.reset()
Expand Down
11 changes: 5 additions & 6 deletions generals/agents/expander_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

from generals.core.action import Action, compute_valid_action_mask
from generals.core.config import Direction
from generals.core.game import Action
from generals.core.observation import Observation

from .agent import Agent
Expand All @@ -16,16 +16,15 @@ def act(self, observation: Observation) -> Action:
Heuristically selects a valid (expanding) action.
Prioritizes capturing opponent and then neutral cells.
"""
mask = observation["action_mask"]
observation = observation["observation"]

mask = compute_valid_action_mask(observation)
valid_actions = np.argwhere(mask == 1)
if len(valid_actions) == 0: # No valid actions
return np.array([1, 0, 0, 0, 0])

army = observation["armies"]
opponent = observation["opponent_cells"]
neutral = observation["neutral_cells"]
army = observation.armies
opponent = observation.opponent_cells
neutral = observation.neutral_cells

# Find actions that capture opponent or neutral cells
actions_capture_opponent = np.zeros(len(valid_actions))
Expand Down
6 changes: 3 additions & 3 deletions generals/agents/random_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from generals.core.game import Action
from generals.core.action import Action, compute_valid_action_mask
from generals.core.observation import Observation

from .agent import Agent
Expand All @@ -23,8 +23,8 @@ def act(self, observation: Observation) -> Action:
"""
Randomly selects a valid action.
"""
mask = observation["action_mask"]
observation = observation["observation"]

mask = compute_valid_action_mask(observation)

valid_actions = np.argwhere(mask == 1)
if len(valid_actions) == 0: # No valid actions
Expand Down
68 changes: 68 additions & 0 deletions generals/core/action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import TypeAlias

import numpy as np

from generals.core.config import DIRECTIONS

from .observation import Observation

"""
Action is intentionally a numpy array rather than a class for the sake of optimization. Granted,
this hasn't been performance tested, so take that decision with a grain of salt.
The action format is an array with 5 entries: [pass, row, col, direction, split]
Args:
pass: boolean integer (0 or 1) indicating whether the agent should pass/skip this turn and do nothing.
row: The row the agent should move from. In the closed-interval: [0, (grid_height - 1)].
col: The column the agent should move from. In the closed-interval: [0, (grid_width - 1)].
direction: An integer indicating which direction to move. 0 (up), 1 (down), 2 (left), 3 (right).
Note: the integer is effecitlvey an index into the DIRECTIONS enum.
split: boolean integer (0 or 1) indicating whether to split the army when moving.
"""
Action: TypeAlias = np.ndarray


def compute_valid_action_mask(observation: Observation) -> np.ndarray:
"""
Return a mask of the valid actions for a given observation.
A valid action is an action that originates from an agent's cell, has
at least 2 units and does not attempt to enter a mountain nor exit the grid.
Returns:
np.ndarray: an NxNx4 array, where each channel is a boolean mask
of valid actions (UP, DOWN, LEFT, RIGHT) for each cell in the grid.
I.e. valid_action_mask[i, j, k] is 1 if action k is valid in cell (i, j).
"""
height, width = observation.owned_cells.shape

ownership_channel = observation.owned_cells
more_than_1_army = (observation.armies > 1) * ownership_channel
owned_cells_indices = np.argwhere(more_than_1_army)
valid_action_mask = np.zeros((height, width, 4), dtype=bool)

if np.sum(ownership_channel) == 0:
return valid_action_mask

for channel_index, direction in enumerate(DIRECTIONS):
destinations = owned_cells_indices + direction.value

# check if destination is in grid bounds
in_first_boundary = np.all(destinations >= 0, axis=1)
in_height_boundary = destinations[:, 0] < height
in_width_boundary = destinations[:, 1] < width
destinations = destinations[in_first_boundary & in_height_boundary & in_width_boundary]

# check if destination is road
passable_cells = 1 - observation.mountains
# assert that every value is either 0 or 1 in passable cells
assert np.all(np.isin(passable_cells, [0, 1])), f"{passable_cells}"
passable_cell_indices = passable_cells[destinations[:, 0], destinations[:, 1]] == 1
action_destinations = destinations[passable_cell_indices]

# get valid action mask for a given direction
valid_source_indices = action_destinations - direction.value
valid_action_mask[valid_source_indices[:, 0], valid_source_indices[:, 1], channel_index] = 1.0

return valid_action_mask
38 changes: 17 additions & 21 deletions generals/core/game.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import gymnasium as gym
import numpy as np

from .action import Action
from .channels import Channels
from .config import DIRECTIONS
from .grid import Grid
from .observation import Observation

# Type aliases
Action: TypeAlias = np.ndarray

Info: TypeAlias = dict[str, Any]


Expand Down Expand Up @@ -41,26 +42,21 @@ def __init__(self, grid: Grid, agents: list[str]):
grid_discrete = np.ones(self.grid_dims, dtype=int) * self.max_army_value
self.observation_space = gym.spaces.Dict(
{
"observation": gym.spaces.Dict(
{
"armies": gym.spaces.MultiDiscrete(grid_discrete),
"generals": grid_multi_binary,
"cities": grid_multi_binary,
"mountains": grid_multi_binary,
"neutral_cells": grid_multi_binary,
"owned_cells": grid_multi_binary,
"opponent_cells": grid_multi_binary,
"fog_cells": grid_multi_binary,
"structures_in_fog": grid_multi_binary,
"owned_land_count": gym.spaces.Discrete(self.max_land_value),
"owned_army_count": gym.spaces.Discrete(self.max_army_value),
"opponent_land_count": gym.spaces.Discrete(self.max_land_value),
"opponent_army_count": gym.spaces.Discrete(self.max_army_value),
"timestep": gym.spaces.Discrete(self.max_timestep),
"priority": gym.spaces.Discrete(2),
}
),
"action_mask": gym.spaces.MultiBinary(self.grid_dims + (4,)),
"armies": gym.spaces.MultiDiscrete(grid_discrete),
"generals": grid_multi_binary,
"cities": grid_multi_binary,
"mountains": grid_multi_binary,
"neutral_cells": grid_multi_binary,
"owned_cells": grid_multi_binary,
"opponent_cells": grid_multi_binary,
"fog_cells": grid_multi_binary,
"structures_in_fog": grid_multi_binary,
"owned_land_count": gym.spaces.Discrete(self.max_land_value),
"owned_army_count": gym.spaces.Discrete(self.max_army_value),
"opponent_land_count": gym.spaces.Discrete(self.max_land_value),
"opponent_army_count": gym.spaces.Discrete(self.max_army_value),
"timestep": gym.spaces.Discrete(self.max_timestep),
"priority": gym.spaces.Discrete(2),
}
)

Expand Down
144 changes: 36 additions & 108 deletions generals/core/observation.py
Original file line number Diff line number Diff line change
@@ -1,111 +1,39 @@
import numpy as np

from generals.core.config import DIRECTIONS


class Observation:
def __init__(
self,
armies: np.ndarray,
generals: np.ndarray,
cities: np.ndarray,
mountains: np.ndarray,
neutral_cells: np.ndarray,
owned_cells: np.ndarray,
opponent_cells: np.ndarray,
fog_cells: np.ndarray,
structures_in_fog: np.ndarray,
owned_land_count: int,
owned_army_count: int,
opponent_land_count: int,
opponent_army_count: int,
timestep: int,
priority: int = 0,
):
self.armies = armies
self.generals = generals
self.cities = cities
self.mountains = mountains
self.neutral_cells = neutral_cells
self.owned_cells = owned_cells
self.opponent_cells = opponent_cells
self.fog_cells = fog_cells
self.structures_in_fog = structures_in_fog
self.owned_land_count = owned_land_count
self.owned_army_count = owned_army_count
self.opponent_land_count = opponent_land_count
self.opponent_army_count = opponent_army_count
self.timestep = timestep
self.priority = priority
# armies, generals, cities, mountains, empty, owner, fogged, structure in fog

def action_mask(self) -> np.ndarray:
"""
Function to compute valid actions from a given ownership mask.
Valid action is an action that originates from agent's cell with atleast 2 units
and does not bump into a mountain or fall out of the grid.
Returns:
np.ndarray: an NxNx4 array, where each channel is a boolean mask
of valid actions (UP, DOWN, LEFT, RIGHT) for each cell in the grid.
I.e. valid_action_mask[i, j, k] is 1 if action k is valid in cell (i, j).
"""
height, width = self.owned_cells.shape
import dataclasses

ownership_channel = self.owned_cells
more_than_1_army = (self.armies > 1) * ownership_channel
owned_cells_indices = np.argwhere(more_than_1_army)
valid_action_mask = np.zeros((height, width, 4), dtype=bool)

if np.sum(ownership_channel) == 0:
return valid_action_mask

for channel_index, direction in enumerate(DIRECTIONS):
destinations = owned_cells_indices + direction.value

# check if destination is in grid bounds
in_first_boundary = np.all(destinations >= 0, axis=1)
in_height_boundary = destinations[:, 0] < height
in_width_boundary = destinations[:, 1] < width
destinations = destinations[in_first_boundary & in_height_boundary & in_width_boundary]

# check if destination is road
passable_cells = 1 - self.mountains
# assert that every value is either 0 or 1 in passable cells
assert np.all(np.isin(passable_cells, [0, 1])), f"{passable_cells}"
passable_cell_indices = passable_cells[destinations[:, 0], destinations[:, 1]] == 1
action_destinations = destinations[passable_cell_indices]

# get valid action mask for a given direction
valid_source_indices = action_destinations - direction.value
valid_action_mask[valid_source_indices[:, 0], valid_source_indices[:, 1], channel_index] = 1.0
import numpy as np

return valid_action_mask

def as_dict(self, with_mask=True):
_obs = {
"armies": self.armies,
"generals": self.generals,
"cities": self.cities,
"mountains": self.mountains,
"neutral_cells": self.neutral_cells,
"owned_cells": self.owned_cells,
"opponent_cells": self.opponent_cells,
"fog_cells": self.fog_cells,
"structures_in_fog": self.structures_in_fog,
"owned_land_count": self.owned_land_count,
"owned_army_count": self.owned_army_count,
"opponent_land_count": self.opponent_land_count,
"opponent_army_count": self.opponent_army_count,
"timestep": self.timestep,
"priority": self.priority,
}
if with_mask:
obs = {
"observation": _obs,
"action_mask": self.action_mask(),
}
else:
obs = _obs
return obs
@dataclasses.dataclass
class Observation(dict):
"""
We override some dictionary methods and subclass dict to allow the
Observation object to be accessible in dictionary-style format,
e.g. observation["armies"]. And to allow for providing a
listing of the keys/attributes.
These steps are necessary because PettingZoo & Gymnasium expect
dictionary-like Observation objects, but we want the benefits of
knowing the dictionaries' members which a dataclass/class provides.
"""

armies: np.ndarray
generals: np.ndarray
cities: np.ndarray
mountains: np.ndarray
neutral_cells: np.ndarray
owned_cells: np.ndarray
opponent_cells: np.ndarray
fog_cells: np.ndarray
structures_in_fog: np.ndarray
owned_land_count: int
owned_army_count: int
opponent_land_count: int
opponent_army_count: int
timestep: int
priority: int = 0

def __getitem__(self, attribute_name: str):
return getattr(self, attribute_name)

def keys(self):
return dataclasses.asdict(self).keys()
6 changes: 3 additions & 3 deletions generals/envs/gymnasium_generals.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,20 +96,20 @@ def reset(
self.observation_space = self.game.observation_space
self.action_space = self.game.action_space

observation = self.game.agent_observation(self.agent_id).as_dict()
observation = self.game.agent_observation(self.agent_id)
info: dict[str, Any] = {}
return observation, info

def step(self, action: Action) -> tuple[Observation, SupportsFloat, bool, bool, dict[str, Any]]:
# Get action of NPC
npc_observation = self.game.agent_observation(self.npc.id).as_dict()
npc_observation = self.game.agent_observation(self.npc.id)
npc_action = self.npc.act(npc_observation)
actions = {self.agent_id: action, self.npc.id: npc_action}

observations, infos = self.game.step(actions)

# From observations of all agents, pick only those relevant for the main agent
obs = observations[self.agent_id].as_dict()
obs = observations[self.agent_id]
info = infos[self.agent_id]
reward = self.reward_fn(obs, action, self.game.is_done(), info)
terminated = self.game.is_done()
Expand Down
Loading

0 comments on commit 09e7a10

Please sign in to comment.