diff --git a/backend/ee/onyx/server/middleware/tenant_tracking.py b/backend/ee/onyx/server/middleware/tenant_tracking.py index 8729c12418f..528219e3b8c 100644 --- a/backend/ee/onyx/server/middleware/tenant_tracking.py +++ b/backend/ee/onyx/server/middleware/tenant_tracking.py @@ -2,15 +2,14 @@ from collections.abc import Awaitable from collections.abc import Callable -import jwt from fastapi import FastAPI from fastapi import HTTPException from fastapi import Request from fastapi import Response from onyx.auth.api_key import extract_tenant_from_api_key_header -from onyx.configs.app_configs import USER_AUTH_SECRET from onyx.db.engine import is_valid_schema_name +from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR @@ -22,11 +21,11 @@ async def set_tenant_id( request: Request, call_next: Callable[[Request], Awaitable[Response]] ) -> Response: try: - tenant_id = ( - _get_tenant_id_from_request(request, logger) - if MULTI_TENANT - else POSTGRES_DEFAULT_SCHEMA - ) + if MULTI_TENANT: + tenant_id = await _get_tenant_id_from_request(request, logger) + else: + tenant_id = POSTGRES_DEFAULT_SCHEMA + CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) return await call_next(request) @@ -35,27 +34,36 @@ async def set_tenant_id( raise -def _get_tenant_id_from_request(request: Request, logger: logging.LoggerAdapter) -> str: - # First check for API key +async def _get_tenant_id_from_request( + request: Request, logger: logging.LoggerAdapter +) -> str: + """ + Attempt to extract tenant_id from: + 1) The API key header + 2) The Redis-based token (stored in Cookie: fastapiusersauth) + Fallback: POSTGRES_DEFAULT_SCHEMA + """ + # Check for API key tenant_id = extract_tenant_from_api_key_header(request) - if tenant_id is not None: + if tenant_id: return tenant_id - # Check for cookie-based auth - token = request.cookies.get("fastapiusersauth") - if not token: - return POSTGRES_DEFAULT_SCHEMA - try: - payload = jwt.decode( - token, - USER_AUTH_SECRET, - audience=["fastapi-users:auth"], - algorithms=["HS256"], - ) - tenant_id_from_payload = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA) + # Look up token data in Redis + token_data = await retrieve_auth_token_data_from_redis(request) - # Since payload.get() can return None, ensure we have a string + if not token_data: + logger.debug( + "Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA" + ) + # Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema + # The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA, + # so we maintain consistency by returning it here when no valid tenant is found. + return POSTGRES_DEFAULT_SCHEMA + + tenant_id_from_payload = token_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA) + + # Since token_data.get() can return None, ensure we have a string tenant_id = ( str(tenant_id_from_payload) if tenant_id_from_payload is not None @@ -67,9 +75,6 @@ def _get_tenant_id_from_request(request: Request, logger: logging.LoggerAdapter) return tenant_id - except jwt.InvalidTokenError: - return POSTGRES_DEFAULT_SCHEMA - except Exception as e: - logger.error(f"Unexpected error in set_tenant_id_middleware: {str(e)}") + logger.error(f"Unexpected error in _get_tenant_id_from_request: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") diff --git a/backend/ee/onyx/server/tenants/api.py b/backend/ee/onyx/server/tenants/api.py index 3d646bbb151..f4438523977 100644 --- a/backend/ee/onyx/server/tenants/api.py +++ b/backend/ee/onyx/server/tenants/api.py @@ -19,7 +19,7 @@ from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant from onyx.auth.users import auth_backend from onyx.auth.users import current_admin_user -from onyx.auth.users import get_jwt_strategy +from onyx.auth.users import get_redis_strategy from onyx.auth.users import User from onyx.configs.app_configs import WEB_DOMAIN from onyx.db.auth import get_user_count @@ -112,7 +112,7 @@ async def impersonate_user( ) if user_to_impersonate is None: raise HTTPException(status_code=404, detail="User not found") - token = await get_jwt_strategy().write_token(user_to_impersonate) + token = await get_redis_strategy().write_token(user_to_impersonate) response = await auth_backend.transport.get_login_response(token) response.set_cookie( diff --git a/backend/onyx/auth/users.py b/backend/onyx/auth/users.py index e1a5885c2b2..fb3e5e8eb65 100644 --- a/backend/onyx/auth/users.py +++ b/backend/onyx/auth/users.py @@ -1,3 +1,5 @@ +import json +import secrets import uuid from collections.abc import AsyncGenerator from datetime import datetime @@ -29,10 +31,8 @@ from fastapi_users import UUIDIDMixin from fastapi_users.authentication import AuthenticationBackend from fastapi_users.authentication import CookieTransport -from fastapi_users.authentication import JWTStrategy +from fastapi_users.authentication import RedisStrategy from fastapi_users.authentication import Strategy -from fastapi_users.authentication.strategy.db import AccessTokenDatabase -from fastapi_users.authentication.strategy.db import DatabaseStrategy from fastapi_users.exceptions import UserAlreadyExists from fastapi_users.jwt import decode_jwt from fastapi_users.jwt import generate_jwt @@ -59,6 +59,8 @@ from onyx.configs.app_configs import AUTH_TYPE from onyx.configs.app_configs import DISABLE_AUTH from onyx.configs.app_configs import EMAIL_CONFIGURED +from onyx.configs.app_configs import REDIS_AUTH_EXPIRE_TIME_SECONDS +from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX from onyx.configs.app_configs import REQUIRE_EMAIL_VERIFICATION from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY @@ -73,7 +75,6 @@ from onyx.configs.constants import PASSWORD_SPECIAL_CHARS from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER from onyx.db.api_key import fetch_user_for_api_key -from onyx.db.auth import get_access_token_db from onyx.db.auth import get_default_admin_user_emails from onyx.db.auth import get_user_count from onyx.db.auth import get_user_db @@ -81,10 +82,10 @@ from onyx.db.engine import get_async_session from onyx.db.engine import get_async_session_with_tenant from onyx.db.engine import get_session_with_tenant -from onyx.db.models import AccessToken from onyx.db.models import OAuthAccount from onyx.db.models import User from onyx.db.users import get_user_by_email +from onyx.redis.redis_pool import get_async_redis_connection from onyx.redis.redis_pool import get_redis_client from onyx.utils.logger import setup_logger from onyx.utils.telemetry import create_milestone_and_report @@ -581,49 +582,70 @@ async def get_user_manager( ) -# This strategy is used to add tenant_id to the JWT token -class TenantAwareJWTStrategy(JWTStrategy): - async def _create_token_data(self, user: User, impersonate: bool = False) -> dict: +def get_redis_strategy() -> RedisStrategy: + return TenantAwareRedisStrategy() + + +class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]): + """ + A custom strategy that fetches the actual async Redis connection inside each method. + We do NOT pass a synchronous or "coroutine" redis object to the constructor. + """ + + def __init__( + self, + lifetime_seconds: Optional[int] = REDIS_AUTH_EXPIRE_TIME_SECONDS, + key_prefix: str = REDIS_AUTH_KEY_PREFIX, + ): + self.lifetime_seconds = lifetime_seconds + self.key_prefix = key_prefix + + async def write_token(self, user: User) -> str: + redis = await get_async_redis_connection() + tenant_id = await fetch_ee_implementation_or_noop( "onyx.server.tenants.provisioning", "get_or_provision_tenant", async_return_default_schema, - )( - email=user.email, - ) + )(email=user.email) - data = { + token_data = { "sub": str(user.id), - "aud": self.token_audience, "tenant_id": tenant_id, } - return data - - async def write_token(self, user: User) -> str: - data = await self._create_token_data(user) - return generate_jwt( - data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm + token = secrets.token_urlsafe() + await redis.set( + f"{self.key_prefix}{token}", + json.dumps(token_data), + ex=self.lifetime_seconds, ) + return token + async def read_token( + self, token: Optional[str], user_manager: BaseUserManager[User, uuid.UUID] + ) -> Optional[User]: + redis = await get_async_redis_connection() + token_data_str = await redis.get(f"{self.key_prefix}{token}") + if not token_data_str: + return None -def get_jwt_strategy() -> TenantAwareJWTStrategy: - return TenantAwareJWTStrategy( - secret=USER_AUTH_SECRET, - lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS, - ) - + try: + token_data = json.loads(token_data_str) + user_id = token_data["sub"] + parsed_id = user_manager.parse_id(user_id) + return await user_manager.get(parsed_id) + except (exceptions.UserNotExists, exceptions.InvalidID, KeyError): + return None -def get_database_strategy( - access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db), -) -> DatabaseStrategy: - return DatabaseStrategy( - access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore - ) + async def destroy_token(self, token: str, user: User) -> None: + """Properly delete the token from async redis.""" + redis = await get_async_redis_connection() + await redis.delete(f"{self.key_prefix}{token}") auth_backend = AuthenticationBackend( - name="jwt", transport=cookie_transport, get_strategy=get_jwt_strategy -) # type: ignore + name="redis", transport=cookie_transport, get_strategy=get_redis_strategy +) class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]): diff --git a/backend/onyx/configs/app_configs.py b/backend/onyx/configs/app_configs.py index 21b74153ed6..0a5d4789995 100644 --- a/backend/onyx/configs/app_configs.py +++ b/backend/onyx/configs/app_configs.py @@ -54,6 +54,10 @@ os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false" ) +REDIS_AUTH_EXPIRE_TIME_SECONDS = int( + os.environ.get("REDIS_AUTH_EXPIRE_TIME_SECONDS") or 3600 +) + SESSION_EXPIRE_TIME_SECONDS = int( os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7 ) # 7 days @@ -188,9 +192,11 @@ REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379)) REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or "" -# Rate limiting for auth endpoints +REDIS_AUTH_KEY_PREFIX = "fastapi_users_token:" + +# Rate limiting for auth endpoints RATE_LIMIT_WINDOW_SECONDS: int | None = None _rate_limit_window_seconds_str = os.environ.get("RATE_LIMIT_WINDOW_SECONDS") if _rate_limit_window_seconds_str is not None: @@ -570,7 +576,6 @@ # JWT configuration JWT_ALGORITHM = "HS256" - ##### # API Key Configs ##### diff --git a/backend/onyx/db/engine.py b/backend/onyx/db/engine.py index 86519fa3492..4176b954205 100644 --- a/backend/onyx/db/engine.py +++ b/backend/onyx/db/engine.py @@ -1,4 +1,5 @@ import contextlib +import json import os import re import ssl @@ -14,7 +15,6 @@ import asyncpg # type: ignore import boto3 -import jwt from fastapi import HTTPException from fastapi import Request from sqlalchemy import event @@ -40,9 +40,9 @@ from onyx.configs.app_configs import POSTGRES_POOL_RECYCLE from onyx.configs.app_configs import POSTGRES_PORT from onyx.configs.app_configs import POSTGRES_USER -from onyx.configs.app_configs import USER_AUTH_SECRET from onyx.configs.constants import POSTGRES_UNKNOWN_APP_NAME from onyx.configs.constants import SSL_CERT_FILE +from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis from onyx.server.utils import BasicAuthenticationError from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT @@ -322,31 +322,33 @@ def provide_iam_token_async( return _ASYNC_ENGINE -def get_current_tenant_id(request: Request) -> str: +async def get_current_tenant_id(request: Request) -> str: if not MULTI_TENANT: tenant_id = POSTGRES_DEFAULT_SCHEMA CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) return tenant_id - token = request.cookies.get("fastapiusersauth") - if not token: - current_value = CURRENT_TENANT_ID_CONTEXTVAR.get() - return current_value - try: - payload = jwt.decode( - token, - USER_AUTH_SECRET, - audience=["fastapi-users:auth"], - algorithms=["HS256"], - ) - tenant_id = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA) + # Look up token data in Redis + token_data = await retrieve_auth_token_data_from_redis(request) + + if not token_data: + current_value = CURRENT_TENANT_ID_CONTEXTVAR.get() + logger.debug( + f"Token data not found or expired in Redis, defaulting to {current_value}" + ) + return current_value + + tenant_id = token_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA) + if not is_valid_schema_name(tenant_id): raise HTTPException(status_code=400, detail="Invalid tenant ID format") + CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) return tenant_id - except jwt.InvalidTokenError: - return CURRENT_TENANT_ID_CONTEXTVAR.get() + except json.JSONDecodeError: + logger.error("Error decoding token data from Redis") + return POSTGRES_DEFAULT_SCHEMA except Exception as e: logger.error(f"Unexpected error in get_current_tenant_id: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") diff --git a/backend/onyx/redis/redis_pool.py b/backend/onyx/redis/redis_pool.py index c118d410dcd..6493f7927f0 100644 --- a/backend/onyx/redis/redis_pool.py +++ b/backend/onyx/redis/redis_pool.py @@ -1,14 +1,17 @@ import asyncio import functools +import json import threading from collections.abc import Callable from typing import Any from typing import Optional import redis +from fastapi import Request from redis import asyncio as aioredis from redis.client import Redis +from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX from onyx.configs.app_configs import REDIS_DB_NUMBER from onyx.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL from onyx.configs.app_configs import REDIS_HOST @@ -228,3 +231,31 @@ async def get_async_redis_connection() -> aioredis.Redis: # Return the established connection (or pool) for all future operations return _async_redis_connection + + +async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None: + token = request.cookies.get("fastapiusersauth") + if not token: + logger.debug("No auth token cookie found") + return None + + try: + redis = await get_async_redis_connection() + redis_key = REDIS_AUTH_KEY_PREFIX + token + token_data_str = await redis.get(redis_key) + + if not token_data_str: + logger.debug(f"Token key {redis_key} not found or expired in Redis") + return None + + return json.loads(token_data_str) + except json.JSONDecodeError: + logger.error("Error decoding token data from Redis") + return None + except Exception as e: + logger.error( + f"Unexpected error in retrieve_auth_token_data_from_redis: {str(e)}" + ) + raise ValueError( + f"Unexpected error in retrieve_auth_token_data_from_redis: {str(e)}" + ) diff --git a/web/src/app/admin/settings/SettingsForm.tsx b/web/src/app/admin/settings/SettingsForm.tsx index df6889012c8..aed6f004741 100644 --- a/web/src/app/admin/settings/SettingsForm.tsx +++ b/web/src/app/admin/settings/SettingsForm.tsx @@ -11,6 +11,7 @@ import React, { useContext, useState, useEffect } from "react"; import { SettingsContext } from "@/components/settings/SettingsProvider"; import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled"; import { Modal } from "@/components/Modal"; +import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants"; export function Checkbox({ label, @@ -218,14 +219,19 @@ export function SettingsForm() { handleToggleSettingsField("auto_scroll", e.target.checked) } /> - - handleToggleSettingsField("anonymous_user_enabled", e.target.checked) - } - /> + {!NEXT_PUBLIC_CLOUD_ENABLED && ( + + handleToggleSettingsField( + "anonymous_user_enabled", + e.target.checked + ) + } + /> + )} {showConfirmModal && ( or
+
- - Create an account - - {NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED && ( => { }; export const logout = async (): Promise => { - const response = await fetch("/auth/logout", { + const response = await fetch("/api/auth/logout", { method: "POST", credentials: "include", });