Skip to content

Commit

Permalink
feat: Make sequential processing 40% faster with JIT Numba
Browse files Browse the repository at this point in the history
  • Loading branch information
strakam committed Mar 8, 2025
1 parent c35b2f5 commit b274f25
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 9 deletions.
6 changes: 3 additions & 3 deletions generals/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def compute_valid_move_mask(observation: Observation) -> np.ndarray:
if np.sum(ownership_channel) == 0:
return valid_action_mask

# check if destination is road
passable_cells = 1 - observation.mountains

for channel_index, direction in enumerate(DIRECTIONS):
destinations = owned_cells_indices + direction.value

Expand All @@ -86,10 +89,7 @@ def compute_valid_move_mask(observation: Observation) -> np.ndarray:
in_width_boundary = destinations[:, 1] < width
destinations = destinations[in_first_boundary & in_height_boundary & in_width_boundary]

# check if destination is road
passable_cells = 1 - observation.mountains
# assert that every value is either 0 or 1 in passable cells
assert np.all(np.isin(passable_cells, [0, 1])), f"{passable_cells}"
passable_cell_indices = passable_cells[destinations[:, 0], destinations[:, 1]] == 1
action_destinations = destinations[passable_cell_indices]

Expand Down
2 changes: 1 addition & 1 deletion generals/core/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, grid: np.ndarray, _agents: list[str]):

def get_visibility(self, agent_id: str) -> np.ndarray:
channel = self._ownership[agent_id]
return maximum_filter(channel, size=3).astype(bool)
return np.bool(maximum_filter(channel, size=3))

@staticmethod
def channel_to_indices(channel: np.ndarray) -> np.ndarray:
Expand Down
21 changes: 16 additions & 5 deletions generals/core/game.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, TypeAlias

import numba as nb
import numpy as np

from .action import Action
Expand All @@ -13,6 +14,16 @@
Info: TypeAlias = dict[str, Any]


@nb.njit(cache=True)
def calculate_army_size(armies, ownership):
return np.int32(np.sum(armies * ownership))


@nb.njit(cache=True)
def calculate_land_size(ownership):
return np.int32(np.sum(ownership))


class Game:
def __init__(self, grid: Grid, agents: list[str]):
# Agents
Expand Down Expand Up @@ -156,8 +167,8 @@ def get_infos(self) -> dict[str, Info]:
"""
players_stats = {}
for agent in self.agents:
army_size = np.sum(self.channels.armies * self.channels.ownership[agent]).astype(int)
land_size = np.sum(self.channels.ownership[agent]).astype(int)
army_size = calculate_army_size(self.channels.armies, self.channels.ownership[agent])
land_size = calculate_land_size(self.channels.ownership[agent])
players_stats[agent] = {
"army": army_size,
"land": land_size,
Expand All @@ -172,8 +183,8 @@ def agent_observation(self, agent: str) -> Observation:
"""
scores = {}
for _agent in self.agents:
army_size = np.sum(self.channels.armies * self.channels.ownership[_agent]).astype(int)
land_size = np.sum(self.channels.ownership[_agent]).astype(int)
army_size = calculate_army_size(self.channels.armies, self.channels.ownership[_agent])
land_size = calculate_land_size(self.channels.ownership[_agent])
scores[_agent] = {
"army": army_size,
"land": land_size,
Expand All @@ -184,7 +195,7 @@ def agent_observation(self, agent: str) -> Observation:

opponent = self.agents[0] if agent == self.agents[1] else self.agents[1]

armies = self.channels.armies.astype(int) * visible
armies = self.channels.armies * visible
mountains = self.channels.mountains * visible
generals = self.channels.generals * visible
cities = self.channels.cities * visible
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ packages = [{include = "generals"}]
[tool.poetry.dependencies]
python = "^3.11"
numpy = "^2.1.1"
numba = "0.60.0"
pettingzoo = "^1.24.3"
gymnasium = "^1.0.0"
pygame = "^2.6.0"
Expand Down

0 comments on commit b274f25

Please sign in to comment.