Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(BA-524): Deprecate X-BackendAI-SSO header for pipeline service authentication #3353

Merged
Merged
1 change: 1 addition & 0 deletions changes/3353.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Deprecate the JWT-based `X-BackendAI-SSO` header to reduce complexity in authentication process for the pipeline service
1 change: 0 additions & 1 deletion configs/webserver/halfstack.conf
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ max_file_upload_size = 4294967296
[pipeline]
endpoint = "http://127.0.0.1:9500"
frontend-endpoint = "http://127.0.0.1:3000"
jwt.secret = "7<:~[X,^Z1XM!*,Pe:PHR!bv,H~Q#l177<7gf_XHD6.<*<.t<[o|V5W(=0x:jTh-"

[ui]
brand = "Lablup Cloud"
Expand Down
2 changes: 0 additions & 2 deletions configs/webserver/sample.conf
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,6 @@ show_non_installed_images = false
#endpoint = "http://127.0.0.1:9500"
# Endpoint to the pipeline service's frontend
#frontend-endpoint = "http://127.0.0.1:3000"
# A secret to sign JWTs used to authenticate users from the pipeline service
#jwt.secret = "7<:~[X,^Z1XM!*,Pe:PHR!bv,H~Q#l177<7gf_XHD6.<*<.t<[o|V5W(=0x:jTh-"

[ui]
brand = "Lablup Cloud"
Expand Down
7 changes: 0 additions & 7 deletions src/ai/backend/web/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,6 @@
{
t.Key("endpoint", default=_config_defaults["pipeline"]["endpoint"]): tx.URL,
t.Key("frontend-endpoint", default=None): t.Null | tx.URL,
t.Key("jwt", default=_config_defaults["pipeline"]["jwt"]): t.Dict(
{
t.Key(
"secret", default=_config_defaults["pipeline"]["jwt"]["secret"]
): t.String,
},
).allow_extra("*"),
},
).allow_extra("*"),
t.Key("ui"): t.Dict({
Expand Down
65 changes: 20 additions & 45 deletions src/ai/backend/web/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@
import json
import logging
import random
from datetime import datetime, timedelta
from typing import Optional, Tuple, Union, cast
from typing import Iterable, Optional, Tuple, Union, cast

import aiohttp
import jwt
from aiohttp import web
from Crypto.Cipher import AES
from Crypto.Util.Padding import unpad
Expand Down Expand Up @@ -152,23 +150,21 @@ async def decrypt_payload(request: web.Request, handler) -> web.StreamResponse:
return await handler(request)


async def web_handler(request: web.Request, *, is_anonymous=False) -> web.StreamResponse:
async def web_handler(
request: web.Request,
*,
is_anonymous: bool = False,
api_endpoint: Optional[str] = None,
http_headers_to_forward_extra: Iterable[str] | None = None,
) -> web.StreamResponse:
stats: WebStats = request.app["stats"]
stats.active_proxy_api_handlers.add(asyncio.current_task()) # type: ignore
config = request.app["config"]
path = request.match_info.get("path", "")
proxy_path, _, real_path = request.path.lstrip("/").partition("/")
if proxy_path == "pipeline":
pipeline_config = config["pipeline"]
if not pipeline_config:
raise RuntimeError("'pipeline' config must be set to handle pipeline requests.")
endpoint = pipeline_config["endpoint"]
log.info(f"WEB_HANDLER: {request.path} -> {endpoint}/{real_path}")
api_session = await asyncio.shield(get_api_session(request, endpoint))
elif is_anonymous:
api_session = await asyncio.shield(get_anonymous_session(request))
if is_anonymous:
api_session = await asyncio.shield(get_anonymous_session(request, api_endpoint))
else:
api_session = await asyncio.shield(get_api_session(request))
api_session = await asyncio.shield(get_api_session(request, api_endpoint))
http_headers_to_forward_extra = http_headers_to_forward_extra or []
try:
async with api_session:
# We perform request signing by ourselves using the HTTP session data,
Expand Down Expand Up @@ -204,26 +200,12 @@ async def web_handler(request: web.Request, *, is_anonymous=False) -> web.Stream
api_rqst.headers["Content-Length"] = request.headers["Content-Length"]
if "Content-Length" in request.headers and secure_context:
api_rqst.headers["Content-Length"] = str(decrypted_payload_length)
for hdr in HTTP_HEADERS_TO_FORWARD:
for hdr in {*HTTP_HEADERS_TO_FORWARD, *http_headers_to_forward_extra}:
# Prevent malicious or accidental modification of critical headers.
if hdr in api_rqst.headers:
continue
if request.headers.get(hdr) is not None:
api_rqst.headers[hdr] = request.headers[hdr]
if proxy_path == "pipeline":
session_id = request.headers.get("X-BackendAI-SessionID", "")
if not (sso_token := request.headers.get("X-BackendAI-SSO")):
jwt_secret = config["pipeline"]["jwt"]["secret"]
now = datetime.now().astimezone()
payload = {
# Registered claims
"exp": now + timedelta(seconds=config["session"]["max_age"]),
"iss": "Backend.AI Webserver",
"iat": now,
# Private claims
"aiohttp_session": session_id,
"access_key": api_session.config.access_key, # since 23.03.10
}
sso_token = jwt.encode(payload, key=jwt_secret, algorithm="HS256")
api_rqst.headers["X-BackendAI-SSO"] = sso_token
api_rqst.headers["X-BackendAI-SessionID"] = session_id
Comment on lines -210 to -226
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does Fasttrack handle version compatibility?
When making changes this time, is there any risk of breaking functionality for clients using previous versions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is recommended to use the same version as Backend.AI Core.
To enhance version compatibility, we may adopt a dedicated header (e.g., X-BackendAI-FastTrack-Version).

# Uploading request body happens at the entering of the block,
# and downloading response body happens in the read loop inside.
async with api_rqst.fetch() as up_resp:
Expand Down Expand Up @@ -347,15 +329,16 @@ async def web_plugin_handler(request, *, is_anonymous=False) -> web.StreamRespon
)


async def websocket_handler(request, *, is_anonymous=False) -> web.StreamResponse:
async def websocket_handler(
request, *, is_anonymous=False, api_endpoint: Optional[str] = None
) -> web.StreamResponse:
stats: WebStats = request.app["stats"]
stats.active_proxy_websocket_handlers.add(asyncio.current_task()) # type: ignore
path = request.match_info["path"]
session = await get_session(request)
app = request.query.get("app")

# Choose a specific Manager endpoint for persistent web app connection.
api_endpoint = None
should_save_session = False
configured_endpoints = request.app["config"]["api"]["endpoint"]
if session.get("api_endpoints", {}).get(app):
Expand All @@ -369,15 +352,7 @@ async def websocket_handler(request, *, is_anonymous=False) -> web.StreamRespons
session["api_endpoints"][app] = str(api_endpoint)
should_save_session = True

proxy_path, _, real_path = request.path.lstrip("/").partition("/")
if proxy_path == "pipeline":
pipeline_config = request.app["config"]["pipeline"]
if not pipeline_config:
raise RuntimeError("'pipeline' config must be set to handle pipeline requests.")
endpoint = pipeline_config["endpoint"].with_scheme("ws")
log.info(f"WEBSOCKET_HANDLER {request.path} -> {endpoint}/{real_path}")
api_session = await asyncio.shield(get_anonymous_session(request, endpoint))
elif is_anonymous:
if is_anonymous:
api_session = await asyncio.shield(get_anonymous_session(request, api_endpoint))
else:
api_session = await asyncio.shield(get_api_session(request, api_endpoint))
Expand Down
25 changes: 19 additions & 6 deletions src/ai/backend/web/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,18 @@ async def server_main(
anon_web_handler = partial(web_handler, is_anonymous=True)
anon_web_plugin_handler = partial(web_plugin_handler, is_anonymous=True)

pipeline_api_endpoint = config["pipeline"]["endpoint"]
pipeline_handler = partial(web_handler, is_anonymous=True, api_endpoint=pipeline_api_endpoint)
pipeline_login_handler = partial(
web_handler,
is_anonymous=False,
api_endpoint=pipeline_api_endpoint,
http_headers_to_forward_extra={"X-BackendAI-SessionID"},
)
pipeline_websocket_handler = partial(
websocket_handler, is_anonymous=True, api_endpoint=pipeline_api_endpoint.with_scheme("ws")
)

app.router.add_route("HEAD", "/func/{path:folders/_/tus/upload/.*$}", anon_web_plugin_handler)
app.router.add_route("PATCH", "/func/{path:folders/_/tus/upload/.*$}", anon_web_plugin_handler)
app.router.add_route(
Expand Down Expand Up @@ -688,12 +700,13 @@ async def server_main(
cors.add(app.router.add_route("POST", "/func/{path:.*$}", web_handler))
cors.add(app.router.add_route("PATCH", "/func/{path:.*$}", web_handler))
cors.add(app.router.add_route("DELETE", "/func/{path:.*$}", web_handler))
cors.add(app.router.add_route("GET", "/pipeline/{path:stream/.*$}", websocket_handler))
cors.add(app.router.add_route("GET", "/pipeline/{path:.*$}", web_handler))
cors.add(app.router.add_route("PUT", "/pipeline/{path:.*$}", web_handler))
cors.add(app.router.add_route("POST", "/pipeline/{path:.*$}", web_handler))
cors.add(app.router.add_route("PATCH", "/pipeline/{path:.*$}", web_handler))
cors.add(app.router.add_route("DELETE", "/pipeline/{path:.*$}", web_handler))
cors.add(app.router.add_route("GET", "/pipeline/{path:stream/.*$}", pipeline_websocket_handler))
cors.add(app.router.add_route("POST", "/pipeline/{path:login/$}", pipeline_login_handler))
cors.add(app.router.add_route("GET", "/pipeline/{path:.*$}", pipeline_handler))
cors.add(app.router.add_route("PUT", "/pipeline/{path:.*$}", pipeline_handler))
cors.add(app.router.add_route("POST", "/pipeline/{path:.*$}", pipeline_handler))
cors.add(app.router.add_route("PATCH", "/pipeline/{path:.*$}", pipeline_handler))
cors.add(app.router.add_route("DELETE", "/pipeline/{path:.*$}", pipeline_handler))
if config["service"]["mode"] == "webui":
cors.add(app.router.add_route("GET", "/config.ini", config_ini_handler))
cors.add(app.router.add_route("GET", "/config.toml", config_toml_handler))
Expand Down
Loading