From 0469b8c1682032bf05e8699151328455c3af0740 Mon Sep 17 00:00:00 2001 From: Kirill Kuzminykh Date: Tue, 8 Feb 2022 11:13:15 +0300 Subject: [PATCH 01/23] - Fixed processing of heartbeats and a session expiration. - Fixed ping-pong based heartbeats for web-socket connections. - Added arguments ``heartbeat_delay`` and ``disconnect_delay`` into ``Session.__init__()``. - Added argument ``disconnect_delay`` into ``SessionManager.__init__()``. - **Breaking change:** Removed argument ``timeout`` from ``Session.__init__()`` and ``SessionManager.__init__()``. - **Breaking change:** Argument ``heartbeat`` of ``SessionManager.__init__()`` renamed into ``heartbeat_delay``. - **Breaking change:** ``Session.registry`` renamed into ``Session.app``. - **Breaking change:** Dropped support of Python < 3.7 --- .gitignore | 1 + .travis.yml | 15 +- CHANGES.rst | 15 ++ MANIFEST.in | 1 - Makefile | 10 +- README.rst | 2 +- docs/Makefile | 67 ------- docs/api.rst | 23 --- docs/conf.py | 185 ------------------- docs/index.rst | 18 -- docs/install.rst | 58 ------ docs/overview.rst | 175 ------------------ examples/chat.html | 172 +++++++++--------- requirements.txt | 20 +-- setup.cfg | 3 + setup.py | 22 ++- sockjs-testsrv.py | 4 +- sockjs/__init__.py | 35 ++-- sockjs/protocol.py | 16 +- sockjs/route.py | 68 +++---- sockjs/session.py | 256 ++++++++++++++++----------- sockjs/transports/__init__.py | 25 ++- sockjs/transports/base.py | 124 +++++++------ sockjs/transports/eventsource.py | 17 +- sockjs/transports/htmlfile.py | 28 +-- sockjs/transports/jsonp.py | 28 +-- sockjs/transports/rawwebsocket.py | 83 +++++++-- sockjs/transports/utils.py | 26 ++- sockjs/transports/websocket.py | 126 +++++++++---- sockjs/transports/xhr.py | 8 +- sockjs/transports/xhrsend.py | 8 +- sockjs/transports/xhrstreaming.py | 10 +- tests/asdf | 0 tests/conftest.py | 71 ++++---- tests/test_route.py | 25 +-- tests/test_session.py | 183 ++++++++++--------- tests/test_transport.py | 54 +++--- tests/test_transport_eventsource.py | 17 +- tests/test_transport_htmlfile.py | 17 +- tests/test_transport_jsonp.py | 16 +- tests/test_transport_rawwebsocket.py | 14 +- tests/test_transport_websocket.py | 48 +++-- tests/test_transport_xhr.py | 14 +- tests/test_transport_xhrsend.py | 14 +- tests/test_transport_xhrstreaming.py | 14 +- 45 files changed, 922 insertions(+), 1214 deletions(-) delete mode 100644 docs/Makefile delete mode 100644 docs/api.rst delete mode 100644 docs/conf.py delete mode 100644 docs/index.rst delete mode 100644 docs/install.rst delete mode 100644 docs/overview.rst delete mode 100644 tests/asdf diff --git a/.gitignore b/.gitignore index 7bb7b93d..8a7d23b7 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ eggs sources dist develop-eggs +build *.egg-info *.pyc *.pyo diff --git a/.travis.yml b/.travis.yml index 38727d34..d0c97e62 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,14 +1,19 @@ language: python + + python: - - 3.5 - - 3.6 - 3.7 - 3.8 + - 3.9 + - 3.10 + + install: - pip install --upgrade setuptools - pip install -r requirements.txt - pip install codecov - - python setup.py develop + - pip install -e .[test] + script: - make cov @@ -16,9 +21,9 @@ script: after_success: - - codecov + deploy: provider: pypi user: aio-libs-bot @@ -28,4 +33,4 @@ deploy: on: tags: true all_branches: true - python: 3.6 + python: 3.7 diff --git a/CHANGES.rst b/CHANGES.rst index 7eb64c5e..00392318 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -2,6 +2,21 @@ CHANGES ======= +0.12.0 (not released yet) +------------------------- + +- Fixed processing of heartbeats and a session expiration. +- Fixed ping-pong based heartbeats for web-socket connections. +- Added arguments ``heartbeat_delay`` and ``disconnect_delay`` into + ``Session.__init__()``. +- Added argument ``disconnect_delay`` into ``SessionManager.__init__()``. +- **Breaking change:** Removed argument ``timeout`` from ``Session.__init__()`` + and ``SessionManager.__init__()``. +- **Breaking change:** Argument ``heartbeat`` of ``SessionManager.__init__()`` + renamed into ``heartbeat_delay``. +- **Breaking change:** ``Session.registry`` renamed into ``Session.app``. +- **Breaking change:** Dropped support of Python < 3.7 + 0.11.0 (2020-10-22) ------------------- diff --git a/MANIFEST.in b/MANIFEST.in index bf38be55..6cd5c519 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -4,7 +4,6 @@ include README.rst include Makefile include sockjs-testsrv.py graft sockjs -graft docs graft examples graft tests global-exclude *.pyc diff --git a/Makefile b/Makefile index aba1428b..ae5d3eb2 100644 --- a/Makefile +++ b/Makefile @@ -13,9 +13,8 @@ flake: fmt: black sockjs tests setup.py - develop: - python setup.py develop + pip install -e .[test] test: flake develop pytest $(FLAGS) ./tests/ @@ -40,11 +39,6 @@ clean: rm -rf coverage rm -rf build rm -rf cover - make -C docs clean python setup.py clean -doc: - make -C docs html - @echo "open file://`pwd`/docs/_build/html/index.html" - -.PHONY: all build venv flake test vtest testloop cov clean doc +.PHONY: all build venv flake test vtest testloop cov clean diff --git a/README.rst b/README.rst index 10625f10..37139570 100644 --- a/README.rst +++ b/README.rst @@ -73,7 +73,7 @@ Supported transports Requirements ------------ -- Python 3.5.3 +- Python 3.7.0 - gunicorn 19.2.0 diff --git a/docs/Makefile b/docs/Makefile deleted file mode 100644 index 8ac72307..00000000 --- a/docs/Makefile +++ /dev/null @@ -1,67 +0,0 @@ -# Makefile for Sphinx documentation -# - -# You can set these variables from the command line. -SPHINXOPTS = -W -SPHINXBUILD = ../../../bin/sphinx-build -PAPER = - -# Internal variables. -PAPEROPT_a4 = -D latex_paper_size=a4 -PAPEROPT_letter = -D latex_paper_size=letter -ALLSPHINXOPTS = -d _build/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . - -.PHONY: help clean html web pickle htmlhelp latex changes linkcheck - -help: - @echo "Please use \`make ' where is one of" - @echo " html to make standalone HTML files" - @echo " pickle to make pickle files (usable by e.g. sphinx-web)" - @echo " htmlhelp to make HTML files and a HTML help project" - @echo " changes to make an overview over all changed/added/deprecated items" - @echo " linkcheck to check all external links for integrity" - -clean: - -rm -rf _build/* - -html: - mkdir -p _build/html _build/doctrees - $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) _build/html - @echo - @echo "Build finished. The HTML pages are in _build/html." - -text: - mkdir -p _build/text _build/doctrees - $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) _build/text - @echo - @echo "Build finished. The HTML pages are in _build/text." - -pickle: - mkdir -p _build/pickle _build/doctrees - $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) _build/pickle - @echo - @echo "Build finished; now you can process the pickle files or run" - @echo " sphinx-web _build/pickle" - @echo "to start the sphinx-web server." - -web: pickle - -htmlhelp: - mkdir -p _build/htmlhelp _build/doctrees - $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) _build/htmlhelp - @echo - @echo "Build finished; now you can run HTML Help Workshop with the" \ - ".hhp project file in _build/htmlhelp." - -changes: - mkdir -p _build/changes _build/doctrees - $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) _build/changes - @echo - @echo "The overview file is in _build/changes." - -linkcheck: - mkdir -p _build/linkcheck _build/doctrees - $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) _build/linkcheck - @echo - @echo "Link check complete; look for any errors in the above output " \ - "or in _build/linkcheck/output.txt." diff --git a/docs/api.rst b/docs/api.rst deleted file mode 100644 index fa0eb87c..00000000 --- a/docs/api.rst +++ /dev/null @@ -1,23 +0,0 @@ -API -=== - - -.. automodule:: pyramid_sockjs - -.. autofunction:: get_session_manager - -Session states: - - .. py:data:: STATE_NEW - - .. py:data:: STATE_OPEN - - .. py:data:: STATE_CLOSING - - .. py:data:: STATE_CLOSED - -.. autoclass:: Session(id) - :members: - -.. autoclass:: SessionManager - :members: diff --git a/docs/conf.py b/docs/conf.py deleted file mode 100644 index f7c2b044..00000000 --- a/docs/conf.py +++ /dev/null @@ -1,185 +0,0 @@ -# -*- coding: utf-8 -*- -# -# ptah documentation build configuration file -# -# This file is execfile()d with the current directory set to its containing -# dir. -# -# The contents of this file are pickled, so don't put values in the -# namespace that aren't pickleable (module imports are okay, they're -# removed automatically). -# -# All configuration values have a default value; values that are commented -# out serve to show the default value. - - -# General configuration -# --------------------- - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. -extensions = ['sphinx.ext.intersphinx', - 'sphinx.ext.autodoc'] - -# Add any paths that contain templates here, relative to this directory. -templates_path = ['.templates'] - -# The suffix of source filenames. -source_suffix = '.rst' - -# The master toctree document. -master_doc = 'index' - -# General substitutions. -project = 'pyramid_sockjs' - -# The default replacements for |version| and |release|, also used in various -# other places throughout the built documents. -# -# The short X.Y version. -version = '2.0dev1' -# The full version, including alpha/beta/rc tags. -release = version - -# There are two options for replacing |today|: either, you set today to -# some non-false value, then it is used: -#today = '' -# Else, today_fmt is used as the format for a strftime call. -today_fmt = '%B %d, %Y' - -# List of documents that shouldn't be included in the build. -#unused_docs = ['_themes/README'] - -# List of directories, relative to source directories, that shouldn't be -# searched for source files. -#exclude_dirs = [] - -# The reST default role (used for this markup: `text`) to use for all -# documents. -#default_role = None - -# If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True - -# If true, the current module name will be prepended to all description -# unit titles (such as .. function::). -#add_module_names = True - -# If true, sectionauthor and moduleauthor directives will be shown in the -# output. They are ignored by default. -#show_authors = False - -# The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' - - -# Options for HTML output -# ----------------------- - -#sys.path.append(os.path.abspath('_themes')) -#html_theme_path = ['_themes'] -#html_theme = 'pylons' - -# The style sheet to use for HTML and HTML Help pages. A file of that name -# must exist either in Sphinx' static/ path, or in one of the custom paths -# given in html_static_path. -#html_style = 'pylons.css' - -# The name for this set of Sphinx documents. If None, it defaults to -# " v documentation". -#html_title = None - -# A shorter title for the navigation bar. Default is the same as -# html_title. -#html_short_title = None - -# The name of an image file (within the static path) to place at the top of -# the sidebar. -#html_logo = '.static/logo_hi.gif' - -# The name of an image file (within the static path) to use as favicon of -# the docs. This file should be a Windows icon file (.ico) being 16x16 or -# 32x32 pixels large. -#html_favicon = None - -# Add any paths that contain custom static files (such as style sheets) -# here, relative to this directory. They are copied after the builtin -# static files, so a file named "default.css" will overwrite the builtin -# "default.css". -#html_static_path = ['.static'] - -# If not '', a 'Last updated on:' timestamp is inserted at every page -# bottom, using the given strftime format. -html_last_updated_fmt = '%b %d, %Y' - -# If true, SmartyPants will be used to convert quotes and dashes to -# typographically correct entities. -#html_use_smartypants = True - -# Custom sidebar templates, maps document names to template names. -#html_sidebars = {} - -# Additional templates that should be rendered to pages, maps page names to -# template names. -#html_additional_pages = {} - -# If false, no module index is generated. -#html_use_modindex = True - -# If false, no index is generated. -#html_use_index = True - -# If true, the index is split into individual pages for each letter. -#html_split_index = False - -# If true, the reST sources are included in the HTML build as -# _sources/. -#html_copy_source = True - -# If true, an OpenSearch description file will be output, and all pages -# will contain a tag referring to it. The value of this option must -# be the base URL from which the finished HTML is served. -#html_use_opensearch = '' - -# If nonempty, this is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = '' - -# Output file base name for HTML help builder. -htmlhelp_basename = 'atemplatedoc' - - -# Options for LaTeX output -# ------------------------ - -# The paper size ('letter' or 'a4'). -#latex_paper_size = 'letter' - -# The font size ('10pt', '11pt' or '12pt'). -#latex_font_size = '10pt' - -# Grouping the document tree into LaTeX files. List of tuples -# (source start file, target name, title, -# author, document class [howto/manual]). -latex_documents = [ - ('index', 'atemplate.tex', 'Pyramid SockJS documentation', - 'Developers', 'manual'), -] - -# The name of an image file (relative to this directory) to place at the -# top of the title page. -latex_logo = '.static/logo_hi.gif' - -# For "manual" documents, if this is true, then toplevel headings are -# parts, not chapters. -#latex_use_parts = False - -# Additional stuff for the LaTeX preamble. -#latex_preamble = '' - -# Documents to append as an appendix to all manuals. -#latex_appendices = [] - -# If false, no module index is generated. -#latex_use_modindex = True - -#autoclass_content = 'both' diff --git a/docs/index.rst b/docs/index.rst deleted file mode 100644 index a70d01cd..00000000 --- a/docs/index.rst +++ /dev/null @@ -1,18 +0,0 @@ -Pyramid SockJS -============== - -.. toctree:: - :maxdepth: 2 - - overview.rst - install.rst - api.rst - - -Indices and tables ------------------- - -* :ref:`genindex` -* :ref:`modindex` -* :ref:`search` - diff --git a/docs/install.rst b/docs/install.rst deleted file mode 100644 index 21835623..00000000 --- a/docs/install.rst +++ /dev/null @@ -1,58 +0,0 @@ -============ -Installation -============ - -virtualenv -========== - - -1. Install virtualenv:: - - $ wget https://raw.github.com/pypa/virtualenv/master/virtualenv.py - $ python2.7 ./virtualenv.py --no-site-packages sockjs - -2. Install gevent 1.0b2 (non-Windows users):: - - $ ./sockjs/bin/pip install http://gevent.googlecode.com/files/gevent-1.0b2.tar.gz - -2. Install gevent 1.0b2 (Windows users, presuming you are running 32bit Python 2.7):: - - $ ./sockjs/Scripts/easy_install http://gevent.googlecode.com/files/gevent-1.0b2-py2.7-win32.egg - -3. Clone pyramid_sockjs from github and then install:: - - $ git clone https://github.com/fafhrd91/pyramid_sockjs.git - $ cd pyramid_sockjs - $ ../sockjs/bin/python setup.py develop - - -Server config -============= - -To use gevent based server use following configuration -for server section:: - - [server:main] - use = egg:pyramid_sockjs#server - host = 127.0.0.1 - port = 8080 - -To use gunicorn server use following configuation for server section, -gunicorn 0.14.3 or greater is required:: - - [server:main] - use = egg:gunicorn - host = 127.0.0.1 - port = 8080 - workers = 1 - worker_class = gevent - - -Chat example -============ - -You can run `chat` example with following command. It doesnt require -any configuration, it runs on host ``127.0.0.1`` and port ``8080``:: - - - $ ./sockjs/bin/python ./pyramid_sockjs/examples/chat.py diff --git a/docs/overview.rst b/docs/overview.rst deleted file mode 100644 index 9d5110ef..00000000 --- a/docs/overview.rst +++ /dev/null @@ -1,175 +0,0 @@ -Pyramid SockJS -============== - -Overview --------- - -Gevent-based SockJS integration for Pyramid. SockJS interface is -implemented as pyramid route. It runs inside wsgi app rather than wsgi server. -It's possible to create any number of different sockjs routes, ie -`/__sockjs__/*` or `/mycustom-sockjs/*`. also you can provide different -session implementation and management for each of sockjs routes. - -Gevent based server is required for ``pyramid_sockjs``. -For example ``gunicorn`` with gevent worker. ``pyramid_sockjs`` provides -simple paster server runner: - -.. code-block:: text - :linenos: - - [server:main] - use = egg:pyramid_sockjs#server - host = 0.0.0.0 - port = 8080 - -Example of sockjs route: - -.. code-block:: python - - def main(global_settings, **settings): - config = Configurator(settings=settings) - config.add_sockjs_route() - - return config.make_wsgi_app() - -By default :py:func:`add_sockjs_route` directive creates sockjs route -with empty name and prefix ``/__sockjs__``, so js client code should look like: - - -.. code-block:: javascript - - - - - -All interactions between client and server happen through `Sessions`. -Its possible to override default session with custom implementation. -Default session is very stupid, its even not possible to receive -client messages, so in most cases it is required to replace session. -Let's implement `echo` session as example: - -.. code-block:: python - - from pyramid_sockjs.session import Session - - class EchoSession(Session): - - def on_open(self): - self.send('Hello') - self.manager.broadcast("Someone joined.") - - def on_message(self, message): - self.send(message) - - def on_close(self): - self.manager.broadcast("Someone left.") - -To use custom session implementation pass it to :py:func:`add_sockjs_route` -directive: - -.. code-block:: python - - def main(global_settings, **settings): - config = Configurator(settings=settings) - - config.add_sockjs_route(session=EchoSession) - - return config.make_wsgi_app() - - -Sessions are managed by ``SessionManager``, each sockjs route has separate -session manager. Session manage is addressed by same name as sockjs route. -To get session manager use :py:func:`get_sockjs_manager` -request function. - -.. code-block:: python - - def main(...): - ... - config.add_sockjs_route('chat-service') - ... - config.add_route('broadcast', '/broadcast') - ... - return config.make_wsgi_app() - - - @view_config(route_name='broadcast', renderer='string') - def send_message(request): - message = request.GET.get('message') - if message: - manager = request.get_sockjs_manager('chat-service') - for session in manager.active_session(): - session.send(message) - - return 'Message has been sent' - - -To use custom ``SessionManager`` pass it as `session_manager=` argument -to :py:func:`add_sockjs_route` configurator directive. -Check :py:class:`pyramid_sockjs.Session` -and :py:class:`pyramid_sockjs.SessionManager` api for -detailed description. - - -Supported transports --------------------- - -* websocket (`hixie-76 `_ - and `hybi-10 `_) -* `xhr-streaming `_ -* `xhr-polling `_ -* `iframe-xhr-polling `_ -* iframe-eventsource (`EventSource `_ used from an. - `iframe via postMessage `_) -* iframe-htmlfile (`HtmlFile `_ - used from an `iframe via postMessage `_.) -* `jsonp-polling `_ - - -Limitations ------------ - -Pyramid sockjs does not support multple websocket session with same session id. - -gevent does not support Python 3 - -Requirements ------------- - -- Python 2.6/2.7 - -- `virtualenv `_ - -- `gevent 1.0b1 or greater `_ - -- `gevent-websocket 0.3.0 or greater `_ - -- `gunicorn 0.14.3 or greater `_ - - -Examples --------- - -You can find them in the `examples` repository at github. - -https://github.com/fafhrd91/pyramid_sockjs/tree/master/examples - - -License -------- - -pyramid_sockjs is offered under the BSD license. diff --git a/examples/chat.html b/examples/chat.html index b576a37b..7c126637 100644 --- a/examples/chat.html +++ b/examples/chat.html @@ -2,88 +2,98 @@ - + +

Chat!

@@ -129,8 +139,8 @@

Chat!

- - + +
diff --git a/requirements.txt b/requirements.txt index cba0344b..8ec0c149 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,10 @@ -black==20.8b1; python_version>="3.6" -flake8==3.8.4 -docutils==0.16 -pytest==6.1.1 -pytest-aiohttp==0.3.0 -pytest-cov==2.10.1 +black==22.1.0; python_version>="3.6" +flake8==4.0.1 +pytest==7.0.0 +pytest-aiohttp==1.0.3 +pytest-cov==3.0.0 pytest-sugar==0.9.4 -pytest-mock==3.3.1 -pytest-timeout==1.4.2 -sphinx==3.2.1 -aiohttp==3.6.3 --e . +pytest-mock==3.7.0 +pytest-timeout==2.1.0 +aiohttp==3.8.1 +-e .[test] diff --git a/setup.cfg b/setup.cfg index 6c896046..222c4da9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,11 +1,14 @@ [easy_install] zip_ok = false + [flake8] ignore = N801,N802,N803,E226 max-line-length = 88 + [tool:pytest] timeout = 3 filterwarnings= error +asyncio_mode = auto diff --git a/setup.py b/setup.py index 30a3f90b..0ca1cfd6 100644 --- a/setup.py +++ b/setup.py @@ -3,6 +3,7 @@ import re from setuptools import setup, find_packages + with codecs.open( os.path.join(os.path.abspath(os.path.dirname(__file__)), "sockjs", "__init__.py"), "r", @@ -13,8 +14,6 @@ except IndexError: raise RuntimeError("Unable to determine version.") -install_requires = ["aiohttp >= 3.0.0"] - def read(f): return open(os.path.join(os.path.dirname(__file__), f)).read().strip() @@ -31,10 +30,10 @@ def read(f): "Intended Audience :: Developers", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.5", - "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", "Programming Language :: Python :: Implementation :: CPython", "Topic :: Internet :: WWW/HTTP", "Framework :: AsyncIO", @@ -44,8 +43,19 @@ def read(f): url="https://github.com/aio-libs/sockjs/", license="Apache 2", packages=find_packages(), - python_requires=">=3.5.3", - install_requires=install_requires, + python_requires=">=3.7.0", + install_requires=[ + "aiohttp>=3.0.0", + ], + extras_require={ + "test": [ + "pytest", + "multidict", + "yarl", + "pytest-aiohttp", + "pytest-timeout", + ], + }, include_package_data=True, zip_safe=False, ) diff --git a/sockjs-testsrv.py b/sockjs-testsrv.py index 6c2a7bce..881c4fce 100644 --- a/sockjs-testsrv.py +++ b/sockjs-testsrv.py @@ -33,7 +33,7 @@ async def broadcastSession(msg, session): EventsourceTransport.maxsize = 4096 XHRStreamingTransport.maxsize = 4096 - app = web.Application(loop=loop) + app = web.Application() sockjs.add_endpoint( app, echoSession, name='echo', prefix='/echo') @@ -48,4 +48,4 @@ async def broadcastSession(msg, session): app, echoSession, name='cookie', prefix='/cookie_needed_echo', cookie_needed=True) - web.run_app(app) + web.run_app(app, port=8081) diff --git a/sockjs/__init__.py b/sockjs/__init__.py index 090bdac1..ff2b35e8 100644 --- a/sockjs/__init__.py +++ b/sockjs/__init__.py @@ -1,26 +1,19 @@ -# pyramid_sockjs - -# Session, SessionManager are not imported - -from sockjs.session import Session -from sockjs.session import SessionManager -from sockjs.exceptions import SessionIsClosed -from sockjs.exceptions import SessionIsAcquired - -from sockjs.protocol import STATE_NEW -from sockjs.protocol import STATE_OPEN -from sockjs.protocol import STATE_CLOSING -from sockjs.protocol import STATE_CLOSED - -from sockjs.protocol import MSG_OPEN -from sockjs.protocol import MSG_MESSAGE -from sockjs.protocol import MSG_CLOSE -from sockjs.protocol import MSG_CLOSED - -from sockjs.route import get_manager, add_endpoint +from .exceptions import SessionIsAcquired, SessionIsClosed +from .protocol import ( + MSG_CLOSE, + MSG_CLOSED, + MSG_MESSAGE, + MSG_OPEN, + STATE_CLOSED, + STATE_CLOSING, + STATE_NEW, + STATE_OPEN, +) +from .route import add_endpoint, get_manager +from .session import Session, SessionManager -__version__ = "0.11.0" +__version__ = "0.12.0" __all__ = ( diff --git a/sockjs/protocol.py b/sockjs/protocol.py index cf9ba367..b56870f2 100644 --- a/sockjs/protocol.py +++ b/sockjs/protocol.py @@ -1,6 +1,8 @@ -import collections +import dataclasses import hashlib from datetime import datetime +from typing import Optional + ENCODING = "utf-8" @@ -74,8 +76,8 @@ def dthandler(obj): IFRAME_HTML = """ - - + + \r\n" % dumps(text)).encode(ENCODING) - await self.response.write(blob) - - self.size += len(blob) - if self.size > self.maxsize: - return True - else: - return False + async def _send(self, text: str): + text = "\r\n" % dumps(text) + return await super()._send(text) async def process(self): request = self.request - try: - callback = request.query.get("c", None) - except Exception: - callback = request.GET.get("c", None) - + callback = request.query.get("c") if callback is None: await self.session._remote_closed() return web.HTTPInternalServerError(text='"callback" parameter required') diff --git a/sockjs/transports/jsonp.py b/sockjs/transports/jsonp.py index ed1d4ec8..6bd0c932 100644 --- a/sockjs/transports/jsonp.py +++ b/sockjs/transports/jsonp.py @@ -2,22 +2,22 @@ import re from urllib.parse import unquote_plus -from aiohttp import web, hdrs +from aiohttp import hdrs, web from .base import StreamingTransport -from .utils import CACHE_CONTROL, session_cookie, cors_headers -from ..protocol import dumps, loads, ENCODING +from .utils import CACHE_CONTROL, cors_headers, session_cookie +from ..protocol import ENCODING, dumps, loads class JSONPolling(StreamingTransport): - + create_session = True + maxsize = 0 check_callback = re.compile(r"^[a-zA-Z0-9_\.]+$") callback = "" - async def send(self, text): - data = "/**/%s(%s);\r\n" % (self.callback, dumps(text)) - await self.response.write(data.encode(ENCODING)) - return True + async def _send(self, text: str): + text = "/**/%s(%s);\r\n" % (self.callback, dumps(text)) + return await super()._send(text) async def process(self): session = self.session @@ -25,11 +25,7 @@ async def process(self): meth = request.method if request.method == hdrs.METH_GET: - try: - callback = self.callback = request.query.get("c") - except Exception: - callback = self.callback = request.GET.get("c") - + callback = self.callback = request.query.get("c") if not callback: await self.session._remote_closed() return web.HTTPInternalServerError(text='"callback" parameter required') @@ -74,7 +70,7 @@ async def process(self): await session._remote_messages(messages) headers = ( - (hdrs.CONTENT_TYPE, "text/html;charset=UTF-8"), + (hdrs.CONTENT_TYPE, "text/plain;charset=UTF-8"), (hdrs.CACHE_CONTROL, CACHE_CONTROL), ) headers += session_cookie(request) @@ -82,3 +78,7 @@ async def process(self): else: return web.HTTPBadRequest(text="No support for such method: %s" % meth) + + +class JSONPollingSend(JSONPolling): + create_session = False diff --git a/sockjs/transports/rawwebsocket.py b/sockjs/transports/rawwebsocket.py index 3997952d..140ac3a3 100644 --- a/sockjs/transports/rawwebsocket.py +++ b/sockjs/transports/rawwebsocket.py @@ -1,19 +1,45 @@ """raw websocket transport.""" import asyncio -from aiohttp import web - from asyncio import ensure_future +from typing import Optional +from uuid import uuid4 + +from aiohttp import web +from async_timeout import timeout from .base import Transport +from .utils import cancel_tasks from ..exceptions import SessionIsClosed -from ..protocol import FRAME_CLOSE, FRAME_MESSAGE, FRAME_MESSAGE_BLOB, FRAME_HEARTBEAT +from ..protocol import FRAME_CLOSE, FRAME_HEARTBEAT, FRAME_MESSAGE, FRAME_MESSAGE_BLOB +from ..session import Session, SessionManager class RawWebSocketTransport(Transport): - async def server(self, ws, session): + heartbeat_timeout = 10 + + @classmethod + def get_session(cls, manager: SessionManager, session_id: str) -> Session: + # For WebSockets, as opposed to other transports, it is valid to + # reuse `session_id`. The lifetime of SockJS WebSocket session is + # defined by a lifetime of underlying WebSocket connection. It is + # correct to have two separate sessions sharing the same + # `session_id` at the same time. + + # Generate unique session_id based on given ID. + orig_session_id = session_id + while session_id in manager: + session_id = "%s-%s" % (orig_session_id, uuid4().hex[-8:]) + return super().get_session(manager, session_id) + + def __init__(self, manager: SessionManager, session: Session, request: web.Request): + super().__init__(manager, session, request) + self._pong_event = asyncio.Event() + self._wait_pong_task: Optional[asyncio.Task] = None + + async def server(self, ws: web.WebSocketResponse): while True: try: - frame, data = await session._wait(pack=False) + frame, data = await self.session._get_frame(pack=False) except SessionIsClosed: break @@ -27,15 +53,32 @@ async def server(self, ws, session): await ws.send_str(data) elif frame == FRAME_HEARTBEAT: await ws.ping() + if self._wait_pong_task is None: + self._wait_pong_task = asyncio.create_task(self._wait_pong()) + self._wait_pong_task.add_done_callback(self._wait_done_callback) elif frame == FRAME_CLOSE: try: - await ws.close(message="Go away!") + await ws.close(message=b"Go away!") finally: - await session._remote_closed() + await self.session._remote_closed() + + async def _wait_pong(self): + try: + async with timeout(self.heartbeat_timeout): + await self._pong_event.wait() + except asyncio.TimeoutError: + self.session.close(3000, "No response from heartbeat") + finally: + self._pong_event.clear() + + def _wait_done_callback(self, _): + self._wait_pong_task = None - async def client(self, ws, session): + async def client(self, ws: web.WebSocketResponse): while True: msg = await ws.receive() + if self._wait_pong_task is not None: + self._pong_event.set() if msg.type == web.WSMsgType.text: if not msg.data: @@ -48,31 +91,35 @@ async def client(self, ws, session): break elif msg.type == web.WSMsgType.PONG: self.session._tick() + elif msg.type == web.WSMsgType.PING: + await ws.pong(msg.data) + self.session._tick() async def process(self): # start websocket connection - ws = self.ws = web.WebSocketResponse(autoping=False) + ws = web.WebSocketResponse(autoping=False) await ws.prepare(self.request) try: - await self.manager.acquire(self.session) + await self.manager.acquire(self.session, self.request) except Exception: # should use specific exception - await ws.close(message="Go away!") + await ws.close(message=b"Go away!") return ws - server = ensure_future(self.server(ws, self.session)) - client = ensure_future(self.client(ws, self.session)) + server = ensure_future(self.server(ws)) + client = ensure_future(self.client(ws)) try: - await asyncio.wait((server, client), return_when=asyncio.FIRST_COMPLETED) + await asyncio.wait( + (server, client), + return_when=asyncio.FIRST_COMPLETED, + ) except asyncio.CancelledError: raise except Exception as exc: await self.session._remote_close(exc) finally: + self.session.expire() await self.manager.release(self.session) - if not server.done(): - server.cancel() - if not client.done(): - client.cancel() + await cancel_tasks(server, client, self._wait_pong_task) return ws diff --git a/sockjs/transports/utils.py b/sockjs/transports/utils.py index b70da21e..1b00183e 100644 --- a/sockjs/transports/utils.py +++ b/sockjs/transports/utils.py @@ -1,7 +1,10 @@ +import asyncio import http.cookies -from aiohttp import hdrs from datetime import datetime, timedelta +import async_timeout +from aiohttp import hdrs + CACHE_CONTROL = "no-store, no-cache, no-transform, must-revalidate, max-age=0" @@ -30,7 +33,7 @@ def session_cookie(request): td365 = timedelta(days=365) td365seconds = str( - (td365.microseconds + (td365.seconds + td365.days * 24 * 3600) * 10 ** 6) // 10 ** 6 + (td365.microseconds + (td365.seconds + td365.days * 24 * 3600) * 10**6) // 10**6 ) @@ -41,3 +44,22 @@ def cache_headers(): (hdrs.CACHE_CONTROL, "max-age=%s, public" % td365seconds), (hdrs.EXPIRES, d.strftime("%a, %d %b %Y %H:%M:%S")), ) + + +async def cancel_tasks(*coros_or_futures, timeout=1): + """Cancel all not stopped coroutine or feature before exit + from this context manager. + """ + futures = [asyncio.ensure_future(cf) for cf in coros_or_futures if cf] + waiting_to_complete = [] + for fut in futures: + if not fut.cancelled() and fut.done(): + continue + fut.cancel() + waiting_to_complete.append(fut) + if waiting_to_complete: + try: + async with async_timeout.timeout(timeout): + await asyncio.gather(*waiting_to_complete, return_exceptions=True) + except asyncio.TimeoutError: + pass diff --git a/sockjs/transports/websocket.py b/sockjs/transports/websocket.py index 6d5d3622..fb2af2bb 100644 --- a/sockjs/transports/websocket.py +++ b/sockjs/transports/websocket.py @@ -1,35 +1,86 @@ """websocket transport""" import asyncio -from aiohttp import web - +import logging from asyncio import ensure_future +from typing import Optional +from uuid import uuid4 + +from aiohttp import web +from aiohttp.web_exceptions import HTTPMethodNotAllowed +from async_timeout import timeout from .base import Transport +from .utils import cancel_tasks from ..exceptions import SessionIsClosed -from ..protocol import STATE_CLOSED, FRAME_CLOSE -from ..protocol import loads, close_frame +from ..protocol import FRAME_CLOSE, FRAME_HEARTBEAT, STATE_CLOSED, close_frame, loads +from ..session import Session, SessionManager + + +log = logging.getLogger("sockjs") class WebSocketTransport(Transport): - async def server(self, ws, session): + heartbeat_timeout = 10 + + @classmethod + def get_session(cls, manager: SessionManager, session_id: str) -> Session: + # For WebSockets, as opposed to other transports, it is valid to + # reuse `session_id`. The lifetime of SockJS WebSocket session is + # defined by a lifetime of underlying WebSocket connection. It is + # correct to have two separate sessions sharing the same + # `session_id` at the same time. + + # Generate unique session_id based on given ID. + orig_session_id = session_id + while session_id in manager: + session_id = "%s-%s" % (orig_session_id, uuid4().hex[-8:]) + return super().get_session(manager, session_id) + + def __init__(self, manager: SessionManager, session: Session, request: web.Request): + super().__init__(manager, session, request) + self._pong_event = asyncio.Event() + self._wait_pong_task: Optional[asyncio.Task] = None + + async def server(self, ws: web.WebSocketResponse): while True: try: - frame, data = await session._wait() + frame, data = await self.session._get_frame() except SessionIsClosed: break - try: - await ws.send_str(data) - except OSError: - pass # ignore 'cannot write to closed transport' + + if frame == FRAME_HEARTBEAT: + await ws.ping() + log.debug("Send WS PING") + if self._wait_pong_task is None: + self._wait_pong_task = asyncio.create_task(self._wait_pong()) + self._wait_pong_task.add_done_callback(self._wait_done_callback) + continue + + await ws.send_str(data) + if frame == FRAME_CLOSE: try: await ws.close() finally: - await session._remote_closed() + await self.session._remote_closed() + + async def _wait_pong(self): + try: + async with timeout(self.heartbeat_timeout): + await self._pong_event.wait() + except asyncio.TimeoutError: + self.session.close(3000, "No response from heartbeat") + finally: + self._pong_event.clear() - async def client(self, ws, session): + def _wait_done_callback(self, _): + self._wait_pong_task = None + + async def client(self, ws: web.WebSocketResponse): while True: msg = await ws.receive() + if self._wait_pong_task is not None: + self._pong_event.set() if msg.type == web.WSMsgType.text: data = msg.data @@ -39,43 +90,56 @@ async def client(self, ws, session): try: text = loads(data) except Exception as exc: - await session._remote_close(exc) - await session._remote_closed() + await self.session._remote_close(exc) + await self.session._remote_closed() await ws.close(message=b"broken json") break if data.startswith("["): - await session._remote_messages(text) + await self.session._remote_messages(text) else: - await session._remote_message(text) - + await self.session._remote_message(text) + elif msg.type == web.WSMsgType.PONG: + log.debug("Received WS PONG") + self.session._tick() + elif msg.type == web.WSMsgType.PING: + log.debug("Received WS PING") + await ws.pong(msg.data) + self.session._tick() elif msg.type == web.WSMsgType.close: - await session._remote_close() + await self.session._remote_close() elif msg.type in (web.WSMsgType.closed, web.WSMsgType.closing): - await session._remote_closed() + await self.session._remote_closed() break async def process(self): + if self.request.method != "GET": + # WebSocket should only accept GET + raise HTTPMethodNotAllowed( + self.request.method, + ["GET"], + body=b"", + content_type="", + ) + # start websocket connection - ws = self.ws = web.WebSocketResponse() + ws = web.WebSocketResponse(autoping=False) await ws.prepare(self.request) # session was interrupted if self.session.interrupted: - await self.ws.send_str(close_frame(1002, "Connection interrupted")) - + await ws.send_str(close_frame(1002, "Connection interrupted")) elif self.session.state == STATE_CLOSED: - await self.ws.send_str(close_frame(3000, "Go away!")) - + await ws.send_str(close_frame(3000, "Go away!")) else: try: - await self.manager.acquire(self.session) + await self.manager.acquire(self.session, self.request) except Exception: # should use specific exception - await self.ws.send_str(close_frame(3000, "Go away!")) + await ws.send_str(close_frame(3000, "Go away!")) await ws.close() return ws - server = ensure_future(self.server(ws, self.session)) - client = ensure_future(self.client(ws, self.session)) + server = ensure_future(self.server(ws)) + client = ensure_future(self.client(ws)) try: await asyncio.wait( (server, client), return_when=asyncio.FIRST_COMPLETED @@ -85,10 +149,8 @@ async def process(self): except Exception as exc: await self.session._remote_close(exc) finally: + self.session.expire() await self.manager.release(self.session) - if not server.done(): - server.cancel() - if not client.done(): - client.cancel() + await cancel_tasks(server, client, self._wait_pong_task) return ws diff --git a/sockjs/transports/xhr.py b/sockjs/transports/xhr.py index 14b5df79..a57b428a 100644 --- a/sockjs/transports/xhr.py +++ b/sockjs/transports/xhr.py @@ -1,15 +1,19 @@ -from aiohttp import web, hdrs +from aiohttp import hdrs, web from .base import StreamingTransport -from .utils import CACHE_CONTROL, session_cookie, cors_headers, cache_headers +from .utils import CACHE_CONTROL, cache_headers, cors_headers, session_cookie class XHRTransport(StreamingTransport): """Long polling derivative transports, used for XHRPolling and JSONPolling.""" + create_session = True maxsize = 0 + async def _send(self, text: str): + return await super()._send(text + "\n") + async def process(self): request = self.request diff --git a/sockjs/transports/xhrsend.py b/sockjs/transports/xhrsend.py index 2a512eea..5a53391d 100644 --- a/sockjs/transports/xhrsend.py +++ b/sockjs/transports/xhrsend.py @@ -1,11 +1,13 @@ -from aiohttp import web, hdrs +from aiohttp import hdrs, web -from ..protocol import loads, ENCODING from .base import Transport -from .utils import CACHE_CONTROL, session_cookie, cors_headers, cache_headers +from .utils import CACHE_CONTROL, cache_headers, cors_headers, session_cookie +from ..protocol import ENCODING, loads class XHRSendTransport(Transport): + create_session = False + async def process(self): request = self.request diff --git a/sockjs/transports/xhrstreaming.py b/sockjs/transports/xhrstreaming.py index bd15c2d6..a51af032 100644 --- a/sockjs/transports/xhrstreaming.py +++ b/sockjs/transports/xhrstreaming.py @@ -1,14 +1,16 @@ -from aiohttp import web, hdrs +from aiohttp import hdrs, web from .base import StreamingTransport -from .utils import CACHE_CONTROL, session_cookie, cors_headers, cache_headers +from .utils import CACHE_CONTROL, cache_headers, cors_headers, session_cookie class XHRStreamingTransport(StreamingTransport): - - maxsize = 131072 # 128K bytes + create_session = True open_seq = b"h" * 2048 + b"\n" + async def _send(self, text: str): + return await super()._send(text + "\n") + async def process(self): request = self.request headers = ( diff --git a/tests/asdf b/tests/asdf deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/conftest.py b/tests/conftest.py index 6802d52f..66a28742 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,10 @@ import asyncio - -from datetime import timedelta from unittest import mock import pytest - from aiohttp import web +from aiohttp.test_utils import make_mocked_coro, make_mocked_request from aiohttp.web_urldispatcher import UrlMappingMatchInfo -from aiohttp.test_utils import make_mocked_request, make_mocked_coro from multidict import CIMultiDict from yarl import URL @@ -15,16 +12,15 @@ from sockjs.route import SockJSRoute -@pytest.fixture -def app(): +@pytest.fixture(name="app") +def app_fixture(): return web.Application() @pytest.fixture def make_fut(): def maker(val, makemock=True): - loop = asyncio.get_event_loop() - fut = loop.create_future() + fut = asyncio.Future() fut.set_result(val) if makemock: @@ -39,41 +35,24 @@ def maker(val, makemock=True): @pytest.fixture def make_handler(): - def maker(result, coro=True, exc=False): + def maker(result, exc=False): if result is None: result = [] output = result - def handler(msg, s): + async def handler(msg, s): if exc: raise ValueError((msg, s)) output.append((msg, s)) - if coro: - - async def async_handler(msg, s): - return handler(msg, s) - - return async_handler - else: - return handler - - return maker - - -@pytest.fixture -def make_route(make_handler, app): - def maker(handlers=transports.handlers): - handler = make_handler([]) - sm = SessionManager("sm", app, handler) - return SockJSRoute("sm", sm, "http:sockjs-cdn", handlers, (), True) + return handler return maker @pytest.fixture def make_request(app): - def maker(method, path, query_params={}, headers=None, match_info=None): + def maker(method, path, query_params=None, headers=None, match_info=None): path = URL(path) if query_params: path = path.with_query(query_params) @@ -98,7 +77,9 @@ def maker(method, path, query_params={}, headers=None, match_info=None): transport = mock.Mock() transport._drain_helper = make_mocked_coro() loop = asyncio.get_event_loop() - ret = make_mocked_request(method, str(path), headers, writer=writer, loop=loop) + ret = make_mocked_request( + method, str(path), headers, writer=writer, transport=transport, loop=loop + ) if match_info is None: match_info = UrlMappingMatchInfo({}, mock.Mock()) @@ -111,24 +92,36 @@ def maker(method, path, query_params={}, headers=None, match_info=None): @pytest.fixture def make_session(make_handler, make_request): - def maker( - name="test", timeout=timedelta(10), request=None, handler=None, result=None - ): - if request is None: - request = make_request("GET", "/TestPath/") - + def maker(name="test", disconnect_delay=10, handler=None, result=None): if handler is None: handler = make_handler(result) - return Session(name, handler, request, timeout=timeout, debug=True) + return Session(name, handler, disconnect_delay=disconnect_delay, debug=True) return maker @pytest.fixture -def make_manager(app, make_handler, make_session): +async def make_manager(event_loop, app, make_handler, make_session): + managers = [] + def maker(handler=None): if handler is None: handler = make_handler([]) - return SessionManager("sm", app, handler, debug=True) + manager = SessionManager("sm", app, handler, debug=True) + managers.append(manager) + return manager + + yield maker + + for sm in managers: + await sm.stop() + + +@pytest.fixture +def make_route(make_manager, make_handler, app): + def maker(handlers=transports.transport_handlers): + sm = make_manager() + app.on_cleanup.append(sm.stop) + return SockJSRoute("sm", sm, "http:sockjs-cdn", handlers, (), True) return maker diff --git a/tests/test_route.py b/tests/test_route.py index c0c390a7..eb2a4818 100644 --- a/tests/test_route.py +++ b/tests/test_route.py @@ -4,6 +4,7 @@ from multidict import CIMultiDict from sockjs import protocol +from sockjs.transports.base import Transport async def test_info(make_route, make_request): @@ -62,8 +63,8 @@ async def test_iframe(make_route, make_request): text = """ - - + + \r\n') assert not stop assert trans.size == len(b'\r\n') trans.maxsize = 1 - stop = trans.send("text data") + stop = await trans._send("text data") assert stop @@ -63,7 +62,7 @@ async def test_process_bad_callback(make_transport, make_fut): assert resp.status == 500 -async def test_session_has_request(make_transport, make_fut): - transp = make_transport(method="POST") - transp.session._remote_messages = make_fut(1) - assert isinstance(transp.session.request, web.Request) +# async def test_session_has_request(make_transport, make_fut): +# transp = make_transport(method="POST") +# transp.session._remote_messages = make_fut(1) +# assert isinstance(transp.session.request, web.Request) diff --git a/tests/test_transport_jsonp.py b/tests/test_transport_jsonp.py index 449f213c..44cdf9a2 100644 --- a/tests/test_transport_jsonp.py +++ b/tests/test_transport_jsonp.py @@ -1,7 +1,5 @@ from unittest import mock -from aiohttp import web - import pytest from aiohttp.test_utils import make_mocked_coro @@ -10,11 +8,11 @@ @pytest.fixture def make_transport(make_request, make_manager, make_handler, make_fut): - def maker(method="GET", path="/", query_params={}): + def maker(method="GET", path="/", query_params=None): handler = make_handler(None) manager = make_manager(handler) request = make_request(method, path, query_params=query_params) - session = manager.get("TestSessionJsonP", create=True, request=request) + session = manager.get("TestSessionJsonP", create=True) request.app.freeze() return jsonp.JSONPolling(manager, session, request) @@ -27,7 +25,7 @@ async def test_streaming_send(make_transport): resp = trans.response = mock.Mock() resp.write = make_mocked_coro(None) - stop = await trans.send("text data") + stop = await trans._send("text data") resp.write.assert_called_with(b'/**/cb("text data");\r\n') assert stop @@ -100,7 +98,7 @@ async def xtest_process_message(make_transport, make_fut): transp.session._remote_messages.assert_called_with(["msg1", "msg2"]) -async def test_session_has_request(make_transport, make_fut): - transp = make_transport(method="POST") - transp.session._remote_messages = make_fut(1) - assert isinstance(transp.session.request, web.Request) +# async def test_session_has_request(make_transport, make_fut): +# transp = make_transport(method="POST") +# transp.session._remote_messages = make_fut(1) +# assert isinstance(transp.session.request, web.Request) diff --git a/tests/test_transport_rawwebsocket.py b/tests/test_transport_rawwebsocket.py index bac60964..d488b76a 100644 --- a/tests/test_transport_rawwebsocket.py +++ b/tests/test_transport_rawwebsocket.py @@ -1,22 +1,21 @@ -from unittest import mock from asyncio import Future +from unittest import mock import pytest +from aiohttp import WSMessage, WSMsgType from sockjs.exceptions import SessionIsClosed from sockjs.protocol import FRAME_CLOSE, FRAME_HEARTBEAT from sockjs.transports.rawwebsocket import RawWebSocketTransport -from aiohttp import WSMessage, WSMsgType - @pytest.fixture def make_transport(make_request, make_fut): - def maker(method="GET", path="/", query_params={}): + def maker(method="GET", path="/", query_params=None): manager = mock.Mock() session = mock.Mock() session._remote_closed = make_fut(1) - session._wait = make_fut((FRAME_CLOSE, "")) + session._get_frame = make_fut((FRAME_CLOSE, "")) request = make_request(method, path, query_params=query_params) request.app.freeze() return RawWebSocketTransport(manager, session, request) @@ -61,7 +60,8 @@ async def test_sends_ping(make_transport, make_fut): session_close_future.set_exception(SessionIsClosed) session = mock.Mock() - session._wait.side_effect = [hb_future, session_close_future] + session._get_frame.side_effect = [hb_future, session_close_future] + transp.session = session - await transp.server(ws, session) + await transp.server(ws) assert ws.ping.called diff --git a/tests/test_transport_websocket.py b/tests/test_transport_websocket.py index 2b1ef593..c9c39406 100644 --- a/tests/test_transport_websocket.py +++ b/tests/test_transport_websocket.py @@ -1,27 +1,26 @@ import asyncio +import datetime from asyncio import Future from unittest import mock -from aiohttp import web, WSMessage, WSMsgType - import pytest - +from aiohttp import WSMessage, WSMsgType from aiohttp.test_utils import make_mocked_coro -from sockjs import SessionManager, MSG_OPEN, MSG_CLOSED, MSG_MESSAGE -from sockjs.protocol import FRAME_CLOSE +from sockjs import MSG_CLOSED, MSG_MESSAGE, MSG_OPEN, Session +from sockjs.protocol import FRAME_CLOSE, SockjsMessage from sockjs.transports import WebSocketTransport @pytest.fixture def make_transport(make_manager, make_request, make_handler, make_fut): - def maker(method="GET", path="/", query_params={}, handler=None): + def maker(method="GET", path="/", query_params=None, handler=None): handler = handler or make_handler(None) manager = make_manager(handler) request = make_request(method, path, query_params=query_params) request.app.freeze() - session = manager.get("TestSessionWebsocket", create=True, request=request) - session._wait = make_fut((FRAME_CLOSE, "")) + session = manager.get("TestSessionWebsocket", create=True) + session._get_frame = make_fut((FRAME_CLOSE, "")) return WebSocketTransport(manager, session, request) return maker @@ -47,35 +46,35 @@ async def xtest_process_release_acquire_and_remote_closed(make_transport): async def test_server_close(app, make_manager, make_request): reached_closed = False - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() - async def handler(msg, session): + async def handler(msg: SockjsMessage, session: Session): nonlocal reached_closed - if msg.tp == MSG_OPEN: - asyncio.ensure_future(session._remote_message("TESTMSG")) - pass - - elif msg.tp == MSG_MESSAGE: + if msg.type == MSG_OPEN: + # To reproduce the ordering which makes the issue + loop.call_later(0.05, session.close) + elif msg.type == MSG_MESSAGE: # To reproduce the ordering which makes the issue loop.call_later(0.05, session.close) - elif msg.tp == MSG_CLOSED: + elif msg.type == MSG_CLOSED: reached_closed = True app.freeze() - request = make_request("GET", "/", query_params={}) - manager = SessionManager("sm", app, handler, debug=True) + request = make_request("GET", "/") + manager = make_manager(handler) session = manager.get("test", create=True) transp = WebSocketTransport(manager, session, request) await transp.process() - assert reached_closed is True - + assert reached_closed is False + assert session.expires + assert not session.expired + session.expires = datetime.datetime.now() + await manager._gc_expired_sessions() -async def test_session_has_request(make_transport, make_fut): - transp = make_transport(method="POST") - assert isinstance(transp.session.request, web.Request) + assert reached_closed is True async def test_frames(make_transport, make_handler): @@ -111,8 +110,7 @@ async def test_frames(make_transport, make_handler): close_frame, ] - session = transp.session - await transp.client(ws, session) + await transp.client(ws) assert result[0][0].data == "single_msg" assert result[1][0].data == "msg1" diff --git a/tests/test_transport_xhr.py b/tests/test_transport_xhr.py index f6aa25b9..2bf8e6cc 100644 --- a/tests/test_transport_xhr.py +++ b/tests/test_transport_xhr.py @@ -1,5 +1,3 @@ -from aiohttp import web - import pytest from sockjs.transports import xhr @@ -7,12 +5,12 @@ @pytest.fixture def make_transport(make_manager, make_request, make_handler, make_fut): - def maker(method="GET", path="/", query_params={}): + def maker(method="GET", path="/", query_params=None): handler = make_handler(None) manager = make_manager(handler) request = make_request(method, path, query_params=query_params) request.app.freeze() - session = manager.get("TestSessionXhr", create=True, request=request) + session = manager.get("TestSessionXhr", create=True) return xhr.XHRTransport(manager, session, request) return maker @@ -32,7 +30,7 @@ async def test_process_OPTIONS(make_transport): assert resp.status == 204 -async def test_session_has_request(make_transport, make_fut): - transp = make_transport() - transp.session._remote_messages = make_fut(1) - assert isinstance(transp.session.request, web.Request) +# async def test_session_has_request(make_transport, make_fut): +# transp = make_transport() +# transp.session._remote_messages = make_fut(1) +# assert isinstance(transp.session.request, web.Request) diff --git a/tests/test_transport_xhrsend.py b/tests/test_transport_xhrsend.py index 5ad64d5a..d9d53792 100644 --- a/tests/test_transport_xhrsend.py +++ b/tests/test_transport_xhrsend.py @@ -1,5 +1,3 @@ -from aiohttp import web - import pytest from sockjs.transports import xhrsend @@ -7,12 +5,12 @@ @pytest.fixture def make_transport(make_manager, make_request, make_handler, make_fut): - def maker(method="GET", path="/", query_params={}): + def maker(method="GET", path="/", query_params=None): handler = make_handler(None) manager = make_manager(handler) request = make_request(method, path, query_params=query_params) request.app.freeze() - session = manager.get("TestSessionXhrSend", create=True, request=request) + session = manager.get("TestSessionXhrSend", create=True) return xhrsend.XHRSendTransport(manager, session, request) return maker @@ -53,7 +51,7 @@ async def test_OPTIONS(make_transport): assert resp.status == 204 -async def test_session_has_request(make_transport, make_fut): - transp = make_transport(method="POST") - transp.session._remote_messages = make_fut(1) - assert isinstance(transp.session.request, web.Request) +# async def test_session_has_request(make_transport, make_fut): +# transp = make_transport(method="POST") +# transp.session._remote_messages = make_fut(1) +# assert isinstance(transp.session.request, web.Request) diff --git a/tests/test_transport_xhrstreaming.py b/tests/test_transport_xhrstreaming.py index 4f27621a..a3894d4a 100644 --- a/tests/test_transport_xhrstreaming.py +++ b/tests/test_transport_xhrstreaming.py @@ -1,5 +1,3 @@ -from aiohttp import web - import pytest from sockjs.transports import xhrstreaming @@ -7,12 +5,12 @@ @pytest.fixture def make_transport(make_manager, make_request, make_handler, make_fut): - def maker(method="GET", path="/", query_params={}): + def maker(method="GET", path="/", query_params=None): handler = make_handler(None) manager = make_manager(handler) request = make_request(method, path, query_params=query_params) request.app.freeze() - session = manager.get("TestSessionXhrStreaming", create=True, request=request) + session = manager.get("TestSessionXhrStreaming", create=True) return xhrstreaming.XHRStreamingTransport(manager, session, request) return maker @@ -32,7 +30,7 @@ async def test_process_OPTIONS(make_transport): assert resp.status == 204 -async def test_session_has_request(make_transport, make_fut): - transp = make_transport(method="POST") - transp.session._remote_messages = make_fut(1) - assert isinstance(transp.session.request, web.Request) +# async def test_session_has_request(make_transport, make_fut): +# transp = make_transport(method="POST") +# transp.session._remote_messages = make_fut(1) +# assert isinstance(transp.session.request, web.Request) From 6b277201d339fd285d9529914af7a77a6f3a5c05 Mon Sep 17 00:00:00 2001 From: Kirill Kuzminykh Date: Tue, 8 Feb 2022 14:20:41 +0300 Subject: [PATCH 02/23] Deleted method ``SessionManager.route_url()``. --- CHANGES.rst | 1 + sockjs/session.py | 28 ++++++++++++---------------- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 00392318..01ea0a3f 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -15,6 +15,7 @@ CHANGES - **Breaking change:** Argument ``heartbeat`` of ``SessionManager.__init__()`` renamed into ``heartbeat_delay``. - **Breaking change:** ``Session.registry`` renamed into ``Session.app``. +- **Breaking change:** Deleted method ``SessionManager.route_url()``. - **Breaking change:** Dropped support of Python < 3.7 0.11.0 (2020-10-22) diff --git a/sockjs/session.py b/sockjs/session.py index 72178d2c..fb35899c 100644 --- a/sockjs/session.py +++ b/sockjs/session.py @@ -2,7 +2,6 @@ import collections import logging import warnings -from asyncio import ensure_future from datetime import datetime, timedelta from typing import List, Optional, Tuple @@ -158,7 +157,6 @@ def _heartbeat(self): self._heartbeats += 1 if self._send_heartbeats: self._feed(FRAME_HEARTBEAT, FRAME_HEARTBEAT) - log.debug("heartbeat sent: %s", self.id) def _feed(self, frame, data): # pack messages @@ -319,10 +317,6 @@ def __init__( self.disconnect_delay = disconnect_delay self.debug = debug - def route_url(self, request: web.Request): - url = self.app.router[self.route_name].url_for() - return "%s://%s%s" % (request.scheme, request.host, url) - @property def started(self): return self._hb_task is not None @@ -340,10 +334,6 @@ async def stop(self, _app=None): self._hb_task = None await self.clear() - def _heartbeat(self): - if self._hb_task is None: - self._hb_task = ensure_future(self._heartbeat_task()) - async def _check_expiration(self, session: Session): if session.expired: log.debug("session expired: %s", session.id) @@ -371,14 +361,20 @@ async def _gc_expired_sessions(self): del sessions[idx] async def _heartbeat_task(self): + delay = min(self.heartbeat_delay, self.disconnect_delay) + if delay <= 0: + delay = max(self.heartbeat_delay, self.disconnect_delay, 10) while True: - await asyncio.sleep(self.heartbeat_delay) + await asyncio.sleep(delay) await self._gc_expired_sessions() - # Send heartbeat - now = datetime.now() - for session in self.sessions: - if session.next_heartbeat <= now: - session._heartbeat() + self._heartbeat() + + def _heartbeat(self): + # Send heartbeat + now = datetime.now() + for session in self.sessions: + if session.next_heartbeat <= now: + session._heartbeat() def _add(self, session: Session): if session.expired: From 2b89b0e3035e6d5e191897aab40698c18bb1e52c Mon Sep 17 00:00:00 2001 From: Kirill Kuzminykh Date: Tue, 8 Feb 2022 14:56:15 +0300 Subject: [PATCH 03/23] Added release date into CHANGES.rst --- CHANGES.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 01ea0a3f..fd4d1607 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -2,8 +2,8 @@ CHANGES ======= -0.12.0 (not released yet) -------------------------- +0.12.0 (2022-02-08) +------------------- - Fixed processing of heartbeats and a session expiration. - Fixed ping-pong based heartbeats for web-socket connections. From 7875007ccc10e90aea73549fc19af070db55c22d Mon Sep 17 00:00:00 2001 From: Kirill Kuzminykh Date: Wed, 9 Feb 2022 15:54:13 +0300 Subject: [PATCH 04/23] Removed commented code from tests and fixed CHANGES.rst --- CHANGES.rst | 10 +++++----- requirements.txt | 3 ++- tests/test_transport.py | 9 --------- tests/test_transport_eventsource.py | 6 ------ tests/test_transport_htmlfile.py | 6 ------ tests/test_transport_jsonp.py | 6 ------ tests/test_transport_xhr.py | 6 ------ tests/test_transport_xhrsend.py | 6 ------ tests/test_transport_xhrstreaming.py | 6 ------ 9 files changed, 7 insertions(+), 51 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index fd4d1607..2a7a2c87 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -5,11 +5,6 @@ CHANGES 0.12.0 (2022-02-08) ------------------- -- Fixed processing of heartbeats and a session expiration. -- Fixed ping-pong based heartbeats for web-socket connections. -- Added arguments ``heartbeat_delay`` and ``disconnect_delay`` into - ``Session.__init__()``. -- Added argument ``disconnect_delay`` into ``SessionManager.__init__()``. - **Breaking change:** Removed argument ``timeout`` from ``Session.__init__()`` and ``SessionManager.__init__()``. - **Breaking change:** Argument ``heartbeat`` of ``SessionManager.__init__()`` @@ -17,6 +12,11 @@ CHANGES - **Breaking change:** ``Session.registry`` renamed into ``Session.app``. - **Breaking change:** Deleted method ``SessionManager.route_url()``. - **Breaking change:** Dropped support of Python < 3.7 +- Fixed processing of heartbeats and a session expiration. +- Fixed ping-pong based heartbeats for web-socket connections. +- Added arguments ``heartbeat_delay`` and ``disconnect_delay`` into + ``Session.__init__()``. +- Added argument ``disconnect_delay`` into ``SessionManager.__init__()``. 0.11.0 (2020-10-22) ------------------- diff --git a/requirements.txt b/requirements.txt index 8ec0c149..d3e1bb37 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -black==22.1.0; python_version>="3.6" +black==22.1.0 flake8==4.0.1 pytest==7.0.0 pytest-aiohttp==1.0.3 @@ -7,4 +7,5 @@ pytest-sugar==0.9.4 pytest-mock==3.7.0 pytest-timeout==2.1.0 aiohttp==3.8.1 +twine==3.8.0 -e .[test] diff --git a/tests/test_transport.py b/tests/test_transport.py index 19dd8cf6..bd30e6f0 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -75,12 +75,3 @@ async def test_handle_session_closed(make_transport, make_fut): await trans.handle_session() trans.session._remote_closed.assert_called_with() trans._send.assert_called_with('c[3000,"Go away!"]') - - -# async def test_session_has_request(make_transport, make_fut): -# transp = make_transport(method="POST") -# session = transp.session -# session._remote_messages = make_fut(1) -# assert session.request is None -# await transp.process() -# assert isinstance(session.request, web.Request) diff --git a/tests/test_transport_eventsource.py b/tests/test_transport_eventsource.py index e19b7d6d..4906f668 100644 --- a/tests/test_transport_eventsource.py +++ b/tests/test_transport_eventsource.py @@ -40,9 +40,3 @@ async def test_process(make_transport, make_fut): resp = await transp.process() assert transp.handle_session.called assert resp.status == 200 - - -# async def test_session_has_request(make_transport, make_fut): -# transp = make_transport(method="POST") -# transp.session._remote_messages = make_fut(1) -# assert isinstance(transp.session.request, web.Request) diff --git a/tests/test_transport_htmlfile.py b/tests/test_transport_htmlfile.py index cfc9dba7..c03c5f11 100644 --- a/tests/test_transport_htmlfile.py +++ b/tests/test_transport_htmlfile.py @@ -60,9 +60,3 @@ async def test_process_bad_callback(make_transport, make_fut): resp = await transp.process() assert transp.session._remote_closed.called assert resp.status == 500 - - -# async def test_session_has_request(make_transport, make_fut): -# transp = make_transport(method="POST") -# transp.session._remote_messages = make_fut(1) -# assert isinstance(transp.session.request, web.Request) diff --git a/tests/test_transport_jsonp.py b/tests/test_transport_jsonp.py index 44cdf9a2..d51208f5 100644 --- a/tests/test_transport_jsonp.py +++ b/tests/test_transport_jsonp.py @@ -96,9 +96,3 @@ async def xtest_process_message(make_transport, make_fut): resp = await transp.process() assert resp.status == 200 transp.session._remote_messages.assert_called_with(["msg1", "msg2"]) - - -# async def test_session_has_request(make_transport, make_fut): -# transp = make_transport(method="POST") -# transp.session._remote_messages = make_fut(1) -# assert isinstance(transp.session.request, web.Request) diff --git a/tests/test_transport_xhr.py b/tests/test_transport_xhr.py index 2bf8e6cc..d52433cf 100644 --- a/tests/test_transport_xhr.py +++ b/tests/test_transport_xhr.py @@ -28,9 +28,3 @@ async def test_process_OPTIONS(make_transport): transp = make_transport(method="OPTIONS") resp = await transp.process() assert resp.status == 204 - - -# async def test_session_has_request(make_transport, make_fut): -# transp = make_transport() -# transp.session._remote_messages = make_fut(1) -# assert isinstance(transp.session.request, web.Request) diff --git a/tests/test_transport_xhrsend.py b/tests/test_transport_xhrsend.py index d9d53792..3302a87c 100644 --- a/tests/test_transport_xhrsend.py +++ b/tests/test_transport_xhrsend.py @@ -49,9 +49,3 @@ async def test_OPTIONS(make_transport): transp = make_transport(method="OPTIONS") resp = await transp.process() assert resp.status == 204 - - -# async def test_session_has_request(make_transport, make_fut): -# transp = make_transport(method="POST") -# transp.session._remote_messages = make_fut(1) -# assert isinstance(transp.session.request, web.Request) diff --git a/tests/test_transport_xhrstreaming.py b/tests/test_transport_xhrstreaming.py index a3894d4a..f413161f 100644 --- a/tests/test_transport_xhrstreaming.py +++ b/tests/test_transport_xhrstreaming.py @@ -28,9 +28,3 @@ async def test_process_OPTIONS(make_transport): transp = make_transport(method="OPTIONS") resp = await transp.process() assert resp.status == 204 - - -# async def test_session_has_request(make_transport, make_fut): -# transp = make_transport(method="POST") -# transp.session._remote_messages = make_fut(1) -# assert isinstance(transp.session.request, web.Request) From 04b78f3637f9394618a7b81d05c0c7602f2f69f8 Mon Sep 17 00:00:00 2001 From: Kirill Kuzminykh Date: Thu, 19 Jan 2023 01:14:21 +0400 Subject: [PATCH 05/23] - Added argument ``cors_config`` into function ``add_endpoint()`` to support of CORS settings from ``aiohttp_cors``. - Function ``add_endpoint()`` now returns all registered routes. - Replaced returning instances of error HTTP responses on raising its as exceptions. - Changed name of some routes. --- .coveragerc | 7 -- .github/workflows/check_and_test.yaml | 38 ++++++ .github/workflows/pythonpublish.yml | 6 +- .travis.yml | 36 ------ CHANGES.rst | 10 ++ Makefile | 13 +-- examples/chat.py | 20 +++- requirements.txt | 18 +-- setup.cfg | 2 +- setup.py | 9 +- sockjs/__init__.py | 2 +- sockjs/route.py | 160 +++++++++++++++----------- sockjs/transports/htmlfile.py | 4 +- sockjs/transports/jsonp.py | 12 +- sockjs/transports/xhrsend.py | 6 +- tests/conftest.py | 26 ++++- tests/test_route.py | 86 +++++++++----- tests/test_transport_htmlfile.py | 9 +- tests/test_transport_jsonp.py | 13 ++- tests/test_transport_xhrsend.py | 5 +- 20 files changed, 289 insertions(+), 193 deletions(-) delete mode 100644 .coveragerc create mode 100644 .github/workflows/check_and_test.yaml delete mode 100644 .travis.yml diff --git a/.coveragerc b/.coveragerc deleted file mode 100644 index 639d298c..00000000 --- a/.coveragerc +++ /dev/null @@ -1,7 +0,0 @@ -[run] -branch = True -source = sockjs, tests -omit = site-packages - -[html] -directory = coverage diff --git a/.github/workflows/check_and_test.yaml b/.github/workflows/check_and_test.yaml new file mode 100644 index 00000000..9eeac92b --- /dev/null +++ b/.github/workflows/check_and_test.yaml @@ -0,0 +1,38 @@ +name: Check and Test + +on: + push: + branches: [ "master" ] + pull_request: + branches: [ "master" ] + workflow_dispatch: {} + + +jobs: + run_tests: + strategy: + matrix: + python-version: [ "3.7", "3.8", "3.9", "3.10", "3.11" ] + + name: Test on Python {{ matrix.python-version }} + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + architecture: x64 + + - name: Install dependencies + run: | + - python -m pip install --upgrade setuptools + - python -m pip install -r requirements.txt + - python -m pip install codecov + - python -m pip install -e .[test] + + - name: Run checks and tests + run: | + - make cov + - pytest ./tests diff --git a/.github/workflows/pythonpublish.yml b/.github/workflows/pythonpublish.yml index b143a530..72fb8e7c 100644 --- a/.github/workflows/pythonpublish.yml +++ b/.github/workflows/pythonpublish.yml @@ -8,15 +8,15 @@ jobs: deploy: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python - uses: actions/setup-python@v1 + uses: actions/setup-python@v4 with: python-version: '3.x' - name: Install dependencies run: | python -m pip install --upgrade pip - pip install setuptools wheel twine + python -m pip install setuptools wheel twine - name: Build and publish env: TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index d0c97e62..00000000 --- a/.travis.yml +++ /dev/null @@ -1,36 +0,0 @@ -language: python - - -python: - - 3.7 - - 3.8 - - 3.9 - - 3.10 - - -install: - - pip install --upgrade setuptools - - pip install -r requirements.txt - - pip install codecov - - pip install -e .[test] - - -script: - - make cov - - python setup.py check -rms - - -after_success: - - codecov - - -deploy: - provider: pypi - user: aio-libs-bot - password: - secure: "En+U4kq5LKMrOQ2g8qGBLRbTPEfgNsLIcAfs264LjAtSb9Rc8g62wJ3s2Vob5f8k848ZcBmBh+K5M75gNIYZtP2wU1p4sHvTnz5bxSkdgN9A7gdgS8UdxwbIO6dUBGz3YTlXNcbRmiSbA8CL7M8ULswVtqTwpr864q80TBlj49+HbfszI2wb9UOu7kGW9i1qFtCRSagqe6V1MWhRc0H5nD8WmwzmlxMtllJfubDA4EpqAwAPPwfxYP7QQgo8L3e5CBcbDmLwvvjXkxrOpp6yae2003AWJXFNcygcpo2mt1BRe8/bXtIxXLET0djP7sj+3yu5XTksAI2JTcGVW9PZUr+NmhTeKTrAFQ+7qW+QNWQRRYlS5SokhOEoTRPRsH2D9kYKFy0wteBNNe2TD4o09KsqQHwY+Lpr5nlJUkX5HXBrHoGSQW9lNaQEq7nutpEwiTBCAdINjmfjKMxHIXfy93XfjR2wwGIoUr94i+yWG/zIkCKwr31s5CvfRbHmntU/jFTk6cqSTfzphk+7XEMWQlw8tRh55b641IY4/PMXqSXx8oNpoK4/lvrKG0KP4wSBBLjIOVSq46VPij3YQjnN2EzqECKess2D6Wrec3JaPukLtjCnOymbMq72BstnRI41THrL6bNpyUc7OkXL9NwoU6TNSXdMZoVA2lu6nRE6F+w=" - distributions: "sdist bdist_wheel" - on: - tags: true - all_branches: true - python: 3.7 diff --git a/CHANGES.rst b/CHANGES.rst index 2a7a2c87..716ac49b 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -2,6 +2,16 @@ CHANGES ======= +0.13.0 (not-released) +--------------------- + +- Added argument ``cors_config`` into function ``add_endpoint()`` + to support of CORS settings from ``aiohttp_cors``. +- Function ``add_endpoint()`` now returns all registered routes. +- Replaced returning instances of error HTTP responses + on raising its as exceptions. +- Changed name of some routes. + 0.12.0 (2022-02-08) ------------------- diff --git a/Makefile b/Makefile index ae5d3eb2..2db017a9 100644 --- a/Makefile +++ b/Makefile @@ -4,14 +4,7 @@ FLAGS= flake: -# python setup.py check -rms flake8 sockjs tests examples - if python -c "import sys; sys.exit(sys.version_info<(3,6))"; then \ - black --check sockjs tests setup.py; \ - fi - -fmt: - black sockjs tests setup.py develop: pip install -e .[test] @@ -22,10 +15,6 @@ test: flake develop vtest: flake develop pytest -s -v $(FLAGS) ./tests/ -cov cover coverage: flake develop - @py.test --cov=sockjs --cov-report=term --cov-report=html tests - @echo "open file://`pwd`/coverage/index.html" - clean: rm -rf `find . -name __pycache__` rm -f `find . -type f -name '*.py[co]' ` @@ -41,4 +30,4 @@ clean: rm -rf cover python setup.py clean -.PHONY: all build venv flake test vtest testloop cov clean +.PHONY: all flake test vtest clean diff --git a/examples/chat.py b/examples/chat.py index efba4801..ee5c660b 100644 --- a/examples/chat.py +++ b/examples/chat.py @@ -1,6 +1,7 @@ import logging import os +import aiohttp_cors from aiohttp import web import sockjs @@ -32,6 +33,23 @@ def index(request): app = web.Application() app.router.add_route('GET', '/', index) - sockjs.add_endpoint(app, chat_msg_handler, name='chat', prefix='/sockjs/') + + # Configure default CORS settings. + cors = aiohttp_cors.setup(app, defaults={ + '*': aiohttp_cors.ResourceOptions( + allow_credentials=True, + expose_headers='*', + allow_headers='*', + max_age=31536000, + ) + }) + + sockjs.add_endpoint( + app, + chat_msg_handler, + name='chat', + prefix='/sockjs/', + cors_config=cors, + ) web.run_app(app) diff --git a/requirements.txt b/requirements.txt index d3e1bb37..3856eab9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,11 @@ -black==22.1.0 -flake8==4.0.1 -pytest==7.0.0 -pytest-aiohttp==1.0.3 -pytest-cov==3.0.0 -pytest-sugar==0.9.4 -pytest-mock==3.7.0 +black==22.12.0 +flake8==6.0.0 +pytest==7.2.1 +pytest-aiohttp==1.0.4 +pytest-cov==4.0.0 +pytest-sugar==0.9.6 +pytest-mock==3.10.0 pytest-timeout==2.1.0 -aiohttp==3.8.1 -twine==3.8.0 +aiohttp==3.8.3 +twine==4.0.2 -e .[test] diff --git a/setup.cfg b/setup.cfg index 222c4da9..f88d780e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,7 +3,7 @@ zip_ok = false [flake8] -ignore = N801,N802,N803,E226 +ignore = N801,N802,N803,E226,W503 max-line-length = 88 diff --git a/setup.py b/setup.py index 0ca1cfd6..b006585b 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,8 @@ import codecs import os import re -from setuptools import setup, find_packages + +from setuptools import find_packages, setup with codecs.open( @@ -34,6 +35,7 @@ def read(f): "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Programming Language :: Python :: Implementation :: CPython", "Topic :: Internet :: WWW/HTTP", "Framework :: AsyncIO", @@ -45,7 +47,7 @@ def read(f): packages=find_packages(), python_requires=">=3.7.0", install_requires=[ - "aiohttp>=3.0.0", + "aiohttp>=3.7.4", ], extras_require={ "test": [ @@ -54,6 +56,9 @@ def read(f): "yarl", "pytest-aiohttp", "pytest-timeout", + "pytest-mock", + "cykooz.testing", + 'aiohttp_cors', ], }, include_package_data=True, diff --git a/sockjs/__init__.py b/sockjs/__init__.py index ff2b35e8..31ab5d7e 100644 --- a/sockjs/__init__.py +++ b/sockjs/__init__.py @@ -13,7 +13,7 @@ from .session import Session, SessionManager -__version__ = "0.12.0" +__version__ = "0.13.0" __all__ = ( diff --git a/sockjs/route.py b/sockjs/route.py index b2b0bcdd..fd6941c7 100644 --- a/sockjs/route.py +++ b/sockjs/route.py @@ -4,10 +4,16 @@ import json import logging import random -from typing import Iterable, Type +from typing import Iterable, Optional, Type from aiohttp import hdrs, web + +try: + from aiohttp_cors import CorsConfig +except ImportError: + CorsConfig = None + from .protocol import IFRAME_HTML from .session import SessionManager from .transports import transport_handlers @@ -17,6 +23,7 @@ log = logging.getLogger("sockjs") +ALL_METH_WO_OPTIONS = hdrs.METH_ALL - {hdrs.METH_OPTIONS} def get_manager(name, app) -> SessionManager: @@ -28,19 +35,23 @@ def _gen_endpoint_name(): def add_endpoint( - app: web.Application, - handler, - *, - name="", - prefix="/sockjs", - manager=None, - disable_transports=(), - sockjs_cdn="https://cdn.jsdelivr.net/npm/sockjs-client@1/dist/sockjs.min.js", # noqa - cookie_needed=True -): + app: web.Application, + handler, + *, + name="", + prefix="/sockjs", + manager=None, + disable_transports=(), + sockjs_cdn="https://cdn.jsdelivr.net/npm/sockjs-client@1/dist/sockjs.min.js", # noqa + cookie_needed=True, + cors_config: Optional[CorsConfig] = None, +) -> list[web.AbstractRoute]: + registered_routes = [] + assert callable(handler), handler - if not asyncio.iscoroutinefunction(handler) and not inspect.isgeneratorfunction( - handler + if ( + not asyncio.iscoroutinefunction(handler) + and not inspect.isgeneratorfunction(handler) ): sync_handler = handler @@ -76,57 +87,80 @@ async def handler(msg, session): ) prefix = prefix.rstrip("/") - route_name = "sockjs-url-%s-greeting" % name - router.add_route(hdrs.METH_GET, prefix, route.greeting, name=route_name) + route_name = "sockjs-greeting-%s" % name + registered_routes.append( + router.add_route(hdrs.METH_GET, prefix, route.greeting, name=route_name) + ) - route_name = "sockjs-url-%s" % name - router.add_route(hdrs.METH_GET, "%s/" % prefix, route.greeting, name=route_name) + route_name = "sockjs-greeting-ts-%s" % name + registered_routes.append( + router.add_route(hdrs.METH_GET, "%s/" % prefix, route.greeting, name=route_name) + ) - route_name = "sockjs-%s" % name - router.add_route( - hdrs.METH_ANY, + resource = router.add_resource( "%s/{server}/{session}/{transport}" % prefix, - route.handler, - name=route_name, + name=f"sockjs-transport-{name}" ) + for method in ALL_METH_WO_OPTIONS: + registered_routes.append( + resource.add_route( + method, + route.handler, + ) + ) route_name = "sockjs-websocket-%s" % name - router.add_route( - hdrs.METH_GET, "%s/websocket" % prefix, route.websocket, name=route_name + registered_routes.append( + router.add_route( + hdrs.METH_GET, "%s/websocket" % prefix, route.websocket, name=route_name + ) ) - router.add_route( - hdrs.METH_GET, "%s/info" % prefix, route.info, name="sockjs-info-%s" % name - ) - router.add_route( - hdrs.METH_OPTIONS, - "%s/info" % prefix, - route.info_options, - name="sockjs-info-options-%s" % name, + registered_routes.append( + router.add_route( + hdrs.METH_GET, + "%s/info" % prefix, + route.info, + name="sockjs-info-%s" % name, + ) ) route_name = "sockjs-iframe-%s" % name - router.add_route( - hdrs.METH_GET, "%s/iframe.html" % prefix, route.iframe, name=route_name + registered_routes.append( + router.add_route( + hdrs.METH_GET, "%s/iframe.html" % prefix, route.iframe, name=route_name + ) ) route_name = "sockjs-iframe-ver-%s" % name - router.add_route( - hdrs.METH_GET, "%s/iframe{version}.html" % prefix, route.iframe, name=route_name + registered_routes.append( + router.add_route( + hdrs.METH_GET, + "%s/iframe{version}.html" % prefix, + route.iframe, + name=route_name + ) ) app.on_cleanup.append(manager.stop) + if cors_config is not None: + # Configure CORS on all routes. + for route in registered_routes: + cors_config.add(route) + + return registered_routes + class SockJSRoute: def __init__( - self, - name: str, - manager: SessionManager, - sockjs_cdn: str, - handlers, - disable_transports: Iterable[str], - cookie_needed=True, + self, + name: str, + manager: SessionManager, + sockjs_cdn: str, + handlers, + disable_transports: Iterable[str], + cookie_needed=True, ): self.name = name self.manager = manager @@ -135,6 +169,9 @@ def __init__( self.cookie_needed = cookie_needed self.iframe_html = (IFRAME_HTML % sockjs_cdn).encode("utf-8") self.iframe_html_hxd = hashlib.md5(self.iframe_html).hexdigest() + self._transport_names = sorted( + set(transport_handlers.keys()) - self.disable_transports + ) async def handler(self, request): info = request.match_info @@ -143,7 +180,7 @@ async def handler(self, request): tid = info["transport"] if tid not in self.handlers or tid in self.disable_transports: - return web.HTTPNotFound() + raise web.HTTPNotFound() transport: Type[Transport] = self.handlers[tid] @@ -154,12 +191,12 @@ async def handler(self, request): sid = info["session"] if not sid or "." in sid or "." in info["server"]: - return web.HTTPNotFound() + raise web.HTTPNotFound() try: session = transport.get_session(manager, sid) except KeyError: - return web.HTTPNotFound(headers=session_cookie(request)) + raise web.HTTPNotFound(headers=session_cookie(request)) t = transport(manager, session, request) try: @@ -167,12 +204,12 @@ async def handler(self, request): except asyncio.CancelledError: raise except web.HTTPException as exc: - return exc + raise exc except Exception: log.exception("Exception in transport: %s" % tid) if manager.is_acquired(session): await manager.release(session) - return web.HTTPInternalServerError() + raise web.HTTPInternalServerError() async def websocket(self, request): if not self.manager.started: @@ -188,7 +225,7 @@ async def websocket(self, request): except asyncio.CancelledError: raise except web.HTTPException as exc: - return exc + raise exc async def info(self, request): resp = web.Response() @@ -198,23 +235,14 @@ async def info(self, request): info = { "entropy": random.randint(1, 2147483647), - "websocket": "websocket" not in self.disable_transports, + "websocket": "websocket" in self._transport_names, "cookie_needed": self.cookie_needed, "origins": ["*:*"], + "transports": self._transport_names, } resp.text = json.dumps(info) return resp - async def info_options(self, request): - resp = web.Response(status=204) - resp.headers[hdrs.CONTENT_TYPE] = "application/json;charset=UTF-8" - resp.headers[hdrs.CACHE_CONTROL] = CACHE_CONTROL - resp.headers[hdrs.ACCESS_CONTROL_ALLOW_METHODS] = "OPTIONS, GET" - resp.headers.extend(cors_headers(request.headers)) - resp.headers.extend(cache_headers()) - resp.headers.extend(session_cookie(request)) - return resp - async def iframe(self, request): cached = request.headers.get(hdrs.IF_NONE_MATCH) if cached: @@ -223,15 +251,15 @@ async def iframe(self, request): response.headers.extend(cache_headers()) return response - headers = ( - (hdrs.CONTENT_TYPE, "text/html;charset=UTF-8"), - (hdrs.ETAG, self.iframe_html_hxd), - ) - headers += cache_headers() + headers = { + hdrs.CONTENT_TYPE: "text/html;charset=UTF-8", + hdrs.ETAG: self.iframe_html_hxd, + } + headers.update(dict(cache_headers())) return web.Response(body=self.iframe_html, headers=headers) async def greeting(self, request): return web.Response( body=b"Welcome to SockJS!\n", - headers=((hdrs.CONTENT_TYPE, "text/plain; charset=UTF-8"),), + headers={hdrs.CONTENT_TYPE: "text/plain; charset=UTF-8"}, ) diff --git a/sockjs/transports/htmlfile.py b/sockjs/transports/htmlfile.py index 40e8d73f..b70da18b 100644 --- a/sockjs/transports/htmlfile.py +++ b/sockjs/transports/htmlfile.py @@ -39,11 +39,11 @@ async def process(self): callback = request.query.get("c") if callback is None: await self.session._remote_closed() - return web.HTTPInternalServerError(text='"callback" parameter required') + raise web.HTTPInternalServerError(text='"callback" parameter required') elif not self.check_callback.match(callback): await self.session._remote_closed() - return web.HTTPInternalServerError(text='invalid "callback" parameter') + raise web.HTTPInternalServerError(text='invalid "callback" parameter') headers = ( (hdrs.CONTENT_TYPE, "text/html; charset=UTF-8"), diff --git a/sockjs/transports/jsonp.py b/sockjs/transports/jsonp.py index 6bd0c932..cc5600d0 100644 --- a/sockjs/transports/jsonp.py +++ b/sockjs/transports/jsonp.py @@ -28,11 +28,11 @@ async def process(self): callback = self.callback = request.query.get("c") if not callback: await self.session._remote_closed() - return web.HTTPInternalServerError(text='"callback" parameter required') + raise web.HTTPInternalServerError(text='"callback" parameter required') elif not self.check_callback.match(callback): await self.session._remote_closed() - return web.HTTPInternalServerError(text='invalid "callback" parameter') + raise web.HTTPInternalServerError(text='invalid "callback" parameter') headers = ( (hdrs.CONTENT_TYPE, "application/javascript; charset=UTF-8"), @@ -53,19 +53,19 @@ async def process(self): ctype = request.content_type.lower() if ctype == "application/x-www-form-urlencoded": if not data.startswith(b"d="): - return web.HTTPInternalServerError(text="Payload expected.") + raise web.HTTPInternalServerError(text="Payload expected.") data = unquote_plus(data[2:].decode(ENCODING)) else: data = data.decode(ENCODING) if not data: - return web.HTTPInternalServerError(text="Payload expected.") + raise web.HTTPInternalServerError(text="Payload expected.") try: messages = loads(data) except Exception: - return web.HTTPInternalServerError(text="Broken JSON encoding.") + raise web.HTTPInternalServerError(text="Broken JSON encoding.") await session._remote_messages(messages) @@ -77,7 +77,7 @@ async def process(self): return web.Response(body=b"ok", headers=headers) else: - return web.HTTPBadRequest(text="No support for such method: %s" % meth) + raise web.HTTPBadRequest(text="No support for such method: %s" % meth) class JSONPollingSend(JSONPolling): diff --git a/sockjs/transports/xhrsend.py b/sockjs/transports/xhrsend.py index 5a53391d..3e965ac4 100644 --- a/sockjs/transports/xhrsend.py +++ b/sockjs/transports/xhrsend.py @@ -12,7 +12,7 @@ async def process(self): request = self.request if request.method not in (hdrs.METH_GET, hdrs.METH_POST, hdrs.METH_OPTIONS): - return web.HTTPForbidden(text="Method is not allowed") + raise web.HTTPForbidden(text="Method is not allowed") if self.request.method == hdrs.METH_OPTIONS: headers = ( @@ -26,12 +26,12 @@ async def process(self): data = await request.read() if not data: - return web.HTTPInternalServerError(text="Payload expected.") + raise web.HTTPInternalServerError(text="Payload expected.") try: messages = loads(data.decode(ENCODING)) except Exception: - return web.HTTPInternalServerError(text="Broken JSON encoding.") + raise web.HTTPInternalServerError(text="Broken JSON encoding.") await self.session._remote_messages(messages) diff --git a/tests/conftest.py b/tests/conftest.py index 66a28742..14ae2de7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,14 +1,15 @@ import asyncio from unittest import mock +import aiohttp_cors import pytest from aiohttp import web -from aiohttp.test_utils import make_mocked_coro, make_mocked_request +from aiohttp.test_utils import TestClient, make_mocked_coro, make_mocked_request from aiohttp.web_urldispatcher import UrlMappingMatchInfo from multidict import CIMultiDict from yarl import URL -from sockjs import Session, SessionManager, transports +from sockjs import Session, SessionManager, add_endpoint, transports from sockjs.route import SockJSRoute @@ -125,3 +126,24 @@ def maker(handlers=transports.transport_handlers): return SockJSRoute("sm", sm, "http:sockjs-cdn", handlers, (), True) return maker + + +@pytest.fixture(name='test_client') +async def test_client_fixture(app, aiohttp_client, make_handler) -> TestClient: + handler = make_handler(None) + # Configure default CORS settings. + cors = aiohttp_cors.setup(app, defaults={ + '*': aiohttp_cors.ResourceOptions( + allow_credentials=True, + expose_headers='*', + allow_headers='*', + max_age=31536000, + ) + }) + add_endpoint( + app, + handler, + name='main', + cors_config=cors, + ) + return await aiohttp_client(app) diff --git a/tests/test_route.py b/tests/test_route.py index eb2a4818..b05439cb 100644 --- a/tests/test_route.py +++ b/tests/test_route.py @@ -1,9 +1,14 @@ import asyncio +import pytest from aiohttp import web +from aiohttp.test_utils import TestClient +from cykooz.testing import D from multidict import CIMultiDict from sockjs import protocol +from sockjs.route import ALL_METH_WO_OPTIONS +from sockjs.transports import transport_handlers from sockjs.transports.base import Transport @@ -31,22 +36,6 @@ async def test_info_entropy(make_route, make_request): assert entropy1 != entropy2 -async def test_info_options(make_route, make_request): - route = make_route() - request = make_request("OPTIONS", "/sm/") - response = await route.info_options(request) - - assert response.status == 204 - - headers = response.headers - assert "Access-Control-Max-Age" in headers - assert "Cache-Control" in headers - assert "Expires" in headers - assert "Set-Cookie" in headers - assert "access-control-allow-credentials" in headers - assert "access-control-allow-origin" in headers - - async def test_greeting(make_route, make_request): route = make_route() request = make_request("GET", "/sm/") @@ -95,8 +84,8 @@ async def test_handler_unknown_transport(make_route, make_request): route = make_route() request = make_request("GET", "/sm/", match_info={"transport": "unknown"}) - res = await route.handler(request) - assert isinstance(res, web.HTTPNotFound) + with pytest.raises(web.HTTPNotFound): + await route.handler(request) async def test_handler_emptry_session(make_route, make_request): @@ -104,8 +93,8 @@ async def test_handler_emptry_session(make_route, make_request): request = make_request( "GET", "/sm/", match_info={"transport": "websocket", "session": ""} ) - res = await route.handler(request) - assert isinstance(res, web.HTTPNotFound) + with pytest.raises(web.HTTPNotFound): + await route.handler(request) async def test_handler_bad_session_id(make_route, make_request): @@ -115,8 +104,8 @@ async def test_handler_bad_session_id(make_route, make_request): "/sm/", match_info={"transport": "websocket", "session": "test.1", "server": "000"}, ) - res = await route.handler(request) - assert isinstance(res, web.HTTPNotFound) + with pytest.raises(web.HTTPNotFound): + await route.handler(request) async def test_handler_bad_server_id(make_route, make_request): @@ -126,8 +115,8 @@ async def test_handler_bad_server_id(make_route, make_request): "/sm/", match_info={"transport": "websocket", "session": "test", "server": "test.1"}, ) - res = await route.handler(request) - assert isinstance(res, web.HTTPNotFound) + with pytest.raises(web.HTTPNotFound): + await route.handler(request) async def test_new_session_before_read(make_route, make_request): @@ -137,8 +126,8 @@ async def test_new_session_before_read(make_route, make_request): "/sm/", match_info={"transport": "xhr_send", "session": "s1", "server": "000"}, ) - res = await route.handler(request) - assert isinstance(res, web.HTTPNotFound) + with pytest.raises(web.HTTPNotFound): + await route.handler(request) async def _test_transport(make_route, make_request): @@ -181,8 +170,8 @@ def process(self): raise Exception("Error") route = make_route(handlers={"test": FakeTransport}) - res = await route.handler(request) - assert isinstance(res, web.HTTPInternalServerError) + with pytest.raises(web.HTTPInternalServerError): + await route.handler(request) async def test_release_session_for_failed_transport(make_route, make_request): @@ -200,8 +189,8 @@ async def process(self): raise Exception("Error") route = make_route(handlers={"test": FakeTransport}) - res = await route.handler(request) - assert isinstance(res, web.HTTPInternalServerError) + with pytest.raises(web.HTTPInternalServerError): + await route.handler(request) s1 = route.manager["s1"] assert not route.manager.is_acquired(s1) @@ -227,3 +216,40 @@ async def _test_raw_websocket_fail(make_route, make_request): request = make_request("GET", "/sm/") res = await route.websocket(request) assert not isinstance(res, web.HTTPNotFound) + + +@pytest.mark.parametrize( + ('url', 'method'), + [ + ('/sockjs', "GET"), + ('/sockjs/', "GET"), + ('/sockjs/info', "GET"), + ] + [ + (f'/sockjs/serv1/234/{transport}', method) + for transport in transport_handlers.keys() + for method in ALL_METH_WO_OPTIONS + ] + [ + ('/sockjs/websocket', "GET"), + ('/sockjs/iframe.html', "GET"), + ('/sockjs/iframe12.html', "GET"), + ] +) +async def test_cors_preflight(test_client: TestClient, url, method): + origin = "http://my_example.com" + headers = { + "HOST": "server.example.com", + "ACCESS-CONTROL-REQUEST-METHOD": method, + "ACCESS-CONTROL-REQUEST-HEADERS": "origin, x-requested-with", + "ORIGIN": origin, + } + + response = await test_client.options(url, headers=headers) + assert response.status in (200, 204) + + headers = response.headers + assert dict(headers) == D({ + "Access-Control-Allow-Origin": origin, + "Access-Control-Allow-Methods": method, + "Access-Control-Allow-Credentials": "true", + "Access-Control-Max-Age": "31536000" + }) diff --git a/tests/test_transport_htmlfile.py b/tests/test_transport_htmlfile.py index c03c5f11..b286be13 100644 --- a/tests/test_transport_htmlfile.py +++ b/tests/test_transport_htmlfile.py @@ -1,6 +1,7 @@ from unittest import mock import pytest +from aiohttp import web from aiohttp.test_utils import make_mocked_coro from sockjs.transports import htmlfile @@ -47,9 +48,9 @@ async def test_process_no_callback(make_transport, make_fut): transp.session = mock.Mock() transp.session._remote_closed = make_fut(1) - resp = await transp.process() + with pytest.raises(web.HTTPInternalServerError): + await transp.process() assert transp.session._remote_closed.called - assert resp.status == 500 async def test_process_bad_callback(make_transport, make_fut): @@ -57,6 +58,6 @@ async def test_process_bad_callback(make_transport, make_fut): transp.session = mock.Mock() transp.session._remote_closed = make_fut(1) - resp = await transp.process() + with pytest.raises(web.HTTPInternalServerError): + await transp.process() assert transp.session._remote_closed.called - assert resp.status == 500 diff --git a/tests/test_transport_jsonp.py b/tests/test_transport_jsonp.py index d51208f5..f6da5cb4 100644 --- a/tests/test_transport_jsonp.py +++ b/tests/test_transport_jsonp.py @@ -1,6 +1,7 @@ from unittest import mock import pytest +from aiohttp import web from aiohttp.test_utils import make_mocked_coro from sockjs.transports import jsonp @@ -43,9 +44,9 @@ async def test_process_no_callback(make_transport, make_fut): transp.session = mock.Mock() transp.session._remote_closed = make_fut(1) - resp = await transp.process() + with pytest.raises(web.HTTPInternalServerError): + await transp.process() assert transp.session._remote_closed.called - assert resp.status == 500 async def test_process_bad_callback(make_transport, make_fut): @@ -53,15 +54,15 @@ async def test_process_bad_callback(make_transport, make_fut): transp.session = mock.Mock() transp.session._remote_closed = make_fut(1) - resp = await transp.process() + with pytest.raises(web.HTTPInternalServerError): + await transp.process() assert transp.session._remote_closed.called - assert resp.status == 500 async def test_process_not_supported(make_transport): transp = make_transport(method="PUT") - resp = await transp.process() - assert resp.status == 400 + with pytest.raises(web.HTTPBadRequest): + await transp.process() async def xtest_process_bad_encoding(make_transport, make_fut): diff --git a/tests/test_transport_xhrsend.py b/tests/test_transport_xhrsend.py index 3302a87c..d7b550de 100644 --- a/tests/test_transport_xhrsend.py +++ b/tests/test_transport_xhrsend.py @@ -1,4 +1,5 @@ import pytest +from aiohttp import web from sockjs.transports import xhrsend @@ -18,8 +19,8 @@ def maker(method="GET", path="/", query_params=None): async def test_not_supported_meth(make_transport): transp = make_transport(method="PUT") - resp = await transp.process() - assert resp.status == 403 + with pytest.raises(web.HTTPForbidden): + await transp.process() async def xtest_no_payload(make_transport, make_fut): From 953282b2e7ed2b33eabe833c64a8deabe904a6e3 Mon Sep 17 00:00:00 2001 From: Kirill Kuzminykh Date: Thu, 19 Jan 2023 01:17:47 +0400 Subject: [PATCH 06/23] Fixed GitHub actions --- .github/workflows/check_and_test.yaml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/check_and_test.yaml b/.github/workflows/check_and_test.yaml index 9eeac92b..5fb3bd7a 100644 --- a/.github/workflows/check_and_test.yaml +++ b/.github/workflows/check_and_test.yaml @@ -14,7 +14,7 @@ jobs: matrix: python-version: [ "3.7", "3.8", "3.9", "3.10", "3.11" ] - name: Test on Python {{ matrix.python-version }} + name: Test on Python ${{ matrix.python-version }} runs-on: ubuntu-latest steps: @@ -29,10 +29,9 @@ jobs: run: | - python -m pip install --upgrade setuptools - python -m pip install -r requirements.txt - - python -m pip install codecov - python -m pip install -e .[test] - name: Run checks and tests run: | - - make cov + - make flake - pytest ./tests From 53dadc09ba23df1ddf4f3a83d38cd8079b01e991 Mon Sep 17 00:00:00 2001 From: Kirill Kuzminykh Date: Thu, 19 Jan 2023 01:19:08 +0400 Subject: [PATCH 07/23] Fixed GitHub actions --- .github/workflows/check_and_test.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/check_and_test.yaml b/.github/workflows/check_and_test.yaml index 5fb3bd7a..3c59a0e4 100644 --- a/.github/workflows/check_and_test.yaml +++ b/.github/workflows/check_and_test.yaml @@ -27,11 +27,11 @@ jobs: - name: Install dependencies run: | - - python -m pip install --upgrade setuptools - - python -m pip install -r requirements.txt - - python -m pip install -e .[test] + python -m pip install --upgrade setuptools + python -m pip install -r requirements.txt + python -m pip install -e .[test] - name: Run checks and tests run: | - - make flake - - pytest ./tests + make flake + pytest ./tests From 879ccc0952c1ec605447d047a4ffe45a9132b8ad Mon Sep 17 00:00:00 2001 From: Kirill Kuzminykh Date: Thu, 19 Jan 2023 01:22:58 +0400 Subject: [PATCH 08/23] Fixed GitHub actions --- requirements.txt | 4 ---- setup.py | 1 - 2 files changed, 5 deletions(-) diff --git a/requirements.txt b/requirements.txt index 3856eab9..80e1f396 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,7 @@ -black==22.12.0 flake8==6.0.0 pytest==7.2.1 pytest-aiohttp==1.0.4 -pytest-cov==4.0.0 -pytest-sugar==0.9.6 pytest-mock==3.10.0 -pytest-timeout==2.1.0 aiohttp==3.8.3 twine==4.0.2 -e .[test] diff --git a/setup.py b/setup.py index b006585b..cb5fd90e 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,6 @@ def read(f): "multidict", "yarl", "pytest-aiohttp", - "pytest-timeout", "pytest-mock", "cykooz.testing", 'aiohttp_cors', From 83818c32cd3cd5f83d9692a8a64e3d68f8401deb Mon Sep 17 00:00:00 2001 From: Kirill Kuzminykh Date: Thu, 19 Jan 2023 01:25:18 +0400 Subject: [PATCH 09/23] Fixed version of flake8 to support Python 3.7 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 80e1f396..85cf178a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -flake8==6.0.0 +flake8<6.0.0 pytest==7.2.1 pytest-aiohttp==1.0.4 pytest-mock==3.10.0 From c120b9e7d0595f2031e791192ad236c41dab13d1 Mon Sep 17 00:00:00 2001 From: Kirill Kuzminykh Date: Thu, 19 Jan 2023 01:27:10 +0400 Subject: [PATCH 10/23] Fixed typing --- sockjs/route.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sockjs/route.py b/sockjs/route.py index fd6941c7..3e173757 100644 --- a/sockjs/route.py +++ b/sockjs/route.py @@ -4,7 +4,7 @@ import json import logging import random -from typing import Iterable, Optional, Type +from typing import Iterable, List, Optional, Type from aiohttp import hdrs, web @@ -45,7 +45,7 @@ def add_endpoint( sockjs_cdn="https://cdn.jsdelivr.net/npm/sockjs-client@1/dist/sockjs.min.js", # noqa cookie_needed=True, cors_config: Optional[CorsConfig] = None, -) -> list[web.AbstractRoute]: +) -> List[web.AbstractRoute]: registered_routes = [] assert callable(handler), handler From 91fc90f0b1d61ab935094f563a65372754fd5ad1 Mon Sep 17 00:00:00 2001 From: Kirill Kuzminykh Date: Thu, 19 Jan 2023 01:29:41 +0400 Subject: [PATCH 11/23] Reverted pytest-timeout --- requirements.txt | 1 + setup.py | 1 + 2 files changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index 85cf178a..fbe74eb0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ flake8<6.0.0 pytest==7.2.1 pytest-aiohttp==1.0.4 pytest-mock==3.10.0 +pytest-timeout==2.1.0 aiohttp==3.8.3 twine==4.0.2 -e .[test] diff --git a/setup.py b/setup.py index cb5fd90e..43e4a7b8 100644 --- a/setup.py +++ b/setup.py @@ -56,6 +56,7 @@ def read(f): "yarl", "pytest-aiohttp", "pytest-mock", + "pytest-timeout", "cykooz.testing", 'aiohttp_cors', ], From c1ccb56817e7fe4c5f6c9851266c313a13873a38 Mon Sep 17 00:00:00 2001 From: Kirill Kuzminykh Date: Thu, 19 Jan 2023 16:49:56 +0400 Subject: [PATCH 12/23] Removed embedded CORS processing. --- sockjs/route.py | 3 +-- sockjs/transports/eventsource.py | 3 ++- sockjs/transports/htmlfile.py | 6 +++--- sockjs/transports/jsonp.py | 8 ++++---- sockjs/transports/utils.py | 14 -------------- sockjs/transports/xhr.py | 9 ++++----- sockjs/transports/xhrsend.py | 9 ++++----- sockjs/transports/xhrstreaming.py | 8 ++++---- 8 files changed, 22 insertions(+), 38 deletions(-) diff --git a/sockjs/route.py b/sockjs/route.py index 3e173757..d7f4f4a8 100644 --- a/sockjs/route.py +++ b/sockjs/route.py @@ -19,7 +19,7 @@ from .transports import transport_handlers from .transports.base import Transport from .transports.rawwebsocket import RawWebSocketTransport -from .transports.utils import CACHE_CONTROL, cache_headers, cors_headers, session_cookie +from .transports.utils import CACHE_CONTROL, cache_headers, session_cookie log = logging.getLogger("sockjs") @@ -231,7 +231,6 @@ async def info(self, request): resp = web.Response() resp.headers[hdrs.CONTENT_TYPE] = "application/json;charset=UTF-8" resp.headers[hdrs.CACHE_CONTROL] = CACHE_CONTROL - resp.headers.extend(cors_headers(request.headers)) info = { "entropy": random.randint(1, 2147483647), diff --git a/sockjs/transports/eventsource.py b/sockjs/transports/eventsource.py index 6e443207..e865768b 100644 --- a/sockjs/transports/eventsource.py +++ b/sockjs/transports/eventsource.py @@ -1,5 +1,6 @@ """ iframe-eventsource transport """ from aiohttp import hdrs, web +from multidict import MultiDict from .base import StreamingTransport from .utils import CACHE_CONTROL, session_cookie @@ -20,7 +21,7 @@ async def process(self): headers += session_cookie(self.request) # open sequence (sockjs protocol) - resp = self.response = web.StreamResponse(headers=headers) + resp = self.response = web.StreamResponse(headers=MultiDict(headers)) await resp.prepare(self.request) await resp.write(b"\r\n") diff --git a/sockjs/transports/htmlfile.py b/sockjs/transports/htmlfile.py index b70da18b..46661f76 100644 --- a/sockjs/transports/htmlfile.py +++ b/sockjs/transports/htmlfile.py @@ -2,9 +2,10 @@ import re from aiohttp import hdrs, web +from multidict import MultiDict from .base import StreamingTransport -from .utils import CACHE_CONTROL, cors_headers, session_cookie +from .utils import CACHE_CONTROL, session_cookie from ..protocol import dumps @@ -51,10 +52,9 @@ async def process(self): (hdrs.CONNECTION, "close"), ) headers += session_cookie(request) - headers += cors_headers(request.headers) # open sequence (sockjs protocol) - resp = self.response = web.StreamResponse(headers=headers) + resp = self.response = web.StreamResponse(headers=MultiDict(headers)) await resp.prepare(self.request) await resp.write( b"".join((PRELUDE1, callback.encode("utf-8"), PRELUDE2, b" " * 1024)) diff --git a/sockjs/transports/jsonp.py b/sockjs/transports/jsonp.py index cc5600d0..76ea6a68 100644 --- a/sockjs/transports/jsonp.py +++ b/sockjs/transports/jsonp.py @@ -3,9 +3,10 @@ from urllib.parse import unquote_plus from aiohttp import hdrs, web +from multidict import MultiDict from .base import StreamingTransport -from .utils import CACHE_CONTROL, cors_headers, session_cookie +from .utils import CACHE_CONTROL, session_cookie from ..protocol import ENCODING, dumps, loads @@ -39,9 +40,8 @@ async def process(self): (hdrs.CACHE_CONTROL, CACHE_CONTROL), ) headers += session_cookie(request) - headers += cors_headers(request.headers) - resp = self.response = web.StreamResponse(headers=headers) + resp = self.response = web.StreamResponse(headers=MultiDict(headers)) await resp.prepare(request) await self.handle_session() @@ -74,7 +74,7 @@ async def process(self): (hdrs.CACHE_CONTROL, CACHE_CONTROL), ) headers += session_cookie(request) - return web.Response(body=b"ok", headers=headers) + return web.Response(body=b"ok", headers=MultiDict(headers)) else: raise web.HTTPBadRequest(text="No support for such method: %s" % meth) diff --git a/sockjs/transports/utils.py b/sockjs/transports/utils.py index 1b00183e..2033acb8 100644 --- a/sockjs/transports/utils.py +++ b/sockjs/transports/utils.py @@ -9,20 +9,6 @@ CACHE_CONTROL = "no-store, no-cache, no-transform, must-revalidate, max-age=0" -def cors_headers(headers, nocreds=False): - origin = headers.get(hdrs.ORIGIN, "*") - cors = ((hdrs.ACCESS_CONTROL_ALLOW_ORIGIN, origin),) - - ac_headers = headers.get(hdrs.ACCESS_CONTROL_REQUEST_HEADERS) - if ac_headers: - cors += ((hdrs.ACCESS_CONTROL_ALLOW_HEADERS, ac_headers),) - - if origin != "*": - return cors + ((hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, "true"),) - else: - return cors - - def session_cookie(request): cookie = request.cookies.get("JSESSIONID", "dummy") cookies = http.cookies.SimpleCookie() diff --git a/sockjs/transports/xhr.py b/sockjs/transports/xhr.py index a57b428a..2047e07f 100644 --- a/sockjs/transports/xhr.py +++ b/sockjs/transports/xhr.py @@ -1,7 +1,8 @@ from aiohttp import hdrs, web +from multidict import MultiDict from .base import StreamingTransport -from .utils import CACHE_CONTROL, cache_headers, cors_headers, session_cookie +from .utils import CACHE_CONTROL, cache_headers, session_cookie class XHRTransport(StreamingTransport): @@ -23,18 +24,16 @@ async def process(self): (hdrs.ACCESS_CONTROL_ALLOW_METHODS, "OPTIONS, POST"), ) headers += session_cookie(request) - headers += cors_headers(request.headers) headers += cache_headers() - return web.Response(status=204, headers=headers) + return web.Response(status=204, headers=MultiDict(headers)) headers = ( (hdrs.CONTENT_TYPE, "application/javascript; charset=UTF-8"), (hdrs.CACHE_CONTROL, CACHE_CONTROL), ) headers += session_cookie(request) - headers += cors_headers(request.headers) - resp = self.response = web.StreamResponse(headers=headers) + resp = self.response = web.StreamResponse(headers=MultiDict(headers)) await resp.prepare(request) await self.handle_session() diff --git a/sockjs/transports/xhrsend.py b/sockjs/transports/xhrsend.py index 3e965ac4..db494617 100644 --- a/sockjs/transports/xhrsend.py +++ b/sockjs/transports/xhrsend.py @@ -1,7 +1,8 @@ from aiohttp import hdrs, web +from multidict import MultiDict from .base import Transport -from .utils import CACHE_CONTROL, cache_headers, cors_headers, session_cookie +from .utils import CACHE_CONTROL, cache_headers, session_cookie from ..protocol import ENCODING, loads @@ -20,9 +21,8 @@ async def process(self): (hdrs.CONTENT_TYPE, "application/javascript; charset=UTF-8"), ) headers += session_cookie(request) - headers += cors_headers(request.headers) headers += cache_headers() - return web.Response(status=204, headers=headers) + return web.Response(status=204, headers=MultiDict(headers)) data = await request.read() if not data: @@ -40,6 +40,5 @@ async def process(self): (hdrs.CACHE_CONTROL, CACHE_CONTROL), ) headers += session_cookie(request) - headers += cors_headers(request.headers) - return web.Response(status=204, headers=headers) + return web.Response(status=204, headers=MultiDict(headers)) diff --git a/sockjs/transports/xhrstreaming.py b/sockjs/transports/xhrstreaming.py index a51af032..1e89009e 100644 --- a/sockjs/transports/xhrstreaming.py +++ b/sockjs/transports/xhrstreaming.py @@ -1,7 +1,8 @@ from aiohttp import hdrs, web +from multidict import MultiDict from .base import StreamingTransport -from .utils import CACHE_CONTROL, cache_headers, cors_headers, session_cookie +from .utils import CACHE_CONTROL, cache_headers, session_cookie class XHRStreamingTransport(StreamingTransport): @@ -20,15 +21,14 @@ async def process(self): ) headers += session_cookie(request) - headers += cors_headers(request.headers) if request.method == hdrs.METH_OPTIONS: headers += ((hdrs.ACCESS_CONTROL_ALLOW_METHODS, "OPTIONS, POST"),) headers += cache_headers() - return web.Response(status=204, headers=headers) + return web.Response(status=204, headers=MultiDict(headers)) # open sequence (sockjs protocol) - resp = self.response = web.StreamResponse(headers=headers) + resp = self.response = web.StreamResponse(headers=MultiDict(headers)) resp.force_close() await resp.prepare(request) await resp.write(self.open_seq) From 8747ea138263029fba8cd95341081cdc9b6d487c Mon Sep 17 00:00:00 2001 From: Kirill Kuzminykh Date: Wed, 25 Jan 2023 17:12:01 +0400 Subject: [PATCH 13/23] Added arguments ``heartbeat_delay`` and ``disconnect_delay`` into function ``add_endpoint()``. --- CHANGES.rst | 2 ++ sockjs/route.py | 10 +++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/CHANGES.rst b/CHANGES.rst index 716ac49b..36ec0e4c 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -7,6 +7,8 @@ CHANGES - Added argument ``cors_config`` into function ``add_endpoint()`` to support of CORS settings from ``aiohttp_cors``. +- Added arguments ``heartbeat_delay`` and ``disconnect_delay`` + into function ``add_endpoint()``. - Function ``add_endpoint()`` now returns all registered routes. - Replaced returning instances of error HTTP responses on raising its as exceptions. diff --git a/sockjs/route.py b/sockjs/route.py index d7f4f4a8..ef20d861 100644 --- a/sockjs/route.py +++ b/sockjs/route.py @@ -45,6 +45,8 @@ def add_endpoint( sockjs_cdn="https://cdn.jsdelivr.net/npm/sockjs-client@1/dist/sockjs.min.js", # noqa cookie_needed=True, cors_config: Optional[CorsConfig] = None, + heartbeat_delay=25, + disconnect_delay=5, ) -> List[web.AbstractRoute]: registered_routes = [] @@ -65,7 +67,13 @@ async def handler(msg, session): # set session manager if manager is None: - manager = SessionManager(name, app, handler) + manager = SessionManager( + name, + app, + handler, + heartbeat_delay, + disconnect_delay, + ) if manager.name != name: raise ValueError("Session manage must have same name as sockjs route") From c8b36fab32320fcf27f9a017714e595e39cd67d6 Mon Sep 17 00:00:00 2001 From: Kirill Kuzminykh Date: Thu, 26 Jan 2023 19:35:32 +0400 Subject: [PATCH 14/23] - Heartbeat task moved from ``SessionManager`` into ``Session``. - Methods ``_acquire`` and ``_release`` of ``Sessions`` renamed into ``acquire`` and ``release``. --- CHANGES.rst | 3 + sockjs/session.py | 126 +++++++++++++++++++++++------------------- tests/test_session.py | 38 ++++++++----- 3 files changed, 96 insertions(+), 71 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 36ec0e4c..95a61d53 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -13,6 +13,9 @@ CHANGES - Replaced returning instances of error HTTP responses on raising its as exceptions. - Changed name of some routes. +- Heartbeat task moved from ``SessionManager`` into ``Session``. +- Methods ``_acquire`` and ``_release`` of ``Sessions`` renamed into + ``acquire`` and ``release``. 0.12.0 (2022-02-08) ------------------- diff --git a/sockjs/session.py b/sockjs/session.py index fb35899c..213e78a8 100644 --- a/sockjs/session.py +++ b/sockjs/session.py @@ -48,15 +48,16 @@ class Session: interrupted = False exception = None app: Optional[web.Application] = None + _hb_task = None def __init__( - self, - session_id, - handler, - *, - heartbeat_delay=25, - disconnect_delay=5, - debug=False + self, + session_id, + handler, + *, + heartbeat_delay=25, + disconnect_delay=5, + debug=False ): self.id = session_id self.handler = handler @@ -111,23 +112,17 @@ def expired(self) -> bool: return self.expires <= datetime.now() return False - def _tick(self, timeout=None): - if timeout is None: - self.next_heartbeat = datetime.now() + timedelta( - seconds=self.heartbeat_delay - ) - else: - self.next_heartbeat = datetime.now() + timedelta(seconds=timeout) - - async def _acquire( - self, manager: "SessionManager", request: web.Request, heartbeat=True + async def acquire( + self, + manager: "SessionManager", + request: web.Request, ): self.acquired = True self.manager = manager self.app = manager.app self.request = request self.expires = None - self._send_heartbeats = heartbeat + self._send_heartbeats = self.heartbeat_delay > 0 self._tick() self._hits += 1 @@ -147,16 +142,43 @@ async def _acquire( self._feed(FRAME_CLOSE, (3000, "Internal error")) log.exception("Exception in open session handling.") - def _release(self): + if self._hb_task is None and self._send_heartbeats: + self._hb_task = asyncio.create_task(self._heartbeat_task()) + + def release(self): self.acquired = False self.manager = None self.request = None self._send_heartbeats = False + if self._hb_task is not None: + try: + self._hb_task.cancel() + except RuntimeError: + pass # an event loop already stopped + self._hb_task = None + + def _tick(self, timeout=None): + if timeout is None: + self.next_heartbeat = datetime.now() + timedelta( + seconds=self.heartbeat_delay + ) + else: + self.next_heartbeat = datetime.now() + timedelta(seconds=timeout) def _heartbeat(self): - self._heartbeats += 1 if self._send_heartbeats: self._feed(FRAME_HEARTBEAT, FRAME_HEARTBEAT) + self._heartbeats += 1 + + async def _heartbeat_task(self): + while True: + now = datetime.now() + if self.next_heartbeat <= now: + self._heartbeat() + self._tick() + delta = (self.next_heartbeat - now).total_seconds() + if delta > 0: + await asyncio.sleep(delta) def _feed(self, frame, data): # pack messages @@ -294,16 +316,16 @@ def close(self, code=3000, reason="Go away!"): class SessionManager(dict): """A basic session manager.""" - _hb_task = None # gc task + _gc_task = None def __init__( - self, - name: str, - app: web.Application, - handler, - heartbeat_delay=25, - disconnect_delay=5, - debug=False, + self, + name: str, + app: web.Application, + handler, + heartbeat_delay=25, + disconnect_delay=5, + debug=False, ): super().__init__() self.name = name @@ -319,19 +341,19 @@ def __init__( @property def started(self): - return self._hb_task is not None + return self._gc_task is not None def start(self): - if not self._hb_task: - self._hb_task = asyncio.create_task(self._heartbeat_task()) + if not self._gc_task: + self._gc_task = asyncio.create_task(self._gc_sessions_task()) async def stop(self, _app=None): - if self._hb_task is not None: + if self._gc_task is not None: try: - self._hb_task.cancel() + self._gc_task.cancel() except RuntimeError: pass # an event loop already stopped - self._hb_task = None + self._gc_task = None await self.clear() async def _check_expiration(self, session: Session): @@ -346,6 +368,12 @@ async def _check_expiration(self, session: Session): await session._remote_closed() return session.id + async def _gc_sessions_task(self): + delay = max(self.disconnect_delay, 5) + while True: + await asyncio.sleep(delay) + await self._gc_expired_sessions() + async def _gc_expired_sessions(self): sessions = self.sessions if sessions: @@ -360,22 +388,6 @@ async def _gc_expired_sessions(self): del self[session_id] del sessions[idx] - async def _heartbeat_task(self): - delay = min(self.heartbeat_delay, self.disconnect_delay) - if delay <= 0: - delay = max(self.heartbeat_delay, self.disconnect_delay, 10) - while True: - await asyncio.sleep(delay) - await self._gc_expired_sessions() - self._heartbeat() - - def _heartbeat(self): - # Send heartbeat - now = datetime.now() - for session in self.sessions: - if session.next_heartbeat <= now: - session._heartbeat() - def _add(self, session: Session): if session.expired: raise ValueError("Can not add expired session") @@ -388,10 +400,10 @@ def _add(self, session: Session): return session def get( - self, - session_id, - create=False, - default=_marker, + self, + session_id, + create=False, + default=_marker, ) -> Session: session = super().get(session_id, None) if session is None: @@ -420,7 +432,7 @@ async def acquire(self, session: Session, request: web.Request): if sid not in self: raise KeyError("Unknown session") - await session._acquire(self, request) + await session.acquire(self, request) self.acquired[sid] = True return session @@ -430,7 +442,7 @@ def is_acquired(self, session): async def release(self, s: Session): if s.id in self.acquired: - s._release() + s.release() del self.acquired[s.id] def active_sessions(self): @@ -456,7 +468,7 @@ def broadcast(self, message): session.send_frame(blob) def __del__(self): - if len(self.sessions) or self._hb_task is not None: + if len(self.sessions) or self._gc_task is not None: warnings.warn( "Please call `await SessionManager.stop()` before del", RuntimeWarning, diff --git a/tests/test_session.py b/tests/test_session.py index 3733b009..12895f88 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -86,6 +86,7 @@ async def test_tick_custom(self, mocker, make_session): async def test_heartbeat(self, make_session): session = make_session("test") + session._send_heartbeats = True assert session._heartbeats == 0 session._heartbeat() assert session._heartbeats == 1 @@ -241,14 +242,23 @@ async def test_acquire_new_session(self, make_manager, make_session, make_reques session = make_session(result=messages) assert session.state == protocol.STATE_NEW + assert session._hb_task is None + assert not session._send_heartbeats - await session._acquire(manager, request=make_request("GET", "/test/")) + await session.acquire(manager, request=make_request("GET", "/test/")) assert session.state == protocol.STATE_OPEN assert session.manager is manager assert session._send_heartbeats + assert session._hb_task is not None assert list(session._queue) == [(protocol.FRAME_OPEN, protocol.FRAME_OPEN)] assert messages == [(protocol.OpenMessage, session)] + hb_task = session._hb_task + session.release() + assert not session._send_heartbeats + assert session._hb_task is None + assert hb_task._must_cancel + async def test_acquire_exception_in_handler( self, make_manager, make_session, make_request ): @@ -259,7 +269,7 @@ async def handler(msg, s): assert session.state == protocol.STATE_NEW sm = make_manager() - await session._acquire(sm, request=make_request("GET", "/test/")) + await session.acquire(sm, request=make_request("GET", "/test/")) assert session.state == protocol.STATE_CLOSING assert session._send_heartbeats assert session.interrupted @@ -458,9 +468,9 @@ async def test_acquire(self, make_manager, make_session, make_request): sm = make_manager() s1 = make_session() sm._add(s1) - s1._acquire = mock.Mock() - s1._acquire.return_value = asyncio.Future() - s1._acquire.return_value.set_result(1) + s1.acquire = mock.Mock() + s1.acquire.return_value = asyncio.Future() + s1.acquire.return_value.set_result(1) s2 = await sm.acquire(s1, request=make_request("GET", "/test/")) @@ -468,7 +478,7 @@ async def test_acquire(self, make_manager, make_session, make_request): assert s1.id in sm.acquired assert sm.acquired[s1.id] assert sm.is_acquired(s1) - assert s1._acquire.called + assert s1.acquire.called async def test_acquire_unknown(self, make_manager, make_session, make_request): sm = make_manager() @@ -488,14 +498,14 @@ async def test_acquire_locked(self, make_manager, make_session, make_request): async def test_release(self, make_manager, make_request): sm = make_manager() s = sm.get("test", True) - s._release = mock.Mock() + s.release = mock.Mock() await sm.acquire(s, request=make_request("GET", "/test/")) await sm.release(s) assert "test" not in sm.acquired assert not sm.is_acquired(s) - assert s._release.called + assert s.release.called async def test_active_sessions(self, make_manager): sm = make_manager() @@ -537,21 +547,21 @@ async def test_clear(self, make_manager): assert s1.state == protocol.STATE_CLOSED assert s2.state == protocol.STATE_CLOSED - async def test_heartbeat(self, make_manager): + async def test_gc_task(self, make_manager): sm = make_manager() assert not sm.started - assert sm._hb_task is None + assert sm._gc_task is None sm.start() assert sm.started - assert sm._hb_task is not None + assert sm._gc_task is not None - hb_task = sm._hb_task + gc_task = sm._gc_task await sm.stop() assert not sm.started - assert sm._hb_task is None - assert hb_task._must_cancel + assert sm._gc_task is None + assert gc_task._must_cancel async def test_gc_expire(self, make_manager, make_session, make_request): sm = make_manager() From 21e2728f1e5b53271c2d61f243eaa5bc5fac190e Mon Sep 17 00:00:00 2001 From: Kirill Kuzminykh Date: Fri, 9 Jun 2023 13:44:26 +0300 Subject: [PATCH 15/23] Added processing of ``ConnectionError`` in ``StreamingTransport``. --- CHANGES.rst | 1 + requirements.txt | 4 ++-- sockjs/route.py | 4 +--- sockjs/transports/base.py | 5 +++-- sockjs/transports/eventsource.py | 1 + 5 files changed, 8 insertions(+), 7 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 95a61d53..275bcf54 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -16,6 +16,7 @@ CHANGES - Heartbeat task moved from ``SessionManager`` into ``Session``. - Methods ``_acquire`` and ``_release`` of ``Sessions`` renamed into ``acquire`` and ``release``. +- Added processing of ``ConnectionError`` in ``StreamingTransport``. 0.12.0 (2022-02-08) ------------------- diff --git a/requirements.txt b/requirements.txt index fbe74eb0..34d71460 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ flake8<6.0.0 -pytest==7.2.1 +pytest==7.3.1 pytest-aiohttp==1.0.4 pytest-mock==3.10.0 pytest-timeout==2.1.0 -aiohttp==3.8.3 +aiohttp==3.8.4 twine==4.0.2 -e .[test] diff --git a/sockjs/route.py b/sockjs/route.py index ef20d861..ab160c60 100644 --- a/sockjs/route.py +++ b/sockjs/route.py @@ -209,10 +209,8 @@ async def handler(self, request): t = transport(manager, session, request) try: return await t.process() - except asyncio.CancelledError: + except (asyncio.CancelledError, web.HTTPException, ConnectionError): raise - except web.HTTPException as exc: - raise exc except Exception: log.exception("Exception in transport: %s" % tid) if manager.is_acquired(session): diff --git a/sockjs/transports/base.py b/sockjs/transports/base.py index 7518a66b..04511a21 100644 --- a/sockjs/transports/base.py +++ b/sockjs/transports/base.py @@ -74,7 +74,8 @@ async def handle_session(self): if self.timeout: try: frame, text = await asyncio.wait_for( - self.session._get_frame(), timeout=self.timeout + self.session._get_frame(), + timeout=self.timeout, ) except asyncio.futures.TimeoutError: frame, text = FRAME_MESSAGE, "a[]" @@ -89,7 +90,7 @@ async def handle_session(self): stop = await self._send(text) if stop: break - except asyncio.CancelledError: + except (asyncio.CancelledError, ConnectionError): await self.session._remote_close(exc=aiohttp.ClientConnectionError) await self.session._remote_closed() raise diff --git a/sockjs/transports/eventsource.py b/sockjs/transports/eventsource.py index e865768b..ee778a19 100644 --- a/sockjs/transports/eventsource.py +++ b/sockjs/transports/eventsource.py @@ -23,6 +23,7 @@ async def process(self): # open sequence (sockjs protocol) resp = self.response = web.StreamResponse(headers=MultiDict(headers)) await resp.prepare(self.request) + # Opera needs one more new line at the start. await resp.write(b"\r\n") # handle session From 140c4b2e1fbf4c27544ab52e1e1226111b7d9f47 Mon Sep 17 00:00:00 2001 From: Kirill Kuzminykh Date: Wed, 14 Jun 2023 17:19:36 +0300 Subject: [PATCH 16/23] Fixed name of transports. --- sockjs/route.py | 39 +++++++++++------- sockjs/transports/__init__.py | 3 +- sockjs/transports/base.py | 1 + sockjs/transports/eventsource.py | 1 + sockjs/transports/htmlfile.py | 1 + sockjs/transports/jsonp.py | 2 + sockjs/transports/rawwebsocket.py | 1 + sockjs/transports/websocket.py | 1 + sockjs/transports/xhr.py | 40 ------------------- .../transports/{xhrsend.py => xhr_pooling.py} | 40 ++++++++++++++++++- sockjs/transports/xhrstreaming.py | 1 + tests/test_route.py | 12 ++++++ tests/test_transport_xhr.py | 4 +- tests/test_transport_xhrsend.py | 4 +- 14 files changed, 87 insertions(+), 63 deletions(-) delete mode 100644 sockjs/transports/xhr.py rename sockjs/transports/{xhrsend.py => xhr_pooling.py} (53%) diff --git a/sockjs/route.py b/sockjs/route.py index ab160c60..17ccf4f9 100644 --- a/sockjs/route.py +++ b/sockjs/route.py @@ -117,12 +117,13 @@ async def handler(msg, session): ) ) - route_name = "sockjs-websocket-%s" % name - registered_routes.append( - router.add_route( - hdrs.METH_GET, "%s/websocket" % prefix, route.websocket, name=route_name + if "websocket-raw" not in route.disable_transports: + route_name = "sockjs-websocket-%s" % name + registered_routes.append( + router.add_route( + hdrs.METH_GET, "%s/websocket" % prefix, route.websocket, name=route_name + ) ) - ) registered_routes.append( router.add_route( @@ -177,21 +178,24 @@ def __init__( self.cookie_needed = cookie_needed self.iframe_html = (IFRAME_HTML % sockjs_cdn).encode("utf-8") self.iframe_html_hxd = hashlib.md5(self.iframe_html).hexdigest() + transport_names = { + transport_class.name + for transport_class in transport_handlers.values() + } + transport_names.add("websocket-raw") self._transport_names = sorted( - set(transport_handlers.keys()) - self.disable_transports + transport_names - self.disable_transports ) async def handler(self, request): info = request.match_info # lookup transport - tid = info["transport"] - - if tid not in self.handlers or tid in self.disable_transports: + t_id = info["transport"] + transport: Optional[Type[Transport]] = self.handlers.get(t_id) + if transport is None or transport.name in self.disable_transports: raise web.HTTPNotFound() - transport: Type[Transport] = self.handlers[tid] - # session manager = self.manager if not manager.started: @@ -209,13 +213,18 @@ async def handler(self, request): t = transport(manager, session, request) try: return await t.process() - except (asyncio.CancelledError, web.HTTPException, ConnectionError): + except (asyncio.CancelledError, web.HTTPException, ConnectionError) as e: + await session._remote_close(exc=e) raise - except Exception: - log.exception("Exception in transport: %s" % tid) + except Exception as e: + log.exception("Exception in transport: %s" % t_id) + await session._remote_close(exc=e) + raise web.HTTPInternalServerError() + finally: if manager.is_acquired(session): + await session._remote_close() + await session._remote_closed() await manager.release(session) - raise web.HTTPInternalServerError() async def websocket(self, request): if not self.manager.started: diff --git a/sockjs/transports/__init__.py b/sockjs/transports/__init__.py index 79af4db2..1c9c75f9 100644 --- a/sockjs/transports/__init__.py +++ b/sockjs/transports/__init__.py @@ -3,8 +3,7 @@ from .htmlfile import HTMLFileTransport from .jsonp import JSONPolling, JSONPollingSend from .websocket import WebSocketTransport -from .xhr import XHRTransport -from .xhrsend import XHRSendTransport +from .xhr_pooling import XHRTransport, XHRSendTransport from .xhrstreaming import XHRStreamingTransport diff --git a/sockjs/transports/base.py b/sockjs/transports/base.py index 04511a21..53307d52 100644 --- a/sockjs/transports/base.py +++ b/sockjs/transports/base.py @@ -17,6 +17,7 @@ class Transport(abc.ABC): + name: str create_session = True @classmethod diff --git a/sockjs/transports/eventsource.py b/sockjs/transports/eventsource.py index ee778a19..eaccff40 100644 --- a/sockjs/transports/eventsource.py +++ b/sockjs/transports/eventsource.py @@ -7,6 +7,7 @@ class EventsourceTransport(StreamingTransport): + name = "eventsource" create_session = True async def _send(self, text: str): diff --git a/sockjs/transports/htmlfile.py b/sockjs/transports/htmlfile.py index 46661f76..86877a10 100644 --- a/sockjs/transports/htmlfile.py +++ b/sockjs/transports/htmlfile.py @@ -27,6 +27,7 @@ class HTMLFileTransport(StreamingTransport): + name = "htmlfile" create_session = True check_callback = re.compile(r"^[a-zA-Z0-9_\.]+$") diff --git a/sockjs/transports/jsonp.py b/sockjs/transports/jsonp.py index 76ea6a68..6e75817c 100644 --- a/sockjs/transports/jsonp.py +++ b/sockjs/transports/jsonp.py @@ -11,6 +11,7 @@ class JSONPolling(StreamingTransport): + name = "jsonp-polling" create_session = True maxsize = 0 check_callback = re.compile(r"^[a-zA-Z0-9_\.]+$") @@ -81,4 +82,5 @@ async def process(self): class JSONPollingSend(JSONPolling): + name = "jsonp-polling" create_session = False diff --git a/sockjs/transports/rawwebsocket.py b/sockjs/transports/rawwebsocket.py index 140ac3a3..53487441 100644 --- a/sockjs/transports/rawwebsocket.py +++ b/sockjs/transports/rawwebsocket.py @@ -15,6 +15,7 @@ class RawWebSocketTransport(Transport): + name = "websocket-raw" heartbeat_timeout = 10 @classmethod diff --git a/sockjs/transports/websocket.py b/sockjs/transports/websocket.py index fb2af2bb..b40fe0c3 100644 --- a/sockjs/transports/websocket.py +++ b/sockjs/transports/websocket.py @@ -20,6 +20,7 @@ class WebSocketTransport(Transport): + name = "websocket" heartbeat_timeout = 10 @classmethod diff --git a/sockjs/transports/xhr.py b/sockjs/transports/xhr.py deleted file mode 100644 index 2047e07f..00000000 --- a/sockjs/transports/xhr.py +++ /dev/null @@ -1,40 +0,0 @@ -from aiohttp import hdrs, web -from multidict import MultiDict - -from .base import StreamingTransport -from .utils import CACHE_CONTROL, cache_headers, session_cookie - - -class XHRTransport(StreamingTransport): - """Long polling derivative transports, - used for XHRPolling and JSONPolling.""" - - create_session = True - maxsize = 0 - - async def _send(self, text: str): - return await super()._send(text + "\n") - - async def process(self): - request = self.request - - if request.method == hdrs.METH_OPTIONS: - headers = ( - (hdrs.CONTENT_TYPE, "application/javascript; charset=UTF-8"), - (hdrs.ACCESS_CONTROL_ALLOW_METHODS, "OPTIONS, POST"), - ) - headers += session_cookie(request) - headers += cache_headers() - return web.Response(status=204, headers=MultiDict(headers)) - - headers = ( - (hdrs.CONTENT_TYPE, "application/javascript; charset=UTF-8"), - (hdrs.CACHE_CONTROL, CACHE_CONTROL), - ) - headers += session_cookie(request) - - resp = self.response = web.StreamResponse(headers=MultiDict(headers)) - await resp.prepare(request) - - await self.handle_session() - return resp diff --git a/sockjs/transports/xhrsend.py b/sockjs/transports/xhr_pooling.py similarity index 53% rename from sockjs/transports/xhrsend.py rename to sockjs/transports/xhr_pooling.py index db494617..23895457 100644 --- a/sockjs/transports/xhrsend.py +++ b/sockjs/transports/xhr_pooling.py @@ -1,12 +1,48 @@ from aiohttp import hdrs, web from multidict import MultiDict -from .base import Transport +from .base import StreamingTransport, Transport from .utils import CACHE_CONTROL, cache_headers, session_cookie -from ..protocol import ENCODING, loads +from ..protocol import loads, ENCODING + + +class XHRTransport(StreamingTransport): + """Long polling derivative transports, + used for XHRPolling and JSONPolling.""" + name = "xhr-polling" + create_session = True + maxsize = 0 + + async def _send(self, text: str): + return await super()._send(text + "\n") + + async def process(self): + request = self.request + + if request.method == hdrs.METH_OPTIONS: + headers = ( + (hdrs.CONTENT_TYPE, "application/javascript; charset=UTF-8"), + (hdrs.ACCESS_CONTROL_ALLOW_METHODS, "OPTIONS, POST"), + ) + headers += session_cookie(request) + headers += cache_headers() + return web.Response(status=204, headers=MultiDict(headers)) + + headers = ( + (hdrs.CONTENT_TYPE, "application/javascript; charset=UTF-8"), + (hdrs.CACHE_CONTROL, CACHE_CONTROL), + ) + headers += session_cookie(request) + + resp = self.response = web.StreamResponse(headers=MultiDict(headers)) + await resp.prepare(request) + + await self.handle_session() + return resp class XHRSendTransport(Transport): + name = "xhr-polling" create_session = False async def process(self): diff --git a/sockjs/transports/xhrstreaming.py b/sockjs/transports/xhrstreaming.py index 1e89009e..f2fcd943 100644 --- a/sockjs/transports/xhrstreaming.py +++ b/sockjs/transports/xhrstreaming.py @@ -6,6 +6,7 @@ class XHRStreamingTransport(StreamingTransport): + name = "xhr-streaming" create_session = True open_seq = b"h" * 2048 + b"\n" diff --git a/tests/test_route.py b/tests/test_route.py index b05439cb..911c5898 100644 --- a/tests/test_route.py +++ b/tests/test_route.py @@ -21,6 +21,15 @@ async def test_info(make_route, make_request): assert info["websocket"] assert info["cookie_needed"] + assert info["transports"] == [ + "eventsource", + "htmlfile", + "jsonp-polling", + "websocket", + "websocket-raw", + "xhr-polling", + "xhr-streaming", + ] async def test_info_entropy(make_route, make_request): @@ -162,6 +171,8 @@ async def test_fail_transport(make_route, make_request): params = [] class FakeTransport(Transport): + name = "test" + def __init__(self, manager, session, request): super().__init__(manager, session, request) params.append((manager, session, request)) @@ -182,6 +193,7 @@ async def test_release_session_for_failed_transport(make_route, make_request): ) class FakeTransport(Transport): + name = "test" create_session = True async def process(self): diff --git a/tests/test_transport_xhr.py b/tests/test_transport_xhr.py index d52433cf..8971c48e 100644 --- a/tests/test_transport_xhr.py +++ b/tests/test_transport_xhr.py @@ -1,6 +1,6 @@ import pytest -from sockjs.transports import xhr +from sockjs.transports import xhr_pooling @pytest.fixture @@ -11,7 +11,7 @@ def maker(method="GET", path="/", query_params=None): request = make_request(method, path, query_params=query_params) request.app.freeze() session = manager.get("TestSessionXhr", create=True) - return xhr.XHRTransport(manager, session, request) + return xhr_pooling.XHRTransport(manager, session, request) return maker diff --git a/tests/test_transport_xhrsend.py b/tests/test_transport_xhrsend.py index d7b550de..56574bc1 100644 --- a/tests/test_transport_xhrsend.py +++ b/tests/test_transport_xhrsend.py @@ -1,7 +1,7 @@ import pytest from aiohttp import web -from sockjs.transports import xhrsend +from sockjs.transports import xhr_pooling @pytest.fixture @@ -12,7 +12,7 @@ def maker(method="GET", path="/", query_params=None): request = make_request(method, path, query_params=query_params) request.app.freeze() session = manager.get("TestSessionXhrSend", create=True) - return xhrsend.XHRSendTransport(manager, session, request) + return xhr_pooling.XHRSendTransport(manager, session, request) return maker From 8516f0226d96be65c61b20e11e16155243710049 Mon Sep 17 00:00:00 2001 From: Kirill Kuzminykh Date: Fri, 7 Jul 2023 21:17:56 +0300 Subject: [PATCH 17/23] - Changed arguments of handler function. Now handler function must be defined like ``async def handler(manager, session, msg):`` - Constants: - FRAME_OPEN - FRAME_CLOSE - FRAME_MESSAGE - FRAME_MESSAGE_BLOB - FRAME_HEARTBEAT replaced by ``Frame`` enums with corresponding values. - Constants: - MSG_OPEN - MSG_MESSAGE - MSG_CLOSE - MSG_CLOSED replaced by ``MsgType`` enums with corresponding values. - Constants: - STATE_NEW - STATE_OPEN - STATE_CLOSING - STATE_CLOSED replaced by ``SessionState`` enums with corresponding values. --- CHANGES.rst | 28 ++ examples/chat.py | 21 +- sockjs-testsrv.py | 34 +-- sockjs/__init__.py | 24 +- sockjs/protocol.py | 58 +++-- sockjs/route.py | 30 ++- sockjs/session.py | 306 +++++++++++----------- sockjs/transports/base.py | 50 ++-- sockjs/transports/htmlfile.py | 4 +- sockjs/transports/jsonp.py | 7 +- sockjs/transports/rawwebsocket.py | 28 +- sockjs/transports/websocket.py | 32 +-- sockjs/transports/xhr_pooling.py | 2 +- tests/conftest.py | 16 +- tests/test_route.py | 2 +- tests/test_session.py | 372 ++++++++++++++------------- tests/test_transport.py | 16 +- tests/test_transport_htmlfile.py | 8 +- tests/test_transport_jsonp.py | 14 +- tests/test_transport_rawwebsocket.py | 12 +- tests/test_transport_websocket.py | 18 +- tests/test_transport_xhrsend.py | 4 +- 22 files changed, 566 insertions(+), 520 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 275bcf54..bb3e0fe7 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -17,6 +17,34 @@ CHANGES - Methods ``_acquire`` and ``_release`` of ``Sessions`` renamed into ``acquire`` and ``release``. - Added processing of ``ConnectionError`` in ``StreamingTransport``. +- Changed arguments of handler function. Now handler function must be defined + like ``async def handler(manager, session, msg):`` +- Constants: + + - FRAME_OPEN + - FRAME_CLOSE + - FRAME_MESSAGE + - FRAME_MESSAGE_BLOB + - FRAME_HEARTBEAT + + replaced by ``Frame`` enums with corresponding values. +- Constants: + + - MSG_OPEN + - MSG_MESSAGE + - MSG_CLOSE + - MSG_CLOSED + + replaced by ``MsgType`` enums with corresponding values. +- Constants: + + - STATE_NEW + - STATE_OPEN + - STATE_CLOSING + - STATE_CLOSED + + replaced by ``SessionState`` enums with corresponding values. + 0.12.0 (2022-02-08) ------------------- diff --git a/examples/chat.py b/examples/chat.py index ee5c660b..e98f0cab 100644 --- a/examples/chat.py +++ b/examples/chat.py @@ -5,21 +5,24 @@ from aiohttp import web import sockjs +from sockjs import SessionManager, Session, SockjsMessage, MsgType CHAT_FILE = open( os.path.join(os.path.dirname(__file__), 'chat.html'), 'rb').read() -async def chat_msg_handler(msg, session): - if session.manager is None: - return - if msg.type == sockjs.MSG_OPEN: - session.manager.broadcast("Someone joined.") - elif msg.type == sockjs.MSG_MESSAGE: - session.manager.broadcast(msg.data) - elif msg.type == sockjs.MSG_CLOSED: - session.manager.broadcast("Someone left.") +async def chat_msg_handler( + manager: SessionManager, + session: Session, + msg: SockjsMessage, +): + if msg.type == MsgType.OPEN: + manager.broadcast("Someone joined.") + elif msg.type == MsgType.MESSAGE: + manager.broadcast(msg.data) + elif msg.type == MsgType.CLOSED: + manager.broadcast("Someone left.") def index(request): diff --git a/sockjs-testsrv.py b/sockjs-testsrv.py index 881c4fce..10541285 100644 --- a/sockjs-testsrv.py +++ b/sockjs-testsrv.py @@ -1,5 +1,5 @@ -import asyncio import logging + from aiohttp import web import sockjs @@ -8,24 +8,23 @@ from sockjs.transports.xhrstreaming import XHRStreamingTransport -async def echoSession(msg, session): - if msg.type == sockjs.MSG_MESSAGE: +async def echo_session(manager, session, msg): + if msg.type == sockjs.MsgType.MESSAGE: session.send(msg.data) -async def closeSessionHander(msg, session): - if msg.type == sockjs.MSG_OPEN: +async def close_session_handler(manager, session, msg): + if msg.type == sockjs.MsgType.OPEN: session.close() -async def broadcastSession(msg, session): - if msg.type == sockjs.MSG_OPEN: - session.manager.broadcast(msg.data) +async def broadcast_session(manager, session, msg): + if msg.type == sockjs.MsgType.OPEN: + manager.broadcast(msg.data) if __name__ == '__main__': """ Sockjs tests server """ - loop = asyncio.get_event_loop() logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s') @@ -36,16 +35,21 @@ async def broadcastSession(msg, session): app = web.Application() sockjs.add_endpoint( - app, echoSession, name='echo', prefix='/echo') + app, echo_session, name='echo', prefix='/echo' + ) sockjs.add_endpoint( - app, closeSessionHander, name='close', prefix='/close') + app, close_session_handler, name='close', prefix='/close' + ) sockjs.add_endpoint( - app, broadcastSession, name='broadcast', prefix='/broadcast') + app, broadcast_session, name='broadcast', prefix='/broadcast' + ) sockjs.add_endpoint( - app, echoSession, name='wsoff', prefix='/disabled_websocket_echo', + app, echo_session, name='wsoff', prefix='/disabled_websocket_echo', disable_transports=('websocket',)) + sockjs.add_endpoint( - app, echoSession, name='cookie', prefix='/cookie_needed_echo', - cookie_needed=True) + app, echo_session, name='cookie', prefix='/cookie_needed_echo', + cookie_needed=True + ) web.run_app(app, port=8081) diff --git a/sockjs/__init__.py b/sockjs/__init__.py index 31ab5d7e..b37cec96 100644 --- a/sockjs/__init__.py +++ b/sockjs/__init__.py @@ -1,21 +1,11 @@ from .exceptions import SessionIsAcquired, SessionIsClosed -from .protocol import ( - MSG_CLOSE, - MSG_CLOSED, - MSG_MESSAGE, - MSG_OPEN, - STATE_CLOSED, - STATE_CLOSING, - STATE_NEW, - STATE_OPEN, -) +from .protocol import SessionState, MsgType, Frame, SockjsMessage from .route import add_endpoint, get_manager from .session import Session, SessionManager __version__ = "0.13.0" - __all__ = ( "get_manager", "add_endpoint", @@ -23,12 +13,8 @@ "SessionManager", "SessionIsClosed", "SessionIsAcquired", - "STATE_NEW", - "STATE_OPEN", - "STATE_CLOSING", - "STATE_CLOSED", - "MSG_OPEN", - "MSG_MESSAGE", - "MSG_CLOSE", - "MSG_CLOSED", + "SessionState", + "MsgType", + "Frame", + "SockjsMessage", ) diff --git a/sockjs/protocol.py b/sockjs/protocol.py index b56870f2..7ec5fa9d 100644 --- a/sockjs/protocol.py +++ b/sockjs/protocol.py @@ -1,15 +1,19 @@ import dataclasses +import enum import hashlib from datetime import datetime -from typing import Optional +from typing import Union ENCODING = "utf-8" -STATE_NEW = 0 -STATE_OPEN = 1 -STATE_CLOSING = 2 -STATE_CLOSED = 3 + +@enum.unique +class SessionState(enum.Enum): + NEW = 0 + OPEN = 1 + CLOSING = 2 + CLOSED = 3 _days = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"] @@ -35,9 +39,9 @@ try: import ujson as json + kwargs = {} # pragma: no cover except ImportError: # pragma: no cover - def dthandler(obj): if isinstance(obj, datetime): now = obj.timetuple() @@ -51,6 +55,7 @@ def dthandler(obj): now[5], ) + kwargs = {"default": dthandler, "separators": (",", ":")} # Faster @@ -64,11 +69,13 @@ def dthandler(obj): # Frames # ------ -FRAME_OPEN = "o" -FRAME_CLOSE = "c" -FRAME_MESSAGE = "a" -FRAME_MESSAGE_BLOB = "a1" -FRAME_HEARTBEAT = "h" +@enum.unique +class Frame(enum.Enum): + OPEN = "o" + CLOSE = "c" + MESSAGE = "a" + MESSAGE_BLOB = "a1" + HEARTBEAT = "h" # ------------------ @@ -91,7 +98,6 @@ def dthandler(obj): """.strip() - IFRAME_MD5 = hashlib.md5(IFRAME_HTML.encode()).hexdigest() loads = json.loads @@ -103,36 +109,38 @@ def dumps(data): def close_frame(code, reason): - return FRAME_CLOSE + json.dumps([code, reason], **kwargs) + return Frame.CLOSE.value + json.dumps([code, reason], **kwargs) def message_frame(message): - return FRAME_MESSAGE + json.dumps([message], **kwargs) + return Frame.MESSAGE.value + json.dumps([message], **kwargs) def messages_frame(messages): - return FRAME_MESSAGE + json.dumps(messages, **kwargs) + return Frame.MESSAGE.value + json.dumps(messages, **kwargs) # Handler messages # --------------------- -MSG_OPEN = 1 -MSG_MESSAGE = 2 -MSG_CLOSE = 3 -MSG_CLOSED = 4 +@enum.unique +class MsgType(enum.Enum): + OPEN = 1 + MESSAGE = 2 + CLOSE = 3 + CLOSED = 4 @dataclasses.dataclass(frozen=True) class SockjsMessage: - type: int - data: Optional[str] + type: MsgType + data: Union[str, Exception, None] @property - def tp(self) -> int: + def tp(self) -> MsgType: return self.type -OpenMessage = SockjsMessage(MSG_OPEN, None) -CloseMessage = SockjsMessage(MSG_CLOSE, None) -ClosedMessage = SockjsMessage(MSG_CLOSED, None) +OPEN_MESSAGE = SockjsMessage(MsgType.OPEN, None) +CLOSE_MESSAGE = SockjsMessage(MsgType.CLOSE, None) +CLOSED_MESSAGE = SockjsMessage(MsgType.CLOSED, None) diff --git a/sockjs/route.py b/sockjs/route.py index 17ccf4f9..c4c2f6d0 100644 --- a/sockjs/route.py +++ b/sockjs/route.py @@ -15,7 +15,7 @@ CorsConfig = None from .protocol import IFRAME_HTML -from .session import SessionManager +from .session import SessionManager, HandlerType from .transports import transport_handlers from .transports.base import Transport from .transports.rawwebsocket import RawWebSocketTransport @@ -36,7 +36,7 @@ def _gen_endpoint_name(): def add_endpoint( app: web.Application, - handler, + handler: HandlerType, *, name="", prefix="/sockjs", @@ -57,8 +57,8 @@ def add_endpoint( ): sync_handler = handler - async def handler(msg, session): - return sync_handler(msg, session) + async def handler(m, s, msg): + return sync_handler(m, s, msg) router = app.router @@ -192,8 +192,8 @@ async def handler(self, request): # lookup transport t_id = info["transport"] - transport: Optional[Type[Transport]] = self.handlers.get(t_id) - if transport is None or transport.name in self.disable_transports: + transport_class: Optional[Type[Transport]] = self.handlers.get(t_id) + if transport_class is None or transport_class.name in self.disable_transports: raise web.HTTPNotFound() # session @@ -206,24 +206,26 @@ async def handler(self, request): raise web.HTTPNotFound() try: - session = transport.get_session(manager, sid) + session = transport_class.get_session(manager, sid) except KeyError: raise web.HTTPNotFound(headers=session_cookie(request)) - t = transport(manager, session, request) + transport = transport_class(manager, session, request) try: - return await t.process() + return await transport.process() except (asyncio.CancelledError, web.HTTPException, ConnectionError) as e: - await session._remote_close(exc=e) + if transport.create_session: + await manager.remote_close(session, exc=e) raise except Exception as e: log.exception("Exception in transport: %s" % t_id) - await session._remote_close(exc=e) + if transport.create_session: + await manager.remote_close(session, exc=e) raise web.HTTPInternalServerError() finally: - if manager.is_acquired(session): - await session._remote_close() - await session._remote_closed() + if transport.create_session and manager.is_acquired(session): + await manager.remote_close(session) + await manager.remote_closed(session) await manager.release(session) async def websocket(self, request): diff --git a/sockjs/session.py b/sockjs/session.py index 213e78a8..645375e2 100644 --- a/sockjs/session.py +++ b/sockjs/session.py @@ -1,35 +1,28 @@ import asyncio -import collections import logging import warnings +from collections import deque from datetime import datetime, timedelta -from typing import List, Optional, Tuple +from typing import Optional, Tuple, Callable, Awaitable, TypeVar, Union, Any from aiohttp import web +from . import SessionState from .exceptions import SessionIsAcquired, SessionIsClosed from .protocol import ( - ClosedMessage, - FRAME_CLOSE, - FRAME_HEARTBEAT, - FRAME_MESSAGE, - FRAME_MESSAGE_BLOB, - FRAME_OPEN, - MSG_CLOSE, - MSG_MESSAGE, - OpenMessage, - STATE_CLOSED, - STATE_CLOSING, - STATE_NEW, - STATE_OPEN, + CLOSED_MESSAGE, + MsgType, + OPEN_MESSAGE, SockjsMessage, close_frame, message_frame, messages_frame, + Frame, ) log = logging.getLogger("sockjs") +HandlerType = Callable[["SessionManager", "Session", SockjsMessage], Awaitable] class Session: @@ -42,25 +35,21 @@ class Session: ``acquired``: Acquired state, indicates that transport is using session """ - manager: Optional["SessionManager"] = None acquired = False - state = STATE_NEW + state = SessionState.NEW interrupted = False exception = None - app: Optional[web.Application] = None _hb_task = None def __init__( self, - session_id, - handler, + session_id: str, *, heartbeat_delay=25, disconnect_delay=5, debug=False ): self.id = session_id - self.handler = handler self.heartbeat_delay = heartbeat_delay self.disconnect_delay = disconnect_delay self.next_heartbeat = datetime.now() + timedelta(seconds=heartbeat_delay) @@ -74,14 +63,14 @@ def __init__( self._send_heartbeats = False self._debug = debug self._waiter = None - self._queue = collections.deque() + self._queue: deque[tuple[Frame, Any]] = deque() def __str__(self): result = ["id=%r" % (self.id,)] - if self.state == STATE_OPEN: + if self.state == SessionState.OPEN: result.append("connected") - elif self.state == STATE_CLOSED: + elif self.state == SessionState.CLOSED: result.append("closed") else: result.append("disconnected") @@ -112,42 +101,26 @@ def expired(self) -> bool: return self.expires <= datetime.now() return False - async def acquire( - self, - manager: "SessionManager", - request: web.Request, - ): + def acquire(self, request: web.Request) -> bool: + """Returns True if session has opened.""" self.acquired = True - self.manager = manager - self.app = manager.app self.request = request self.expires = None self._send_heartbeats = self.heartbeat_delay > 0 - self._tick() + self.tick() self._hits += 1 - if self.state == STATE_NEW: + if self.state == SessionState.NEW: log.debug("open session: %s", self.id) - self.state = STATE_OPEN - self._feed(FRAME_OPEN, FRAME_OPEN) - try: - await self.handler(OpenMessage, self) - except asyncio.CancelledError: - raise - except Exception as exc: - self.state = STATE_CLOSING - self.exception = exc - self.interrupted = True - self._feed(FRAME_CLOSE, (3000, "Internal error")) - log.exception("Exception in open session handling.") + self.state = SessionState.OPEN + self.feed(Frame.OPEN, Frame.OPEN.value) + return True - if self._hb_task is None and self._send_heartbeats: - self._hb_task = asyncio.create_task(self._heartbeat_task()) + return False def release(self): self.acquired = False - self.manager = None self.request = None self._send_heartbeats = False if self._hb_task is not None: @@ -157,7 +130,11 @@ def release(self): pass # an event loop already stopped self._hb_task = None - def _tick(self, timeout=None): + def create_heartbeat_task(self): + if self._hb_task is None and self._send_heartbeats: + self._hb_task = asyncio.create_task(self._heartbeat_task()) + + def tick(self, timeout=None): if timeout is None: self.next_heartbeat = datetime.now() + timedelta( seconds=self.heartbeat_delay @@ -165,90 +142,55 @@ def _tick(self, timeout=None): else: self.next_heartbeat = datetime.now() + timedelta(seconds=timeout) - def _heartbeat(self): + def heartbeat(self): if self._send_heartbeats: - self._feed(FRAME_HEARTBEAT, FRAME_HEARTBEAT) + self.feed(Frame.HEARTBEAT, Frame.HEARTBEAT.value) self._heartbeats += 1 async def _heartbeat_task(self): while True: now = datetime.now() if self.next_heartbeat <= now: - self._heartbeat() - self._tick() + self.heartbeat() + self.tick() delta = (self.next_heartbeat - now).total_seconds() if delta > 0: await asyncio.sleep(delta) - def _feed(self, frame, data): + def feed(self, frame: Frame, data): # pack messages - if frame == FRAME_MESSAGE: - if self._queue and self._queue[-1][0] == FRAME_MESSAGE: + if frame == Frame.MESSAGE: + if self._queue and self._queue[-1][0] == Frame.MESSAGE: self._queue[-1][1].append(data) else: self._queue.append((frame, [data])) else: self._queue.append((frame, data)) - # notify waiter - waiter = self._waiter - if waiter is not None: - self._waiter = None - if not waiter.cancelled(): - waiter.set_result(True) - self._tick() + self.release_waiters() + self.tick() - async def _get_frame(self, pack=True) -> Tuple[str, str]: - if not self._queue and self.state != STATE_CLOSED: + async def get_frame(self, pack=True) -> Tuple[Frame, str]: + if not self._queue and self.state != SessionState.CLOSED: assert not self._waiter self._waiter = asyncio.Future() await self._waiter if self._queue: frame, payload = self._queue.popleft() - self._tick() + self.tick() if pack: - if frame == FRAME_CLOSE: - return FRAME_CLOSE, close_frame(*payload) - elif frame == FRAME_MESSAGE: - return FRAME_MESSAGE, messages_frame(payload) + match frame: + case Frame.CLOSE: + return frame, close_frame(*payload) + case Frame.MESSAGE: + return frame, messages_frame(payload) return frame, payload else: raise SessionIsClosed() - async def _remote_close(self, exc=None): - """Close session from remote.""" - if self.state in (STATE_CLOSING, STATE_CLOSED): - return - - log.info("close session: %s", self.id) - self._tick() - self.state = STATE_CLOSING - if exc is not None: - self.exception = exc - self.interrupted = True - try: - await self.handler(SockjsMessage(MSG_CLOSE, exc), self) - except Exception: - log.exception("Exception in close handler.") - - async def _remote_closed(self): - if self.state == STATE_CLOSED: - return - - if self.disconnect_delay and not self.expired: - self.expire() - return - - log.info("session closed: %s", self.id) - self.state = STATE_CLOSED - self.expire() - try: - await self.handler(ClosedMessage, self) - except Exception: - log.exception("Exception in closed handler.") - + def release_waiters(self): # notify waiter waiter = self._waiter if waiter is not None: @@ -256,25 +198,6 @@ async def _remote_closed(self): if not waiter.cancelled(): waiter.set_result(True) - async def _remote_message(self, msg): - log.debug("incoming message: %s, %s", self.id, msg[:200]) - self._tick() - - try: - await self.handler(SockjsMessage(MSG_MESSAGE, msg), self) - except Exception: - log.exception("Exception in message handler.") - - async def _remote_messages(self, messages): - self._tick() - - for msg in messages: - log.debug("incoming message: %s, %s", self.id, msg[:200]) - try: - await self.handler(SockjsMessage(MSG_MESSAGE, msg), self) - except Exception: - log.exception("Exception in message handler.") - def send(self, msg: str) -> bool: """send message to client.""" assert isinstance(msg, str), "String is required" @@ -282,10 +205,10 @@ def send(self, msg: str) -> bool: if self._debug: log.info("outgoing message: %s, %s", self.id, str(msg)[:200]) - if self.state != STATE_OPEN: + if self.state != SessionState.OPEN: return False - self._feed(FRAME_MESSAGE, msg) + self.feed(Frame.MESSAGE, msg) return True def send_frame(self, frm): @@ -293,27 +216,27 @@ def send_frame(self, frm): if self._debug: log.info("outgoing message: %s, %s", self.id, frm[:200]) - if self.state != STATE_OPEN: + if self.state != SessionState.OPEN: return - self._feed(FRAME_MESSAGE_BLOB, frm) + self.feed(Frame.MESSAGE_BLOB, frm) def close(self, code=3000, reason="Go away!"): """close session""" - if self.state in (STATE_CLOSING, STATE_CLOSED): + if self.state in (SessionState.CLOSING, SessionState.CLOSED): return if self._debug: log.debug("close session: %s", self.id) - self.state = STATE_CLOSING - self._feed(FRAME_CLOSE, (code, reason)) + self.state = SessionState.CLOSING + self.feed(Frame.CLOSE, (code, reason)) _marker = object() -class SessionManager(dict): +class SessionManager: """A basic session manager.""" _gc_task = None @@ -322,19 +245,18 @@ def __init__( self, name: str, app: web.Application, - handler, + handler: HandlerType, heartbeat_delay=25, disconnect_delay=5, debug=False, ): - super().__init__() self.name = name self.route_name = "sockjs-url-%s" % name self.app = app self.handler = handler self.factory = Session self.acquired = {} - self.sessions: List[Session] = [] + self.sessions: dict[str, Session] = {} self.heartbeat_delay = heartbeat_delay self.disconnect_delay = disconnect_delay self.debug = debug @@ -362,10 +284,10 @@ async def _check_expiration(self, session: Session): # Session is to be GC'd immediately if session.id in self.acquired: await self.release(session) - if session.state == STATE_OPEN: - await session._remote_close() - if session.state == STATE_CLOSING: - await session._remote_closed() + if session.state == SessionState.OPEN: + await self.remote_close(session) + if session.state == SessionState.CLOSING: + await self.remote_closed(session) return session.id async def _gc_sessions_task(self): @@ -377,7 +299,10 @@ async def _gc_sessions_task(self): async def _gc_expired_sessions(self): sessions = self.sessions if sessions: - tasks = [self._check_expiration(session) for session in sessions] + tasks = [ + self._check_expiration(session) + for session in sessions.values() + ] expired_session_ids = await asyncio.gather(*tasks) idx = 0 @@ -385,33 +310,29 @@ async def _gc_expired_sessions(self): if session_id is None: idx += 1 continue - del self[session_id] - del sessions[idx] + sessions.pop(session_id, None) def _add(self, session: Session): if session.expired: raise ValueError("Can not add expired session") - session.manager = self - session.app = self.app - - self[session.id] = session - self.sessions.append(session) + self.sessions[session.id] = session return session + _T = TypeVar('_T') + def get( self, session_id, create=False, - default=_marker, - ) -> Session: - session = super().get(session_id, None) + default: _T = _marker, + ) -> Union[Session, _T]: + session = self.sessions.get(session_id, None) if session is None: if create: session = self._add( self.factory( session_id, - self.handler, heartbeat_delay=self.heartbeat_delay, disconnect_delay=self.disconnect_delay, debug=self.debug, @@ -429,10 +350,22 @@ async def acquire(self, session: Session, request: web.Request): if sid in self.acquired: raise SessionIsAcquired("Another connection still open") - if sid not in self: + if sid not in self.sessions: raise KeyError("Unknown session") - await session.acquire(self, request) + if session.acquire(request): + try: + await self.handler(self, session, OPEN_MESSAGE) + except asyncio.CancelledError: + raise + except Exception as exc: + session.state = SessionState.CLOSING + session.exception = exc + session.interrupted = True + session.feed(Frame.CLOSE, (3000, "Internal error")) + log.exception("Exception in open session handling.") + + session.create_heartbeat_task() self.acquired[sid] = True return session @@ -446,25 +379,24 @@ async def release(self, s: Session): del self.acquired[s.id] def active_sessions(self): - for session in list(self.values()): + for session in list(self.sessions.values()): if not session.expired: yield session async def clear(self): """Manually expire all sessions in the pool.""" - for session in list(self.values()): - if session.state != STATE_CLOSED: + for session in list(self.sessions.values()): + if session.state != SessionState.CLOSED: session.disconnect_delay = 0 - await session._remote_closed() - + await self.remote_closed(session) self.sessions.clear() - super().clear() - def broadcast(self, message): + def broadcast(self, message, exclude_session_ids: Optional[set] = None): blob = message_frame(message) + exclude_session_ids = exclude_session_ids or set() - for session in list(self.values()): - if not session.expired: + for session in self.sessions.values(): + if not session.expired and session.id not in exclude_session_ids: session.send_frame(blob) def __del__(self): @@ -473,3 +405,59 @@ def __del__(self): "Please call `await SessionManager.stop()` before del", RuntimeWarning, ) + + async def remote_message(self, session: Session, msg): + """Call handler with message received from client.""" + log.debug("incoming message: %s, %s", session.id, msg[:200]) + session.tick() + + try: + await self.handler(self, session, SockjsMessage(MsgType.MESSAGE, msg)) + except Exception: + log.exception("Exception in message handler.") + + async def remote_messages(self, session: Session, messages): + """Call handler for all messages received from client.""" + session.tick() + + for msg in messages: + log.debug("incoming message: %s, %s", session.id, msg[:200]) + try: + await self.handler(self, session, SockjsMessage(MsgType.MESSAGE, msg)) + except Exception: + log.exception("Exception in message handler.") + + async def remote_close(self, session: Session, exc=None): + """Start session closing.""" + if session.state in (SessionState.CLOSING, SessionState.CLOSED): + return + + log.info("close session: %s", session.id) + session.tick() + session.state = SessionState.CLOSING + if exc is not None: + session.exception = exc + session.interrupted = True + try: + await self.handler(self, session, SockjsMessage(MsgType.CLOSE, exc)) + except Exception: + log.exception("Exception in close handler.") + + async def remote_closed(self, session: Session): + """Close session.""" + if session.state == SessionState.CLOSED: + return + + if session.disconnect_delay and not session.expired: + session.expire() + return + + log.info("session closed: %s", session.id) + session.state = SessionState.CLOSED + session.expire() + try: + await self.handler(self, session, CLOSED_MESSAGE) + except Exception: + log.exception("Exception in closed handler.") + + session.release_waiters() diff --git a/sockjs/transports/base.py b/sockjs/transports/base.py index 53307d52..f836c1ef 100644 --- a/sockjs/transports/base.py +++ b/sockjs/transports/base.py @@ -1,21 +1,23 @@ import abc import asyncio -import aiohttp from aiohttp import web +from aiohttp.web_exceptions import HTTPClientError from ..exceptions import SessionIsAcquired, SessionIsClosed from ..protocol import ( ENCODING, - FRAME_CLOSE, - FRAME_MESSAGE, - STATE_CLOSED, - STATE_CLOSING, close_frame, + SessionState, + Frame, ) from ..session import Session, SessionManager +class HTTPClientClosedConnection(HTTPClientError): + status_code = 499 + + class Transport(abc.ABC): name: str create_session = True @@ -24,7 +26,12 @@ class Transport(abc.ABC): def get_session(cls, manager: SessionManager, session_id: str) -> Session: return manager.get(session_id, create=cls.create_session) - def __init__(self, manager: SessionManager, session: Session, request: web.Request): + def __init__( + self, + manager: SessionManager, + session: Session, + request: web.Request, + ): self.manager = manager self.session = session self.request = request @@ -44,10 +51,13 @@ def __init__(self, manager: SessionManager, session: Session, request: web.Reque self.response = None async def _send(self, text: str): - blob = text.encode(ENCODING) - await self.response.write(blob) - self.size += len(blob) - return self.size > self.maxsize + try: + blob = text.encode(ENCODING) + await self.response.write(blob) + self.size += len(blob) + return self.size > self.maxsize + except ConnectionResetError as e: + raise HTTPClientClosedConnection() from e async def handle_session(self): assert self.response is not None, "Response is not specified." @@ -58,8 +68,8 @@ async def handle_session(self): return # session is closing or closed - if self.session.state in (STATE_CLOSING, STATE_CLOSED): - await self.session._remote_closed() + if self.session.state in (SessionState.CLOSING, SessionState.CLOSED): + await self.manager.remote_closed(self.session) await self._send(close_frame(3000, "Go away!")) return @@ -75,25 +85,25 @@ async def handle_session(self): if self.timeout: try: frame, text = await asyncio.wait_for( - self.session._get_frame(), + self.session.get_frame(), timeout=self.timeout, ) except asyncio.futures.TimeoutError: - frame, text = FRAME_MESSAGE, "a[]" + frame, text = Frame.MESSAGE, "a[]" else: - frame, text = await self.session._get_frame() + frame, text = await self.session.get_frame() - if frame == FRAME_CLOSE: - await self.session._remote_closed() + if frame == Frame.CLOSE: + await self.manager.remote_closed(self.session) await self._send(text) break stop = await self._send(text) if stop: break - except (asyncio.CancelledError, ConnectionError): - await self.session._remote_close(exc=aiohttp.ClientConnectionError) - await self.session._remote_closed() + except (asyncio.CancelledError, ConnectionError) as e: + await self.manager.remote_close(self.session, exc=e) + await self.manager.remote_closed(self.session) raise except SessionIsClosed: pass diff --git a/sockjs/transports/htmlfile.py b/sockjs/transports/htmlfile.py index 86877a10..ef193658 100644 --- a/sockjs/transports/htmlfile.py +++ b/sockjs/transports/htmlfile.py @@ -40,11 +40,11 @@ async def process(self): callback = request.query.get("c") if callback is None: - await self.session._remote_closed() + await self.manager.remote_closed(self.session) raise web.HTTPInternalServerError(text='"callback" parameter required') elif not self.check_callback.match(callback): - await self.session._remote_closed() + await self.manager.remote_closed(self.session) raise web.HTTPInternalServerError(text='invalid "callback" parameter') headers = ( diff --git a/sockjs/transports/jsonp.py b/sockjs/transports/jsonp.py index 6e75817c..fd2389c6 100644 --- a/sockjs/transports/jsonp.py +++ b/sockjs/transports/jsonp.py @@ -22,6 +22,7 @@ async def _send(self, text: str): return await super()._send(text) async def process(self): + manager = self.manager session = self.session request = self.request meth = request.method @@ -29,11 +30,11 @@ async def process(self): if request.method == hdrs.METH_GET: callback = self.callback = request.query.get("c") if not callback: - await self.session._remote_closed() + await self.manager.remote_closed(self.session) raise web.HTTPInternalServerError(text='"callback" parameter required') elif not self.check_callback.match(callback): - await self.session._remote_closed() + await self.manager.remote_closed(self.session) raise web.HTTPInternalServerError(text='invalid "callback" parameter') headers = ( @@ -68,7 +69,7 @@ async def process(self): except Exception: raise web.HTTPInternalServerError(text="Broken JSON encoding.") - await session._remote_messages(messages) + await manager.remote_messages(session, messages) headers = ( (hdrs.CONTENT_TYPE, "text/plain;charset=UTF-8"), diff --git a/sockjs/transports/rawwebsocket.py b/sockjs/transports/rawwebsocket.py index 53487441..f63ea44a 100644 --- a/sockjs/transports/rawwebsocket.py +++ b/sockjs/transports/rawwebsocket.py @@ -10,7 +10,7 @@ from .base import Transport from .utils import cancel_tasks from ..exceptions import SessionIsClosed -from ..protocol import FRAME_CLOSE, FRAME_HEARTBEAT, FRAME_MESSAGE, FRAME_MESSAGE_BLOB +from ..protocol import Frame from ..session import Session, SessionManager @@ -28,7 +28,7 @@ def get_session(cls, manager: SessionManager, session_id: str) -> Session: # Generate unique session_id based on given ID. orig_session_id = session_id - while session_id in manager: + while session_id in manager.sessions: session_id = "%s-%s" % (orig_session_id, uuid4().hex[-8:]) return super().get_session(manager, session_id) @@ -40,28 +40,28 @@ def __init__(self, manager: SessionManager, session: Session, request: web.Reque async def server(self, ws: web.WebSocketResponse): while True: try: - frame, data = await self.session._get_frame(pack=False) + frame, data = await self.session.get_frame(pack=False) except SessionIsClosed: break - if frame == FRAME_MESSAGE: + if frame == Frame.MESSAGE: for text in data: await ws.send_str(text) - elif frame == FRAME_MESSAGE_BLOB: + elif frame == Frame.MESSAGE_BLOB: data = data[1:] if data.startswith("["): data = data[1:-1] await ws.send_str(data) - elif frame == FRAME_HEARTBEAT: + elif frame == Frame.HEARTBEAT: await ws.ping() if self._wait_pong_task is None: self._wait_pong_task = asyncio.create_task(self._wait_pong()) self._wait_pong_task.add_done_callback(self._wait_done_callback) - elif frame == FRAME_CLOSE: + elif frame == Frame.CLOSE: try: await ws.close(message=b"Go away!") finally: - await self.session._remote_closed() + await self.manager.remote_closed(self.session) async def _wait_pong(self): try: @@ -84,17 +84,17 @@ async def client(self, ws: web.WebSocketResponse): if msg.type == web.WSMsgType.text: if not msg.data: continue - await self.session._remote_message(msg.data) + await self.manager.remote_message(self.session, msg.data) elif msg.type == web.WSMsgType.close: - await self.session._remote_close() + await self.manager.remote_close(self.session) elif msg.type in (web.WSMsgType.closed, web.WSMsgType.closing): - await self.session._remote_closed() + await self.manager.remote_closed(self.session) break elif msg.type == web.WSMsgType.PONG: - self.session._tick() + self.session.tick() elif msg.type == web.WSMsgType.PING: await ws.pong(msg.data) - self.session._tick() + self.session.tick() async def process(self): # start websocket connection @@ -117,7 +117,7 @@ async def process(self): except asyncio.CancelledError: raise except Exception as exc: - await self.session._remote_close(exc) + await self.manager.remote_close(self.session, exc) finally: self.session.expire() await self.manager.release(self.session) diff --git a/sockjs/transports/websocket.py b/sockjs/transports/websocket.py index b40fe0c3..fd0ab4cd 100644 --- a/sockjs/transports/websocket.py +++ b/sockjs/transports/websocket.py @@ -12,7 +12,7 @@ from .base import Transport from .utils import cancel_tasks from ..exceptions import SessionIsClosed -from ..protocol import FRAME_CLOSE, FRAME_HEARTBEAT, STATE_CLOSED, close_frame, loads +from ..protocol import SessionState, Frame, close_frame, loads from ..session import Session, SessionManager @@ -33,7 +33,7 @@ def get_session(cls, manager: SessionManager, session_id: str) -> Session: # Generate unique session_id based on given ID. orig_session_id = session_id - while session_id in manager: + while session_id in manager.sessions: session_id = "%s-%s" % (orig_session_id, uuid4().hex[-8:]) return super().get_session(manager, session_id) @@ -45,11 +45,11 @@ def __init__(self, manager: SessionManager, session: Session, request: web.Reque async def server(self, ws: web.WebSocketResponse): while True: try: - frame, data = await self.session._get_frame() + frame, data = await self.session.get_frame() except SessionIsClosed: break - if frame == FRAME_HEARTBEAT: + if frame == Frame.HEARTBEAT: await ws.ping() log.debug("Send WS PING") if self._wait_pong_task is None: @@ -59,11 +59,11 @@ async def server(self, ws: web.WebSocketResponse): await ws.send_str(data) - if frame == FRAME_CLOSE: + if frame == Frame.CLOSE: try: await ws.close() finally: - await self.session._remote_closed() + await self.manager.remote_closed(self.session) async def _wait_pong(self): try: @@ -91,26 +91,26 @@ async def client(self, ws: web.WebSocketResponse): try: text = loads(data) except Exception as exc: - await self.session._remote_close(exc) - await self.session._remote_closed() + await self.manager.remote_close(self.session, exc) + await self.manager.remote_closed(self.session) await ws.close(message=b"broken json") break if data.startswith("["): - await self.session._remote_messages(text) + await self.manager.remote_messages(self.session, text) else: - await self.session._remote_message(text) + await self.manager.remote_message(self.session, text) elif msg.type == web.WSMsgType.PONG: log.debug("Received WS PONG") - self.session._tick() + self.session.tick() elif msg.type == web.WSMsgType.PING: log.debug("Received WS PING") await ws.pong(msg.data) - self.session._tick() + self.session.tick() elif msg.type == web.WSMsgType.close: - await self.session._remote_close() + await self.manager.remote_close(self.session) elif msg.type in (web.WSMsgType.closed, web.WSMsgType.closing): - await self.session._remote_closed() + await self.manager.remote_closed(self.session) break async def process(self): @@ -130,7 +130,7 @@ async def process(self): # session was interrupted if self.session.interrupted: await ws.send_str(close_frame(1002, "Connection interrupted")) - elif self.session.state == STATE_CLOSED: + elif self.session.state == SessionState.CLOSED: await ws.send_str(close_frame(3000, "Go away!")) else: try: @@ -148,7 +148,7 @@ async def process(self): except asyncio.CancelledError: raise except Exception as exc: - await self.session._remote_close(exc) + await self.manager.remote_close(self.session, exc) finally: self.session.expire() await self.manager.release(self.session) diff --git a/sockjs/transports/xhr_pooling.py b/sockjs/transports/xhr_pooling.py index 23895457..0ddd96f4 100644 --- a/sockjs/transports/xhr_pooling.py +++ b/sockjs/transports/xhr_pooling.py @@ -69,7 +69,7 @@ async def process(self): except Exception: raise web.HTTPInternalServerError(text="Broken JSON encoding.") - await self.session._remote_messages(messages) + await self.manager.remote_messages(self.session, messages) headers = ( (hdrs.CONTENT_TYPE, "text/plain; charset=UTF-8"), diff --git a/tests/conftest.py b/tests/conftest.py index 14ae2de7..f97bba11 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import asyncio +from typing import Optional from unittest import mock import aiohttp_cors @@ -41,7 +42,7 @@ def maker(result, exc=False): result = [] output = result - async def handler(msg, s): + async def handler(manager, s, msg): if exc: raise ValueError((msg, s)) output.append((msg, s)) @@ -93,10 +94,15 @@ def maker(method, path, query_params=None, headers=None, match_info=None): @pytest.fixture def make_session(make_handler, make_request): - def maker(name="test", disconnect_delay=10, handler=None, result=None): - if handler is None: - handler = make_handler(result) - return Session(name, handler, disconnect_delay=disconnect_delay, debug=True) + def maker( + name="test", + disconnect_delay=10, + manager: Optional[SessionManager] = None + ): + session = Session(name, disconnect_delay=disconnect_delay, debug=True) + if manager: + manager.sessions[session.id] = session + return session return maker diff --git a/tests/test_route.py b/tests/test_route.py index 911c5898..75b967c0 100644 --- a/tests/test_route.py +++ b/tests/test_route.py @@ -204,7 +204,7 @@ async def process(self): with pytest.raises(web.HTTPInternalServerError): await route.handler(request) - s1 = route.manager["s1"] + s1 = route.manager.sessions["s1"] assert not route.manager.is_acquired(s1) diff --git a/tests/test_session.py b/tests/test_session.py index 12895f88..f5db114e 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -6,26 +6,24 @@ import pytest from aiohttp import web -from sockjs import Session, SessionIsAcquired, SessionIsClosed, protocol +from sockjs import Session, SessionIsAcquired, SessionIsClosed, protocol, SessionState, Frame class TestSession: - async def test_ctor(self, mocker, make_handler, make_request): + async def test_ctor(self, mocker): dt = mocker.patch("sockjs.session.datetime") now = dt.now.return_value = datetime.now() - handler = make_handler([]) - session = Session("id", handler) - + session = Session("id") assert session.id == "id" assert not session.expired assert session.expires == now + timedelta(seconds=5) assert session._hits == 0 assert session._heartbeats == 0 - assert session.state == protocol.STATE_NEW + assert session.state == SessionState.NEW - session = Session("id", handler, disconnect_delay=15) + session = Session("id", disconnect_delay=15) assert session.id == "id" assert not session.expired @@ -33,22 +31,22 @@ async def test_ctor(self, mocker, make_handler, make_request): async def test_str(self, make_session): session = make_session("test") - session.state = protocol.STATE_OPEN + session.state = SessionState.OPEN assert str(session) == "id='test' connected" session._hits = 10 session._heartbeats = 50 - session.state = protocol.STATE_CLOSING + session.state = SessionState.CLOSING assert str(session) == "id='test' disconnected hits=10 heartbeats=50" - session._feed(protocol.FRAME_MESSAGE, "msg") + session.feed(Frame.MESSAGE, "msg") assert str(session) == "id='test' disconnected queue[1] hits=10 heartbeats=50" - session.state = protocol.STATE_CLOSED + session.state = SessionState.CLOSED assert str(session) == "id='test' closed queue[1] hits=10 heartbeats=50" - session.state = protocol.STATE_OPEN + session.state = SessionState.OPEN session.acquired = True expected = "id='test' connected acquired queue[1] hits=10 heartbeats=50" assert str(session) == expected @@ -59,7 +57,7 @@ async def test_tick(self, mocker, make_session): session = make_session("test") now = dt.now.return_value = now + timedelta(hours=1) - session._tick() + session.tick() assert session.next_heartbeat == now + timedelta( seconds=session.heartbeat_delay ) @@ -70,7 +68,7 @@ async def test_tick_different_timeoutk(self, mocker, make_session): session = make_session("test", disconnect_delay=20) now = dt.now.return_value = now + timedelta(hours=1) - session._tick() + session.tick() assert session.next_heartbeat == now + timedelta( seconds=session.heartbeat_delay ) @@ -81,24 +79,24 @@ async def test_tick_custom(self, mocker, make_session): session = make_session("test", disconnect_delay=20) now = dt.now.return_value = now + timedelta(hours=1) - session._tick(30) + session.tick(30) assert session.next_heartbeat == now + timedelta(seconds=30) async def test_heartbeat(self, make_session): session = make_session("test") session._send_heartbeats = True assert session._heartbeats == 0 - session._heartbeat() + session.heartbeat() assert session._heartbeats == 1 - session._heartbeat() + session.heartbeat() assert session._heartbeats == 2 async def test_heartbeat_transport(self, make_session): session = make_session("test") session._send_heartbeats = True - session._heartbeat() + session.heartbeat() assert list(session._queue) == [ - (protocol.FRAME_HEARTBEAT, protocol.FRAME_HEARTBEAT) + (Frame.HEARTBEAT, Frame.HEARTBEAT.value) ] async def test_expire(self, make_session, mocker): @@ -117,10 +115,10 @@ async def test_send(self, make_session): session.send("message") assert list(session._queue) == [] - session.state = protocol.STATE_OPEN + session.state = SessionState.OPEN session.send("message") - assert list(session._queue) == [(protocol.FRAME_MESSAGE, ["message"])] + assert list(session._queue) == [(Frame.MESSAGE, ["message"])] async def test_send_non_str(self, make_session): session = make_session("test") @@ -132,126 +130,125 @@ async def test_send_frame(self, make_session): session.send_frame('a["message"]') assert list(session._queue) == [] - session.state = protocol.STATE_OPEN + session.state = SessionState.OPEN session.send_frame('a["message"]') - assert list(session._queue) == [(protocol.FRAME_MESSAGE_BLOB, 'a["message"]')] + assert list(session._queue) == [(Frame.MESSAGE_BLOB, 'a["message"]')] async def test_feed(self, make_session): session = make_session("test") - session._feed(protocol.FRAME_OPEN, protocol.FRAME_OPEN) - session._feed(protocol.FRAME_MESSAGE, "msg") - session._feed(protocol.FRAME_CLOSE, (3001, "reason")) + session.feed(Frame.OPEN, Frame.OPEN.value) + session.feed(Frame.MESSAGE, "msg") + session.feed(Frame.CLOSE, (3001, "reason")) assert list(session._queue) == [ - (protocol.FRAME_OPEN, protocol.FRAME_OPEN), - (protocol.FRAME_MESSAGE, ["msg"]), - (protocol.FRAME_CLOSE, (3001, "reason")), + (Frame.OPEN, Frame.OPEN.value), + (Frame.MESSAGE, ["msg"]), + (Frame.CLOSE, (3001, "reason")), ] async def test_feed_msg_packing(self, make_session): session = make_session("test") - session._feed(protocol.FRAME_MESSAGE, "msg1") - session._feed(protocol.FRAME_MESSAGE, "msg2") - session._feed(protocol.FRAME_CLOSE, (3001, "reason")) - session._feed(protocol.FRAME_MESSAGE, "msg3") + session.feed(Frame.MESSAGE, "msg1") + session.feed(Frame.MESSAGE, "msg2") + session.feed(Frame.CLOSE, (3001, "reason")) + session.feed(Frame.MESSAGE, "msg3") assert list(session._queue) == [ - (protocol.FRAME_MESSAGE, ["msg1", "msg2"]), - (protocol.FRAME_CLOSE, (3001, "reason")), - (protocol.FRAME_MESSAGE, ["msg3"]), + (Frame.MESSAGE, ["msg1", "msg2"]), + (Frame.CLOSE, (3001, "reason")), + (Frame.MESSAGE, ["msg3"]), ] async def test_feed_with_waiter(self, make_session): session = make_session("test") session._waiter = waiter = asyncio.Future() - session._feed(protocol.FRAME_MESSAGE, "msg") + session.feed(Frame.MESSAGE, "msg") - assert list(session._queue) == [(protocol.FRAME_MESSAGE, ["msg"])] + assert list(session._queue) == [(Frame.MESSAGE, ["msg"])] assert session._waiter is None assert waiter.done() async def test_wait(self, make_session): s = make_session("test") - s.state = protocol.STATE_OPEN + s.state = SessionState.OPEN async def send(): await asyncio.sleep(0.001) - s._feed(protocol.FRAME_MESSAGE, "msg1") + s.feed(Frame.MESSAGE, "msg1") ensure_future(send()) - frame, payload = await s._get_frame() - assert frame == protocol.FRAME_MESSAGE + frame, payload = await s.get_frame() + assert frame == Frame.MESSAGE assert payload == 'a["msg1"]' async def test_wait_closed(self, make_session): s = make_session("test") - s.state = protocol.STATE_CLOSED + s.state = SessionState.CLOSED with pytest.raises(SessionIsClosed): - await s._get_frame() + await s.get_frame() async def test_wait_message(self, make_session): s = make_session("test") - s.state = protocol.STATE_OPEN - s._feed(protocol.FRAME_MESSAGE, "msg1") - frame, payload = await s._get_frame() - assert frame == protocol.FRAME_MESSAGE + s.state = SessionState.OPEN + s.feed(Frame.MESSAGE, "msg1") + frame, payload = await s.get_frame() + assert frame == Frame.MESSAGE assert payload == 'a["msg1"]' async def test_wait_close(self, make_session): s = make_session("test") - s.state = protocol.STATE_OPEN - s._feed(protocol.FRAME_CLOSE, (3000, "Go away!")) - frame, payload = await s._get_frame() - assert frame == protocol.FRAME_CLOSE + s.state = SessionState.OPEN + s.feed(Frame.CLOSE, (3000, "Go away!")) + frame, payload = await s.get_frame() + assert frame == Frame.CLOSE assert payload == 'c[3000,"Go away!"]' async def test_wait_message_unpack(self, make_session): s = make_session("test") - s.state = protocol.STATE_OPEN - s._feed(protocol.FRAME_MESSAGE, "msg1") - frame, payload = await s._get_frame(pack=False) - assert frame == protocol.FRAME_MESSAGE + s.state = SessionState.OPEN + s.feed(Frame.MESSAGE, "msg1") + frame, payload = await s.get_frame(pack=False) + assert frame == Frame.MESSAGE assert payload == ["msg1"] async def test_wait_close_unpack(self, make_session): s = make_session("test") - s.state = protocol.STATE_OPEN - s._feed(protocol.FRAME_CLOSE, (3000, "Go away!")) - frame, payload = await s._get_frame(pack=False) - assert frame == protocol.FRAME_CLOSE + s.state = SessionState.OPEN + s.feed(Frame.CLOSE, (3000, "Go away!")) + frame, payload = await s.get_frame(pack=False) + assert frame == Frame.CLOSE assert payload == (3000, "Go away!") async def test_close(self, make_session): session = make_session("test") - session.state = protocol.STATE_OPEN + session.state = SessionState.OPEN session.close() - assert session.state == protocol.STATE_CLOSING - assert list(session._queue) == [(protocol.FRAME_CLOSE, (3000, "Go away!"))] + assert session.state == SessionState.CLOSING + assert list(session._queue) == [(Frame.CLOSE, (3000, "Go away!"))] async def test_close_idempotent(self, make_session): session = make_session("test") - session.state = protocol.STATE_CLOSED + session.state = SessionState.CLOSED session.close() - assert session.state == protocol.STATE_CLOSED + assert session.state == SessionState.CLOSED assert list(session._queue) == [] - async def test_acquire_new_session(self, make_manager, make_session, make_request): - manager = make_manager() + async def test_acquire_new_session(self, make_manager, make_session, make_request, make_handler): messages = [] - - session = make_session(result=messages) - assert session.state == protocol.STATE_NEW + handler = make_handler(result=messages) + manager = make_manager(handler) + session = make_session(manager=manager) + assert session.state == SessionState.NEW assert session._hb_task is None assert not session._send_heartbeats - await session.acquire(manager, request=make_request("GET", "/test/")) - assert session.state == protocol.STATE_OPEN - assert session.manager is manager + await manager.acquire(session, request=make_request("GET", "/test/")) + assert session.state == SessionState.OPEN assert session._send_heartbeats assert session._hb_task is not None - assert list(session._queue) == [(protocol.FRAME_OPEN, protocol.FRAME_OPEN)] - assert messages == [(protocol.OpenMessage, session)] + assert list(session._queue) == [(Frame.OPEN, Frame.OPEN.value)] + assert messages == [(protocol.OPEN_MESSAGE, session)] hb_task = session._hb_task session.release() @@ -265,146 +262,168 @@ async def test_acquire_exception_in_handler( async def handler(msg, s): raise ValueError - session = make_session(handler=handler) - assert session.state == protocol.STATE_NEW + sm = make_manager(handler) + session = make_session(manager=sm) + assert session.state == SessionState.NEW - sm = make_manager() - await session.acquire(sm, request=make_request("GET", "/test/")) - assert session.state == protocol.STATE_CLOSING + await sm.acquire(session, request=make_request("GET", "/test/")) + assert session.state == SessionState.CLOSING assert session._send_heartbeats assert session.interrupted assert list(session._queue) == [ - (protocol.FRAME_OPEN, protocol.FRAME_OPEN), - (protocol.FRAME_CLOSE, (3000, "Internal error")), + (Frame.OPEN, Frame.OPEN.value), + (Frame.CLOSE, (3000, "Internal error")), ] - async def test_remote_close(self, make_session): + async def test_remote_close(self, make_session, make_manager, make_handler): messages = [] - session = make_session(result=messages) + handler = make_handler(result=messages) + manager = make_manager(handler) + session = make_session(manager=manager) - await session._remote_close() + await manager.remote_close(session) assert not session.interrupted - assert session.state == protocol.STATE_CLOSING - assert messages == [(protocol.SockjsMessage(protocol.MSG_CLOSE, None), session)] + assert session.state == SessionState.CLOSING + assert messages == [(protocol.SockjsMessage(protocol.MsgType.CLOSE, None), session)] - async def test_remote_close_idempotent(self, make_session): + async def test_remote_close_idempotent(self, make_session, make_manager, make_handler): messages = [] - session = make_session(result=messages) - session.state = protocol.STATE_CLOSED + handler = make_handler(result=messages) + manager = make_manager(handler) + session = make_session() + session.state = SessionState.CLOSED - await session._remote_close() - assert session.state == protocol.STATE_CLOSED + await manager.remote_close(session) + assert session.state == SessionState.CLOSED assert messages == [] - async def test_remote_close_with_exc(self, make_session): + async def test_remote_close_with_exc(self, make_session, make_manager, make_handler): messages = [] - session = make_session(result=messages) + handler = make_handler(result=messages) + manager = make_manager(handler) + session = make_session(manager=manager) exc = ValueError() - await session._remote_close(exc=exc) + await manager.remote_close(session, exc=exc) assert session.interrupted - assert session.state == protocol.STATE_CLOSING - assert messages == [(protocol.SockjsMessage(protocol.MSG_CLOSE, exc), session)] + assert session.state == SessionState.CLOSING + assert messages == [(protocol.SockjsMessage(protocol.MsgType.CLOSE, exc), session)] - async def test_remote_close_exc_in_handler(self, make_session, make_handler): + async def test_remote_close_exc_in_handler(self, make_session, make_manager, make_handler): handler = make_handler([], exc=True) - session = make_session(handler=handler) + manager = make_manager(handler) + session = make_session() - await session._remote_close() + await manager.remote_close(session) assert not session.interrupted - assert session.state == protocol.STATE_CLOSING + assert session.state == SessionState.CLOSING - async def test_remote_closed(self, make_session): + async def test_remote_closed(self, make_session, make_manager, make_handler): messages = [] - session = make_session(result=messages) + handler = make_handler(result=messages) + manager = make_manager(handler) + session = make_session(manager=manager) - await session._remote_closed() + await manager.remote_closed(session) assert session.expires > datetime.now() assert not session.expired - assert session.state == protocol.STATE_NEW + assert session.state == SessionState.NEW assert messages == [] session.expires = datetime.now() assert session.expired - await session._remote_closed() - assert session.state == protocol.STATE_CLOSED - assert messages == [(protocol.ClosedMessage, session)] + await manager.remote_closed(session) + assert session.state == SessionState.CLOSED + assert messages == [(protocol.CLOSED_MESSAGE, session)] # Without delay messages = [] - session = make_session(result=messages, disconnect_delay=0) - await session._remote_closed() + handler = make_handler(result=messages) + manager = make_manager(handler) + session = make_session(disconnect_delay=0) + await manager.remote_closed(session) assert session.expired - await session._remote_closed() - assert session.state == protocol.STATE_CLOSED - assert messages == [(protocol.ClosedMessage, session)] + await manager.remote_closed(session) + assert session.state == SessionState.CLOSED + assert messages == [(protocol.CLOSED_MESSAGE, session)] - async def test_remote_closed_idempotent(self, make_session): + async def test_remote_closed_idempotent(self, make_session, make_manager, make_handler): messages = [] - session = make_session(result=messages) - session.state = protocol.STATE_CLOSED + handler = make_handler(result=messages) + manager = make_manager(handler) + session = make_session() + session.state = SessionState.CLOSED - await session._remote_closed() - assert session.state == protocol.STATE_CLOSED + await manager.remote_closed(session) + assert session.state == SessionState.CLOSED assert messages == [] - async def test_remote_closed_with_waiter(self, make_session): + async def test_remote_closed_with_waiter(self, make_session, make_manager, make_handler): messages = [] - session = make_session(result=messages, disconnect_delay=0) + handler = make_handler(result=messages) + manager = make_manager(handler) + session = make_session(manager=manager, disconnect_delay=0) session._waiter = waiter = asyncio.Future() now = datetime.now() - await session._remote_closed() + await manager.remote_closed(session) assert waiter.done() assert session.expires <= now assert session.expired assert session._waiter is None - assert session.state == protocol.STATE_CLOSED - assert messages == [(protocol.ClosedMessage, session)] + assert session.state == SessionState.CLOSED + assert messages == [(protocol.CLOSED_MESSAGE, session)] - async def test_remote_closed_exc_in_handler(self, make_handler, make_session): + async def test_remote_closed_exc_in_handler(self, make_session, make_manager, make_handler): handler = make_handler([], exc=True) - session = make_session(handler=handler, disconnect_delay=0) + manager = make_manager(handler) + session = make_session(disconnect_delay=0) now = datetime.now() - await session._remote_closed() + await manager.remote_closed(session) assert session.expires <= now assert session.expired - assert session.state == protocol.STATE_CLOSED + assert session.state == SessionState.CLOSED - async def test_remote_message(self, make_session): + async def test_remote_message(self, make_session, make_manager, make_handler): messages = [] - session = make_session(result=messages) + handler = make_handler(result=messages) + manager = make_manager(handler) + session = make_session(manager=manager) - await session._remote_message("msg") + await manager.remote_message(session, "msg") assert messages == [ - (protocol.SockjsMessage(protocol.MSG_MESSAGE, "msg"), session) + (protocol.SockjsMessage(protocol.MsgType.MESSAGE, "msg"), session) ] - async def test_remote_message_exc(self, make_handler, make_session): + async def test_remote_message_exc(self, make_session, make_manager, make_handler): messages = [] handler = make_handler(messages, exc=True) - session = make_session(handler=handler) + manager = make_manager(handler) + session = make_session() - await session._remote_message("msg") + await manager.remote_message(session, "msg") assert messages == [] - async def test_remote_messages(self, make_session): + async def test_remote_messages(self, make_session, make_manager, make_handler): messages = [] - session = make_session(result=messages) + handler = make_handler(result=messages) + manager = make_manager(handler) + session = make_session(manager=manager) - await session._remote_messages(("msg1", "msg2")) + await manager.remote_messages(session, ("msg1", "msg2")) assert messages == [ - (protocol.SockjsMessage(protocol.MSG_MESSAGE, "msg1"), session), - (protocol.SockjsMessage(protocol.MSG_MESSAGE, "msg2"), session), + (protocol.SockjsMessage(protocol.MsgType.MESSAGE, "msg1"), session), + (protocol.SockjsMessage(protocol.MsgType.MESSAGE, "msg2"), session), ] - async def test_remote_messages_exc(self, make_handler, make_session): + async def test_remote_messages_exc(self, make_session, make_manager, make_handler): messages = [] handler = make_handler(messages, exc=True) - session = make_session(handler=handler) + manager = make_manager(handler) + session = make_session() - await session._remote_messages(("msg1", "msg2")) + await manager.remote_messages(session, ("msg1", "msg2")) assert messages == [] @@ -421,16 +440,15 @@ async def test_fresh(self, make_manager, make_session): sm = make_manager() s = make_session() sm._add(s) - assert "test" in sm + assert "test" in sm.sessions async def test_add(self, make_manager, make_session): sm = make_manager() s = make_session() sm._add(s) - assert "test" in sm - assert sm["test"] is s - assert s.manager is sm + assert "test" in sm.sessions + assert sm.sessions["test"] is s async def test_add_expired(self, make_manager, make_session): sm = make_manager() @@ -461,7 +479,7 @@ async def test_get_with_create(self, make_manager): sm = make_manager() s = sm.get("test", True) - assert s.id in sm + assert s.id in sm.sessions assert isinstance(s, Session) async def test_acquire(self, make_manager, make_session, make_request): @@ -470,7 +488,7 @@ async def test_acquire(self, make_manager, make_session, make_request): sm._add(s1) s1.acquire = mock.Mock() s1.acquire.return_value = asyncio.Future() - s1.acquire.return_value.set_result(1) + s1.acquire.return_value.set_result(True) s2 = await sm.acquire(s1, request=make_request("GET", "/test/")) @@ -523,29 +541,29 @@ async def test_broadcast(self, make_manager): sm = make_manager() s1 = sm.get("test1", True) - s1.state = protocol.STATE_OPEN + s1.state = SessionState.OPEN s2 = sm.get("test2", True) - s2.state = protocol.STATE_OPEN + s2.state = SessionState.OPEN sm.broadcast("msg") - assert list(s1._queue) == [(protocol.FRAME_MESSAGE_BLOB, 'a["msg"]')] - assert list(s2._queue) == [(protocol.FRAME_MESSAGE_BLOB, 'a["msg"]')] + assert list(s1._queue) == [(Frame.MESSAGE_BLOB, 'a["msg"]')] + assert list(s2._queue) == [(Frame.MESSAGE_BLOB, 'a["msg"]')] async def test_clear(self, make_manager): sm = make_manager() s1 = sm.get("s1", True) - s1.state = protocol.STATE_OPEN + s1.state = SessionState.OPEN s2 = sm.get("s2", True) - s2.state = protocol.STATE_OPEN + s2.state = SessionState.OPEN await sm.clear() - assert not bool(sm) + assert not bool(sm.sessions) assert s1.expired assert s2.expired - assert s1.state == protocol.STATE_CLOSED - assert s2.state == protocol.STATE_CLOSED + assert s1.state == SessionState.CLOSED + assert s2.state == SessionState.CLOSED async def test_gc_task(self, make_manager): sm = make_manager() @@ -565,9 +583,7 @@ async def test_gc_task(self, make_manager): async def test_gc_expire(self, make_manager, make_session, make_request): sm = make_manager() - s = make_session() - - sm._add(s) + s = make_session(manager=sm) await sm.acquire(s, request=make_request("GET", "/test/")) await sm.release(s) @@ -576,30 +592,26 @@ async def test_gc_expire(self, make_manager, make_session, make_request): assert s.expired await sm._gc_expired_sessions() - assert s.id not in sm + assert s.id not in sm.sessions assert s.expired - assert s.state == protocol.STATE_CLOSED + assert s.state == SessionState.CLOSED async def test_gc_expire_acquired(self, make_manager, make_session, make_request): sm = make_manager() - s = make_session() - sm._add(s) + s = make_session(manager=sm) await sm.acquire(s, request=make_request("GET", "/test/")) s.expires = datetime.now() - timedelta(seconds=30) await sm._gc_expired_sessions() - assert s.id not in sm + assert s.id not in sm.sessions assert s.id not in sm.acquired assert s.expired - assert s.state == protocol.STATE_CLOSED + assert s.state == SessionState.CLOSED async def test_gc_one_expire(self, make_manager, make_session, make_request): sm = make_manager() - s1 = make_session("id1") - s2 = make_session("id2") - - sm._add(s1) - sm._add(s2) + s1 = make_session("id1", manager=sm) + s2 = make_session("id2", manager=sm) await sm.acquire(s1, request=make_request("GET", "/test/")) await sm.acquire(s2, request=make_request("GET", "/test/")) await sm.release(s1) @@ -608,16 +620,13 @@ async def test_gc_one_expire(self, make_manager, make_session, make_request): s1.expires = datetime.now() - timedelta(seconds=30) await sm._gc_expired_sessions() - assert s1.id not in sm - assert s2.id in sm + assert s1.id not in sm.sessions + assert s2.id in sm.sessions async def test_emits_warning_on_del(self, make_manager, make_session): sm = make_manager() - s1 = make_session("id1") - s2 = make_session("id2") - - sm._add(s1) - sm._add(s2) + make_session("id1", manager=sm) + make_session("id2", manager=sm) with pytest.warns(RuntimeWarning) as warning: getattr(sm, "__del__")() @@ -628,11 +637,8 @@ async def test_does_not_emits_warning_on_del_if_no_sessions( self, make_manager, make_session ): sm = make_manager() - s1 = make_session("id1") - s2 = make_session("id2") - - sm._add(s1) - sm._add(s2) + make_session("id1", manager=sm) + make_session("id2", manager=sm) await sm.clear() getattr(sm, "__del__")() diff --git a/tests/test_transport.py b/tests/test_transport.py index bd30e6f0..24c3d820 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -4,7 +4,7 @@ from aiohttp import web from aiohttp.test_utils import make_mocked_coro -from sockjs import protocol +from sockjs import SessionState from sockjs.transports import base @@ -55,23 +55,25 @@ async def test_handle_session_interrupted(make_transport, make_fut): async def test_handle_session_closing(make_transport, make_fut): trans = make_transport() + manager = trans.manager trans._send = make_fut(1) trans.session.interrupted = False - trans.session.state = protocol.STATE_CLOSING - trans.session._remote_closed = make_fut(1) + trans.session.state = SessionState.CLOSING + manager.remote_closed = make_fut(1) trans.response = web.StreamResponse() await trans.handle_session() - trans.session._remote_closed.assert_called_with() + manager.remote_closed.assert_called() trans._send.assert_called_with('c[3000,"Go away!"]') async def test_handle_session_closed(make_transport, make_fut): trans = make_transport() + manager = trans.manager trans._send = make_fut(1) trans.session.interrupted = False - trans.session.state = protocol.STATE_CLOSED - trans.session._remote_closed = make_fut(1) + trans.session.state = SessionState.CLOSED + manager.remote_closed = make_fut(1) trans.response = web.StreamResponse() await trans.handle_session() - trans.session._remote_closed.assert_called_with() + manager.remote_closed.assert_called() trans._send.assert_called_with('c[3000,"Go away!"]') diff --git a/tests/test_transport_htmlfile.py b/tests/test_transport_htmlfile.py index b286be13..6d8e77dc 100644 --- a/tests/test_transport_htmlfile.py +++ b/tests/test_transport_htmlfile.py @@ -46,18 +46,18 @@ async def test_process(make_transport, make_fut): async def test_process_no_callback(make_transport, make_fut): transp = make_transport() transp.session = mock.Mock() - transp.session._remote_closed = make_fut(1) + transp.manager.remote_closed = make_fut(1) with pytest.raises(web.HTTPInternalServerError): await transp.process() - assert transp.session._remote_closed.called + assert transp.manager.remote_closed.called async def test_process_bad_callback(make_transport, make_fut): transp = make_transport(query_params={"c": "calback!!!!"}) transp.session = mock.Mock() - transp.session._remote_closed = make_fut(1) + transp.manager.remote_closed = make_fut(1) with pytest.raises(web.HTTPInternalServerError): await transp.process() - assert transp.session._remote_closed.called + assert transp.manager.remote_closed.called diff --git a/tests/test_transport_jsonp.py b/tests/test_transport_jsonp.py index f6da5cb4..1109d984 100644 --- a/tests/test_transport_jsonp.py +++ b/tests/test_transport_jsonp.py @@ -42,21 +42,21 @@ async def test_process(make_transport, make_fut): async def test_process_no_callback(make_transport, make_fut): transp = make_transport() transp.session = mock.Mock() - transp.session._remote_closed = make_fut(1) + transp.manager.remote_closed = make_fut(1) with pytest.raises(web.HTTPInternalServerError): await transp.process() - assert transp.session._remote_closed.called + assert transp.manager.remote_closed.called async def test_process_bad_callback(make_transport, make_fut): transp = make_transport(query_params={"c": "calback!!!!"}) transp.session = mock.Mock() - transp.session._remote_closed = make_fut(1) + transp.manager.remote_closed = make_fut(1) with pytest.raises(web.HTTPInternalServerError): await transp.process() - assert transp.session._remote_closed.called + assert transp.manager.remote_closed.called async def test_process_not_supported(make_transport): @@ -68,7 +68,6 @@ async def test_process_not_supported(make_transport): async def xtest_process_bad_encoding(make_transport, make_fut): transp = make_transport(method="POST") transp.request.read = make_fut(b"test") - transp.request.content_type transp.request._content_type = "application/x-www-form-urlencoded" resp = await transp.process() assert resp.status == 500 @@ -77,7 +76,6 @@ async def xtest_process_bad_encoding(make_transport, make_fut): async def xtest_process_no_payload(make_transport, make_fut): transp = make_transport(method="POST") transp.request.read = make_fut(b"d=") - transp.request.content_type transp.request._content_type = "application/x-www-form-urlencoded" resp = await transp.process() assert resp.status == 500 @@ -92,8 +90,8 @@ async def xtest_process_bad_json(make_transport, make_fut): async def xtest_process_message(make_transport, make_fut): transp = make_transport(method="POST") - transp.session._remote_messages = make_fut(1) + transp.manager.remote_messages = make_fut(1) transp.request.read = make_fut(b'["msg1","msg2"]') resp = await transp.process() assert resp.status == 200 - transp.session._remote_messages.assert_called_with(["msg1", "msg2"]) + transp.manager.remote_messages.assert_called_with(["msg1", "msg2"]) diff --git a/tests/test_transport_rawwebsocket.py b/tests/test_transport_rawwebsocket.py index d488b76a..62e5f0b3 100644 --- a/tests/test_transport_rawwebsocket.py +++ b/tests/test_transport_rawwebsocket.py @@ -5,8 +5,9 @@ from aiohttp import WSMessage, WSMsgType from sockjs.exceptions import SessionIsClosed -from sockjs.protocol import FRAME_CLOSE, FRAME_HEARTBEAT +from sockjs.protocol import Frame from sockjs.transports.rawwebsocket import RawWebSocketTransport +from sockjs.transports.utils import cancel_tasks @pytest.fixture @@ -15,7 +16,7 @@ def maker(method="GET", path="/", query_params=None): manager = mock.Mock() session = mock.Mock() session._remote_closed = make_fut(1) - session._get_frame = make_fut((FRAME_CLOSE, "")) + session._get_frame = make_fut((Frame.CLOSE, "")) request = make_request(method, path, query_params=query_params) request.app.freeze() return RawWebSocketTransport(manager, session, request) @@ -41,7 +42,7 @@ async def xtest_ticks_pong(make_transport, make_fut): session = transp.session await transp.client(ws, session) - assert session._tick.called + assert session.tick.called async def test_sends_ping(make_transport, make_fut): @@ -54,14 +55,15 @@ async def test_sends_ping(make_transport, make_fut): ws.ping.side_effect = [future] hb_future = Future() - hb_future.set_result((FRAME_HEARTBEAT, b"")) + hb_future.set_result((Frame.HEARTBEAT, b"")) session_close_future = Future() session_close_future.set_exception(SessionIsClosed) session = mock.Mock() - session._get_frame.side_effect = [hb_future, session_close_future] + session.get_frame.side_effect = [hb_future, session_close_future] transp.session = session await transp.server(ws) assert ws.ping.called + await cancel_tasks(transp._wait_pong_task) diff --git a/tests/test_transport_websocket.py b/tests/test_transport_websocket.py index c9c39406..6f9defcc 100644 --- a/tests/test_transport_websocket.py +++ b/tests/test_transport_websocket.py @@ -7,8 +7,8 @@ from aiohttp import WSMessage, WSMsgType from aiohttp.test_utils import make_mocked_coro -from sockjs import MSG_CLOSED, MSG_MESSAGE, MSG_OPEN, Session -from sockjs.protocol import FRAME_CLOSE, SockjsMessage +from sockjs import MsgType, Session, Frame +from sockjs.protocol import SockjsMessage from sockjs.transports import WebSocketTransport @@ -20,7 +20,7 @@ def maker(method="GET", path="/", query_params=None, handler=None): request = make_request(method, path, query_params=query_params) request.app.freeze() session = manager.get("TestSessionWebsocket", create=True) - session._get_frame = make_fut((FRAME_CLOSE, "")) + session.get_frame = make_fut((Frame.CLOSE, "")) return WebSocketTransport(manager, session, request) return maker @@ -31,6 +31,8 @@ async def xtest_process_release_acquire_and_remote_closed(make_transport): transp.session.interrupted = False transp.manager.acquire = make_mocked_coro() transp.manager.release = make_mocked_coro() + transp.manager.remote_closed = make_mocked_coro() + resp = await transp.process() await transp.manager.clear() @@ -38,7 +40,7 @@ async def xtest_process_release_acquire_and_remote_closed(make_transport): assert resp.headers.get("upgrade", "").lower() == "websocket" assert resp.headers.get("connection", "").lower() == "upgrade" - transp.session._remote_closed.assert_called_once_with() + assert transp.manager.remote_closed.called assert transp.manager.acquire.called assert transp.manager.release.called @@ -48,15 +50,15 @@ async def test_server_close(app, make_manager, make_request): loop = asyncio.get_running_loop() - async def handler(msg: SockjsMessage, session: Session): + async def handler(manager, session: Session, msg: SockjsMessage): nonlocal reached_closed - if msg.type == MSG_OPEN: + if msg.type == MsgType.OPEN: # To reproduce the ordering which makes the issue loop.call_later(0.05, session.close) - elif msg.type == MSG_MESSAGE: + elif msg.type == MsgType.MESSAGE: # To reproduce the ordering which makes the issue loop.call_later(0.05, session.close) - elif msg.type == MSG_CLOSED: + elif msg.type == MsgType.CLOSED: reached_closed = True app.freeze() diff --git a/tests/test_transport_xhrsend.py b/tests/test_transport_xhrsend.py index 56574bc1..072ef66f 100644 --- a/tests/test_transport_xhrsend.py +++ b/tests/test_transport_xhrsend.py @@ -39,11 +39,11 @@ async def xtest_bad_json(make_transport, make_fut): async def xtest_post_message(make_transport, make_fut): transp = make_transport() - transp.session._remote_messages = make_fut(1) + transp.manager.remote_messages = make_fut(1) transp.request.read = make_fut(b'["msg1","msg2"]') resp = await transp.process() assert resp.status == 204 - transp.session._remote_messages.assert_called_with(["msg1", "msg2"]) + transp.manager.remote_messages.assert_called_with(["msg1", "msg2"]) async def test_OPTIONS(make_transport): From 7b3af51bbb9dc6b8286fc049d71340874cd2e2b2 Mon Sep 17 00:00:00 2001 From: Kirill Kuzminykh Date: Fri, 7 Jul 2023 22:19:38 +0300 Subject: [PATCH 18/23] Fixed error processing in StreamingTransport --- sockjs/route.py | 2 ++ sockjs/session.py | 18 ++++++++++++------ sockjs/transports/base.py | 4 ++-- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/sockjs/route.py b/sockjs/route.py index c4c2f6d0..9174a0c8 100644 --- a/sockjs/route.py +++ b/sockjs/route.py @@ -47,6 +47,7 @@ def add_endpoint( cors_config: Optional[CorsConfig] = None, heartbeat_delay=25, disconnect_delay=5, + debug=False, ) -> List[web.AbstractRoute]: registered_routes = [] @@ -73,6 +74,7 @@ async def handler(m, s, msg): handler, heartbeat_delay, disconnect_delay, + debug=debug, ) if manager.name != name: diff --git a/sockjs/session.py b/sockjs/session.py index 645375e2..fdf33f91 100644 --- a/sockjs/session.py +++ b/sockjs/session.py @@ -112,7 +112,8 @@ def acquire(self, request: web.Request) -> bool: self._hits += 1 if self.state == SessionState.NEW: - log.debug("open session: %s", self.id) + if self._debug: + log.debug("open session: %s", self.id) self.state = SessionState.OPEN self.feed(Frame.OPEN, Frame.OPEN.value) return True @@ -280,7 +281,8 @@ async def stop(self, _app=None): async def _check_expiration(self, session: Session): if session.expired: - log.debug("session expired: %s", session.id) + if self.debug: + log.debug("session expired: %s", session.id) # Session is to be GC'd immediately if session.id in self.acquired: await self.release(session) @@ -408,7 +410,8 @@ def __del__(self): async def remote_message(self, session: Session, msg): """Call handler with message received from client.""" - log.debug("incoming message: %s, %s", session.id, msg[:200]) + if self.debug: + log.debug("incoming message: %s, %s", session.id, msg[:200]) session.tick() try: @@ -421,7 +424,8 @@ async def remote_messages(self, session: Session, messages): session.tick() for msg in messages: - log.debug("incoming message: %s, %s", session.id, msg[:200]) + if self.debug: + log.debug("incoming message: %s, %s", session.id, msg[:200]) try: await self.handler(self, session, SockjsMessage(MsgType.MESSAGE, msg)) except Exception: @@ -432,7 +436,8 @@ async def remote_close(self, session: Session, exc=None): if session.state in (SessionState.CLOSING, SessionState.CLOSED): return - log.info("close session: %s", session.id) + if self.debug: + log.info("close session: %s", session.id) session.tick() session.state = SessionState.CLOSING if exc is not None: @@ -452,7 +457,8 @@ async def remote_closed(self, session: Session): session.expire() return - log.info("session closed: %s", session.id) + if self.debug: + log.info("session closed: %s", session.id) session.state = SessionState.CLOSED session.expire() try: diff --git a/sockjs/transports/base.py b/sockjs/transports/base.py index f836c1ef..7af79618 100644 --- a/sockjs/transports/base.py +++ b/sockjs/transports/base.py @@ -2,7 +2,7 @@ import asyncio from aiohttp import web -from aiohttp.web_exceptions import HTTPClientError +from aiohttp.web_exceptions import HTTPClientError, HTTPError from ..exceptions import SessionIsAcquired, SessionIsClosed from ..protocol import ( @@ -101,7 +101,7 @@ async def handle_session(self): stop = await self._send(text) if stop: break - except (asyncio.CancelledError, ConnectionError) as e: + except (asyncio.CancelledError, ConnectionError, HTTPError) as e: await self.manager.remote_close(self.session, exc=e) await self.manager.remote_closed(self.session) raise From 6850a483c2101a65fd055306a4ff0c46d439fec0 Mon Sep 17 00:00:00 2001 From: Kirill Kuzminykh Date: Fri, 7 Jul 2023 22:33:44 +0300 Subject: [PATCH 19/23] Changed supported versions of Python. --- .github/workflows/check_and_test.yaml | 2 +- README.rst | 2 +- setup.py | 5 +---- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/.github/workflows/check_and_test.yaml b/.github/workflows/check_and_test.yaml index 3c59a0e4..266e80d4 100644 --- a/.github/workflows/check_and_test.yaml +++ b/.github/workflows/check_and_test.yaml @@ -12,7 +12,7 @@ jobs: run_tests: strategy: matrix: - python-version: [ "3.7", "3.8", "3.9", "3.10", "3.11" ] + python-version: [ "3.10", "3.11" ] name: Test on Python ${{ matrix.python-version }} runs-on: ubuntu-latest diff --git a/README.rst b/README.rst index 37139570..c3cf70bd 100644 --- a/README.rst +++ b/README.rst @@ -73,7 +73,7 @@ Supported transports Requirements ------------ -- Python 3.7.0 +- Python 3.10.0 - gunicorn 19.2.0 diff --git a/setup.py b/setup.py index 43e4a7b8..44860253 100644 --- a/setup.py +++ b/setup.py @@ -31,9 +31,6 @@ def read(f): "Intended Audience :: Developers", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: Implementation :: CPython", @@ -45,7 +42,7 @@ def read(f): url="https://github.com/aio-libs/sockjs/", license="Apache 2", packages=find_packages(), - python_requires=">=3.7.0", + python_requires=">=3.10.0", install_requires=[ "aiohttp>=3.7.4", ], From cb90f5d7a8e98cd8ad11cdf596c24ef5320d49df Mon Sep 17 00:00:00 2001 From: Kirill Kuzminykh Date: Mon, 10 Jul 2023 10:10:27 +0300 Subject: [PATCH 20/23] Added "sockjs_transport_name" into request object. --- sockjs/route.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sockjs/route.py b/sockjs/route.py index 9174a0c8..12060f95 100644 --- a/sockjs/route.py +++ b/sockjs/route.py @@ -7,6 +7,7 @@ from typing import Iterable, List, Optional, Type from aiohttp import hdrs, web +from aiohttp.web_request import Request try: @@ -189,7 +190,7 @@ def __init__( transport_names - self.disable_transports ) - async def handler(self, request): + async def handler(self, request: Request): info = request.match_info # lookup transport @@ -212,6 +213,7 @@ async def handler(self, request): except KeyError: raise web.HTTPNotFound(headers=session_cookie(request)) + request["sockjs_transport_name"] = transport_class.name transport = transport_class(manager, session, request) try: return await transport.process() From 1ef9c1df78902911aadf23a21727742618b5852d Mon Sep 17 00:00:00 2001 From: Kirill Kuzminykh Date: Thu, 13 Jun 2024 14:36:30 +0300 Subject: [PATCH 21/23] - Updated dependencies versions. - Fixed code style. --- requirements.txt | 14 +++--- setup.cfg | 4 +- setup.py | 3 +- sockjs/protocol.py | 5 +- sockjs/route.py | 62 ++++++++++++------------ sockjs/session.py | 36 ++++++-------- sockjs/transports/base.py | 8 ++-- sockjs/transports/eventsource.py | 3 +- sockjs/transports/htmlfile.py | 3 +- sockjs/transports/jsonp.py | 1 + sockjs/transports/rawwebsocket.py | 1 + sockjs/transports/websocket.py | 1 + sockjs/transports/xhr_pooling.py | 1 + tests/conftest.py | 29 ++++++------ tests/test_session.py | 71 +++++++++++++++++++++++----- tests/test_transport_xhr.py | 2 +- tests/test_transport_xhrsend.py | 2 +- tests/test_transport_xhrstreaming.py | 2 +- 18 files changed, 146 insertions(+), 102 deletions(-) diff --git a/requirements.txt b/requirements.txt index 34d71460..a2afe0e8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ -flake8<6.0.0 -pytest==7.3.1 -pytest-aiohttp==1.0.4 -pytest-mock==3.10.0 -pytest-timeout==2.1.0 -aiohttp==3.8.4 -twine==4.0.2 +flake8==7.0.0 +pytest==8.2.2 +pytest-aiohttp==1.0.5 +pytest-mock==3.14.0 +pytest-timeout==2.3.1 +aiohttp==3.9.5 +twine==5.1.0 -e .[test] diff --git a/setup.cfg b/setup.cfg index f88d780e..a2ed1736 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,6 +9,6 @@ max-line-length = 88 [tool:pytest] timeout = 3 -filterwarnings= - error +#filterwarnings= +# error asyncio_mode = auto diff --git a/setup.py b/setup.py index 44860253..c4e31438 100644 --- a/setup.py +++ b/setup.py @@ -33,6 +33,7 @@ def read(f): "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Programming Language :: Python :: Implementation :: CPython", "Topic :: Internet :: WWW/HTTP", "Framework :: AsyncIO", @@ -55,7 +56,7 @@ def read(f): "pytest-mock", "pytest-timeout", "cykooz.testing", - 'aiohttp_cors', + "aiohttp_cors", ], }, include_package_data=True, diff --git a/sockjs/protocol.py b/sockjs/protocol.py index 7ec5fa9d..61c0868f 100644 --- a/sockjs/protocol.py +++ b/sockjs/protocol.py @@ -39,9 +39,9 @@ class SessionState(enum.Enum): try: import ujson as json - kwargs = {} # pragma: no cover except ImportError: # pragma: no cover + def dthandler(obj): if isinstance(obj, datetime): now = obj.timetuple() @@ -55,7 +55,6 @@ def dthandler(obj): now[5], ) - kwargs = {"default": dthandler, "separators": (",", ":")} # Faster @@ -69,6 +68,7 @@ def dthandler(obj): # Frames # ------ + @enum.unique class Frame(enum.Enum): OPEN = "o" @@ -123,6 +123,7 @@ def messages_frame(messages): # Handler messages # --------------------- + @enum.unique class MsgType(enum.Enum): OPEN = 1 diff --git a/sockjs/route.py b/sockjs/route.py index 12060f95..09a535c9 100644 --- a/sockjs/route.py +++ b/sockjs/route.py @@ -36,26 +36,25 @@ def _gen_endpoint_name(): def add_endpoint( - app: web.Application, - handler: HandlerType, - *, - name="", - prefix="/sockjs", - manager=None, - disable_transports=(), - sockjs_cdn="https://cdn.jsdelivr.net/npm/sockjs-client@1/dist/sockjs.min.js", # noqa - cookie_needed=True, - cors_config: Optional[CorsConfig] = None, - heartbeat_delay=25, - disconnect_delay=5, - debug=False, + app: web.Application, + handler: HandlerType, + *, + name="", + prefix="/sockjs", + manager=None, + disable_transports=(), + sockjs_cdn="https://cdn.jsdelivr.net/npm/sockjs-client@1/dist/sockjs.min.js", # noqa + cookie_needed=True, + cors_config: Optional[CorsConfig] = None, + heartbeat_delay=25, + disconnect_delay=5, + debug=False, ) -> List[web.AbstractRoute]: registered_routes = [] assert callable(handler), handler - if ( - not asyncio.iscoroutinefunction(handler) - and not inspect.isgeneratorfunction(handler) + if not asyncio.iscoroutinefunction(handler) and not inspect.isgeneratorfunction( + handler ): sync_handler = handler @@ -109,8 +108,7 @@ async def handler(m, s, msg): ) resource = router.add_resource( - "%s/{server}/{session}/{transport}" % prefix, - name=f"sockjs-transport-{name}" + "%s/{server}/{session}/{transport}" % prefix, name=f"sockjs-transport-{name}" ) for method in ALL_METH_WO_OPTIONS: registered_routes.append( @@ -124,7 +122,10 @@ async def handler(m, s, msg): route_name = "sockjs-websocket-%s" % name registered_routes.append( router.add_route( - hdrs.METH_GET, "%s/websocket" % prefix, route.websocket, name=route_name + hdrs.METH_GET, + "%s/websocket" % prefix, + route.websocket, + name=route_name, ) ) @@ -150,7 +151,7 @@ async def handler(m, s, msg): hdrs.METH_GET, "%s/iframe{version}.html" % prefix, route.iframe, - name=route_name + name=route_name, ) ) @@ -166,13 +167,13 @@ async def handler(m, s, msg): class SockJSRoute: def __init__( - self, - name: str, - manager: SessionManager, - sockjs_cdn: str, - handlers, - disable_transports: Iterable[str], - cookie_needed=True, + self, + name: str, + manager: SessionManager, + sockjs_cdn: str, + handlers, + disable_transports: Iterable[str], + cookie_needed=True, ): self.name = name self.manager = manager @@ -182,13 +183,10 @@ def __init__( self.iframe_html = (IFRAME_HTML % sockjs_cdn).encode("utf-8") self.iframe_html_hxd = hashlib.md5(self.iframe_html).hexdigest() transport_names = { - transport_class.name - for transport_class in transport_handlers.values() + transport_class.name for transport_class in transport_handlers.values() } transport_names.add("websocket-raw") - self._transport_names = sorted( - transport_names - self.disable_transports - ) + self._transport_names = sorted(transport_names - self.disable_transports) async def handler(self, request: Request): info = request.match_info diff --git a/sockjs/session.py b/sockjs/session.py index fdf33f91..126ef445 100644 --- a/sockjs/session.py +++ b/sockjs/session.py @@ -42,12 +42,7 @@ class Session: _hb_task = None def __init__( - self, - session_id: str, - *, - heartbeat_delay=25, - disconnect_delay=5, - debug=False + self, session_id: str, *, heartbeat_delay=25, disconnect_delay=5, debug=False, ): self.id = session_id self.heartbeat_delay = heartbeat_delay @@ -243,13 +238,13 @@ class SessionManager: _gc_task = None def __init__( - self, - name: str, - app: web.Application, - handler: HandlerType, - heartbeat_delay=25, - disconnect_delay=5, - debug=False, + self, + name: str, + app: web.Application, + handler: HandlerType, + heartbeat_delay=25, + disconnect_delay=5, + debug=False, ): self.name = name self.route_name = "sockjs-url-%s" % name @@ -301,10 +296,7 @@ async def _gc_sessions_task(self): async def _gc_expired_sessions(self): sessions = self.sessions if sessions: - tasks = [ - self._check_expiration(session) - for session in sessions.values() - ] + tasks = [self._check_expiration(session) for session in sessions.values()] expired_session_ids = await asyncio.gather(*tasks) idx = 0 @@ -321,13 +313,13 @@ def _add(self, session: Session): self.sessions[session.id] = session return session - _T = TypeVar('_T') + _T = TypeVar("_T") def get( - self, - session_id, - create=False, - default: _T = _marker, + self, + session_id, + create=False, + default: _T = _marker, ) -> Union[Session, _T]: session = self.sessions.get(session_id, None) if session is None: diff --git a/sockjs/transports/base.py b/sockjs/transports/base.py index 7af79618..dcfa4949 100644 --- a/sockjs/transports/base.py +++ b/sockjs/transports/base.py @@ -27,10 +27,10 @@ def get_session(cls, manager: SessionManager, session_id: str) -> Session: return manager.get(session_id, create=cls.create_session) def __init__( - self, - manager: SessionManager, - session: Session, - request: web.Request, + self, + manager: SessionManager, + session: Session, + request: web.Request, ): self.manager = manager self.session = session diff --git a/sockjs/transports/eventsource.py b/sockjs/transports/eventsource.py index eaccff40..0681ddac 100644 --- a/sockjs/transports/eventsource.py +++ b/sockjs/transports/eventsource.py @@ -1,4 +1,5 @@ -""" iframe-eventsource transport """ +"""iframe-eventsource transport""" + from aiohttp import hdrs, web from multidict import MultiDict diff --git a/sockjs/transports/htmlfile.py b/sockjs/transports/htmlfile.py index ef193658..240a89f5 100644 --- a/sockjs/transports/htmlfile.py +++ b/sockjs/transports/htmlfile.py @@ -1,4 +1,5 @@ -""" iframe-htmlfile transport """ +"""iframe-htmlfile transport""" + import re from aiohttp import hdrs, web diff --git a/sockjs/transports/jsonp.py b/sockjs/transports/jsonp.py index fd2389c6..3c2099cf 100644 --- a/sockjs/transports/jsonp.py +++ b/sockjs/transports/jsonp.py @@ -1,4 +1,5 @@ """jsonp transport""" + import re from urllib.parse import unquote_plus diff --git a/sockjs/transports/rawwebsocket.py b/sockjs/transports/rawwebsocket.py index f63ea44a..15415ab3 100644 --- a/sockjs/transports/rawwebsocket.py +++ b/sockjs/transports/rawwebsocket.py @@ -1,4 +1,5 @@ """raw websocket transport.""" + import asyncio from asyncio import ensure_future from typing import Optional diff --git a/sockjs/transports/websocket.py b/sockjs/transports/websocket.py index fd0ab4cd..e4e0979b 100644 --- a/sockjs/transports/websocket.py +++ b/sockjs/transports/websocket.py @@ -1,4 +1,5 @@ """websocket transport""" + import asyncio import logging from asyncio import ensure_future diff --git a/sockjs/transports/xhr_pooling.py b/sockjs/transports/xhr_pooling.py index 0ddd96f4..e05f194a 100644 --- a/sockjs/transports/xhr_pooling.py +++ b/sockjs/transports/xhr_pooling.py @@ -9,6 +9,7 @@ class XHRTransport(StreamingTransport): """Long polling derivative transports, used for XHRPolling and JSONPolling.""" + name = "xhr-polling" create_session = True maxsize = 0 diff --git a/tests/conftest.py b/tests/conftest.py index f97bba11..164383ea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -95,9 +95,7 @@ def maker(method, path, query_params=None, headers=None, match_info=None): @pytest.fixture def make_session(make_handler, make_request): def maker( - name="test", - disconnect_delay=10, - manager: Optional[SessionManager] = None + name="test", disconnect_delay=10, manager: Optional[SessionManager] = None ): session = Session(name, disconnect_delay=disconnect_delay, debug=True) if manager: @@ -108,7 +106,7 @@ def maker( @pytest.fixture -async def make_manager(event_loop, app, make_handler, make_session): +async def make_manager(app, make_handler, make_session): managers = [] def maker(handler=None): @@ -134,22 +132,25 @@ def maker(handlers=transports.transport_handlers): return maker -@pytest.fixture(name='test_client') +@pytest.fixture(name="test_client") async def test_client_fixture(app, aiohttp_client, make_handler) -> TestClient: handler = make_handler(None) # Configure default CORS settings. - cors = aiohttp_cors.setup(app, defaults={ - '*': aiohttp_cors.ResourceOptions( - allow_credentials=True, - expose_headers='*', - allow_headers='*', - max_age=31536000, - ) - }) + cors = aiohttp_cors.setup( + app, + defaults={ + "*": aiohttp_cors.ResourceOptions( + allow_credentials=True, + expose_headers="*", + allow_headers="*", + max_age=31536000, + ) + }, + ) add_endpoint( app, handler, - name='main', + name="main", cors_config=cors, ) return await aiohttp_client(app) diff --git a/tests/test_session.py b/tests/test_session.py index f5db114e..648d08b4 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -6,7 +6,14 @@ import pytest from aiohttp import web -from sockjs import Session, SessionIsAcquired, SessionIsClosed, protocol, SessionState, Frame +from sockjs import ( + Session, + SessionIsAcquired, + SessionIsClosed, + protocol, + SessionState, + Frame, +) class TestSession: @@ -95,9 +102,7 @@ async def test_heartbeat_transport(self, make_session): session = make_session("test") session._send_heartbeats = True session.heartbeat() - assert list(session._queue) == [ - (Frame.HEARTBEAT, Frame.HEARTBEAT.value) - ] + assert list(session._queue) == [(Frame.HEARTBEAT, Frame.HEARTBEAT.value)] async def test_expire(self, make_session, mocker): dt = mocker.patch("sockjs.session.datetime") @@ -234,7 +239,13 @@ async def test_close_idempotent(self, make_session): assert session.state == SessionState.CLOSED assert list(session._queue) == [] - async def test_acquire_new_session(self, make_manager, make_session, make_request, make_handler): + async def test_acquire_new_session( + self, + make_manager, + make_session, + make_request, + make_handler, + ): messages = [] handler = make_handler(result=messages) manager = make_manager(handler) @@ -284,9 +295,16 @@ async def test_remote_close(self, make_session, make_manager, make_handler): await manager.remote_close(session) assert not session.interrupted assert session.state == SessionState.CLOSING - assert messages == [(protocol.SockjsMessage(protocol.MsgType.CLOSE, None), session)] + assert messages == [ + (protocol.SockjsMessage(protocol.MsgType.CLOSE, None), session) + ] - async def test_remote_close_idempotent(self, make_session, make_manager, make_handler): + async def test_remote_close_idempotent( + self, + make_session, + make_manager, + make_handler, + ): messages = [] handler = make_handler(result=messages) manager = make_manager(handler) @@ -297,7 +315,12 @@ async def test_remote_close_idempotent(self, make_session, make_manager, make_ha assert session.state == SessionState.CLOSED assert messages == [] - async def test_remote_close_with_exc(self, make_session, make_manager, make_handler): + async def test_remote_close_with_exc( + self, + make_session, + make_manager, + make_handler, + ): messages = [] handler = make_handler(result=messages) manager = make_manager(handler) @@ -307,9 +330,16 @@ async def test_remote_close_with_exc(self, make_session, make_manager, make_hand await manager.remote_close(session, exc=exc) assert session.interrupted assert session.state == SessionState.CLOSING - assert messages == [(protocol.SockjsMessage(protocol.MsgType.CLOSE, exc), session)] + assert messages == [ + (protocol.SockjsMessage(protocol.MsgType.CLOSE, exc), session) + ] - async def test_remote_close_exc_in_handler(self, make_session, make_manager, make_handler): + async def test_remote_close_exc_in_handler( + self, + make_session, + make_manager, + make_handler, + ): handler = make_handler([], exc=True) manager = make_manager(handler) session = make_session() @@ -347,7 +377,12 @@ async def test_remote_closed(self, make_session, make_manager, make_handler): assert session.state == SessionState.CLOSED assert messages == [(protocol.CLOSED_MESSAGE, session)] - async def test_remote_closed_idempotent(self, make_session, make_manager, make_handler): + async def test_remote_closed_idempotent( + self, + make_session, + make_manager, + make_handler, + ): messages = [] handler = make_handler(result=messages) manager = make_manager(handler) @@ -358,7 +393,12 @@ async def test_remote_closed_idempotent(self, make_session, make_manager, make_h assert session.state == SessionState.CLOSED assert messages == [] - async def test_remote_closed_with_waiter(self, make_session, make_manager, make_handler): + async def test_remote_closed_with_waiter( + self, + make_session, + make_manager, + make_handler, + ): messages = [] handler = make_handler(result=messages) manager = make_manager(handler) @@ -374,7 +414,12 @@ async def test_remote_closed_with_waiter(self, make_session, make_manager, make_ assert session.state == SessionState.CLOSED assert messages == [(protocol.CLOSED_MESSAGE, session)] - async def test_remote_closed_exc_in_handler(self, make_session, make_manager, make_handler): + async def test_remote_closed_exc_in_handler( + self, + make_session, + make_manager, + make_handler, + ): handler = make_handler([], exc=True) manager = make_manager(handler) session = make_session(disconnect_delay=0) diff --git a/tests/test_transport_xhr.py b/tests/test_transport_xhr.py index 8971c48e..e9ccf0f7 100644 --- a/tests/test_transport_xhr.py +++ b/tests/test_transport_xhr.py @@ -24,7 +24,7 @@ async def test_process(make_transport, make_fut): assert resp.status == 200 -async def test_process_OPTIONS(make_transport): +async def test_process_options(make_transport): transp = make_transport(method="OPTIONS") resp = await transp.process() assert resp.status == 204 diff --git a/tests/test_transport_xhrsend.py b/tests/test_transport_xhrsend.py index 072ef66f..e7b4cf14 100644 --- a/tests/test_transport_xhrsend.py +++ b/tests/test_transport_xhrsend.py @@ -46,7 +46,7 @@ async def xtest_post_message(make_transport, make_fut): transp.manager.remote_messages.assert_called_with(["msg1", "msg2"]) -async def test_OPTIONS(make_transport): +async def test_options(make_transport): transp = make_transport(method="OPTIONS") resp = await transp.process() assert resp.status == 204 diff --git a/tests/test_transport_xhrstreaming.py b/tests/test_transport_xhrstreaming.py index f413161f..060664c4 100644 --- a/tests/test_transport_xhrstreaming.py +++ b/tests/test_transport_xhrstreaming.py @@ -24,7 +24,7 @@ async def test_process(make_transport, make_fut): assert resp.status == 200 -async def test_process_OPTIONS(make_transport): +async def test_process_options(make_transport): transp = make_transport(method="OPTIONS") resp = await transp.process() assert resp.status == 204 From 645ca016e45a9a848941077046feb2ea77a5792e Mon Sep 17 00:00:00 2001 From: Kirill Kuzminykh Date: Thu, 13 Jun 2024 15:04:11 +0300 Subject: [PATCH 22/23] Fixed dependencies. --- .github/workflows/check_and_test.yaml | 6 +++--- setup.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/check_and_test.yaml b/.github/workflows/check_and_test.yaml index 266e80d4..04aa843e 100644 --- a/.github/workflows/check_and_test.yaml +++ b/.github/workflows/check_and_test.yaml @@ -12,15 +12,15 @@ jobs: run_tests: strategy: matrix: - python-version: [ "3.10", "3.11" ] + python-version: [ "3.10", "3.11", "3.12" ] name: Test on Python ${{ matrix.python-version }} runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - - uses: actions/setup-python@v4 + - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} architecture: x64 diff --git a/setup.py b/setup.py index c4e31438..2a0e4923 100644 --- a/setup.py +++ b/setup.py @@ -46,6 +46,7 @@ def read(f): python_requires=">=3.10.0", install_requires=[ "aiohttp>=3.7.4", + "async-timeout>=4.0.3", ], extras_require={ "test": [ From f617baad36e20f22f5ce0ed6d377cf3ead3cda89 Mon Sep 17 00:00:00 2001 From: Kirill Kuzminykh Date: Thu, 13 Jun 2024 15:14:20 +0300 Subject: [PATCH 23/23] Fixed versions of GitHub actions. --- .github/workflows/pythonpublish.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pythonpublish.yml b/.github/workflows/pythonpublish.yml index 72fb8e7c..bb2a83e0 100644 --- a/.github/workflows/pythonpublish.yml +++ b/.github/workflows/pythonpublish.yml @@ -8,9 +8,9 @@ jobs: deploy: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.x' - name: Install dependencies