diff --git a/generals/core/replay.py b/generals/core/replay.py index 64d17f4..01c4021 100644 --- a/generals/core/replay.py +++ b/generals/core/replay.py @@ -3,6 +3,7 @@ from generals.core.grid import Grid from generals.core.game import Game +from generals.gui.event_handler import ReplayCommand from generals.gui import GUI from copy import deepcopy @@ -33,7 +34,7 @@ def load(cls, path): def play(self): agents = [agent for agent in self.agent_data.keys()] game = Game(self.grid, agents) - gui = GUI(game, self.agent_data, from_replay=True) + gui = GUI(game, self.agent_data, mode="replay") gui_properties = gui.properties game_step, last_input_time, last_move_time = 0, 0, 0 @@ -41,15 +42,15 @@ def play(self): _t = time.time() # Check inputs if _t - last_input_time > 0.008: # check for input every 8ms - control_events = gui.tick() + command = gui.tick() last_input_time = _t else: - control_events = {"time_change": 0} - if "restart" in control_events: + command = ReplayCommand() + if command.restart: game_step = 0 # If we control replay, change game state game_step = max( - 0, min(len(self.game_states) - 1, game_step + control_events["time_change"]) + 0, min(len(self.game_states) - 1, game_step + command.frame_change) ) if gui_properties.paused and game_step != game.time: game.channels = deepcopy(self.game_states[game_step]) @@ -57,7 +58,7 @@ def play(self): last_move_time = _t # If we are not paused, play the game elif ( - _t - last_move_time > gui_properties.game_speed * 0.512 + _t - last_move_time > (1/gui_properties.game_speed) * 0.512 and not gui_properties.paused ): if game.is_done(): diff --git a/generals/envs/gymnasium_integration.py b/generals/envs/gymnasium_integration.py index e9b4942..cd1c991 100644 --- a/generals/envs/gymnasium_integration.py +++ b/generals/envs/gymnasium_integration.py @@ -58,11 +58,11 @@ def action_space(self) -> gym.Space: def render(self, fps: int = 6) -> None: if self.render_mode == "human": - self.gui.tick(fps=fps) + _ = self.gui.tick(fps=fps) - def reset(self, seed: int | None = None, options: dict[str, Any] | None = None) -> tuple[ - Observation, dict[str, Any] - ]: + def reset( + self, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[Observation, dict[str, Any]]: if options is None: options = {} super().reset(seed=seed) @@ -95,7 +95,9 @@ def reset(self, seed: int | None = None, options: dict[str, Any] | None = None) info = {} return observation, info - def step(self, action: Action) -> tuple[Observation, SupportsFloat, bool, bool, dict[str, Any]]: + def step( + self, action: Action + ) -> tuple[Observation, SupportsFloat, bool, bool, dict[str, Any]]: # get action of NPC npc_action = self.npc.play(self.game._agent_observation(self.npc.name)) actions = {self.agent_name: action, self.npc.name: npc_action} @@ -133,3 +135,6 @@ def _default_reward( else: reward = 0 return reward + + def close(self) -> None: + self.gui.close() diff --git a/generals/envs/pettingzoo_integration.py b/generals/envs/pettingzoo_integration.py index 9afa322..9e5f4a4 100644 --- a/generals/envs/pettingzoo_integration.py +++ b/generals/envs/pettingzoo_integration.py @@ -59,7 +59,7 @@ def action_space(self, agent: AgentID) -> spaces.Space: def render(self, fps=6) -> None: if self.render_mode == "human": - self.gui.tick(fps=fps) + _ = self.gui.tick(fps=fps) def reset( self, seed: int | None = None, options: dict | None = None @@ -76,7 +76,7 @@ def reset( self.game = Game(grid, self.agents) if self.render_mode == "human": - self.gui = GUI(self.game, self.agent_data) + self.gui = GUI(self.game, self.agent_data, "train") if "replay_file" in options: self.replay = Replay( @@ -146,5 +146,5 @@ def _default_reward( reward = 0 return reward - def close(self): - print("Closing environment") + def close(self) -> None: + self.gui.close() diff --git a/generals/gui/event_handler.py b/generals/gui/event_handler.py index fc56a62..50f1012 100644 --- a/generals/gui/event_handler.py +++ b/generals/gui/event_handler.py @@ -1,77 +1,149 @@ import pygame +from pygame.event import Event +from abc import abstractmethod from .properties import Properties +from generals.core import config as c + +# keybindings # +RIGHT = pygame.K_RIGHT +LEFT = pygame.K_LEFT +SPACE = pygame.K_SPACE +Q = pygame.K_q +R = pygame.K_r +H = pygame.K_h +L = pygame.K_l + + +class Command: + def __init__(self): + self.quit: bool = False + + +class ReplayCommand(Command): + def __init__(self): + super().__init__() + self.frame_change: int = 0 + self.speed_change: float = 1.0 + self.restart: bool = False + self.pause_toggle: bool = False + + +class GameCommand(Command): + def __init__(self): + super().__init__() + raise NotImplementedError + + +class TrainCommand(Command): + def __init__(self): + super().__init__() class EventHandler: - def __init__(self, properties: Properties, from_replay=False): + def __init__(self, properties: Properties): """ Initialize the event handler. Args: properties: the Properties object - from_replay: bool, whether the game is from a replay """ self.properties = properties - self.from_replay = from_replay + self.mode = properties.mode - def handle_events(self): + def handle_events(self) -> Command: """ Handle pygame GUI events """ - control_events = { - "time_change": 0, - } for event in pygame.event.get(): - if event.type == pygame.QUIT or ( - event.type == pygame.KEYDOWN and event.key == pygame.K_q - ): - pygame.quit() - quit() - - if event.type == pygame.KEYDOWN and self.from_replay: - self.__handle_key_controls(event, control_events) + if event.type == pygame.QUIT: + self.command.quit = True + if event.type == pygame.KEYDOWN: + self.handle_key_event(event) elif event.type == pygame.MOUSEBUTTONDOWN: - self.__handle_mouse_click() - - return control_events - + self.handle_mouse_event() + return self.command - def __handle_key_controls(self, event, control_events): + def is_click_on_agents_row(self, x: int, y: int, i: int) -> bool: """ - Handle key controls for replay mode. - Control game speed, pause, and replay frames. + Check if the click is on an agent's row. + + Args: + x: int, x-coordinate of the click + y: int, y-coordinate of the click + i: int, index of the row """ - match event.key: - # Speed up game right arrow is pressed - case pygame.K_RIGHT: - self.properties.game_speed = max(1 / 128, self.properties.game_speed / 2) - # Slow down game left arrow is pressed - case pygame.K_LEFT: - self.properties.game_speed = min(32.0, self.properties.game_speed * 2) - # Toggle play/pause - case pygame.K_SPACE: - self.properties.paused = not self.properties.paused - case pygame.K_r: - control_events["restart"] = True - # Control replay frames - case pygame.K_h: - control_events["time_change"] = -1 - self.properties.paused = True - case pygame.K_l: - control_events["time_change"] = 1 - self.properties.paused = True - - - def __handle_mouse_click(self): + return ( + x >= self.properties.display_grid_width + and (i + 1) * c.GUI_ROW_HEIGHT <= y < (i + 2) * c.GUI_ROW_HEIGHT + ) + + @abstractmethod + def handle_key_event(self, event: Event) -> Command: + raise NotImplementedError + + @abstractmethod + def handle_mouse_event(self): + raise NotImplementedError + + +class ReplayEventHandler(EventHandler): + def __init__(self, properties: Properties): + super().__init__(properties) + self.command = ReplayCommand() + + def handle_key_event(self, event: Event) -> ReplayCommand: + if event.key == Q: + self.command.quit = True + elif event.key == RIGHT: + self.command.speed_change = 2.0 + elif event.key == LEFT: + self.command.speed_change = 0.5 + elif event.key == SPACE: + self.command.pause_toggle = True + elif event.key == R: + self.command.restart = True + elif event.key == H: + self.command.frame_change = -1 + elif event.key == L: + self.command.frame_change = 1 + return self.command + + def handle_mouse_event(self) -> None: """ - Handle mouse click event. + Handle mouse clicks in replay mode. """ agents = self.properties.game.agents agent_fov = self.properties.agent_fov x, y = pygame.mouse.get_pos() for i, agent in enumerate(agents): - if self.properties.is_click_on_agents_row(x, y, i): + if self.is_click_on_agents_row(x, y, i): agent_fov[agent] = not agent_fov[agent] break + + +class GameEventHandler(EventHandler): + def __init__(self, properties: Properties): + super().__init__(properties) + self.command = GameCommand() + + def handle_key_event(self, event: Event) -> GameCommand: + raise NotImplementedError + + def handle_mouse_event(self) -> None: + raise NotImplementedError + + +class TrainEventHandler(EventHandler): + def __init__(self, properties: Properties): + super().__init__(properties) + self.command = TrainCommand() + + def handle_key_event(self, event: Event) -> TrainCommand: + if event.key == Q: + self.command.quit = True + return self.command + + def handle_mouse_event(self) -> None: + raise NotImplementedError diff --git a/generals/gui/gui.py b/generals/gui/gui.py index 43f24da..531fda8 100644 --- a/generals/gui/gui.py +++ b/generals/gui/gui.py @@ -1,20 +1,50 @@ -from typing import Any +import pygame +from typing import Any, Literal from generals.core.game import Game from .properties import Properties -from .event_handler import EventHandler +from .event_handler import TrainEventHandler, GameEventHandler, ReplayEventHandler from .rendering import Renderer class GUI: def __init__( - self, game: Game, agent_data: dict[str, dict[str, Any]], from_replay=False + self, + game: Game, + agent_data: dict[str, dict[str, Any]], + mode: Literal["train", "game", "replay"] = "train", ): - self.properties = Properties(game, agent_data) + self.properties = Properties(game, agent_data, mode) self.__renderer = Renderer(self.properties) - self.__event_handler = EventHandler(self.properties, from_replay) + self.__event_handler = self.__initialize_event_handler() + + pygame.init() + pygame.display.set_caption("Generals") + + # Handle key repeats + pygame.key.set_repeat(500, 64) + + def __initialize_event_handler(self): + if self.properties.mode == "train": + return TrainEventHandler + elif self.properties.mode == "game": + return GameEventHandler + elif self.properties.mode == "replay": + return ReplayEventHandler def tick(self, fps=None): - control_events = self.__event_handler.handle_events() + handler = self.__event_handler(self.properties) + command = handler.handle_events() + if command.quit: + quit() + if self.properties.mode == "replay": + self.properties.update_speed(command.speed_change) + if command.frame_change != 0 or command.restart: + self.properties.paused = True + if command.pause_toggle: + self.properties.paused = not self.properties.paused self.__renderer.render(fps) - return control_events + return command + + def close(self): + pygame.quit() diff --git a/generals/gui/properties.py b/generals/gui/properties.py index 6bc43d2..992da09 100644 --- a/generals/gui/properties.py +++ b/generals/gui/properties.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any +from typing import Any, Literal from pygame.time import Clock @@ -11,7 +11,7 @@ class Properties: __game: Game __agent_data: dict[str, dict[str, Any]] - __paused: bool = False + __mode: Literal["train", "game", "replay"] __game_speed: int = 1 __clock: Clock = Clock() @@ -22,21 +22,11 @@ def __post_init__(self): self.__display_grid_height: int = c.SQUARE_SIZE * self.grid_height self.__right_panel_width: int = 4 * c.GUI_CELL_WIDTH - self.__agent_fov: dict[str, bool] = {name: True for name in self.agent_data.keys()} + self.__paused: bool = False - def is_click_on_agents_row(self, x: int, y: int, i: int) -> bool: - """ - Check if the click is on an agent's row. - - Args: - x: int, x-coordinate of the click - y: int, y-coordinate of the click - i: int, index of the row - """ - return ( - x >= self.display_grid_width - and (i + 1) * c.GUI_ROW_HEIGHT <= y < (i + 2) * c.GUI_ROW_HEIGHT - ) + self.__agent_fov: dict[str, bool] = { + name: True for name in self.agent_data.keys() + } @property def game(self): @@ -46,6 +36,10 @@ def game(self): def agent_data(self): return self.__agent_data + @property + def mode(self): + return self.__mode + @property def paused(self): return self.__paused @@ -59,8 +53,9 @@ def game_speed(self): return self.__game_speed @game_speed.setter - def game_speed(self, value: int): - self.__game_speed = value + def game_speed(self, value: float): + new_speed = min(32.0, max(0.25, value)) # clip speed + self.__game_speed = new_speed @property def clock(self): @@ -89,3 +84,8 @@ def display_grid_height(self): @property def right_panel_width(self): return self.__right_panel_width + + def update_speed(self, multiplier: float) -> None: + """multiplier: usually 2.0 or 0.5""" + new_speed = self.game_speed * multiplier + self.game_speed = new_speed diff --git a/generals/gui/rendering.py b/generals/gui/rendering.py index 334d92a..4cc41a6 100644 --- a/generals/gui/rendering.py +++ b/generals/gui/rendering.py @@ -16,6 +16,7 @@ def __init__(self, properties: Properties): self.properties = properties + self.mode = self.properties.mode self.game = self.properties.game self.agent_data = self.properties.agent_data @@ -140,7 +141,9 @@ def render_stats(self): info_text = { "time": f"Time: {str(self.game.time // 2) + ('.' if self.game.time % 2 == 1 else '')}", - "speed": "Paused" if self.properties.paused else f"Speed: {str(1 / self.properties.game_speed)}x", + "speed": "Paused" + if self.mode == "replay" and self.properties.paused + else f"Speed: {str(self.properties.game_speed)}x", } # Write additional info