diff --git a/generals/core/game.py b/generals/core/game.py index dadc849..db9808d 100644 --- a/generals/core/game.py +++ b/generals/core/game.py @@ -88,7 +88,7 @@ def action_mask(self, agent: str) -> np.ndarray: owned_cells_indices = self.channel_to_indices(more_than_1_army) valid_action_mask = np.zeros((self.grid_dims[0], self.grid_dims[1], 4), dtype=bool) - if self.is_done(): + if self.is_done() and not self.agent_won(agent): # if you lost, return all zeros return valid_action_mask for channel_index, direction in enumerate(DIRECTIONS): @@ -107,6 +107,7 @@ def action_mask(self, agent: str) -> np.ndarray: # get valid action mask for a given direction valid_source_indices = action_destinations - direction.value valid_action_mask[valid_source_indices[:, 0], valid_source_indices[:, 1], channel_index] = 1.0 + # assert False return valid_action_mask def channel_to_indices(self, channel: np.ndarray) -> np.ndarray: @@ -197,7 +198,7 @@ def step(self, actions: dict[str, Action]) -> tuple[dict[str, Observation], dict self.time += 1 if self.is_done(): - # Give all cells of loser to winner + # give all cells of loser to winner winner = self.agents[0] if self.agent_won(self.agents[0]) else self.agents[1] loser = self.agents[1] if winner == self.agents[0] else self.agents[0] self.channels.ownership[winner] += self.channels.ownership[loser] diff --git a/generals/core/grid.py b/generals/core/grid.py index 842477b..3083ca1 100644 --- a/generals/core/grid.py +++ b/generals/core/grid.py @@ -1,4 +1,5 @@ import numpy as np +from numpy.random import Generator from .config import MOUNTAIN, PASSABLE @@ -24,7 +25,7 @@ def grid(self, grid: str | np.ndarray): pass case _: raise ValueError("Grid must be encoded as a string or a numpy array.") - if not Grid.verify_grid(grid): + if not Grid.verify_grid_connectivity(grid): raise ValueError("Invalid grid layout - generals cannot reach each other.") # check that exactly one 'A' and one 'B' are present in the grid first_general = np.argwhere(np.isin(grid, ["A"])) @@ -34,6 +35,11 @@ def grid(self, grid: str | np.ndarray): self._grid = grid + @staticmethod + def generals_distance(grid: "Grid") -> int: + generals = np.argwhere(np.isin(grid.grid, ["A", "B"])) + return abs(generals[0][0] - generals[1][0]) + abs(generals[0][1] - generals[1][1]) + @staticmethod def numpify_grid(grid: str) -> np.ndarray: return np.array([list(row) for row in grid.strip().split("\n")]) @@ -43,17 +49,21 @@ def stringify_grid(grid: np.ndarray) -> str: return "\n".join(["".join(row) for row in grid]) @staticmethod - def verify_grid(grid: np.ndarray) -> bool: + def verify_grid_connectivity(grid: np.ndarray | str) -> bool: """ Verify grid layout (can generals reach each other?) Returns True if grid is valid, False otherwise """ + if isinstance(grid, str): + grid = Grid.numpify_grid(grid) + + height, width = grid.shape def dfs(grid, visited, square): i, j = square - if i < 0 or i >= grid.shape[0] or j < 0 or j >= grid.shape[1] or visited[i, j]: + if i < 0 or i >= height or j < 0 or j >= width or visited[i, j]: return - if grid[i, j] == MOUNTAIN: + if grid[i, j] == MOUNTAIN or str(grid[i, j]).isdigit(): # mountain or city return visited[i, j] = True for di, dj in [[-1, 0], [1, 0], [0, -1], [0, 1]]: @@ -62,6 +72,7 @@ def dfs(grid, visited, square): generals = np.argwhere(np.isin(grid, ["A", "B"])) start, end = generals[0], generals[1] + visited = np.zeros_like(grid, dtype=bool) dfs(grid, visited, start) return visited[end[0], end[1]] @@ -84,7 +95,15 @@ def __init__( self.mountain_density = mountain_density self.city_density = city_density self.general_positions = general_positions - self.seed = seed + self._rng = np.random.default_rng(seed) + + @property + def rng(self): + return self._rng + + @rng.setter + def rng(self, number_generator: Generator): + self._rng = number_generator def grid_from_string(self, grid: str) -> Grid: return Grid(grid) @@ -105,32 +124,28 @@ def grid_from_generator( city_density = self.city_density if general_positions is None: general_positions = self.general_positions - if seed is None: - if self.seed is None: - seed = np.random.randint(0, 2**20) - else: - seed = self.seed + if seed is not None: + self.rng = np.random.default_rng(seed) # Probabilities of each cell type p_neutral = 1 - mountain_density - city_density probs = [p_neutral, mountain_density] + [city_density / 10] * 10 # Place cells on the map - rng = np.random.default_rng(seed) - map = rng.choice( + map = self.rng.choice( [PASSABLE, MOUNTAIN, "0", "1", "2", "3", "4", "5", "6", "7", "8", "9"], size=grid_dims, p=probs, ) - # Place generals on random squares - generals_positions is a list of two tuples - if general_positions is None: - general_positions = [] - while len(general_positions) < 2: - position = tuple(rng.integers(0, grid_dims)) - if position not in general_positions: - general_positions.append(position) - + # Place generals on random squares, they should be atleast some distance apart + min_distance = max(grid_dims) // 2 + p1 = self.rng.integers(0, grid_dims[0]), self.rng.integers(0, grid_dims[1]) + while True: + p2 = self.rng.integers(0, grid_dims[0]), self.rng.integers(0, grid_dims[1]) + if abs(p1[0] - p2[0]) + abs(p1[1] - p2[1]) >= min_distance: + break + general_positions = [p1, p2] for i, idx in enumerate(general_positions): map[idx[0], idx[1]] = chr(ord("A") + i) @@ -140,11 +155,10 @@ def grid_from_generator( try: return Grid(map_string) except ValueError: - seed += 1 # Increase seed to generate a different map return self.grid_from_generator( grid_dims=grid_dims, mountain_density=mountain_density, city_density=city_density, general_positions=general_positions, - seed=seed, + seed=None, ) diff --git a/generals/envs/gymnasium_generals.py b/generals/envs/gymnasium_generals.py index e036ef9..febaf91 100644 --- a/generals/envs/gymnasium_generals.py +++ b/generals/envs/gymnasium_generals.py @@ -67,8 +67,8 @@ def reset( if "grid" in options: grid = self.grid_factory.grid_from_string(options["grid"]) else: - map_seed = self.np_random.integers(0, 2**20) - grid = self.grid_factory.grid_from_generator(seed=map_seed) + self.grid_factory.rng = self.np_random + grid = self.grid_factory.grid_from_generator() # Create game for current run self.game = Game(grid, self.agent_ids) diff --git a/tests/test_map.py b/tests/test_grid.py similarity index 70% rename from tests/test_map.py rename to tests/test_grid.py index 460dcea..38aa974 100644 --- a/tests/test_map.py +++ b/tests/test_grid.py @@ -1,6 +1,6 @@ import numpy as np -from generals.core.grid import Grid +from generals.core.grid import Grid, GridFactory def test_grid_creation(): @@ -16,7 +16,6 @@ def test_grid_creation(): grid_nd_array = Grid(map_nd_array) assert grid_str == grid_nd_array - def test_verify_grid(): map = """ ..... @@ -25,8 +24,8 @@ def test_verify_grid(): ..22. ...B. """ - map = Grid.numpify_grid(map) - assert Grid.verify_grid(map) + _grid = Grid(map) + assert Grid.verify_grid_connectivity(_grid.grid) map = """ ..... @@ -37,27 +36,7 @@ def test_verify_grid(): """ map = Grid.numpify_grid(map) - assert not Grid.verify_grid(map) - - map = """ -..... -.A##2 -##.2. -..2## -...B. - """ - map = Grid.numpify_grid(map) - assert Grid.verify_grid(map) - - map = """ -..#.. -.A##2 -##.2. -..2## -...B. - """ - map = Grid.numpify_grid(map) - assert not Grid.verify_grid(map) + assert not Grid.verify_grid_connectivity(map) map = """ ..... @@ -67,7 +46,7 @@ def test_verify_grid(): ..... """ map = Grid.numpify_grid(map) - assert Grid.verify_grid(map) + assert Grid.verify_grid_connectivity(map) map = """ ...#. @@ -77,7 +56,7 @@ def test_verify_grid(): ..... """ map = Grid.numpify_grid(map) - assert not Grid.verify_grid(map) + assert not Grid.verify_grid_connectivity(map) map = """ ...#. @@ -87,7 +66,17 @@ def test_verify_grid(): ..... """ map = Grid.numpify_grid(map) - assert Grid.verify_grid(map) + assert not Grid.verify_grid_connectivity(map) + +def test_grid_factory(): + generator = GridFactory() + generator.rng = np.random.default_rng() + for _ in range(10): + grid = generator.grid_from_generator() + assert Grid.verify_grid_connectivity(grid.grid) + height, width = grid.grid.shape + assert Grid.generals_distance(grid) >= max(height, width) // 2 + def test_numpify_map():