Skip to content

Commit

Permalink
allow 'name' in key word arguments of url_for() and url_path_for() (#608
Browse files Browse the repository at this point in the history
)
  • Loading branch information
dansan committed Aug 28, 2019
1 parent ce2883e commit 0ddea63
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 30 deletions.
3 changes: 3 additions & 0 deletions docs/release-notes.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
## 0.1...
* Backwards compatible change to signatures of URL reversing methods `url_for()` and `url_path_for()` allows to use `name` in path arguments.

## 0.12.1

* Add `URL.include_query_params(**kwargs)`
Expand Down
4 changes: 2 additions & 2 deletions starlette/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def decorator(func: typing.Callable) -> typing.Callable:

return decorator

def url_path_for(self, name: str, **path_params: str) -> URLPath:
return self.router.url_path_for(name, **path_params)
def url_path_for(self, *args: str, **kwargs: str) -> URLPath:
return self.router.url_path_for(*args, **kwargs)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
scope["app"] = self
Expand Down
4 changes: 2 additions & 2 deletions starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ def state(self) -> State:
self._state = State(self.scope["state"])
return self._state

def url_for(self, name: str, **path_params: typing.Any) -> str:
def url_for(self, *args: str, **kwargs: typing.Any) -> str:
router = self.scope["router"]
url_path = router.url_path_for(name, **path_params)
url_path = router.url_path_for(*args, **kwargs)
return url_path.make_absolute_url(base_url=self.url)


Expand Down
64 changes: 41 additions & 23 deletions starlette/routing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import inspect
import functools
import re
import traceback
import typing
Expand All @@ -17,7 +18,7 @@

class NoMatchFound(Exception):
"""
Raised by `.url_for(name, **path_params)` and `.url_path_for(name, **path_params)`
Raised by `.url_for(*args, **kwargs)` and `.url_path_for(*args, **kwargs)`
if no matching route exists.
"""

Expand All @@ -28,6 +29,16 @@ class Match(Enum):
FULL = 2


def assert_args_is_only_route(func: typing.Callable) -> typing.Callable:
@functools.wraps(func)
def wrapper(self: BaseRoute, *args: str, **kwargs: str) -> URLPath:
assert len(args) > 0, "Missing route name as the first argument."
assert len(args) < 2, "Invalid positional argument passed."
return func(self, *args, **kwargs)

return wrapper


def request_response(func: typing.Callable) -> ASGIApp:
"""
Takes a function or coroutine `func(request) -> response`,
Expand Down Expand Up @@ -127,7 +138,7 @@ class BaseRoute:
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
raise NotImplementedError() # pragma: no cover

def url_path_for(self, name: str, **path_params: str) -> URLPath:
def url_path_for(self, *args: str, **kwargs: str) -> URLPath:
raise NotImplementedError() # pragma: no cover

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
Expand Down Expand Up @@ -184,15 +195,16 @@ def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
return Match.FULL, child_scope
return Match.NONE, {}

def url_path_for(self, name: str, **path_params: str) -> URLPath:
seen_params = set(path_params.keys())
@assert_args_is_only_route
def url_path_for(self, *args: str, **kwargs: str) -> URLPath:
seen_params = set(kwargs.keys())
expected_params = set(self.param_convertors.keys())

if name != self.name or seen_params != expected_params:
if args[0] != self.name or seen_params != expected_params:
raise NoMatchFound()

path, remaining_params = replace_params(
self.path_format, self.param_convertors, path_params
self.path_format, self.param_convertors, kwargs
)
assert not remaining_params
return URLPath(path=path, protocol="http")
Expand Down Expand Up @@ -247,15 +259,16 @@ def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
return Match.FULL, child_scope
return Match.NONE, {}

def url_path_for(self, name: str, **path_params: str) -> URLPath:
seen_params = set(path_params.keys())
@assert_args_is_only_route
def url_path_for(self, *args: str, **kwargs: str) -> URLPath:
seen_params = set(kwargs.keys())
expected_params = set(self.param_convertors.keys())

if name != self.name or seen_params != expected_params:
if args[0] != self.name or seen_params != expected_params:
raise NoMatchFound()

path, remaining_params = replace_params(
self.path_format, self.param_convertors, path_params
self.path_format, self.param_convertors, kwargs
)
assert not remaining_params
return URLPath(path=path, protocol="websocket")
Expand Down Expand Up @@ -318,12 +331,14 @@ def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
return Match.FULL, child_scope
return Match.NONE, {}

def url_path_for(self, name: str, **path_params: str) -> URLPath:
if self.name is not None and name == self.name and "path" in path_params:
@assert_args_is_only_route
def url_path_for(self, *args: str, **kwargs: str) -> URLPath:
name = args[0]
if self.name is not None and name == self.name and "path" in kwargs:
# 'name' matches "<mount_name>".
path_params["path"] = path_params["path"].lstrip("/")
kwargs["path"] = kwargs["path"].lstrip("/")
path, remaining_params = replace_params(
self.path_format, self.param_convertors, path_params
self.path_format, self.param_convertors, kwargs
)
if not remaining_params:
return URLPath(path=path)
Expand All @@ -334,9 +349,9 @@ def url_path_for(self, name: str, **path_params: str) -> URLPath:
else:
# 'name' matches "<mount_name>:<child_name>".
remaining_name = name[len(self.name) + 1 :]
path_params["path"] = ""
kwargs["path"] = ""
path, remaining_params = replace_params(
self.path_format, self.param_convertors, path_params
self.path_format, self.param_convertors, kwargs
)
for route in self.routes or []:
try:
Expand Down Expand Up @@ -385,12 +400,14 @@ def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
return Match.FULL, child_scope
return Match.NONE, {}

def url_path_for(self, name: str, **path_params: str) -> URLPath:
if self.name is not None and name == self.name and "path" in path_params:
@assert_args_is_only_route
def url_path_for(self, *args: str, **kwargs: str) -> URLPath:
name = args[0]
if self.name is not None and name == self.name and "path" in kwargs:
# 'name' matches "<mount_name>".
path = path_params.pop("path")
path = kwargs.pop("path")
host, remaining_params = replace_params(
self.host_format, self.param_convertors, path_params
self.host_format, self.param_convertors, kwargs
)
if not remaining_params:
return URLPath(path=path, host=host)
Expand All @@ -402,7 +419,7 @@ def url_path_for(self, name: str, **path_params: str) -> URLPath:
# 'name' matches "<mount_name>:<child_name>".
remaining_name = name[len(self.name) + 1 :]
host, remaining_params = replace_params(
self.host_format, self.param_convertors, path_params
self.host_format, self.param_convertors, kwargs
)
for route in self.routes or []:
try:
Expand Down Expand Up @@ -567,10 +584,11 @@ async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
response = PlainTextResponse("Not Found", status_code=404)
await response(scope, receive, send)

def url_path_for(self, name: str, **path_params: str) -> URLPath:
@assert_args_is_only_route
def url_path_for(self, *args: str, **kwargs: str) -> URLPath:
for route in self.routes:
try:
return route.url_path_for(name, **path_params)
return route.url_path_for(args[0], **kwargs)
except NoMatchFound:
pass
raise NoMatchFound()
Expand Down
6 changes: 4 additions & 2 deletions starlette/templating.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@ def __init__(self, directory: str) -> None:

def get_env(self, directory: str) -> "jinja2.Environment":
@jinja2.contextfunction
def url_for(context: dict, name: str, **path_params: typing.Any) -> str:
def url_for(context: dict, *args: str, **kwargs: typing.Any) -> str:
assert len(args) > 0, "Missing route name as the first argument."
assert len(args) < 2, "Invalid positional argument passed."
request = context["request"]
return request.url_for(name, **path_params)
return request.url_for(args[0], **kwargs)

loader = jinja2.FileSystemLoader(directory)
env = jinja2.Environment(loader=loader, autoescape=True)
Expand Down
19 changes: 19 additions & 0 deletions tests/test_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest

from starlette.applications import Starlette
from starlette.requests import ClientDisconnect, Request, State
from starlette.responses import JSONResponse, Response
from starlette.testclient import TestClient
Expand Down Expand Up @@ -285,6 +286,24 @@ async def app(scope, receive, send):
assert response.text == "Hello, cookies!"


def test_request_url_for():
app = Starlette()

@app.route("/users/{name}")
async def func_users(request): # pragma: no cover
...

@app.route("/test")
async def func_url_for_test(request: Request):
url = request.url_for("func_users", name="abcde")
return Response(str(url), media_type="text/plain")

client = TestClient(app)
response = client.get("/test")
assert response.status_code == 200
assert response.text == "http://testserver/users/abcde"


def test_chunked_encoding():
async def app(scope, receive, send):
request = Request(scope, receive)
Expand Down
21 changes: 20 additions & 1 deletion tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ def user(request):
return Response(content, media_type="text/plain")


def user2(request):
content = "User 2 " + request.path_params["name"]
return Response(content, media_type="text/plain")


def user_me(request):
content = "User fixed me"
return Response(content, media_type="text/plain")
Expand All @@ -38,6 +43,7 @@ def user_no_match(request): # pragma: no cover
Route("/", endpoint=users),
Route("/me", endpoint=user_me),
Route("/{username}", endpoint=user),
Route("/n/{name}", endpoint=user2),
Route("/nomatch", endpoint=user_no_match),
],
),
Expand Down Expand Up @@ -112,6 +118,10 @@ def test_router():
assert response.status_code == 200
assert response.text == "User tomchristie"

response = client.get("/users/n/tomchristie")
assert response.status_code == 200
assert response.text == "User 2 tomchristie"

response = client.get("/users/me")
assert response.status_code == 200
assert response.text == "User fixed me"
Expand Down Expand Up @@ -149,14 +159,23 @@ def test_route_converters():

def test_url_path_for():
assert app.url_path_for("homepage") == "/"
assert app.url_path_for("user", username="tomchristie") == "/users/tomchristie"
assert app.url_path_for("user", username="tomchristie1") == "/users/tomchristie1"
assert app.url_path_for("user2", name="tomchristie2") == "/users/n/tomchristie2"
with pytest.raises(NoMatchFound):
assert app.url_path_for("user", name="tomchristie1")
with pytest.raises(NoMatchFound):
assert app.url_path_for("user2", username="tomchristie2")
assert app.url_path_for("websocket_endpoint") == "/ws"
with pytest.raises(NoMatchFound):
assert app.url_path_for("broken")
with pytest.raises(AssertionError):
app.url_path_for("user", username="tom/christie")
with pytest.raises(AssertionError):
app.url_path_for("user", username="")
with pytest.raises(AssertionError):
assert app.url_path_for("user", "args2", name="tomchristie1")
with pytest.raises(AssertionError):
assert app.url_path_for(name="tomchristie1")


def test_url_for():
Expand Down

0 comments on commit 0ddea63

Please sign in to comment.