Skip to content

Commit

Permalink
Merge pull request #1696 from SciPhi-AI/feature/make-limits-configurable
Browse files Browse the repository at this point in the history
checkin limits implementation
  • Loading branch information
emrgnt-cmplxty authored Dec 16, 2024
2 parents 0bd70fa + 7f29118 commit 1ab9f90
Show file tree
Hide file tree
Showing 10 changed files with 138 additions and 163 deletions.
49 changes: 49 additions & 0 deletions py/core/base/providers/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,21 @@ class PostgresConfigurationSettings(BaseModel):
work_mem: Optional[int] = 4096


class LimitSettings(BaseModel):
global_per_min: Optional[int] = None
route_per_min: Optional[int] = None
monthly_limit: Optional[int] = None

def merge_with_defaults(
self, defaults: "LimitSettings"
) -> "LimitSettings":
return LimitSettings(
global_per_min=self.global_per_min or defaults.global_per_min,
route_per_min=self.route_per_min or defaults.route_per_min,
monthly_limit=self.monthly_limit or defaults.monthly_limit,
)


class DatabaseConfig(ProviderConfig):
"""A base database configuration class"""

Expand Down Expand Up @@ -163,6 +178,13 @@ class DatabaseConfig(ProviderConfig):
)
graph_search_settings: GraphSearchSettings = GraphSearchSettings()

# Rate limits
limits: LimitSettings = LimitSettings(
global_per_min=60, route_per_min=20, monthly_limit=10000
)
route_limits: dict[str, LimitSettings] = {}
user_limits: dict[UUID, dict[str, LimitSettings]] = {}

def __post_init__(self):
self.validate_config()
# Capture additional fields
Expand All @@ -177,6 +199,33 @@ def validate_config(self) -> None:
def supported_providers(self) -> list[str]:
return ["postgres"]

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "DatabaseConfig":
instance = super().from_dict(
data
) # or some logic to create the base instance

limits_data = data.get("limits", {})
default_limits = LimitSettings(
global_per_min=limits_data.get("global_per_min", 60),
route_per_min=limits_data.get("route_per_min", 20),
monthly_limit=limits_data.get("monthly_limit", 10000),
)

instance.limits = default_limits

route_limits_data = limits_data.get("routes", {})
for route_str, route_cfg in route_limits_data.items():
instance.route_limits[route_str] = LimitSettings(**route_cfg)

# user_limits parsing if needed:
# user_limits_data = limits_data.get("users", {})
# for user_str, user_cfg in user_limits_data.items():
# user_id = UUID(user_str)
# instance.user_limits[user_id] = LimitSettings(**user_cfg)

return instance


class DatabaseProvider(Provider):
connection_manager: DatabaseConnectionManager
Expand Down
2 changes: 1 addition & 1 deletion py/core/database/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from fastapi import HTTPException

from core.base import (
Handler,
DatabaseConfig,
Handler,
KGExtractionStatus,
R2RException,
generate_default_user_collection_id,
Expand Down
2 changes: 1 addition & 1 deletion py/core/database/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from fastapi import HTTPException

from core.base import (
Handler,
DocumentResponse,
DocumentType,
Handler,
IngestionStatus,
KGEnrichmentStatus,
KGExtractionStatus,
Expand Down
218 changes: 67 additions & 151 deletions py/core/database/limits.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import logging
from datetime import datetime, timedelta, timezone
from typing import Optional
from typing import Dict, Optional
from uuid import UUID

from core.base import Handler, R2RException
from core.base import Handler

from ..base.providers.database import DatabaseConfig, LimitSettings
from .base import PostgresConnectionManager

logger = logging.getLogger()
logger = logging.getLogger(__name__)


class PostgresLimitsHandler(Handler):
Expand All @@ -17,10 +18,10 @@ def __init__(
self,
project_name: str,
connection_manager: PostgresConnectionManager,
route_limits: dict,
config: DatabaseConfig,
):
super().__init__(project_name, connection_manager)
self.route_limits = route_limits
self.config = config

async def create_tables(self):
query = f"""
Expand Down Expand Up @@ -54,176 +55,91 @@ async def _count_requests(
params = [user_id, since]

result = await self.connection_manager.fetchrow_query(query, params)
return result["count"] if result else 0
count = result["count"] if result else 0
logger.debug(
f"_count_requests(user_id={user_id}, route={route}, since={since.isoformat()}): {count}"
)
return count

async def _count_monthly_requests(self, user_id: UUID) -> int:
now = datetime.now(timezone.utc)
start_of_month = now.replace(
day=1, hour=0, minute=0, second=0, microsecond=0
)
return await self._count_requests(
count = await self._count_requests(
user_id, route=None, since=start_of_month
)
return count

def _determine_limits_for(
self, user_id: UUID, route: str
) -> LimitSettings:
limits = self.config.limits

# Route-specific limits
route_limits = self.config.route_limits.get(route)
if route_limits:
limits = limits.merge_with_defaults(route_limits)

# User-specific limits
user_limits = self.config.user_limits.get(user_id)
if user_limits:
limits = limits.merge_with_defaults(user_limits)
return limits

async def check_limits(self, user_id: UUID, route: str):
limits = self.route_limits.get(
route,
{
"global_per_min": 60,
"route_per_min": 30,
"monthly_limit": 10000,
},
)
# Determine final applicable limits
limits = self._determine_limits_for(user_id, route)
if not limits:
# If no limits found, use defaults
limits = self.config.default_limits

global_per_min = limits["global_per_min"]
route_per_min = limits["route_per_min"]
monthly_limit = limits["monthly_limit"]
global_per_min = limits.global_per_min
route_per_min = limits.route_per_min
monthly_limit = limits.monthly_limit

now = datetime.now(timezone.utc)
one_min_ago = now - timedelta(minutes=1)

logger.info(
f"Checking limits for user_id={user_id}, route={route}, "
f"global_per_min={global_per_min}, route_per_min={route_per_min}, monthly_limit={monthly_limit}, now={now.isoformat()}"
)

# Global per-minute check
user_req_count = await self._count_requests(user_id, None, one_min_ago)
print("min req count = ", user_req_count)
if user_req_count >= global_per_min:
raise ValueError("Global per-minute rate limit exceeded")
if global_per_min is not None:
user_req_count = await self._count_requests(
user_id, None, one_min_ago
)
if user_req_count >= global_per_min:
logger.warning(
f"Global per-minute limit exceeded for user_id={user_id}, route={route}"
)
raise ValueError("Global per-minute rate limit exceeded")

# Per-route per-minute check
route_req_count = await self._count_requests(
user_id, route, one_min_ago
)
if route_req_count >= route_per_min:
raise ValueError("Per-route per-minute rate limit exceeded")
if route_per_min is not None:
route_req_count = await self._count_requests(
user_id, route, one_min_ago
)
if route_req_count >= route_per_min:
logger.warning(
f"Per-route per-minute limit exceeded for user_id={user_id}, route={route}"
)
raise ValueError("Per-route per-minute rate limit exceeded")

# Monthly limit check
monthly_count = await self._count_monthly_requests(user_id)
print("monthly_count = ", monthly_count)

if monthly_count >= monthly_limit:
raise ValueError("Monthly rate limit exceeded")
if monthly_limit is not None:
monthly_count = await self._count_monthly_requests(user_id)
if monthly_count >= monthly_limit:
logger.warning(
f"Monthly limit exceeded for user_id={user_id}, route={route}"
)
raise ValueError("Monthly rate limit exceeded")

async def log_request(self, user_id: UUID, route: str):
query = f"""
INSERT INTO {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (time, user_id, route)
VALUES (CURRENT_TIMESTAMP AT TIME ZONE 'UTC', $1, $2)
"""
await self.connection_manager.execute_query(query, [user_id, route])


# import logging
# from datetime import datetime, timedelta
# from typing import Optional
# from uuid import UUID

# from core.base import Handler, R2RException

# from .base import PostgresConnectionManager

# logger = logging.getLogger()


# class PostgresLimitsHandler(Handler):
# TABLE_NAME = "request_log"

# def __init__(
# self,
# project_name: str,
# connection_manager: PostgresConnectionManager,
# route_limits: dict,
# ):
# super().__init__(project_name, connection_manager)
# self.route_limits = route_limits

# async def create_tables(self):
# """
# Create the request_log table if it doesn't exist.
# """
# query = f"""
# CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (
# time TIMESTAMPTZ NOT NULL,
# user_id UUID NOT NULL,
# route TEXT NOT NULL
# );
# """
# await self.connection_manager.execute_query(query)

# async def _count_requests(
# self, user_id: UUID, route: Optional[str], since: datetime
# ) -> int:
# if route:
# query = f"""
# SELECT COUNT(*)::int
# FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
# WHERE user_id = $1
# AND route = $2
# AND time >= $3
# """
# params = [user_id, route, since]
# else:
# query = f"""
# SELECT COUNT(*)::int
# FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
# WHERE user_id = $1
# AND time >= $2
# """
# params = [user_id, since]

# result = await self.connection_manager.fetchrow_query(query, params)
# return result["count"] if result else 0

# async def _count_monthly_requests(self, user_id: UUID) -> int:
# now = datetime.utcnow()
# start_of_month = now.replace(
# day=1, hour=0, minute=0, second=0, microsecond=0
# )
# return await self._count_requests(
# user_id, route=None, since=start_of_month
# )

# async def check_limits(self, user_id: UUID, route: str):
# """
# Check if the user can proceed with the request, using route-specific limits.
# Raises ValueError if the user exceeded any limit.
# """
# limits = self.route_limits.get(
# route,
# {
# "global_per_min": 60, # default global per min
# "route_per_min": 20, # default route per min
# "monthly_limit": 10000, # default monthly limit
# },
# )

# global_per_min = limits["global_per_min"]
# route_per_min = limits["route_per_min"]
# monthly_limit = limits["monthly_limit"]

# now = datetime.utcnow()
# one_min_ago = now - timedelta(minutes=1)

# # Global per-minute check
# user_req_count = await self._count_requests(user_id, None, one_min_ago)
# print('min req count = ', user_req_count)
# if user_req_count >= global_per_min:
# raise ValueError("Global per-minute rate limit exceeded")

# # Per-route per-minute check
# route_req_count = await self._count_requests(
# user_id, route, one_min_ago
# )
# if route_req_count >= route_per_min:
# raise ValueError("Per-route per-minute rate limit exceeded")

# # Monthly limit check
# monthly_count = await self._count_monthly_requests(user_id)
# print('monthly_count = ', monthly_count)

# if monthly_count >= monthly_limit:
# raise ValueError("Monthly rate limit exceeded")

# async def log_request(self, user_id: UUID, route: str):
# query = f"""
# INSERT INTO {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (time, user_id, route)
# VALUES (NOW(), $1, $2)
# """
# await self.connection_manager.execute_query(query, [user_id, route])
3 changes: 1 addition & 2 deletions py/core/database/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,7 @@ def __init__(
self.limits_handler = PostgresLimitsHandler(
project_name=self.project_name,
connection_manager=self.connection_manager,
# TODO - this should be set in the config
route_limits={},
config=self.config,
)

async def initialize(self):
Expand Down
2 changes: 0 additions & 2 deletions py/core/main/api/v3/base_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,10 @@ async def rate_limit_dependency(

request.state.user_id = user_id
request.state.route = route
print("in rate limit dependency....")
# Yield to run the route
try:
yield
finally:
print("finally....")
# After the route completes successfully, log the request
await self.providers.database.limits_handler.log_request(
user_id, route
Expand Down
6 changes: 3 additions & 3 deletions py/core/main/api/v3/users_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,12 @@ async def register(
auth_user=Depends(self.providers.auth.auth_wrapper),
) -> WrappedUserResponse:
"""Register a new user with the given email and password."""
print('email = ', email)
print('making request.....')
print("email = ", email)
print("making request.....")
registration_response = await self.services["auth"].register(
email, password
)
print('registration_response = ', registration_response)
print("registration_response = ", registration_response)

if name or bio or profile_picture:
return await self.services["auth"].update_user(
Expand Down
Loading

0 comments on commit 1ab9f90

Please sign in to comment.