Skip to content

Commit

Permalink
Re-enable support for running sync tasks from async entrypoints (#3108)
Browse files Browse the repository at this point in the history
- When using an async entrypoint you can now freely mix and match sync
and async tasks with a uniform api (ie all tasks return a sync or async
future depending on context)
- Fix issues with scheduling deeply nested tasks (use threadsafe methods
to schedule coroutines and create futures)
  • Loading branch information
nfcampos authored Jan 21, 2025
2 parents 31cc6b9 + 2cafb49 commit e10b7c1
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 32 deletions.
27 changes: 10 additions & 17 deletions libs/langgraph/langgraph/pregel/executor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import concurrent.futures
import sys
import time
from contextlib import ExitStack
from contextvars import copy_context
Expand All @@ -22,6 +21,7 @@
from typing_extensions import ParamSpec

from langgraph.errors import GraphBubbleUp
from langgraph.utils.future import CONTEXT_NOT_SUPPORTED, run_coroutine_threadsafe

P = ParamSpec("P")
T = TypeVar("T")
Expand Down Expand Up @@ -132,8 +132,7 @@ class AsyncBackgroundExecutor(AsyncContextManager):
ignoring CancelledError"""

def __init__(self, config: RunnableConfig) -> None:
self.context_not_supported = sys.version_info < (3, 11)
self.tasks: dict[asyncio.Task, tuple[bool, bool]] = {}
self.tasks: dict[asyncio.Future, tuple[bool, bool]] = {}
self.sentinel = object()
self.loop = asyncio.get_running_loop()
if max_concurrency := config.get("max_concurrency"):
Expand All @@ -150,23 +149,23 @@ def submit( # type: ignore[valid-type]
__name__: Optional[str] = None,
__cancel_on_exit__: bool = False,
__reraise_on_exit__: bool = True,
__next_tick__: bool = False,
__next_tick__: bool = False, # noop in async (always True)
**kwargs: P.kwargs,
) -> asyncio.Task[T]:
) -> asyncio.Future[T]:
coro = cast(Coroutine[None, None, T], fn(*args, **kwargs))
if self.semaphore:
coro = gated(self.semaphore, coro)
if __next_tick__:
coro = anext_tick(coro)
if self.context_not_supported:
task = self.loop.create_task(coro, name=__name__)
if CONTEXT_NOT_SUPPORTED:
task = run_coroutine_threadsafe(coro, self.loop, name=__name__)
else:
task = self.loop.create_task(coro, name=__name__, context=copy_context())
task = run_coroutine_threadsafe(
coro, self.loop, name=__name__, context=copy_context()
)
self.tasks[task] = (__cancel_on_exit__, __reraise_on_exit__)
task.add_done_callback(self.done)
return task

def done(self, task: asyncio.Task) -> None:
def done(self, task: asyncio.Future) -> None:
try:
if exc := task.exception():
# This exception is an interruption signal, not an error
Expand Down Expand Up @@ -219,9 +218,3 @@ def next_tick(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
"""A function that yields control to other threads before running another function."""
time.sleep(0)
return fn(*args, **kwargs)


async def anext_tick(coro: Coroutine[None, None, T]) -> T:
"""A coroutine that yields control to event loop before running another coroutine."""
await asyncio.sleep(0)
return await coro
31 changes: 21 additions & 10 deletions libs/langgraph/langgraph/pregel/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,7 @@ def call(
assert fut is not None, "writer did not return a future for call"
# return a chained future to ensure commit() callback is called
# before the returned future is resolved, to ensure stream order etc
sfut: concurrent.futures.Future[Any] = concurrent.futures.Future()
chain_future(fut, sfut)
return sfut
return chain_future(fut, concurrent.futures.Future())

tasks = tuple(tasks)
futures = FuturesDict(
Expand Down Expand Up @@ -400,10 +398,6 @@ def call(
retry: Optional[RetryPolicy] = None,
callbacks: Callbacks = None,
) -> Union[asyncio.Future[Any], concurrent.futures.Future[Any]]:
if not asyncio.iscoroutinefunction(func):
raise RuntimeError(
"In an async context use func.to_thread(...) to invoke tasks"
)
(fut,) = writer(
task,
[(PUSH, None)],
Expand All @@ -412,9 +406,26 @@ def call(
assert fut is not None, "writer did not return a future for call"
# return a chained future to ensure commit() callback is called
# before the returned future is resolved, to ensure stream order etc
sfut: asyncio.Future[Any] = asyncio.Future(loop=loop)
chain_future(fut, sfut)
return sfut
try:
in_async = asyncio.current_task() is not None
except RuntimeError:
in_async = False
# if in async context return an async future
# otherwise return a chained sync future
if in_async:
if isinstance(fut, asyncio.Task):
sfut: Union[asyncio.Future[Any], concurrent.futures.Future[Any]] = (
asyncio.Future(loop=loop)
)
loop.call_soon_threadsafe(chain_future, fut, sfut)
return sfut
else:
# already wrapped in a future
return fut
else:
sfut = concurrent.futures.Future()
loop.call_soon_threadsafe(chain_future, fut, sfut)
return sfut

loop = asyncio.get_event_loop()
tasks = tuple(tasks)
Expand Down
83 changes: 80 additions & 3 deletions libs/langgraph/langgraph/utils/future.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import asyncio
import concurrent.futures
from typing import Union
import contextvars
import inspect
import sys
import types
from typing import Awaitable, Coroutine, Generator, Optional, TypeVar, Union, cast

T = TypeVar("T")
AnyFuture = Union[asyncio.Future, concurrent.futures.Future]

CONTEXT_NOT_SUPPORTED = sys.version_info < (3, 11)


def _get_loop(fut: asyncio.Future) -> asyncio.AbstractEventLoop:
# Tries to call Future.get_loop() if it's available.
Expand Down Expand Up @@ -52,10 +59,11 @@ def _copy_future_state(source: AnyFuture, dest: asyncio.Future) -> None:
The other Future may be a concurrent.futures.Future.
"""
if dest.done():
return
assert source.done()
if dest.cancelled():
return
assert not dest.done()
if source.cancelled():
dest.cancel()
else:
Expand Down Expand Up @@ -112,10 +120,11 @@ def _call_set_state(source: AnyFuture) -> None:
source.add_done_callback(_call_set_state)


def chain_future(source: AnyFuture, destination: AnyFuture) -> None:
def chain_future(source: AnyFuture, destination: AnyFuture) -> AnyFuture:
# adapted from asyncio.run_coroutine_threadsafe
try:
_chain_future(source, destination)
return destination
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as exc:
Expand All @@ -125,3 +134,71 @@ def chain_future(source: AnyFuture, destination: AnyFuture) -> None:
else:
destination.set_exception(exc)
raise


def _ensure_future(
coro_or_future: Union[Coroutine[None, None, T], Awaitable[T]],
*,
loop: asyncio.AbstractEventLoop,
name: Optional[str] = None,
context: Optional[contextvars.Context] = None,
) -> asyncio.Task[T]:
called_wrap_awaitable = False
if not asyncio.iscoroutine(coro_or_future):
if inspect.isawaitable(coro_or_future):
coro_or_future = cast(
Coroutine[None, None, T], _wrap_awaitable(coro_or_future)
)
called_wrap_awaitable = True
else:
raise TypeError(
"An asyncio.Future, a coroutine or an awaitable is required."
f" Got {type(coro_or_future).__name__} instead."
)

try:
if CONTEXT_NOT_SUPPORTED:
return loop.create_task(coro_or_future, name=name)
else:
return loop.create_task(coro_or_future, name=name, context=context)
except RuntimeError:
if not called_wrap_awaitable:
coro_or_future.close()
raise


@types.coroutine
def _wrap_awaitable(awaitable: Awaitable[T]) -> Generator[None, None, T]:
"""Helper for asyncio.ensure_future().
Wraps awaitable (an object with __await__) into a coroutine
that will later be wrapped in a Task by ensure_future().
"""
return (yield from awaitable.__await__())


def run_coroutine_threadsafe(
coro: Coroutine[None, None, T],
loop: asyncio.AbstractEventLoop,
name: Optional[str] = None,
context: Optional[contextvars.Context] = None,
) -> asyncio.Future[T]:
"""Submit a coroutine object to a given event loop.
Return a asyncio.Future to access the result.
"""
future: asyncio.Future[T] = asyncio.Future(loop=loop)

def callback() -> None:
try:
chain_future(
_ensure_future(coro, loop=loop, name=name, context=context), future
)
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as exc:
future.set_exception(exc)
raise

loop.call_soon_threadsafe(callback, context=context)
return future
68 changes: 66 additions & 2 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,7 +1158,7 @@ async def alittlewhile(input: Any) -> None:
with pytest.raises(asyncio.TimeoutError):
async for chunk in graph.astream(1, stream_mode="updates"):
assert chunk == {"alittlewhile": {"alittlewhile": "1"}}
await asyncio.sleep(0.6)
await asyncio.sleep(stream_hang_s)

assert inner_task_cancelled

Expand Down Expand Up @@ -2489,6 +2489,71 @@ async def graph(input: list[int]) -> list[str]:
assert mapper_calls == 2


@NEEDS_CONTEXTVARS
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_imp_nested(checkpointer_name: str) -> None:
async def mynode(input: list[str]) -> list[str]:
return [it + "a" for it in input]

builder = StateGraph(list[str])
builder.add_node(mynode)
builder.add_edge(START, "mynode")
add_a = builder.compile()

@task
def submapper(input: int) -> str:
return str(input)

@task
async def mapper(input: int) -> str:
await asyncio.sleep(input / 100)
return await submapper(input) * 2

async with awith_checkpointer(checkpointer_name) as checkpointer:

@entrypoint(checkpointer=checkpointer)
async def graph(input: list[int]) -> list[str]:
futures = [mapper(i) for i in input]
mapped = await asyncio.gather(*futures)
answer = interrupt("question")
final = [m + answer for m in mapped]
return await add_a.ainvoke(final)

assert graph.get_input_jsonschema() == {
"type": "array",
"items": {"type": "integer"},
"title": "LangGraphInput",
}
assert graph.get_output_jsonschema() == {
"type": "array",
"items": {"type": "string"},
"title": "LangGraphOutput",
}

thread1 = {"configurable": {"thread_id": "1"}}
assert [c async for c in graph.astream([0, 1], thread1)] == [
{"submapper": "0"},
{"mapper": "00"},
{"submapper": "1"},
{"mapper": "11"},
{
"__interrupt__": (
Interrupt(
value="question",
resumable=True,
ns=[AnyStr("graph:")],
when="during",
),
)
},
]

assert await graph.ainvoke(Command(resume="answer"), thread1) == [
"00answera",
"11answera",
]


@NEEDS_CONTEXTVARS
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_imp_task_cancel(checkpointer_name: str) -> None:
Expand Down Expand Up @@ -2540,7 +2605,6 @@ async def graph(input: list[int]) -> list[str]:
assert mapper_cancels == 2


@pytest.mark.skip("TODO: re-enable")
@NEEDS_CONTEXTVARS
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_imp_sync_from_async(checkpointer_name: str) -> None:
Expand Down

0 comments on commit e10b7c1

Please sign in to comment.