From 2bf18f6629ca8b0415fec8c1c248cca520dd4a4b Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Fri, 13 Jan 2023 22:01:08 +0100 Subject: [PATCH] Expose additional context (#1620) This PR contains 2 main changes: - Expose additional context. We now expose the scope, operation, connexion context, and receive channel as context aware globals. This makes them available to the decorators independent of the framework in between. The user will also be able to use these. I also implemented a `TestContext` class which can be used to populate the context during testing. It's minimal, but can be extended towards the future. - Rename the decorators to be framework specific. This is part of a bigger change for which I'll submit a follow up PR. I was working on this first when it became clear that the context would need to be extended, which is why this is already included. --- connexion/apis/flask_api.py | 10 +- connexion/apps/async_app.py | 10 +- connexion/context.py | 26 +++-- connexion/decorators/__init__.py | 2 +- connexion/decorators/main.py | 113 +++++++++++--------- connexion/decorators/parameter.py | 73 ++++++++----- connexion/decorators/response.py | 22 ++-- connexion/frameworks/flask.py | 12 --- connexion/frameworks/starlette.py | 12 --- connexion/lifecycle.py | 22 ++++ connexion/middleware/abstract.py | 11 +- connexion/middleware/context.py | 36 ++++++- connexion/middleware/request_validation.py | 2 - connexion/middleware/response_validation.py | 2 - connexion/middleware/routing.py | 8 +- connexion/middleware/security.py | 3 - connexion/problem.py | 4 +- connexion/testing.py | 72 +++++++++++++ tests/decorators/test_parameter.py | 68 ++++-------- 19 files changed, 302 insertions(+), 206 deletions(-) create mode 100644 connexion/testing.py diff --git a/connexion/apis/flask_api.py b/connexion/apis/flask_api.py index 69864d6dd..115daed3b 100644 --- a/connexion/apis/flask_api.py +++ b/connexion/apis/flask_api.py @@ -9,9 +9,8 @@ from flask import Response as FlaskResponse from connexion.apis.abstract import AbstractAPI -from connexion.decorators import SyncDecorator +from connexion.decorators import FlaskDecorator from connexion.frameworks import flask as flask_utils -from connexion.frameworks.flask import Flask as FlaskFramework from connexion.jsonifier import Jsonifier from connexion.operations import AbstractOperation from connexion.uri_parsing import AbstractURIParser @@ -91,12 +90,7 @@ def from_operation( @property def fn(self) -> t.Callable: - decorator = SyncDecorator( - self._operation, - uri_parser_cls=self.uri_parser_class, - framework=FlaskFramework, - parameter=True, - response=True, + decorator = FlaskDecorator( pythonic_params=self.pythonic_params, jsonifier=self.api.jsonifier, ) diff --git a/connexion/apps/async_app.py b/connexion/apps/async_app.py index cd1db9fa5..dd280e3e2 100644 --- a/connexion/apps/async_app.py +++ b/connexion/apps/async_app.py @@ -15,9 +15,8 @@ from connexion.apis.abstract import AbstractAPI from connexion.apps.abstract import AbstractApp -from connexion.decorators import AsyncDecorator +from connexion.decorators import StarletteDecorator from connexion.exceptions import MissingMiddleware, ProblemException -from connexion.frameworks.starlette import Starlette as StarletteFramework from connexion.middleware.main import ConnexionMiddleware from connexion.middleware.routing import ROUTING_CONTEXT from connexion.operations import AbstractOperation @@ -192,12 +191,7 @@ def from_operation( @property def fn(self) -> t.Callable: - decorator = AsyncDecorator( - self._operation, - uri_parser_cls=self._operation.uri_parser_class, - framework=StarletteFramework, - parameter=True, - response=True, + decorator = StarletteDecorator( pythonic_params=self.pythonic_params, jsonifier=self.api.jsonifier, ) diff --git a/connexion/context.py b/connexion/context.py index 519da947b..9f10aced0 100644 --- a/connexion/context.py +++ b/connexion/context.py @@ -1,12 +1,24 @@ from contextvars import ContextVar -from starlette.types import Scope +from starlette.types import Receive, Scope +from werkzeug.local import LocalProxy + +from connexion.operations import AbstractOperation + +UNBOUND_MESSAGE = ( + "Working outside of operation context. Make sure your app is wrapped in a " + "ContextMiddleware and you're processing a request while accessing the context." +) -_scope: ContextVar[Scope] = ContextVar("SCOPE") +_context: ContextVar[dict] = ContextVar("CONTEXT") +context = LocalProxy(_context, unbound_message=UNBOUND_MESSAGE) -def __getattr__(name): - if name == "scope": - return _scope.get() - if name == "context": - return _scope.get().get("extensions", {}).get("connexion_context", {}) +_operation: ContextVar[AbstractOperation] = ContextVar("OPERATION") +operation = LocalProxy(_operation, unbound_message=UNBOUND_MESSAGE) + +_receive: ContextVar[Receive] = ContextVar("RECEIVE") +receive = LocalProxy(_receive, unbound_message=UNBOUND_MESSAGE) + +_scope: ContextVar[Scope] = ContextVar("SCOPE") +scope = LocalProxy(_scope, unbound_message=UNBOUND_MESSAGE) diff --git a/connexion/decorators/__init__.py b/connexion/decorators/__init__.py index d6716ba55..187274535 100644 --- a/connexion/decorators/__init__.py +++ b/connexion/decorators/__init__.py @@ -1,4 +1,4 @@ """ This module defines decorators which Connexion uses to wrap user provided view functions. """ -from .main import AsyncDecorator, SyncDecorator # noqa +from .main import FlaskDecorator, StarletteDecorator # noqa diff --git a/connexion/decorators/main.py b/connexion/decorators/main.py index f2b02a917..3b565ff72 100644 --- a/connexion/decorators/main.py +++ b/connexion/decorators/main.py @@ -1,16 +1,17 @@ import abc import asyncio import functools +import json import typing as t from asgiref.sync import async_to_sync from starlette.concurrency import run_in_threadpool +from connexion.context import operation, receive, scope from connexion.decorators.parameter import ( AsyncParameterDecorator, BaseParameterDecorator, SyncParameterDecorator, - inspect_function_arguments, ) from connexion.decorators.response import ( AsyncResponseDecorator, @@ -18,39 +19,28 @@ SyncResponseDecorator, ) from connexion.frameworks.abstract import Framework -from connexion.operations import AbstractOperation +from connexion.frameworks.flask import Flask as FlaskFramework +from connexion.frameworks.starlette import Starlette as StarletteFramework from connexion.uri_parsing import AbstractURIParser class BaseDecorator: """Base class for connexion decorators.""" + framework: t.Type[Framework] + def __init__( self, - operation_spec: AbstractOperation, *, - uri_parser_cls: t.Type[AbstractURIParser], - framework: t.Type[Framework], - parameter: bool, - response: bool, pythonic_params: bool = False, - jsonifier, + uri_parser_class: AbstractURIParser = None, + jsonifier=json, ) -> None: - self.operation_spec = operation_spec - self.uri_parser = uri_parser_cls( - operation_spec.parameters, operation_spec.body_definition() - ) - self.framework = framework - self.produces = self.operation_spec.produces - self.parameter = parameter - self.response = response self.pythonic_params = pythonic_params + self.uri_parser_class = uri_parser_class self.jsonifier = jsonifier - if self.parameter: - self.arguments, self.has_kwargs = inspect_function_arguments( - operation_spec.function - ) + self.arguments, self.has_kwargs = None, None @property @abc.abstractmethod @@ -68,27 +58,25 @@ def _sync_async_decorator(self) -> t.Callable[[t.Callable], t.Callable]: """Decorator to translate between sync and async functions.""" raise NotImplementedError + @property + def uri_parser(self): + uri_parser_class = self.uri_parser_class or operation.uri_parser_class + return uri_parser_class(operation.parameters, operation.body_definition()) + def decorate(self, function: t.Callable) -> t.Callable: """Decorate a function with decorators based on the operation.""" function = self._sync_async_decorator(function) - if self.parameter: - parameter_decorator = self._parameter_decorator_cls( - self.operation_spec, - get_body_fn=self.framework.get_body, - arguments=self.arguments, - has_kwargs=self.has_kwargs, - pythonic_params=self.pythonic_params, - ) - function = parameter_decorator(function) + parameter_decorator = self._parameter_decorator_cls( + pythonic_params=self.pythonic_params, + ) + function = parameter_decorator(function) - if self.response: - response_decorator = self._response_decorator_cls( - self.operation_spec, - framework=self.framework, - jsonifier=self.jsonifier, - ) - function = response_decorator(function) + response_decorator = self._response_decorator_cls( + framework=self.framework, + jsonifier=self.jsonifier, + ) + function = response_decorator(function) return function @@ -97,7 +85,13 @@ def __call__(self, function: t.Callable) -> t.Callable: raise NotImplementedError -class SyncDecorator(BaseDecorator): +class FlaskDecorator(BaseDecorator): + """Decorator for usage with Flask. The parameter decorator works with a Flask request, + and provides Flask datastructures to the view function. The response decorator returns + a Flask response""" + + framework = FlaskFramework + @property def _parameter_decorator_cls(self) -> t.Type[SyncParameterDecorator]: return SyncParameterDecorator @@ -123,25 +117,33 @@ def wrapper(*args, **kwargs) -> t.Callable: def __call__(self, function: t.Callable) -> t.Callable: @functools.wraps(function) def wrapper(*args, **kwargs): - # TODO: move into parameter decorator? - connexion_request = self.framework.get_request( - *args, uri_parser=self.uri_parser, **kwargs - ) - + request = self.framework.get_request(uri_parser=self.uri_parser) decorated_function = self.decorate(function) - return decorated_function(connexion_request) + return decorated_function(request) return wrapper -class AsyncDecorator(BaseDecorator): +class ASGIDecorator(BaseDecorator): + """Decorator for usage with ASGI apps. The parameter decorator works with a Starlette request, + and provides Starlette datastructures to the view function. This works for any ASGI app, since + we get the request via the connexion context provided by ASGI middleware. + + This decorator does not parse responses, but passes them directly to the ASGI App.""" + + framework = StarletteFramework + @property def _parameter_decorator_cls(self) -> t.Type[AsyncParameterDecorator]: return AsyncParameterDecorator @property - def _response_decorator_cls(self) -> t.Type[AsyncResponseDecorator]: - return AsyncResponseDecorator + def _response_decorator_cls(self) -> t.Type[BaseResponseDecorator]: + class NoResponseDecorator(BaseResponseDecorator): + def __call__(self, function: t.Callable) -> t.Callable: + return lambda request: function(request) + + return NoResponseDecorator @property def _sync_async_decorator(self) -> t.Callable[[t.Callable], t.Callable]: @@ -160,15 +162,24 @@ async def wrapper(*args, **kwargs): def __call__(self, function: t.Callable) -> t.Callable: @functools.wraps(function) async def wrapper(*args, **kwargs): - # TODO: move into parameter decorator? - connexion_request = self.framework.get_request( - *args, uri_parser=self.uri_parser, **kwargs + request = self.framework.get_request( + uri_parser=self.uri_parser, scope=scope, receive=receive ) - decorated_function = self.decorate(function) - response = decorated_function(connexion_request) + response = decorated_function(request) while asyncio.iscoroutine(response): response = await response return response return wrapper + + +class StarletteDecorator(ASGIDecorator): + """Decorator for usage with Connexion or Starlette apps. The parameter decorator works with a + Starlette request, and provides Starlette datastructures to the view function. + + The response decorator returns Starlette responses.""" + + @property + def _response_decorator_cls(self) -> t.Type[AsyncResponseDecorator]: + return AsyncResponseDecorator diff --git a/connexion/decorators/parameter.py b/connexion/decorators/parameter.py index 468f5381b..84cc1fa4e 100644 --- a/connexion/decorators/parameter.py +++ b/connexion/decorators/parameter.py @@ -14,6 +14,9 @@ import inflection +from connexion.context import context, operation +from connexion.frameworks.flask import Flask as FlaskFramework +from connexion.frameworks.starlette import Starlette as StarletteFramework from connexion.http_facts import FORM_CONTENT_TYPES from connexion.lifecycle import ConnexionRequest, MiddlewareRequest from connexion.operations import AbstractOperation, Swagger2Operation @@ -27,30 +30,26 @@ class BaseParameterDecorator: def __init__( self, - operation: AbstractOperation, *, - get_body_fn: t.Callable, - arguments: t.List[str], - has_kwargs: bool, pythonic_params: bool = False, ) -> None: - self.operation = operation - self.get_body_fn = get_body_fn - self.arguments = arguments - self.has_kwargs = has_kwargs self.sanitize_fn = pythonic if pythonic_params else sanitized def _maybe_get_body( - self, request: t.Union[ConnexionRequest, MiddlewareRequest] + self, + request: t.Union[ConnexionRequest, MiddlewareRequest], + *, + arguments: t.List[str], + has_kwargs: bool, ) -> t.Any: - body_name = self.sanitize_fn(self.operation.body_name(request.content_type)) + body_name = self.sanitize_fn(operation.body_name(request.content_type)) # Pass form contents separately for Swagger2 for backward compatibility with # Connexion 2 Checking for body_name is not enough - if (body_name in self.arguments or self.has_kwargs) or ( + if (body_name in arguments or has_kwargs) or ( request.mimetype in FORM_CONTENT_TYPES - and isinstance(self.operation, Swagger2Operation) + and isinstance(operation, Swagger2Operation) ): - return self.get_body_fn(request) + return request.get_body() else: return None @@ -60,17 +59,24 @@ def __call__(self, function: t.Callable) -> t.Callable: class SyncParameterDecorator(BaseParameterDecorator): + + framework = FlaskFramework + def __call__(self, function: t.Callable) -> t.Callable: + unwrapped_function = unwrap_decorators(function) + arguments, has_kwargs = inspect_function_arguments(unwrapped_function) + @functools.wraps(function) - def wrapper(request: t.Union[ConnexionRequest, MiddlewareRequest]) -> t.Any: - request_body = self._maybe_get_body(request) + def wrapper(request: ConnexionRequest) -> t.Any: + request_body = self._maybe_get_body( + request, arguments=arguments, has_kwargs=has_kwargs + ) kwargs = prep_kwargs( request, - operation=self.operation, request_body=request_body, - arguments=self.arguments, - has_kwargs=self.has_kwargs, + arguments=arguments, + has_kwargs=has_kwargs, sanitize=self.sanitize_fn, ) @@ -80,22 +86,27 @@ def wrapper(request: t.Union[ConnexionRequest, MiddlewareRequest]) -> t.Any: class AsyncParameterDecorator(BaseParameterDecorator): + + framework = StarletteFramework + def __call__(self, function: t.Callable) -> t.Callable: + unwrapped_function = unwrap_decorators(function) + arguments, has_kwargs = inspect_function_arguments(unwrapped_function) + @functools.wraps(function) - async def wrapper( - request: t.Union[ConnexionRequest, MiddlewareRequest] - ) -> t.Any: - request_body = self._maybe_get_body(request) + async def wrapper(request: MiddlewareRequest) -> t.Any: + request_body = self._maybe_get_body( + request, arguments=arguments, has_kwargs=has_kwargs + ) while asyncio.iscoroutine(request_body): request_body = await request_body kwargs = prep_kwargs( request, - operation=self.operation, request_body=request_body, - arguments=self.arguments, - has_kwargs=self.has_kwargs, + arguments=arguments, + has_kwargs=has_kwargs, sanitize=self.sanitize_fn, ) @@ -107,7 +118,6 @@ async def wrapper( def prep_kwargs( request: t.Union[ConnexionRequest, MiddlewareRequest], *, - operation: AbstractOperation, request_body: t.Any, arguments: t.List[str], has_kwargs: bool, @@ -129,18 +139,25 @@ def prep_kwargs( kwargs = {sanitize(k): v for k, v in kwargs.items()} # add context info (e.g. from security decorator) - for key, value in request.context.items(): + for key, value in context.items(): if has_kwargs or key in arguments: kwargs[key] = value else: logger.debug("Context parameter '%s' not in function arguments", key) # attempt to provide the request context to the function if CONTEXT_NAME in arguments: - kwargs[CONTEXT_NAME] = request.context + kwargs[CONTEXT_NAME] = context return kwargs +def unwrap_decorators(function: t.Callable) -> t.Callable: + """Unwrap decorators to return the original function.""" + while hasattr(function, "__wrapped__"): + function = function.__wrapped__ # type: ignore + return function + + def inspect_function_arguments(function: t.Callable) -> t.Tuple[t.List[str], bool]: """ Returns the list of variables names of a function and if it diff --git a/connexion/decorators/response.py b/connexion/decorators/response.py index 082556fdf..1553cedfc 100644 --- a/connexion/decorators/response.py +++ b/connexion/decorators/response.py @@ -6,21 +6,18 @@ import typing as t from enum import Enum +from connexion.context import operation from connexion.datastructures import NoContent from connexion.exceptions import NonConformingResponseHeaders from connexion.frameworks.abstract import Framework from connexion.lifecycle import ConnexionResponse, MiddlewareResponse -from connexion.operations import AbstractOperation from connexion.utils import is_json_mimetype logger = logging.getLogger(__name__) class BaseResponseDecorator: - def __init__( - self, operation: AbstractOperation, *, framework: t.Type[Framework], jsonifier - ): - self.operation = operation + def __init__(self, *, framework: t.Type[Framework], jsonifier): self.framework = framework self.jsonifier = jsonifier @@ -39,7 +36,8 @@ def build_framework_response(self, handler_response): data, content_type=content_type, status_code=status_code, headers=headers ) - def _deduct_content_type(self, data: t.Any, headers: dict) -> str: + @staticmethod + def _deduct_content_type(data: t.Any, headers: dict) -> str: """Deduct the response content type from the returned data, headers and operation spec. :param data: Response data @@ -52,7 +50,7 @@ def _deduct_content_type(self, data: t.Any, headers: dict) -> str: content_type = headers.get("Content-Type") # TODO: don't default - produces = list(set(self.operation.produces)) + produces = list(set(operation.produces)) if data is not None and not produces: produces = ["application/json"] @@ -60,7 +58,7 @@ def _deduct_content_type(self, data: t.Any, headers: dict) -> str: if content_type not in produces: raise NonConformingResponseHeaders( f"Returned content type ({content_type}) is not defined in operation spec " - f"({self.operation.produces})." + f"({operation.produces})." ) else: if not produces: @@ -153,13 +151,13 @@ def _unpack_handler_response( class SyncResponseDecorator(BaseResponseDecorator): def __call__(self, function: t.Callable) -> t.Callable: @functools.wraps(function) - def wrapper(request): + def wrapper(*args, **kwargs): """ This method converts a handler response to a framework response. The handler response can be a ConnexionResponse, a framework response, a tuple or an object. """ - handler_response = function(request) + handler_response = function(*args, **kwargs) if self.framework.is_framework_response(handler_response): return handler_response elif isinstance(handler_response, (ConnexionResponse, MiddlewareResponse)): @@ -173,13 +171,13 @@ def wrapper(request): class AsyncResponseDecorator(BaseResponseDecorator): def __call__(self, function: t.Callable) -> t.Callable: @functools.wraps(function) - async def wrapper(request): + async def wrapper(*args, **kwargs): """ This method converts a handler response to a framework response. The handler response can be a ConnexionResponse, a framework response, a tuple or an object. """ - handler_response = await function(request) + handler_response = await function(*args, **kwargs) if self.framework.is_framework_response(handler_response): return handler_response elif isinstance(handler_response, (ConnexionResponse, MiddlewareResponse)): diff --git a/connexion/frameworks/flask.py b/connexion/frameworks/flask.py index 5cd5d1b77..d6cbc7b53 100644 --- a/connexion/frameworks/flask.py +++ b/connexion/frameworks/flask.py @@ -11,10 +11,8 @@ import werkzeug from connexion.frameworks.abstract import Framework -from connexion.http_facts import FORM_CONTENT_TYPES from connexion.lifecycle import ConnexionRequest from connexion.uri_parsing import AbstractURIParser -from connexion.utils import is_json_mimetype class Flask(Framework): @@ -58,16 +56,6 @@ def build_response( def get_request(*, uri_parser: AbstractURIParser, **kwargs) -> ConnexionRequest: # type: ignore return ConnexionRequest(flask.request, uri_parser=uri_parser) - @staticmethod - def get_body(request): - if is_json_mimetype(request.content_type): - return request.get_json(silent=True) - elif request.mimetype in FORM_CONTENT_TYPES: - return request.form - else: - # Return explicit None instead of empty bytestring so it is handled as null downstream - return request.get_data() or None - PATH_PARAMETER = re.compile(r"\{([^}]*)\}") diff --git a/connexion/frameworks/starlette.py b/connexion/frameworks/starlette.py index 89eeb9501..3121de6d1 100644 --- a/connexion/frameworks/starlette.py +++ b/connexion/frameworks/starlette.py @@ -8,9 +8,7 @@ from starlette.types import Receive, Scope from connexion.frameworks.abstract import Framework -from connexion.http_facts import FORM_CONTENT_TYPES from connexion.lifecycle import MiddlewareRequest, MiddlewareResponse -from connexion.utils import is_json_mimetype class Starlette(Framework): @@ -54,16 +52,6 @@ def build_response( def get_request(*, scope: Scope, receive: Receive, **kwargs) -> MiddlewareRequest: # type: ignore return MiddlewareRequest(scope, receive) - @staticmethod - async def get_body(request): - if is_json_mimetype(request.content_type): - return await request.json() - elif request.mimetype in FORM_CONTENT_TYPES: - return await request.form() - else: - # Return explicit None instead of empty bytestring so it is handled as null downstream - return await request.data() or None - PATH_PARAMETER = re.compile(r"\{([^}]*)\}") PATH_PARAMETER_CONVERTERS = {"integer": "int", "number": "float", "path": "path"} diff --git a/connexion/lifecycle.py b/connexion/lifecycle.py index f19b213f7..ebf120312 100644 --- a/connexion/lifecycle.py +++ b/connexion/lifecycle.py @@ -9,6 +9,9 @@ from starlette.requests import Request as StarletteRequest from starlette.responses import StreamingResponse as StarletteStreamingResponse +from connexion.http_facts import FORM_CONTENT_TYPES +from connexion.utils import is_json_mimetype + class ConnexionRequest: def __init__(self, flask_request: FlaskRequest, uri_parser=None): @@ -41,6 +44,16 @@ def form(self): form_data = self.uri_parser.resolve_form(form) return form_data + def get_body(self): + """Get body based on content type""" + if is_json_mimetype(self.content_type): + return self.get_json(silent=True) + elif self.mimetype in FORM_CONTENT_TYPES: + return self.form + else: + # Return explicit None instead of empty bytestring so it is handled as null downstream + return self.get_data() or None + def __getattr__(self, item): return getattr(self._flask_request, item) @@ -98,6 +111,15 @@ def files(self): # TODO: separate files? return {} + async def get_body(self): + if is_json_mimetype(self.content_type): + return await self.json() + elif self.mimetype in FORM_CONTENT_TYPES: + return await self.form() + else: + # Return explicit None instead of empty bytestring so it is handled as null downstream + return await self.data() or None + class MiddlewareResponse(StarletteStreamingResponse): """Wraps starlette StreamingResponse so it can easily be extended.""" diff --git a/connexion/middleware/abstract.py b/connexion/middleware/abstract.py index 4dc1ccddc..0247ba6a8 100644 --- a/connexion/middleware/abstract.py +++ b/connexion/middleware/abstract.py @@ -40,10 +40,6 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: class RoutedAPI(AbstractSpecAPI, t.Generic[OP]): - - operation_cls: t.Type[OP] - """The operation this middleware uses, which should implement the RoutingOperation protocol.""" - def __init__( self, specification: t.Union[pathlib.Path, str, dict], @@ -70,7 +66,12 @@ def add_paths(self) -> None: def add_operation(self, path: str, method: str) -> None: operation_spec_cls = self.specification.operation_cls operation = operation_spec_cls.from_spec( - self.specification, self, path, method, self.resolver + self.specification, + self, + path, + method, + self.resolver, + uri_parser_class=self.options.uri_parser_class, ) routed_operation = self.make_operation(operation) self.operations[operation.operation_id] = routed_operation diff --git a/connexion/middleware/context.py b/connexion/middleware/context.py index 978ac8cfe..71a1febf7 100644 --- a/connexion/middleware/context.py +++ b/connexion/middleware/context.py @@ -2,13 +2,39 @@ middleware stack, so it exposes the scope passed to the application""" from starlette.types import ASGIApp, Receive, Scope, Send -from connexion.context import _scope +from connexion.context import _context, _operation, _receive, _scope +from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware +from connexion.operations import AbstractOperation -class ContextMiddleware: - def __init__(self, app: ASGIApp) -> None: - self.app = app +class ContextOperation: + def __init__( + self, + next_app: ASGIApp, + *, + operation: AbstractOperation, + ) -> None: + self.next_app = next_app + self.operation = operation async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + _context.set(scope.get("extensions", {}).get("connexion_context", {})) + _operation.set(self.operation) + _receive.set(receive) _scope.set(scope) - await self.app(scope, receive, send) + await self.next_app(scope, receive, send) + + +class ContextAPI(RoutedAPI[ContextOperation]): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.add_paths() + + def make_operation(self, operation: AbstractOperation) -> ContextOperation: + return ContextOperation(self.next_app, operation=operation) + + +class ContextMiddleware(RoutedMiddleware[ContextAPI]): + """Middleware to expose operation specific context to application.""" + + api_cls = ContextAPI diff --git a/connexion/middleware/request_validation.py b/connexion/middleware/request_validation.py index be1741ac0..b15d44f88 100644 --- a/connexion/middleware/request_validation.py +++ b/connexion/middleware/request_validation.py @@ -120,8 +120,6 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send): class RequestValidationAPI(RoutedAPI[RequestValidationOperation]): """Validation API.""" - operation_cls = RequestValidationOperation - def __init__( self, *args, diff --git a/connexion/middleware/response_validation.py b/connexion/middleware/response_validation.py index 9e5a57efb..761858c89 100644 --- a/connexion/middleware/response_validation.py +++ b/connexion/middleware/response_validation.py @@ -125,8 +125,6 @@ async def wrapped_send(message: t.MutableMapping[str, t.Any]) -> None: class ResponseValidationAPI(RoutedAPI[ResponseValidationOperation]): """Validation API.""" - operation_cls = ResponseValidationOperation - def __init__( self, *args, diff --git a/connexion/middleware/routing.py b/connexion/middleware/routing.py index beedb02d9..fe0643eb2 100644 --- a/connexion/middleware/routing.py +++ b/connexion/middleware/routing.py @@ -68,12 +68,18 @@ def __init__( resolver=resolver, resolver_error_handler=resolver_error_handler, debug=debug, + **kwargs, ) def add_operation(self, path: str, method: str) -> None: operation_cls = self.specification.operation_cls operation = operation_cls.from_spec( - self.specification, self, path, method, self.resolver + self.specification, + self, + path, + method, + self.resolver, + uri_parser_class=self.options.uri_parser_class, ) routing_operation = RoutingOperation.from_operation( operation, next_app=self.next_app diff --git a/connexion/middleware/security.py b/connexion/middleware/security.py index 376f5286b..043aa6b2d 100644 --- a/connexion/middleware/security.py +++ b/connexion/middleware/security.py @@ -205,9 +205,6 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: class SecurityAPI(RoutedAPI[SecurityOperation]): - - operation_cls = SecurityOperation - def __init__(self, *args, auth_all_paths: bool = False, **kwargs): super().__init__(*args, **kwargs) diff --git a/connexion/problem.py b/connexion/problem.py index e749facd1..8d0488b50 100644 --- a/connexion/problem.py +++ b/connexion/problem.py @@ -4,8 +4,6 @@ to communicate distinct "problem types" to non-human consumers. """ -from .lifecycle import ConnexionResponse - def problem(status, title, detail, type=None, instance=None, headers=None, ext=None): """ @@ -33,6 +31,8 @@ def problem(status, title, detail, type=None, instance=None, headers=None, ext=N :return: error response :rtype: ConnexionResponse """ + from .lifecycle import ConnexionResponse # prevent circular import + if not type: type = "about:blank" diff --git a/connexion/testing.py b/connexion/testing.py new file mode 100644 index 000000000..20a4335bf --- /dev/null +++ b/connexion/testing.py @@ -0,0 +1,72 @@ +import contextvars +import typing as t +from unittest.mock import MagicMock + +from starlette.types import Receive, Scope + +from connexion.context import _context, _operation, _receive, _scope +from connexion.operations import AbstractOperation + + +class TestContext: + __test__ = False # Pytest + + def __init__( + self, + *, + context: dict = None, + operation: AbstractOperation = None, + receive: Receive = None, + scope: Scope = None, + ) -> None: + self.context = context if context is not None else self.build_context() + self.operation = operation if operation is not None else self.build_operation() + self.receive = receive if receive is not None else self.build_receive() + self.scope = scope if scope is not None else self.build_scope() + + self.tokens: t.Dict[str, contextvars.Token] = {} + + def __enter__(self) -> None: + self.tokens["context"] = _context.set(self.context) + self.tokens["operation"] = _operation.set(self.operation) + self.tokens["receive"] = _receive.set(self.receive) + self.tokens["scope"] = _scope.set(self.scope) + return + + def __exit__(self, type, value, traceback): + _context.reset(self.tokens["context"]) + _operation.reset(self.tokens["operation"]) + _receive.reset(self.tokens["receive"]) + _scope.reset(self.tokens["scope"]) + return False + + @staticmethod + def build_context() -> dict: + return {} + + @staticmethod + def build_operation() -> AbstractOperation: + return MagicMock(name="operation") + + @staticmethod + def build_receive() -> Receive: + async def receive() -> t.MutableMapping[str, t.Any]: + return { + "type": "http.request", + "body": b"", + } + + return receive + + @staticmethod + def build_scope(**kwargs) -> Scope: + scope = { + "type": "http", + "query_string": b"", + "headers": [(b"Content-Type", b"application/octet-stream")], + } + + for key, value in kwargs.items(): + scope[key] = value + + return scope diff --git a/tests/decorators/test_parameter.py b/tests/decorators/test_parameter.py index 778804f36..03324632a 100644 --- a/tests/decorators/test_parameter.py +++ b/tests/decorators/test_parameter.py @@ -3,14 +3,13 @@ from connexion.decorators.parameter import ( AsyncParameterDecorator, SyncParameterDecorator, - inspect_function_arguments, pythonic, ) +from connexion.testing import TestContext def test_sync_injection(): request = MagicMock(name="request") - request.query_params = {} request.path_params = {"p1": "123"} func = MagicMock() @@ -18,25 +17,18 @@ def test_sync_injection(): def handler(**kwargs): func(**kwargs) - def get_body_fn(_request): - return {} - operation = MagicMock(name="operation") operation.body_name = lambda _: "body" - arguments, has_kwargs = inspect_function_arguments(handler) - - parameter_decorator = SyncParameterDecorator( - operation, get_body_fn=get_body_fn, arguments=arguments, has_kwargs=has_kwargs - ) - decorated_handler = parameter_decorator(handler) - decorated_handler(request) + with TestContext(operation=operation): + parameter_decorator = SyncParameterDecorator() + decorated_handler = parameter_decorator(handler) + decorated_handler(request) func.assert_called_with(p1="123") async def test_async_injection(): request = MagicMock(name="request") - request.query_params = {} request.path_params = {"p1": "123"} func = MagicMock() @@ -44,74 +36,56 @@ async def test_async_injection(): async def handler(**kwargs): func(**kwargs) - def get_body_fn(_request): - return {} - operation = MagicMock(name="operation") operation.body_name = lambda _: "body" - arguments, has_kwargs = inspect_function_arguments(handler) - - parameter_decorator = AsyncParameterDecorator( - operation, get_body_fn=get_body_fn, arguments=arguments, has_kwargs=has_kwargs - ) - decorated_handler = parameter_decorator(handler) - await decorated_handler(request) + with TestContext(operation=operation): + parameter_decorator = AsyncParameterDecorator() + decorated_handler = parameter_decorator(handler) + await decorated_handler(request) func.assert_called_with(p1="123") def test_sync_injection_with_context(): request = MagicMock(name="request") - request.query_params = {} request.path_params = {"p1": "123"} - request.context = {} func = MagicMock() def handler(context_, **kwargs): func(context_, **kwargs) - def get_body_fn(_request): - return {} + context = {"test": "success"} operation = MagicMock(name="operation") operation.body_name = lambda _: "body" - arguments, has_kwargs = inspect_function_arguments(handler) - - parameter_decorator = SyncParameterDecorator( - operation, get_body_fn=get_body_fn, arguments=arguments, has_kwargs=has_kwargs - ) - decorated_handler = parameter_decorator(handler) - decorated_handler(request) - func.assert_called_with(request.context, p1="123") + with TestContext(context=context, operation=operation): + parameter_decorator = SyncParameterDecorator() + decorated_handler = parameter_decorator(handler) + decorated_handler(request) + func.assert_called_with(context, p1="123", test="success") async def test_async_injection_with_context(): request = MagicMock(name="request") - request.query_params = {} request.path_params = {"p1": "123"} - request.context = {} func = MagicMock() async def handler(context_, **kwargs): func(context_, **kwargs) - def get_body_fn(_request): - return {} + context = {"test": "success"} operation = MagicMock(name="operation") operation.body_name = lambda _: "body" - arguments, has_kwargs = inspect_function_arguments(handler) - - parameter_decorator = AsyncParameterDecorator( - operation, get_body_fn=get_body_fn, arguments=arguments, has_kwargs=has_kwargs - ) - decorated_handler = parameter_decorator(handler) - await decorated_handler(request) - func.assert_called_with(request.context, p1="123") + with TestContext(context=context, operation=operation): + parameter_decorator = AsyncParameterDecorator() + decorated_handler = parameter_decorator(handler) + await decorated_handler(request) + func.assert_called_with(context, p1="123", test="success") def test_pythonic_params():