Skip to content

Commit

Permalink
fix: Align priority moves with generalsio
Browse files Browse the repository at this point in the history
  • Loading branch information
strakam committed Oct 29, 2024
1 parent 3b400a3 commit 6f539fd
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
12 changes: 10 additions & 2 deletions generals/core/game.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class Game:
def __init__(self, grid: Grid, agents: list[str]):
# Agents
self.agents = agents
self.agent_order = self.agents[:]

# Grid
_grid = grid.grid
Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(self, grid: Grid, agents: list[str]):
"opponent_land_count": gym.spaces.Discrete(self.max_army_value),
"opponent_army_count": gym.spaces.Discrete(self.max_army_value),
"timestep": gym.spaces.Discrete(self.max_timestep),
"priority": gym.spaces.Discrete(2),
}
),
"action_mask": gym.spaces.MultiBinary(self.grid_dims + (4,)),
Expand Down Expand Up @@ -99,8 +101,9 @@ def step(self, actions: dict[str, Action]) -> tuple[dict[str, Observation], dict
continue
moves[agent] = (i, j, direction, army_to_move)

# Evaluate moves (smaller army movements are prioritized)
for agent in sorted(moves, key=lambda x: moves[x][3]):
for agent in self.agent_order:
if agent not in moves:
continue
si, sj, direction, army_to_move = moves[agent]

# Cap the amount of army to move (previous moves may have lowered available army)
Expand Down Expand Up @@ -135,6 +138,9 @@ def step(self, actions: dict[str, Action]) -> tuple[dict[str, Observation], dict
if square_winner != target_square_owner:
self.channels.ownership[target_square_owner][di, dj] = 0

# Swap agent order (because priority is alternating)
self.agent_order = self.agent_order[::-1]

if not done_before_actions:
self.time += 1

Expand Down Expand Up @@ -225,6 +231,7 @@ def agent_observation(self, agent: str) -> Observation:
opponent_land_count = scores[opponent]["land"]
opponent_army_count = scores[opponent]["army"]
timestep = self.time
priority = 1 if agent == self.agents[0] else 0

return Observation(
armies=armies,
Expand All @@ -241,6 +248,7 @@ def agent_observation(self, agent: str) -> Observation:
opponent_land_count=opponent_land_count,
opponent_army_count=opponent_army_count,
timestep=timestep,
priority=priority,
)

def agent_won(self, agent: str) -> bool:
Expand Down
3 changes: 3 additions & 0 deletions generals/core/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
opponent_land_count: int,
opponent_army_count: int,
timestep: int,
priority: int = 0,
):
self.armies = armies
self.generals = generals
Expand All @@ -35,6 +36,7 @@ def __init__(
self.opponent_land_count = opponent_land_count
self.opponent_army_count = opponent_army_count
self.timestep = timestep
self.priority = priority
# armies, generals, cities, mountains, empty, owner, fogged, structure in fog

def action_mask(self) -> np.ndarray:
Expand Down Expand Up @@ -97,6 +99,7 @@ def as_dict(self, with_mask=True):
"opponent_land_count": self.opponent_land_count,
"opponent_army_count": self.opponent_army_count,
"timestep": self.timestep,
"priority": self.priority,
}
if with_mask:
obs = {
Expand Down

0 comments on commit 6f539fd

Please sign in to comment.