Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async Redis #3618

Merged
merged 4 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/onyx/configs/app_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@

REDIS_AUTH_KEY_PREFIX = "fastapi_users_token:"


# Rate limiting for auth endpoints
RATE_LIMIT_WINDOW_SECONDS: int | None = None
_rate_limit_window_seconds_str = os.environ.get("RATE_LIMIT_WINDOW_SECONDS")
Expand All @@ -213,6 +212,7 @@
except ValueError:
pass

AUTH_RATE_LIMITING_ENABLED = RATE_LIMIT_MAX_REQUESTS and RATE_LIMIT_WINDOW_SECONDS
# Used for general redis things
REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0))

Expand Down
15 changes: 8 additions & 7 deletions backend/onyx/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from onyx.configs.app_configs import APP_API_PREFIX
from onyx.configs.app_configs import APP_HOST
from onyx.configs.app_configs import APP_PORT
from onyx.configs.app_configs import AUTH_RATE_LIMITING_ENABLED
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
from onyx.configs.app_configs import LOG_ENDPOINT_LATENCY
Expand Down Expand Up @@ -74,9 +75,9 @@
from onyx.server.manage.slack_bot import router as slack_bot_management_router
from onyx.server.manage.users import router as user_router
from onyx.server.middleware.latency_logging import add_latency_logging_middleware
from onyx.server.middleware.rate_limiting import close_limiter
from onyx.server.middleware.rate_limiting import close_auth_limiter
from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
from onyx.server.middleware.rate_limiting import setup_limiter
from onyx.server.middleware.rate_limiting import setup_auth_limiter
from onyx.server.onyx_api.ingestion import router as onyx_api_router
from onyx.server.openai_assistants_api.full_openai_assistants_api import (
get_full_openai_assistants_api_router,
Expand Down Expand Up @@ -174,7 +175,7 @@ def include_auth_router_with_prefix(


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator:
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
# Set recursion limit
if SYSTEM_RECURSION_LIMIT is not None:
sys.setrecursionlimit(SYSTEM_RECURSION_LIMIT)
Expand Down Expand Up @@ -215,13 +216,13 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:

optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})

# Set up rate limiter
await setup_limiter()
if AUTH_RATE_LIMITING_ENABLED:
await setup_auth_limiter()

yield

# Close rate limiter
await close_limiter()
if AUTH_RATE_LIMITING_ENABLED:
await close_auth_limiter()


def log_http_error(_: Request, exc: Exception) -> JSONResponse:
Expand Down
55 changes: 42 additions & 13 deletions backend/onyx/redis/redis_pool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import functools
import json
import ssl
import threading
from collections.abc import Callable
from typing import Any
Expand Down Expand Up @@ -194,10 +195,6 @@ def create_pool(
redis_pool = RedisPool()


def get_redis_client(*, tenant_id: str | None) -> Redis:
return redis_pool.get_client(tenant_id)


# # Usage example
# redis_pool = RedisPool()
# redis_client = redis_pool.get_client()
Expand All @@ -207,6 +204,18 @@ def get_redis_client(*, tenant_id: str | None) -> Redis:
# value = redis_client.get('key')
# print(value.decode()) # Output: 'value'


def get_redis_client(*, tenant_id: str | None) -> Redis:
return redis_pool.get_client(tenant_id)


SSL_CERT_REQS_MAP = {
"none": ssl.CERT_NONE,
"optional": ssl.CERT_OPTIONAL,
"required": ssl.CERT_REQUIRED,
}


_async_redis_connection: aioredis.Redis | None = None
_async_lock = asyncio.Lock()

Expand All @@ -224,15 +233,35 @@ async def get_async_redis_connection() -> aioredis.Redis:
async with _async_lock:
# Double-check inside the lock to avoid race conditions
if _async_redis_connection is None:
scheme = "rediss" if REDIS_SSL else "redis"
url = f"{scheme}://{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER}"

# Create a new Redis connection (or connection pool) from the URL
_async_redis_connection = aioredis.from_url(
url,
password=REDIS_PASSWORD,
max_connections=REDIS_POOL_MAX_CONNECTIONS,
)
# Load env vars or your config variables

connection_kwargs: dict[str, Any] = {
"host": REDIS_HOST,
"port": REDIS_PORT,
"db": REDIS_DB_NUMBER,
"password": REDIS_PASSWORD,
"max_connections": REDIS_POOL_MAX_CONNECTIONS,
"health_check_interval": REDIS_HEALTH_CHECK_INTERVAL,
"socket_keepalive": True,
"socket_keepalive_options": REDIS_SOCKET_KEEPALIVE_OPTIONS,
}

if REDIS_SSL:
ssl_context = ssl.create_default_context()

if REDIS_SSL_CA_CERTS:
ssl_context.load_verify_locations(REDIS_SSL_CA_CERTS)
ssl_context.check_hostname = False

# Map your string to the proper ssl.CERT_* constant
ssl_context.verify_mode = SSL_CERT_REQS_MAP.get(
REDIS_SSL_CERT_REQS, ssl.CERT_NONE
)

connection_kwargs["ssl"] = ssl_context

# Create a new Redis connection (or connection pool) with SSL configuration
_async_redis_connection = aioredis.Redis(**connection_kwargs)

# Return the established connection (or pool) for all future operations
return _async_redis_connection
Expand Down
11 changes: 6 additions & 5 deletions backend/onyx/server/middleware/rate_limiting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,19 @@
from fastapi_limiter import FastAPILimiter
from fastapi_limiter.depends import RateLimiter

from onyx.configs.app_configs import AUTH_RATE_LIMITING_ENABLED
from onyx.configs.app_configs import RATE_LIMIT_MAX_REQUESTS
from onyx.configs.app_configs import RATE_LIMIT_WINDOW_SECONDS
from onyx.redis.redis_pool import get_async_redis_connection


async def setup_limiter() -> None:
async def setup_auth_limiter() -> None:
# Use the centralized async Redis connection
redis = await get_async_redis_connection()
await FastAPILimiter.init(redis)


async def close_limiter() -> None:
async def close_auth_limiter() -> None:
# This closes the FastAPILimiter connection so we don't leave open connections to Redis.
await FastAPILimiter.close()

Expand All @@ -32,14 +33,14 @@ async def rate_limit_key(request: Request) -> str:


def get_auth_rate_limiters() -> List[Callable]:
if not (RATE_LIMIT_MAX_REQUESTS and RATE_LIMIT_WINDOW_SECONDS):
if not AUTH_RATE_LIMITING_ENABLED:
return []

return [
Depends(
RateLimiter(
times=RATE_LIMIT_MAX_REQUESTS,
seconds=RATE_LIMIT_WINDOW_SECONDS,
times=RATE_LIMIT_MAX_REQUESTS or 100,
seconds=RATE_LIMIT_WINDOW_SECONDS or 60,
# Use the custom key function to distinguish users
identifier=rate_limit_key,
)
Expand Down
Loading