Skip to content

Commit

Permalink
refactor: action to a subclass of ndarray for a clearer interface and…
Browse files Browse the repository at this point in the history
… consistent repr.
  • Loading branch information
anordin95 committed Dec 27, 2024
1 parent 399c13c commit 8ac225e
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 66 deletions.
71 changes: 37 additions & 34 deletions generals/agents/expander_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

from generals.core.action import Action, compute_valid_action_mask
from generals.core.config import Direction
from generals.core.action import Action, compute_valid_move_mask
from generals.core.config import DIRECTIONS
from generals.core.observation import Observation

from .agent import Agent
Expand All @@ -17,40 +17,43 @@ def act(self, observation: Observation) -> Action:
Prioritizes capturing opponent and then neutral cells.
"""

mask = compute_valid_action_mask(observation)
valid_actions = np.argwhere(mask == 1)
if len(valid_actions) == 0: # No valid actions
return np.array([1, 0, 0, 0, 0])

army = observation.armies
opponent = observation.opponent_cells
neutral = observation.neutral_cells

# Find actions that capture opponent or neutral cells
actions_capture_opponent = np.zeros(len(valid_actions))
actions_capture_neutral = np.zeros(len(valid_actions))

directions = [Direction.UP, Direction.DOWN, Direction.LEFT, Direction.RIGHT]
for i, action in enumerate(valid_actions):
di, dj = action[:-1] + directions[action[-1]].value # Destination cell indices
if army[action[0], action[1]] <= army[di, dj] + 1: # Can't capture
continue
elif opponent[di, dj]:
actions_capture_opponent[i] = 1
elif neutral[di, dj]:
actions_capture_neutral[i] = 1

if np.any(actions_capture_opponent): # Capture random opponent cell if possible
action_index = np.random.choice(np.nonzero(actions_capture_opponent)[0])
action = valid_actions[action_index]
elif np.any(actions_capture_neutral): # Capture random neutral cell if possible
action_index = np.random.choice(np.nonzero(actions_capture_neutral)[0])
action = valid_actions[action_index]
mask = compute_valid_move_mask(observation)
valid_moves = np.argwhere(mask == 1)

# Skip the turn if there are no valid moves.
if len(valid_moves) == 0:
return Action(to_pass=True)

army_mask = observation.armies
opponent_mask = observation.opponent_cells
neutral_mask = observation.neutral_cells

# Find moves that capture opponent or neutral cells
capture_opponent_moves = np.zeros(len(valid_moves))
capture_neutral_moves = np.zeros(len(valid_moves))

for move_idx, move in enumerate(valid_moves):
orig_row, orig_col, direction = move
row_offset, col_offset = DIRECTIONS[direction].value
dest_row, dest_col = (orig_row + row_offset, orig_col + col_offset)
enough_armies_to_capture = army_mask[orig_row, orig_col] > army_mask[dest_row, dest_col] + 1

if opponent_mask[dest_row, dest_col] and enough_armies_to_capture:
capture_opponent_moves[move_idx] = 1
elif neutral_mask[dest_row, dest_col] and enough_armies_to_capture:
capture_neutral_moves[move_idx] = 1

if np.any(capture_opponent_moves): # Capture random opponent cell if possible
move_index = np.random.choice(np.nonzero(capture_opponent_moves)[0])
move = valid_moves[move_index]
elif np.any(capture_neutral_moves): # Capture random neutral cell if possible
move_index = np.random.choice(np.nonzero(capture_neutral_moves)[0])
move = valid_moves[move_index]
else: # Otherwise, select a random valid action
action_index = np.random.choice(len(valid_actions))
action = valid_actions[action_index]
move_index = np.random.choice(len(valid_moves))
move = valid_moves[move_index]

action = np.array([0, action[0], action[1], action[2], 0])
action = Action(to_pass=False, row=move[0], col=move[1], direction=move[2], to_split=False)
return action

def reset(self):
Expand Down
24 changes: 13 additions & 11 deletions generals/agents/random_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from generals.core.action import Action, compute_valid_action_mask
from generals.core.action import Action, compute_valid_move_mask
from generals.core.observation import Observation

from .agent import Agent
Expand All @@ -24,19 +24,21 @@ def act(self, observation: Observation) -> Action:
Randomly selects a valid action.
"""

mask = compute_valid_action_mask(observation)
mask = compute_valid_move_mask(observation)

valid_actions = np.argwhere(mask == 1)
if len(valid_actions) == 0: # No valid actions
return [1, 0, 0, 0, 0]
pass_turn = 0 if np.random.rand() > self.idle_probability else 1
split_army = 0 if np.random.rand() > self.split_probability else 1
# Skip the turn if there are no valid moves.
valid_moves = np.argwhere(mask == 1)
if len(valid_moves) == 0:
return Action(to_pass=True)

action_index = np.random.choice(len(valid_actions))
cell = valid_actions[action_index][:2]
direction = valid_actions[action_index][2]
to_pass = 1 if np.random.rand() <= self.idle_probability else 0
to_split = 1 if np.random.rand() <= self.split_probability else 0

action = [pass_turn, cell[0], cell[1], direction, split_army]
move_index = np.random.choice(len(valid_moves))
(row, col) = valid_moves[move_index][:2]
direction = valid_moves[move_index][2]

action = Action(to_pass, row, col, direction, to_split)
return action

def reset(self):
Expand Down
52 changes: 31 additions & 21 deletions generals/core/action.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,43 @@
from typing import TypeAlias

import numpy as np

from generals.core.config import DIRECTIONS
from generals.core.config import DIRECTIONS, Direction

from .observation import Observation

"""
Action is intentionally a numpy array rather than a class for the sake of optimization. Granted,
this hasn't been performance tested, so take that decision with a grain of salt.
The action format is an array with 5 entries: [pass, row, col, direction, split]
Args:
pass: boolean integer (0 or 1) indicating whether the agent should pass/skip this turn and do nothing.
row: The row the agent should move from. In the closed-interval: [0, (grid_height - 1)].
col: The column the agent should move from. In the closed-interval: [0, (grid_width - 1)].
direction: An integer indicating which direction to move. 0 (up), 1 (down), 2 (left), 3 (right).
Note: the integer is effecitlvey an index into the DIRECTIONS enum.
split: boolean integer (0 or 1) indicating whether to split the army when moving.
"""
Action: TypeAlias = np.ndarray

class Action(np.ndarray):
"""
Action objects walk & talk like typical numpy-arrays, but have a more descriptive and narrow interface.
"""

def compute_valid_action_mask(observation: Observation) -> np.ndarray:
def __new__(cls, to_pass: bool, row: int = 0, col: int = 0, direction: int | Direction = 0, to_split: bool = False):
"""
Args:
cls: This argument is automatically provided by Python and is the Action class.
to_pass: Indicates whether the agent should pass/skip this turn i.e. do nothing. If to_pass is True,
the other arguments, like row & col, are effectively ignored.
row: The row the agent should move from. In the closed-interval: [0, (grid_height - 1)].
col: The column the agent should move from. In the closed-interval: [0, (grid_width - 1)].
direction: The direction the agent should move from the tile (row, col). Can either pass an enum-member
of Directions or the integer representation of the direction, which is the relevant index into the
config.DIRECTIONS array.
to_split: Indicates whether the army in (row, col) should be split, then moved in direction.
"""
if isinstance(direction, Direction):
direction = DIRECTIONS.index(direction)
action_array = np.array([to_pass, row, col, direction, to_split], dtype=np.int8)
return action_array


def compute_valid_move_mask(observation: Observation) -> np.ndarray:
"""
Return a mask of the valid actions for a given observation.
Return a mask of the valid moves for a given observation.
A valid move originates from a cell the agent owns, has at least 2 armies on
and does not attempt to enter a mountain nor exit the grid.
A valid action is an action that originates from an agent's cell, has
at least 2 units and does not attempt to enter a mountain nor exit the grid.
A move is distinct from an action. A move only has 3 dimensions: (row, col, direction).
Whereas an action also includes to_pass & to_split.
Returns:
np.ndarray: an NxNx4 array, where each channel is a boolean mask
Expand Down

0 comments on commit 8ac225e

Please sign in to comment.