Skip to content

Commit

Permalink
Ensure accurate root_path removal in get_route_path function (#2600)
Browse files Browse the repository at this point in the history
* fix: regex inside function get_route_path to remove root_path

* fix: apply format ruff

* fix: mypy

---------

Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
  • Loading branch information
gabriel-f-santos and Kludex authored Sep 1, 2024
1 parent 1eb4036 commit 1131b3c
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 2 deletions.
2 changes: 1 addition & 1 deletion starlette/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,5 @@ def collapse_excgroups() -> typing.Generator[None, None, None]:

def get_route_path(scope: Scope) -> str:
root_path = scope.get("root_path", "")
route_path = re.sub(r"^" + root_path, "", scope["path"])
route_path = re.sub(r"^" + root_path + r"(?=/|$)", "", scope["path"])
return route_path
17 changes: 16 additions & 1 deletion tests/test__utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import functools
from typing import Any

from starlette._utils import is_async_callable
import pytest

from starlette._utils import get_route_path, is_async_callable
from starlette.types import Scope


def test_async_func() -> None:
Expand Down Expand Up @@ -78,3 +81,15 @@ async def async_func(
partial = functools.partial(async_func, b=2)
nested_partial = functools.partial(partial, a=1)
assert is_async_callable(nested_partial)


@pytest.mark.parametrize(
"scope, expected_result",
[
({"path": "/foo-123/bar", "root_path": "/foo"}, "/foo-123/bar"),
({"path": "/foo/bar", "root_path": "/foo"}, "/bar"),
({"path": "/foo", "root_path": "/foo"}, ""),
],
)
def test_get_route_path(scope: Scope, expected_result: str) -> None:
assert get_route_path(scope) == expected_result
14 changes: 14 additions & 0 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1221,6 +1221,12 @@ async def pure_asgi_echo_paths(scope: Scope, receive: Receive, send: Send, name:
name="path",
methods=["GET"],
),
Route(
"/root-queue/path",
functools.partial(echo_paths, name="queue_path"),
name="queue_path",
methods=["POST"],
),
Mount("/asgipath", app=functools.partial(pure_asgi_echo_paths, name="asgipath")),
Mount(
"/sub",
Expand Down Expand Up @@ -1266,3 +1272,11 @@ def test_paths_with_root_path(test_client_factory: TestClientFactory) -> None:
"path": "/root/sub/path",
"root_path": "/root/sub",
}

response = client.post("/root/root-queue/path")
assert response.status_code == 200
assert response.json() == {
"name": "queue_path",
"path": "/root/root-queue/path",
"root_path": "/root",
}

0 comments on commit 1131b3c

Please sign in to comment.