Skip to content

Commit

Permalink
Merge pull request #71 from strakam/refactors
Browse files Browse the repository at this point in the history
refactor: Improve handle events
  • Loading branch information
strakam authored Oct 2, 2024
2 parents d7f9924 + 7a9c526 commit f4dd9f3
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 87 deletions.
13 changes: 7 additions & 6 deletions generals/core/replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from generals.core.grid import Grid
from generals.core.game import Game
from generals.gui.event_handler import ReplayCommand
from generals.gui import GUI
from copy import deepcopy

Expand Down Expand Up @@ -33,31 +34,31 @@ def load(cls, path):
def play(self):
agents = [agent for agent in self.agent_data.keys()]
game = Game(self.grid, agents)
gui = GUI(game, self.agent_data, from_replay=True)
gui = GUI(game, self.agent_data, mode="replay")
gui_properties = gui.properties

game_step, last_input_time, last_move_time = 0, 0, 0
while 1:
_t = time.time()
# Check inputs
if _t - last_input_time > 0.008: # check for input every 8ms
control_events = gui.tick()
command = gui.tick()
last_input_time = _t
else:
control_events = {"time_change": 0}
if "restart" in control_events:
command = ReplayCommand()
if command.restart:
game_step = 0
# If we control replay, change game state
game_step = max(
0, min(len(self.game_states) - 1, game_step + control_events["time_change"])
0, min(len(self.game_states) - 1, game_step + command.frame_change)
)
if gui_properties.paused and game_step != game.time:
game.channels = deepcopy(self.game_states[game_step])
game.time = game_step
last_move_time = _t
# If we are not paused, play the game
elif (
_t - last_move_time > gui_properties.game_speed * 0.512
_t - last_move_time > (1/gui_properties.game_speed) * 0.512
and not gui_properties.paused
):
if game.is_done():
Expand Down
15 changes: 10 additions & 5 deletions generals/envs/gymnasium_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ def action_space(self) -> gym.Space:

def render(self, fps: int = 6) -> None:
if self.render_mode == "human":
self.gui.tick(fps=fps)
_ = self.gui.tick(fps=fps)

def reset(self, seed: int | None = None, options: dict[str, Any] | None = None) -> tuple[
Observation, dict[str, Any]
]:
def reset(
self, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[Observation, dict[str, Any]]:
if options is None:
options = {}
super().reset(seed=seed)
Expand Down Expand Up @@ -95,7 +95,9 @@ def reset(self, seed: int | None = None, options: dict[str, Any] | None = None)
info = {}
return observation, info

def step(self, action: Action) -> tuple[Observation, SupportsFloat, bool, bool, dict[str, Any]]:
def step(
self, action: Action
) -> tuple[Observation, SupportsFloat, bool, bool, dict[str, Any]]:
# get action of NPC
npc_action = self.npc.play(self.game._agent_observation(self.npc.name))
actions = {self.agent_name: action, self.npc.name: npc_action}
Expand Down Expand Up @@ -133,3 +135,6 @@ def _default_reward(
else:
reward = 0
return reward

def close(self) -> None:
self.gui.close()
8 changes: 4 additions & 4 deletions generals/envs/pettingzoo_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def action_space(self, agent: AgentID) -> spaces.Space:

def render(self, fps=6) -> None:
if self.render_mode == "human":
self.gui.tick(fps=fps)
_ = self.gui.tick(fps=fps)

def reset(
self, seed: int | None = None, options: dict | None = None
Expand All @@ -76,7 +76,7 @@ def reset(
self.game = Game(grid, self.agents)

if self.render_mode == "human":
self.gui = GUI(self.game, self.agent_data)
self.gui = GUI(self.game, self.agent_data, "train")

if "replay_file" in options:
self.replay = Replay(
Expand Down Expand Up @@ -146,5 +146,5 @@ def _default_reward(
reward = 0
return reward

def close(self):
print("Closing environment")
def close(self) -> None:
self.gui.close()
164 changes: 118 additions & 46 deletions generals/gui/event_handler.py
Original file line number Diff line number Diff line change
@@ -1,77 +1,149 @@
import pygame
from pygame.event import Event
from abc import abstractmethod

from .properties import Properties
from generals.core import config as c

# keybindings #
RIGHT = pygame.K_RIGHT
LEFT = pygame.K_LEFT
SPACE = pygame.K_SPACE
Q = pygame.K_q
R = pygame.K_r
H = pygame.K_h
L = pygame.K_l


class Command:
def __init__(self):
self.quit: bool = False


class ReplayCommand(Command):
def __init__(self):
super().__init__()
self.frame_change: int = 0
self.speed_change: float = 1.0
self.restart: bool = False
self.pause_toggle: bool = False


class GameCommand(Command):
def __init__(self):
super().__init__()
raise NotImplementedError


class TrainCommand(Command):
def __init__(self):
super().__init__()


class EventHandler:
def __init__(self, properties: Properties, from_replay=False):
def __init__(self, properties: Properties):
"""
Initialize the event handler.
Args:
properties: the Properties object
from_replay: bool, whether the game is from a replay
"""
self.properties = properties
self.from_replay = from_replay
self.mode = properties.mode

def handle_events(self):
def handle_events(self) -> Command:
"""
Handle pygame GUI events
"""
control_events = {
"time_change": 0,
}
for event in pygame.event.get():
if event.type == pygame.QUIT or (
event.type == pygame.KEYDOWN and event.key == pygame.K_q
):
pygame.quit()
quit()

if event.type == pygame.KEYDOWN and self.from_replay:
self.__handle_key_controls(event, control_events)
if event.type == pygame.QUIT:
self.command.quit = True
if event.type == pygame.KEYDOWN:
self.handle_key_event(event)
elif event.type == pygame.MOUSEBUTTONDOWN:
self.__handle_mouse_click()

return control_events

self.handle_mouse_event()
return self.command

def __handle_key_controls(self, event, control_events):
def is_click_on_agents_row(self, x: int, y: int, i: int) -> bool:
"""
Handle key controls for replay mode.
Control game speed, pause, and replay frames.
Check if the click is on an agent's row.
Args:
x: int, x-coordinate of the click
y: int, y-coordinate of the click
i: int, index of the row
"""
match event.key:
# Speed up game right arrow is pressed
case pygame.K_RIGHT:
self.properties.game_speed = max(1 / 128, self.properties.game_speed / 2)
# Slow down game left arrow is pressed
case pygame.K_LEFT:
self.properties.game_speed = min(32.0, self.properties.game_speed * 2)
# Toggle play/pause
case pygame.K_SPACE:
self.properties.paused = not self.properties.paused
case pygame.K_r:
control_events["restart"] = True
# Control replay frames
case pygame.K_h:
control_events["time_change"] = -1
self.properties.paused = True
case pygame.K_l:
control_events["time_change"] = 1
self.properties.paused = True


def __handle_mouse_click(self):
return (
x >= self.properties.display_grid_width
and (i + 1) * c.GUI_ROW_HEIGHT <= y < (i + 2) * c.GUI_ROW_HEIGHT
)

@abstractmethod
def handle_key_event(self, event: Event) -> Command:
raise NotImplementedError

@abstractmethod
def handle_mouse_event(self):
raise NotImplementedError


class ReplayEventHandler(EventHandler):
def __init__(self, properties: Properties):
super().__init__(properties)
self.command = ReplayCommand()

def handle_key_event(self, event: Event) -> ReplayCommand:
if event.key == Q:
self.command.quit = True
elif event.key == RIGHT:
self.command.speed_change = 2.0
elif event.key == LEFT:
self.command.speed_change = 0.5
elif event.key == SPACE:
self.command.pause_toggle = True
elif event.key == R:
self.command.restart = True
elif event.key == H:
self.command.frame_change = -1
elif event.key == L:
self.command.frame_change = 1
return self.command

def handle_mouse_event(self) -> None:
"""
Handle mouse click event.
Handle mouse clicks in replay mode.
"""
agents = self.properties.game.agents
agent_fov = self.properties.agent_fov

x, y = pygame.mouse.get_pos()
for i, agent in enumerate(agents):
if self.properties.is_click_on_agents_row(x, y, i):
if self.is_click_on_agents_row(x, y, i):
agent_fov[agent] = not agent_fov[agent]
break


class GameEventHandler(EventHandler):
def __init__(self, properties: Properties):
super().__init__(properties)
self.command = GameCommand()

def handle_key_event(self, event: Event) -> GameCommand:
raise NotImplementedError

def handle_mouse_event(self) -> None:
raise NotImplementedError


class TrainEventHandler(EventHandler):
def __init__(self, properties: Properties):
super().__init__(properties)
self.command = TrainCommand()

def handle_key_event(self, event: Event) -> TrainCommand:
if event.key == Q:
self.command.quit = True
return self.command

def handle_mouse_event(self) -> None:
raise NotImplementedError
44 changes: 37 additions & 7 deletions generals/gui/gui.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,50 @@
from typing import Any
import pygame
from typing import Any, Literal

from generals.core.game import Game
from .properties import Properties
from .event_handler import EventHandler
from .event_handler import TrainEventHandler, GameEventHandler, ReplayEventHandler
from .rendering import Renderer


class GUI:
def __init__(
self, game: Game, agent_data: dict[str, dict[str, Any]], from_replay=False
self,
game: Game,
agent_data: dict[str, dict[str, Any]],
mode: Literal["train", "game", "replay"] = "train",
):
self.properties = Properties(game, agent_data)
self.properties = Properties(game, agent_data, mode)
self.__renderer = Renderer(self.properties)
self.__event_handler = EventHandler(self.properties, from_replay)
self.__event_handler = self.__initialize_event_handler()

pygame.init()
pygame.display.set_caption("Generals")

# Handle key repeats
pygame.key.set_repeat(500, 64)

def __initialize_event_handler(self):
if self.properties.mode == "train":
return TrainEventHandler
elif self.properties.mode == "game":
return GameEventHandler
elif self.properties.mode == "replay":
return ReplayEventHandler

def tick(self, fps=None):
control_events = self.__event_handler.handle_events()
handler = self.__event_handler(self.properties)
command = handler.handle_events()
if command.quit:
quit()
if self.properties.mode == "replay":
self.properties.update_speed(command.speed_change)
if command.frame_change != 0 or command.restart:
self.properties.paused = True
if command.pause_toggle:
self.properties.paused = not self.properties.paused
self.__renderer.render(fps)
return control_events
return command

def close(self):
pygame.quit()
Loading

0 comments on commit f4dd9f3

Please sign in to comment.