Skip to content

Commit

Permalink
Simple initial rate limiting implementation (#4976)
Browse files Browse the repository at this point in the history
  • Loading branch information
rbren authored Nov 19, 2024
1 parent c9ed9b1 commit 3c61a95
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 1 deletion.
11 changes: 10 additions & 1 deletion openhands/server/listen.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,12 @@
from openhands.llm import bedrock
from openhands.runtime.base import Runtime
from openhands.server.auth.auth import get_sid_from_token, sign_token
from openhands.server.middleware import LocalhostCORSMiddleware, NoCacheMiddleware
from openhands.server.middleware import (
InMemoryRateLimiter,
LocalhostCORSMiddleware,
NoCacheMiddleware,
RateLimitMiddleware,
)
from openhands.server.session import SessionManager

load_dotenv()
Expand All @@ -84,6 +89,10 @@


app.add_middleware(NoCacheMiddleware)
app.add_middleware(
RateLimitMiddleware, rate_limiter=InMemoryRateLimiter(requests=2, seconds=1)
)


security_scheme = HTTPBearer()

Expand Down
57 changes: 57 additions & 0 deletions openhands/server/middleware.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import asyncio
from collections import defaultdict
from datetime import datetime, timedelta
from urllib.parse import urlparse

from fastapi import Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp

Expand Down Expand Up @@ -41,3 +46,55 @@ async def dispatch(self, request, call_next):
response.headers['Pragma'] = 'no-cache'
response.headers['Expires'] = '0'
return response


class InMemoryRateLimiter:
history: dict
requests: int
seconds: int
sleep_seconds: int

def __init__(self, requests: int = 2, seconds: int = 1, sleep_seconds: int = 1):
self.requests = requests
self.seconds = seconds
self.history = defaultdict(list)

def _clean_old_requests(self, key: str) -> None:
now = datetime.now()
cutoff = now - timedelta(seconds=self.seconds)
self.history[key] = [ts for ts in self.history[key] if ts > cutoff]

async def __call__(self, request: Request) -> bool:
key = request.client.host
now = datetime.now()

self._clean_old_requests(key)

self.history[key].append(now)

if len(self.history[key]) > self.requests * 2:
return False
elif len(self.history[key]) > self.requests:
if self.sleep_seconds > 0:
await asyncio.sleep(self.sleep_seconds)
return True
else:
return False

return True


class RateLimitMiddleware(BaseHTTPMiddleware):
def __init__(self, app: ASGIApp, rate_limiter: InMemoryRateLimiter):
super().__init__(app)
self.rate_limiter = rate_limiter

async def dispatch(self, request, call_next):
ok = await self.rate_limiter(request)
if not ok:
return JSONResponse(
status_code=429,
content={'message': 'Too many requests'},
headers={'Retry-After': '1'},
)
return await call_next(request)

0 comments on commit 3c61a95

Please sign in to comment.