Skip to content

Commit

Permalink
Pre release finetuning (#87)
Browse files Browse the repository at this point in the history
* refactor: Restructure env creation

* feat: Convert from tuple action to dict action

* refactor: Pre-release cleanup

* refactor: Further sweeping

update workflow, solve mypy problems, update makefile, remove unused
files

* style: Resolve PR comments
  • Loading branch information
strakam authored Oct 13, 2024
1 parent b3f8e0a commit c6d85fb
Show file tree
Hide file tree
Showing 20 changed files with 272 additions and 227 deletions.
31 changes: 0 additions & 31 deletions .github/workflows/tests.yml

This file was deleted.

2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ repos:
hooks:
- id: mypy
name: mypy
entry: mypy --exclude 'examples/' --exclude 'tests/' --disable-error-code "annotation-unchecked" .
entry: mypy --exclude 'examples/' --exclude 'tests/' --exclude 'generals/agents/' --disable-error-code "annotation-unchecked" .
pass_filenames: false
language: system
types: [python]
Expand Down
4 changes: 3 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
.PHONY: test build clean


# Run PettingZoo example
pz:
python3 -m examples.pettingzoo_example
Expand Down Expand Up @@ -28,6 +27,9 @@ test_performance:
test:
pytest

pc:
pre-commit run --all-files

build:
python setup.py sdist bdist_wheel

Expand Down
2 changes: 1 addition & 1 deletion examples/gymnasium_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from generals import AgentFactory

# Initialize agents
# Initialize opponent agent ("random" or "expander")
npc = AgentFactory.make_agent("random")

# Create environment
Expand Down
11 changes: 5 additions & 6 deletions examples/pettingzoo_example.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import gymnasium as gym

from generals.agents import AgentFactory
from generals.envs import PettingZooGenerals

# Initialize agents
random = AgentFactory.make_agent("random")
Expand All @@ -9,14 +8,14 @@
agents = {
random.id: random,
expander.id: expander,
} # Environment calls agents by name
}
agent_ids = list(agents.keys()) # Environment calls agents by name

# Create environment -- render modes: {None, "human"}
env = gym.make("pz-generals-v0", agents=list(agents.keys()), render_mode="human")
# Create environment
env = PettingZooGenerals(agents=agent_ids, render_mode="human")
observations, info = env.reset()

done = False

while not done:
actions = {}
for agent in env.agents:
Expand Down
19 changes: 10 additions & 9 deletions generals/__init__.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,29 @@
from gymnasium.envs.registration import register

from .agents.agent_factory import AgentFactory
from .core.grid import Grid, GridFactory
from .core.replay import Replay
from generals.agents.agent_factory import AgentFactory
from generals.core.grid import Grid, GridFactory
from generals.core.replay import Replay
from generals.envs.pettingzoo_generals import PettingZooGenerals

__all__ = [
"AgentFactory",
"GridFactory",
"PettingZooGenerals",
"Grid",
"Replay",
]


def _register_generals_envs():
def _register_gym_generals_envs():
register(
id="gym-generals-v0",
entry_point="generals.envs.env:gym_generals_v0",
entry_point="generals.envs.gymnasium_generals:GymnasiumGenerals",
)

register(
id="pz-generals-v0",
entry_point="generals.envs.env:pz_generals_v0",
disable_env_checker=True,
id="gym-generals-normalized-v0",
entry_point="generals.envs.initializers:gyms_generals_normalized_v0",
)


_register_generals_envs()
_register_gym_generals_envs()
6 changes: 4 additions & 2 deletions generals/agents/agent.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from abc import ABC, abstractmethod

from generals.core.game import Action, Observation


class Agent(ABC):
"""
Base class for all agents.
"""

def __init__(self, id="NPC", color=(67, 70, 86)):
def __init__(self, id: str = "NPC", color: tuple[int, int, int] = (67, 70, 86)):
self.id = id
self.color = color

@abstractmethod
def act(self, observation):
def act(self, observation: Observation) -> Action:
"""
This method should be implemented by the child class.
It should receive an observation and return an action.
Expand Down
25 changes: 15 additions & 10 deletions generals/agents/expander_agent.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import numpy as np

from generals.core.config import Direction
from generals.core.config import Action, Direction, Observation

from .agent import Agent


class ExpanderAgent(Agent):
def __init__(self, id="Expander", color=(0, 130, 255)):
def __init__(self, id: str = "Expander", color: tuple[int, int, int] = (0, 130, 255)):
super().__init__(id, color)

def act(self, observation):
def act(self, observation: Observation) -> Action:
"""
Heuristically selects a valid (expanding) action.
Prioritizes capturing opponent and then neutral cells.
Expand All @@ -19,7 +19,12 @@ def act(self, observation):

valid_actions = np.argwhere(mask == 1)
if len(valid_actions) == 0: # No valid actions
return 1, np.array([0, 0]), 0, 0
return {
"pass": 1,
"cell": np.array([0, 0]),
"direction": 0,
"split": 0,
}

army = observation["army"]
opponent = observation["opponent_cells"]
Expand Down Expand Up @@ -49,12 +54,12 @@ def act(self, observation):
action_index = np.random.choice(len(valid_actions))
action = valid_actions[action_index]

pass_turn = 0 # 0 for not passing the turn, 1 for passing the turn
split_army = 0 # 0 for not splitting the army, 1 for splitting the army
cell = np.array([action[0], action[1]])
direction = action[2]

action = (pass_turn, cell, direction, split_army)
action = {
"pass": 0,
"cell": action[:2],
"direction": action[2],
"split": 0,
}
return action

def reset(self):
Expand Down
27 changes: 22 additions & 5 deletions generals/agents/random_agent.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
import numpy as np

from generals.core.game import Action, Observation

from .agent import Agent


class RandomAgent(Agent):
def __init__(self, id="Random", color=(242, 61, 106), split_prob=0.25, idle_prob=0.05):
def __init__(
self,
id: str = "Random",
color: tuple[int, int, int] = (242, 61, 106),
split_prob: float = 0.25,
idle_prob: float = 0.05,
):
super().__init__(id, color)

self.idle_probability = idle_prob
self.split_probability = split_prob

def act(self, observation):
def act(self, observation: Observation) -> Action:
"""
Randomly selects a valid action.
"""
Expand All @@ -19,16 +27,25 @@ def act(self, observation):

valid_actions = np.argwhere(mask == 1)
if len(valid_actions) == 0: # No valid actions
return 1, (0, 0), 0, 0

return {
"pass": 1,
"cell": np.array([0, 0]),
"direction": 0,
"split": 0,
}
pass_turn = 0 if np.random.rand() > self.idle_probability else 1
split_army = 0 if np.random.rand() > self.split_probability else 1

action_index = np.random.choice(len(valid_actions))
cell = valid_actions[action_index][:2]
direction = valid_actions[action_index][2]

action = (pass_turn, cell, direction, split_army)
action = {
"pass": pass_turn,
"cell": cell,
"direction": direction,
"split": split_army,
}
return action

def reset(self):
Expand Down
22 changes: 12 additions & 10 deletions generals/core/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,21 @@ class Channels:
ownership_neutral - ownership mask for neutral cells that are passable (1 if cell is neutral, 0 otherwise)
"""

def __init__(self, map: np.ndarray, _agents: list[str]):
self._army: np.ndarray = np.where(np.isin(map, valid_generals), 1, 0).astype(int)
self._general: np.ndarray = np.where(np.isin(map, valid_generals), 1, 0).astype(bool)
self._mountain: np.ndarray = np.where(map == MOUNTAIN, 1, 0).astype(bool)
self._city: np.ndarray = np.where(np.char.isdigit(map), 1, 0).astype(bool)
self._passable: np.ndarray = (map != MOUNTAIN).astype(bool)

self._ownership: dict[str, np.ndarray] = {"neutral": ((map == PASSABLE) | (np.char.isdigit(map))).astype(bool)}
def __init__(self, grid: np.ndarray, _agents: list[str]):
self._army: np.ndarray = np.where(np.isin(grid, valid_generals), 1, 0).astype(int)
self._general: np.ndarray = np.where(np.isin(grid, valid_generals), 1, 0).astype(bool)
self._mountain: np.ndarray = np.where(grid == MOUNTAIN, 1, 0).astype(bool)
self._city: np.ndarray = np.where(np.char.isdigit(grid), 1, 0).astype(bool)
self._passable: np.ndarray = (grid != MOUNTAIN).astype(bool)

self._ownership: dict[str, np.ndarray] = {
"neutral": ((grid == PASSABLE) | (np.char.isdigit(grid))).astype(bool)
}
for i, agent in enumerate(_agents):
self._ownership[agent] = np.where(map == chr(ord("A") + i), 1, 0).astype(bool)
self._ownership[agent] = np.where(grid == chr(ord("A") + i), 1, 0).astype(bool)

# City costs are 40 + digit in the cell
city_costs = np.where(np.char.isdigit(map), map, "0").astype(int)
city_costs = np.where(np.char.isdigit(grid), grid, "0").astype(int)
self.army += 40 * self.city + city_costs

@property
Expand Down
14 changes: 12 additions & 2 deletions generals/core/config.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
from enum import Enum, IntEnum, StrEnum
from importlib.resources import files
from typing import Literal
from typing import Any, Callable, Literal, TypeAlias

import gymnasium as gym
import numpy as np

Observation: TypeAlias = dict[str, np.ndarray | dict[str, gym.Space]]
Action: TypeAlias = dict[str, int | np.ndarray]
Info: TypeAlias = dict[str, Any]

Reward: TypeAlias = float
RewardFn: TypeAlias = Callable[[Observation, Action, bool, Info], Reward]
AgentID: TypeAlias = str

#################
# Game Literals #
#################
PASSABLE: Literal["."] = "."
MOUNTAIN: Literal["#"] = "#"
CITY: Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] = 0 # CITY can be any digit 0-9


class Dimension(IntEnum):
Expand Down
Loading

0 comments on commit c6d85fb

Please sign in to comment.