Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JWT -> Redis #3574

Merged
merged 16 commits into from
Jan 4, 2025
59 changes: 32 additions & 27 deletions backend/ee/onyx/server/middleware/tenant_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
from collections.abc import Awaitable
from collections.abc import Callable

import jwt
from fastapi import FastAPI
from fastapi import HTTPException
from fastapi import Request
from fastapi import Response

from onyx.auth.api_key import extract_tenant_from_api_key_header
from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.db.engine import is_valid_schema_name
from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
Expand All @@ -22,11 +21,11 @@ async def set_tenant_id(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
try:
tenant_id = (
_get_tenant_id_from_request(request, logger)
if MULTI_TENANT
else POSTGRES_DEFAULT_SCHEMA
)
if MULTI_TENANT:
tenant_id = await _get_tenant_id_from_request(request, logger)
else:
tenant_id = POSTGRES_DEFAULT_SCHEMA

CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
return await call_next(request)

Expand All @@ -35,27 +34,36 @@ async def set_tenant_id(
raise


def _get_tenant_id_from_request(request: Request, logger: logging.LoggerAdapter) -> str:
# First check for API key
async def _get_tenant_id_from_request(
request: Request, logger: logging.LoggerAdapter
) -> str:
"""
Attempt to extract tenant_id from:
1) The API key header
2) The Redis-based token (stored in Cookie: fastapiusersauth)
Fallback: POSTGRES_DEFAULT_SCHEMA
"""
# Check for API key
tenant_id = extract_tenant_from_api_key_header(request)
if tenant_id is not None:
if tenant_id:
return tenant_id

# Check for cookie-based auth
token = request.cookies.get("fastapiusersauth")
if not token:
return POSTGRES_DEFAULT_SCHEMA

try:
payload = jwt.decode(
token,
USER_AUTH_SECRET,
audience=["fastapi-users:auth"],
algorithms=["HS256"],
)
tenant_id_from_payload = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
# Look up token data in Redis
token_data = await retrieve_auth_token_data_from_redis(request)

# Since payload.get() can return None, ensure we have a string
if not token_data:
logger.debug(
"Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA"
)
# Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema
# The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA,
# so we maintain consistency by returning it here when no valid tenant is found.
return POSTGRES_DEFAULT_SCHEMA
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the implications of returning POSTGRES_DEFAULT_SCHEMA in the multi tenant case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Auth error (which is what we'd want!)


tenant_id_from_payload = token_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)

# Since token_data.get() can return None, ensure we have a string
tenant_id = (
str(tenant_id_from_payload)
if tenant_id_from_payload is not None
Expand All @@ -67,9 +75,6 @@ def _get_tenant_id_from_request(request: Request, logger: logging.LoggerAdapter)

return tenant_id

except jwt.InvalidTokenError:
return POSTGRES_DEFAULT_SCHEMA

except Exception as e:
logger.error(f"Unexpected error in set_tenant_id_middleware: {str(e)}")
logger.error(f"Unexpected error in _get_tenant_id_from_request: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
4 changes: 2 additions & 2 deletions backend/ee/onyx/server/tenants/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
from onyx.auth.users import auth_backend
from onyx.auth.users import current_admin_user
from onyx.auth.users import get_jwt_strategy
from onyx.auth.users import get_redis_strategy
from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.db.auth import get_user_count
Expand Down Expand Up @@ -112,7 +112,7 @@ async def impersonate_user(
)
if user_to_impersonate is None:
raise HTTPException(status_code=404, detail="User not found")
token = await get_jwt_strategy().write_token(user_to_impersonate)
token = await get_redis_strategy().write_token(user_to_impersonate)

response = await auth_backend.transport.get_login_response(token)
response.set_cookie(
Expand Down
88 changes: 55 additions & 33 deletions backend/onyx/auth/users.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json
import secrets
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime
Expand Down Expand Up @@ -29,10 +31,8 @@
from fastapi_users import UUIDIDMixin
from fastapi_users.authentication import AuthenticationBackend
from fastapi_users.authentication import CookieTransport
from fastapi_users.authentication import JWTStrategy
from fastapi_users.authentication import RedisStrategy
from fastapi_users.authentication import Strategy
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
from fastapi_users.authentication.strategy.db import DatabaseStrategy
from fastapi_users.exceptions import UserAlreadyExists
from fastapi_users.jwt import decode_jwt
from fastapi_users.jwt import generate_jwt
Expand All @@ -59,6 +59,8 @@
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.configs.app_configs import EMAIL_CONFIGURED
from onyx.configs.app_configs import REDIS_AUTH_EXPIRE_TIME_SECONDS
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
from onyx.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
Expand All @@ -73,18 +75,17 @@
from onyx.configs.constants import PASSWORD_SPECIAL_CHARS
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
from onyx.db.api_key import fetch_user_for_api_key
from onyx.db.auth import get_access_token_db
from onyx.db.auth import get_default_admin_user_emails
from onyx.db.auth import get_user_count
from onyx.db.auth import get_user_db
from onyx.db.auth import SQLAlchemyUserAdminDB
from onyx.db.engine import get_async_session
from onyx.db.engine import get_async_session_with_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import AccessToken
from onyx.db.models import OAuthAccount
from onyx.db.models import User
from onyx.db.users import get_user_by_email
from onyx.redis.redis_pool import get_async_redis_connection
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
Expand Down Expand Up @@ -581,49 +582,70 @@ async def get_user_manager(
)


# This strategy is used to add tenant_id to the JWT token
class TenantAwareJWTStrategy(JWTStrategy):
async def _create_token_data(self, user: User, impersonate: bool = False) -> dict:
def get_redis_strategy() -> RedisStrategy:
return TenantAwareRedisStrategy()


class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
"""
A custom strategy that fetches the actual async Redis connection inside each method.
We do NOT pass a synchronous or "coroutine" redis object to the constructor.
"""

def __init__(
self,
lifetime_seconds: Optional[int] = REDIS_AUTH_EXPIRE_TIME_SECONDS,
key_prefix: str = REDIS_AUTH_KEY_PREFIX,
):
self.lifetime_seconds = lifetime_seconds
self.key_prefix = key_prefix

async def write_token(self, user: User) -> str:
redis = await get_async_redis_connection()

tenant_id = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_provision_tenant",
async_return_default_schema,
)(
email=user.email,
)
)(email=user.email)

data = {
token_data = {
"sub": str(user.id),
"aud": self.token_audience,
"tenant_id": tenant_id,
}
return data

async def write_token(self, user: User) -> str:
data = await self._create_token_data(user)
return generate_jwt(
data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm
token = secrets.token_urlsafe()
await redis.set(
f"{self.key_prefix}{token}",
json.dumps(token_data),
ex=self.lifetime_seconds,
)
return token

async def read_token(
self, token: Optional[str], user_manager: BaseUserManager[User, uuid.UUID]
) -> Optional[User]:
redis = await get_async_redis_connection()
token_data_str = await redis.get(f"{self.key_prefix}{token}")
if not token_data_str:
return None

def get_jwt_strategy() -> TenantAwareJWTStrategy:
return TenantAwareJWTStrategy(
secret=USER_AUTH_SECRET,
lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,
)

try:
token_data = json.loads(token_data_str)
user_id = token_data["sub"]
parsed_id = user_manager.parse_id(user_id)
return await user_manager.get(parsed_id)
except (exceptions.UserNotExists, exceptions.InvalidID, KeyError):
return None

def get_database_strategy(
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
) -> DatabaseStrategy:
return DatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
)
async def destroy_token(self, token: str, user: User) -> None:
"""Properly delete the token from async redis."""
redis = await get_async_redis_connection()
await redis.delete(f"{self.key_prefix}{token}")


auth_backend = AuthenticationBackend(
name="jwt", transport=cookie_transport, get_strategy=get_jwt_strategy
) # type: ignore
name="redis", transport=cookie_transport, get_strategy=get_redis_strategy
)


class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
Expand Down
9 changes: 7 additions & 2 deletions backend/onyx/configs/app_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
)

REDIS_AUTH_EXPIRE_TIME_SECONDS = int(
os.environ.get("REDIS_AUTH_EXPIRE_TIME_SECONDS") or 3600
)

SESSION_EXPIRE_TIME_SECONDS = int(
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
) # 7 days
Expand Down Expand Up @@ -188,9 +192,11 @@
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""

# Rate limiting for auth endpoints

REDIS_AUTH_KEY_PREFIX = "fastapi_users_token:"


# Rate limiting for auth endpoints
RATE_LIMIT_WINDOW_SECONDS: int | None = None
_rate_limit_window_seconds_str = os.environ.get("RATE_LIMIT_WINDOW_SECONDS")
if _rate_limit_window_seconds_str is not None:
Expand Down Expand Up @@ -570,7 +576,6 @@
# JWT configuration
JWT_ALGORITHM = "HS256"


#####
# API Key Configs
#####
Expand Down
36 changes: 19 additions & 17 deletions backend/onyx/db/engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
import json
import os
import re
import ssl
Expand All @@ -14,7 +15,6 @@

import asyncpg # type: ignore
import boto3
import jwt
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy import event
Expand All @@ -40,9 +40,9 @@
from onyx.configs.app_configs import POSTGRES_POOL_RECYCLE
from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USER
from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from onyx.configs.constants import SSL_CERT_FILE
from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
Expand Down Expand Up @@ -322,31 +322,33 @@ def provide_iam_token_async(
return _ASYNC_ENGINE


def get_current_tenant_id(request: Request) -> str:
async def get_current_tenant_id(request: Request) -> str:
if not MULTI_TENANT:
tenant_id = POSTGRES_DEFAULT_SCHEMA
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
return tenant_id

token = request.cookies.get("fastapiusersauth")
if not token:
current_value = CURRENT_TENANT_ID_CONTEXTVAR.get()
return current_value

try:
payload = jwt.decode(
token,
USER_AUTH_SECRET,
audience=["fastapi-users:auth"],
algorithms=["HS256"],
)
tenant_id = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
# Look up token data in Redis
token_data = await retrieve_auth_token_data_from_redis(request)

if not token_data:
current_value = CURRENT_TENANT_ID_CONTEXTVAR.get()
logger.debug(
f"Token data not found or expired in Redis, defaulting to {current_value}"
)
return current_value

tenant_id = token_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)

if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID format")

CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
return tenant_id
except jwt.InvalidTokenError:
return CURRENT_TENANT_ID_CONTEXTVAR.get()
except json.JSONDecodeError:
logger.error("Error decoding token data from Redis")
return POSTGRES_DEFAULT_SCHEMA
except Exception as e:
logger.error(f"Unexpected error in get_current_tenant_id: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
Expand Down
Loading
Loading