diff --git a/.gitignore b/.gitignore index 5ffa2980..2ce352d8 100644 --- a/.gitignore +++ b/.gitignore @@ -105,3 +105,5 @@ venv.bak/ # IDE Settings .idea/ + +.DS_Store diff --git a/mangum/adapter.py b/mangum/adapter.py index 31a0b468..7a937359 100644 --- a/mangum/adapter.py +++ b/mangum/adapter.py @@ -1,14 +1,13 @@ -from itertools import chain import logging +from itertools import chain from contextlib import ExitStack from typing import List, Optional, Type -import warnings from mangum.protocols import HTTPCycle, LifespanCycle from mangum.handlers import ALB, HTTPGateway, APIGateway, LambdaAtEdge from mangum.exceptions import ConfigurationError from mangum.types import ( - ASGIApp, + ASGI, LifespanMode, LambdaConfig, LambdaEvent, @@ -31,7 +30,7 @@ class Mangum: def __init__( self, - app: ASGIApp, + app: ASGI, lifespan: LifespanMode = "auto", api_gateway_base_path: str = "/", custom_handlers: Optional[List[Type[LambdaHandler]]] = None, @@ -45,27 +44,12 @@ def __init__( self.lifespan = lifespan self.api_gateway_base_path = api_gateway_base_path or "/" self.config = LambdaConfig(api_gateway_base_path=self.api_gateway_base_path) - - if custom_handlers is not None: - warnings.warn( # pragma: no cover - "Support for custom event handlers is currently provisional and may " - "drastically change (or be removed entirely) in the future.", - FutureWarning, - ) - self.custom_handlers = custom_handlers or [] def infer(self, event: LambdaEvent, context: LambdaContext) -> LambdaHandler: - for handler_cls in chain( - self.custom_handlers, - HANDLERS, - ): - handler = handler_cls.infer( - event, - context, - self.config, - ) - if handler: + for handler_cls in chain(self.custom_handlers, HANDLERS): + if handler_cls.infer(event, context, self.config): + handler = handler_cls(event, context, self.config) break else: raise RuntimeError( # pragma: no cover diff --git a/mangum/handlers/alb.py b/mangum/handlers/alb.py index 23dc62b9..02ef0a9f 100644 --- a/mangum/handlers/alb.py +++ b/mangum/handlers/alb.py @@ -1,20 +1,18 @@ from itertools import islice -from typing import Dict, Generator, List, Optional, Tuple +from typing import Dict, Generator, List, Tuple from urllib.parse import urlencode, unquote, unquote_plus - from mangum.handlers.utils import ( get_server_and_port, handle_base64_response_body, maybe_encode_body, ) from mangum.types import ( - HTTPResponse, - HTTPScope, + Response, + Scope, LambdaConfig, LambdaEvent, LambdaContext, - LambdaHandler, QueryParams, ) @@ -86,11 +84,8 @@ class ALB: @classmethod def infer( cls, event: LambdaEvent, context: LambdaContext, config: LambdaConfig - ) -> Optional[LambdaHandler]: - if "requestContext" in event and "elb" in event["requestContext"]: - return cls(event, context, config) - - return None + ) -> bool: + return "requestContext" in event and "elb" in event["requestContext"] def __init__( self, event: LambdaEvent, context: LambdaContext, config: LambdaConfig @@ -107,7 +102,7 @@ def body(self) -> bytes: ) @property - def scope(self) -> HTTPScope: + def scope(self) -> Scope: headers = transform_headers(self.event) list_headers = [list(x) for x in headers] @@ -129,7 +124,7 @@ def scope(self) -> HTTPScope: server = get_server_and_port(uq_headers) client = (source_ip, 0) - scope: HTTPScope = { + scope: Scope = { "type": "http", "method": http_method, "http_version": "1.1", @@ -148,7 +143,7 @@ def scope(self) -> HTTPScope: return scope - def __call__(self, response: HTTPResponse) -> dict: + def __call__(self, response: Response) -> dict: multi_value_headers: Dict[str, List[str]] = {} for key, value in response["headers"]: lower_key = key.decode().lower() diff --git a/mangum/handlers/api_gateway.py b/mangum/handlers/api_gateway.py index 682642e2..9bca9e28 100644 --- a/mangum/handlers/api_gateway.py +++ b/mangum/handlers/api_gateway.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Tuple from urllib.parse import urlencode from mangum.handlers.utils import ( @@ -9,14 +9,13 @@ strip_api_gateway_path, ) from mangum.types import ( - HTTPResponse, + Response, LambdaConfig, Headers, LambdaEvent, LambdaContext, - LambdaHandler, QueryParams, - HTTPScope, + Scope, ) @@ -68,11 +67,8 @@ class APIGateway: @classmethod def infer( cls, event: LambdaEvent, context: LambdaContext, config: LambdaConfig - ) -> Optional[LambdaHandler]: - if "resource" in event and "requestContext" in event: - return cls(event, context, config) - - return None + ) -> bool: + return "resource" in event and "requestContext" in event def __init__( self, event: LambdaEvent, context: LambdaContext, config: LambdaConfig @@ -89,7 +85,7 @@ def body(self) -> bytes: ) @property - def scope(self) -> HTTPScope: + def scope(self) -> Scope: headers = _handle_multi_value_headers_for_request(self.event) return { "type": "http", @@ -114,7 +110,7 @@ def scope(self) -> HTTPScope: "aws.context": self.context, } - def __call__(self, response: HTTPResponse) -> dict: + def __call__(self, response: Response) -> dict: finalized_headers, multi_value_headers = handle_multi_value_headers( response["headers"] ) @@ -135,11 +131,8 @@ class HTTPGateway: @classmethod def infer( cls, event: LambdaEvent, context: LambdaContext, config: LambdaConfig - ) -> Optional[LambdaHandler]: - if "version" in event and "requestContext" in event: - return cls(event, context, config) - - return None + ) -> bool: + return "version" in event and "requestContext" in event def __init__( self, event: LambdaEvent, context: LambdaContext, config: LambdaConfig @@ -156,7 +149,7 @@ def body(self) -> bytes: ) @property - def scope(self) -> HTTPScope: + def scope(self) -> Scope: request_context = self.event["requestContext"] event_version = self.event["version"] @@ -203,7 +196,7 @@ def scope(self) -> HTTPScope: "aws.context": self.context, } - def __call__(self, response: HTTPResponse) -> dict: + def __call__(self, response: Response) -> dict: if self.scope["aws.event"]["version"] == "2.0": finalized_headers, cookies = _combine_headers_v2(response["headers"]) diff --git a/mangum/handlers/lambda_at_edge.py b/mangum/handlers/lambda_at_edge.py index 7a91147b..6d307f05 100644 --- a/mangum/handlers/lambda_at_edge.py +++ b/mangum/handlers/lambda_at_edge.py @@ -1,35 +1,27 @@ -from typing import Dict, List, Optional +from typing import Dict, List from mangum.handlers.utils import ( handle_base64_response_body, handle_multi_value_headers, maybe_encode_body, ) -from mangum.types import ( - HTTPScope, - HTTPResponse, - LambdaConfig, - LambdaEvent, - LambdaContext, - LambdaHandler, -) +from mangum.types import Scope, Response, LambdaConfig, LambdaEvent, LambdaContext class LambdaAtEdge: @classmethod def infer( cls, event: LambdaEvent, context: LambdaContext, config: LambdaConfig - ) -> Optional[LambdaHandler]: - if ( + ) -> bool: + return ( "Records" in event and len(event["Records"]) > 0 and "cf" in event["Records"][0] - ): - return cls(event, context, config) + ) # FIXME: Since this is the last in the chain it doesn't get coverage by default, - # just ignoring it for now. - return None # pragma: nocover + # # just ignoring it for now. + # return None # pragma: nocover def __init__( self, event: LambdaEvent, context: LambdaContext, config: LambdaConfig @@ -47,7 +39,7 @@ def body(self) -> bytes: ) @property - def scope(self) -> HTTPScope: + def scope(self) -> Scope: cf_request = self.event["Records"][0]["cf"]["request"] scheme_header = cf_request["headers"].get("cloudfront-forwarded-proto", [{}]) scheme = scheme_header[0].get("value", "https") @@ -84,7 +76,7 @@ def scope(self) -> HTTPScope: "aws.context": self.context, } - def __call__(self, response: HTTPResponse) -> dict: + def __call__(self, response: Response) -> dict: multi_value_headers, _ = handle_multi_value_headers(response["headers"]) response_body, is_base64_encoded = handle_base64_response_body( response["body"], multi_value_headers diff --git a/mangum/protocols/http.py b/mangum/protocols/http.py index fd452f5a..b43b11b4 100644 --- a/mangum/protocols/http.py +++ b/mangum/protocols/http.py @@ -3,15 +3,7 @@ import logging from io import BytesIO - -from mangum.types import ( - ASGIApp, - ASGIReceiveEvent, - ASGISendEvent, - HTTPDisconnectEvent, - HTTPScope, - HTTPResponse, -) +from mangum.types import ASGI, Message, Scope, Response from mangum.exceptions import UnexpectedMessage @@ -35,12 +27,12 @@ class HTTPCycleState(enum.Enum): class HTTPCycle: - def __init__(self, scope: HTTPScope, body: bytes) -> None: + def __init__(self, scope: Scope, body: bytes) -> None: self.scope = scope self.buffer = BytesIO() self.state = HTTPCycleState.REQUEST self.logger = logging.getLogger("mangum.http") - self.app_queue: asyncio.Queue[ASGIReceiveEvent] = asyncio.Queue() + self.app_queue: asyncio.Queue[Message] = asyncio.Queue() self.app_queue.put_nowait( { "type": "http.request", @@ -49,7 +41,7 @@ def __init__(self, scope: HTTPScope, body: bytes) -> None: } ) - def __call__(self, app: ASGIApp) -> HTTPResponse: + def __call__(self, app: ASGI) -> Response: asgi_instance = self.run(app) loop = asyncio.get_event_loop() asgi_task = loop.create_task(asgi_instance) @@ -61,7 +53,7 @@ def __call__(self, app: ASGIApp) -> HTTPResponse: "body": self.body, } - async def run(self, app: ASGIApp) -> None: + async def run(self, app: ASGI) -> None: try: await app(self.scope, self.receive, self.send) except BaseException: @@ -86,10 +78,10 @@ async def run(self, app: ASGIApp) -> None: self.body = b"Internal Server Error" self.headers = [[b"content-type", b"text/plain; charset=utf-8"]] - async def receive(self) -> ASGIReceiveEvent: + async def receive(self) -> Message: return await self.app_queue.get() # pragma: no cover - async def send(self, message: ASGISendEvent) -> None: + async def send(self, message: Message) -> None: if ( self.state is HTTPCycleState.REQUEST and message["type"] == "http.response.start" @@ -110,7 +102,7 @@ async def send(self, message: ASGISendEvent) -> None: self.buffer.close() self.state = HTTPCycleState.COMPLETE - await self.app_queue.put(HTTPDisconnectEvent(type="http.disconnect")) + await self.app_queue.put({"type": "http.disconnect"}) self.logger.info( "%s %s %s", diff --git a/mangum/protocols/lifespan.py b/mangum/protocols/lifespan.py index f1f2a6e2..ca873924 100644 --- a/mangum/protocols/lifespan.py +++ b/mangum/protocols/lifespan.py @@ -4,15 +4,7 @@ from types import TracebackType from typing import Optional, Type - -from mangum.types import ( - ASGIApp, - LifespanMode, - ASGIReceiveEvent, - ASGISendEvent, - LifespanShutdownEvent, - LifespanStartupEvent, -) +from mangum.types import ASGI, LifespanMode, Message from mangum.exceptions import LifespanUnsupported, LifespanFailure, UnexpectedMessage @@ -60,13 +52,13 @@ class LifespanCycle: shutdown flow. """ - def __init__(self, app: ASGIApp, lifespan: LifespanMode) -> None: + def __init__(self, app: ASGI, lifespan: LifespanMode) -> None: self.app = app self.lifespan = lifespan self.state: LifespanCycleState = LifespanCycleState.CONNECTING self.exception: Optional[BaseException] = None self.loop = asyncio.get_event_loop() - self.app_queue: asyncio.Queue[ASGIReceiveEvent] = asyncio.Queue() + self.app_queue: asyncio.Queue[Message] = asyncio.Queue() self.startup_event: asyncio.Event = asyncio.Event() self.shutdown_event: asyncio.Event = asyncio.Event() self.logger = logging.getLogger("mangum.lifespan") @@ -103,7 +95,7 @@ async def run(self) -> None: self.startup_event.set() self.shutdown_event.set() - async def receive(self) -> ASGIReceiveEvent: + async def receive(self) -> Message: """Awaited by the application to receive ASGI `lifespan` events.""" if self.state is LifespanCycleState.CONNECTING: @@ -121,7 +113,7 @@ async def receive(self) -> ASGIReceiveEvent: return await self.app_queue.get() - async def send(self, message: ASGISendEvent) -> None: + async def send(self, message: Message) -> None: """Awaited by the application to send ASGI `lifespan` events.""" message_type = message["type"] self.logger.info( @@ -169,7 +161,7 @@ async def send(self, message: ASGISendEvent) -> None: async def startup(self) -> None: """Pushes the `lifespan` startup event to the queue and handles errors.""" self.logger.info("Waiting for application startup.") - await self.app_queue.put(LifespanStartupEvent(type="lifespan.startup")) + await self.app_queue.put({"type": "lifespan.startup"}) await self.startup_event.wait() if self.state is LifespanCycleState.FAILED: raise LifespanFailure(self.exception) @@ -182,7 +174,7 @@ async def startup(self) -> None: async def shutdown(self) -> None: """Pushes the `lifespan` shutdown event to the queue and handles errors.""" self.logger.info("Waiting for application shutdown.") - await self.app_queue.put(LifespanShutdownEvent(type="lifespan.shutdown")) + await self.app_queue.put({"type": "lifespan.shutdown"}) await self.shutdown_event.wait() if self.state is LifespanCycleState.FAILED: raise LifespanFailure(self.exception) diff --git a/mangum/types.py b/mangum/types.py index bcf4a0d0..20e80950 100644 --- a/mangum/types.py +++ b/mangum/types.py @@ -2,7 +2,6 @@ from typing import ( List, - Tuple, Dict, Any, Union, @@ -95,121 +94,21 @@ def get_remaining_time_in_millis(self) -> int: Headers: TypeAlias = List[List[bytes]] +Message: TypeAlias = MutableMapping[str, Any] +Scope: TypeAlias = MutableMapping[str, Any] +Receive: TypeAlias = Callable[[], Awaitable[Message]] +Send: TypeAlias = Callable[[Message], Awaitable[None]] -class HTTPRequestEvent(TypedDict): - type: Literal["http.request"] - body: bytes - more_body: bool - - -class HTTPDisconnectEvent(TypedDict): - type: Literal["http.disconnect"] - - -class HTTPResponseStartEvent(TypedDict): - type: Literal["http.response.start"] - status: int - headers: Headers - - -class HTTPResponseBodyEvent(TypedDict): - type: Literal["http.response.body"] - body: bytes - more_body: bool - - -class LifespanStartupEvent(TypedDict): - type: Literal["lifespan.startup"] - - -class LifespanStartupCompleteEvent(TypedDict): - type: Literal["lifespan.startup.complete"] - - -class LifespanStartupFailedEvent(TypedDict): - type: Literal["lifespan.startup.failed"] - message: str - - -class LifespanShutdownEvent(TypedDict): - type: Literal["lifespan.shutdown"] - - -class LifespanShutdownCompleteEvent(TypedDict): - type: Literal["lifespan.shutdown.complete"] - - -class LifespanShutdownFailedEvent(TypedDict): - type: Literal["lifespan.shutdown.failed"] - message: str - - -ASGIReceiveEvent: TypeAlias = Union[ - HTTPRequestEvent, - HTTPDisconnectEvent, - LifespanStartupEvent, - LifespanShutdownEvent, -] - -ASGISendEvent: TypeAlias = Union[ - HTTPResponseStartEvent, - HTTPResponseBodyEvent, - HTTPDisconnectEvent, - LifespanStartupCompleteEvent, - LifespanStartupFailedEvent, - LifespanShutdownCompleteEvent, - LifespanShutdownFailedEvent, -] - - -ASGIReceive: TypeAlias = Callable[[], Awaitable[ASGIReceiveEvent]] -ASGISend: TypeAlias = Callable[[ASGISendEvent], Awaitable[None]] - - -class ASGISpec(TypedDict): - spec_version: Literal["2.0"] - version: Literal["3.0"] - - -HTTPScope = TypedDict( - "HTTPScope", - { - "type": Literal["http"], - "asgi": ASGISpec, - "http_version": Literal["1.1"], - "scheme": str, - "method": str, - "path": str, - "raw_path": None, - "root_path": Literal[""], - "query_string": bytes, - "headers": Headers, - "client": Tuple[str, int], - "server": Tuple[str, int], - "aws.event": LambdaEvent, - "aws.context": LambdaContext, - }, -) - - -class LifespanScope(TypedDict): - type: Literal["lifespan"] - asgi: ASGISpec +class ASGI(Protocol): + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + ... # pragma: no cover LifespanMode: TypeAlias = Literal["auto", "on", "off"] -Scope: TypeAlias = Union[HTTPScope, LifespanScope] - - -class ASGIApp(Protocol): - async def __call__( - self, scope: Scope, receive: ASGIReceive, send: ASGISend - ) -> None: - ... # pragma: no cover -class HTTPResponse(TypedDict): +class Response(TypedDict): status: int headers: Headers body: bytes @@ -220,15 +119,13 @@ class LambdaConfig(TypedDict): class LambdaHandler(Protocol): + def __init__(self, *args: Any) -> None: + ... # pragma: no cover + @classmethod def infer( cls, event: LambdaEvent, context: LambdaContext, config: LambdaConfig - ) -> Optional[LambdaHandler]: - ... # pragma: no cover - - def __init__( - self, event: LambdaEvent, context: LambdaContext, config: LambdaConfig - ) -> None: + ) -> bool: ... # pragma: no cover @property @@ -236,8 +133,8 @@ def body(self) -> bytes: ... # pragma: no cover @property - def scope(self) -> HTTPScope: + def scope(self) -> Scope: ... # pragma: no cover - def __call__(self, response: HTTPResponse) -> dict: + def __call__(self, response: Response) -> dict: ... # pragma: no cover diff --git a/tests/handlers/test_custom.py b/tests/handlers/test_custom.py index 24286ce3..c330bb60 100644 --- a/tests/handlers/test_custom.py +++ b/tests/handlers/test_custom.py @@ -1,12 +1,9 @@ -from typing import Optional - from mangum.types import ( - HTTPScope, + Scope, Headers, LambdaConfig, LambdaContext, LambdaEvent, - LambdaHandler, ) @@ -14,11 +11,8 @@ class CustomHandler: @classmethod def infer( cls, event: LambdaEvent, context: LambdaContext, config: LambdaConfig - ) -> Optional[LambdaHandler]: - if "my-custom-key" in event: - return cls(event, context, config) - - return None + ) -> bool: + return "my-custom-key" in event def __init__( self, event: LambdaEvent, context: LambdaContext, config: LambdaConfig @@ -32,7 +26,7 @@ def body(self) -> bytes: return b"My request body" @property - def scope(self) -> HTTPScope: + def scope(self) -> Scope: headers = {} return { "type": "http",