Skip to content

Commit

Permalink
raise TypeError instead of using assert (#608)
Browse files Browse the repository at this point in the history
  • Loading branch information
dansan committed Sep 13, 2019
1 parent 0ddea63 commit 6bf9b0e
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 15 deletions.
22 changes: 13 additions & 9 deletions starlette/routing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import inspect
import functools
import inspect
import re
import traceback
import typing
Expand Down Expand Up @@ -29,11 +29,15 @@ class Match(Enum):
FULL = 2


def assert_args_is_only_route(func: typing.Callable) -> typing.Callable:
def verify_args_is_only_route(func: typing.Callable) -> typing.Callable:
"""Raise TypeError if number of positional arguments is not exactly 1."""

@functools.wraps(func)
def wrapper(self: BaseRoute, *args: str, **kwargs: str) -> URLPath:
assert len(args) > 0, "Missing route name as the first argument."
assert len(args) < 2, "Invalid positional argument passed."
if len(args) < 1:
raise TypeError("Missing route name as the first argument.")
if len(args) > 1:
raise TypeError("Invalid positional argument passed.")
return func(self, *args, **kwargs)

return wrapper
Expand Down Expand Up @@ -195,7 +199,7 @@ def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
return Match.FULL, child_scope
return Match.NONE, {}

@assert_args_is_only_route
@verify_args_is_only_route
def url_path_for(self, *args: str, **kwargs: str) -> URLPath:
seen_params = set(kwargs.keys())
expected_params = set(self.param_convertors.keys())
Expand Down Expand Up @@ -259,7 +263,7 @@ def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
return Match.FULL, child_scope
return Match.NONE, {}

@assert_args_is_only_route
@verify_args_is_only_route
def url_path_for(self, *args: str, **kwargs: str) -> URLPath:
seen_params = set(kwargs.keys())
expected_params = set(self.param_convertors.keys())
Expand Down Expand Up @@ -331,7 +335,7 @@ def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
return Match.FULL, child_scope
return Match.NONE, {}

@assert_args_is_only_route
@verify_args_is_only_route
def url_path_for(self, *args: str, **kwargs: str) -> URLPath:
name = args[0]
if self.name is not None and name == self.name and "path" in kwargs:
Expand Down Expand Up @@ -400,7 +404,7 @@ def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
return Match.FULL, child_scope
return Match.NONE, {}

@assert_args_is_only_route
@verify_args_is_only_route
def url_path_for(self, *args: str, **kwargs: str) -> URLPath:
name = args[0]
if self.name is not None and name == self.name and "path" in kwargs:
Expand Down Expand Up @@ -584,7 +588,7 @@ async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
response = PlainTextResponse("Not Found", status_code=404)
await response(scope, receive, send)

@assert_args_is_only_route
@verify_args_is_only_route
def url_path_for(self, *args: str, **kwargs: str) -> URLPath:
for route in self.routes:
try:
Expand Down
6 changes: 4 additions & 2 deletions starlette/templating.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,10 @@ def __init__(self, directory: str) -> None:
def get_env(self, directory: str) -> "jinja2.Environment":
@jinja2.contextfunction
def url_for(context: dict, *args: str, **kwargs: typing.Any) -> str:
assert len(args) > 0, "Missing route name as the first argument."
assert len(args) < 2, "Invalid positional argument passed."
if len(args) < 1:
raise TypeError("Missing route name as the second argument.")
if len(args) > 1:
raise TypeError("Invalid positional argument passed.")
request = context["request"]
return request.url_for(args[0], **kwargs)

Expand Down
6 changes: 6 additions & 0 deletions tests/.ignore_lifespan
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
[coverage:run]
omit =
starlette/middleware/lifespan.py

[report]
exclude_lines =
pragma: no cover
pragma: nocover
raise NotImplementedError
4 changes: 2 additions & 2 deletions tests/test_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,8 @@ def test_request_url_for():
app = Starlette()

@app.route("/users/{name}")
async def func_users(request): # pragma: no cover
...
async def func_users(request):
raise NotImplementedError() # ignored by coverage

@app.route("/test")
async def func_url_for_test(request: Request):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,9 @@ def test_url_path_for():
app.url_path_for("user", username="tom/christie")
with pytest.raises(AssertionError):
app.url_path_for("user", username="")
with pytest.raises(AssertionError):
with pytest.raises(TypeError):
assert app.url_path_for("user", "args2", name="tomchristie1")
with pytest.raises(AssertionError):
with pytest.raises(TypeError):
assert app.url_path_for(name="tomchristie1")


Expand Down
6 changes: 6 additions & 0 deletions tests/test_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ def test_templates(tmpdir):
app = Starlette(debug=True)
templates = Jinja2Templates(directory=str(tmpdir))

url_for_func = templates.env.globals["url_for"]
with pytest.raises(TypeError):
assert url_for_func({}, "user", "args2", name="tomchristie")
with pytest.raises(TypeError):
assert url_for_func({}, name="tomchristie")

@app.route("/")
async def homepage(request):
return templates.TemplateResponse("index.html", {"request": request})
Expand Down

0 comments on commit 6bf9b0e

Please sign in to comment.