diff --git a/replit_river/session.py b/replit_river/session.py index 2be2b81..a321448 100644 --- a/replit_river/session.py +++ b/replit_river/session.py @@ -116,10 +116,15 @@ async def is_websocket_open(self) -> bool: async with self._ws_lock: return await self._ws_wrapper.is_open() - async def _begin_close_session_countdown(self) -> None: - """Begin the countdown to close session, this should be called when - websocket is closed. + async def _handle_connection_closed(self) -> None: """ + Handle the WebSocket connection being closing. + This will trigger connection retries, if enabled, and starts a + session reconnection timer. + If the timer expires before reconnection, the session will be closed. + """ + # ensure websocket is closed and initiate connection retries if applicable. + await self.close_websocket(self._ws_wrapper, not self._is_server) # calculate the value now before establishing it so that there are no # await points between the check and the assignment to avoid a TOCTOU # race. @@ -145,12 +150,7 @@ async def serve(self) -> None: try: await self._handle_messages_from_ws(tg) except ConnectionClosed: - if self._retry_connection_callback: - self._task_manager.create_task( - self._retry_connection_callback() - ) - - await self._begin_close_session_countdown() + await self._handle_connection_closed() logger.debug("ConnectionClosed while serving", exc_info=True) except FailedSendingMessageException: # Expected error if the connection is closed. @@ -310,10 +310,7 @@ async def _heartbeat( "%r closing websocket because of heartbeat misses", self.session_id, ) - await self.close_websocket( - self._ws_wrapper, should_retry=not self._is_server - ) - await self._begin_close_session_countdown() + await self._handle_connection_closed() continue except FailedSendingMessageException: # this is expected during websocket closed period @@ -344,9 +341,7 @@ async def _send_transport_message( websocket: websockets.WebSocketCommonProtocol, ) -> None: try: - await send_transport_message( - msg, websocket, self._begin_close_session_countdown - ) + await send_transport_message(msg, websocket, self._handle_connection_closed) except WebsocketClosedException as e: raise e except FailedSendingMessageException as e: