Skip to content

Commit

Permalink
♻️ Relax type annotations, refactor custom handler inference method, …
Browse files Browse the repository at this point in the history
…renaming things. (#259)
  • Loading branch information
jordaneremieff authored Apr 23, 2022
1 parent f0b58cb commit b85cd4a
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 228 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,5 @@ venv.bak/

# IDE Settings
.idea/

.DS_Store
28 changes: 6 additions & 22 deletions mangum/adapter.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from itertools import chain
import logging
from itertools import chain
from contextlib import ExitStack
from typing import List, Optional, Type
import warnings

from mangum.protocols import HTTPCycle, LifespanCycle
from mangum.handlers import ALB, HTTPGateway, APIGateway, LambdaAtEdge
from mangum.exceptions import ConfigurationError
from mangum.types import (
ASGIApp,
ASGI,
LifespanMode,
LambdaConfig,
LambdaEvent,
Expand All @@ -31,7 +30,7 @@
class Mangum:
def __init__(
self,
app: ASGIApp,
app: ASGI,
lifespan: LifespanMode = "auto",
api_gateway_base_path: str = "/",
custom_handlers: Optional[List[Type[LambdaHandler]]] = None,
Expand All @@ -45,27 +44,12 @@ def __init__(
self.lifespan = lifespan
self.api_gateway_base_path = api_gateway_base_path or "/"
self.config = LambdaConfig(api_gateway_base_path=self.api_gateway_base_path)

if custom_handlers is not None:
warnings.warn( # pragma: no cover
"Support for custom event handlers is currently provisional and may "
"drastically change (or be removed entirely) in the future.",
FutureWarning,
)

self.custom_handlers = custom_handlers or []

def infer(self, event: LambdaEvent, context: LambdaContext) -> LambdaHandler:
for handler_cls in chain(
self.custom_handlers,
HANDLERS,
):
handler = handler_cls.infer(
event,
context,
self.config,
)
if handler:
for handler_cls in chain(self.custom_handlers, HANDLERS):
if handler_cls.infer(event, context, self.config):
handler = handler_cls(event, context, self.config)
break
else:
raise RuntimeError( # pragma: no cover
Expand Down
21 changes: 8 additions & 13 deletions mangum/handlers/alb.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
from itertools import islice
from typing import Dict, Generator, List, Optional, Tuple
from typing import Dict, Generator, List, Tuple
from urllib.parse import urlencode, unquote, unquote_plus


from mangum.handlers.utils import (
get_server_and_port,
handle_base64_response_body,
maybe_encode_body,
)
from mangum.types import (
HTTPResponse,
HTTPScope,
Response,
Scope,
LambdaConfig,
LambdaEvent,
LambdaContext,
LambdaHandler,
QueryParams,
)

Expand Down Expand Up @@ -86,11 +84,8 @@ class ALB:
@classmethod
def infer(
cls, event: LambdaEvent, context: LambdaContext, config: LambdaConfig
) -> Optional[LambdaHandler]:
if "requestContext" in event and "elb" in event["requestContext"]:
return cls(event, context, config)

return None
) -> bool:
return "requestContext" in event and "elb" in event["requestContext"]

def __init__(
self, event: LambdaEvent, context: LambdaContext, config: LambdaConfig
Expand All @@ -107,7 +102,7 @@ def body(self) -> bytes:
)

@property
def scope(self) -> HTTPScope:
def scope(self) -> Scope:

headers = transform_headers(self.event)
list_headers = [list(x) for x in headers]
Expand All @@ -129,7 +124,7 @@ def scope(self) -> HTTPScope:
server = get_server_and_port(uq_headers)
client = (source_ip, 0)

scope: HTTPScope = {
scope: Scope = {
"type": "http",
"method": http_method,
"http_version": "1.1",
Expand All @@ -148,7 +143,7 @@ def scope(self) -> HTTPScope:

return scope

def __call__(self, response: HTTPResponse) -> dict:
def __call__(self, response: Response) -> dict:
multi_value_headers: Dict[str, List[str]] = {}
for key, value in response["headers"]:
lower_key = key.decode().lower()
Expand Down
29 changes: 11 additions & 18 deletions mangum/handlers/api_gateway.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Tuple
from urllib.parse import urlencode

from mangum.handlers.utils import (
Expand All @@ -9,14 +9,13 @@
strip_api_gateway_path,
)
from mangum.types import (
HTTPResponse,
Response,
LambdaConfig,
Headers,
LambdaEvent,
LambdaContext,
LambdaHandler,
QueryParams,
HTTPScope,
Scope,
)


Expand Down Expand Up @@ -68,11 +67,8 @@ class APIGateway:
@classmethod
def infer(
cls, event: LambdaEvent, context: LambdaContext, config: LambdaConfig
) -> Optional[LambdaHandler]:
if "resource" in event and "requestContext" in event:
return cls(event, context, config)

return None
) -> bool:
return "resource" in event and "requestContext" in event

def __init__(
self, event: LambdaEvent, context: LambdaContext, config: LambdaConfig
Expand All @@ -89,7 +85,7 @@ def body(self) -> bytes:
)

@property
def scope(self) -> HTTPScope:
def scope(self) -> Scope:
headers = _handle_multi_value_headers_for_request(self.event)
return {
"type": "http",
Expand All @@ -114,7 +110,7 @@ def scope(self) -> HTTPScope:
"aws.context": self.context,
}

def __call__(self, response: HTTPResponse) -> dict:
def __call__(self, response: Response) -> dict:
finalized_headers, multi_value_headers = handle_multi_value_headers(
response["headers"]
)
Expand All @@ -135,11 +131,8 @@ class HTTPGateway:
@classmethod
def infer(
cls, event: LambdaEvent, context: LambdaContext, config: LambdaConfig
) -> Optional[LambdaHandler]:
if "version" in event and "requestContext" in event:
return cls(event, context, config)

return None
) -> bool:
return "version" in event and "requestContext" in event

def __init__(
self, event: LambdaEvent, context: LambdaContext, config: LambdaConfig
Expand All @@ -156,7 +149,7 @@ def body(self) -> bytes:
)

@property
def scope(self) -> HTTPScope:
def scope(self) -> Scope:
request_context = self.event["requestContext"]
event_version = self.event["version"]

Expand Down Expand Up @@ -203,7 +196,7 @@ def scope(self) -> HTTPScope:
"aws.context": self.context,
}

def __call__(self, response: HTTPResponse) -> dict:
def __call__(self, response: Response) -> dict:
if self.scope["aws.event"]["version"] == "2.0":
finalized_headers, cookies = _combine_headers_v2(response["headers"])

Expand Down
26 changes: 9 additions & 17 deletions mangum/handlers/lambda_at_edge.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,27 @@
from typing import Dict, List, Optional
from typing import Dict, List

from mangum.handlers.utils import (
handle_base64_response_body,
handle_multi_value_headers,
maybe_encode_body,
)
from mangum.types import (
HTTPScope,
HTTPResponse,
LambdaConfig,
LambdaEvent,
LambdaContext,
LambdaHandler,
)
from mangum.types import Scope, Response, LambdaConfig, LambdaEvent, LambdaContext


class LambdaAtEdge:
@classmethod
def infer(
cls, event: LambdaEvent, context: LambdaContext, config: LambdaConfig
) -> Optional[LambdaHandler]:
if (
) -> bool:
return (
"Records" in event
and len(event["Records"]) > 0
and "cf" in event["Records"][0]
):
return cls(event, context, config)
)

# FIXME: Since this is the last in the chain it doesn't get coverage by default,
# just ignoring it for now.
return None # pragma: nocover
# # just ignoring it for now.
# return None # pragma: nocover

def __init__(
self, event: LambdaEvent, context: LambdaContext, config: LambdaConfig
Expand All @@ -47,7 +39,7 @@ def body(self) -> bytes:
)

@property
def scope(self) -> HTTPScope:
def scope(self) -> Scope:
cf_request = self.event["Records"][0]["cf"]["request"]
scheme_header = cf_request["headers"].get("cloudfront-forwarded-proto", [{}])
scheme = scheme_header[0].get("value", "https")
Expand Down Expand Up @@ -84,7 +76,7 @@ def scope(self) -> HTTPScope:
"aws.context": self.context,
}

def __call__(self, response: HTTPResponse) -> dict:
def __call__(self, response: Response) -> dict:
multi_value_headers, _ = handle_multi_value_headers(response["headers"])
response_body, is_base64_encoded = handle_base64_response_body(
response["body"], multi_value_headers
Expand Down
24 changes: 8 additions & 16 deletions mangum/protocols/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,7 @@
import logging
from io import BytesIO


from mangum.types import (
ASGIApp,
ASGIReceiveEvent,
ASGISendEvent,
HTTPDisconnectEvent,
HTTPScope,
HTTPResponse,
)
from mangum.types import ASGI, Message, Scope, Response
from mangum.exceptions import UnexpectedMessage


Expand All @@ -35,12 +27,12 @@ class HTTPCycleState(enum.Enum):


class HTTPCycle:
def __init__(self, scope: HTTPScope, body: bytes) -> None:
def __init__(self, scope: Scope, body: bytes) -> None:
self.scope = scope
self.buffer = BytesIO()
self.state = HTTPCycleState.REQUEST
self.logger = logging.getLogger("mangum.http")
self.app_queue: asyncio.Queue[ASGIReceiveEvent] = asyncio.Queue()
self.app_queue: asyncio.Queue[Message] = asyncio.Queue()
self.app_queue.put_nowait(
{
"type": "http.request",
Expand All @@ -49,7 +41,7 @@ def __init__(self, scope: HTTPScope, body: bytes) -> None:
}
)

def __call__(self, app: ASGIApp) -> HTTPResponse:
def __call__(self, app: ASGI) -> Response:
asgi_instance = self.run(app)
loop = asyncio.get_event_loop()
asgi_task = loop.create_task(asgi_instance)
Expand All @@ -61,7 +53,7 @@ def __call__(self, app: ASGIApp) -> HTTPResponse:
"body": self.body,
}

async def run(self, app: ASGIApp) -> None:
async def run(self, app: ASGI) -> None:
try:
await app(self.scope, self.receive, self.send)
except BaseException:
Expand All @@ -86,10 +78,10 @@ async def run(self, app: ASGIApp) -> None:
self.body = b"Internal Server Error"
self.headers = [[b"content-type", b"text/plain; charset=utf-8"]]

async def receive(self) -> ASGIReceiveEvent:
async def receive(self) -> Message:
return await self.app_queue.get() # pragma: no cover

async def send(self, message: ASGISendEvent) -> None:
async def send(self, message: Message) -> None:
if (
self.state is HTTPCycleState.REQUEST
and message["type"] == "http.response.start"
Expand All @@ -110,7 +102,7 @@ async def send(self, message: ASGISendEvent) -> None:
self.buffer.close()

self.state = HTTPCycleState.COMPLETE
await self.app_queue.put(HTTPDisconnectEvent(type="http.disconnect"))
await self.app_queue.put({"type": "http.disconnect"})

self.logger.info(
"%s %s %s",
Expand Down
Loading

0 comments on commit b85cd4a

Please sign in to comment.