Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add authentication headers when proxying pipeline websocket requests #1457

Closed
wants to merge 10 commits into from
1 change: 1 addition & 0 deletions changes/1457.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `X-BackendAI-SessionID` and `X-BackendAI-SSO` headers to the pipeline websocket proxy handler.
2 changes: 1 addition & 1 deletion src/ai/backend/manager/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2560,7 +2560,7 @@ async def _restart_kernel(kernel: KernelRow) -> None:
keepalive_timeout=self.rpc_keepalive_timeout,
) as rpc:
updated_config: Dict[str, Any] = {
# TODO: support resacling of sub-containers
# TODO: support rescaling of sub-containers
}
kernel_info = await rpc.call.restart_kernel(
str(kernel.session_id),
Expand Down
7 changes: 6 additions & 1 deletion src/ai/backend/web/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,12 @@
t.Key("pipeline"): t.Dict(
{
t.Key("endpoint", default=None): t.Null | tx.URL,
}
t.Key("jwt"): t.Dict(
{
t.Key("secret"): t.String,
},
),
},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I am guessing correctly, this jwt directive is for the traffics from FastTrack. FastTrack is a private project and thus only limited number of users who operate Backend.AI webserver will be required to fill out this field. When taking account into this situation I don't think this whole pipeline configuration body should be a required config values. Please refactor the validator to either mark the pipeline configuration as an Optional value or at lease leave jwt as nullable and dynamically check its existence when the request actually comes from the FastTrack side.

).allow_extra("*"),
t.Key("ui"): t.Dict(
{
Expand Down
32 changes: 26 additions & 6 deletions src/ai/backend/web/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
import logging
import random
from datetime import datetime, timedelta, timezone
from datetime import datetime, timedelta
from typing import Optional, Tuple, Union, cast

import aiohttp
Expand Down Expand Up @@ -207,18 +207,18 @@ async def web_handler(request: web.Request, *, is_anonymous=False) -> web.Stream
aiohttp_session = request.cookies.get("AIOHTTP_SESSION")
if not (sso_token := request.headers.get("X-BackendAI-SSO")):
jwt_secret = request.app["config"]["pipeline"]["jwt"]["secret"]
now = datetime.now(tz=timezone(timedelta(hours=9)))
payload = {
now = datetime.now().astimezone()
claims = {
# Registered claims
"exp": now + timedelta(hours=1),
"exp": now + timedelta(days=1),
"iss": "Backend.AI Webserver",
"iat": now,
# Private claims
"aiohttp_session": aiohttp_session,
"access_key": api_session.config.access_key,
# "secret_key": api_session.config.secret_key,
}
sso_token = jwt.encode(payload, key=jwt_secret, algorithm="HS256")
sso_token = jwt.encode(claims, key=jwt_secret, algorithm="HS256")
api_rqst.headers["X-BackendAI-SSO"] = sso_token
if session_id := (request_headers.get("X-BackendAI-SessionID") or aiohttp_session):
api_rqst.headers["X-BackendAI-SessionID"] = session_id
Expand Down Expand Up @@ -382,13 +382,14 @@ async def websocket_handler(request, *, is_anonymous=False) -> web.StreamRespons
else:
endpoint = endpoint.with_scheme("ws")
log.info(f"WEBSOCKET_HANDLER {request.path} -> {endpoint}/{real_path}")
api_session = await asyncio.shield(get_anonymous_session(request, endpoint))
api_session = await asyncio.shield(get_api_session(request, endpoint))
elif is_anonymous:
api_session = await asyncio.shield(get_anonymous_session(request, api_endpoint))
else:
api_session = await asyncio.shield(get_api_session(request, api_endpoint))
try:
async with api_session:
request_headers = extra_config_headers.check(request.headers)
request_api_version = request.headers.get("X-BackendAI-Version", None)
fill_forwarding_hdrs_to_api_session(request, api_session)
api_rqst = Request(
Expand All @@ -399,6 +400,25 @@ async def websocket_handler(request, *, is_anonymous=False) -> web.StreamRespons
content_type=request.content_type,
override_api_version=request_api_version,
)
if proxy_path == "pipeline":
aiohttp_session = request.cookies.get("AIOHTTP_SESSION")
if not (sso_token := request.headers.get("X-BackendAI-SSO")):
jwt_secret = request.app["config"]["pipeline"]["jwt"]["secret"]
kyujin-cho marked this conversation as resolved.
Show resolved Hide resolved
now = datetime.now().astimezone()
claims = {
# Registered claims
"exp": now + timedelta(days=1),
"iss": "Backend.AI Webserver",
"iat": now,
# Private claims
"aiohttp_session": aiohttp_session,
"access_key": api_session.config.access_key,
# "secret_key": api_session.config.secret_key,
}
sso_token = jwt.encode(claims, key=jwt_secret, algorithm="HS256")
api_rqst.headers["X-BackendAI-SSO"] = sso_token
if session_id := (request_headers.get("X-BackendAI-SessionID") or aiohttp_session):
kyujin-cho marked this conversation as resolved.
Show resolved Hide resolved
api_rqst.headers["X-BackendAI-SessionID"] = session_id
async with api_rqst.connect_websocket() as up_conn:
down_conn = web.WebSocketResponse()
await down_conn.prepare(request)
Expand Down