Skip to content

Commit

Permalink
thread utils respect contextvars (#4074)
Browse files Browse the repository at this point in the history
* thread utils respect contextvars now

* address pablo comments

* removed tenant id from places it was already being passed

* fix rate limit check and pablo comment
  • Loading branch information
evan-danswer authored Feb 24, 2025
1 parent 1f2af37 commit 4a4e4a6
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 49 deletions.
22 changes: 11 additions & 11 deletions backend/ee/onyx/server/query_and_chat/token_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from sqlalchemy.orm import Session

from onyx.db.api_key import is_api_key_email_address
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
from onyx.db.models import TokenRateLimit
Expand All @@ -28,21 +28,21 @@
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel


def _check_token_rate_limits(user: User | None, tenant_id: str) -> None:
def _check_token_rate_limits(user: User | None) -> None:
if user is None:
# Unauthenticated users are only rate limited by global settings
_user_is_rate_limited_by_global(tenant_id)
_user_is_rate_limited_by_global()

elif is_api_key_email_address(user.email):
# API keys are only rate limited by global settings
_user_is_rate_limited_by_global(tenant_id)
_user_is_rate_limited_by_global()

else:
run_functions_tuples_in_parallel(
[
(_user_is_rate_limited, (user.id, tenant_id)),
(_user_is_rate_limited_by_group, (user.id, tenant_id)),
(_user_is_rate_limited_by_global, (tenant_id,)),
(_user_is_rate_limited, (user.id,)),
(_user_is_rate_limited_by_group, (user.id,)),
(_user_is_rate_limited_by_global, ()),
]
)

Expand All @@ -52,8 +52,8 @@ def _check_token_rate_limits(user: User | None, tenant_id: str) -> None:
"""


def _user_is_rate_limited(user_id: UUID, tenant_id: str) -> None:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
def _user_is_rate_limited(user_id: UUID) -> None:
with get_session_with_current_tenant() as db_session:
user_rate_limits = fetch_all_user_token_rate_limits(
db_session=db_session, enabled_only=True, ordered=False
)
Expand Down Expand Up @@ -93,8 +93,8 @@ def _fetch_user_usage(
"""


def _user_is_rate_limited_by_group(user_id: UUID, tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
def _user_is_rate_limited_by_group(user_id: UUID) -> None:
with get_session_with_current_tenant() as db_session:
group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session)

if group_rate_limits:
Expand Down
1 change: 0 additions & 1 deletion backend/onyx/chat/process_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,6 @@ def stream_chat_message_objects(
for img in img_generation_response
if img.image_data
],
tenant_id=tenant_id,
)
info.ai_message_files.extend(
[
Expand Down
26 changes: 12 additions & 14 deletions backend/onyx/file_store/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sqlalchemy.orm import Session

from onyx.configs.constants import FileOrigin
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import ChatMessage
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import FileDescriptor
Expand Down Expand Up @@ -53,11 +53,11 @@ def load_all_chat_files(
return files


def save_file_from_url(url: str, tenant_id: str) -> str:
def save_file_from_url(url: str) -> str:
"""NOTE: using multiple sessions here, since this is often called
using multithreading. In practice, sharing a session has resulted in
weird errors."""
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
response = requests.get(url)
response.raise_for_status()

Expand All @@ -75,8 +75,8 @@ def save_file_from_url(url: str, tenant_id: str) -> str:
return unique_id


def save_file_from_base64(base64_string: str, tenant_id: str) -> str:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
def save_file_from_base64(base64_string: str) -> str:
with get_session_with_current_tenant() as db_session:
unique_id = str(uuid4())
file_store = get_default_file_store(db_session)
file_store.save_file(
Expand All @@ -90,14 +90,12 @@ def save_file_from_base64(base64_string: str, tenant_id: str) -> str:


def save_file(
tenant_id: str,
url: str | None = None,
base64_data: str | None = None,
) -> str:
"""Save a file from either a URL or base64 encoded string.
Args:
tenant_id: The tenant ID to save the file under
url: URL to download file from
base64_data: Base64 encoded file data
Expand All @@ -111,22 +109,22 @@ def save_file(
raise ValueError("Cannot specify both url and base64_data")

if url is not None:
return save_file_from_url(url, tenant_id)
return save_file_from_url(url)
elif base64_data is not None:
return save_file_from_base64(base64_data, tenant_id)
return save_file_from_base64(base64_data)
else:
raise ValueError("Must specify either url or base64_data")


def save_files(urls: list[str], base64_files: list[str], tenant_id: str) -> list[str]:
def save_files(urls: list[str], base64_files: list[str]) -> list[str]:
# NOTE: be explicit about typing so that if we change things, we get notified
funcs: list[
tuple[
Callable[[str, str | None, str | None], str],
tuple[str, str | None, str | None],
Callable[[str | None, str | None], str],
tuple[str | None, str | None],
]
] = [(save_file, (tenant_id, url, None)) for url in urls] + [
(save_file, (tenant_id, None, base64_file)) for base64_file in base64_files
] = [(save_file, (url, None)) for url in urls] + [
(save_file, (None, base64_file)) for base64_file in base64_files
]

return run_functions_tuples_in_parallel(funcs)
14 changes: 6 additions & 8 deletions backend/onyx/server/query_and_chat/token_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@

from onyx.auth.users import current_chat_accesssible_user
from onyx.db.engine import get_session_context_manager
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
from onyx.db.models import TokenRateLimit
from onyx.db.models import User
from onyx.db.token_limit import fetch_all_global_token_rate_limits
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.contextvars import get_current_tenant_id


logger = setup_logger()
Expand All @@ -39,22 +37,22 @@ def check_token_rate_limits(
return

versioned_rate_limit_strategy = fetch_versioned_implementation(
"onyx.server.query_and_chat.token_limit", "_check_token_rate_limits"
"onyx.server.query_and_chat.token_limit", _check_token_rate_limits.__name__
)
return versioned_rate_limit_strategy(user, get_current_tenant_id())
return versioned_rate_limit_strategy(user)


def _check_token_rate_limits(_: User | None, tenant_id: str | None) -> None:
_user_is_rate_limited_by_global(tenant_id)
def _check_token_rate_limits(_: User | None) -> None:
_user_is_rate_limited_by_global()


"""
Global rate limits
"""


def _user_is_rate_limited_by_global(tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
def _user_is_rate_limited_by_global() -> None:
with get_session_context_manager() as db_session:
global_rate_limits = fetch_all_global_token_rate_limits(
db_session=db_session, enabled_only=True, ordered=False
)
Expand Down
25 changes: 10 additions & 15 deletions backend/onyx/utils/threadpool_concurrency.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextvars
import threading
import uuid
from collections.abc import Callable
Expand All @@ -14,10 +15,6 @@
R = TypeVar("R")


# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
# is not safe, update this comment.
def run_functions_tuples_in_parallel(
functions_with_args: list[tuple[Callable, tuple]],
allow_failures: bool = False,
Expand Down Expand Up @@ -45,8 +42,11 @@ def run_functions_tuples_in_parallel(

results = []
with ThreadPoolExecutor(max_workers=workers) as executor:
# The primary reason for propagating contextvars is to allow acquiring a db session
# that respects tenant id. Context.run is expected to be low-overhead, but if we later
# find that it is increasing latency we can make using it optional.
future_to_index = {
executor.submit(func, *args): i
executor.submit(contextvars.copy_context().run, func, *args): i
for i, (func, args) in enumerate(functions_with_args)
}

Expand Down Expand Up @@ -83,10 +83,6 @@ def execute(self) -> R:
return self.func(*self.args, **self.kwargs)


# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
# is not safe, update this comment.
def run_functions_in_parallel(
function_calls: list[FunctionCall],
allow_failures: bool = False,
Expand All @@ -102,7 +98,9 @@ def run_functions_in_parallel(

with ThreadPoolExecutor(max_workers=len(function_calls)) as executor:
future_to_id = {
executor.submit(func_call.execute): func_call.result_id
executor.submit(
contextvars.copy_context().run, func_call.execute
): func_call.result_id
for func_call in function_calls
}

Expand Down Expand Up @@ -143,18 +141,15 @@ def end(self) -> None:
)


# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
# is not safe, update this comment.
def run_with_timeout(
timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
) -> R:
"""
Executes a function with a timeout. If the function doesn't complete within the specified
timeout, raises TimeoutError.
"""
task = TimeoutThread(timeout, func, *args, **kwargs)
context = contextvars.copy_context()
task = TimeoutThread(timeout, context.run, func, *args, **kwargs)
task.start()
task.join(timeout)

Expand Down
131 changes: 131 additions & 0 deletions backend/tests/unit/onyx/utils/test_threadpool_contextvars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import contextvars
import time

from onyx.utils.threadpool_concurrency import FunctionCall
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.threadpool_concurrency import run_with_timeout

# Create a test contextvar
test_var = contextvars.ContextVar("test_var", default="default")


def get_contextvar_value() -> str:
"""Helper function that runs in a thread and returns the contextvar value"""
# Add a small sleep to ensure we're actually running in a different thread
time.sleep(0.1)
return test_var.get()


def test_run_with_timeout_preserves_contextvar() -> None:
"""Test that run_with_timeout preserves contextvar values"""
# Set a value in the main thread
test_var.set("test_value")

# Run function with timeout and verify the value is preserved
result = run_with_timeout(1.0, get_contextvar_value)
assert result == "test_value"


def test_run_functions_in_parallel_preserves_contextvar() -> None:
"""Test that run_functions_in_parallel preserves contextvar values"""
# Set a value in the main thread
test_var.set("parallel_test")

# Create multiple function calls
function_calls = [
FunctionCall(get_contextvar_value),
FunctionCall(get_contextvar_value),
]

# Run in parallel and verify all results have the correct value
results = run_functions_in_parallel(function_calls)

for result_id, value in results.items():
assert value == "parallel_test"


def test_run_functions_tuples_preserves_contextvar() -> None:
"""Test that run_functions_tuples_in_parallel preserves contextvar values"""
# Set a value in the main thread
test_var.set("tuple_test")

# Create list of function tuples
functions_with_args = [
(get_contextvar_value, ()),
(get_contextvar_value, ()),
]

# Run in parallel and verify all results have the correct value
results = run_functions_tuples_in_parallel(functions_with_args)

for result in results:
assert result == "tuple_test"


def test_nested_contextvar_modifications() -> None:
"""Test that modifications to contextvars in threads don't affect other threads"""

def modify_and_return_contextvar(new_value: str) -> tuple[str, str]:
"""Helper that modifies the contextvar and returns both values"""
original = test_var.get()
test_var.set(new_value)
time.sleep(0.1) # Ensure threads overlap
return original, test_var.get()

# Set initial value
test_var.set("initial")

# Run multiple functions that modify the contextvar
functions_with_args = [
(modify_and_return_contextvar, ("thread1",)),
(modify_and_return_contextvar, ("thread2",)),
]

results = run_functions_tuples_in_parallel(functions_with_args)

# Verify each thread saw the initial value and its own modification
for original, modified in results:
assert original == "initial" # Each thread should see the initial value
assert modified in [
"thread1",
"thread2",
] # Each thread should see its own modification

# Verify the main thread's value wasn't affected
assert test_var.get() == "initial"


def test_contextvar_isolation_between_runs() -> None:
"""Test that contextvar changes don't leak between separate parallel runs"""

def set_and_return_contextvar(value: str) -> str:
test_var.set(value)
return test_var.get()

# First run
test_var.set("first_run")
first_results = run_functions_tuples_in_parallel(
[
(set_and_return_contextvar, ("thread1",)),
(set_and_return_contextvar, ("thread2",)),
]
)

# Verify first run results
assert all(result in ["thread1", "thread2"] for result in first_results)

# Second run should still see the main thread's value
assert test_var.get() == "first_run"

# Second run with different value
test_var.set("second_run")
second_results = run_functions_tuples_in_parallel(
[
(set_and_return_contextvar, ("thread3",)),
(set_and_return_contextvar, ("thread4",)),
]
)

# Verify second run results
assert all(result in ["thread3", "thread4"] for result in second_results)

0 comments on commit 4a4e4a6

Please sign in to comment.