Skip to content

Commit

Permalink
refactor: Codebase combing and docs update
Browse files Browse the repository at this point in the history
  • Loading branch information
strakam committed Oct 24, 2024
1 parent 7e830c4 commit aa33036
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 70 deletions.
18 changes: 9 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,19 +159,19 @@ An observation for one agent is a dictionary `{"observation": observation, "acti
The `observation` is a `Dict`. Values are either `numpy` matrices with shape `(N,M)`, or simple `int` constants:
| Key | Shape | Description |
| -------------------- | --------- | ---------------------------------------------------------------------------- |
| `army` | `(N,M)` | Number of units in a cell regardless of the owner |
| `general` | `(N,M)` | Mask indicating cells containing a general |
| `city` | `(N,M)` | Mask indicating cells containing a city |
| `visible_cells` | `(N,M)` | Mask indicating cells that are visible to the agent |
| `owned_cells` | `(N,M)` | Mask indicating cells owned by the agent |
| `opponent_cells` | `(N,M)` | Mask indicating cells owned by the opponent |
| `neutral_cells` | `(N,M)` | Mask indicating cells that are not owned by any agent |
| `structures_in_fog` | `(N,M)` | Mask indicating whether cells contain cities or mountains (in fog) |
| `armies` | `(N,M)` | Number of units in a visible cell regardless of the owner |
| `generals` | `(N,M)` | Mask indicating visible cells containing a general |
| `cities` | `(N,M)` | Mask indicating visible cells containing a city |
| `mountains` | `(N,M)` | Mask indicating visible cells containing mountains |
| `neutral_cells` | `(N,M)` | Mask indicating visible cells that are not owned by any agent |
| `owned_cells` | `(N,M)` | Mask indicating visible cells owned by the agent |
| `opponent_cells` | `(N,M)` | Mask indicating visible cells owned by the opponent |
| `fog_cells` | `(N,M)` | Mask indicating fog cells that are not mountains or cities |
| `structures_in_fog` | `(N,M)` | Mask showing cells containing either cities or mountains in fog |
| `owned_land_count` || Number of cells the agent owns |
| `owned_army_count` || Total number of units owned by the agent |
| `opponent_land_count`|| Number of cells owned by the opponent |
| `opponent_army_count`|| Total number of units owned by the opponent |
| `is_winner` || Indicates whether the agent won |
| `timestep` || Current timestep of the game |

The `action_mask` is a 3D array with shape `(N, M, 4)`, where each element corresponds to whether a move is valid from cell
Expand Down
2 changes: 1 addition & 1 deletion generals/agents/expander_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from generals.core.config import Action, Direction
from generals.core.game import Action, Direction
from generals.core.observation import Observation

from .agent import Agent
Expand Down
13 changes: 1 addition & 12 deletions generals/core/config.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,6 @@
from collections.abc import Callable
from enum import Enum, IntEnum, StrEnum
from importlib.resources import files
from typing import Any, Literal, TypeAlias

import numpy as np

# Type aliases
Action: TypeAlias = dict[str, int | np.ndarray]
Info: TypeAlias = dict[str, Any]

Reward: TypeAlias = float
RewardFn: TypeAlias = Callable[["Observation", Action, bool, Info], Reward]
AgentID: TypeAlias = str
from typing import Literal

# Game Literals
PASSABLE: Literal["."] = "."
Expand Down
21 changes: 13 additions & 8 deletions generals/core/game.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from typing import Any
from typing import Any, TypeAlias

import gymnasium as gym
import numpy as np

from .channels import Channels
from .config import DIRECTIONS, Action, Info
from .config import DIRECTIONS
from .grid import Grid
from .observation import Observation

# Type aliases
Action: TypeAlias = dict[str, int | np.ndarray]
Info: TypeAlias = dict[str, Any]


class Game:
def __init__(self, grid: Grid, agents: list[str]):
Expand Down Expand Up @@ -42,10 +46,10 @@ def __init__(self, grid: Grid, agents: list[str]):
"generals": grid_multi_binary,
"cities": grid_multi_binary,
"mountains": grid_multi_binary,
"neutral_cells": grid_multi_binary,
"owned_cells": grid_multi_binary,
"opponent_cells": grid_multi_binary,
"neutral_cells": grid_multi_binary,
"visible_cells": grid_multi_binary,
"fog_cells": grid_multi_binary,
"structures_in_fog": grid_multi_binary,
"owned_land_count": gym.spaces.Discrete(self.max_army_value),
"owned_army_count": gym.spaces.Discrete(self.max_army_value),
Expand Down Expand Up @@ -211,10 +215,11 @@ def agent_observation(self, agent: str) -> Observation:
mountains = self.channels.mountains * visible
generals = self.channels.generals * visible
cities = self.channels.cities * visible
neutral_cells = self.channels.ownership_neutral * visible
owned_cells = self.channels.ownership[agent] * visible
opponent_cells = self.channels.ownership[opponent] * visible
neutral_cells = self.channels.ownership_neutral * visible
structures_in_fog = invisible * (self.channels.mountains + self.channels.cities)
fog_cells = invisible - structures_in_fog
owned_land_count = scores[agent]["land"]
owned_army_count = scores[agent]["army"]
opponent_land_count = scores[opponent]["land"]
Expand All @@ -226,17 +231,17 @@ def agent_observation(self, agent: str) -> Observation:
generals=generals,
cities=cities,
mountains=mountains,
neutral_cells=neutral_cells,
owned_cells=owned_cells,
opponent_cells=opponent_cells,
neutral_cells=neutral_cells,
visible_cells=visible,
fog_cells=fog_cells,
structures_in_fog=structures_in_fog,
owned_land_count=owned_land_count,
owned_army_count=owned_army_count,
opponent_land_count=opponent_land_count,
opponent_army_count=opponent_army_count,
timestep=timestep,
).as_dict()
)

def agent_won(self, agent: str) -> bool:
"""
Expand Down
13 changes: 7 additions & 6 deletions generals/core/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ def __init__(
generals: np.ndarray,
cities: np.ndarray,
mountains: np.ndarray,
neutral_cells: np.ndarray,
owned_cells: np.ndarray,
opponent_cells: np.ndarray,
neutral_cells: np.ndarray,
visible_cells: np.ndarray,
fog_cells: np.ndarray,
structures_in_fog: np.ndarray,
owned_land_count: int,
owned_army_count: int,
Expand All @@ -25,16 +25,17 @@ def __init__(
self.generals = generals
self.cities = cities
self.mountains = mountains
self.neutral_cells = neutral_cells
self.owned_cells = owned_cells
self.opponent_cells = opponent_cells
self.neutral_cells = neutral_cells
self.visible_cells = visible_cells
self.fog_cells = fog_cells
self.structures_in_fog = structures_in_fog
self.owned_land_count = owned_land_count
self.owned_army_count = owned_army_count
self.opponent_land_count = opponent_land_count
self.opponent_army_count = opponent_army_count
self.timestep = timestep
# armies, generals, cities, mountains, empty, owner, fogged, structure in fog

def action_mask(self) -> np.ndarray:
"""
Expand Down Expand Up @@ -86,10 +87,10 @@ def as_dict(self, with_mask=True):
"generals": self.generals,
"cities": self.cities,
"mountains": self.mountains,
"neutral_cells": self.neutral_cells,
"owned_cells": self.owned_cells,
"opponent_cells": self.opponent_cells,
"neutral_cells": self.neutral_cells,
"visible_cells": self.visible_cells,
"fog_cells": self.fog_cells,
"structures_in_fog": self.structures_in_fog,
"owned_land_count": self.owned_land_count,
"owned_army_count": self.owned_army_count,
Expand Down
23 changes: 10 additions & 13 deletions generals/envs/gymnasium_generals.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from collections.abc import Callable
from copy import deepcopy
from typing import Any, SupportsFloat
from typing import Any, SupportsFloat, TypeAlias

import gymnasium as gym

from generals.agents import Agent, AgentFactory
from generals.core.config import Reward, RewardFn
from generals.core.game import Action, Game, Info
from generals.core.grid import GridFactory
from generals.core.observation import Observation
from generals.core.replay import Replay
from generals.gui import GUI
from generals.gui.properties import GuiMode

Reward: TypeAlias = float
RewardFn: TypeAlias = Callable[[Observation, Action, bool, Info], Reward]


class GymnasiumGenerals(gym.Env):
metadata = {
Expand Down Expand Up @@ -91,20 +94,21 @@ def reset(
self.observation_space = self.game.observation_space
self.action_space = self.game.action_space

observation = self.game.agent_observation(self.agent_id)
observation = self.game.agent_observation(self.agent_id).as_dict()
info: dict[str, Any] = {}
return observation, info

def step(self, action: Action) -> tuple[Observation, SupportsFloat, bool, bool, dict[str, Any]]:
# Get action of NPC
npc_observation = self.game.agent_observation(self.npc.id)
npc_observation = self.game.agent_observation(self.npc.id).as_dict()
npc_action = self.npc.act(npc_observation)
actions = {self.agent_id: action, self.npc.id: npc_action}

observations, infos = self.game.step(actions)
infos = {agent_id: {} for agent_id in self.agent_ids}

# From observations of all agents, pick only those relevant for the main agent
obs = observations[self.agent_id]
obs = observations[self.agent_id].as_dict()
info = infos[self.agent_id]
reward = self.reward_fn(obs, action, self.game.is_done(), info)
terminated = self.game.is_done()
Expand All @@ -127,14 +131,7 @@ def _default_reward(
done: bool,
info: Info,
) -> Reward:
"""
Give 0 if game still running, otherwise 1 for winner and -1 for loser.
"""
if done:
reward = 1 if observation["observation"]["is_winner"] else -1
else:
reward = 0
return reward
return 0

def close(self) -> None:
if self.render_mode == "human":
Expand Down
20 changes: 9 additions & 11 deletions generals/envs/pettingzoo_generals.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
import functools
from collections.abc import Callable
from copy import deepcopy
from typing import Any
from typing import Any, TypeAlias

import pettingzoo # type: ignore
from gymnasium import spaces

from generals.agents.agent import Agent
from generals.core.config import AgentID, Reward, RewardFn
from generals.core.game import Action, Game, Info, Observation
from generals.core.grid import GridFactory
from generals.core.replay import Replay
from generals.gui import GUI
from generals.gui.properties import GuiMode

AgentID: TypeAlias = str
Reward: TypeAlias = float
RewardFn: TypeAlias = Callable[[Observation, Action, bool, Info], Reward]


class PettingZooGenerals(pettingzoo.ParallelEnv):
metadata: dict[str, Any] = {
Expand Down Expand Up @@ -82,7 +86,7 @@ def reset(
elif hasattr(self, "replay"):
del self.replay

observations = self.game.get_all_observations()
observations = {agent: self.game.agent_observation(agent).as_dict() for agent in self.agents}
infos: dict[str, Any] = {agent: {} for agent in self.agents}
return observations, infos

Expand All @@ -96,6 +100,7 @@ def step(
dict[AgentID, Info],
]:
observations, infos = self.game.step(actions)
observations = {agent: observation.as_dict() for agent, observation in observations.items()}
# You probably want to set your truncation based on self.game.time
truncation = False if self.truncation is None else self.game.time >= self.truncation
truncated = {agent: truncation for agent in self.agents}
Expand Down Expand Up @@ -128,14 +133,7 @@ def _default_reward(
done: bool,
info: Info,
) -> Reward:
"""
Give 0 if game still running, otherwise 1 for winner and -1 for loser.
"""
if done:
reward = 1 if observation["observation"]["is_winner"] else -1
else:
reward = 0
return reward
return 0

def close(self) -> None:
if self.render_mode == "human":
Expand Down
19 changes: 9 additions & 10 deletions generals/remote/generalsio_client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import numpy as np
from scipy.ndimage import maximum_filter # type: ignore
from socketio import SimpleClient # type: ignore

from generals.agents.agent import Agent
from generals.core.game import Direction
from generals.core.config import Direction
from generals.core.observation import Observation

DIRECTIONS = [Direction.UP, Direction.DOWN, Direction.LEFT, Direction.RIGHT]
Expand Down Expand Up @@ -36,7 +35,7 @@ def __str__(self) -> str:

def apply_diff(old: list[int], diff: list[int]) -> list[int]:
i = 0
new = []
new: list[int] = []
while i < len(diff):
if diff[i] > 0: # matching
new.extend(old[len(new) : len(new) + diff[i]])
Expand Down Expand Up @@ -68,8 +67,8 @@ def __init__(self, data: dict):

self.n_players = len(self.usernames)

self.map = []
self.cities = []
self.map: list[int] = []
self.cities: list[int] = []

def update(self, data: dict) -> None:
self.turn = data["turn"]
Expand Down Expand Up @@ -100,7 +99,7 @@ def get_observation(self) -> "Observation":
opponent_cells = np.where(terrain == self.opponent_index, 1, 0)
neutral_cells = np.where(terrain == -1, 1, 0)
mountain_cells = np.where(terrain == -2, 1, 0)
visible_cells = maximum_filter(np.where(terrain == self.player_index, 1, 0), size=3)
fog_cells = np.where(terrain == -3, 1, 0)
structures_in_fog = np.where(terrain == -4, 1, 0)
owned_land_count = self.scores[self.player_index]["tiles"]
owned_army_count = self.scores[self.player_index]["total"]
Expand All @@ -113,10 +112,10 @@ def get_observation(self) -> "Observation":
generals=generals,
cities=cities,
mountains=mountain_cells,
neutral_cells=neutral_cells,
owned_cells=owned_cells,
opponent_cells=opponent_cells,
neutral_cells=neutral_cells,
visible_cells=visible_cells,
fog_cells=fog_cells,
structures_in_fog=structures_in_fog,
owned_land_count=owned_land_count,
owned_army_count=owned_army_count,
Expand Down Expand Up @@ -210,8 +209,8 @@ def _play_game(self) -> None:
# This code here should be made way prettier, its just POC
action = self.agent.act(obs)
if not action["pass"]:
source = action["cell"]
direction = DIRECTIONS[action["direction"]].value
source: np.ndarray = np.array(action["cell"])
direction = np.array(DIRECTIONS[action["direction"]].value)
split = action["split"]
destination = source + direction
# convert to index
Expand Down

0 comments on commit aa33036

Please sign in to comment.