diff --git a/aiida/manage/tests/unittest_classes.py b/aiida/manage/tests/unittest_classes.py index bb747c7666..e722a78c8e 100644 --- a/aiida/manage/tests/unittest_classes.py +++ b/aiida/manage/tests/unittest_classes.py @@ -82,7 +82,7 @@ def run(self, suite, backend=None, profile_name=None): import warnings from aiida.common.warnings import AiidaDeprecationWarning warnings.warn( # pylint: disable=no-member - 'Please use "pytest" for testing AiiDA plugins. Support for "unittest" be removed in `v2.0.0`', + 'Please use "pytest" for testing AiiDA plugins. Support for "unittest" will be removed in `v2.0.0`', AiidaDeprecationWarning ) diff --git a/aiida/restapi/common/identifiers.py b/aiida/restapi/common/identifiers.py index 6e3109e49c..4fd1d5ff5d 100644 --- a/aiida/restapi/common/identifiers.py +++ b/aiida/restapi/common/identifiers.py @@ -31,8 +31,7 @@ 'process.calculation%.calcfunction.%|aiida.calculations:arithmetic.add' # More than one operator in segment """ - -import collections +from collections.abc import MutableMapping from aiida.common.escaping import escape_for_sql_like @@ -163,7 +162,7 @@ def load_entry_point_from_full_type(full_type): raise EntryPointError('entry point of the given full type cannot be loaded') -class Namespace(collections.MutableMapping): +class Namespace(MutableMapping): """Namespace that can be used to map the node class hierarchy.""" namespace_separator = '.' diff --git a/aiida/restapi/common/utils.py b/aiida/restapi/common/utils.py index 4abc4dd873..c69a14289f 100644 --- a/aiida/restapi/common/utils.py +++ b/aiida/restapi/common/utils.py @@ -8,13 +8,15 @@ # For further information please visit http://www.aiida.net # ########################################################################### """ Util methods """ -import urllib.parse from datetime import datetime, timedelta +import urllib.parse from flask import jsonify from flask.json import JSONEncoder +from wrapt import decorator from aiida.common.exceptions import InputValidationError, ValidationError +from aiida.manage.manager import get_manager from aiida.restapi.common.exceptions import RestValidationError, \ RestInputValidationError @@ -845,3 +847,17 @@ def list_routes(): output.append(line) return sorted(set(output)) + + +@decorator +def close_session(wrapped, _, args, kwargs): + """Close AiiDA SQLAlchemy (QueryBuilder) session + + This decorator can be used for router endpoints to close the SQLAlchemy global scoped session after the response + has been created. This is needed, since the QueryBuilder uses a SQLAlchemy global scoped session no matter the + profile's database backend. + """ + try: + return wrapped(*args, **kwargs) + finally: + get_manager().get_backend().get_session().close() diff --git a/aiida/restapi/resources.py b/aiida/restapi/resources.py index 49ccadec17..547f8a0f9f 100644 --- a/aiida/restapi/resources.py +++ b/aiida/restapi/resources.py @@ -15,12 +15,11 @@ from aiida.common.lang import classproperty from aiida.restapi.common.exceptions import RestInputValidationError -from aiida.restapi.common.utils import Utils +from aiida.restapi.common.utils import Utils, close_session class ServerInfo(Resource): - # pylint: disable=fixme - """Endpointd to return general server info""" + """Endpoint to return general server info""" def __init__(self, **kwargs): # Configure utils @@ -97,6 +96,8 @@ class BaseResource(Resource): _translator_class = BaseTranslator _parse_pk_uuid = None # Flag to tell the path parser whether to expect a pk or a uuid pattern + method_decorators = [close_session] # Close SQLA session after any method call + ## TODO add the caching support. I cache total count, results, and possibly def __init__(self, **kwargs): @@ -106,11 +107,13 @@ def __init__(self, **kwargs): utils_conf_keys = ('PREFIX', 'PERPAGE_DEFAULT', 'LIMIT_DEFAULT') self.utils_confs = {k: kwargs[k] for k in utils_conf_keys if k in kwargs} self.utils = Utils(**self.utils_confs) - self.method_decorators = {'get': kwargs.get('get_decorators', [])} + + # HTTP Request method decorators + if 'get_decorators' in kwargs and isinstance(kwargs['get_decorators'], (tuple, list, set)): + self.method_decorators = {'get': list(kwargs['get_decorators'])} @classproperty - def parse_pk_uuid(cls): - # pylint: disable=no-self-argument + def parse_pk_uuid(cls): # pylint: disable=no-self-argument return cls._parse_pk_uuid def _load_and_verify(self, node_id=None): @@ -212,17 +215,6 @@ class Node(BaseResource): _translator_class = NodeTranslator _parse_pk_uuid = 'uuid' # Parse a uuid pattern in the URL path (not a pk) - def __init__(self, **kwargs): - super().__init__(**kwargs) - from aiida.orm import Node as tNode - self.tclass = tNode - - # Configure utils - utils_conf_keys = ('PREFIX', 'PERPAGE_DEFAULT', 'LIMIT_DEFAULT') - self.utils_confs = {k: kwargs[k] for k in utils_conf_keys if k in kwargs} - self.utils = Utils(**self.utils_confs) - self.method_decorators = {'get': kwargs.get('get_decorators', [])} - def get(self, id=None, page=None): # pylint: disable=redefined-builtin,invalid-name,unused-argument # pylint: disable=too-many-locals,too-many-statements,too-many-branches,fixme,unused-variable """ diff --git a/aiida/restapi/run_api.py b/aiida/restapi/run_api.py index e26a3a0e97..8cc4df95d8 100755 --- a/aiida/restapi/run_api.py +++ b/aiida/restapi/run_api.py @@ -56,7 +56,7 @@ def run_api(flask_app=api_classes.App, flask_api=api_classes.AiidaApi, **kwargs) port = kwargs.pop('port', CLI_DEFAULTS['PORT']) debug = kwargs.pop('debug', APP_CONFIG['DEBUG']) - app, api = configure_api(flask_app, flask_api, **kwargs) + api = configure_api(flask_app, flask_api, **kwargs) if hookup: # Run app through built-in werkzeug server @@ -66,7 +66,7 @@ def run_api(flask_app=api_classes.App, flask_api=api_classes.AiidaApi, **kwargs) else: # Return the app & api without specifying port/host to be handled by an external server (e.g. apache). # Some of the user-defined configuration of the app is ineffective (only affects built-in server). - return (app, api) + return api.app, api def configure_api(flask_app=api_classes.App, flask_api=api_classes.AiidaApi, **kwargs): @@ -81,7 +81,8 @@ def configure_api(flask_app=api_classes.App, flask_api=api_classes.AiidaApi, **k :param catch_internal_server: If true, catch and print all inter server errors :param wsgi_profile: use WSGI profiler middleware for finding bottlenecks in web application - :returns: tuple (app, api) + :returns: Flask RESTful API + :rtype: :py:class:`flask_restful.Api` """ # Unpack parameters @@ -119,6 +120,5 @@ def configure_api(flask_app=api_classes.App, flask_api=api_classes.AiidaApi, **k app.config['PROFILE'] = True app.wsgi_app = ProfilerMiddleware(app.wsgi_app, restrictions=[30]) - # Instantiate an Api by associating its app - api = flask_api(app, **API_CONFIG) - return (app, api) + # Instantiate and return a Flask RESTful API by associating its app + return flask_api(app, **API_CONFIG) diff --git a/docs/source/developer_guide/core/extend_restapi.rst b/docs/source/developer_guide/core/extend_restapi.rst index af97cf7275..a5d35d28d4 100644 --- a/docs/source/developer_guide/core/extend_restapi.rst +++ b/docs/source/developer_guide/core/extend_restapi.rst @@ -369,13 +369,14 @@ as confirmed by the response to the GET request. As a final remark, there might be circumstances in which you do not want to use the internal werkzeug-based server. For example, you might want to run the app through Apache using a wsgi script. -In this case, simply use ``configure_api`` to return two custom objects ``app`` and ``api``: +In this case, simply use ``configure_api`` to return a custom object ``api``: .. code-block:: python - (app, api) = configure_api(App, MycloudApi, **kwargs) + api = configure_api(App, MycloudApi, **kwargs) +The ``app`` can be retrieved by ``api.app``. This snippet of code becomes the fundamental block of a *wsgi* file used by Apache as documented in :ref:`restapi_apache`. Moreover, we recommend to consult the documentation of `mod_wsgi `_. diff --git a/tests/restapi/conftest.py b/tests/restapi/conftest.py new file mode 100644 index 0000000000..5840127fa8 --- /dev/null +++ b/tests/restapi/conftest.py @@ -0,0 +1,50 @@ +"""pytest fixtures for use with the aiida.restapi tests""" +import pytest + + +@pytest.fixture(scope='function') +def restapi_server(): + """Make REST API server""" + from werkzeug.serving import make_server + + from aiida.restapi.common.config import CLI_DEFAULTS + from aiida.restapi.run_api import configure_api + + def _restapi_server(restapi=None): + if restapi is None: + flask_restapi = configure_api() + else: + flask_restapi = configure_api(flask_api=restapi) + + return make_server( + host=CLI_DEFAULTS['HOST_NAME'], + port=int(CLI_DEFAULTS['PORT']), + app=flask_restapi.app, + threaded=True, + processes=1, + request_handler=None, + passthrough_errors=True, + ssl_context=None, + fd=None + ) + + return _restapi_server + + +@pytest.fixture +def server_url(): + from aiida.restapi.common.config import CLI_DEFAULTS, API_CONFIG + + return 'http://{hostname}:{port}{api}'.format( + hostname=CLI_DEFAULTS['HOST_NAME'], port=CLI_DEFAULTS['PORT'], api=API_CONFIG['PREFIX'] + ) + + +@pytest.fixture +def restrict_sqlalchemy_queuepool(aiida_profile): + """Create special SQLAlchemy engine for use with QueryBuilder - backend-agnostic""" + from aiida.manage.manager import get_manager + + backend_manager = get_manager().get_backend_manager() + backend_manager.reset_backend_environment() + backend_manager.load_backend_environment(aiida_profile, pool_timeout=1, max_overflow=0) diff --git a/tests/restapi/test_routes.py b/tests/restapi/test_routes.py index ab1fb26363..db30d11ce9 100644 --- a/tests/restapi/test_routes.py +++ b/tests/restapi/test_routes.py @@ -43,9 +43,9 @@ def setUpClass(cls, *args, **kwargs): # pylint: disable=too-many-locals, too-ma # order, api.__init__) kwargs = dict(PREFIX=cls._url_prefix, PERPAGE_DEFAULT=cls._PERPAGE_DEFAULT, LIMIT_DEFAULT=cls._LIMIT_DEFAULT) - app, _api = configure_api(catch_internal_server=True) + api = configure_api(catch_internal_server=True) - cls.app = app + cls.app = api.app cls.app.config['TESTING'] = True # create test inputs @@ -286,7 +286,7 @@ def process_test( """ Check whether response matches expected values. - :param entity_type: url requested fot the type of the node + :param entity_type: url requested for the type of the node :param url: web url :param full_list: if url is requested to get full list :param empty_list: if the response list is empty diff --git a/tests/restapi/test_threaded_restapi.py b/tests/restapi/test_threaded_restapi.py new file mode 100644 index 0000000000..7a530061da --- /dev/null +++ b/tests/restapi/test_threaded_restapi.py @@ -0,0 +1,121 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Tests for the `aiida.restapi` module, using it in threaded mode. + +Threaded mode is the default (and only) way to run the AiiDA REST API (see `aiida.restapi.run_api:run_api()`). +This test file's layout is inspired by https://gist.github.com/prschmid/4643738 +""" +import time +from threading import Thread + +import requests +import pytest + +NO_OF_REQUESTS = 100 + + +@pytest.mark.usefixtures('clear_database_before_test', 'restrict_sqlalchemy_queuepool') +def test_run_threaded_server(restapi_server, server_url, aiida_localhost): + """Run AiiDA REST API threaded in a separate thread and perform many sequential requests""" + + server = restapi_server() + computer_id = aiida_localhost.uuid + + # Create a thread that will contain the running server, + # since we do not wish to block the main thread + server_thread = Thread(target=server.serve_forever) + + try: + server_thread.start() + + for _ in range(NO_OF_REQUESTS): + response = requests.get(server_url + '/computers/{}'.format(computer_id), timeout=10) + + assert response.status_code == 200 + + try: + response_json = response.json() + except ValueError: + pytest.fail('Could not turn response into JSON. Response: {}'.format(response.raw)) + else: + assert 'data' in response_json + + except Exception as exc: # pylint: disable=broad-except + pytest.fail('Something went terribly wrong! Exception: {}'.format(repr(exc))) + finally: + server.shutdown() + + # Wait a total of 1 min (100 x 0.6 s) for the Thread to close/join, else fail + for _ in range(100): + if server_thread.is_alive(): + time.sleep(0.6) + else: + break + else: + pytest.fail('Thread did not close/join within 1 min after REST API server was called to shutdown') + + +@pytest.mark.usefixtures('clear_database_before_test', 'restrict_sqlalchemy_queuepool') +def test_run_without_close_session(restapi_server, server_url, aiida_localhost, capfd): + """Run AiiDA REST API threaded in a separate thread and perform many sequential requests""" + from aiida.restapi.api import AiidaApi + from aiida.restapi.resources import Computer + + class NoCloseSessionApi(AiidaApi): + """Add Computer to this API (again) with a new endpoint, but pass an empty list for `get_decorators`""" + + def __init__(self, app=None, **kwargs): + super().__init__(app=app, **kwargs) + + # This is a copy of adding the `Computer` resource, + # but only a few URLs are added, and `get_decorators` is passed with an empty list. + extra_kwargs = kwargs.copy() + extra_kwargs.update({'get_decorators': []}) + self.add_resource( + Computer, + '/computers_no_close_session/', + '/computers_no_close_session//', + endpoint='computers_no_close_session', + strict_slashes=False, + resource_class_kwargs=extra_kwargs, + ) + + server = restapi_server(NoCloseSessionApi) + computer_id = aiida_localhost.uuid + + # Create a thread that will contain the running server, + # since we do not wish to block the main thread + server_thread = Thread(target=server.serve_forever) + + try: + server_thread.start() + + for _ in range(NO_OF_REQUESTS): + requests.get(server_url + '/computers_no_close_session/{}'.format(computer_id), timeout=10) + pytest.fail('{} requests were not enough to raise a SQLAlchemy TimeoutError!'.format(NO_OF_REQUESTS)) + + except (requests.exceptions.ConnectionError, OSError): + pass + except Exception as exc: # pylint: disable=broad-except + pytest.fail('Something went terribly wrong! Exception: {}'.format(repr(exc))) + finally: + server.shutdown() + + # Wait a total of 1 min (100 x 0.6 s) for the Thread to close/join, else fail + for _ in range(100): + if server_thread.is_alive(): + time.sleep(0.6) + else: + break + else: + pytest.fail('Thread did not close/join within 1 min after REST API server was called to shutdown') + + captured = capfd.readouterr() + assert 'sqlalchemy.exc.TimeoutError: QueuePool limit of size ' in captured.err