Skip to content

Commit

Permalink
chore: Restructure code after client connetions
Browse files Browse the repository at this point in the history
  • Loading branch information
strakam committed Oct 23, 2024
1 parent 1b84fed commit 7e830c4
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 181 deletions.
3 changes: 2 additions & 1 deletion generals/agents/expander_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np

from generals.core.config import Action, Direction, Observation
from generals.core.config import Action, Direction
from generals.core.observation import Observation

from .agent import Agent

Expand Down
3 changes: 2 additions & 1 deletion generals/agents/random_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np

from generals.core.game import Action, Observation
from generals.core.game import Action
from generals.core.observation import Observation

from .agent import Agent

Expand Down
58 changes: 29 additions & 29 deletions generals/core/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,21 @@

class Channels:
"""
army - army size in each cell
general - general mask (1 if general is in cell, 0 otherwise)
mountain - mountain mask (1 if cell is mountain, 0 otherwise)
city - city mask (1 if cell is city, 0 otherwise)
armies - army size in each cell
generals - general mask (1 if general is in cell, 0 otherwise)
mountains - mountain mask (1 if cell is mountain, 0 otherwise)
cities - city mask (1 if cell is city, 0 otherwise)
passable - passable mask (1 if cell is passable, 0 otherwise)
ownership_i - ownership mask for player i (1 if player i owns cell, 0 otherwise)
ownership_neutral - ownership mask for neutral cells that are
passable (1 if cell is neutral, 0 otherwise)
"""

def __init__(self, grid: np.ndarray, _agents: list[str]):
self._army: np.ndarray = np.where(np.isin(grid, valid_generals), 1, 0).astype(int)
self._general: np.ndarray = np.where(np.isin(grid, valid_generals), 1, 0).astype(bool)
self._mountain: np.ndarray = np.where(grid == MOUNTAIN, 1, 0).astype(bool)
self._city: np.ndarray = np.where(np.char.isdigit(grid), 1, 0).astype(bool)
self._armies: np.ndarray = np.where(np.isin(grid, valid_generals), 1, 0).astype(int)
self._generals: np.ndarray = np.where(np.isin(grid, valid_generals), 1, 0).astype(bool)
self._mountains: np.ndarray = np.where(grid == MOUNTAIN, 1, 0).astype(bool)
self._cities: np.ndarray = np.where(np.char.isdigit(grid), 1, 0).astype(bool)
self._passable: np.ndarray = (grid != MOUNTAIN).astype(bool)

self._ownership: dict[str, np.ndarray] = {
Expand All @@ -33,7 +33,7 @@ def __init__(self, grid: np.ndarray, _agents: list[str]):

# City costs are 40 + digit in the cell
city_costs = np.where(np.char.isdigit(grid), grid, "0").astype(int)
self.army += 40 * self.city + city_costs
self.armies += 40 * self.cities + city_costs

def get_visibility(self, agent_id: str) -> np.ndarray:
channel = self._ownership[agent_id]
Expand All @@ -55,36 +55,36 @@ def ownership(self, value):
self._ownership = value

@property
def army(self) -> np.ndarray:
return self._army
def armies(self) -> np.ndarray:
return self._armies

@army.setter
def army(self, value):
self._army = value
@armies.setter
def armies(self, value):
self._armies = value

@property
def general(self) -> np.ndarray:
return self._general
def generals(self) -> np.ndarray:
return self._generals

@general.setter
def general(self, value):
self._general = value
@generals.setter
def generals(self, value):
self._generals = value

@property
def mountain(self) -> np.ndarray:
return self._mountain
def mountains(self) -> np.ndarray:
return self._mountains

@mountain.setter
def mountain(self, value):
self._mountain = value
@mountains.setter
def mountains(self, value):
self._mountains = value

@property
def city(self) -> np.ndarray:
return self._city
def cities(self) -> np.ndarray:
return self._cities

@city.setter
def city(self, value):
self._city = value
@cities.setter
def cities(self, value):
self._cities = value

@property
def passable(self) -> np.ndarray:
Expand Down
7 changes: 4 additions & 3 deletions generals/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@
from importlib.resources import files
from typing import Any, Literal, TypeAlias

import gymnasium as gym
import numpy as np

# Type aliases
Observation: TypeAlias = dict[str, np.ndarray | dict[str, gym.Space]]
Action: TypeAlias = dict[str, int | np.ndarray]
Info: TypeAlias = dict[str, Any]

Reward: TypeAlias = float
RewardFn: TypeAlias = Callable[[Observation, Action, bool, Info], Reward]
RewardFn: TypeAlias = Callable[["Observation", Action, bool, Info], Reward]
AgentID: TypeAlias = str

# Game Literals
Expand All @@ -34,6 +32,9 @@ class Direction(Enum):
RIGHT = (0, 1)


DIRECTIONS = [Direction.UP, Direction.DOWN, Direction.LEFT, Direction.RIGHT]


class Path(StrEnum):
GENERAL_PATH = str(files("generals.assets.images") / "crownie.png")
CITY_PATH = str(files("generals.assets.images") / "citie.png")
Expand Down
157 changes: 69 additions & 88 deletions generals/core/game.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import numpy as np

from .channels import Channels
from .config import Action, Direction, Info, Observation
from .config import DIRECTIONS, Action, Info
from .grid import Grid

DIRECTIONS = [Direction.UP, Direction.DOWN, Direction.LEFT, Direction.RIGHT]
from .observation import Observation


class Game:
Expand Down Expand Up @@ -42,6 +41,7 @@ def __init__(self, grid: Grid, agents: list[str]):
"armies": gym.spaces.MultiDiscrete(grid_discrete),
"generals": grid_multi_binary,
"cities": grid_multi_binary,
"mountains": grid_multi_binary,
"owned_cells": grid_multi_binary,
"opponent_cells": grid_multi_binary,
"neutral_cells": grid_multi_binary,
Expand All @@ -67,47 +67,6 @@ def __init__(self, grid: Grid, agents: list[str]):
}
)

# def action_mask(self, agent: str) -> np.ndarray:
# """
# Function to compute valid actions from a given ownership mask.
#
# Valid action is an action that originates from agent's cell with atleast 2 units
# and does not bump into a mountain or fall out of the grid.
# Returns:
# np.ndarray: an NxNx4 array, where each channel is a boolean mask
# of valid actions (UP, DOWN, LEFT, RIGHT) for each cell in the grid.
#
# I.e. valid_action_mask[i, j, k] is 1 if action k is valid in cell (i, j).
# """
#
# ownership_channel = self.channels.ownership[agent]
# more_than_1_army = (self.channels.army > 1) * ownership_channel
# owned_cells_indices = self.channel_to_indices(more_than_1_army)
# valid_action_mask = np.zeros((self.grid_dims[0], self.grid_dims[1], 4), dtype=bool)
#
# if self.is_done() and not self.agent_won(agent): # if you lost, return all zeros
# return valid_action_mask
#
# for channel_index, direction in enumerate(DIRECTIONS):
# destinations = owned_cells_indices + direction.value
#
# # check if destination is in grid bounds
# in_first_boundary = np.all(destinations >= 0, axis=1)
# in_height_boundary = destinations[:, 0] < self.grid_dims[0]
# in_width_boundary = destinations[:, 1] < self.grid_dims[1]
# destinations = destinations[in_first_boundary & in_height_boundary & in_width_boundary]
#
# # check if destination is road
# passable_cell_indices = self.channels.passable[destinations[:, 0], destinations[:, 1]] == 1
# action_destinations = destinations[passable_cell_indices]
#
# # get valid action mask for a given direction
# valid_source_indices = action_destinations - direction.value
# valid_action_mask[valid_source_indices[:, 0], valid_source_indices[:, 1], channel_index] = 1.0
# # assert False
# return valid_action_mask


def step(self, actions: dict[str, Action]) -> tuple[dict[str, Observation], dict[str, Any]]:
"""
Perform one step of the game
Expand All @@ -129,9 +88,9 @@ def step(self, actions: dict[str, Action]) -> tuple[dict[str, Observation], dict
if pass_turn == 1:
continue
if split_army == 1: # Agent wants to split the army
army_to_move = self.channels.army[i, j] // 2
army_to_move = self.channels.armies[i, j] // 2
else: # Leave just one army in the source cell
army_to_move = self.channels.army[i, j] - 1
army_to_move = self.channels.armies[i, j] - 1
if army_to_move < 1: # Skip if army size to move is less than 1
continue
moves[agent] = (i, j, direction, army_to_move)
Expand All @@ -141,8 +100,8 @@ def step(self, actions: dict[str, Action]) -> tuple[dict[str, Observation], dict
si, sj, direction, army_to_move = moves[agent]

# Cap the amount of army to move (previous moves may have lowered available army)
army_to_move = min(army_to_move, self.channels.army[si, sj] - 1)
army_to_stay = self.channels.army[si, sj] - army_to_move
army_to_move = min(army_to_move, self.channels.armies[si, sj] - 1)
army_to_stay = self.channels.armies[si, sj] - army_to_move

# Check if the current agent still owns the source cell and has more than 1 army
if self.channels.ownership[agent][si, sj] == 0 or army_to_move < 1:
Expand All @@ -154,20 +113,20 @@ def step(self, actions: dict[str, Action]) -> tuple[dict[str, Observation], dict
) # destination indices

# Figure out the target square owner and army size
target_square_army = self.channels.army[di, dj]
target_square_army = self.channels.armies[di, dj]
target_square_owner_idx = np.argmax(
[self.channels.ownership[agent][di, dj] for agent in ["neutral"] + self.agents]
)
target_square_owner = (["neutral"] + self.agents)[target_square_owner_idx]
if target_square_owner == agent:
self.channels.army[di, dj] += army_to_move
self.channels.army[si, sj] = army_to_stay
self.channels.armies[di, dj] += army_to_move
self.channels.armies[si, sj] = army_to_stay
else:
# Calculate resulting army, winner and update channels
remaining_army = np.abs(target_square_army - army_to_move)
square_winner = agent if target_square_army < army_to_move else target_square_owner
self.channels.army[di, dj] = remaining_army
self.channels.army[si, sj] = army_to_stay
self.channels.armies[di, dj] = remaining_army
self.channels.armies[si, sj] = army_to_stay
self.channels.ownership[square_winner][di, dj] = 1
if square_winner != target_square_owner:
self.channels.ownership[target_square_owner][di, dj] = 0
Expand All @@ -184,6 +143,10 @@ def step(self, actions: dict[str, Action]) -> tuple[dict[str, Observation], dict
else:
self._global_game_update()

observations = {agent: self.agent_observation(agent) for agent in self.agents}
infos = self.get_infos()
return observations, infos

def _global_game_update(self) -> None:
"""
Update game state globally.
Expand All @@ -194,13 +157,13 @@ def _global_game_update(self) -> None:
# every `increment_rate` steps, increase army size in each cell
if self.time % self.increment_rate == 0:
for owner in owners:
self.channels.army += self.channels.ownership[owner]
self.channels.armies += self.channels.ownership[owner]

# Increment armies on general and city cells, but only if they are owned by player
if self.time % 2 == 0 and self.time > 0:
update_mask = self.channels.general + self.channels.city
update_mask = self.channels.generals + self.channels.cities
for owner in owners:
self.channels.army += update_mask * self.channels.ownership[owner]
self.channels.armies += update_mask * self.channels.ownership[owner]

def is_done(self) -> bool:
"""
Expand All @@ -217,45 +180,63 @@ def get_infos(self) -> dict[str, Info]:
"""
players_stats = {}
for agent in self.agents:
army_size = np.sum(self.channels.army * self.channels.ownership[agent]).astype(int)
army_size = np.sum(self.channels.armies * self.channels.ownership[agent]).astype(int)
land_size = np.sum(self.channels.ownership[agent]).astype(int)
players_stats[agent] = {
"army": army_size,
"land": land_size,
"is_winner": self.agent_won(agent),
}
return players_stats
#
# def agent_observation(self, agent: str) -> Observation:
# """
# Returns an observation for a given agent.
# """
# info = self.get_infos()
# opponent = self.agents[0] if agent == self.agents[1] else self.agents[1]
# visible = self.channels.get_visibility(agent)
# invisible = 1 - visible
# _observation = {
# "army": self.channels.army.astype(int) * visible,
# "general": self.channels.general * visible,
# "city": self.channels.city * visible,
# "owned_cells": self.channels.ownership[agent] * visible,
# "opponent_cells": self.channels.ownership[opponent] * visible,
# "neutral_cells": self.channels.ownership_neutral * visible,
# "visible_cells": visible,
# "structures_in_fog": invisible * (self.channels.mountain + self.channels.city),
# "owned_land_count": info[agent]["land"],
# "owned_army_count": info[agent]["army"],
# "opponent_land_count": info[opponent]["land"],
# "opponent_army_count": info[opponent]["army"],
# "is_winner": int(info[agent]["is_winner"]),
# "timestep": self.time,
# }
# observation: Observation = {
# "observation": _observation,
# "action_mask": self.action_mask(agent),
# }
#
# return observation

def agent_observation(self, agent: str) -> Observation:
"""
Returns an observation for a given agent.
"""
scores = {}
for _agent in self.agents:
army_size = np.sum(self.channels.armies * self.channels.ownership[_agent]).astype(int)
land_size = np.sum(self.channels.ownership[_agent]).astype(int)
scores[_agent] = {
"army": army_size,
"land": land_size,
}

visible = self.channels.get_visibility(agent)
invisible = 1 - visible

opponent = self.agents[0] if agent == self.agents[1] else self.agents[1]

armies = self.channels.armies.astype(int) * visible
mountains = self.channels.mountains * visible
generals = self.channels.generals * visible
cities = self.channels.cities * visible
owned_cells = self.channels.ownership[agent] * visible
opponent_cells = self.channels.ownership[opponent] * visible
neutral_cells = self.channels.ownership_neutral * visible
structures_in_fog = invisible * (self.channels.mountains + self.channels.cities)
owned_land_count = scores[agent]["land"]
owned_army_count = scores[agent]["army"]
opponent_land_count = scores[opponent]["land"]
opponent_army_count = scores[opponent]["army"]
timestep = self.time

return Observation(
armies=armies,
generals=generals,
cities=cities,
mountains=mountains,
owned_cells=owned_cells,
opponent_cells=opponent_cells,
neutral_cells=neutral_cells,
visible_cells=visible,
structures_in_fog=structures_in_fog,
owned_land_count=owned_land_count,
owned_army_count=owned_army_count,
opponent_land_count=opponent_land_count,
opponent_army_count=opponent_army_count,
timestep=timestep,
).as_dict()

def agent_won(self, agent: str) -> bool:
"""
Expand Down
Loading

0 comments on commit 7e830c4

Please sign in to comment.