diff --git a/aiida/manage/tests/unittest_classes.py b/aiida/manage/tests/unittest_classes.py
index bb747c7666..e722a78c8e 100644
--- a/aiida/manage/tests/unittest_classes.py
+++ b/aiida/manage/tests/unittest_classes.py
@@ -82,7 +82,7 @@ def run(self, suite, backend=None, profile_name=None):
import warnings
from aiida.common.warnings import AiidaDeprecationWarning
warnings.warn( # pylint: disable=no-member
- 'Please use "pytest" for testing AiiDA plugins. Support for "unittest" be removed in `v2.0.0`',
+ 'Please use "pytest" for testing AiiDA plugins. Support for "unittest" will be removed in `v2.0.0`',
AiidaDeprecationWarning
)
diff --git a/aiida/restapi/common/identifiers.py b/aiida/restapi/common/identifiers.py
index 6e3109e49c..4fd1d5ff5d 100644
--- a/aiida/restapi/common/identifiers.py
+++ b/aiida/restapi/common/identifiers.py
@@ -31,8 +31,7 @@
'process.calculation%.calcfunction.%|aiida.calculations:arithmetic.add' # More than one operator in segment
"""
-
-import collections
+from collections.abc import MutableMapping
from aiida.common.escaping import escape_for_sql_like
@@ -163,7 +162,7 @@ def load_entry_point_from_full_type(full_type):
raise EntryPointError('entry point of the given full type cannot be loaded')
-class Namespace(collections.MutableMapping):
+class Namespace(MutableMapping):
"""Namespace that can be used to map the node class hierarchy."""
namespace_separator = '.'
diff --git a/aiida/restapi/common/utils.py b/aiida/restapi/common/utils.py
index 4abc4dd873..c69a14289f 100644
--- a/aiida/restapi/common/utils.py
+++ b/aiida/restapi/common/utils.py
@@ -8,13 +8,15 @@
# For further information please visit http://www.aiida.net #
###########################################################################
""" Util methods """
-import urllib.parse
from datetime import datetime, timedelta
+import urllib.parse
from flask import jsonify
from flask.json import JSONEncoder
+from wrapt import decorator
from aiida.common.exceptions import InputValidationError, ValidationError
+from aiida.manage.manager import get_manager
from aiida.restapi.common.exceptions import RestValidationError, \
RestInputValidationError
@@ -845,3 +847,17 @@ def list_routes():
output.append(line)
return sorted(set(output))
+
+
+@decorator
+def close_session(wrapped, _, args, kwargs):
+ """Close AiiDA SQLAlchemy (QueryBuilder) session
+
+ This decorator can be used for router endpoints to close the SQLAlchemy global scoped session after the response
+ has been created. This is needed, since the QueryBuilder uses a SQLAlchemy global scoped session no matter the
+ profile's database backend.
+ """
+ try:
+ return wrapped(*args, **kwargs)
+ finally:
+ get_manager().get_backend().get_session().close()
diff --git a/aiida/restapi/resources.py b/aiida/restapi/resources.py
index 49ccadec17..547f8a0f9f 100644
--- a/aiida/restapi/resources.py
+++ b/aiida/restapi/resources.py
@@ -15,12 +15,11 @@
from aiida.common.lang import classproperty
from aiida.restapi.common.exceptions import RestInputValidationError
-from aiida.restapi.common.utils import Utils
+from aiida.restapi.common.utils import Utils, close_session
class ServerInfo(Resource):
- # pylint: disable=fixme
- """Endpointd to return general server info"""
+ """Endpoint to return general server info"""
def __init__(self, **kwargs):
# Configure utils
@@ -97,6 +96,8 @@ class BaseResource(Resource):
_translator_class = BaseTranslator
_parse_pk_uuid = None # Flag to tell the path parser whether to expect a pk or a uuid pattern
+ method_decorators = [close_session] # Close SQLA session after any method call
+
## TODO add the caching support. I cache total count, results, and possibly
def __init__(self, **kwargs):
@@ -106,11 +107,13 @@ def __init__(self, **kwargs):
utils_conf_keys = ('PREFIX', 'PERPAGE_DEFAULT', 'LIMIT_DEFAULT')
self.utils_confs = {k: kwargs[k] for k in utils_conf_keys if k in kwargs}
self.utils = Utils(**self.utils_confs)
- self.method_decorators = {'get': kwargs.get('get_decorators', [])}
+
+ # HTTP Request method decorators
+ if 'get_decorators' in kwargs and isinstance(kwargs['get_decorators'], (tuple, list, set)):
+ self.method_decorators = {'get': list(kwargs['get_decorators'])}
@classproperty
- def parse_pk_uuid(cls):
- # pylint: disable=no-self-argument
+ def parse_pk_uuid(cls): # pylint: disable=no-self-argument
return cls._parse_pk_uuid
def _load_and_verify(self, node_id=None):
@@ -212,17 +215,6 @@ class Node(BaseResource):
_translator_class = NodeTranslator
_parse_pk_uuid = 'uuid' # Parse a uuid pattern in the URL path (not a pk)
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- from aiida.orm import Node as tNode
- self.tclass = tNode
-
- # Configure utils
- utils_conf_keys = ('PREFIX', 'PERPAGE_DEFAULT', 'LIMIT_DEFAULT')
- self.utils_confs = {k: kwargs[k] for k in utils_conf_keys if k in kwargs}
- self.utils = Utils(**self.utils_confs)
- self.method_decorators = {'get': kwargs.get('get_decorators', [])}
-
def get(self, id=None, page=None): # pylint: disable=redefined-builtin,invalid-name,unused-argument
# pylint: disable=too-many-locals,too-many-statements,too-many-branches,fixme,unused-variable
"""
diff --git a/aiida/restapi/run_api.py b/aiida/restapi/run_api.py
index e26a3a0e97..8cc4df95d8 100755
--- a/aiida/restapi/run_api.py
+++ b/aiida/restapi/run_api.py
@@ -56,7 +56,7 @@ def run_api(flask_app=api_classes.App, flask_api=api_classes.AiidaApi, **kwargs)
port = kwargs.pop('port', CLI_DEFAULTS['PORT'])
debug = kwargs.pop('debug', APP_CONFIG['DEBUG'])
- app, api = configure_api(flask_app, flask_api, **kwargs)
+ api = configure_api(flask_app, flask_api, **kwargs)
if hookup:
# Run app through built-in werkzeug server
@@ -66,7 +66,7 @@ def run_api(flask_app=api_classes.App, flask_api=api_classes.AiidaApi, **kwargs)
else:
# Return the app & api without specifying port/host to be handled by an external server (e.g. apache).
# Some of the user-defined configuration of the app is ineffective (only affects built-in server).
- return (app, api)
+ return api.app, api
def configure_api(flask_app=api_classes.App, flask_api=api_classes.AiidaApi, **kwargs):
@@ -81,7 +81,8 @@ def configure_api(flask_app=api_classes.App, flask_api=api_classes.AiidaApi, **k
:param catch_internal_server: If true, catch and print all inter server errors
:param wsgi_profile: use WSGI profiler middleware for finding bottlenecks in web application
- :returns: tuple (app, api)
+ :returns: Flask RESTful API
+ :rtype: :py:class:`flask_restful.Api`
"""
# Unpack parameters
@@ -119,6 +120,5 @@ def configure_api(flask_app=api_classes.App, flask_api=api_classes.AiidaApi, **k
app.config['PROFILE'] = True
app.wsgi_app = ProfilerMiddleware(app.wsgi_app, restrictions=[30])
- # Instantiate an Api by associating its app
- api = flask_api(app, **API_CONFIG)
- return (app, api)
+ # Instantiate and return a Flask RESTful API by associating its app
+ return flask_api(app, **API_CONFIG)
diff --git a/docs/source/developer_guide/core/extend_restapi.rst b/docs/source/developer_guide/core/extend_restapi.rst
index af97cf7275..a5d35d28d4 100644
--- a/docs/source/developer_guide/core/extend_restapi.rst
+++ b/docs/source/developer_guide/core/extend_restapi.rst
@@ -369,13 +369,14 @@ as confirmed by the response to the GET request.
As a final remark, there might be circumstances in which you do not want to use the internal werkzeug-based server.
For example, you might want to run the app through Apache using a wsgi script.
-In this case, simply use ``configure_api`` to return two custom objects ``app`` and ``api``:
+In this case, simply use ``configure_api`` to return a custom object ``api``:
.. code-block:: python
- (app, api) = configure_api(App, MycloudApi, **kwargs)
+ api = configure_api(App, MycloudApi, **kwargs)
+The ``app`` can be retrieved by ``api.app``.
This snippet of code becomes the fundamental block of a *wsgi* file used by Apache as documented in :ref:`restapi_apache`.
Moreover, we recommend to consult the documentation of `mod_wsgi `_.
diff --git a/tests/restapi/conftest.py b/tests/restapi/conftest.py
new file mode 100644
index 0000000000..5840127fa8
--- /dev/null
+++ b/tests/restapi/conftest.py
@@ -0,0 +1,50 @@
+"""pytest fixtures for use with the aiida.restapi tests"""
+import pytest
+
+
+@pytest.fixture(scope='function')
+def restapi_server():
+ """Make REST API server"""
+ from werkzeug.serving import make_server
+
+ from aiida.restapi.common.config import CLI_DEFAULTS
+ from aiida.restapi.run_api import configure_api
+
+ def _restapi_server(restapi=None):
+ if restapi is None:
+ flask_restapi = configure_api()
+ else:
+ flask_restapi = configure_api(flask_api=restapi)
+
+ return make_server(
+ host=CLI_DEFAULTS['HOST_NAME'],
+ port=int(CLI_DEFAULTS['PORT']),
+ app=flask_restapi.app,
+ threaded=True,
+ processes=1,
+ request_handler=None,
+ passthrough_errors=True,
+ ssl_context=None,
+ fd=None
+ )
+
+ return _restapi_server
+
+
+@pytest.fixture
+def server_url():
+ from aiida.restapi.common.config import CLI_DEFAULTS, API_CONFIG
+
+ return 'http://{hostname}:{port}{api}'.format(
+ hostname=CLI_DEFAULTS['HOST_NAME'], port=CLI_DEFAULTS['PORT'], api=API_CONFIG['PREFIX']
+ )
+
+
+@pytest.fixture
+def restrict_sqlalchemy_queuepool(aiida_profile):
+ """Create special SQLAlchemy engine for use with QueryBuilder - backend-agnostic"""
+ from aiida.manage.manager import get_manager
+
+ backend_manager = get_manager().get_backend_manager()
+ backend_manager.reset_backend_environment()
+ backend_manager.load_backend_environment(aiida_profile, pool_timeout=1, max_overflow=0)
diff --git a/tests/restapi/test_routes.py b/tests/restapi/test_routes.py
index ab1fb26363..db30d11ce9 100644
--- a/tests/restapi/test_routes.py
+++ b/tests/restapi/test_routes.py
@@ -43,9 +43,9 @@ def setUpClass(cls, *args, **kwargs): # pylint: disable=too-many-locals, too-ma
# order, api.__init__)
kwargs = dict(PREFIX=cls._url_prefix, PERPAGE_DEFAULT=cls._PERPAGE_DEFAULT, LIMIT_DEFAULT=cls._LIMIT_DEFAULT)
- app, _api = configure_api(catch_internal_server=True)
+ api = configure_api(catch_internal_server=True)
- cls.app = app
+ cls.app = api.app
cls.app.config['TESTING'] = True
# create test inputs
@@ -286,7 +286,7 @@ def process_test(
"""
Check whether response matches expected values.
- :param entity_type: url requested fot the type of the node
+ :param entity_type: url requested for the type of the node
:param url: web url
:param full_list: if url is requested to get full list
:param empty_list: if the response list is empty
diff --git a/tests/restapi/test_threaded_restapi.py b/tests/restapi/test_threaded_restapi.py
new file mode 100644
index 0000000000..7a530061da
--- /dev/null
+++ b/tests/restapi/test_threaded_restapi.py
@@ -0,0 +1,121 @@
+# -*- coding: utf-8 -*-
+###########################################################################
+# Copyright (c), The AiiDA team. All rights reserved. #
+# This file is part of the AiiDA code. #
+# #
+# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
+# For further information on the license, see the LICENSE.txt file #
+# For further information please visit http://www.aiida.net #
+###########################################################################
+"""Tests for the `aiida.restapi` module, using it in threaded mode.
+
+Threaded mode is the default (and only) way to run the AiiDA REST API (see `aiida.restapi.run_api:run_api()`).
+This test file's layout is inspired by https://gist.github.com/prschmid/4643738
+"""
+import time
+from threading import Thread
+
+import requests
+import pytest
+
+NO_OF_REQUESTS = 100
+
+
+@pytest.mark.usefixtures('clear_database_before_test', 'restrict_sqlalchemy_queuepool')
+def test_run_threaded_server(restapi_server, server_url, aiida_localhost):
+ """Run AiiDA REST API threaded in a separate thread and perform many sequential requests"""
+
+ server = restapi_server()
+ computer_id = aiida_localhost.uuid
+
+ # Create a thread that will contain the running server,
+ # since we do not wish to block the main thread
+ server_thread = Thread(target=server.serve_forever)
+
+ try:
+ server_thread.start()
+
+ for _ in range(NO_OF_REQUESTS):
+ response = requests.get(server_url + '/computers/{}'.format(computer_id), timeout=10)
+
+ assert response.status_code == 200
+
+ try:
+ response_json = response.json()
+ except ValueError:
+ pytest.fail('Could not turn response into JSON. Response: {}'.format(response.raw))
+ else:
+ assert 'data' in response_json
+
+ except Exception as exc: # pylint: disable=broad-except
+ pytest.fail('Something went terribly wrong! Exception: {}'.format(repr(exc)))
+ finally:
+ server.shutdown()
+
+ # Wait a total of 1 min (100 x 0.6 s) for the Thread to close/join, else fail
+ for _ in range(100):
+ if server_thread.is_alive():
+ time.sleep(0.6)
+ else:
+ break
+ else:
+ pytest.fail('Thread did not close/join within 1 min after REST API server was called to shutdown')
+
+
+@pytest.mark.usefixtures('clear_database_before_test', 'restrict_sqlalchemy_queuepool')
+def test_run_without_close_session(restapi_server, server_url, aiida_localhost, capfd):
+ """Run AiiDA REST API threaded in a separate thread and perform many sequential requests"""
+ from aiida.restapi.api import AiidaApi
+ from aiida.restapi.resources import Computer
+
+ class NoCloseSessionApi(AiidaApi):
+ """Add Computer to this API (again) with a new endpoint, but pass an empty list for `get_decorators`"""
+
+ def __init__(self, app=None, **kwargs):
+ super().__init__(app=app, **kwargs)
+
+ # This is a copy of adding the `Computer` resource,
+ # but only a few URLs are added, and `get_decorators` is passed with an empty list.
+ extra_kwargs = kwargs.copy()
+ extra_kwargs.update({'get_decorators': []})
+ self.add_resource(
+ Computer,
+ '/computers_no_close_session/',
+ '/computers_no_close_session//',
+ endpoint='computers_no_close_session',
+ strict_slashes=False,
+ resource_class_kwargs=extra_kwargs,
+ )
+
+ server = restapi_server(NoCloseSessionApi)
+ computer_id = aiida_localhost.uuid
+
+ # Create a thread that will contain the running server,
+ # since we do not wish to block the main thread
+ server_thread = Thread(target=server.serve_forever)
+
+ try:
+ server_thread.start()
+
+ for _ in range(NO_OF_REQUESTS):
+ requests.get(server_url + '/computers_no_close_session/{}'.format(computer_id), timeout=10)
+ pytest.fail('{} requests were not enough to raise a SQLAlchemy TimeoutError!'.format(NO_OF_REQUESTS))
+
+ except (requests.exceptions.ConnectionError, OSError):
+ pass
+ except Exception as exc: # pylint: disable=broad-except
+ pytest.fail('Something went terribly wrong! Exception: {}'.format(repr(exc)))
+ finally:
+ server.shutdown()
+
+ # Wait a total of 1 min (100 x 0.6 s) for the Thread to close/join, else fail
+ for _ in range(100):
+ if server_thread.is_alive():
+ time.sleep(0.6)
+ else:
+ break
+ else:
+ pytest.fail('Thread did not close/join within 1 min after REST API server was called to shutdown')
+
+ captured = capfd.readouterr()
+ assert 'sqlalchemy.exc.TimeoutError: QueuePool limit of size ' in captured.err