From b9d4bbebcbd1719aa0947549c34a689a6853a1bc Mon Sep 17 00:00:00 2001 From: Casper Welzel Andersen Date: Wed, 22 Apr 2020 12:10:22 +0200 Subject: [PATCH] Close SQLA session after every REST API request This is needed due to the server running in threaded mode, i.e., creating a new thread for each incoming request. This concept is great for handling many requests, but crashes when used together with AiiDA's global singleton SQLA session used, no matter the backend of the profile by the `QueryBuilder`. Specifically, this leads to issues with the SQLA QueuePool, since the connections are not properly released when a thread is closed. This leads to unintended QueuePool overflow. This fix wraps all HTTP method requests and makes sure to close the current thread's SQLA session after the request as been completely handled. Use Flask-RESTful's integrated `Resource` attribute `method_decorators` to apply `close_session` wrapper to all and any HTTP request that may be requested of AiiDA's `BaseResource` (and its sub-classes). Additionally, remove the `__init__` function overwritten in `Node(BaseResource)`, since it is redundant, and the attributes `tclass` is not relevant with v4 (AiiDA v1.0.0 and above), but was never removed. It should have been removed when moving to v4 in 4ff2829. Concerning the added tests: the timeout needs to be set for Python 3.5 in order to stop the http socket and properly raise (and escape out of an infinite loop). The `capfd` fixture must be used, otherwise the exception cannot be properly captured. The tests were simplified into the pytest scheme with ideas from @sphuber and @greschd. --- aiida/manage/tests/unittest_classes.py | 2 +- aiida/restapi/common/identifiers.py | 5 +- aiida/restapi/common/utils.py | 18 ++- aiida/restapi/resources.py | 26 ++-- aiida/restapi/run_api.py | 12 +- .../developer_guide/core/extend_restapi.rst | 5 +- tests/restapi/conftest.py | 50 ++++++++ tests/restapi/test_routes.py | 6 +- tests/restapi/test_threaded_restapi.py | 121 ++++++++++++++++++ 9 files changed, 212 insertions(+), 33 deletions(-) create mode 100644 tests/restapi/conftest.py create mode 100644 tests/restapi/test_threaded_restapi.py diff --git a/aiida/manage/tests/unittest_classes.py b/aiida/manage/tests/unittest_classes.py index bb747c7666..e722a78c8e 100644 --- a/aiida/manage/tests/unittest_classes.py +++ b/aiida/manage/tests/unittest_classes.py @@ -82,7 +82,7 @@ def run(self, suite, backend=None, profile_name=None): import warnings from aiida.common.warnings import AiidaDeprecationWarning warnings.warn( # pylint: disable=no-member - 'Please use "pytest" for testing AiiDA plugins. Support for "unittest" be removed in `v2.0.0`', + 'Please use "pytest" for testing AiiDA plugins. Support for "unittest" will be removed in `v2.0.0`', AiidaDeprecationWarning ) diff --git a/aiida/restapi/common/identifiers.py b/aiida/restapi/common/identifiers.py index 6e3109e49c..4fd1d5ff5d 100644 --- a/aiida/restapi/common/identifiers.py +++ b/aiida/restapi/common/identifiers.py @@ -31,8 +31,7 @@ 'process.calculation%.calcfunction.%|aiida.calculations:arithmetic.add' # More than one operator in segment """ - -import collections +from collections.abc import MutableMapping from aiida.common.escaping import escape_for_sql_like @@ -163,7 +162,7 @@ def load_entry_point_from_full_type(full_type): raise EntryPointError('entry point of the given full type cannot be loaded') -class Namespace(collections.MutableMapping): +class Namespace(MutableMapping): """Namespace that can be used to map the node class hierarchy.""" namespace_separator = '.' diff --git a/aiida/restapi/common/utils.py b/aiida/restapi/common/utils.py index 4abc4dd873..c69a14289f 100644 --- a/aiida/restapi/common/utils.py +++ b/aiida/restapi/common/utils.py @@ -8,13 +8,15 @@ # For further information please visit http://www.aiida.net # ########################################################################### """ Util methods """ -import urllib.parse from datetime import datetime, timedelta +import urllib.parse from flask import jsonify from flask.json import JSONEncoder +from wrapt import decorator from aiida.common.exceptions import InputValidationError, ValidationError +from aiida.manage.manager import get_manager from aiida.restapi.common.exceptions import RestValidationError, \ RestInputValidationError @@ -845,3 +847,17 @@ def list_routes(): output.append(line) return sorted(set(output)) + + +@decorator +def close_session(wrapped, _, args, kwargs): + """Close AiiDA SQLAlchemy (QueryBuilder) session + + This decorator can be used for router endpoints to close the SQLAlchemy global scoped session after the response + has been created. This is needed, since the QueryBuilder uses a SQLAlchemy global scoped session no matter the + profile's database backend. + """ + try: + return wrapped(*args, **kwargs) + finally: + get_manager().get_backend().get_session().close() diff --git a/aiida/restapi/resources.py b/aiida/restapi/resources.py index 49ccadec17..547f8a0f9f 100644 --- a/aiida/restapi/resources.py +++ b/aiida/restapi/resources.py @@ -15,12 +15,11 @@ from aiida.common.lang import classproperty from aiida.restapi.common.exceptions import RestInputValidationError -from aiida.restapi.common.utils import Utils +from aiida.restapi.common.utils import Utils, close_session class ServerInfo(Resource): - # pylint: disable=fixme - """Endpointd to return general server info""" + """Endpoint to return general server info""" def __init__(self, **kwargs): # Configure utils @@ -97,6 +96,8 @@ class BaseResource(Resource): _translator_class = BaseTranslator _parse_pk_uuid = None # Flag to tell the path parser whether to expect a pk or a uuid pattern + method_decorators = [close_session] # Close SQLA session after any method call + ## TODO add the caching support. I cache total count, results, and possibly def __init__(self, **kwargs): @@ -106,11 +107,13 @@ def __init__(self, **kwargs): utils_conf_keys = ('PREFIX', 'PERPAGE_DEFAULT', 'LIMIT_DEFAULT') self.utils_confs = {k: kwargs[k] for k in utils_conf_keys if k in kwargs} self.utils = Utils(**self.utils_confs) - self.method_decorators = {'get': kwargs.get('get_decorators', [])} + + # HTTP Request method decorators + if 'get_decorators' in kwargs and isinstance(kwargs['get_decorators'], (tuple, list, set)): + self.method_decorators = {'get': list(kwargs['get_decorators'])} @classproperty - def parse_pk_uuid(cls): - # pylint: disable=no-self-argument + def parse_pk_uuid(cls): # pylint: disable=no-self-argument return cls._parse_pk_uuid def _load_and_verify(self, node_id=None): @@ -212,17 +215,6 @@ class Node(BaseResource): _translator_class = NodeTranslator _parse_pk_uuid = 'uuid' # Parse a uuid pattern in the URL path (not a pk) - def __init__(self, **kwargs): - super().__init__(**kwargs) - from aiida.orm import Node as tNode - self.tclass = tNode - - # Configure utils - utils_conf_keys = ('PREFIX', 'PERPAGE_DEFAULT', 'LIMIT_DEFAULT') - self.utils_confs = {k: kwargs[k] for k in utils_conf_keys if k in kwargs} - self.utils = Utils(**self.utils_confs) - self.method_decorators = {'get': kwargs.get('get_decorators', [])} - def get(self, id=None, page=None): # pylint: disable=redefined-builtin,invalid-name,unused-argument # pylint: disable=too-many-locals,too-many-statements,too-many-branches,fixme,unused-variable """ diff --git a/aiida/restapi/run_api.py b/aiida/restapi/run_api.py index e26a3a0e97..8cc4df95d8 100755 --- a/aiida/restapi/run_api.py +++ b/aiida/restapi/run_api.py @@ -56,7 +56,7 @@ def run_api(flask_app=api_classes.App, flask_api=api_classes.AiidaApi, **kwargs) port = kwargs.pop('port', CLI_DEFAULTS['PORT']) debug = kwargs.pop('debug', APP_CONFIG['DEBUG']) - app, api = configure_api(flask_app, flask_api, **kwargs) + api = configure_api(flask_app, flask_api, **kwargs) if hookup: # Run app through built-in werkzeug server @@ -66,7 +66,7 @@ def run_api(flask_app=api_classes.App, flask_api=api_classes.AiidaApi, **kwargs) else: # Return the app & api without specifying port/host to be handled by an external server (e.g. apache). # Some of the user-defined configuration of the app is ineffective (only affects built-in server). - return (app, api) + return api.app, api def configure_api(flask_app=api_classes.App, flask_api=api_classes.AiidaApi, **kwargs): @@ -81,7 +81,8 @@ def configure_api(flask_app=api_classes.App, flask_api=api_classes.AiidaApi, **k :param catch_internal_server: If true, catch and print all inter server errors :param wsgi_profile: use WSGI profiler middleware for finding bottlenecks in web application - :returns: tuple (app, api) + :returns: Flask RESTful API + :rtype: :py:class:`flask_restful.Api` """ # Unpack parameters @@ -119,6 +120,5 @@ def configure_api(flask_app=api_classes.App, flask_api=api_classes.AiidaApi, **k app.config['PROFILE'] = True app.wsgi_app = ProfilerMiddleware(app.wsgi_app, restrictions=[30]) - # Instantiate an Api by associating its app - api = flask_api(app, **API_CONFIG) - return (app, api) + # Instantiate and return a Flask RESTful API by associating its app + return flask_api(app, **API_CONFIG) diff --git a/docs/source/developer_guide/core/extend_restapi.rst b/docs/source/developer_guide/core/extend_restapi.rst index af97cf7275..a5d35d28d4 100644 --- a/docs/source/developer_guide/core/extend_restapi.rst +++ b/docs/source/developer_guide/core/extend_restapi.rst @@ -369,13 +369,14 @@ as confirmed by the response to the GET request. As a final remark, there might be circumstances in which you do not want to use the internal werkzeug-based server. For example, you might want to run the app through Apache using a wsgi script. -In this case, simply use ``configure_api`` to return two custom objects ``app`` and ``api``: +In this case, simply use ``configure_api`` to return a custom object ``api``: .. code-block:: python - (app, api) = configure_api(App, MycloudApi, **kwargs) + api = configure_api(App, MycloudApi, **kwargs) +The ``app`` can be retrieved by ``api.app``. This snippet of code becomes the fundamental block of a *wsgi* file used by Apache as documented in :ref:`restapi_apache`. Moreover, we recommend to consult the documentation of `mod_wsgi `_. diff --git a/tests/restapi/conftest.py b/tests/restapi/conftest.py new file mode 100644 index 0000000000..5840127fa8 --- /dev/null +++ b/tests/restapi/conftest.py @@ -0,0 +1,50 @@ +"""pytest fixtures for use with the aiida.restapi tests""" +import pytest + + +@pytest.fixture(scope='function') +def restapi_server(): + """Make REST API server""" + from werkzeug.serving import make_server + + from aiida.restapi.common.config import CLI_DEFAULTS + from aiida.restapi.run_api import configure_api + + def _restapi_server(restapi=None): + if restapi is None: + flask_restapi = configure_api() + else: + flask_restapi = configure_api(flask_api=restapi) + + return make_server( + host=CLI_DEFAULTS['HOST_NAME'], + port=int(CLI_DEFAULTS['PORT']), + app=flask_restapi.app, + threaded=True, + processes=1, + request_handler=None, + passthrough_errors=True, + ssl_context=None, + fd=None + ) + + return _restapi_server + + +@pytest.fixture +def server_url(): + from aiida.restapi.common.config import CLI_DEFAULTS, API_CONFIG + + return 'http://{hostname}:{port}{api}'.format( + hostname=CLI_DEFAULTS['HOST_NAME'], port=CLI_DEFAULTS['PORT'], api=API_CONFIG['PREFIX'] + ) + + +@pytest.fixture +def restrict_sqlalchemy_queuepool(aiida_profile): + """Create special SQLAlchemy engine for use with QueryBuilder - backend-agnostic""" + from aiida.manage.manager import get_manager + + backend_manager = get_manager().get_backend_manager() + backend_manager.reset_backend_environment() + backend_manager.load_backend_environment(aiida_profile, pool_timeout=1, max_overflow=0) diff --git a/tests/restapi/test_routes.py b/tests/restapi/test_routes.py index ab1fb26363..db30d11ce9 100644 --- a/tests/restapi/test_routes.py +++ b/tests/restapi/test_routes.py @@ -43,9 +43,9 @@ def setUpClass(cls, *args, **kwargs): # pylint: disable=too-many-locals, too-ma # order, api.__init__) kwargs = dict(PREFIX=cls._url_prefix, PERPAGE_DEFAULT=cls._PERPAGE_DEFAULT, LIMIT_DEFAULT=cls._LIMIT_DEFAULT) - app, _api = configure_api(catch_internal_server=True) + api = configure_api(catch_internal_server=True) - cls.app = app + cls.app = api.app cls.app.config['TESTING'] = True # create test inputs @@ -286,7 +286,7 @@ def process_test( """ Check whether response matches expected values. - :param entity_type: url requested fot the type of the node + :param entity_type: url requested for the type of the node :param url: web url :param full_list: if url is requested to get full list :param empty_list: if the response list is empty diff --git a/tests/restapi/test_threaded_restapi.py b/tests/restapi/test_threaded_restapi.py new file mode 100644 index 0000000000..7a530061da --- /dev/null +++ b/tests/restapi/test_threaded_restapi.py @@ -0,0 +1,121 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Tests for the `aiida.restapi` module, using it in threaded mode. + +Threaded mode is the default (and only) way to run the AiiDA REST API (see `aiida.restapi.run_api:run_api()`). +This test file's layout is inspired by https://gist.github.com/prschmid/4643738 +""" +import time +from threading import Thread + +import requests +import pytest + +NO_OF_REQUESTS = 100 + + +@pytest.mark.usefixtures('clear_database_before_test', 'restrict_sqlalchemy_queuepool') +def test_run_threaded_server(restapi_server, server_url, aiida_localhost): + """Run AiiDA REST API threaded in a separate thread and perform many sequential requests""" + + server = restapi_server() + computer_id = aiida_localhost.uuid + + # Create a thread that will contain the running server, + # since we do not wish to block the main thread + server_thread = Thread(target=server.serve_forever) + + try: + server_thread.start() + + for _ in range(NO_OF_REQUESTS): + response = requests.get(server_url + '/computers/{}'.format(computer_id), timeout=10) + + assert response.status_code == 200 + + try: + response_json = response.json() + except ValueError: + pytest.fail('Could not turn response into JSON. Response: {}'.format(response.raw)) + else: + assert 'data' in response_json + + except Exception as exc: # pylint: disable=broad-except + pytest.fail('Something went terribly wrong! Exception: {}'.format(repr(exc))) + finally: + server.shutdown() + + # Wait a total of 1 min (100 x 0.6 s) for the Thread to close/join, else fail + for _ in range(100): + if server_thread.is_alive(): + time.sleep(0.6) + else: + break + else: + pytest.fail('Thread did not close/join within 1 min after REST API server was called to shutdown') + + +@pytest.mark.usefixtures('clear_database_before_test', 'restrict_sqlalchemy_queuepool') +def test_run_without_close_session(restapi_server, server_url, aiida_localhost, capfd): + """Run AiiDA REST API threaded in a separate thread and perform many sequential requests""" + from aiida.restapi.api import AiidaApi + from aiida.restapi.resources import Computer + + class NoCloseSessionApi(AiidaApi): + """Add Computer to this API (again) with a new endpoint, but pass an empty list for `get_decorators`""" + + def __init__(self, app=None, **kwargs): + super().__init__(app=app, **kwargs) + + # This is a copy of adding the `Computer` resource, + # but only a few URLs are added, and `get_decorators` is passed with an empty list. + extra_kwargs = kwargs.copy() + extra_kwargs.update({'get_decorators': []}) + self.add_resource( + Computer, + '/computers_no_close_session/', + '/computers_no_close_session//', + endpoint='computers_no_close_session', + strict_slashes=False, + resource_class_kwargs=extra_kwargs, + ) + + server = restapi_server(NoCloseSessionApi) + computer_id = aiida_localhost.uuid + + # Create a thread that will contain the running server, + # since we do not wish to block the main thread + server_thread = Thread(target=server.serve_forever) + + try: + server_thread.start() + + for _ in range(NO_OF_REQUESTS): + requests.get(server_url + '/computers_no_close_session/{}'.format(computer_id), timeout=10) + pytest.fail('{} requests were not enough to raise a SQLAlchemy TimeoutError!'.format(NO_OF_REQUESTS)) + + except (requests.exceptions.ConnectionError, OSError): + pass + except Exception as exc: # pylint: disable=broad-except + pytest.fail('Something went terribly wrong! Exception: {}'.format(repr(exc))) + finally: + server.shutdown() + + # Wait a total of 1 min (100 x 0.6 s) for the Thread to close/join, else fail + for _ in range(100): + if server_thread.is_alive(): + time.sleep(0.6) + else: + break + else: + pytest.fail('Thread did not close/join within 1 min after REST API server was called to shutdown') + + captured = capfd.readouterr() + assert 'sqlalchemy.exc.TimeoutError: QueuePool limit of size ' in captured.err