Skip to content

Commit

Permalink
refactor: Codebase combing and docs update
Browse files Browse the repository at this point in the history
  • Loading branch information
strakam committed Oct 24, 2024
1 parent 7e830c4 commit 5ca4adb
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 25 deletions.
18 changes: 9 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,19 +159,19 @@ An observation for one agent is a dictionary `{"observation": observation, "acti
The `observation` is a `Dict`. Values are either `numpy` matrices with shape `(N,M)`, or simple `int` constants:
| Key | Shape | Description |
| -------------------- | --------- | ---------------------------------------------------------------------------- |
| `army` | `(N,M)` | Number of units in a cell regardless of the owner |
| `general` | `(N,M)` | Mask indicating cells containing a general |
| `city` | `(N,M)` | Mask indicating cells containing a city |
| `visible_cells` | `(N,M)` | Mask indicating cells that are visible to the agent |
| `owned_cells` | `(N,M)` | Mask indicating cells owned by the agent |
| `opponent_cells` | `(N,M)` | Mask indicating cells owned by the opponent |
| `neutral_cells` | `(N,M)` | Mask indicating cells that are not owned by any agent |
| `structures_in_fog` | `(N,M)` | Mask indicating whether cells contain cities or mountains (in fog) |
| `armies` | `(N,M)` | Number of units in a visible cell regardless of the owner |
| `generals` | `(N,M)` | Mask indicating visible cells containing a general |
| `cities` | `(N,M)` | Mask indicating visible cells containing a city |
| `mountains` | `(N,M)` | Mask indicating visible cells containing mountains |
| `neutral_cells` | `(N,M)` | Mask indicating visible cells that are not owned by any agent |
| `owned_cells` | `(N,M)` | Mask indicating visible cells owned by the agent |
| `opponent_cells` | `(N,M)` | Mask indicating visible cells owned by the opponent |
| `fog_cells` | `(N,M)` | Mask indicating fog cells that are not mountains or cities |
| `structures_in_fog` | `(N,M)` | Mask showing cells containing either cities or mountains in fog |
| `owned_land_count` || Number of cells the agent owns |
| `owned_army_count` || Total number of units owned by the agent |
| `opponent_land_count`|| Number of cells owned by the opponent |
| `opponent_army_count`|| Total number of units owned by the opponent |
| `is_winner` || Indicates whether the agent won |
| `timestep` || Current timestep of the game |

The `action_mask` is a 3D array with shape `(N, M, 4)`, where each element corresponds to whether a move is valid from cell
Expand Down
13 changes: 7 additions & 6 deletions generals/core/game.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ def __init__(self, grid: Grid, agents: list[str]):
"generals": grid_multi_binary,
"cities": grid_multi_binary,
"mountains": grid_multi_binary,
"neutral_cells": grid_multi_binary,
"owned_cells": grid_multi_binary,
"opponent_cells": grid_multi_binary,
"neutral_cells": grid_multi_binary,
"visible_cells": grid_multi_binary,
"fog_cells": grid_multi_binary,
"structures_in_fog": grid_multi_binary,
"owned_land_count": gym.spaces.Discrete(self.max_army_value),
"owned_army_count": gym.spaces.Discrete(self.max_army_value),
Expand Down Expand Up @@ -211,10 +211,11 @@ def agent_observation(self, agent: str) -> Observation:
mountains = self.channels.mountains * visible
generals = self.channels.generals * visible
cities = self.channels.cities * visible
neutral_cells = self.channels.ownership_neutral * visible
owned_cells = self.channels.ownership[agent] * visible
opponent_cells = self.channels.ownership[opponent] * visible
neutral_cells = self.channels.ownership_neutral * visible
structures_in_fog = invisible * (self.channels.mountains + self.channels.cities)
fog_cells = invisible - structures_in_fog
owned_land_count = scores[agent]["land"]
owned_army_count = scores[agent]["army"]
opponent_land_count = scores[opponent]["land"]
Expand All @@ -226,17 +227,17 @@ def agent_observation(self, agent: str) -> Observation:
generals=generals,
cities=cities,
mountains=mountains,
neutral_cells=neutral_cells,
owned_cells=owned_cells,
opponent_cells=opponent_cells,
neutral_cells=neutral_cells,
visible_cells=visible,
fog_cells=fog_cells,
structures_in_fog=structures_in_fog,
owned_land_count=owned_land_count,
owned_army_count=owned_army_count,
opponent_land_count=opponent_land_count,
opponent_army_count=opponent_army_count,
timestep=timestep,
).as_dict()
)

def agent_won(self, agent: str) -> bool:
"""
Expand Down
13 changes: 7 additions & 6 deletions generals/core/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ def __init__(
generals: np.ndarray,
cities: np.ndarray,
mountains: np.ndarray,
neutral_cells: np.ndarray,
owned_cells: np.ndarray,
opponent_cells: np.ndarray,
neutral_cells: np.ndarray,
visible_cells: np.ndarray,
fog_cells: np.ndarray,
structures_in_fog: np.ndarray,
owned_land_count: int,
owned_army_count: int,
Expand All @@ -25,16 +25,17 @@ def __init__(
self.generals = generals
self.cities = cities
self.mountains = mountains
self.neutral_cells = neutral_cells
self.owned_cells = owned_cells
self.opponent_cells = opponent_cells
self.neutral_cells = neutral_cells
self.visible_cells = visible_cells
self.fog_cells = fog_cells
self.structures_in_fog = structures_in_fog
self.owned_land_count = owned_land_count
self.owned_army_count = owned_army_count
self.opponent_land_count = opponent_land_count
self.opponent_army_count = opponent_army_count
self.timestep = timestep
# armies, generals, cities, mountains, empty, owner, fogged, structure in fog

def action_mask(self) -> np.ndarray:
"""
Expand Down Expand Up @@ -86,10 +87,10 @@ def as_dict(self, with_mask=True):
"generals": self.generals,
"cities": self.cities,
"mountains": self.mountains,
"neutral_cells": self.neutral_cells,
"owned_cells": self.owned_cells,
"opponent_cells": self.opponent_cells,
"neutral_cells": self.neutral_cells,
"visible_cells": self.visible_cells,
"fog_cells": self.fog_cells,
"structures_in_fog": self.structures_in_fog,
"owned_land_count": self.owned_land_count,
"owned_army_count": self.owned_army_count,
Expand Down
7 changes: 4 additions & 3 deletions generals/envs/gymnasium_generals.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,20 +91,21 @@ def reset(
self.observation_space = self.game.observation_space
self.action_space = self.game.action_space

observation = self.game.agent_observation(self.agent_id)
observation = self.game.agent_observation(self.agent_id).as_dict()
info: dict[str, Any] = {}
return observation, info

def step(self, action: Action) -> tuple[Observation, SupportsFloat, bool, bool, dict[str, Any]]:
# Get action of NPC
npc_observation = self.game.agent_observation(self.npc.id)
npc_observation = self.game.agent_observation(self.npc.id).as_dict()
npc_action = self.npc.act(npc_observation)
actions = {self.agent_id: action, self.npc.id: npc_action}

observations, infos = self.game.step(actions)
infos = {agent_id: {} for agent_id in self.agent_ids}

# From observations of all agents, pick only those relevant for the main agent
obs = observations[self.agent_id]
obs = observations[self.agent_id].as_dict()
info = infos[self.agent_id]
reward = self.reward_fn(obs, action, self.game.is_done(), info)
terminated = self.game.is_done()
Expand Down
3 changes: 2 additions & 1 deletion generals/envs/pettingzoo_generals.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def reset(
elif hasattr(self, "replay"):
del self.replay

observations = self.game.get_all_observations()
observations = {agent: self.game.agent_observation(agent).as_dict() for agent in self.agents}
infos: dict[str, Any] = {agent: {} for agent in self.agents}
return observations, infos

Expand All @@ -96,6 +96,7 @@ def step(
dict[AgentID, Info],
]:
observations, infos = self.game.step(actions)
observations = {agent: observation.as_dict() for agent, observation in observations.items()}
# You probably want to set your truncation based on self.game.time
truncation = False if self.truncation is None else self.game.time >= self.truncation
truncated = {agent: truncation for agent in self.agents}
Expand Down

0 comments on commit 5ca4adb

Please sign in to comment.