Skip to content

Commit

Permalink
Merge branch 'master' into add-url-for-query-params
Browse files Browse the repository at this point in the history
  • Loading branch information
aminalaee authored Jan 27, 2022
2 parents f650488 + 2b54f42 commit c28d9e9
Show file tree
Hide file tree
Showing 13 changed files with 87 additions and 20 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
---
name: Publish

on:
Expand All @@ -11,6 +10,9 @@ jobs:
name: "Publish release"
runs-on: "ubuntu-latest"

environment:
name: deploy

steps:
- uses: "actions/checkout@v2"
- uses: "actions/setup-python@v2"
Expand Down
22 changes: 18 additions & 4 deletions docs/exceptions.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@ how you return responses when errors or handled exceptions occur.

```python
from starlette.applications import Starlette
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import HTMLResponse


HTML_404_PAGE = ...
HTML_500_PAGE = ...


async def not_found(request, exc):
async def not_found(request: Request, exc: HTTPException):
return HTMLResponse(content=HTML_404_PAGE, status_code=exc.status_code)

async def server_error(request, exc):
async def server_error(request: Request, exc: HTTPException):
return HTMLResponse(content=HTML_500_PAGE, status_code=exc.status_code)


Expand All @@ -40,14 +42,26 @@ In particular you might want to override how the built-in `HTTPException` class
is handled. For example, to use JSON style responses:

```python
async def http_exception(request, exc):
async def http_exception(request: Request, exc: HTTPException):
return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)

exception_handlers = {
HTTPException: http_exception
}
```

The `HTTPException` is also equipped with the `headers` argument. Which allows the propagation
of the headers to the response class:

```python
async def http_exception(request: Request, exc: HTTPException):
return JSONResponse(
{"detail": exc.detail},
status_code=exc.status_code,
headers=exc.headers
)
```

## Errors and handled exceptions

It is important to differentiate between handled exceptions and errors.
Expand Down Expand Up @@ -76,7 +90,7 @@ The `HTTPException` class provides a base class that you can use for any
handled exceptions. The `ExceptionMiddleware` implementation defaults to
returning plain-text HTTP responses for any `HTTPException`.

* `HTTPException(status_code, detail=None)`
* `HTTPException(status_code, detail=None, headers=None)`

You should only raise `HTTPException` inside routing or endpoints. Middleware
classes should instead just return appropriate responses directly.
2 changes: 1 addition & 1 deletion docs/testclient.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ May raise `starlette.websockets.WebSocketDisconnect` if the application does not

* `.receive_text()` - Wait for incoming text sent by the application and return it.
* `.receive_bytes()` - Wait for incoming bytestring sent by the application and return it.
* `.receive_json(mode="text")` - Wait for incoming json data sent by the application and return it. Use `mode="binary"` to send JSON over binary data frames.
* `.receive_json(mode="text")` - Wait for incoming json data sent by the application and return it. Use `mode="binary"` to receive JSON over binary data frames.

May raise `starlette.websockets.WebSocketDisconnect`.

Expand Down
10 changes: 8 additions & 2 deletions starlette/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
self.scope = scope
self.receive = receive
self.send = send
self._allowed_methods = [
method
for method in ("GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS")
if getattr(self, method.lower(), None) is not None
]

def __await__(self) -> typing.Generator:
return self.dispatch().__await__()
Expand All @@ -43,9 +48,10 @@ async def method_not_allowed(self, request: Request) -> Response:
# If we're running inside a starlette application then raise an
# exception, so that the configurable exception handler can deal with
# returning the response. For plain ASGI apps, just return the response.
headers = {"Allow": ", ".join(self._allowed_methods)}
if "app" in self.scope:
raise HTTPException(status_code=405)
return PlainTextResponse("Method Not Allowed", status_code=405)
raise HTTPException(status_code=405, headers=headers)
return PlainTextResponse("Method Not Allowed", status_code=405, headers=headers)


class WebSocketEndpoint:
Expand Down
11 changes: 8 additions & 3 deletions starlette/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@


class HTTPException(Exception):
def __init__(self, status_code: int, detail: str = None) -> None:
def __init__(
self, status_code: int, detail: str = None, headers: dict = None
) -> None:
if detail is None:
detail = http.HTTPStatus(status_code).phrase
self.status_code = status_code
self.detail = detail
self.headers = headers

def __repr__(self) -> str:
class_name = self.__class__.__name__
Expand Down Expand Up @@ -99,5 +102,7 @@ async def sender(message: Message) -> None:

def http_exception(self, request: Request, exc: HTTPException) -> Response:
if exc.status_code in {204, 304}:
return Response(b"", status_code=exc.status_code)
return PlainTextResponse(exc.detail, status_code=exc.status_code)
return Response(status_code=exc.status_code, headers=exc.headers)
return PlainTextResponse(
exc.detail, status_code=exc.status_code, headers=exc.headers
)
2 changes: 1 addition & 1 deletion starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def query_params(self) -> QueryParams:
return self._query_params

@property
def path_params(self) -> dict:
def path_params(self) -> typing.Dict[str, typing.Any]:
return self.scope.get("path_params", {})

@property
Expand Down
18 changes: 14 additions & 4 deletions starlette/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
self,
content: typing.Any = None,
status_code: int = 200,
headers: dict = None,
headers: typing.Mapping[str, str] = None,
media_type: str = None,
background: BackgroundTask = None,
) -> None:
Expand Down Expand Up @@ -174,6 +174,16 @@ class PlainTextResponse(Response):
class JSONResponse(Response):
media_type = "application/json"

def __init__(
self,
content: typing.Any,
status_code: int = 200,
headers: dict = None,
media_type: str = None,
background: BackgroundTask = None,
) -> None:
super().__init__(content, status_code, headers, media_type, background)

def render(self, content: typing.Any) -> bytes:
return json.dumps(
content,
Expand All @@ -189,7 +199,7 @@ def __init__(
self,
url: typing.Union[str, URL],
status_code: int = 307,
headers: dict = None,
headers: typing.Mapping[str, str] = None,
background: BackgroundTask = None,
) -> None:
super().__init__(
Expand All @@ -203,7 +213,7 @@ def __init__(
self,
content: typing.Any,
status_code: int = 200,
headers: dict = None,
headers: typing.Mapping[str, str] = None,
media_type: str = None,
background: BackgroundTask = None,
) -> None:
Expand Down Expand Up @@ -258,7 +268,7 @@ def __init__(
self,
path: typing.Union[str, "os.PathLike[str]"],
status_code: int = 200,
headers: dict = None,
headers: typing.Mapping[str, str] = None,
media_type: str = None,
background: BackgroundTask = None,
filename: str = None,
Expand Down
7 changes: 5 additions & 2 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,13 @@ def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath:

async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.methods and scope["method"] not in self.methods:
headers = {"Allow": ", ".join(self.methods)}
if "app" in scope:
raise HTTPException(status_code=405)
raise HTTPException(status_code=405, headers=headers)
else:
response = PlainTextResponse("Method Not Allowed", status_code=405)
response = PlainTextResponse(
"Method Not Allowed", status_code=405, headers=headers
)
await response(scope, receive, send)
else:
await self.app(scope, receive, send)
Expand Down
4 changes: 2 additions & 2 deletions starlette/templating.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(
template: typing.Any,
context: dict,
status_code: int = 200,
headers: dict = None,
headers: typing.Mapping[str, str] = None,
media_type: str = None,
background: BackgroundTask = None,
):
Expand Down Expand Up @@ -85,7 +85,7 @@ def TemplateResponse(
name: str,
context: dict,
status_code: int = 200,
headers: dict = None,
headers: typing.Mapping[str, str] = None,
media_type: str = None,
background: BackgroundTask = None,
) -> _TemplateResponse:
Expand Down
1 change: 1 addition & 0 deletions tests/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def test_http_endpoint_route_method(client):
response = client.post("/")
assert response.status_code == 405
assert response.text == "Method Not Allowed"
assert response.headers["allow"] == "GET"


def test_websocket_endpoint_on_connect(test_client_factory):
Expand Down
22 changes: 22 additions & 0 deletions tests/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,18 @@ def not_acceptable(request):
raise HTTPException(status_code=406)


def no_content(request):
raise HTTPException(status_code=204)


def not_modified(request):
raise HTTPException(status_code=304)


def with_headers(request):
raise HTTPException(status_code=200, headers={"x-potato": "always"})


class HandledExcAfterResponse:
async def __call__(self, scope, receive, send):
response = PlainTextResponse("OK", status_code=200)
Expand All @@ -28,7 +36,9 @@ async def __call__(self, scope, receive, send):
routes=[
Route("/runtime_error", endpoint=raise_runtime_error),
Route("/not_acceptable", endpoint=not_acceptable),
Route("/no_content", endpoint=no_content),
Route("/not_modified", endpoint=not_modified),
Route("/with_headers", endpoint=with_headers),
Route("/handled_exc_after_response", endpoint=HandledExcAfterResponse()),
WebSocketRoute("/runtime_error", endpoint=raise_runtime_error),
]
Expand All @@ -50,12 +60,24 @@ def test_not_acceptable(client):
assert response.text == "Not Acceptable"


def test_no_content(client):
response = client.get("/no_content")
assert response.status_code == 204
assert "content-length" not in response.headers


def test_not_modified(client):
response = client.get("/not_modified")
assert response.status_code == 304
assert response.text == ""


def test_with_headers(client):
response = client.get("/with_headers")
assert response.status_code == 200
assert response.headers["x-potato"] == "always"


def test_websockets_should_raise(client):
with pytest.raises(RuntimeError):
with client.websocket_connect("/runtime_error"):
Expand Down
3 changes: 3 additions & 0 deletions tests/test_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ async def app(scope, receive, send):
client = test_client_factory(app)
response = client.get("/")
assert response.json() is None
assert response.content == b"null"


def test_redirect_response(test_client_factory):
Expand Down Expand Up @@ -330,7 +331,9 @@ def test_empty_response(test_client_factory):
app = Response()
client: TestClient = test_client_factory(app)
response = client.get("/")
assert response.content == b""
assert response.headers["content-length"] == "0"
assert "content-type" not in response.headers


def test_empty_204_response(test_client_factory):
Expand Down
1 change: 1 addition & 0 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def test_router(client):
response = client.post("/")
assert response.status_code == 405
assert response.text == "Method Not Allowed"
assert set(response.headers["allow"].split(", ")) == {"HEAD", "GET"}

response = client.get("/foo")
assert response.status_code == 404
Expand Down

0 comments on commit c28d9e9

Please sign in to comment.