Skip to content

Commit

Permalink
Merge pull request #82 from martius-lab/carina/disconnect-error-handling
Browse files Browse the repository at this point in the history
Disconnect/Timeout handling and Input Validation
  • Loading branch information
Cari1111 authored Feb 26, 2024
2 parents f42a268 + eb542c4 commit 1ebdbe3
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 53 deletions.
5 changes: 3 additions & 2 deletions comprl/client/networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def connectionLost(self, reason):
reason (object): The reason for the lost connection.
"""
log.debug(f"Disconnected from the server. Reason: {reason}")
reactor.stop()
if reactor.running:
reactor.stop()
return super().connectionLost(reason)

@Auth.responder
Expand Down Expand Up @@ -116,7 +117,7 @@ def on_error(self, msg):
Args:
msg (object): The error description.
"""
self.agent.on_error(msg=msg)
self.agent.on_error(msg=str(msg, encoding="utf-8"))
return {}


Expand Down
6 changes: 6 additions & 0 deletions comprl/server/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ def on_disconnect(self, player: IPlayer):
log.debug(f"Player {player.id} disconnected")
self.matchmaking.remove(player.id)
self.player_manager.remove(player)
self.game_manager.force_game_end(player.id)

def on_timeout(self, player: IPlayer, failure, timeout):
"""gets called when a player has a timeout"""
log.debug(f"Player {player.id} had timeout after {timeout}s")
player.disconnect(reason=f"Timeout after {timeout}s")

def on_update(self):
"""gets called every update cycle"""
Expand Down
79 changes: 64 additions & 15 deletions comprl/server/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from comprl.shared.types import GameID, PlayerID
from comprl.server.util import IDGenerator
from comprl.server.data.interfaces import GameResult
from comprl.server.data.interfaces import GameResult, GameEndState


class IAction:
Expand Down Expand Up @@ -101,8 +101,10 @@ def __init__(self, players: list[IPlayer]) -> None:
"""
self.id: GameID = IDGenerator.generate_game_id()
self.players = {p.id: p for p in players}
self.start_time = datetime.now()
self.finish_callbacks: list[Callable[["IGame"], None]] = []
self.scores: dict[PlayerID, float] = {p.id: 0.0 for p in players}
self.start_time = datetime.now()
self.disconnected_player_id: PlayerID | None = None

def add_finish_callback(self, callback: Callable[["IGame"], None]) -> None:
"""
Expand All @@ -124,7 +126,7 @@ def start(self):

self._run()

def end(self, reason="unknown"):
def _end(self, reason="unknown"):
"""
Notifies all players that the game has ended.
Expand Down Expand Up @@ -158,15 +160,63 @@ def _run(self):
for p in self.players.values():

def _res(value, id=p.id):
# TODO: validate action here ?
if not self._validate_action(value):
self.players[id].disconnect("Invalid action")
actions[id] = value
if len(actions) == len(self.players):
# all players have submitted their actions
# update the game, and if the game is over, end it
self._run() if not self.update(actions) else self.end()
if self.disconnected_player_id is not None:
return
if not self.update(actions):
self._run()
else:
self._end(reason="Player won")

p.get_action(self.get_observation(p.id), _res)

def force_end(self, player_id: PlayerID):
"""forces the end of the game. Should be used when a player disconnects.
Args:
player_id (PlayerID): the player that caused the forced end (disconnected)
"""
self.disconnected_player_id = player_id
self._end(reason="Player disconnected")

def get_result(self) -> GameResult | None:
"""
Returns the result of the game.
Returns:
GameResult: The result of the game.
"""
player_ids = list(self.players.keys())

game_end_state = GameEndState.DRAW
if self._player_won(player_ids[0]) or self._player_won(player_ids[1]):
game_end_state = GameEndState.WIN
if self.disconnected_player_id is not None:
game_end_state = GameEndState.DISCONNECTED

user1_id = self.players[player_ids[0]].user_id
user2_id = self.players[player_ids[1]].user_id

if user1_id is None or user2_id is None:
return None

return GameResult(
game_id=self.id,
user1_id=user1_id,
user2_id=user2_id,
score_user_1=self.scores[player_ids[0]],
score_user_2=self.scores[player_ids[1]],
start_time=self.start_time,
end_state=game_end_state,
is_user1_winner=self._player_won(player_ids[0]),
is_user1_disconnected=(self.disconnected_player_id == player_ids[0]),
)

@abc.abstractmethod
def _validate_action(self, action) -> bool:
"""
Expand Down Expand Up @@ -219,16 +269,6 @@ def get_player_result(self, id: PlayerID) -> int:
"""
...

@abc.abstractmethod
def get_result(self) -> GameResult:
"""
Returns the result of the game.
Returns:
GameResult: The result of the game.
"""
...


class IServer:
"""
Expand Down Expand Up @@ -269,6 +309,15 @@ def on_disconnect(self, player: IPlayer):
"""
...

@abc.abstractmethod
def on_timeout(self, player: IPlayer, failure, timeout):
"""
Gets called when a player has a timeout.
Args:
player (IPlayer): The player that has a timeout.
"""
...

@abc.abstractmethod
def on_update(self):
"""
Expand Down
27 changes: 22 additions & 5 deletions comprl/server/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,29 @@ def end_game(self, game: IGame) -> None:
"""

if game.id in self.games:
GameData(ConfigProvider.get("game_data")).add(
self.games[game.id].get_result()
)

game_result = game.get_result()
if game_result is not None:
GameData(ConfigProvider.get("game_data")).add(game_result)
else:
log.error(f"Game had no valid result. Game-ID: {game.id}")
del self.games[game.id]

def force_game_end(self, player_id: PlayerID):
"""Forces all games, that a player is currently playing, to end.
Args:
player_id (PlayerID): id of the player
"""
involved_games: list[IGame] = []
for _, game in self.games.items():
for game_player_id in game.players:
if player_id == game_player_id:
involved_games.append(game)
break
for game in involved_games:
log.debug("Game was forced to end because of a disconnected player")
game.force_end(player_id=player_id)

def get(self, game_id: GameID) -> IGame | None:
"""
Retrieves the game instance with the specified ID.
Expand Down Expand Up @@ -259,7 +276,7 @@ def remove(self, player_id: PlayerID) -> None:
Args:
player_id (PlayerID): The ID of the player to be removed.
"""
pass
self.queue = [p for p in self.queue if (p != player_id)]

def update(self):
"""
Expand Down
62 changes: 51 additions & 11 deletions comprl/server/networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ def __init__(self, boxReceiver=None, locator=None):

self.connection_made_callbacks: list[Callable[[], None]] = []

Check failure on line 37 in comprl/server/networking.py

View workflow job for this annotation

GitHub Actions / mypy

note: By default the bodies of untyped functions are not checked, consider using --check-untyped-defs [annotation-unchecked]
self.connection_lost_callbacks: list[Callable[[], None]] = []

Check failure on line 38 in comprl/server/networking.py

View workflow job for this annotation

GitHub Actions / mypy

note: By default the bodies of untyped functions are not checked, consider using --check-untyped-defs [annotation-unchecked]
self.connection_timeout_callbacks: list[Callable[[any, any], None]] = []

Check failure on line 39 in comprl/server/networking.py

View workflow job for this annotation

GitHub Actions / mypy

note: By default the bodies of untyped functions are not checked, consider using --check-untyped-defs [annotation-unchecked]

def addConnectionMadeCallback(self, callback):
def add_connection_made_callback(self, callback):
"""adds callback that is executed, when the connection is made
Args:
Expand All @@ -48,7 +49,7 @@ def addConnectionMadeCallback(self, callback):
"""
self.connection_made_callbacks.append(callback)

def addConnectionLostCallback(self, callback):
def add_connection_lost_callback(self, callback):
"""
Adds a callback function to be executed when the connection is lost.
Expand All @@ -60,6 +61,18 @@ def addConnectionLostCallback(self, callback):
"""
self.connection_lost_callbacks.append(callback)

def add_connection_timeout_callback(self, callback):
"""
Adds a callback function to be executed when there is a timeout.
Args:
callback: The callback function to be added.
Returns:
None
"""
self.connection_timeout_callbacks.append(callback)

def connectionMade(self) -> None:
"""
Called when the connection to the client is established.
Expand All @@ -83,7 +96,7 @@ def connectionLost(self, reason):
Returns:
None
"""
log.debug("connection to client lost")
# log.debug("connection to client lost")
for c in self.connection_lost_callbacks:
c()

Expand All @@ -99,7 +112,9 @@ def connectionTimeout(self, failure, timeout) -> None:
Returns:
None
"""
pass
# log.debug("connection timeout")
for c in self.connection_timeout_callbacks:
c(failure, timeout)

def get_token(self, return_callback: Callable[[str], None]) -> None:
"""
Expand All @@ -122,8 +137,11 @@ def callback(res):
self.send_error(msg="Tried to connect with wrong version")
self.disconnect()

self.callRemote(Auth).addCallback(callback=callback).addTimeout(
ConfigProvider.get("timeout"), reactor, self.connectionTimeout
return (
self.callRemote(Auth)
.addCallback(callback=callback)
.addTimeout(ConfigProvider.get("timeout"), reactor, self.connectionTimeout)
.addErrback(self.handle_remote_error)
)

def is_ready(self, return_callback: Callable[[bool], None]) -> bool:
Expand All @@ -141,6 +159,7 @@ def is_ready(self, return_callback: Callable[[bool], None]) -> bool:
self.callRemote(Ready)
.addCallback(callback=lambda res: return_callback(res["ready"]))
.addTimeout(ConfigProvider.get("timeout"), reactor, self.connectionTimeout)
.addErrback(self.handle_remote_error)
)

def notify_start(self, game_id: GameID) -> None:
Expand Down Expand Up @@ -171,6 +190,7 @@ def get_step(
self.callRemote(Step, obv=obv)
.addCallback(callback=lambda res: return_callback(res["action"]))
.addTimeout(ConfigProvider.get("timeout"), reactor, self.connectionTimeout)
.addErrback(self.handle_remote_error)
)

def notify_end(self, result, stats) -> None:
Expand All @@ -184,8 +204,10 @@ def notify_end(self, result, stats) -> None:
Returns:
None
"""
return self.callRemote(EndGame, result=result, stats=stats).addTimeout(
ConfigProvider.get("timeout"), reactor, self.connectionTimeout
return (
self.callRemote(EndGame, result=result, stats=stats)
.addTimeout(ConfigProvider.get("timeout"), reactor, self.connectionTimeout)
.addErrback(self.handle_remote_error)
)

def send_error(self, msg: str):
Expand All @@ -198,7 +220,9 @@ def send_error(self, msg: str):
Returns:
None
"""
self.callRemote(Error, msg=str.encode(msg))
return self.callRemote(Error, msg=str.encode(msg)).addErrback(
self.handle_remote_error
)

def disconnect(self):
"""
Expand All @@ -209,6 +233,15 @@ def disconnect(self):
"""
self.transport.loseConnection()

def handle_remote_error(place, error):
"""Is called when an error in Deferred occurs
Args:
place : where the error was caused
error : description of the error
"""
log.debug(f"Caught error in remote Callback at {place}")


class COMPPlayer(IPlayer):
"""Represents a player in the COMP game.
Expand Down Expand Up @@ -331,10 +364,17 @@ def buildProtocol(self, addr: IAddress) -> Protocol | None:
comp_player: COMPPlayer = COMPPlayer(protocol)

# set up the callbacks needed for the server
protocol.addConnectionMadeCallback(lambda: self.server.on_connect(comp_player))
protocol.addConnectionLostCallback(
protocol.add_connection_made_callback(
lambda: self.server.on_connect(comp_player)
)
protocol.add_connection_lost_callback(
lambda: self.server.on_disconnect(comp_player)
)
protocol.add_connection_timeout_callback(
lambda failure, timeout: self.server.on_timeout(
comp_player, failure, timeout
)
)

return protocol

Expand Down
3 changes: 2 additions & 1 deletion examples/simple/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

@bob.event
def get_step(obv: list[float]):
return [float(random.randint(1, 2))]
# return [float(random.randint(1, 2))]
return [float(input("enter number: ")) or float(random.randint(1, 2))]


@bob.event
Expand Down
Loading

0 comments on commit 1ebdbe3

Please sign in to comment.