diff --git a/README.md b/README.md index a50afa8..7af309f 100644 --- a/README.md +++ b/README.md @@ -159,19 +159,19 @@ An observation for one agent is a dictionary `{"observation": observation, "acti The `observation` is a `Dict`. Values are either `numpy` matrices with shape `(N,M)`, or simple `int` constants: | Key | Shape | Description | | -------------------- | --------- | ---------------------------------------------------------------------------- | -| `army` | `(N,M)` | Number of units in a cell regardless of the owner | -| `general` | `(N,M)` | Mask indicating cells containing a general | -| `city` | `(N,M)` | Mask indicating cells containing a city | -| `visible_cells` | `(N,M)` | Mask indicating cells that are visible to the agent | -| `owned_cells` | `(N,M)` | Mask indicating cells owned by the agent | -| `opponent_cells` | `(N,M)` | Mask indicating cells owned by the opponent | -| `neutral_cells` | `(N,M)` | Mask indicating cells that are not owned by any agent | -| `structures_in_fog` | `(N,M)` | Mask indicating whether cells contain cities or mountains (in fog) | +| `armies` | `(N,M)` | Number of units in a visible cell regardless of the owner | +| `generals` | `(N,M)` | Mask indicating visible cells containing a general | +| `cities` | `(N,M)` | Mask indicating visible cells containing a city | +| `mountains` | `(N,M)` | Mask indicating visible cells containing mountains | +| `neutral_cells` | `(N,M)` | Mask indicating visible cells that are not owned by any agent | +| `owned_cells` | `(N,M)` | Mask indicating visible cells owned by the agent | +| `opponent_cells` | `(N,M)` | Mask indicating visible cells owned by the opponent | +| `fog_cells` | `(N,M)` | Mask indicating fog cells that are not mountains or cities | +| `structures_in_fog` | `(N,M)` | Mask showing cells containing either cities or mountains in fog | | `owned_land_count` | — | Number of cells the agent owns | | `owned_army_count` | — | Total number of units owned by the agent | | `opponent_land_count`| — | Number of cells owned by the opponent | | `opponent_army_count`| — | Total number of units owned by the opponent | -| `is_winner` | — | Indicates whether the agent won | | `timestep` | — | Current timestep of the game | The `action_mask` is a 3D array with shape `(N, M, 4)`, where each element corresponds to whether a move is valid from cell diff --git a/generals/core/game.py b/generals/core/game.py index 9f25c37..107a939 100644 --- a/generals/core/game.py +++ b/generals/core/game.py @@ -42,10 +42,10 @@ def __init__(self, grid: Grid, agents: list[str]): "generals": grid_multi_binary, "cities": grid_multi_binary, "mountains": grid_multi_binary, + "neutral_cells": grid_multi_binary, "owned_cells": grid_multi_binary, "opponent_cells": grid_multi_binary, - "neutral_cells": grid_multi_binary, - "visible_cells": grid_multi_binary, + "fog_cells": grid_multi_binary, "structures_in_fog": grid_multi_binary, "owned_land_count": gym.spaces.Discrete(self.max_army_value), "owned_army_count": gym.spaces.Discrete(self.max_army_value), @@ -211,10 +211,11 @@ def agent_observation(self, agent: str) -> Observation: mountains = self.channels.mountains * visible generals = self.channels.generals * visible cities = self.channels.cities * visible + neutral_cells = self.channels.ownership_neutral * visible owned_cells = self.channels.ownership[agent] * visible opponent_cells = self.channels.ownership[opponent] * visible - neutral_cells = self.channels.ownership_neutral * visible structures_in_fog = invisible * (self.channels.mountains + self.channels.cities) + fog_cells = invisible - structures_in_fog owned_land_count = scores[agent]["land"] owned_army_count = scores[agent]["army"] opponent_land_count = scores[opponent]["land"] @@ -226,17 +227,17 @@ def agent_observation(self, agent: str) -> Observation: generals=generals, cities=cities, mountains=mountains, + neutral_cells=neutral_cells, owned_cells=owned_cells, opponent_cells=opponent_cells, - neutral_cells=neutral_cells, - visible_cells=visible, + fog_cells=fog_cells, structures_in_fog=structures_in_fog, owned_land_count=owned_land_count, owned_army_count=owned_army_count, opponent_land_count=opponent_land_count, opponent_army_count=opponent_army_count, timestep=timestep, - ).as_dict() + ) def agent_won(self, agent: str) -> bool: """ diff --git a/generals/core/observation.py b/generals/core/observation.py index ecd69ec..cbd104b 100644 --- a/generals/core/observation.py +++ b/generals/core/observation.py @@ -10,10 +10,10 @@ def __init__( generals: np.ndarray, cities: np.ndarray, mountains: np.ndarray, + neutral_cells: np.ndarray, owned_cells: np.ndarray, opponent_cells: np.ndarray, - neutral_cells: np.ndarray, - visible_cells: np.ndarray, + fog_cells: np.ndarray, structures_in_fog: np.ndarray, owned_land_count: int, owned_army_count: int, @@ -25,16 +25,17 @@ def __init__( self.generals = generals self.cities = cities self.mountains = mountains + self.neutral_cells = neutral_cells self.owned_cells = owned_cells self.opponent_cells = opponent_cells - self.neutral_cells = neutral_cells - self.visible_cells = visible_cells + self.fog_cells = fog_cells self.structures_in_fog = structures_in_fog self.owned_land_count = owned_land_count self.owned_army_count = owned_army_count self.opponent_land_count = opponent_land_count self.opponent_army_count = opponent_army_count self.timestep = timestep + # armies, generals, cities, mountains, empty, owner, fogged, structure in fog def action_mask(self) -> np.ndarray: """ @@ -86,10 +87,10 @@ def as_dict(self, with_mask=True): "generals": self.generals, "cities": self.cities, "mountains": self.mountains, + "neutral_cells": self.neutral_cells, "owned_cells": self.owned_cells, "opponent_cells": self.opponent_cells, - "neutral_cells": self.neutral_cells, - "visible_cells": self.visible_cells, + "fog_cells": self.fog_cells, "structures_in_fog": self.structures_in_fog, "owned_land_count": self.owned_land_count, "owned_army_count": self.owned_army_count, diff --git a/generals/envs/gymnasium_generals.py b/generals/envs/gymnasium_generals.py index 793924f..4043a56 100644 --- a/generals/envs/gymnasium_generals.py +++ b/generals/envs/gymnasium_generals.py @@ -91,20 +91,21 @@ def reset( self.observation_space = self.game.observation_space self.action_space = self.game.action_space - observation = self.game.agent_observation(self.agent_id) + observation = self.game.agent_observation(self.agent_id).as_dict() info: dict[str, Any] = {} return observation, info def step(self, action: Action) -> tuple[Observation, SupportsFloat, bool, bool, dict[str, Any]]: # Get action of NPC - npc_observation = self.game.agent_observation(self.npc.id) + npc_observation = self.game.agent_observation(self.npc.id).as_dict() npc_action = self.npc.act(npc_observation) actions = {self.agent_id: action, self.npc.id: npc_action} observations, infos = self.game.step(actions) infos = {agent_id: {} for agent_id in self.agent_ids} + # From observations of all agents, pick only those relevant for the main agent - obs = observations[self.agent_id] + obs = observations[self.agent_id].as_dict() info = infos[self.agent_id] reward = self.reward_fn(obs, action, self.game.is_done(), info) terminated = self.game.is_done() diff --git a/generals/envs/pettingzoo_generals.py b/generals/envs/pettingzoo_generals.py index bd550d0..76c413b 100644 --- a/generals/envs/pettingzoo_generals.py +++ b/generals/envs/pettingzoo_generals.py @@ -82,7 +82,7 @@ def reset( elif hasattr(self, "replay"): del self.replay - observations = self.game.get_all_observations() + observations = {agent: self.game.agent_observation(agent).as_dict() for agent in self.agents} infos: dict[str, Any] = {agent: {} for agent in self.agents} return observations, infos @@ -96,6 +96,7 @@ def step( dict[AgentID, Info], ]: observations, infos = self.game.step(actions) + observations = {agent: observation.as_dict() for agent, observation in observations.items()} # You probably want to set your truncation based on self.game.time truncation = False if self.truncation is None else self.game.time >= self.truncation truncated = {agent: truncation for agent in self.agents}