diff --git a/starlette/routing.py b/starlette/routing.py index 46f8b56285..8233f1cbe4 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -1,6 +1,6 @@ import asyncio -import inspect import functools +import inspect import re import traceback import typing @@ -29,11 +29,15 @@ class Match(Enum): FULL = 2 -def assert_args_is_only_route(func: typing.Callable) -> typing.Callable: +def verify_args_is_only_route(func: typing.Callable) -> typing.Callable: + """Raise TypeError if number of positional arguments is not exactly 1.""" + @functools.wraps(func) def wrapper(self: BaseRoute, *args: str, **kwargs: str) -> URLPath: - assert len(args) > 0, "Missing route name as the first argument." - assert len(args) < 2, "Invalid positional argument passed." + if len(args) < 1: + raise TypeError("Missing route name as the first argument.") + if len(args) > 1: + raise TypeError("Invalid positional argument passed.") return func(self, *args, **kwargs) return wrapper @@ -195,7 +199,7 @@ def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: return Match.FULL, child_scope return Match.NONE, {} - @assert_args_is_only_route + @verify_args_is_only_route def url_path_for(self, *args: str, **kwargs: str) -> URLPath: seen_params = set(kwargs.keys()) expected_params = set(self.param_convertors.keys()) @@ -259,7 +263,7 @@ def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: return Match.FULL, child_scope return Match.NONE, {} - @assert_args_is_only_route + @verify_args_is_only_route def url_path_for(self, *args: str, **kwargs: str) -> URLPath: seen_params = set(kwargs.keys()) expected_params = set(self.param_convertors.keys()) @@ -331,7 +335,7 @@ def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: return Match.FULL, child_scope return Match.NONE, {} - @assert_args_is_only_route + @verify_args_is_only_route def url_path_for(self, *args: str, **kwargs: str) -> URLPath: name = args[0] if self.name is not None and name == self.name and "path" in kwargs: @@ -400,7 +404,7 @@ def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: return Match.FULL, child_scope return Match.NONE, {} - @assert_args_is_only_route + @verify_args_is_only_route def url_path_for(self, *args: str, **kwargs: str) -> URLPath: name = args[0] if self.name is not None and name == self.name and "path" in kwargs: @@ -584,7 +588,7 @@ async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None: response = PlainTextResponse("Not Found", status_code=404) await response(scope, receive, send) - @assert_args_is_only_route + @verify_args_is_only_route def url_path_for(self, *args: str, **kwargs: str) -> URLPath: for route in self.routes: try: diff --git a/starlette/templating.py b/starlette/templating.py index e63e98a983..a60ace5fb2 100644 --- a/starlette/templating.py +++ b/starlette/templating.py @@ -55,8 +55,10 @@ def __init__(self, directory: str) -> None: def get_env(self, directory: str) -> "jinja2.Environment": @jinja2.contextfunction def url_for(context: dict, *args: str, **kwargs: typing.Any) -> str: - assert len(args) > 0, "Missing route name as the first argument." - assert len(args) < 2, "Invalid positional argument passed." + if len(args) < 1: + raise TypeError("Missing route name as the second argument.") + if len(args) > 1: + raise TypeError("Invalid positional argument passed.") request = context["request"] return request.url_for(args[0], **kwargs) diff --git a/tests/.ignore_lifespan b/tests/.ignore_lifespan index 0a33582179..c517550bb8 100644 --- a/tests/.ignore_lifespan +++ b/tests/.ignore_lifespan @@ -1,3 +1,9 @@ [coverage:run] omit = starlette/middleware/lifespan.py + +[report] +exclude_lines = + pragma: no cover + pragma: nocover + raise NotImplementedError diff --git a/tests/test_requests.py b/tests/test_requests.py index 9713014d97..c1985c2a99 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -290,8 +290,8 @@ def test_request_url_for(): app = Starlette() @app.route("/users/{name}") - async def func_users(request): # pragma: no cover - ... + async def func_users(request): + raise NotImplementedError() # ignored by coverage @app.route("/test") async def func_url_for_test(request: Request): diff --git a/tests/test_routing.py b/tests/test_routing.py index 0a358f21fc..0a4c6f6144 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -172,9 +172,9 @@ def test_url_path_for(): app.url_path_for("user", username="tom/christie") with pytest.raises(AssertionError): app.url_path_for("user", username="") - with pytest.raises(AssertionError): + with pytest.raises(TypeError): assert app.url_path_for("user", "args2", name="tomchristie1") - with pytest.raises(AssertionError): + with pytest.raises(TypeError): assert app.url_path_for(name="tomchristie1") diff --git a/tests/test_templates.py b/tests/test_templates.py index a0ab3e1b0b..d5ecc1495f 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -15,6 +15,12 @@ def test_templates(tmpdir): app = Starlette(debug=True) templates = Jinja2Templates(directory=str(tmpdir)) + url_for_func = templates.env.globals["url_for"] + with pytest.raises(TypeError): + assert url_for_func({}, "user", "args2", name="tomchristie") + with pytest.raises(TypeError): + assert url_for_func({}, name="tomchristie") + @app.route("/") async def homepage(request): return templates.TemplateResponse("index.html", {"request": request})