Skip to content

Commit

Permalink
refactor: Align observations more with generalsio
Browse files Browse the repository at this point in the history
  • Loading branch information
strakam committed Oct 22, 2024
1 parent 76353ca commit 64a10ef
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 65 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ The `observation` is a `Dict`. Values are either `numpy` matrices with shape `(N
| `owned_cells` | `(N,M)` | Mask indicating cells owned by the agent |
| `opponent_cells` | `(N,M)` | Mask indicating cells owned by the opponent |
| `neutral_cells` | `(N,M)` | Mask indicating cells that are not owned by any agent |
| `structure` | `(N,M)` | Mask indicating whether cells contain cities or mountains, even out of FoV |
| `structures_in_fog` | `(N,M)` | Mask indicating whether cells contain cities or mountains (in fog) |
| `owned_land_count` || Number of cells the agent owns |
| `owned_army_count` || Total number of units owned by the agent |
| `opponent_land_count`|| Number of cells owned by the opponent |
Expand Down
5 changes: 5 additions & 0 deletions generals/core/channels.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from scipy.ndimage import maximum_filter # type: ignore

from .config import MOUNTAIN, PASSABLE

Expand Down Expand Up @@ -34,6 +35,10 @@ def __init__(self, grid: np.ndarray, _agents: list[str]):
city_costs = np.where(np.char.isdigit(grid), grid, "0").astype(int)
self.army += 40 * self.city + city_costs

def get_visibility(self, agent_id: str) -> np.ndarray:
channel = self._ownership[agent_id]
return maximum_filter(channel, size=3)

@property
def ownership(self) -> dict[str, np.ndarray]:
return self._ownership
Expand Down
28 changes: 11 additions & 17 deletions generals/core/game.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import gymnasium as gym
import numpy as np
from scipy.ndimage import maximum_filter # type: ignore

from .channels import Channels
from .config import Action, Direction, Info, Observation
Expand Down Expand Up @@ -48,7 +47,7 @@ def __init__(self, grid: Grid, agents: list[str]):
"opponent_cells": grid_multi_binary,
"neutral_cells": grid_multi_binary,
"visible_cells": grid_multi_binary,
"structure": grid_multi_binary,
"structures_in_fog": grid_multi_binary,
"owned_land_count": gym.spaces.Discrete(self.max_army_value),
"owned_army_count": gym.spaces.Discrete(self.max_army_value),
"opponent_land_count": gym.spaces.Discrete(self.max_army_value),
Expand Down Expand Up @@ -116,12 +115,6 @@ def channel_to_indices(self, channel: np.ndarray) -> np.ndarray:
"""
return np.argwhere(channel != 0)

def visibility_channel(self, ownership_channel: np.ndarray) -> np.ndarray:
"""
Returns a binary channel of visible cells from the perspective of the given player.
"""
return maximum_filter(ownership_channel, size=3)

def step(self, actions: dict[str, Action]) -> tuple[dict[str, Observation], dict[str, Any]]:
"""
Perform one step of the game
Expand Down Expand Up @@ -264,16 +257,17 @@ def agent_observation(self, agent: str) -> Observation:
"""
info = self.get_infos()
opponent = self.agents[0] if agent == self.agents[1] else self.agents[1]
visibility = self.visibility_channel(self.channels.ownership[agent])
visible = self.channels.get_visibility(agent)
invisible = 1 - visible
_observation = {
"army": self.channels.army.astype(int) * visibility,
"general": self.channels.general * visibility,
"city": self.channels.city * visibility,
"owned_cells": self.channels.ownership[agent] * visibility,
"opponent_cells": self.channels.ownership[opponent] * visibility,
"neutral_cells": self.channels.ownership_neutral * visibility,
"visible_cells": visibility,
"structure": self.channels.mountain + self.channels.city,
"army": self.channels.army.astype(int) * visible,
"general": self.channels.general * visible,
"city": self.channels.city * visible,
"owned_cells": self.channels.ownership[agent] * visible,
"opponent_cells": self.channels.ownership[opponent] * visible,
"neutral_cells": self.channels.ownership_neutral * visible,
"visible_cells": visible,
"structures_in_fog": invisible * (self.channels.mountain + self.channels.city),
"owned_land_count": info[agent]["land"],
"owned_army_count": info[agent]["army"],
"opponent_land_count": info[opponent]["land"],
Expand Down
6 changes: 3 additions & 3 deletions generals/envs/gymnasium_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, env):
"opponent_cells": grid_multi_binary,
"neutral_cells": grid_multi_binary,
"visible_cells": grid_multi_binary,
"structure": grid_multi_binary,
"structures_in_fog": grid_multi_binary,
"owned_land_count": unit_box,
"owned_army_count": unit_box,
"opponent_land_count": unit_box,
Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(self, env):
"opponent_cells": grid_multi_binary,
"neutral_cells": grid_multi_binary,
"visible_cells": grid_multi_binary,
"structure": grid_multi_binary,
"structures_in_fog": grid_multi_binary,
"owned_land_count": unit_box,
"owned_army_count": unit_box,
"opponent_land_count": unit_box,
Expand Down Expand Up @@ -106,7 +106,7 @@ def observation(self, observation):
_observation["opponent_cells"],
_observation["neutral_cells"],
_observation["visible_cells"],
_observation["structure"],
_observation["structures_in_fog"],
_owned_land_count,
_owned_army_count,
_opponent_land_count,
Expand Down
2 changes: 1 addition & 1 deletion generals/gui/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def render_grid(self):
ownership = self.game.channels.ownership[agent]
owned_map = np.logical_or(owned_map, ownership)
if self.agent_fov[agent]:
visibility = self.game.visibility_channel(ownership)
visibility = self.game.channels.get_visibility(agent)
visible_map = np.logical_or(visible_map, visibility)

# Helper maps for not owned and invisible cells
Expand Down
46 changes: 37 additions & 9 deletions generals/remote/generalsio_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import numpy as np
from socketio import SimpleClient # type: ignore

from generals.agents.agent import Agent
from generals.core.config import Observation


class GeneralsBotError(Exception):
Expand Down Expand Up @@ -41,24 +43,52 @@ def apply_diff(old: list[int], diff: list[int]) -> list[int]:
i += 1
return new

test_old_1 = [0, 0]
test_diff_1 = [1, 1, 3]
desired = [0,3]
assert apply_diff(test_old_1, test_diff_1) == desired
test_old_2 = [0,0]
test_diff_2 = [0,1,2,1]
desired = [2, 0]
assert apply_diff(test_old_2, test_diff_2) == desired
print("All tests passed")


class GeneralsIOState:
def __init__(self, data: dict):
self.replay_id = data["replay_id"]
self.usernames = data["usernames"]
self.player_index = data["playerIndex"]
self.opponent_index = 1 - self.player_index # works only for 1v1

self.n_players = len(self.usernames)

self.map = []
self.cities = []
self.generals = []
self.scores = []
self.stars = []

self.turn = 0

def update(self, data: dict) -> None:
self.turn = data["turn"]
self.map = apply_diff(self.map, data["map_diff"])
self.cities = apply_diff(self.cities, data["cities_diff"])
self.generals = data["generals"]
self.scores = data["scores"]
if "stars" in data:
self.stars = data["stars"]


def agent_observation(self) -> Observation:
width, height = self.map[0], self.map[1]
size = height * width

armies = np.array(self.map[2 : 2 + size]).reshape((height, width))
terrain = np.array(self.map[2 + size : 2 + 2 * size]).reshape((height, width))

# make 2D binary map of owned cells. These are the ones that have self.player_index value in terrain
army = armies
owned_cells = np.where(terrain == self.player_index, 1, 0)
opponent_cells = np.where(terrain == self.opponent_index, 1, 0)
visible_neutral_cells = np.where(terrain == -1, 1, 0)
print(self.generals)


class GeneralsIOClient(SimpleClient):
Expand All @@ -72,8 +102,6 @@ def __init__(self, agent: Agent, user_id: str):
self.connect("https://botws.generals.io")
self.user_id = user_id
self._queue_id = ""
self.replay_id = None
self.usernames = []

@property
def queue_id(self):
Expand Down Expand Up @@ -125,15 +153,14 @@ def join_game(self, force_start: bool = True) -> None:
def _initialize_game(self, data: dict) -> None:
"""
Triggered after server starts the game.
:param agent_index: The index of agent in the game
:param data: dictionary of information received in the beginning
"""
self.game_state = GeneralsIOState(data[0])

def _play_game(self) -> None:
"""
Triggered after server starts the game.
TODO: spawn a new thread in which Agent will calculate its moves
:param agent_index: The index of agent in the game
"""
winner = False
# TODO deserts?
Expand All @@ -143,6 +170,7 @@ def _play_game(self) -> None:
match event:
case "game_update":
self.game_state.update(data)
self.game_state.agent_observation()
case "game_lost" | "game_won":
# server sends game_lost or game_won before game_over
winner = event == "game_won"
Expand Down
35 changes: 1 addition & 34 deletions tests/test_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,39 +67,6 @@ def test_channel_to_indices():
assert (indices == reference).all()


def test_visibility_channel():
"""
For given ownership mask, we should get visibility mask.
"""
dummy_game = get_game()

ownership = np.array([[0, 0, 0], [0, 1, 0], [0, 0, 0]])
reference = np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]])
visibility = dummy_game.visibility_channel(ownership)
assert (visibility == reference).all()

ownership = np.array(
[
[0, 0, 0, 0, 0],
[1, 1, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1],
]
)
reference = np.array(
[
[1, 1, 1, 0, 0],
[1, 1, 1, 0, 0],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[0, 0, 0, 1, 1],
]
)
visibility = dummy_game.visibility_channel(ownership)
assert (visibility == reference).all()


def test_action_mask():
"""
For given ownership mask and passable mask, we should get NxNx4 mask of valid actions.
Expand All @@ -115,7 +82,7 @@ def test_action_mask():
],
dtype=int,
)
game.channels._set_passable(
game.channels.passable = (
np.array(
[
[1, 1, 1, 1],
Expand Down

0 comments on commit 64a10ef

Please sign in to comment.