Skip to content

Commit

Permalink
Support Debug extension (#1991)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex authored Feb 6, 2023
1 parent 3697c8d commit ca1711f
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 11 deletions.
29 changes: 26 additions & 3 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import anyio

from starlette.background import BackgroundTask
from starlette.requests import Request
from starlette.responses import Response, StreamingResponse
from starlette.responses import ContentStream, Response, StreamingResponse
from starlette.types import ASGIApp, Message, Receive, Scope, Send

RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
Expand Down Expand Up @@ -75,6 +76,9 @@ async def coro() -> None:

try:
message = await recv_stream.receive()
info = message.get("info", None)
if message["type"] == "http.response.debug" and info is not None:
message = await recv_stream.receive()
except anyio.EndOfStream:
if app_exc is not None:
raise app_exc
Expand All @@ -93,8 +97,8 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
if app_exc is not None:
raise app_exc

response = StreamingResponse(
status_code=message["status"], content=body_stream()
response = _StreamingResponse(
status_code=message["status"], content=body_stream(), info=info
)
response.raw_headers = message["headers"]
return response
Expand All @@ -109,3 +113,22 @@ async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
raise NotImplementedError() # pragma: no cover


class _StreamingResponse(StreamingResponse):
def __init__(
self,
content: ContentStream,
status_code: int = 200,
headers: typing.Optional[typing.Mapping[str, str]] = None,
media_type: typing.Optional[str] = None,
background: typing.Optional[BackgroundTask] = None,
info: typing.Optional[typing.Mapping[str, typing.Any]] = None,
) -> None:
self._info = info
super().__init__(content, status_code, headers, media_type, background)

async def stream_response(self, send: Send) -> None:
if self._info:
await send({"type": "http.response.debug", "info": self._info})
return await super().stream_response(send)
10 changes: 6 additions & 4 deletions starlette/templating.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ def __init__(
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
request = self.context.get("request", {})
extensions = request.get("extensions", {})
if "http.response.template" in extensions:
if "http.response.debug" in extensions:
await send(
{
"type": "http.response.template",
"template": self.template,
"context": self.context,
"type": "http.response.debug",
"info": {
"template": self.template,
"context": self.context,
},
}
)
await super().__call__(scope, receive, send)
Expand Down
8 changes: 4 additions & 4 deletions starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
"headers": headers,
"client": ["testclient", 50000],
"server": [host, port],
"extensions": {"http.response.template": {}},
"extensions": {"http.response.debug": {}},
}

request_complete = False
Expand Down Expand Up @@ -324,9 +324,9 @@ async def send(message: Message) -> None:
if not more_body:
raw_kwargs["stream"].seek(0)
response_complete.set()
elif message["type"] == "http.response.template":
template = message["template"]
context = message["context"]
elif message["type"] == "http.response.debug":
template = message["info"]["template"]
context = message["info"]["context"]

try:
with self.portal_factory() as portal:
Expand Down
28 changes: 28 additions & 0 deletions tests/test_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import pytest

from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.routing import Route
from starlette.templating import Jinja2Templates

Expand Down Expand Up @@ -60,3 +62,29 @@ def hello_world_processor(request):
assert response.text == "<html>Hello World</html>"
assert response.template.name == "index.html"
assert set(response.context.keys()) == {"request", "username"}


def test_template_with_middleware(tmpdir, test_client_factory):
path = os.path.join(tmpdir, "index.html")
with open(path, "w") as file:
file.write("<html>Hello, <a href='{{ url_for('homepage') }}'>world</a></html>")

async def homepage(request):
return templates.TemplateResponse("index.html", {"request": request})

class CustomMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
return await call_next(request)

app = Starlette(
debug=True,
routes=[Route("/", endpoint=homepage)],
middleware=[Middleware(CustomMiddleware)],
)
templates = Jinja2Templates(directory=str(tmpdir))

client = test_client_factory(app)
response = client.get("/")
assert response.text == "<html>Hello, <a href='http://testserver/'>world</a></html>"
assert response.template.name == "index.html"
assert set(response.context.keys()) == {"request"}

0 comments on commit ca1711f

Please sign in to comment.