Skip to content

Commit

Permalink
Merge pull request #95 from strakam/better-maps
Browse files Browse the repository at this point in the history
Better maps
  • Loading branch information
strakam authored Oct 18, 2024
2 parents 8af020e + d56c63c commit 3cfac8d
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 54 deletions.
5 changes: 3 additions & 2 deletions generals/core/game.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def action_mask(self, agent: str) -> np.ndarray:
owned_cells_indices = self.channel_to_indices(more_than_1_army)
valid_action_mask = np.zeros((self.grid_dims[0], self.grid_dims[1], 4), dtype=bool)

if self.is_done():
if self.is_done() and not self.agent_won(agent): # if you lost, return all zeros
return valid_action_mask

for channel_index, direction in enumerate(DIRECTIONS):
Expand All @@ -107,6 +107,7 @@ def action_mask(self, agent: str) -> np.ndarray:
# get valid action mask for a given direction
valid_source_indices = action_destinations - direction.value
valid_action_mask[valid_source_indices[:, 0], valid_source_indices[:, 1], channel_index] = 1.0
# assert False
return valid_action_mask

def channel_to_indices(self, channel: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -197,7 +198,7 @@ def step(self, actions: dict[str, Action]) -> tuple[dict[str, Observation], dict
self.time += 1

if self.is_done():
# Give all cells of loser to winner
# give all cells of loser to winner
winner = self.agents[0] if self.agent_won(self.agents[0]) else self.agents[1]
loser = self.agents[1] if winner == self.agents[0] else self.agents[0]
self.channels.ownership[winner] += self.channels.ownership[loser]
Expand Down
58 changes: 36 additions & 22 deletions generals/core/grid.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from numpy.random import Generator

from .config import MOUNTAIN, PASSABLE

Expand All @@ -24,7 +25,7 @@ def grid(self, grid: str | np.ndarray):
pass
case _:
raise ValueError("Grid must be encoded as a string or a numpy array.")
if not Grid.verify_grid(grid):
if not Grid.verify_grid_connectivity(grid):
raise ValueError("Invalid grid layout - generals cannot reach each other.")
# check that exactly one 'A' and one 'B' are present in the grid
first_general = np.argwhere(np.isin(grid, ["A"]))
Expand All @@ -34,6 +35,11 @@ def grid(self, grid: str | np.ndarray):

self._grid = grid

@staticmethod
def generals_distance(grid: "Grid") -> int:
generals = np.argwhere(np.isin(grid.grid, ["A", "B"]))
return abs(generals[0][0] - generals[1][0]) + abs(generals[0][1] - generals[1][1])

@staticmethod
def numpify_grid(grid: str) -> np.ndarray:
return np.array([list(row) for row in grid.strip().split("\n")])
Expand All @@ -43,17 +49,21 @@ def stringify_grid(grid: np.ndarray) -> str:
return "\n".join(["".join(row) for row in grid])

@staticmethod
def verify_grid(grid: np.ndarray) -> bool:
def verify_grid_connectivity(grid: np.ndarray | str) -> bool:
"""
Verify grid layout (can generals reach each other?)
Returns True if grid is valid, False otherwise
"""
if isinstance(grid, str):
grid = Grid.numpify_grid(grid)

height, width = grid.shape

def dfs(grid, visited, square):
i, j = square
if i < 0 or i >= grid.shape[0] or j < 0 or j >= grid.shape[1] or visited[i, j]:
if i < 0 or i >= height or j < 0 or j >= width or visited[i, j]:
return
if grid[i, j] == MOUNTAIN:
if grid[i, j] == MOUNTAIN or str(grid[i, j]).isdigit(): # mountain or city
return
visited[i, j] = True
for di, dj in [[-1, 0], [1, 0], [0, -1], [0, 1]]:
Expand All @@ -62,6 +72,7 @@ def dfs(grid, visited, square):

generals = np.argwhere(np.isin(grid, ["A", "B"]))
start, end = generals[0], generals[1]

visited = np.zeros_like(grid, dtype=bool)
dfs(grid, visited, start)
return visited[end[0], end[1]]
Expand All @@ -84,7 +95,15 @@ def __init__(
self.mountain_density = mountain_density
self.city_density = city_density
self.general_positions = general_positions
self.seed = seed
self._rng = np.random.default_rng(seed)

@property
def rng(self):
return self._rng

@rng.setter
def rng(self, number_generator: Generator):
self._rng = number_generator

def grid_from_string(self, grid: str) -> Grid:
return Grid(grid)
Expand All @@ -105,32 +124,28 @@ def grid_from_generator(
city_density = self.city_density
if general_positions is None:
general_positions = self.general_positions
if seed is None:
if self.seed is None:
seed = np.random.randint(0, 2**20)
else:
seed = self.seed
if seed is not None:
self.rng = np.random.default_rng(seed)

# Probabilities of each cell type
p_neutral = 1 - mountain_density - city_density
probs = [p_neutral, mountain_density] + [city_density / 10] * 10

# Place cells on the map
rng = np.random.default_rng(seed)
map = rng.choice(
map = self.rng.choice(
[PASSABLE, MOUNTAIN, "0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],
size=grid_dims,
p=probs,
)

# Place generals on random squares - generals_positions is a list of two tuples
if general_positions is None:
general_positions = []
while len(general_positions) < 2:
position = tuple(rng.integers(0, grid_dims))
if position not in general_positions:
general_positions.append(position)

# Place generals on random squares, they should be atleast some distance apart
min_distance = max(grid_dims) // 2
p1 = self.rng.integers(0, grid_dims[0]), self.rng.integers(0, grid_dims[1])
while True:
p2 = self.rng.integers(0, grid_dims[0]), self.rng.integers(0, grid_dims[1])
if abs(p1[0] - p2[0]) + abs(p1[1] - p2[1]) >= min_distance:
break
general_positions = [p1, p2]
for i, idx in enumerate(general_positions):
map[idx[0], idx[1]] = chr(ord("A") + i)

Expand All @@ -140,11 +155,10 @@ def grid_from_generator(
try:
return Grid(map_string)
except ValueError:
seed += 1 # Increase seed to generate a different map
return self.grid_from_generator(
grid_dims=grid_dims,
mountain_density=mountain_density,
city_density=city_density,
general_positions=general_positions,
seed=seed,
seed=None,
)
4 changes: 2 additions & 2 deletions generals/envs/gymnasium_generals.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def reset(
if "grid" in options:
grid = self.grid_factory.grid_from_string(options["grid"])
else:
map_seed = self.np_random.integers(0, 2**20)
grid = self.grid_factory.grid_from_generator(seed=map_seed)
self.grid_factory.rng = self.np_random
grid = self.grid_factory.grid_from_generator()

# Create game for current run
self.game = Game(grid, self.agent_ids)
Expand Down
45 changes: 17 additions & 28 deletions tests/test_map.py → tests/test_grid.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from generals.core.grid import Grid
from generals.core.grid import Grid, GridFactory


def test_grid_creation():
Expand All @@ -16,7 +16,6 @@ def test_grid_creation():
grid_nd_array = Grid(map_nd_array)
assert grid_str == grid_nd_array


def test_verify_grid():
map = """
.....
Expand All @@ -25,8 +24,8 @@ def test_verify_grid():
..22.
...B.
"""
map = Grid.numpify_grid(map)
assert Grid.verify_grid(map)
_grid = Grid(map)
assert Grid.verify_grid_connectivity(_grid.grid)

map = """
.....
Expand All @@ -37,27 +36,7 @@ def test_verify_grid():
"""

map = Grid.numpify_grid(map)
assert not Grid.verify_grid(map)

map = """
.....
.A##2
##.2.
..2##
...B.
"""
map = Grid.numpify_grid(map)
assert Grid.verify_grid(map)

map = """
..#..
.A##2
##.2.
..2##
...B.
"""
map = Grid.numpify_grid(map)
assert not Grid.verify_grid(map)
assert not Grid.verify_grid_connectivity(map)

map = """
.....
Expand All @@ -67,7 +46,7 @@ def test_verify_grid():
.....
"""
map = Grid.numpify_grid(map)
assert Grid.verify_grid(map)
assert Grid.verify_grid_connectivity(map)

map = """
...#.
Expand All @@ -77,7 +56,7 @@ def test_verify_grid():
.....
"""
map = Grid.numpify_grid(map)
assert not Grid.verify_grid(map)
assert not Grid.verify_grid_connectivity(map)

map = """
...#.
Expand All @@ -87,7 +66,17 @@ def test_verify_grid():
.....
"""
map = Grid.numpify_grid(map)
assert Grid.verify_grid(map)
assert not Grid.verify_grid_connectivity(map)

def test_grid_factory():
generator = GridFactory()
generator.rng = np.random.default_rng()
for _ in range(10):
grid = generator.grid_from_generator()
assert Grid.verify_grid_connectivity(grid.grid)
height, width = grid.grid.shape
assert Grid.generals_distance(grid) >= max(height, width) // 2



def test_numpify_map():
Expand Down

0 comments on commit 3cfac8d

Please sign in to comment.