diff --git a/generals/core/game.py b/generals/core/game.py index feda084..20769c0 100644 --- a/generals/core/game.py +++ b/generals/core/game.py @@ -17,6 +17,7 @@ class Game: def __init__(self, grid: Grid, agents: list[str]): # Agents self.agents = agents + self.agent_order = self.agents[:] # Grid _grid = grid.grid @@ -56,6 +57,7 @@ def __init__(self, grid: Grid, agents: list[str]): "opponent_land_count": gym.spaces.Discrete(self.max_army_value), "opponent_army_count": gym.spaces.Discrete(self.max_army_value), "timestep": gym.spaces.Discrete(self.max_timestep), + "priority": gym.spaces.Discrete(2), } ), "action_mask": gym.spaces.MultiBinary(self.grid_dims + (4,)), @@ -99,8 +101,9 @@ def step(self, actions: dict[str, Action]) -> tuple[dict[str, Observation], dict continue moves[agent] = (i, j, direction, army_to_move) - # Evaluate moves (smaller army movements are prioritized) - for agent in sorted(moves, key=lambda x: moves[x][3]): + for agent in self.agent_order: + if agent not in moves: + continue si, sj, direction, army_to_move = moves[agent] # Cap the amount of army to move (previous moves may have lowered available army) @@ -135,6 +138,9 @@ def step(self, actions: dict[str, Action]) -> tuple[dict[str, Observation], dict if square_winner != target_square_owner: self.channels.ownership[target_square_owner][di, dj] = 0 + # Swap agent order (because priority is alternating) + self.agent_order = self.agent_order[::-1] + if not done_before_actions: self.time += 1 @@ -225,6 +231,7 @@ def agent_observation(self, agent: str) -> Observation: opponent_land_count = scores[opponent]["land"] opponent_army_count = scores[opponent]["army"] timestep = self.time + priority = 1 if agent == self.agents[0] else 0 return Observation( armies=armies, @@ -241,6 +248,7 @@ def agent_observation(self, agent: str) -> Observation: opponent_land_count=opponent_land_count, opponent_army_count=opponent_army_count, timestep=timestep, + priority=priority, ) def agent_won(self, agent: str) -> bool: diff --git a/generals/core/observation.py b/generals/core/observation.py index cbd104b..b7aff2d 100644 --- a/generals/core/observation.py +++ b/generals/core/observation.py @@ -20,6 +20,7 @@ def __init__( opponent_land_count: int, opponent_army_count: int, timestep: int, + priority: int = 0, ): self.armies = armies self.generals = generals @@ -35,6 +36,7 @@ def __init__( self.opponent_land_count = opponent_land_count self.opponent_army_count = opponent_army_count self.timestep = timestep + self.priority = priority # armies, generals, cities, mountains, empty, owner, fogged, structure in fog def action_mask(self) -> np.ndarray: @@ -97,6 +99,7 @@ def as_dict(self, with_mask=True): "opponent_land_count": self.opponent_land_count, "opponent_army_count": self.opponent_army_count, "timestep": self.timestep, + "priority": self.priority, } if with_mask: obs = {