Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

♻️ Relax type annotations, refactor custom handler inferences, naming #259

Merged
merged 1 commit into from
Apr 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,5 @@ venv.bak/

# IDE Settings
.idea/

.DS_Store
28 changes: 6 additions & 22 deletions mangum/adapter.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from itertools import chain
import logging
from itertools import chain
from contextlib import ExitStack
from typing import List, Optional, Type
import warnings

from mangum.protocols import HTTPCycle, LifespanCycle
from mangum.handlers import ALB, HTTPGateway, APIGateway, LambdaAtEdge
from mangum.exceptions import ConfigurationError
from mangum.types import (
ASGIApp,
ASGI,
LifespanMode,
LambdaConfig,
LambdaEvent,
Expand All @@ -31,7 +30,7 @@
class Mangum:
def __init__(
self,
app: ASGIApp,
app: ASGI,
lifespan: LifespanMode = "auto",
api_gateway_base_path: str = "/",
custom_handlers: Optional[List[Type[LambdaHandler]]] = None,
Expand All @@ -45,27 +44,12 @@ def __init__(
self.lifespan = lifespan
self.api_gateway_base_path = api_gateway_base_path or "/"
self.config = LambdaConfig(api_gateway_base_path=self.api_gateway_base_path)

if custom_handlers is not None:
warnings.warn( # pragma: no cover
"Support for custom event handlers is currently provisional and may "
"drastically change (or be removed entirely) in the future.",
FutureWarning,
)

self.custom_handlers = custom_handlers or []

def infer(self, event: LambdaEvent, context: LambdaContext) -> LambdaHandler:
for handler_cls in chain(
self.custom_handlers,
HANDLERS,
):
handler = handler_cls.infer(
event,
context,
self.config,
)
if handler:
for handler_cls in chain(self.custom_handlers, HANDLERS):
if handler_cls.infer(event, context, self.config):
handler = handler_cls(event, context, self.config)
break
else:
raise RuntimeError( # pragma: no cover
Expand Down
21 changes: 8 additions & 13 deletions mangum/handlers/alb.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
from itertools import islice
from typing import Dict, Generator, List, Optional, Tuple
from typing import Dict, Generator, List, Tuple
from urllib.parse import urlencode, unquote, unquote_plus


from mangum.handlers.utils import (
get_server_and_port,
handle_base64_response_body,
maybe_encode_body,
)
from mangum.types import (
HTTPResponse,
HTTPScope,
Response,
Scope,
LambdaConfig,
LambdaEvent,
LambdaContext,
LambdaHandler,
QueryParams,
)

Expand Down Expand Up @@ -86,11 +84,8 @@ class ALB:
@classmethod
def infer(
cls, event: LambdaEvent, context: LambdaContext, config: LambdaConfig
) -> Optional[LambdaHandler]:
if "requestContext" in event and "elb" in event["requestContext"]:
return cls(event, context, config)

return None
) -> bool:
return "requestContext" in event and "elb" in event["requestContext"]

def __init__(
self, event: LambdaEvent, context: LambdaContext, config: LambdaConfig
Expand All @@ -107,7 +102,7 @@ def body(self) -> bytes:
)

@property
def scope(self) -> HTTPScope:
def scope(self) -> Scope:

headers = transform_headers(self.event)
list_headers = [list(x) for x in headers]
Expand All @@ -129,7 +124,7 @@ def scope(self) -> HTTPScope:
server = get_server_and_port(uq_headers)
client = (source_ip, 0)

scope: HTTPScope = {
scope: Scope = {
"type": "http",
"method": http_method,
"http_version": "1.1",
Expand All @@ -148,7 +143,7 @@ def scope(self) -> HTTPScope:

return scope

def __call__(self, response: HTTPResponse) -> dict:
def __call__(self, response: Response) -> dict:
multi_value_headers: Dict[str, List[str]] = {}
for key, value in response["headers"]:
lower_key = key.decode().lower()
Expand Down
29 changes: 11 additions & 18 deletions mangum/handlers/api_gateway.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Tuple
from urllib.parse import urlencode

from mangum.handlers.utils import (
Expand All @@ -9,14 +9,13 @@
strip_api_gateway_path,
)
from mangum.types import (
HTTPResponse,
Response,
LambdaConfig,
Headers,
LambdaEvent,
LambdaContext,
LambdaHandler,
QueryParams,
HTTPScope,
Scope,
)


Expand Down Expand Up @@ -68,11 +67,8 @@ class APIGateway:
@classmethod
def infer(
cls, event: LambdaEvent, context: LambdaContext, config: LambdaConfig
) -> Optional[LambdaHandler]:
if "resource" in event and "requestContext" in event:
return cls(event, context, config)

return None
) -> bool:
return "resource" in event and "requestContext" in event

def __init__(
self, event: LambdaEvent, context: LambdaContext, config: LambdaConfig
Expand All @@ -89,7 +85,7 @@ def body(self) -> bytes:
)

@property
def scope(self) -> HTTPScope:
def scope(self) -> Scope:
headers = _handle_multi_value_headers_for_request(self.event)
return {
"type": "http",
Expand All @@ -114,7 +110,7 @@ def scope(self) -> HTTPScope:
"aws.context": self.context,
}

def __call__(self, response: HTTPResponse) -> dict:
def __call__(self, response: Response) -> dict:
finalized_headers, multi_value_headers = handle_multi_value_headers(
response["headers"]
)
Expand All @@ -135,11 +131,8 @@ class HTTPGateway:
@classmethod
def infer(
cls, event: LambdaEvent, context: LambdaContext, config: LambdaConfig
) -> Optional[LambdaHandler]:
if "version" in event and "requestContext" in event:
return cls(event, context, config)

return None
) -> bool:
return "version" in event and "requestContext" in event

def __init__(
self, event: LambdaEvent, context: LambdaContext, config: LambdaConfig
Expand All @@ -156,7 +149,7 @@ def body(self) -> bytes:
)

@property
def scope(self) -> HTTPScope:
def scope(self) -> Scope:
request_context = self.event["requestContext"]
event_version = self.event["version"]

Expand Down Expand Up @@ -203,7 +196,7 @@ def scope(self) -> HTTPScope:
"aws.context": self.context,
}

def __call__(self, response: HTTPResponse) -> dict:
def __call__(self, response: Response) -> dict:
if self.scope["aws.event"]["version"] == "2.0":
finalized_headers, cookies = _combine_headers_v2(response["headers"])

Expand Down
26 changes: 9 additions & 17 deletions mangum/handlers/lambda_at_edge.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,27 @@
from typing import Dict, List, Optional
from typing import Dict, List

from mangum.handlers.utils import (
handle_base64_response_body,
handle_multi_value_headers,
maybe_encode_body,
)
from mangum.types import (
HTTPScope,
HTTPResponse,
LambdaConfig,
LambdaEvent,
LambdaContext,
LambdaHandler,
)
from mangum.types import Scope, Response, LambdaConfig, LambdaEvent, LambdaContext


class LambdaAtEdge:
@classmethod
def infer(
cls, event: LambdaEvent, context: LambdaContext, config: LambdaConfig
) -> Optional[LambdaHandler]:
if (
) -> bool:
return (
"Records" in event
and len(event["Records"]) > 0
and "cf" in event["Records"][0]
):
return cls(event, context, config)
)

# FIXME: Since this is the last in the chain it doesn't get coverage by default,
# just ignoring it for now.
return None # pragma: nocover
# # just ignoring it for now.
# return None # pragma: nocover

def __init__(
self, event: LambdaEvent, context: LambdaContext, config: LambdaConfig
Expand All @@ -47,7 +39,7 @@ def body(self) -> bytes:
)

@property
def scope(self) -> HTTPScope:
def scope(self) -> Scope:
cf_request = self.event["Records"][0]["cf"]["request"]
scheme_header = cf_request["headers"].get("cloudfront-forwarded-proto", [{}])
scheme = scheme_header[0].get("value", "https")
Expand Down Expand Up @@ -84,7 +76,7 @@ def scope(self) -> HTTPScope:
"aws.context": self.context,
}

def __call__(self, response: HTTPResponse) -> dict:
def __call__(self, response: Response) -> dict:
multi_value_headers, _ = handle_multi_value_headers(response["headers"])
response_body, is_base64_encoded = handle_base64_response_body(
response["body"], multi_value_headers
Expand Down
24 changes: 8 additions & 16 deletions mangum/protocols/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,7 @@
import logging
from io import BytesIO


from mangum.types import (
ASGIApp,
ASGIReceiveEvent,
ASGISendEvent,
HTTPDisconnectEvent,
HTTPScope,
HTTPResponse,
)
from mangum.types import ASGI, Message, Scope, Response
from mangum.exceptions import UnexpectedMessage


Expand All @@ -35,12 +27,12 @@ class HTTPCycleState(enum.Enum):


class HTTPCycle:
def __init__(self, scope: HTTPScope, body: bytes) -> None:
def __init__(self, scope: Scope, body: bytes) -> None:
self.scope = scope
self.buffer = BytesIO()
self.state = HTTPCycleState.REQUEST
self.logger = logging.getLogger("mangum.http")
self.app_queue: asyncio.Queue[ASGIReceiveEvent] = asyncio.Queue()
self.app_queue: asyncio.Queue[Message] = asyncio.Queue()
self.app_queue.put_nowait(
{
"type": "http.request",
Expand All @@ -49,7 +41,7 @@ def __init__(self, scope: HTTPScope, body: bytes) -> None:
}
)

def __call__(self, app: ASGIApp) -> HTTPResponse:
def __call__(self, app: ASGI) -> Response:
asgi_instance = self.run(app)
loop = asyncio.get_event_loop()
asgi_task = loop.create_task(asgi_instance)
Expand All @@ -61,7 +53,7 @@ def __call__(self, app: ASGIApp) -> HTTPResponse:
"body": self.body,
}

async def run(self, app: ASGIApp) -> None:
async def run(self, app: ASGI) -> None:
try:
await app(self.scope, self.receive, self.send)
except BaseException:
Expand All @@ -86,10 +78,10 @@ async def run(self, app: ASGIApp) -> None:
self.body = b"Internal Server Error"
self.headers = [[b"content-type", b"text/plain; charset=utf-8"]]

async def receive(self) -> ASGIReceiveEvent:
async def receive(self) -> Message:
return await self.app_queue.get() # pragma: no cover

async def send(self, message: ASGISendEvent) -> None:
async def send(self, message: Message) -> None:
if (
self.state is HTTPCycleState.REQUEST
and message["type"] == "http.response.start"
Expand All @@ -110,7 +102,7 @@ async def send(self, message: ASGISendEvent) -> None:
self.buffer.close()

self.state = HTTPCycleState.COMPLETE
await self.app_queue.put(HTTPDisconnectEvent(type="http.disconnect"))
await self.app_queue.put({"type": "http.disconnect"})

self.logger.info(
"%s %s %s",
Expand Down
Loading