Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: enhance API token validation with session locking and last used timestamp update #12426

Merged
merged 7 commits into from
Jan 7, 2025
37 changes: 22 additions & 15 deletions api/controllers/service_api/wraps.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Callable
from datetime import UTC, datetime
from datetime import UTC, datetime, timedelta
from enum import Enum
from functools import wraps
from typing import Optional
Expand All @@ -8,6 +8,8 @@
from flask_login import user_logged_in # type: ignore
from flask_restful import Resource # type: ignore
from pydantic import BaseModel
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, Unauthorized

from extensions.ext_database import db
Expand Down Expand Up @@ -174,7 +176,7 @@ def decorated(*args, **kwargs):
return decorator


def validate_and_get_api_token(scope=None):
def validate_and_get_api_token(scope: str | None = None):
"""
Validate and get API token.
"""
Expand All @@ -188,20 +190,25 @@ def validate_and_get_api_token(scope=None):
if auth_scheme != "bearer":
raise Unauthorized("Authorization scheme must be 'Bearer'")

api_token = (
db.session.query(ApiToken)
.filter(
ApiToken.token == auth_token,
ApiToken.type == scope,
current_time = datetime.now(UTC).replace(tzinfo=None)
cutoff_time = current_time - timedelta(minutes=1)
with Session(db.engine, expire_on_commit=False) as session:
update_stmt = (
update(ApiToken)
.where(ApiToken.token == auth_token, ApiToken.last_used_at < cutoff_time, ApiToken.type == scope)
.values(last_used_at=current_time)
.returning(ApiToken)
)
.first()
)

if not api_token:
raise Unauthorized("Access token is invalid")

api_token.last_used_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
result = session.execute(update_stmt)
api_token = result.scalar_one_or_none()

if not api_token:
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
api_token = session.scalar(stmt)
if not api_token:
raise Unauthorized("Access token is invalid")
else:
session.commit()

return api_token

Expand Down
1 change: 1 addition & 0 deletions api/docker/entrypoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ else
--bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \
--workers ${SERVER_WORKER_AMOUNT:-1} \
--worker-class ${SERVER_WORKER_CLASS:-gevent} \
--worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \
--timeout ${GUNICORN_TIMEOUT:-200} \
app:app
fi
Expand Down
8 changes: 4 additions & 4 deletions api/services/billing_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Optional
from typing import Literal, Optional

import httpx
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
Expand All @@ -17,7 +17,6 @@ def get_info(cls, tenant_id: str):
params = {"tenant_id": tenant_id}

billing_info = cls._send_request("GET", "/subscription/info", params=params)

return billing_info

@classmethod
Expand Down Expand Up @@ -47,12 +46,13 @@ def get_invoices(cls, prefilled_email: str = "", tenant_id: str = ""):
retry=retry_if_exception_type(httpx.RequestError),
reraise=True,
)
def _send_request(cls, method, endpoint, json=None, params=None):
def _send_request(cls, method: Literal["GET", "POST", "DELETE"], endpoint: str, json=None, params=None):
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}

url = f"{cls.base_url}{endpoint}"
response = httpx.request(method, url, json=json, params=params, headers=headers)

if method == "GET" and response.status_code != httpx.codes.OK:
raise ValueError("Unable to retrieve billing information. Please try again later or contact support.")
return response.json()

@staticmethod
Expand Down
7 changes: 5 additions & 2 deletions docker/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,13 @@ DIFY_PORT=5001
# The number of API server workers, i.e., the number of workers.
# Formula: number of cpu cores x 2 + 1 for sync, 1 for Gevent
# Reference: https://docs.gunicorn.org/en/stable/design.html#how-many-workers
SERVER_WORKER_AMOUNT=
SERVER_WORKER_AMOUNT=1

# Defaults to gevent. If using windows, it can be switched to sync or solo.
SERVER_WORKER_CLASS=
SERVER_WORKER_CLASS=gevent

# Default number of worker connections, the default is 10.
SERVER_WORKER_CONNECTIONS=10

# Similar to SERVER_WORKER_CLASS.
# If using windows, it can be switched to sync or solo.
Expand Down
5 changes: 3 additions & 2 deletions docker/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ x-shared-env: &shared-api-worker-env
APP_MAX_EXECUTION_TIME: ${APP_MAX_EXECUTION_TIME:-1200}
DIFY_BIND_ADDRESS: ${DIFY_BIND_ADDRESS:-0.0.0.0}
DIFY_PORT: ${DIFY_PORT:-5001}
SERVER_WORKER_AMOUNT: ${SERVER_WORKER_AMOUNT:-}
SERVER_WORKER_CLASS: ${SERVER_WORKER_CLASS:-}
SERVER_WORKER_AMOUNT: ${SERVER_WORKER_AMOUNT:-1}
SERVER_WORKER_CLASS: ${SERVER_WORKER_CLASS:-gevent}
SERVER_WORKER_CONNECTIONS: ${SERVER_WORKER_CONNECTIONS:-10}
CELERY_WORKER_CLASS: ${CELERY_WORKER_CLASS:-}
GUNICORN_TIMEOUT: ${GUNICORN_TIMEOUT:-360}
CELERY_WORKER_AMOUNT: ${CELERY_WORKER_AMOUNT:-}
Expand Down
Loading