Skip to content

Commit

Permalink
Bugfix/starlette root path (#1833)
Browse files Browse the repository at this point in the history
Fixes the change in behaviour in starlette v0.33 via PR
[#2352](encode/starlette#2352)
  • Loading branch information
Ruwann authored Dec 29, 2023
1 parent 6dc9436 commit 0ae9ba8
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 8 deletions.
7 changes: 4 additions & 3 deletions connexion/middleware/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
scope.get("path_params", {})
)

api_base_path = scope.get("root_path", "")[
len(original_scope.get("root_path", "")) :
]
def get_root_path(scope: Scope) -> str:
return scope.get("route_root_path", scope.get("root_path", ""))

api_base_path = get_root_path(scope)[len(get_root_path(original_scope)) :]

extensions = original_scope.setdefault("extensions", {})
connexion_routing = extensions.setdefault(ROUTING_CONTEXT, {})
Expand Down
4 changes: 3 additions & 1 deletion connexion/middleware/swagger_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def _base_path_for_prefix(self, request: StarletteRequest) -> str:
"""
returns a modified basePath which includes the incoming root_path.
"""
return request.scope.get("root_path", "").rstrip("/")
return request.scope.get(
"route_root_path", request.scope.get("root_path", "")
).rstrip("/")

def _spec_for_prefix(self, request):
"""
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Jinja2 = ">= 3.0.0"
python-multipart = ">= 0.0.5"
PyYAML = ">= 5.1"
requests = ">= 2.27"
starlette = ">= 0.27, <0.33"
starlette = ">= 0.27"
typing-extensions = ">= 4"
werkzeug = ">= 2.2.1"

Expand Down
6 changes: 3 additions & 3 deletions tests/test_flask_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,16 @@ def get_value(data, path):
assert example == "a7b8869c-5f24-4ce0-a5d1-3e44c3663aa9"

res = app_client.get("/v1.0/datetime")
assert res.status_code == 200, f"Error is {res.data}"
assert res.status_code == 200, f"Error is {res.text}"
data = res.json()
assert data == {"value": "2000-01-02T03:04:05.000006Z"}

res = app_client.get("/v1.0/date")
assert res.status_code == 200, f"Error is {res.data}"
assert res.status_code == 200, f"Error is {res.text}"
data = res.json()
assert data == {"value": "2000-01-02"}

res = app_client.get("/v1.0/uuid")
assert res.status_code == 200, f"Error is {res.data}"
assert res.status_code == 200, f"Error is {res.text}"
data = res.json()
assert data == {"value": "e7ff66d0-3ec2-4c4e-bed0-6e4723c24c51"}

0 comments on commit 0ae9ba8

Please sign in to comment.