Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Centralize error handling in ExceptionMiddleware #1754

Merged
merged 4 commits into from
Oct 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions connexion/apps/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
from starlette.types import ASGIApp, Receive, Scope, Send

from connexion.jsonifier import Jsonifier
from connexion.lifecycle import ConnexionRequest, ConnexionResponse
from connexion.middleware import ConnexionMiddleware, MiddlewarePosition, SpecMiddleware
from connexion.middleware.lifespan import Lifespan
from connexion.options import SwaggerUIOptions
from connexion.resolver import Resolver
from connexion.types import MaybeAwaitable
from connexion.uri_parsing import AbstractURIParser


Expand Down Expand Up @@ -250,14 +252,18 @@ def decorator(func: t.Callable) -> t.Callable:

@abc.abstractmethod
def add_error_handler(
self, code_or_exception: t.Union[int, t.Type[Exception]], function: t.Callable
self,
code_or_exception: t.Union[int, t.Type[Exception]],
function: t.Callable[
[ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse]
],
) -> None:
"""
Register a callable to handle application errors.

:param code_or_exception: An exception class or the status code of HTTP exceptions to
handle.
:param function: Callable that will handle exception.
:param function: Callable that will handle exception, may be async.
"""

def test_client(self, **kwargs):
Expand Down
12 changes: 9 additions & 3 deletions connexion/apps/asynchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
from connexion.apps.abstract import AbstractApp
from connexion.decorators import StarletteDecorator
from connexion.jsonifier import Jsonifier
from connexion.lifecycle import ConnexionRequest, ConnexionResponse
from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware
from connexion.middleware.lifespan import Lifespan
from connexion.operations import AbstractOperation
from connexion.options import SwaggerUIOptions
from connexion.resolver import Resolver
from connexion.types import MaybeAwaitable
from connexion.uri_parsing import AbstractURIParser

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -88,7 +90,7 @@ def make_operation(self, operation: AbstractOperation) -> AsyncOperation:
)


class AsyncMiddlewareApp(RoutedMiddleware[AsyncApi]):
class AsyncASGIApp(RoutedMiddleware[AsyncApi]):

api_cls = AsyncApi

Expand Down Expand Up @@ -176,7 +178,7 @@ def __init__(
:param security_map: A dictionary of security handlers to use. Defaults to
:obj:`security.SECURITY_HANDLERS`
"""
self._middleware_app: AsyncMiddlewareApp = AsyncMiddlewareApp()
self._middleware_app: AsyncASGIApp = AsyncASGIApp()

super().__init__(
import_name,
Expand Down Expand Up @@ -205,6 +207,10 @@ def add_url_rule(
)

def add_error_handler(
self, code_or_exception: t.Union[int, t.Type[Exception]], function: t.Callable
self,
code_or_exception: t.Union[int, t.Type[Exception]],
function: t.Callable[
[ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse]
],
) -> None:
self.middleware.add_error_handler(code_or_exception, function)
51 changes: 15 additions & 36 deletions connexion/apps/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,22 @@
import typing as t

import flask
import werkzeug.exceptions
from a2wsgi import WSGIMiddleware
from flask import Response as FlaskResponse
from flask import signals
from starlette.types import Receive, Scope, Send

from connexion.apps.abstract import AbstractApp
from connexion.decorators import FlaskDecorator
from connexion.exceptions import InternalServerError, ProblemException, ResolverError
from connexion.exceptions import ResolverError
from connexion.frameworks import flask as flask_utils
from connexion.jsonifier import Jsonifier
from connexion.lifecycle import ConnexionRequest, ConnexionResponse
from connexion.middleware.abstract import AbstractRoutingAPI, SpecMiddleware
from connexion.middleware.lifespan import Lifespan
from connexion.operations import AbstractOperation
from connexion.options import SwaggerUIOptions
from connexion.problem import problem
from connexion.resolver import Resolver
from connexion.types import MaybeAwaitable
from connexion.uri_parsing import AbstractURIParser


Expand Down Expand Up @@ -117,44 +116,20 @@ def add_url_rule(
return self.blueprint.add_url_rule(rule, endpoint, view_func, **options)


class FlaskMiddlewareApp(SpecMiddleware):
class FlaskASGIApp(SpecMiddleware):
def __init__(self, import_name, server_args: dict, **kwargs):
self.app = flask.Flask(import_name, **server_args)
self.app.json = flask_utils.FlaskJSONProvider(self.app)
self.app.url_map.converters["float"] = flask_utils.NumberConverter
self.app.url_map.converters["int"] = flask_utils.IntegerConverter

self.set_errors_handlers()
# Propagate Errors so we can handle them in the middleware
self.app.config["PROPAGATE_EXCEPTIONS"] = True
self.app.config["TRAP_BAD_REQUEST_ERRORS"] = True
self.app.config["TRAP_HTTP_EXCEPTIONS"] = True

self.asgi_app = WSGIMiddleware(self.app.wsgi_app)

def set_errors_handlers(self):
for error_code in werkzeug.exceptions.default_exceptions:
self.app.register_error_handler(error_code, self.common_error_handler)

self.app.register_error_handler(ProblemException, self.common_error_handler)

def common_error_handler(self, exception: Exception) -> FlaskResponse:
"""Default error handler."""
if isinstance(exception, ProblemException):
response = exception.to_problem()
else:
if not isinstance(exception, werkzeug.exceptions.HTTPException):
exception = InternalServerError()

response = problem(
title=exception.name,
detail=exception.description,
status=exception.code,
)

if response.status_code >= 500:
signals.got_request_exception.send(self.app, exception=exception)

return flask.make_response(
(response.body, response.status_code, response.headers)
)

def add_api(self, specification, *, name: str = None, **kwargs):
api = FlaskApi(specification, **kwargs)

Expand All @@ -177,7 +152,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
class FlaskApp(AbstractApp):
"""Connexion Application based on ConnexionMiddleware wrapping a Flask application."""

_middleware_app: FlaskMiddlewareApp
_middleware_app: FlaskASGIApp

def __init__(
self,
Expand Down Expand Up @@ -237,7 +212,7 @@ def __init__(
:param security_map: A dictionary of security handlers to use. Defaults to
:obj:`security.SECURITY_HANDLERS`
"""
self._middleware_app = FlaskMiddlewareApp(import_name, server_args or {})
self._middleware_app = FlaskASGIApp(import_name, server_args or {})
self.app = self._middleware_app.app
super().__init__(
import_name,
Expand Down Expand Up @@ -266,6 +241,10 @@ def add_url_rule(
)

def add_error_handler(
self, code_or_exception: t.Union[int, t.Type[Exception]], function: t.Callable
self,
code_or_exception: t.Union[int, t.Type[Exception]],
function: t.Callable[
[ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse]
],
) -> None:
self.app.register_error_handler(code_or_exception, function)
4 changes: 2 additions & 2 deletions connexion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from starlette.types import Receive, Scope
from werkzeug.local import LocalProxy

from connexion.lifecycle import ASGIRequest
from connexion.lifecycle import ConnexionRequest
from connexion.operations import AbstractOperation

UNBOUND_MESSAGE = (
Expand All @@ -25,5 +25,5 @@
scope = LocalProxy(_scope, unbound_message=UNBOUND_MESSAGE)

request = LocalProxy(
lambda: ASGIRequest(scope, receive), unbound_message=UNBOUND_MESSAGE
lambda: ConnexionRequest(scope, receive), unbound_message=UNBOUND_MESSAGE
)
8 changes: 4 additions & 4 deletions connexion/decorators/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from connexion.context import context, operation
from connexion.frameworks.abstract import Framework
from connexion.http_facts import FORM_CONTENT_TYPES
from connexion.lifecycle import ASGIRequest, WSGIRequest
from connexion.lifecycle import ConnexionRequest, WSGIRequest
from connexion.operations import AbstractOperation, Swagger2Operation
from connexion.utils import (
deep_merge,
Expand All @@ -43,7 +43,7 @@ def __init__(

def _maybe_get_body(
self,
request: t.Union[WSGIRequest, ASGIRequest],
request: t.Union[WSGIRequest, ConnexionRequest],
*,
arguments: t.List[str],
has_kwargs: bool,
Expand Down Expand Up @@ -95,7 +95,7 @@ def __call__(self, function: t.Callable) -> t.Callable:
arguments, has_kwargs = inspect_function_arguments(unwrapped_function)

@functools.wraps(function)
async def wrapper(request: ASGIRequest) -> t.Any:
async def wrapper(request: ConnexionRequest) -> t.Any:
request_body = self._maybe_get_body(
request, arguments=arguments, has_kwargs=has_kwargs
)
Expand All @@ -118,7 +118,7 @@ async def wrapper(request: ASGIRequest) -> t.Any:


def prep_kwargs(
request: t.Union[WSGIRequest, ASGIRequest],
request: t.Union[WSGIRequest, ConnexionRequest],
*,
request_body: t.Any,
files: t.Dict[str, t.Any],
Expand Down
6 changes: 3 additions & 3 deletions connexion/frameworks/starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from starlette.types import Receive, Scope

from connexion.frameworks.abstract import Framework
from connexion.lifecycle import ASGIRequest
from connexion.lifecycle import ConnexionRequest
from connexion.uri_parsing import AbstractURIParser


Expand Down Expand Up @@ -48,8 +48,8 @@ def build_response(
)

@staticmethod
def get_request(*, scope: Scope, receive: Receive, uri_parser: AbstractURIParser, **kwargs) -> ASGIRequest: # type: ignore
return ASGIRequest(scope, receive, uri_parser=uri_parser)
def get_request(*, scope: Scope, receive: Receive, uri_parser: AbstractURIParser, **kwargs) -> ConnexionRequest: # type: ignore
return ConnexionRequest(scope, receive, uri_parser=uri_parser)


PATH_PARAMETER = re.compile(r"\{([^}]*)\}")
Expand Down
19 changes: 16 additions & 3 deletions connexion/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __getattr__(self, item):
return getattr(self._werkzeug_request, item)


class ASGIRequest(_RequestInterface):
class ConnexionRequest(_RequestInterface):
"""
Implementation of the Connexion :code:`_RequestInterface` representing an ASGI request.

Expand All @@ -142,7 +142,9 @@ class ASGIRequest(_RequestInterface):
"""

def __init__(self, *args, uri_parser=None, **kwargs):
self._starlette_request = StarletteRequest(*args, **kwargs)
# Might be set in `from_starlette_request` class method
if not hasattr(self, "_starlette_request"):
self._starlette_request = StarletteRequest(*args, **kwargs)
self.uri_parser = uri_parser

self._context = None
Expand All @@ -152,6 +154,16 @@ def __init__(self, *args, uri_parser=None, **kwargs):
self._form = None
self._files = None

@classmethod
def from_starlette_request(
cls, request: StarletteRequest, uri_parser=None
) -> "ConnexionRequest":
# Instantiate the class, and set the `_starlette_request` property before initializing.
self = cls.__new__(cls)
self._starlette_request = request
self.__init__(uri_parser=uri_parser) # type: ignore
return self

@property
def context(self):
if self._context is None:
Expand Down Expand Up @@ -226,7 +238,8 @@ async def get_body(self):
return await self.body() or None

def __getattr__(self, item):
return getattr(self._starlette_request, item)
if self.__getattribute__("_starlette_request"):
return getattr(self._starlette_request, item)


class ConnexionResponse:
Expand Down
Loading
Loading