diff --git a/docs/client.rst b/docs/client.rst index 1a55b71e..e3e1fb2c 100644 --- a/docs/client.rst +++ b/docs/client.rst @@ -312,8 +312,8 @@ server:: print("The connection failed!") @sio.event - def disconnect(): - print("I'm disconnected!") + def disconnect(reason): + print("I'm disconnected! reason:", reason) The ``connect_error`` handler is invoked when a connection attempt fails. If the server provides arguments, these are passed on to the handler. The server @@ -325,7 +325,20 @@ server initiated disconnects, or accidental disconnects, for example due to networking failures. In the case of an accidental disconnection, the client is going to attempt to reconnect immediately after invoking the disconnect handler. As soon as the connection is re-established the connect handler will -be invoked once again. +be invoked once again. The handler receives a ``reason`` argument which +provides the cause of the disconnection:: + + @sio.event + def disconnect(reason): + if reason == sio.reason.CLIENT_DISCONNECT: + print('the client disconnected') + elif reason == sio.reason.SERVER_DISCONNECT: + print('the server disconnected the client') + else: + print('disconnect reason:', reason) + +See the The :attr:`socketio.Client.reason` attribute for a list of possible +disconnection reasons. The ``connect``, ``connect_error`` and ``disconnect`` events have to be defined explicitly and are not invoked on a catch-all event handler. @@ -509,7 +522,7 @@ that belong to a namespace can be created as methods of a subclass of def on_connect(self): pass - def on_disconnect(self): + def on_disconnect(self, reason): pass def on_my_event(self, data): @@ -525,7 +538,7 @@ coroutines if desired:: def on_connect(self): pass - def on_disconnect(self): + def on_disconnect(self, reason): pass async def on_my_event(self, data): diff --git a/docs/server.rst b/docs/server.rst index c20adf9f..ed15ed32 100644 --- a/docs/server.rst +++ b/docs/server.rst @@ -232,8 +232,8 @@ automatically when a client connects or disconnects from the server:: print('connect ', sid) @sio.event - def disconnect(sid): - print('disconnect ', sid) + def disconnect(sid, reason): + print('disconnect ', sid, reason) The ``connect`` event is an ideal place to perform user authentication, and any necessary mapping between user entities in the application and the ``sid`` @@ -256,6 +256,21 @@ message:: def connect(sid, environ, auth): raise ConnectionRefusedError('authentication failed') +The disconnect handler receives the ``sid`` assigned to the client and a +``reason``, which provides the cause of the disconnection:: + + @sio.event + def disconnect(sid, reason): + if reason == sio.reason.CLIENT_DISCONNECT: + print('the client disconnected') + elif reason == sio.reason.SERVER_DISCONNECT: + print('the server disconnected the client') + else: + print('disconnect reason:', reason) + +See the The :attr:`socketio.Server.reason` attribute for a list of possible +disconnection reasons. + Catch-All Event Handlers ~~~~~~~~~~~~~~~~~~~~~~~~ @@ -433,7 +448,7 @@ belong to a namespace can be created as methods in a subclass of def on_connect(self, sid, environ): pass - def on_disconnect(self, sid): + def on_disconnect(self, sid, reason): pass def on_my_event(self, sid, data): @@ -449,7 +464,7 @@ if desired:: def on_connect(self, sid, environ): pass - def on_disconnect(self, sid): + def on_disconnect(self, sid, reason): pass async def on_my_event(self, sid, data): diff --git a/examples/client/async/fiddle_client.py b/examples/client/async/fiddle_client.py index 5b43dccd..e5aeb6cc 100644 --- a/examples/client/async/fiddle_client.py +++ b/examples/client/async/fiddle_client.py @@ -10,8 +10,8 @@ async def connect(): @sio.event -async def disconnect(): - print('disconnected from server') +async def disconnect(reason): + print('disconnected from server, reason:', reason) @sio.event diff --git a/examples/client/sync/fiddle_client.py b/examples/client/sync/fiddle_client.py index 50f5e2aa..71a7a540 100644 --- a/examples/client/sync/fiddle_client.py +++ b/examples/client/sync/fiddle_client.py @@ -9,8 +9,8 @@ def connect(): @sio.event -def disconnect(): - print('disconnected from server') +def disconnect(reason): + print('disconnected from server, reason:', reason) @sio.event diff --git a/examples/server/aiohttp/app.html b/examples/server/aiohttp/app.html index 74d404d7..627b9186 100644 --- a/examples/server/aiohttp/app.html +++ b/examples/server/aiohttp/app.html @@ -11,8 +11,8 @@ socket.on('connect', function() { socket.emit('my_event', {data: 'I\'m connected!'}); }); - socket.on('disconnect', function() { - $('#log').append('
Disconnected'); + socket.on('disconnect', function(reason) { + $('#log').append('
Disconnected: ' + reason); }); socket.on('my_response', function(msg) { $('#log').append('
Received: ' + msg.data); diff --git a/examples/server/aiohttp/app.py b/examples/server/aiohttp/app.py index cba51937..1568ca1f 100644 --- a/examples/server/aiohttp/app.py +++ b/examples/server/aiohttp/app.py @@ -70,8 +70,8 @@ async def connect(sid, environ): @sio.event -def disconnect(sid): - print('Client disconnected') +def disconnect(sid, reason): + print('Client disconnected, reason:', reason) app.router.add_static('/static', 'static') @@ -84,4 +84,4 @@ async def init_app(): if __name__ == '__main__': - web.run_app(init_app()) + web.run_app(init_app(), port=5000) diff --git a/examples/server/aiohttp/fiddle.py b/examples/server/aiohttp/fiddle.py index dfde8e10..64ce330d 100644 --- a/examples/server/aiohttp/fiddle.py +++ b/examples/server/aiohttp/fiddle.py @@ -19,8 +19,8 @@ async def connect(sid, environ, auth): @sio.event -def disconnect(sid): - print('disconnected', sid) +def disconnect(sid, reason): + print('disconnected', sid, reason) app.router.add_static('/static', 'static') @@ -28,4 +28,4 @@ def disconnect(sid): if __name__ == '__main__': - web.run_app(app) + web.run_app(app, port=5000) diff --git a/examples/server/asgi/app.html b/examples/server/asgi/app.html index d2f0e9ac..ad826565 100644 --- a/examples/server/asgi/app.html +++ b/examples/server/asgi/app.html @@ -11,8 +11,8 @@ socket.on('connect', function() { socket.emit('my_event', {data: 'I\'m connected!'}); }); - socket.on('disconnect', function() { - $('#log').append('
Disconnected'); + socket.on('disconnect', function(reason) { + $('#log').append('
Disconnected: ' + reason); }); socket.on('my_response', function(msg) { $('#log').append('
Received: ' + msg.data); diff --git a/examples/server/asgi/app.py b/examples/server/asgi/app.py index 36af85f2..d549ab09 100644 --- a/examples/server/asgi/app.py +++ b/examples/server/asgi/app.py @@ -88,8 +88,8 @@ async def test_connect(sid, environ): @sio.on('disconnect') -def test_disconnect(sid): - print('Client disconnected') +def test_disconnect(sid, reason): + print('Client disconnected, reason:', reason) if __name__ == '__main__': diff --git a/examples/server/asgi/fiddle.py b/examples/server/asgi/fiddle.py index 6899ed1a..402a3799 100644 --- a/examples/server/asgi/fiddle.py +++ b/examples/server/asgi/fiddle.py @@ -17,8 +17,8 @@ async def connect(sid, environ, auth): @sio.event -def disconnect(sid): - print('disconnected', sid) +def disconnect(sid, reason): + print('disconnected', sid, reason) if __name__ == '__main__': diff --git a/examples/server/javascript/fiddle.js b/examples/server/javascript/fiddle.js index 940e4da3..c6a039a0 100644 --- a/examples/server/javascript/fiddle.js +++ b/examples/server/javascript/fiddle.js @@ -19,8 +19,8 @@ io.on('connection', socket => { hello: 'you' }); - socket.on('disconnect', () => { - console.log(`disconnect ${socket.id}`); + socket.on('disconnect', (reason) => { + console.log(`disconnect ${socket.id}, reason: ${reason}`); }); }); diff --git a/examples/server/sanic/app.html b/examples/server/sanic/app.html index 30c59643..b87b2df1 100644 --- a/examples/server/sanic/app.html +++ b/examples/server/sanic/app.html @@ -11,8 +11,8 @@ socket.on('connect', function() { socket.emit('my_event', {data: 'I\'m connected!'}); }); - socket.on('disconnect', function() { - $('#log').append('
Disconnected'); + socket.on('disconnect', function(reason) { + $('#log').append('
Disconnected: ' + reason); }); socket.on('my_response', function(msg) { $('#log').append('
Received: ' + msg.data); diff --git a/examples/server/sanic/app.py b/examples/server/sanic/app.py index 7f02d231..447ddff6 100644 --- a/examples/server/sanic/app.py +++ b/examples/server/sanic/app.py @@ -77,8 +77,8 @@ async def connect(sid, environ): @sio.event -def disconnect(sid): - print('Client disconnected') +def disconnect(sid, reason): + print('Client disconnected, reason:', reason) app.static('/static', './static') diff --git a/examples/server/sanic/fiddle.py b/examples/server/sanic/fiddle.py index 5ecb509e..405e6e56 100644 --- a/examples/server/sanic/fiddle.py +++ b/examples/server/sanic/fiddle.py @@ -21,8 +21,8 @@ async def connect(sid, environ, auth): @sio.event -def disconnect(sid): - print('disconnected', sid) +def disconnect(sid, reason): + print('disconnected', sid, reason) app.static('/static', './static') diff --git a/examples/server/tornado/app.py b/examples/server/tornado/app.py index 16f7a191..58317d9b 100644 --- a/examples/server/tornado/app.py +++ b/examples/server/tornado/app.py @@ -75,8 +75,8 @@ async def connect(sid, environ): @sio.event -def disconnect(sid): - print('Client disconnected') +def disconnect(sid, reason): + print('Client disconnected, reason:', reason) def main(): diff --git a/examples/server/tornado/fiddle.py b/examples/server/tornado/fiddle.py index 1e7e9278..b3878a2a 100644 --- a/examples/server/tornado/fiddle.py +++ b/examples/server/tornado/fiddle.py @@ -24,8 +24,8 @@ async def connect(sid, environ, auth): @sio.event -def disconnect(sid): - print('disconnected', sid) +def disconnect(sid, reason): + print('disconnected', sid, reason) def main(): diff --git a/examples/server/tornado/templates/app.html b/examples/server/tornado/templates/app.html index 74d404d7..627b9186 100644 --- a/examples/server/tornado/templates/app.html +++ b/examples/server/tornado/templates/app.html @@ -11,8 +11,8 @@ socket.on('connect', function() { socket.emit('my_event', {data: 'I\'m connected!'}); }); - socket.on('disconnect', function() { - $('#log').append('
Disconnected'); + socket.on('disconnect', function(reason) { + $('#log').append('
Disconnected: ' + reason); }); socket.on('my_response', function(msg) { $('#log').append('
Received: ' + msg.data); diff --git a/examples/server/wsgi/app.py b/examples/server/wsgi/app.py index 7b019fd0..62bd59b1 100644 --- a/examples/server/wsgi/app.py +++ b/examples/server/wsgi/app.py @@ -94,8 +94,8 @@ def connect(sid, environ): @sio.event -def disconnect(sid): - print('Client disconnected') +def disconnect(sid, reason): + print('Client disconnected, reason:', reason) if __name__ == '__main__': diff --git a/examples/server/wsgi/django_socketio/socketio_app/static/index.html b/examples/server/wsgi/django_socketio/socketio_app/static/index.html index 6dbef78b..b10818f4 100644 --- a/examples/server/wsgi/django_socketio/socketio_app/static/index.html +++ b/examples/server/wsgi/django_socketio/socketio_app/static/index.html @@ -11,8 +11,8 @@ socket.on('connect', function() { socket.emit('my_event', {data: 'I\'m connected!'}); }); - socket.on('disconnect', function() { - $('#log').append('
Disconnected'); + socket.on('disconnect', function(reason) { + $('#log').append('
Disconnected: ' + reason); }); socket.on('my_response', function(msg) { $('#log').append('
Received: ' + msg.data); diff --git a/examples/server/wsgi/django_socketio/socketio_app/views.py b/examples/server/wsgi/django_socketio/socketio_app/views.py index 854c0fb0..f54e1d67 100644 --- a/examples/server/wsgi/django_socketio/socketio_app/views.py +++ b/examples/server/wsgi/django_socketio/socketio_app/views.py @@ -78,5 +78,5 @@ def connect(sid, environ): @sio.event -def disconnect(sid): - print('Client disconnected') +def disconnect(sid, reason): + print('Client disconnected, reason:', reason) diff --git a/examples/server/wsgi/fiddle.py b/examples/server/wsgi/fiddle.py index 247751be..e9cd703d 100644 --- a/examples/server/wsgi/fiddle.py +++ b/examples/server/wsgi/fiddle.py @@ -23,8 +23,8 @@ def connect(sid, environ, auth): @sio.event -def disconnect(sid): - print('disconnected', sid) +def disconnect(sid, reason): + print('disconnected', sid, reason) if __name__ == '__main__': diff --git a/examples/server/wsgi/templates/index.html b/examples/server/wsgi/templates/index.html index 8a7308af..e37a6cbd 100644 --- a/examples/server/wsgi/templates/index.html +++ b/examples/server/wsgi/templates/index.html @@ -11,8 +11,8 @@ socket.on('connect', function() { socket.emit('my_event', {data: 'I\'m connected!'}); }); - socket.on('disconnect', function() { - $('#log').append('
Disconnected'); + socket.on('disconnect', function(reason) { + $('#log').append('
Disconnected: ' + reason); }); socket.on('my_response', function(msg) { $('#log').append('
Received: ' + msg.data); diff --git a/pyproject.toml b/pyproject.toml index 8a5453a1..a3ebb6f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ classifiers = [ requires-python = ">=3.8" dependencies = [ "bidict >= 0.21.0", - "python-engineio >= 4.8.0", + "python-engineio >= 4.11.0", ] [project.readme] diff --git a/src/socketio/async_client.py b/src/socketio/async_client.py index fb1abc14..463073e7 100644 --- a/src/socketio/async_client.py +++ b/src/socketio/async_client.py @@ -338,7 +338,6 @@ async def shutdown(self): await self.disconnect() elif self._reconnect_task: # pragma: no branch self._reconnect_abort.set() - print(self._reconnect_task) await self._reconnect_task def start_background_task(self, target, *args, **kwargs): @@ -398,8 +397,9 @@ async def _handle_disconnect(self, namespace): if not self.connected: return namespace = namespace or '/' - await self._trigger_event('disconnect', namespace=namespace) - await self._trigger_event('__disconnect_final', namespace=namespace) + await self._trigger_event('disconnect', namespace, + self.reason.SERVER_DISCONNECT) + await self._trigger_event('__disconnect_final', namespace) if namespace in self.namespaces: del self.namespaces[namespace] if not self.namespaces: @@ -462,11 +462,27 @@ async def _trigger_event(self, event, namespace, *args): if handler: if asyncio.iscoroutinefunction(handler): try: - ret = await handler(*args) + try: + ret = await handler(*args) + except TypeError: + # the legacy disconnect event does not take a reason + # argument + if event == 'disconnect': + ret = await handler(*args[:-1]) + else: # pragma: no cover + raise except asyncio.CancelledError: # pragma: no cover ret = None else: - ret = handler(*args) + try: + ret = handler(*args) + except TypeError: + # the legacy disconnect event does not take a reason + # argument + if event == 'disconnect': + ret = handler(*args[:-1]) + else: # pragma: no cover + raise return ret # or else, forward the event to a namepsace handler if one exists @@ -566,16 +582,15 @@ async def _handle_eio_message(self, data): else: raise ValueError('Unknown packet type.') - async def _handle_eio_disconnect(self): + async def _handle_eio_disconnect(self, reason): """Handle the Engine.IO disconnection event.""" self.logger.info('Engine.IO connection dropped') will_reconnect = self.reconnection and self.eio.state == 'connected' if self.connected: for n in self.namespaces: - await self._trigger_event('disconnect', namespace=n) + await self._trigger_event('disconnect', n, reason) if not will_reconnect: - await self._trigger_event('__disconnect_final', - namespace=n) + await self._trigger_event('__disconnect_final', n) self.namespaces = {} self.connected = False self.callbacks = {} diff --git a/src/socketio/async_namespace.py b/src/socketio/async_namespace.py index 89442aeb..42d65089 100644 --- a/src/socketio/async_namespace.py +++ b/src/socketio/async_namespace.py @@ -34,11 +34,27 @@ async def trigger_event(self, event, *args): handler = getattr(self, handler_name) if asyncio.iscoroutinefunction(handler) is True: try: - ret = await handler(*args) + try: + ret = await handler(*args) + except TypeError: + # legacy disconnect events do not have a reason + # argument + if event == 'disconnect': + ret = await handler(*args[:-1]) + else: # pragma: no cover + raise except asyncio.CancelledError: # pragma: no cover ret = None else: - ret = handler(*args) + try: + ret = handler(*args) + except TypeError: + # legacy disconnect events do not have a reason + # argument + if event == 'disconnect': + ret = handler(*args[:-1]) + else: # pragma: no cover + raise return ret async def emit(self, event, data=None, to=None, room=None, skip_sid=None, @@ -199,11 +215,27 @@ async def trigger_event(self, event, *args): handler = getattr(self, handler_name) if asyncio.iscoroutinefunction(handler) is True: try: - ret = await handler(*args) + try: + ret = await handler(*args) + except TypeError: + # legacy disconnect events do not have a reason + # argument + if event == 'disconnect': + ret = await handler(*args[:-1]) + else: # pragma: no cover + raise except asyncio.CancelledError: # pragma: no cover ret = None else: - ret = handler(*args) + try: + ret = handler(*args) + except TypeError: + # legacy disconnect events do not have a reason + # argument + if event == 'disconnect': + ret = handler(*args[:-1]) + else: # pragma: no cover + raise return ret async def emit(self, event, data=None, namespace=None, callback=None): diff --git a/src/socketio/async_server.py b/src/socketio/async_server.py index 9b0e9774..f10fb8ae 100644 --- a/src/socketio/async_server.py +++ b/src/socketio/async_server.py @@ -427,7 +427,8 @@ async def disconnect(self, sid, namespace=None, ignore_queue=False): eio_sid = self.manager.pre_disconnect(sid, namespace=namespace) await self._send_packet(eio_sid, self.packet_class( packet.DISCONNECT, namespace=namespace)) - await self._trigger_event('disconnect', namespace, sid) + await self._trigger_event('disconnect', namespace, sid, + self.reason.SERVER_DISCONNECT) await self.manager.disconnect(sid, namespace=namespace, ignore_queue=True) @@ -575,14 +576,15 @@ async def _handle_connect(self, eio_sid, namespace, data): await self._send_packet(eio_sid, self.packet_class( packet.CONNECT, {'sid': sid}, namespace=namespace)) - async def _handle_disconnect(self, eio_sid, namespace): + async def _handle_disconnect(self, eio_sid, namespace, reason=None): """Handle a client disconnect.""" namespace = namespace or '/' sid = self.manager.sid_from_eio_sid(eio_sid, namespace) if not self.manager.is_connected(sid, namespace): # pragma: no cover return self.manager.pre_disconnect(sid, namespace=namespace) - await self._trigger_event('disconnect', namespace, sid) + await self._trigger_event('disconnect', namespace, sid, + reason or self.reason.CLIENT_DISCONNECT) await self.manager.disconnect(sid, namespace, ignore_queue=True) async def _handle_event(self, eio_sid, namespace, id, data): @@ -634,11 +636,25 @@ async def _trigger_event(self, event, namespace, *args): if handler: if asyncio.iscoroutinefunction(handler): try: - ret = await handler(*args) + try: + ret = await handler(*args) + except TypeError: + # legacy disconnect events use only one argument + if event == 'disconnect': + ret = await handler(*args[:-1]) + else: # pragma: no cover + raise except asyncio.CancelledError: # pragma: no cover ret = None else: - ret = handler(*args) + try: + ret = handler(*args) + except TypeError: + # legacy disconnect events use only one argument + if event == 'disconnect': + ret = handler(*args[:-1]) + else: # pragma: no cover + raise return ret # or else, forward the event to a namespace handler if one exists handler, args = self._get_namespace_handler(namespace, args) @@ -671,7 +687,8 @@ async def _handle_eio_message(self, eio_sid, data): if pkt.packet_type == packet.CONNECT: await self._handle_connect(eio_sid, pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: - await self._handle_disconnect(eio_sid, pkt.namespace) + await self._handle_disconnect(eio_sid, pkt.namespace, + self.reason.CLIENT_DISCONNECT) elif pkt.packet_type == packet.EVENT: await self._handle_event(eio_sid, pkt.namespace, pkt.id, pkt.data) @@ -686,10 +703,10 @@ async def _handle_eio_message(self, eio_sid, data): else: raise ValueError('Unknown packet type.') - async def _handle_eio_disconnect(self, eio_sid): + async def _handle_eio_disconnect(self, eio_sid, reason): """Handle Engine.IO disconnect event.""" for n in list(self.manager.get_namespaces()).copy(): - await self._handle_disconnect(eio_sid, n) + await self._handle_disconnect(eio_sid, n, reason) if eio_sid in self.environ: del self.environ[eio_sid] diff --git a/src/socketio/base_client.py b/src/socketio/base_client.py index 1becf914..7bf44207 100644 --- a/src/socketio/base_client.py +++ b/src/socketio/base_client.py @@ -3,6 +3,8 @@ import signal import threading +import engineio + from . import base_namespace from . import packet @@ -31,6 +33,7 @@ def signal_handler(sig, frame): # pragma: no cover class BaseClient: reserved_events = ['connect', 'connect_error', 'disconnect', '__disconnect_final'] + reason = engineio.Client.reason def __init__(self, reconnection=True, reconnection_attempts=0, reconnection_delay=1, reconnection_delay_max=5, @@ -285,7 +288,7 @@ def _handle_eio_connect(self): # pragma: no cover def _handle_eio_message(self, data): # pragma: no cover raise NotImplementedError() - def _handle_eio_disconnect(self): # pragma: no cover + def _handle_eio_disconnect(self, reason): # pragma: no cover raise NotImplementedError() def _engineio_client_class(self): # pragma: no cover diff --git a/src/socketio/base_server.py b/src/socketio/base_server.py index d5a353bc..d134eba1 100644 --- a/src/socketio/base_server.py +++ b/src/socketio/base_server.py @@ -1,5 +1,7 @@ import logging +import engineio + from . import manager from . import base_namespace from . import packet @@ -9,6 +11,7 @@ class BaseServer: reserved_events = ['connect', 'disconnect'] + reason = engineio.Server.reason def __init__(self, client_manager=None, logger=False, serializer='default', json=None, async_handlers=True, always_connect=False, diff --git a/src/socketio/client.py b/src/socketio/client.py index c4f9eaa7..ade2dd6f 100644 --- a/src/socketio/client.py +++ b/src/socketio/client.py @@ -377,8 +377,9 @@ def _handle_disconnect(self, namespace): if not self.connected: return namespace = namespace or '/' - self._trigger_event('disconnect', namespace=namespace) - self._trigger_event('__disconnect_final', namespace=namespace) + self._trigger_event('disconnect', namespace, + self.reason.SERVER_DISCONNECT) + self._trigger_event('__disconnect_final', namespace) if namespace in self.namespaces: del self.namespaces[namespace] if not self.namespaces: @@ -436,7 +437,14 @@ def _trigger_event(self, event, namespace, *args): # first see if we have an explicit handler for the event handler, args = self._get_event_handler(event, namespace, args) if handler: - return handler(*args) + try: + return handler(*args) + except TypeError: + # the legacy disconnect event does not take a reason argument + if event == 'disconnect': + return handler(*args[:-1]) + else: # pragma: no cover + raise # or else, forward the event to a namespace handler if one exists handler, args = self._get_namespace_handler(namespace, args) @@ -525,15 +533,15 @@ def _handle_eio_message(self, data): else: raise ValueError('Unknown packet type.') - def _handle_eio_disconnect(self): + def _handle_eio_disconnect(self, reason): """Handle the Engine.IO disconnection event.""" self.logger.info('Engine.IO connection dropped') will_reconnect = self.reconnection and self.eio.state == 'connected' if self.connected: for n in self.namespaces: - self._trigger_event('disconnect', namespace=n) + self._trigger_event('disconnect', n, reason) if not will_reconnect: - self._trigger_event('__disconnect_final', namespace=n) + self._trigger_event('__disconnect_final', n) self.namespaces = {} self.connected = False self.callbacks = {} diff --git a/src/socketio/namespace.py b/src/socketio/namespace.py index 3bf4f95b..60cab783 100644 --- a/src/socketio/namespace.py +++ b/src/socketio/namespace.py @@ -23,7 +23,14 @@ def trigger_event(self, event, *args): """ handler_name = 'on_' + (event or '') if hasattr(self, handler_name): - return getattr(self, handler_name)(*args) + try: + return getattr(self, handler_name)(*args) + except TypeError: + # legacy disconnect events do not have a reason argument + if event == 'disconnect': + return getattr(self, handler_name)(*args[:-1]) + else: # pragma: no cover + raise def emit(self, event, data=None, to=None, room=None, skip_sid=None, namespace=None, callback=None, ignore_queue=False): @@ -154,7 +161,14 @@ def trigger_event(self, event, *args): """ handler_name = 'on_' + (event or '') if hasattr(self, handler_name): - return getattr(self, handler_name)(*args) + try: + return getattr(self, handler_name)(*args) + except TypeError: + # legacy disconnect events do not have a reason argument + if event == 'disconnect': + return getattr(self, handler_name)(*args[:-1]) + else: # pragma: no cover + raise def emit(self, event, data=None, namespace=None, callback=None): """Emit a custom event to the server. diff --git a/src/socketio/server.py b/src/socketio/server.py index ae73df6a..71c702de 100644 --- a/src/socketio/server.py +++ b/src/socketio/server.py @@ -403,7 +403,8 @@ def disconnect(self, sid, namespace=None, ignore_queue=False): eio_sid = self.manager.pre_disconnect(sid, namespace=namespace) self._send_packet(eio_sid, self.packet_class( packet.DISCONNECT, namespace=namespace)) - self._trigger_event('disconnect', namespace, sid) + self._trigger_event('disconnect', namespace, sid, + self.reason.SERVER_DISCONNECT) self.manager.disconnect(sid, namespace=namespace, ignore_queue=True) @@ -557,14 +558,15 @@ def _handle_connect(self, eio_sid, namespace, data): self._send_packet(eio_sid, self.packet_class( packet.CONNECT, {'sid': sid}, namespace=namespace)) - def _handle_disconnect(self, eio_sid, namespace): + def _handle_disconnect(self, eio_sid, namespace, reason=None): """Handle a client disconnect.""" namespace = namespace or '/' sid = self.manager.sid_from_eio_sid(eio_sid, namespace) if not self.manager.is_connected(sid, namespace): # pragma: no cover return self.manager.pre_disconnect(sid, namespace=namespace) - self._trigger_event('disconnect', namespace, sid) + self._trigger_event('disconnect', namespace, sid, + reason or self.reason.CLIENT_DISCONNECT) self.manager.disconnect(sid, namespace, ignore_queue=True) def _handle_event(self, eio_sid, namespace, id, data): @@ -611,7 +613,14 @@ def _trigger_event(self, event, namespace, *args): # first see if we have an explicit handler for the event handler, args = self._get_event_handler(event, namespace, args) if handler: - return handler(*args) + try: + return handler(*args) + except TypeError: + # legacy disconnect events use only one argument + if event == 'disconnect': + return handler(*args[:-1]) + else: # pragma: no cover + raise # or else, forward the event to a namespace handler if one exists handler, args = self._get_namespace_handler(namespace, args) if handler: @@ -642,7 +651,8 @@ def _handle_eio_message(self, eio_sid, data): if pkt.packet_type == packet.CONNECT: self._handle_connect(eio_sid, pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: - self._handle_disconnect(eio_sid, pkt.namespace) + self._handle_disconnect(eio_sid, pkt.namespace, + self.reason.CLIENT_DISCONNECT) elif pkt.packet_type == packet.EVENT: self._handle_event(eio_sid, pkt.namespace, pkt.id, pkt.data) elif pkt.packet_type == packet.ACK: @@ -655,10 +665,10 @@ def _handle_eio_message(self, eio_sid, data): else: raise ValueError('Unknown packet type.') - def _handle_eio_disconnect(self, eio_sid): + def _handle_eio_disconnect(self, eio_sid, reason): """Handle Engine.IO disconnect event.""" for n in list(self.manager.get_namespaces()).copy(): - self._handle_disconnect(eio_sid, n) + self._handle_disconnect(eio_sid, n, reason) if eio_sid in self.environ: del self.environ[eio_sid] diff --git a/tests/async/test_client.py b/tests/async/test_client.py index 26681b76..3eec1a88 100644 --- a/tests/async/test_client.py +++ b/tests/async/test_client.py @@ -578,11 +578,9 @@ async def test_handle_disconnect(self): c._trigger_event = mock.AsyncMock() await c._handle_disconnect('/') c._trigger_event.assert_any_await( - 'disconnect', namespace='/' - ) - c._trigger_event.assert_any_await( - '__disconnect_final', namespace='/' + 'disconnect', '/', c.reason.SERVER_DISCONNECT ) + c._trigger_event.assert_any_await('__disconnect_final', '/') assert not c.connected await c._handle_disconnect('/') assert c._trigger_event.await_count == 2 @@ -593,21 +591,15 @@ async def test_handle_disconnect_namespace(self): c.namespaces = {'/foo': '1', '/bar': '2'} c._trigger_event = mock.AsyncMock() await c._handle_disconnect('/foo') - c._trigger_event.assert_any_await( - 'disconnect', namespace='/foo' - ) - c._trigger_event.assert_any_await( - '__disconnect_final', namespace='/foo' - ) + c._trigger_event.assert_any_await('disconnect', '/foo', + c.reason.SERVER_DISCONNECT) + c._trigger_event.assert_any_await('__disconnect_final', '/foo') assert c.namespaces == {'/bar': '2'} assert c.connected await c._handle_disconnect('/bar') - c._trigger_event.assert_any_await( - 'disconnect', namespace='/bar' - ) - c._trigger_event.assert_any_await( - '__disconnect_final', namespace='/bar' - ) + c._trigger_event.assert_any_await('disconnect', '/bar', + c.reason.SERVER_DISCONNECT) + c._trigger_event.assert_any_await('__disconnect_final', '/bar') assert c.namespaces == {} assert not c.connected @@ -617,12 +609,9 @@ async def test_handle_disconnect_unknown_namespace(self): c.namespaces = {'/foo': '1', '/bar': '2'} c._trigger_event = mock.AsyncMock() await c._handle_disconnect('/baz') - c._trigger_event.assert_any_await( - 'disconnect', namespace='/baz' - ) - c._trigger_event.assert_any_await( - '__disconnect_final', namespace='/baz' - ) + c._trigger_event.assert_any_await('disconnect', '/baz', + c.reason.SERVER_DISCONNECT) + c._trigger_event.assert_any_await('__disconnect_final', '/baz') assert c.namespaces == {'/foo': '1', '/bar': '2'} assert c.connected @@ -632,8 +621,9 @@ async def test_handle_disconnect_default_namespaces(self): c.namespaces = {'/foo': '1', '/bar': '2'} c._trigger_event = mock.AsyncMock() await c._handle_disconnect('/') - c._trigger_event.assert_any_await('disconnect', namespace='/') - c._trigger_event.assert_any_await('__disconnect_final', namespace='/') + c._trigger_event.assert_any_await('disconnect', '/', + c.reason.SERVER_DISCONNECT) + c._trigger_event.assert_any_await('__disconnect_final', '/') assert c.namespaces == {'/foo': '1', '/bar': '2'} assert c.connected @@ -818,6 +808,26 @@ async def test_trigger_event_namespace(self): handler.assert_awaited_once_with(1, '2') catchall_handler.assert_awaited_once_with('bar', 1, '2', 3) + async def test_trigger_legacy_disconnect_event(self): + c = async_client.AsyncClient() + + @c.on('disconnect') + def baz(): + return 'baz' + + r = await c._trigger_event('disconnect', '/', 'foo') + assert r == 'baz' + + async def test_trigger_legacy_disconnect_event_async(self): + c = async_client.AsyncClient() + + @c.on('disconnect') + async def baz(): + return 'baz' + + r = await c._trigger_event('disconnect', '/', 'foo') + assert r == 'baz' + async def test_trigger_event_class_namespace(self): c = async_client.AsyncClient() result = [] @@ -1127,10 +1137,8 @@ async def test_eio_disconnect(self): c.start_background_task = mock.MagicMock() c.sid = 'foo' c.eio.state = 'connected' - await c._handle_eio_disconnect() - c._trigger_event.assert_awaited_once_with( - 'disconnect', namespace='/' - ) + await c._handle_eio_disconnect('foo') + c._trigger_event.assert_awaited_once_with('disconnect', '/', 'foo') assert c.sid is None assert not c.connected @@ -1141,9 +1149,13 @@ async def test_eio_disconnect_namespaces(self): c._trigger_event = mock.AsyncMock() c.sid = 'foo' c.eio.state = 'connected' - await c._handle_eio_disconnect() - c._trigger_event.assert_any_await('disconnect', namespace='/foo') - c._trigger_event.assert_any_await('disconnect', namespace='/bar') + await c._handle_eio_disconnect(c.reason.CLIENT_DISCONNECT) + c._trigger_event.assert_any_await('disconnect', '/foo', + c.reason.CLIENT_DISCONNECT) + c._trigger_event.assert_any_await('disconnect', '/bar', + c.reason.CLIENT_DISCONNECT) + c._trigger_event.asserT_any_await('disconnect', '/', + c.reason.CLIENT_DISCONNECT) assert c.sid is None assert not c.connected @@ -1151,14 +1163,14 @@ async def test_eio_disconnect_reconnect(self): c = async_client.AsyncClient(reconnection=True) c.start_background_task = mock.MagicMock() c.eio.state = 'connected' - await c._handle_eio_disconnect() + await c._handle_eio_disconnect(c.reason.CLIENT_DISCONNECT) c.start_background_task.assert_called_once_with(c._handle_reconnect) async def test_eio_disconnect_self_disconnect(self): c = async_client.AsyncClient(reconnection=True) c.start_background_task = mock.MagicMock() c.eio.state = 'disconnected' - await c._handle_eio_disconnect() + await c._handle_eio_disconnect(c.reason.CLIENT_DISCONNECT) c.start_background_task.assert_not_called() async def test_eio_disconnect_no_reconnect(self): @@ -1169,13 +1181,10 @@ async def test_eio_disconnect_no_reconnect(self): c.start_background_task = mock.MagicMock() c.sid = 'foo' c.eio.state = 'connected' - await c._handle_eio_disconnect() - c._trigger_event.assert_any_await( - 'disconnect', namespace='/' - ) - c._trigger_event.assert_any_await( - '__disconnect_final', namespace='/' - ) + await c._handle_eio_disconnect(c.reason.TRANSPORT_ERROR) + c._trigger_event.assert_any_await('disconnect', '/', + c.reason.TRANSPORT_ERROR) + c._trigger_event.assert_any_await('__disconnect_final', '/') assert c.sid is None assert not c.connected c.start_background_task.assert_not_called() diff --git a/tests/async/test_manager.py b/tests/async/test_manager.py index fd5fe816..aa890649 100644 --- a/tests/async/test_manager.py +++ b/tests/async/test_manager.py @@ -183,8 +183,6 @@ async def test_close_room(self): await self.bm.enter_room(sid, '/foo', 'bar') await self.bm.enter_room(sid, '/foo', 'bar') await self.bm.close_room('bar', '/foo') - from pprint import pprint - pprint(self.bm.rooms) assert 'bar' not in self.bm.rooms['/foo'] async def test_close_invalid_room(self): diff --git a/tests/async/test_namespace.py b/tests/async/test_namespace.py index ad9b1a03..526d6769 100644 --- a/tests/async/test_namespace.py +++ b/tests/async/test_namespace.py @@ -19,13 +19,37 @@ async def on_connect(self, sid, environ): async def test_disconnect_event(self): result = {} + class MyNamespace(async_namespace.AsyncNamespace): + async def on_disconnect(self, sid, reason): + result['result'] = (sid, reason) + + ns = MyNamespace('/foo') + ns._set_server(mock.MagicMock()) + await ns.trigger_event('disconnect', 'sid', 'foo') + assert result['result'] == ('sid', 'foo') + + async def test_legacy_disconnect_event(self): + result = {} + + class MyNamespace(async_namespace.AsyncNamespace): + def on_disconnect(self, sid): + result['result'] = sid + + ns = MyNamespace('/foo') + ns._set_server(mock.MagicMock()) + await ns.trigger_event('disconnect', 'sid', 'foo') + assert result['result'] == 'sid' + + async def test_legacy_disconnect_event_async(self): + result = {} + class MyNamespace(async_namespace.AsyncNamespace): async def on_disconnect(self, sid): result['result'] = sid ns = MyNamespace('/foo') ns._set_server(mock.MagicMock()) - await ns.trigger_event('disconnect', 'sid') + await ns.trigger_event('disconnect', 'sid', 'foo') assert result['result'] == 'sid' async def test_sync_event(self): @@ -242,6 +266,42 @@ async def test_disconnect(self): await ns.disconnect('sid', namespace='/bar') ns.server.disconnect.assert_awaited_with('sid', namespace='/bar') + async def test_disconnect_event_client(self): + result = {} + + class MyNamespace(async_namespace.AsyncClientNamespace): + async def on_disconnect(self, reason): + result['result'] = reason + + ns = MyNamespace('/foo') + ns._set_client(mock.MagicMock()) + await ns.trigger_event('disconnect', 'foo') + assert result['result'] == 'foo' + + async def test_legacy_disconnect_event_client(self): + result = {} + + class MyNamespace(async_namespace.AsyncClientNamespace): + def on_disconnect(self): + result['result'] = 'ok' + + ns = MyNamespace('/foo') + ns._set_client(mock.MagicMock()) + await ns.trigger_event('disconnect', 'foo') + assert result['result'] == 'ok' + + async def test_legacy_disconnect_event_client_async(self): + result = {} + + class MyNamespace(async_namespace.AsyncClientNamespace): + async def on_disconnect(self): + result['result'] = 'ok' + + ns = MyNamespace('/foo') + ns._set_client(mock.MagicMock()) + await ns.trigger_event('disconnect', 'foo') + assert result['result'] == 'ok' + async def test_sync_event_client(self): result = {} diff --git a/tests/async/test_server.py b/tests/async/test_server.py index d9129d46..f60de27a 100644 --- a/tests/async/test_server.py +++ b/tests/async/test_server.py @@ -56,7 +56,7 @@ async def test_on_event(self, eio): def foo(): pass - def bar(): + def bar(reason): pass s.on('disconnect', bar) @@ -537,8 +537,36 @@ async def test_handle_disconnect(self, eio): s.on('disconnect', handler) await s._handle_eio_connect('123', 'environ') await s._handle_eio_message('123', '0') - await s._handle_eio_disconnect('123') - handler.assert_called_once_with('1') + await s._handle_eio_disconnect('123', 'foo') + handler.assert_called_once_with('1', 'foo') + s.manager.disconnect.assert_awaited_once_with( + '1', '/', ignore_queue=True) + assert s.environ == {} + + async def test_handle_legacy_disconnect(self, eio): + eio.return_value.send = mock.AsyncMock() + s = async_server.AsyncServer() + s.manager.disconnect = mock.AsyncMock() + handler = mock.MagicMock(side_effect=[TypeError, None]) + s.on('disconnect', handler) + await s._handle_eio_connect('123', 'environ') + await s._handle_eio_message('123', '0') + await s._handle_eio_disconnect('123', 'foo') + handler.assert_called_with('1') + s.manager.disconnect.assert_awaited_once_with( + '1', '/', ignore_queue=True) + assert s.environ == {} + + async def test_handle_legacy_disconnect_async(self, eio): + eio.return_value.send = mock.AsyncMock() + s = async_server.AsyncServer() + s.manager.disconnect = mock.AsyncMock() + handler = mock.AsyncMock(side_effect=[TypeError, None]) + s.on('disconnect', handler) + await s._handle_eio_connect('123', 'environ') + await s._handle_eio_message('123', '0') + await s._handle_eio_disconnect('123', 'foo') + handler.assert_awaited_with('1') s.manager.disconnect.assert_awaited_once_with( '1', '/', ignore_queue=True) assert s.environ == {} @@ -552,9 +580,9 @@ async def test_handle_disconnect_namespace(self, eio): s.on('disconnect', handler_namespace, namespace='/foo') await s._handle_eio_connect('123', 'environ') await s._handle_eio_message('123', '0/foo,') - await s._handle_eio_disconnect('123') + await s._handle_eio_disconnect('123', 'foo') handler.assert_not_called() - handler_namespace.assert_called_once_with('1') + handler_namespace.assert_called_once_with('1', 'foo') assert s.environ == {} async def test_handle_disconnect_only_namespace(self, eio): @@ -568,13 +596,14 @@ async def test_handle_disconnect_only_namespace(self, eio): await s._handle_eio_message('123', '0/foo,') await s._handle_eio_message('123', '1/foo,') assert handler.call_count == 0 - handler_namespace.assert_called_once_with('1') + handler_namespace.assert_called_once_with( + '1', s.reason.CLIENT_DISCONNECT) assert s.environ == {'123': 'environ'} async def test_handle_disconnect_unknown_client(self, eio): mgr = self._get_mock_manager() s = async_server.AsyncServer(client_manager=mgr) - await s._handle_eio_disconnect('123') + await s._handle_eio_disconnect('123', 'foo') async def test_handle_event(self, eio): eio.return_value.send = mock.AsyncMock() @@ -624,7 +653,8 @@ async def test_handle_event_with_catchall_namespace(self, eio): await s._handle_eio_message('123', '2/bar,["msg","a","b"]') await s._handle_eio_message('123', '2/foo,["my message","a","b","c"]') await s._handle_eio_message('123', '2/bar,["my message","a","b","c"]') - await s._trigger_event('disconnect', '/bar', sid_bar) + await s._trigger_event('disconnect', '/bar', sid_bar, + s.reason.CLIENT_DISCONNECT) connect_star_handler.assert_called_once_with('/bar', sid_bar) msg_foo_handler.assert_called_once_with(sid_foo, 'a', 'b') msg_star_handler.assert_called_once_with('/bar', sid_bar, 'a', 'b') @@ -884,8 +914,8 @@ class MyNamespace(async_namespace.AsyncNamespace): def on_connect(self, sid, environ): result['result'] = (sid, environ) - async def on_disconnect(self, sid): - result['result'] = ('disconnect', sid) + async def on_disconnect(self, sid, reason): + result['result'] = ('disconnect', sid, reason) async def on_foo(self, sid, data): result['result'] = (sid, data) @@ -908,7 +938,8 @@ async def on_baz(self, sid, data1, data2): await s._handle_eio_message('123', '2/foo,["baz","a","b"]') assert result['result'] == ('a', 'b') await s.disconnect('1', '/foo') - assert result['result'] == ('disconnect', '1') + assert result['result'] == ('disconnect', '1', + s.reason.SERVER_DISCONNECT) async def test_catchall_namespace_handler(self, eio): eio.return_value.send = mock.AsyncMock() diff --git a/tests/common/test_client.py b/tests/common/test_client.py index c7399dc1..ac930c79 100644 --- a/tests/common/test_client.py +++ b/tests/common/test_client.py @@ -752,8 +752,9 @@ def test_handle_disconnect(self): c.connected = True c._trigger_event = mock.MagicMock() c._handle_disconnect('/') - c._trigger_event.assert_any_call('disconnect', namespace='/') - c._trigger_event.assert_any_call('__disconnect_final', namespace='/') + c._trigger_event.assert_any_call('disconnect', '/', + c.reason.SERVER_DISCONNECT) + c._trigger_event.assert_any_call('__disconnect_final', '/') assert not c.connected c._handle_disconnect('/') assert c._trigger_event.call_count == 2 @@ -764,21 +765,15 @@ def test_handle_disconnect_namespace(self): c.namespaces = {'/foo': '1', '/bar': '2'} c._trigger_event = mock.MagicMock() c._handle_disconnect('/foo') - c._trigger_event.assert_any_call( - 'disconnect', namespace='/foo' - ) - c._trigger_event.assert_any_call( - '__disconnect_final', namespace='/foo' - ) + c._trigger_event.assert_any_call('disconnect', '/foo', + c.reason.SERVER_DISCONNECT) + c._trigger_event.assert_any_call('__disconnect_final', '/foo') assert c.namespaces == {'/bar': '2'} assert c.connected c._handle_disconnect('/bar') - c._trigger_event.assert_any_call( - 'disconnect', namespace='/bar' - ) - c._trigger_event.assert_any_call( - '__disconnect_final', namespace='/bar' - ) + c._trigger_event.assert_any_call('disconnect', '/bar', + c.reason.SERVER_DISCONNECT) + c._trigger_event.assert_any_call('__disconnect_final', '/bar') assert c.namespaces == {} assert not c.connected @@ -788,12 +783,9 @@ def test_handle_disconnect_unknown_namespace(self): c.namespaces = {'/foo': '1', '/bar': '2'} c._trigger_event = mock.MagicMock() c._handle_disconnect('/baz') - c._trigger_event.assert_any_call( - 'disconnect', namespace='/baz' - ) - c._trigger_event.assert_any_call( - '__disconnect_final', namespace='/baz' - ) + c._trigger_event.assert_any_call('disconnect', '/baz', + c.reason.SERVER_DISCONNECT) + c._trigger_event.assert_any_call('__disconnect_final', '/baz') assert c.namespaces == {'/foo': '1', '/bar': '2'} assert c.connected @@ -803,9 +795,9 @@ def test_handle_disconnect_default_namespace(self): c.namespaces = {'/foo': '1', '/bar': '2'} c._trigger_event = mock.MagicMock() c._handle_disconnect('/') - print(c._trigger_event.call_args_list) - c._trigger_event.assert_any_call('disconnect', namespace='/') - c._trigger_event.assert_any_call('__disconnect_final', namespace='/') + c._trigger_event.assert_any_call('disconnect', '/', + c.reason.SERVER_DISCONNECT) + c._trigger_event.assert_any_call('__disconnect_final', '/') assert c.namespaces == {'/foo': '1', '/bar': '2'} assert c.connected @@ -1003,8 +995,8 @@ class MyNamespace(namespace.ClientNamespace): def on_connect(self, ns): result['result'] = (ns,) - def on_disconnect(self, ns): - result['result'] = ('disconnect', ns) + def on_disconnect(self, ns, reason): + result['result'] = ('disconnect', ns, reason) def on_foo(self, ns, data): result['result'] = (ns, data) @@ -1025,8 +1017,8 @@ def on_baz(self, ns, data1, data2): assert result['result'] == 'bar/foo' c._trigger_event('baz', '/foo', 'a', 'b') assert result['result'] == ('/foo', 'a', 'b') - c._trigger_event('disconnect', '/foo') - assert result['result'] == ('disconnect', '/foo') + c._trigger_event('disconnect', '/foo', 'bar') + assert result['result'] == ('disconnect', '/foo', 'bar') def test_trigger_event_class_namespace(self): c = client.Client() @@ -1286,8 +1278,8 @@ def test_eio_disconnect(self): c.start_background_task = mock.MagicMock() c.sid = 'foo' c.eio.state = 'connected' - c._handle_eio_disconnect() - c._trigger_event.assert_called_once_with('disconnect', namespace='/') + c._handle_eio_disconnect('foo') + c._trigger_event.assert_called_once_with('disconnect', '/', 'foo') assert c.sid is None assert not c.connected @@ -1299,10 +1291,13 @@ def test_eio_disconnect_namespaces(self): c.start_background_task = mock.MagicMock() c.sid = 'foo' c.eio.state = 'connected' - c._handle_eio_disconnect() - c._trigger_event.assert_any_call('disconnect', namespace='/foo') - c._trigger_event.assert_any_call('disconnect', namespace='/bar') - c._trigger_event.assert_any_call('disconnect', namespace='/') + c._handle_eio_disconnect(c.reason.CLIENT_DISCONNECT) + c._trigger_event.assert_any_call('disconnect', '/foo', + c.reason.CLIENT_DISCONNECT) + c._trigger_event.assert_any_call('disconnect', '/bar', + c.reason.CLIENT_DISCONNECT) + c._trigger_event.assert_any_call('disconnect', '/', + c.reason.CLIENT_DISCONNECT) assert c.sid is None assert not c.connected @@ -1310,14 +1305,14 @@ def test_eio_disconnect_reconnect(self): c = client.Client(reconnection=True) c.start_background_task = mock.MagicMock() c.eio.state = 'connected' - c._handle_eio_disconnect() + c._handle_eio_disconnect(c.reason.CLIENT_DISCONNECT) c.start_background_task.assert_called_once_with(c._handle_reconnect) def test_eio_disconnect_self_disconnect(self): c = client.Client(reconnection=True) c.start_background_task = mock.MagicMock() c.eio.state = 'disconnected' - c._handle_eio_disconnect() + c._handle_eio_disconnect(c.reason.CLIENT_DISCONNECT) c.start_background_task.assert_not_called() def test_eio_disconnect_no_reconnect(self): @@ -1328,9 +1323,10 @@ def test_eio_disconnect_no_reconnect(self): c.start_background_task = mock.MagicMock() c.sid = 'foo' c.eio.state = 'connected' - c._handle_eio_disconnect() - c._trigger_event.assert_any_call('disconnect', namespace='/') - c._trigger_event.assert_any_call('__disconnect_final', namespace='/') + c._handle_eio_disconnect(c.reason.TRANSPORT_ERROR) + c._trigger_event.assert_any_call('disconnect', '/', + c.reason.TRANSPORT_ERROR) + c._trigger_event.assert_any_call('__disconnect_final', '/') assert c.sid is None assert not c.connected c.start_background_task.assert_not_called() diff --git a/tests/common/test_namespace.py b/tests/common/test_namespace.py index 8bfa9899..f1476e41 100644 --- a/tests/common/test_namespace.py +++ b/tests/common/test_namespace.py @@ -19,13 +19,25 @@ def on_connect(self, sid, environ): def test_disconnect_event(self): result = {} + class MyNamespace(namespace.Namespace): + def on_disconnect(self, sid, reason): + result['result'] = (sid, reason) + + ns = MyNamespace('/foo') + ns._set_server(mock.MagicMock()) + ns.trigger_event('disconnect', 'sid', 'foo') + assert result['result'] == ('sid', 'foo') + + def test_legacy_disconnect_event(self): + result = {} + class MyNamespace(namespace.Namespace): def on_disconnect(self, sid): result['result'] = sid ns = MyNamespace('/foo') ns._set_server(mock.MagicMock()) - ns.trigger_event('disconnect', 'sid') + ns.trigger_event('disconnect', 'sid', 'foo') assert result['result'] == 'sid' def test_event(self): @@ -216,6 +228,30 @@ def test_disconnect(self): ns.disconnect('sid', namespace='/bar') ns.server.disconnect.assert_called_with('sid', namespace='/bar') + def test_disconnect_event_client(self): + result = {} + + class MyNamespace(namespace.ClientNamespace): + def on_disconnect(self, reason): + result['result'] = reason + + ns = MyNamespace('/foo') + ns._set_client(mock.MagicMock()) + ns.trigger_event('disconnect', 'foo') + assert result['result'] == 'foo' + + def test_legacy_disconnect_event_client(self): + result = {} + + class MyNamespace(namespace.ClientNamespace): + def on_disconnect(self): + result['result'] = 'ok' + + ns = MyNamespace('/foo') + ns._set_client(mock.MagicMock()) + ns.trigger_event('disconnect', 'foo') + assert result['result'] == 'ok' + def test_event_not_found_client(self): result = {} diff --git a/tests/common/test_server.py b/tests/common/test_server.py index 57ddc2f4..445d5d9e 100644 --- a/tests/common/test_server.py +++ b/tests/common/test_server.py @@ -39,7 +39,7 @@ def test_on_event(self, eio): def foo(): pass - def bar(): + def bar(reason): pass s.on('disconnect', bar) @@ -510,8 +510,21 @@ def test_handle_disconnect(self, eio): s.on('disconnect', handler) s._handle_eio_connect('123', 'environ') s._handle_eio_message('123', '0') - s._handle_eio_disconnect('123') - handler.assert_called_once_with('1') + s._handle_eio_disconnect('123', 'foo') + handler.assert_called_once_with('1', 'foo') + s.manager.disconnect.assert_called_once_with('1', '/', + ignore_queue=True) + assert s.environ == {} + + def test_handle_legacy_disconnect(self, eio): + s = server.Server() + s.manager.disconnect = mock.MagicMock() + handler = mock.MagicMock(side_effect=[TypeError, None]) + s.on('disconnect', handler) + s._handle_eio_connect('123', 'environ') + s._handle_eio_message('123', '0') + s._handle_eio_disconnect('123', 'foo') + handler.assert_called_with('1') s.manager.disconnect.assert_called_once_with('1', '/', ignore_queue=True) assert s.environ == {} @@ -524,9 +537,9 @@ def test_handle_disconnect_namespace(self, eio): s.on('disconnect', handler_namespace, namespace='/foo') s._handle_eio_connect('123', 'environ') s._handle_eio_message('123', '0/foo,') - s._handle_eio_disconnect('123') + s._handle_eio_disconnect('123', 'foo') handler.assert_not_called() - handler_namespace.assert_called_once_with('1') + handler_namespace.assert_called_once_with('1', 'foo') assert s.environ == {} def test_handle_disconnect_only_namespace(self, eio): @@ -539,13 +552,14 @@ def test_handle_disconnect_only_namespace(self, eio): s._handle_eio_message('123', '0/foo,') s._handle_eio_message('123', '1/foo,') assert handler.call_count == 0 - handler_namespace.assert_called_once_with('1') + handler_namespace.assert_called_once_with( + '1', s.reason.CLIENT_DISCONNECT) assert s.environ == {'123': 'environ'} def test_handle_disconnect_unknown_client(self, eio): mgr = mock.MagicMock() s = server.Server(client_manager=mgr) - s._handle_eio_disconnect('123') + s._handle_eio_disconnect('123', 'foo') def test_handle_event(self, eio): s = server.Server(async_handlers=False) @@ -596,7 +610,8 @@ def test_handle_event_with_catchall_namespace(self, eio): s._handle_eio_message('123', '2/bar,["msg","a","b"]') s._handle_eio_message('123', '2/foo,["my message","a","b","c"]') s._handle_eio_message('123', '2/bar,["my message","a","b","c"]') - s._trigger_event('disconnect', '/bar', sid_bar) + s._trigger_event('disconnect', '/bar', sid_bar, + s.reason.CLIENT_DISCONNECT) connect_star_handler.assert_called_once_with('/bar', sid_bar) msg_foo_handler.assert_called_once_with(sid_foo, 'a', 'b') msg_star_handler.assert_called_once_with('/bar', sid_bar, 'a', 'b') @@ -825,8 +840,8 @@ class MyNamespace(namespace.Namespace): def on_connect(self, sid, environ): result['result'] = (sid, environ) - def on_disconnect(self, sid): - result['result'] = ('disconnect', sid) + def on_disconnect(self, sid, reason): + result['result'] = ('disconnect', sid, reason) def on_foo(self, sid, data): result['result'] = (sid, data) @@ -849,7 +864,8 @@ def on_baz(self, sid, data1, data2): s._handle_eio_message('123', '2/foo,["baz","a","b"]') assert result['result'] == ('a', 'b') s.disconnect('1', '/foo') - assert result['result'] == ('disconnect', '1') + assert result['result'] == ('disconnect', '1', + s.reason.SERVER_DISCONNECT) def test_catchall_namespace_handler(self, eio): result = {}