diff --git a/connexion/apps/abstract.py b/connexion/apps/abstract.py index f70767984..c8912be1e 100644 --- a/connexion/apps/abstract.py +++ b/connexion/apps/abstract.py @@ -10,10 +10,12 @@ from starlette.types import ASGIApp, Receive, Scope, Send from connexion.jsonifier import Jsonifier +from connexion.lifecycle import ConnexionRequest, ConnexionResponse from connexion.middleware import ConnexionMiddleware, MiddlewarePosition, SpecMiddleware from connexion.middleware.lifespan import Lifespan from connexion.options import SwaggerUIOptions from connexion.resolver import Resolver +from connexion.types import MaybeAwaitable from connexion.uri_parsing import AbstractURIParser @@ -250,14 +252,18 @@ def decorator(func: t.Callable) -> t.Callable: @abc.abstractmethod def add_error_handler( - self, code_or_exception: t.Union[int, t.Type[Exception]], function: t.Callable + self, + code_or_exception: t.Union[int, t.Type[Exception]], + function: t.Callable[ + [ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse] + ], ) -> None: """ Register a callable to handle application errors. :param code_or_exception: An exception class or the status code of HTTP exceptions to handle. - :param function: Callable that will handle exception. + :param function: Callable that will handle exception, may be async. """ def test_client(self, **kwargs): diff --git a/connexion/apps/asynchronous.py b/connexion/apps/asynchronous.py index 9974381fd..9ab08a401 100644 --- a/connexion/apps/asynchronous.py +++ b/connexion/apps/asynchronous.py @@ -14,11 +14,13 @@ from connexion.apps.abstract import AbstractApp from connexion.decorators import StarletteDecorator from connexion.jsonifier import Jsonifier +from connexion.lifecycle import ConnexionRequest, ConnexionResponse from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware from connexion.middleware.lifespan import Lifespan from connexion.operations import AbstractOperation from connexion.options import SwaggerUIOptions from connexion.resolver import Resolver +from connexion.types import MaybeAwaitable from connexion.uri_parsing import AbstractURIParser logger = logging.getLogger(__name__) @@ -88,7 +90,7 @@ def make_operation(self, operation: AbstractOperation) -> AsyncOperation: ) -class AsyncMiddlewareApp(RoutedMiddleware[AsyncApi]): +class AsyncASGIApp(RoutedMiddleware[AsyncApi]): api_cls = AsyncApi @@ -176,7 +178,7 @@ def __init__( :param security_map: A dictionary of security handlers to use. Defaults to :obj:`security.SECURITY_HANDLERS` """ - self._middleware_app: AsyncMiddlewareApp = AsyncMiddlewareApp() + self._middleware_app: AsyncASGIApp = AsyncASGIApp() super().__init__( import_name, @@ -205,6 +207,10 @@ def add_url_rule( ) def add_error_handler( - self, code_or_exception: t.Union[int, t.Type[Exception]], function: t.Callable + self, + code_or_exception: t.Union[int, t.Type[Exception]], + function: t.Callable[ + [ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse] + ], ) -> None: self.middleware.add_error_handler(code_or_exception, function) diff --git a/connexion/apps/flask.py b/connexion/apps/flask.py index a584c344b..98345ff8d 100644 --- a/connexion/apps/flask.py +++ b/connexion/apps/flask.py @@ -6,23 +6,22 @@ import typing as t import flask -import werkzeug.exceptions from a2wsgi import WSGIMiddleware from flask import Response as FlaskResponse -from flask import signals from starlette.types import Receive, Scope, Send from connexion.apps.abstract import AbstractApp from connexion.decorators import FlaskDecorator -from connexion.exceptions import InternalServerError, ProblemException, ResolverError +from connexion.exceptions import ResolverError from connexion.frameworks import flask as flask_utils from connexion.jsonifier import Jsonifier +from connexion.lifecycle import ConnexionRequest, ConnexionResponse from connexion.middleware.abstract import AbstractRoutingAPI, SpecMiddleware from connexion.middleware.lifespan import Lifespan from connexion.operations import AbstractOperation from connexion.options import SwaggerUIOptions -from connexion.problem import problem from connexion.resolver import Resolver +from connexion.types import MaybeAwaitable from connexion.uri_parsing import AbstractURIParser @@ -117,44 +116,20 @@ def add_url_rule( return self.blueprint.add_url_rule(rule, endpoint, view_func, **options) -class FlaskMiddlewareApp(SpecMiddleware): +class FlaskASGIApp(SpecMiddleware): def __init__(self, import_name, server_args: dict, **kwargs): self.app = flask.Flask(import_name, **server_args) self.app.json = flask_utils.FlaskJSONProvider(self.app) self.app.url_map.converters["float"] = flask_utils.NumberConverter self.app.url_map.converters["int"] = flask_utils.IntegerConverter - self.set_errors_handlers() + # Propagate Errors so we can handle them in the middleware + self.app.config["PROPAGATE_EXCEPTIONS"] = True + self.app.config["TRAP_BAD_REQUEST_ERRORS"] = True + self.app.config["TRAP_HTTP_EXCEPTIONS"] = True self.asgi_app = WSGIMiddleware(self.app.wsgi_app) - def set_errors_handlers(self): - for error_code in werkzeug.exceptions.default_exceptions: - self.app.register_error_handler(error_code, self.common_error_handler) - - self.app.register_error_handler(ProblemException, self.common_error_handler) - - def common_error_handler(self, exception: Exception) -> FlaskResponse: - """Default error handler.""" - if isinstance(exception, ProblemException): - response = exception.to_problem() - else: - if not isinstance(exception, werkzeug.exceptions.HTTPException): - exception = InternalServerError() - - response = problem( - title=exception.name, - detail=exception.description, - status=exception.code, - ) - - if response.status_code >= 500: - signals.got_request_exception.send(self.app, exception=exception) - - return flask.make_response( - (response.body, response.status_code, response.headers) - ) - def add_api(self, specification, *, name: str = None, **kwargs): api = FlaskApi(specification, **kwargs) @@ -177,7 +152,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: class FlaskApp(AbstractApp): """Connexion Application based on ConnexionMiddleware wrapping a Flask application.""" - _middleware_app: FlaskMiddlewareApp + _middleware_app: FlaskASGIApp def __init__( self, @@ -237,7 +212,7 @@ def __init__( :param security_map: A dictionary of security handlers to use. Defaults to :obj:`security.SECURITY_HANDLERS` """ - self._middleware_app = FlaskMiddlewareApp(import_name, server_args or {}) + self._middleware_app = FlaskASGIApp(import_name, server_args or {}) self.app = self._middleware_app.app super().__init__( import_name, @@ -266,6 +241,10 @@ def add_url_rule( ) def add_error_handler( - self, code_or_exception: t.Union[int, t.Type[Exception]], function: t.Callable + self, + code_or_exception: t.Union[int, t.Type[Exception]], + function: t.Callable[ + [ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse] + ], ) -> None: self.app.register_error_handler(code_or_exception, function) diff --git a/connexion/context.py b/connexion/context.py index b72fa079d..e21128c70 100644 --- a/connexion/context.py +++ b/connexion/context.py @@ -3,7 +3,7 @@ from starlette.types import Receive, Scope from werkzeug.local import LocalProxy -from connexion.lifecycle import ASGIRequest +from connexion.lifecycle import ConnexionRequest from connexion.operations import AbstractOperation UNBOUND_MESSAGE = ( @@ -25,5 +25,5 @@ scope = LocalProxy(_scope, unbound_message=UNBOUND_MESSAGE) request = LocalProxy( - lambda: ASGIRequest(scope, receive), unbound_message=UNBOUND_MESSAGE + lambda: ConnexionRequest(scope, receive), unbound_message=UNBOUND_MESSAGE ) diff --git a/connexion/decorators/parameter.py b/connexion/decorators/parameter.py index 3ae1e0fb0..ea12a69bc 100644 --- a/connexion/decorators/parameter.py +++ b/connexion/decorators/parameter.py @@ -16,7 +16,7 @@ from connexion.context import context, operation from connexion.frameworks.abstract import Framework from connexion.http_facts import FORM_CONTENT_TYPES -from connexion.lifecycle import ASGIRequest, WSGIRequest +from connexion.lifecycle import ConnexionRequest, WSGIRequest from connexion.operations import AbstractOperation, Swagger2Operation from connexion.utils import ( deep_merge, @@ -43,7 +43,7 @@ def __init__( def _maybe_get_body( self, - request: t.Union[WSGIRequest, ASGIRequest], + request: t.Union[WSGIRequest, ConnexionRequest], *, arguments: t.List[str], has_kwargs: bool, @@ -95,7 +95,7 @@ def __call__(self, function: t.Callable) -> t.Callable: arguments, has_kwargs = inspect_function_arguments(unwrapped_function) @functools.wraps(function) - async def wrapper(request: ASGIRequest) -> t.Any: + async def wrapper(request: ConnexionRequest) -> t.Any: request_body = self._maybe_get_body( request, arguments=arguments, has_kwargs=has_kwargs ) @@ -118,7 +118,7 @@ async def wrapper(request: ASGIRequest) -> t.Any: def prep_kwargs( - request: t.Union[WSGIRequest, ASGIRequest], + request: t.Union[WSGIRequest, ConnexionRequest], *, request_body: t.Any, files: t.Dict[str, t.Any], diff --git a/connexion/frameworks/starlette.py b/connexion/frameworks/starlette.py index 702c7e48f..097c10e54 100644 --- a/connexion/frameworks/starlette.py +++ b/connexion/frameworks/starlette.py @@ -8,7 +8,7 @@ from starlette.types import Receive, Scope from connexion.frameworks.abstract import Framework -from connexion.lifecycle import ASGIRequest +from connexion.lifecycle import ConnexionRequest from connexion.uri_parsing import AbstractURIParser @@ -48,8 +48,8 @@ def build_response( ) @staticmethod - def get_request(*, scope: Scope, receive: Receive, uri_parser: AbstractURIParser, **kwargs) -> ASGIRequest: # type: ignore - return ASGIRequest(scope, receive, uri_parser=uri_parser) + def get_request(*, scope: Scope, receive: Receive, uri_parser: AbstractURIParser, **kwargs) -> ConnexionRequest: # type: ignore + return ConnexionRequest(scope, receive, uri_parser=uri_parser) PATH_PARAMETER = re.compile(r"\{([^}]*)\}") diff --git a/connexion/lifecycle.py b/connexion/lifecycle.py index e002207e0..a388d79e2 100644 --- a/connexion/lifecycle.py +++ b/connexion/lifecycle.py @@ -130,7 +130,7 @@ def __getattr__(self, item): return getattr(self._werkzeug_request, item) -class ASGIRequest(_RequestInterface): +class ConnexionRequest(_RequestInterface): """ Implementation of the Connexion :code:`_RequestInterface` representing an ASGI request. @@ -142,7 +142,9 @@ class ASGIRequest(_RequestInterface): """ def __init__(self, *args, uri_parser=None, **kwargs): - self._starlette_request = StarletteRequest(*args, **kwargs) + # Might be set in `from_starlette_request` class method + if not hasattr(self, "_starlette_request"): + self._starlette_request = StarletteRequest(*args, **kwargs) self.uri_parser = uri_parser self._context = None @@ -152,6 +154,16 @@ def __init__(self, *args, uri_parser=None, **kwargs): self._form = None self._files = None + @classmethod + def from_starlette_request( + cls, request: StarletteRequest, uri_parser=None + ) -> "ConnexionRequest": + # Instantiate the class, and set the `_starlette_request` property before initializing. + self = cls.__new__(cls) + self._starlette_request = request + self.__init__(uri_parser=uri_parser) # type: ignore + return self + @property def context(self): if self._context is None: @@ -226,7 +238,8 @@ async def get_body(self): return await self.body() or None def __getattr__(self, item): - return getattr(self._starlette_request, item) + if self.__getattribute__("_starlette_request"): + return getattr(self._starlette_request, item) class ConnexionResponse: diff --git a/connexion/middleware/exceptions.py b/connexion/middleware/exceptions.py index cba002d7b..e76f4b93e 100644 --- a/connexion/middleware/exceptions.py +++ b/connexion/middleware/exceptions.py @@ -1,68 +1,114 @@ +import asyncio import logging +import typing as t +import werkzeug.exceptions +from starlette.concurrency import run_in_threadpool from starlette.exceptions import HTTPException from starlette.middleware.exceptions import ( ExceptionMiddleware as StarletteExceptionMiddleware, ) from starlette.requests import Request as StarletteRequest -from starlette.responses import Response +from starlette.responses import Response as StarletteResponse from starlette.types import ASGIApp, Receive, Scope, Send from connexion.exceptions import InternalServerError, ProblemException, problem +from connexion.lifecycle import ConnexionRequest, ConnexionResponse +from connexion.types import MaybeAwaitable logger = logging.getLogger(__name__) -class ExceptionMiddleware(StarletteExceptionMiddleware): - """Subclass of starlette ExceptionMiddleware to change handling of HTTP exceptions to - existing connexion behavior.""" +def connexion_wrapper( + handler: t.Callable[ + [ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse] + ] +) -> t.Callable[[StarletteRequest, Exception], t.Awaitable[StarletteResponse]]: + """Wrapper that translates Starlette requests to Connexion requests before passing + them to the error handler, and translates the returned Connexion responses to + Starlette responses.""" - def __init__(self, next_app: ASGIApp): - super().__init__(next_app) - self.add_exception_handler(ProblemException, self.problem_handler) - self.add_exception_handler(Exception, self.common_error_handler) + async def wrapper(request: StarletteRequest, exc: Exception) -> StarletteResponse: + request = ConnexionRequest.from_starlette_request(request) - @staticmethod - def problem_handler(_request: StarletteRequest, exc: ProblemException): - logger.error("%r", exc) + if asyncio.iscoroutinefunction(handler): + response = await handler(request, exc) # type: ignore + else: + response = await run_in_threadpool(handler, request, exc) - response = exc.to_problem() - - return Response( + return StarletteResponse( content=response.body, status_code=response.status_code, media_type=response.mimetype, headers=response.headers, ) - @staticmethod - def http_exception(_request: StarletteRequest, exc: HTTPException) -> Response: - logger.error("%r", exc) + return wrapper + - headers = exc.headers +class ExceptionMiddleware(StarletteExceptionMiddleware): + """Subclass of starlette ExceptionMiddleware to change handling of HTTP exceptions to + existing connexion behavior.""" - connexion_response = problem( - title=exc.detail, detail=exc.detail, status=exc.status_code, headers=headers + def __init__(self, next_app: ASGIApp): + super().__init__(next_app) + self.add_exception_handler(ProblemException, self.problem_handler) # type: ignore + self.add_exception_handler( + werkzeug.exceptions.HTTPException, self.flask_error_handler ) + self.add_exception_handler(Exception, self.common_error_handler) - return Response( - content=connexion_response.body, - status_code=connexion_response.status_code, - media_type=connexion_response.mimetype, - headers=connexion_response.headers, + def add_exception_handler( + self, + exc_class_or_status_code: t.Union[int, t.Type[Exception]], + handler: t.Callable[[ConnexionRequest, Exception], StarletteResponse], + ) -> None: + super().add_exception_handler( + exc_class_or_status_code, handler=connexion_wrapper(handler) ) @staticmethod - def common_error_handler(_request: StarletteRequest, exc: Exception) -> Response: - logger.error("%r", exc, exc_info=exc) + def problem_handler(_request: ConnexionRequest, exc: ProblemException): + """Default handler for Connexion ProblemExceptions""" + logger.error("%r", exc) + return exc.to_problem() - response = InternalServerError().to_problem() + @staticmethod + @connexion_wrapper + def http_exception( + _request: StarletteRequest, exc: HTTPException, **kwargs + ) -> StarletteResponse: + """Default handler for Starlette HTTPException""" + logger.error("%r", exc) + return problem( + title=exc.detail, + detail=exc.detail, + status=exc.status_code, + headers=exc.headers, + ) - return Response( - content=response.body, - status_code=response.status_code, - media_type=response.mimetype, - headers=response.headers, + @staticmethod + def common_error_handler( + _request: StarletteRequest, exc: Exception + ) -> ConnexionResponse: + """Default handler for any unhandled Exception""" + logger.error("%r", exc, exc_info=exc) + return InternalServerError().to_problem() + + def flask_error_handler( + self, request: StarletteRequest, exc: werkzeug.exceptions.HTTPException + ) -> ConnexionResponse: + """Default handler for Flask / werkzeug HTTPException""" + # If a handler is registered for the received status_code, call it instead. + # This is only done automatically for Starlette HTTPExceptions + if handler := self._status_handlers.get(exc.code): + starlette_exception = HTTPException(exc.code, detail=exc.description) + return handler(request, starlette_exception) + + return problem( + title=exc.name, + detail=exc.description, + status=exc.code, ) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: diff --git a/connexion/middleware/main.py b/connexion/middleware/main.py index ded10d925..075fac664 100644 --- a/connexion/middleware/main.py +++ b/connexion/middleware/main.py @@ -12,6 +12,7 @@ from connexion import utils from connexion.handlers import ResolverErrorHandler from connexion.jsonifier import Jsonifier +from connexion.lifecycle import ConnexionRequest, ConnexionResponse from connexion.middleware.abstract import SpecMiddleware from connexion.middleware.context import ContextMiddleware from connexion.middleware.exceptions import ExceptionMiddleware @@ -23,6 +24,7 @@ from connexion.middleware.swagger_ui import SwaggerUIMiddleware from connexion.options import SwaggerUIOptions from connexion.resolver import Resolver +from connexion.types import MaybeAwaitable from connexion.uri_parsing import AbstractURIParser from connexion.utils import inspect_function_arguments @@ -419,14 +421,18 @@ def add_api( self.apis.append(api) def add_error_handler( - self, code_or_exception: t.Union[int, t.Type[Exception]], function: t.Callable + self, + code_or_exception: t.Union[int, t.Type[Exception]], + function: t.Callable[ + [ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse] + ], ) -> None: """ Register a callable to handle application errors. :param code_or_exception: An exception class or the status code of HTTP exceptions to handle. - :param function: Callable that will handle exception. + :param function: Callable that will handle exception, may be async. """ if self.middleware_stack is not None: raise RuntimeError( diff --git a/connexion/middleware/security.py b/connexion/middleware/security.py index 6c4d9010f..7180c56e6 100644 --- a/connexion/middleware/security.py +++ b/connexion/middleware/security.py @@ -5,7 +5,7 @@ from starlette.types import ASGIApp, Receive, Scope, Send from connexion.exceptions import ProblemException -from connexion.lifecycle import ASGIRequest +from connexion.lifecycle import ConnexionRequest from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware from connexion.operations import AbstractOperation from connexion.security import SecurityHandlerFactory @@ -95,7 +95,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.next_app(scope, receive, send) return - request = ASGIRequest(scope) + request = ConnexionRequest(scope) await self.verification_fn(request) await self.next_app(scope, receive, send) diff --git a/connexion/security.py b/connexion/security.py index c9e273cd2..692d34d91 100644 --- a/connexion/security.py +++ b/connexion/security.py @@ -54,7 +54,7 @@ from connexion.decorators.parameter import inspect_function_arguments from connexion.exceptions import OAuthProblem, OAuthResponseProblem, OAuthScopeProblem -from connexion.lifecycle import ASGIRequest +from connexion.lifecycle import ConnexionRequest from connexion.utils import get_function_from_name logger = logging.getLogger(__name__) @@ -248,7 +248,7 @@ def get_fn(self, security_scheme, required_scopes): def _get_verify_func(self, api_key_info_func, loc, name): check_api_key_func = self.check_api_key(api_key_info_func) - def wrapper(request: ASGIRequest): + def wrapper(request: ConnexionRequest): if loc == "query": api_key = request.query_params.get(name) elif loc == "header": diff --git a/connexion/types.py b/connexion/types.py new file mode 100644 index 000000000..ee20ecfa3 --- /dev/null +++ b/connexion/types.py @@ -0,0 +1,4 @@ +import typing as t + +ReturnType = t.TypeVar("ReturnType") +MaybeAwaitable = t.Union[t.Awaitable[ReturnType], ReturnType] diff --git a/docs/context.rst b/docs/context.rst index 288fd56ea..de344674c 100644 --- a/docs/context.rst +++ b/docs/context.rst @@ -24,12 +24,13 @@ See below for an explanation of the different variables. request ------- -A ``Request`` object representing the incoming request. This is an instance of the ``ASGIRequest``. +A ``Request`` object representing the incoming request. This is an instance of the +``ConnexionRequest``. -.. dropdown:: View a detailed reference of the ``ASGIRequest`` class +.. dropdown:: View a detailed reference of the ``ConnexionRequest`` class :icon: eye - .. autoclass:: connexion.lifecycle.ASGIRequest + .. autoclass:: connexion.lifecycle.ConnexionRequest :noindex: :members: :undoc-members: diff --git a/docs/request.rst b/docs/request.rst index 18faaa3b1..308ad0df6 100644 --- a/docs/request.rst +++ b/docs/request.rst @@ -488,7 +488,7 @@ request. .. dropdown:: View a detailed reference of the ``connexion.request`` class :icon: eye - .. autoclass:: connexion.lifecycle.ASGIRequest + .. autoclass:: connexion.lifecycle.ConnexionRequest :members: :undoc-members: :inherited-members: diff --git a/docs/routing.rst b/docs/routing.rst index f01f97063..e12611700 100644 --- a/docs/routing.rst +++ b/docs/routing.rst @@ -59,7 +59,7 @@ operation: Note that :code:`HEAD` requests will be handled by the :code:`operationId` specified under the :code:`GET` operation in the specification. :code:`Connexion.request.method` can be used to -determine which request was made. See :class:`.ASGIRequest`. +determine which request was made. See :class:`.ConnexionRequest`. Automatic routing ----------------- diff --git a/docs/v3.rst b/docs/v3.rst index 8c8b4859b..578c62813 100644 --- a/docs/v3.rst +++ b/docs/v3.rst @@ -127,6 +127,8 @@ Smaller breaking changes has been added to work with Flask's ``MethodView`` specifically. * Built-in support for uWSGI has been removed. You can re-add this functionality using a custom middleware. * The request body is now passed through for ``GET``, ``HEAD``, ``DELETE``, ``CONNECT`` and ``OPTIONS`` methods as well. +* Error handlers registered on the on the underlying Flask app directly will be ignored. You + should register them on the Connexion app directly. Non-breaking changes diff --git a/tests/decorators/test_security.py b/tests/decorators/test_security.py index 41695072d..abb88eb01 100644 --- a/tests/decorators/test_security.py +++ b/tests/decorators/test_security.py @@ -10,7 +10,7 @@ OAuthResponseProblem, OAuthScopeProblem, ) -from connexion.lifecycle import ASGIRequest +from connexion.lifecycle import ConnexionRequest from connexion.security import ( NO_VALUE, ApiKeySecurityHandler, @@ -61,7 +61,7 @@ def somefunc(token): somefunc, security_handler.validate_scope, ["admin"] ) - request = ASGIRequest(scope={"type": "http", "headers": []}) + request = ConnexionRequest(scope={"type": "http", "headers": []}) assert wrapped_func(request) is NO_VALUE @@ -83,7 +83,7 @@ async def get_tokeninfo_response(*args, **kwargs): token_info_func, security_handler.validate_scope, ["admin"] ) - request = ASGIRequest( + request = ConnexionRequest( scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]} ) @@ -124,7 +124,7 @@ def somefunc(token): somefunc, security_handler.validate_scope, ["admin"] ) - request = ASGIRequest( + request = ConnexionRequest( scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]} ) @@ -143,7 +143,7 @@ def token_info(token): token_info, security_handler.validate_scope, ["admin"] ) - request = ASGIRequest( + request = ConnexionRequest( scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]} ) @@ -178,7 +178,7 @@ def somefunc(username, password, required_scopes=None): security_handler = BasicSecurityHandler() wrapped_func = security_handler._get_verify_func(somefunc) - request = ASGIRequest( + request = ConnexionRequest( scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]} ) @@ -194,7 +194,7 @@ def basic_info(username, password, required_scopes=None): security_handler = BasicSecurityHandler() wrapped_func = security_handler._get_verify_func(basic_info) - request = ASGIRequest( + request = ConnexionRequest( scope={"type": "http", "headers": [[b"authorization", b"Basic Zm9vOmJhcg=="]]} ) @@ -212,7 +212,7 @@ def apikey_info(apikey, required_scopes=None): apikey_info, "query", "auth" ) - request = ASGIRequest(scope={"type": "http", "query_string": b"auth=foobar"}) + request = ConnexionRequest(scope={"type": "http", "query_string": b"auth=foobar"}) assert await wrapped_func(request) is not None @@ -228,7 +228,9 @@ def apikey_info(apikey, required_scopes=None): apikey_info, "header", "X-Auth" ) - request = ASGIRequest(scope={"type": "http", "headers": [[b"x-auth", b"foobar"]]}) + request = ConnexionRequest( + scope={"type": "http", "headers": [[b"x-auth", b"foobar"]]} + ) assert await wrapped_func(request) is not None @@ -259,16 +261,20 @@ def apikey2_info(apikey, required_scopes=None): wrapped_func = security_handler_factory.verify_multiple_schemes(schemes) # Single key does not succeed - request = ASGIRequest(scope={"type": "http", "headers": [[b"x-auth-1", b"foobar"]]}) + request = ConnexionRequest( + scope={"type": "http", "headers": [[b"x-auth-1", b"foobar"]]} + ) assert await wrapped_func(request) is NO_VALUE - request = ASGIRequest(scope={"type": "http", "headers": [[b"x-auth-2", b"bar"]]}) + request = ConnexionRequest( + scope={"type": "http", "headers": [[b"x-auth-2", b"bar"]]} + ) assert await wrapped_func(request) is NO_VALUE # Supplying both keys does succeed - request = ASGIRequest( + request = ConnexionRequest( scope={ "type": "http", "headers": [[b"x-auth-1", b"foobar"], [b"x-auth-2", b"bar"]], @@ -287,7 +293,7 @@ async def test_verify_security_oauthproblem(): security_handler_factory = SecurityHandlerFactory() security_func = security_handler_factory.verify_security([]) - request = MagicMock(spec_set=ASGIRequest) + request = MagicMock(spec_set=ConnexionRequest) with pytest.raises(OAuthProblem) as exc_info: await security_func(request)