Skip to content

Commit

Permalink
allow 'name' in key word arguments of url_for() and url_path_for() (#608
Browse files Browse the repository at this point in the history
)
  • Loading branch information
dansan committed Aug 22, 2019
1 parent cba5eb8 commit e158072
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 39 deletions.
3 changes: 3 additions & 0 deletions docs/release-notes.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
## 0.1...
* Backwards compatible change to signatures of URL reversing methods `url_for()` and `url_path_for()` allows to use `name` in path arguments.

## 0.12.1

* Add `URL.include_query_params(**kwargs)`
Expand Down
4 changes: 2 additions & 2 deletions starlette/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ def decorator(func: typing.Callable) -> typing.Callable:

return decorator

def url_path_for(self, name: str, **path_params: str) -> URLPath:
return self.router.url_path_for(name, **path_params)
def url_path_for(self, *args: str, **kwargs: str) -> URLPath:
return self.router.url_path_for(*args, **kwargs)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
scope["app"] = self
Expand Down
4 changes: 2 additions & 2 deletions starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,9 @@ def state(self) -> State:
self._state = State(self.scope["state"])
return self._state

def url_for(self, name: str, **path_params: typing.Any) -> str:
def url_for(self, *args: str, **kwargs: typing.Any) -> str:
router = self.scope["router"]
url_path = router.url_path_for(name, **path_params)
url_path = router.url_path_for(*args, **kwargs)
return url_path.make_absolute_url(base_url=self.url)


Expand Down
53 changes: 30 additions & 23 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

class NoMatchFound(Exception):
"""
Raised by `.url_for(name, **path_params)` and `.url_path_for(name, **path_params)`
Raised by `.url_for(*args, **kwargs)` and `.url_path_for(*args, **kwargs)`
if no matching route exists.
"""

Expand Down Expand Up @@ -127,7 +127,7 @@ class BaseRoute:
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
raise NotImplementedError() # pragma: no cover

def url_path_for(self, name: str, **path_params: str) -> URLPath:
def url_path_for(self, *args: str, **kwargs: str) -> URLPath:
raise NotImplementedError() # pragma: no cover

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
Expand Down Expand Up @@ -184,15 +184,16 @@ def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
return Match.FULL, child_scope
return Match.NONE, {}

def url_path_for(self, name: str, **path_params: str) -> URLPath:
seen_params = set(path_params.keys())
def url_path_for(self, *args: str, **kwargs: str) -> URLPath:
assert len(args) == 1, "Exactly one positional argument required: the routes name."
seen_params = set(kwargs.keys())
expected_params = set(self.param_convertors.keys())

if name != self.name or seen_params != expected_params:
if args[0] != self.name or seen_params != expected_params:
raise NoMatchFound()

path, remaining_params = replace_params(
self.path_format, self.param_convertors, path_params
self.path_format, self.param_convertors, kwargs
)
assert not remaining_params
return URLPath(path=path, protocol="http")
Expand Down Expand Up @@ -247,15 +248,16 @@ def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
return Match.FULL, child_scope
return Match.NONE, {}

def url_path_for(self, name: str, **path_params: str) -> URLPath:
seen_params = set(path_params.keys())
def url_path_for(self, *args: str, **kwargs: str) -> URLPath:
assert len(args) == 1, "Exactly one positional argument required: the routes name."
seen_params = set(kwargs.keys())
expected_params = set(self.param_convertors.keys())

if name != self.name or seen_params != expected_params:
if args[0] != self.name or seen_params != expected_params:
raise NoMatchFound()

path, remaining_params = replace_params(
self.path_format, self.param_convertors, path_params
self.path_format, self.param_convertors, kwargs
)
assert not remaining_params
return URLPath(path=path, protocol="websocket")
Expand Down Expand Up @@ -318,12 +320,14 @@ def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
return Match.FULL, child_scope
return Match.NONE, {}

def url_path_for(self, name: str, **path_params: str) -> URLPath:
if self.name is not None and name == self.name and "path" in path_params:
def url_path_for(self, *args: str, **kwargs: str) -> URLPath:
assert len(args) == 1, "Exactly one positional argument required: the routes name."
name = args[0]
if self.name is not None and name == self.name and "path" in kwargs:
# 'name' matches "<mount_name>".
path_params["path"] = path_params["path"].lstrip("/")
kwargs["path"] = kwargs["path"].lstrip("/")
path, remaining_params = replace_params(
self.path_format, self.param_convertors, path_params
self.path_format, self.param_convertors, kwargs
)
if not remaining_params:
return URLPath(path=path)
Expand All @@ -334,9 +338,9 @@ def url_path_for(self, name: str, **path_params: str) -> URLPath:
else:
# 'name' matches "<mount_name>:<child_name>".
remaining_name = name[len(self.name) + 1 :]
path_params["path"] = ""
kwargs["path"] = ""
path, remaining_params = replace_params(
self.path_format, self.param_convertors, path_params
self.path_format, self.param_convertors, kwargs
)
for route in self.routes or []:
try:
Expand Down Expand Up @@ -385,12 +389,14 @@ def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
return Match.FULL, child_scope
return Match.NONE, {}

def url_path_for(self, name: str, **path_params: str) -> URLPath:
if self.name is not None and name == self.name and "path" in path_params:
def url_path_for(self, *args: str, **kwargs: str) -> URLPath:
assert len(args) == 1, "Exactly one positional argument required: the routes name."
name = args[0]
if self.name is not None and name == self.name and "path" in kwargs:
# 'name' matches "<mount_name>".
path = path_params.pop("path")
path = kwargs.pop("path")
host, remaining_params = replace_params(
self.host_format, self.param_convertors, path_params
self.host_format, self.param_convertors, kwargs
)
if not remaining_params:
return URLPath(path=path, host=host)
Expand All @@ -402,7 +408,7 @@ def url_path_for(self, name: str, **path_params: str) -> URLPath:
# 'name' matches "<mount_name>:<child_name>".
remaining_name = name[len(self.name) + 1 :]
host, remaining_params = replace_params(
self.host_format, self.param_convertors, path_params
self.host_format, self.param_convertors, kwargs
)
for route in self.routes or []:
try:
Expand Down Expand Up @@ -567,10 +573,11 @@ async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
response = PlainTextResponse("Not Found", status_code=404)
await response(scope, receive, send)

def url_path_for(self, name: str, **path_params: str) -> URLPath:
def url_path_for(self, *args: str, **kwargs: str) -> URLPath:
assert len(args) == 1, "Exactly one positional argument required: the routes name."
for route in self.routes:
try:
return route.url_path_for(name, **path_params)
return route.url_path_for(args[0], **kwargs)
except NoMatchFound:
pass
raise NoMatchFound()
Expand Down
5 changes: 3 additions & 2 deletions starlette/templating.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@ def __init__(self, directory: str) -> None:

def get_env(self, directory: str) -> "jinja2.Environment":
@jinja2.contextfunction
def url_for(context: dict, name: str, **path_params: typing.Any) -> str:
def url_for(context: dict, *args: str, **kwargs: typing.Any) -> str:
assert len(args) == 1, "Exactly one positional argument in *args required."
request = context["request"]
return request.url_for(name, **path_params)
return request.url_for(args[0], **kwargs)

loader = jinja2.FileSystemLoader(directory)
env = jinja2.Environment(loader=loader, autoescape=True)
Expand Down
20 changes: 20 additions & 0 deletions tests/test_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest

from starlette.applications import Starlette
from starlette.requests import ClientDisconnect, Request, State
from starlette.responses import JSONResponse, Response
from starlette.testclient import TestClient
Expand Down Expand Up @@ -285,6 +286,25 @@ async def app(scope, receive, send):
assert response.text == "Hello, cookies!"


def test_request_url_for():
app = Starlette()

@app.route("/users/{name}")
async def func_users(request):
name = request.path_params["name"]
return Response(name, media_type="text/plain")

@app.route("/test")
async def func_url_for_test(request: Request):
url = request.url_for("func_users", name="abcde")
return Response(str(url), media_type="text/plain")

client = TestClient(app)
response = client.get("/test")
assert response.status_code == 200
assert response.text == "http://testserver/users/abcde"


def test_chunked_encoding():
async def app(scope, receive, send):
request = Request(scope, receive)
Expand Down
20 changes: 10 additions & 10 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def users(request):


def user(request):
content = "User " + request.path_params["username"]
content = "User " + request.path_params["name"]
return Response(content, media_type="text/plain")


Expand All @@ -37,7 +37,7 @@ def user_no_match(request): # pragma: no cover
routes=[
Route("/", endpoint=users),
Route("/me", endpoint=user_me),
Route("/{username}", endpoint=user),
Route("/{name}", endpoint=user),
Route("/nomatch", endpoint=user_no_match),
],
),
Expand Down Expand Up @@ -149,14 +149,14 @@ def test_route_converters():

def test_url_path_for():
assert app.url_path_for("homepage") == "/"
assert app.url_path_for("user", username="tomchristie") == "/users/tomchristie"
assert app.url_path_for("user", name="tomchristie") == "/users/tomchristie"
assert app.url_path_for("websocket_endpoint") == "/ws"
with pytest.raises(NoMatchFound):
assert app.url_path_for("broken")
with pytest.raises(AssertionError):
app.url_path_for("user", username="tom/christie")
app.url_path_for("user", name="tom/christie")
with pytest.raises(AssertionError):
app.url_path_for("user", username="")
app.url_path_for("user", name="")


def test_url_for():
Expand All @@ -165,7 +165,7 @@ def test_url_for():
== "https://example.org/"
)
assert (
app.url_path_for("user", username="tomchristie").make_absolute_url(
app.url_path_for("user", name="tomchristie").make_absolute_url(
base_url="https://example.org"
)
== "https://example.org/users/tomchristie"
Expand Down Expand Up @@ -252,10 +252,10 @@ def test_reverse_mount_urls():
mounted = Router([Mount("/users", ok, name="users")])
assert mounted.url_path_for("users", path="/a") == "/users/a"

users = Router([Route("/{username}", ok, name="user")])
users = Router([Route("/{name}", ok, name="user")])
mounted = Router([Mount("/{subpath}/users", users, name="users")])
assert (
mounted.url_path_for("users:user", subpath="test", username="tom")
mounted.url_path_for("users:user", subpath="test", name="tom")
== "/test/users/tom"
)
assert (
Expand All @@ -270,7 +270,7 @@ def test_mount_at_root():


def users_api(request):
return JSONResponse({"users": [{"username": "tom"}]})
return JSONResponse({"users": [{"name": "tom"}]})


mixed_hosts_app = Router(
Expand Down Expand Up @@ -298,7 +298,7 @@ def test_host_routing():

response = client.get("/users")
assert response.status_code == 200
assert response.json() == {"users": [{"username": "tom"}]}
assert response.json() == {"users": [{"name": "tom"}]}

response = client.get("/")
assert response.status_code == 404
Expand Down

0 comments on commit e158072

Please sign in to comment.