Skip to content

Commit

Permalink
use context varaibles instead of thread locals
Browse files Browse the repository at this point in the history
Signed-off-by: Achille Roussel <achille.roussel@gmail.com>
  • Loading branch information
achille-roussel committed Jun 18, 2024
1 parent a406d2a commit d5f5848
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 16 deletions.
19 changes: 5 additions & 14 deletions src/dispatch/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio
import contextvars
import logging
import pickle
import sys
import threading
from dataclasses import dataclass, field
from types import coroutine
from typing import (
Expand Down Expand Up @@ -32,19 +32,10 @@
CoroutineID: TypeAlias = int
CorrelationID: TypeAlias = int


class ThreadLocal(threading.local):
in_function_call: bool

def __init__(self):
self.in_function_call = False


thread_local = ThreadLocal()

_in_function_call = contextvars.ContextVar("dispatch.scheduler.in_function_call", default=False)

def in_function_call() -> bool:
return thread_local.in_function_call
return bool(_in_function_call.get())


@dataclass
Expand Down Expand Up @@ -343,15 +334,15 @@ def __init__(

async def run(self, input: Input) -> Output:
try:
thread_local.in_function_call = True
token = _in_function_call.set(True)
return await self._run(input)
except Exception as e:
logger.exception(
"unexpected exception occurred during coroutine scheduling"
)
return Output.error(Error.from_exception(e))
finally:
thread_local.in_function_call = False
_in_function_call.reset(token)

def _init_state(self, input: Input) -> State:
logger.debug("starting main coroutine")
Expand Down
53 changes: 51 additions & 2 deletions src/dispatch/test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import unittest
from datetime import datetime, timedelta
from functools import wraps
from typing import Any, Callable, Coroutine, Dict, Optional, TypeVar
from typing import Any, Callable, Coroutine, Dict, List, Optional, TypeVar

import aiohttp
from aiohttp import web
Expand Down Expand Up @@ -193,6 +193,7 @@ def make_request(call: Call) -> RunRequest:
else:
error = Error(type="status", message=str(res.status))
return CallResult(
correlation_id=call.correlation_id,
dispatch_id=dispatch_id,
error=error,
)
Expand All @@ -203,6 +204,7 @@ def make_request(call: Call) -> RunRequest:
continue
result = res.exit.result
return CallResult(
correlation_id=call.correlation_id,
dispatch_id=dispatch_id,
output=result.output if result.HasField("output") else None,
error=result.error if result.HasField("error") else None,
Expand Down Expand Up @@ -317,6 +319,7 @@ def test(self):
endpoint=DISPATCH_ENDPOINT_URL,
client=Client(api_key=DISPATCH_API_KEY, api_url=DISPATCH_API_URL),
)
set_default_registry(_registry)


@_registry.function
Expand Down Expand Up @@ -354,7 +357,33 @@ async def broken_nested(name: str) -> str:
return await broken()


set_default_registry(_registry)
@_registry.function
async def distributed_merge_sort(values: List[int]) -> List[int]:
if len(values) <= 1:
return values
i = len(values) // 2

(l, r) = await dispatch.gather(
distributed_merge_sort(values[:i]),
distributed_merge_sort(values[i:]),
)

return merge(l, r)


def merge(l: List[int], r: List[int]) -> List[int]:
result = []
i = j = 0
while i < len(l) and j < len(r):
if l[i] < r[j]:
result.append(l[i])
i += 1
else:
result.append(r[j])
j += 1
result.extend(l[i:])
result.extend(r[j:])
return result


class TestCase(unittest.TestCase):
Expand Down Expand Up @@ -473,6 +502,26 @@ async def test_call_nested_function_with_error(self):
with self.assertRaises(ValueError) as e:
await broken_nested("hello")

@aiotest
async def test_distributed_merge_sort_no_values(self):
values: List[int] = []
self.assertEqual(await distributed_merge_sort(values), sorted(values))

@aiotest
async def test_distributed_merge_sort_one_value(self):
values: List[int] = [1]
self.assertEqual(await distributed_merge_sort(values), sorted(values))

@aiotest
async def test_distributed_merge_sort_two_values(self):
values: List[int] = [1, 5]
self.assertEqual(await distributed_merge_sort(values), sorted(values))

@aiotest
async def test_distributed_merge_sort_many_values(self):
values: List[int] = [1, 5, 3, 2, 4, 6, 7, 8, 9, 0]
self.assertEqual(await distributed_merge_sort(values), sorted(values))


class ClientRequestContentLengthMissing(aiohttp.ClientRequest):
def update_headers(self, skip_auto_headers):
Expand Down

0 comments on commit d5f5848

Please sign in to comment.