Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Actions are now stored in a .pkl file #84

Merged
merged 7 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,7 @@ venv/
.vscode

#databases
*.db
*.db

# stored actions
*.pkl
Empty file.
20 changes: 19 additions & 1 deletion comprl/server/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import abc
from typing import Callable, Optional
from datetime import datetime
import numpy as np
import pickle

from comprl.shared.types import GameID, PlayerID
from comprl.server.util import IDGenerator
Expand Down Expand Up @@ -111,6 +113,12 @@ def __init__(self, players: list[IPlayer]) -> None:
self.scores: dict[PlayerID, float] = {p.id: 0.0 for p in players}
self.start_time = datetime.now()
self.disconnected_player_id: PlayerID | None = None
# dict storing all actions and possible more to be saved later.
# "actions" is a list of all actions in the game
self.game_info: dict[str, list[np.ndarray]] = {}
self.all_actions: list[np.ndarray] = []
# When writing a game class you can fill the dict game_info with more
# information

def add_finish_callback(self, callback: Callable[["IGame"], None]) -> None:
"""
Expand All @@ -134,12 +142,21 @@ def start(self):

def _end(self, reason="unknown"):
"""
Notifies all players that the game has ended.
Notifies all players that the game has ended and writes all actions in a file.

Args:
reason (str): The reason why the game has ended. Defaults to "unknown".
"""

# store actions:
# TODO: maybe add multithreading here to ease the load on the main server thread
# as storing the actions can take a while
self.game_info["actions"] = np.array(self.all_actions)

with open("comprl/server/game_actions/" + str(self.id) + ".pkl", "wb") as f:
pickle.dump(self.game_info, f)

# notify end
for callback in self.finish_callbacks:
callback(self)

Expand All @@ -166,6 +183,7 @@ def _res(value, id=p.id):
# update the game, and if the game is over, end it
if self.disconnected_player_id is not None:
return
self.all_actions.append([actions[p] for p in actions])
if not self._update(actions):
self._run()
else:
Expand Down
14 changes: 14 additions & 0 deletions comprl/server/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
"""

import logging as log
import pickle
from typing import Type
import numpy as np

from comprl.server.interfaces import IGame, IPlayer
from comprl.shared.types import GameID, PlayerID
Expand Down Expand Up @@ -88,6 +90,18 @@ def get(self, game_id: GameID) -> IGame | None:
"""
return self.games.get(game_id, None)

def get_stored_actions(self, game_id: GameID) -> dict[str, list[np.ndarray]]:
"""get a game from the log file

Args:
game_id (GameID): id of the game we want to get
Returns:
dict[str, list[np.ndarray]]: the dict containing the actions and possible
more info
"""
with open("comprl/server/game_actions/" + str(game_id) + ".pkl", "rb") as f:
return pickle.load(f)


class PlayerManager:
"""
Expand Down
46 changes: 41 additions & 5 deletions examples/hockey/hockey_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import laserhockey.hockey_env as h_env
import numpy as np
from typing import List


class HockeyGame(IGame):
Expand All @@ -20,8 +21,9 @@ def __init__(self, players: list[IPlayer]) -> None:
self.player_1_id = players[0].id
self.player_2_id = players[1].id

# number of rounds played per game and current score
self.remaining_rounds: int = 4
# number of rounds played per game
self.num_rounds: int = 4
self.remaining_rounds = self.num_rounds
# Bool weather players play in default orientation or sides are swapped.
# Alternates between rounds.
self.sides_swapped = False
Expand All @@ -32,6 +34,21 @@ def __init__(self, players: list[IPlayer]) -> None:
self.terminated = False
self.truncated = False

# array storing all actions/observations to be saved later.
# contains tuples consisting of the actions of the players
# (in the order of the player dictionary)
self.actions_this_round: List[np.ndarray] = []
self.observations_this_round: List[np.ndarray] = []

# self.game_info also contains:
# - num_rounds: number of rounds played
# - actions_round_0: actions of the first round
# - actions_round_1: actions of the second round
# - ...
# - actions_round_(num_rounds-1): actions of the last round
# - observations_round_0: observations of the first round
# - ...

super().__init__(players)

def start(self):
Expand All @@ -41,25 +58,30 @@ def start(self):
"""

self.obs_player_one, self.info = self.env.reset()
self.observations_this_round.append(self.obs_player_one)
return super().start()

def end(self, reason="unknown"):
def _end(self, reason="unknown"):
"""notifies all players that the game has ended

Args:
reason (str, optional): reason why the game has ended.
Defaults to "unknown"
"""
self.env.close()
return super().end(reason)
# add useful information to the game_info
self.game_info["num_rounds"] = [
np.array([self.num_rounds])
] # to respect type of dict
return super()._end(reason)

def _update(self, actions_dict: dict[PlayerID, list[float]]) -> bool:
"""perform one gym step, using the actions

Returns:
bool: True if the game is over, False otherwise.
"""
self.env.render(mode="human") # (un)comment to render or not
# self.env.render(mode="human") # (un)comment to render or not

self.action = np.hstack(
[
Expand All @@ -75,6 +97,10 @@ def _update(self, actions_dict: dict[PlayerID, list[float]]) -> bool:
self.info,
) = self.env.step(self.action)

# store the actions and observations
self.actions_this_round.append(self.action)
self.observations_this_round.append(self.obs_player_one)

# check if current round has ended
if self.terminated or self.truncated:
# update score
Expand All @@ -84,6 +110,16 @@ def _update(self, actions_dict: dict[PlayerID, list[float]]) -> bool:
if self.winner == -1:
self.scores[self.player_2_id] += 1

# store the actions and observations of the round
self.game_info[
"actions_round_" + str(self.num_rounds - self.remaining_rounds)
] = self.actions_this_round
self.actions_this_round = []
self.game_info[
"observations_round_" + str(self.num_rounds - self.remaining_rounds)
] = self.observations_this_round
self.observations_this_round = []

# reset env, swap player side, swap player ids and decrease remaining rounds
self.obs_player_one, self.info = self.env.reset()
self.sides_swapped = not self.sides_swapped
Expand Down
Loading