Skip to content

Commit

Permalink
Add support for lifespan state (#337)
Browse files Browse the repository at this point in the history
* Add support for lifespan state

* Add lifespan state to http connection

* Add test

* Revert import changes

* Fix imports

* Fix imports

* Improve a bit the test

* Use typing_extensions.Literal

---------

Co-authored-by: Muhammad Furqan Habibi <furqan.habibi@hennge.com>
  • Loading branch information
Kludex and FurqanHabibi authored Sep 26, 2024
1 parent 5a7121d commit 631b930
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 3 deletions.
4 changes: 3 additions & 1 deletion mangum/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,14 @@ def infer(self, event: LambdaEvent, context: LambdaContext) -> LambdaHandler:

def __call__(self, event: LambdaEvent, context: LambdaContext) -> dict[str, Any]:
handler = self.infer(event, context)
scope = handler.scope
with ExitStack() as stack:
if self.lifespan in ("auto", "on"):
lifespan_cycle = LifespanCycle(self.app, self.lifespan)
stack.enter_context(lifespan_cycle)
scope.update({"state": lifespan_cycle.lifespan_state.copy()})

http_cycle = HTTPCycle(handler.scope, handler.body)
http_cycle = HTTPCycle(scope, handler.body)
http_response = http_cycle(self.app)

return handler(http_response)
Expand Down
6 changes: 4 additions & 2 deletions mangum/protocols/lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import enum
import logging
from types import TracebackType
from typing import Any

from mangum.exceptions import LifespanFailure, LifespanUnsupported, UnexpectedMessage
from mangum.types import ASGI, LifespanMode, Message
Expand Down Expand Up @@ -63,6 +64,7 @@ def __init__(self, app: ASGI, lifespan: LifespanMode) -> None:
self.startup_event: asyncio.Event = asyncio.Event()
self.shutdown_event: asyncio.Event = asyncio.Event()
self.logger = logging.getLogger("mangum.lifespan")
self.lifespan_state: dict[str, Any] = {}

def __enter__(self) -> None:
"""Runs the event loop for application startup."""
Expand All @@ -82,7 +84,7 @@ async def run(self) -> None:
"""Calls the application with the `lifespan` connection scope."""
try:
await self.app(
{"type": "lifespan", "asgi": {"spec_version": "2.0", "version": "3.0"}},
{"type": "lifespan", "asgi": {"spec_version": "2.0", "version": "3.0"}, "state": self.lifespan_state},
self.receive,
self.send,
)
Expand All @@ -101,7 +103,7 @@ async def receive(self) -> Message:
if self.state is LifespanCycleState.CONNECTING:
# Connection established. The next event returned by the queue will be
# `lifespan.startup` to inform the application that the connection is
# ready to receive lfiespan messages.
# ready to receive lifespan messages.
self.state = LifespanCycleState.STARTUP

elif self.state is LifespanCycleState.STARTUP:
Expand Down
50 changes: 50 additions & 0 deletions tests/test_lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from quart import Quart
from starlette.applications import Starlette
from starlette.responses import PlainTextResponse
from typing_extensions import Literal

from mangum import Mangum
from mangum.exceptions import LifespanFailure
from mangum.types import Receive, Scope, Send


@pytest.mark.parametrize(
Expand Down Expand Up @@ -209,6 +211,54 @@ async def app(scope, receive, send):
handler(mock_aws_api_gateway_event, {})


@pytest.mark.parametrize(
"mock_aws_api_gateway_event,lifespan",
[(["GET", None, None], "auto"), (["GET", None, None], "on")],
indirect=["mock_aws_api_gateway_event"],
)
def test_lifespan_state(mock_aws_api_gateway_event, lifespan: Literal["on", "auto"]) -> None:
startup_complete = False
shutdown_complete = False

async def app(scope: Scope, receive: Receive, send: Send):
nonlocal startup_complete, shutdown_complete

if scope["type"] == "lifespan":
while True:
message = await receive()
if message["type"] == "lifespan.startup":
scope["state"].update({"test_key": b"Hello, world!"})
await send({"type": "lifespan.startup.complete"})
startup_complete = True
elif message["type"] == "lifespan.shutdown":
await send({"type": "lifespan.shutdown.complete"})
shutdown_complete = True
return

if scope["type"] == "http":
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [[b"content-type", b"text/plain; charset=utf-8"]],
}
)
await send({"type": "http.response.body", "body": scope["state"]["test_key"]})

handler = Mangum(app, lifespan=lifespan)
response = handler(mock_aws_api_gateway_event, {})

assert startup_complete
assert shutdown_complete
assert response == {
"statusCode": 200,
"isBase64Encoded": False,
"headers": {"content-type": "text/plain; charset=utf-8"},
"multiValueHeaders": {},
"body": "Hello, world!",
}


@pytest.mark.parametrize("mock_aws_api_gateway_event", [["GET", None, None]], indirect=True)
def test_starlette_lifespan(mock_aws_api_gateway_event) -> None:
startup_complete = False
Expand Down

0 comments on commit 631b930

Please sign in to comment.