From aed1f0ba1861811ba45fef90c4c94664a54c3be1 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 20 Jan 2025 10:44:46 -0800 Subject: [PATCH 1/5] Re-enable support for running sync tasks from async entrypoints - When using an async entrypoint you can now freely mix and match sync and async tasks with a uniform api (ie all tasks return a sync or async future depending on context) - Fix issues with scheduling deeply nested tasks (use threadsafe methods to schedule coroutines and create futures) --- libs/langgraph/langgraph/pregel/executor.py | 27 +++---- libs/langgraph/langgraph/pregel/runner.py | 30 +++++--- libs/langgraph/langgraph/utils/future.py | 80 ++++++++++++++++++++- libs/langgraph/tests/test_pregel_async.py | 68 +++++++++++++++++- 4 files changed, 173 insertions(+), 32 deletions(-) diff --git a/libs/langgraph/langgraph/pregel/executor.py b/libs/langgraph/langgraph/pregel/executor.py index 8bd8671be..64bb6c90e 100644 --- a/libs/langgraph/langgraph/pregel/executor.py +++ b/libs/langgraph/langgraph/pregel/executor.py @@ -1,6 +1,5 @@ import asyncio import concurrent.futures -import sys import time from contextlib import ExitStack from contextvars import copy_context @@ -22,6 +21,7 @@ from typing_extensions import ParamSpec from langgraph.errors import GraphBubbleUp +from langgraph.utils.future import CONTEXT_NOT_SUPPORTED, run_coroutine_threadsafe P = ParamSpec("P") T = TypeVar("T") @@ -132,8 +132,7 @@ class AsyncBackgroundExecutor(AsyncContextManager): ignoring CancelledError""" def __init__(self, config: RunnableConfig) -> None: - self.context_not_supported = sys.version_info < (3, 11) - self.tasks: dict[asyncio.Task, tuple[bool, bool]] = {} + self.tasks: dict[asyncio.Future, tuple[bool, bool]] = {} self.sentinel = object() self.loop = asyncio.get_running_loop() if max_concurrency := config.get("max_concurrency"): @@ -150,23 +149,23 @@ def submit( # type: ignore[valid-type] __name__: Optional[str] = None, __cancel_on_exit__: bool = False, __reraise_on_exit__: bool = True, - __next_tick__: bool = False, + __next_tick__: bool = False, # noop in async (always True) **kwargs: P.kwargs, - ) -> asyncio.Task[T]: + ) -> asyncio.Future[T]: coro = cast(Coroutine[None, None, T], fn(*args, **kwargs)) if self.semaphore: coro = gated(self.semaphore, coro) - if __next_tick__: - coro = anext_tick(coro) - if self.context_not_supported: - task = self.loop.create_task(coro, name=__name__) + if CONTEXT_NOT_SUPPORTED: + task = run_coroutine_threadsafe(coro, self.loop, name=__name__) else: - task = self.loop.create_task(coro, name=__name__, context=copy_context()) + task = run_coroutine_threadsafe( + coro, self.loop, name=__name__, context=copy_context() + ) self.tasks[task] = (__cancel_on_exit__, __reraise_on_exit__) task.add_done_callback(self.done) return task - def done(self, task: asyncio.Task) -> None: + def done(self, task: asyncio.Future) -> None: try: if exc := task.exception(): # This exception is an interruption signal, not an error @@ -219,9 +218,3 @@ def next_tick(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: """A function that yields control to other threads before running another function.""" time.sleep(0) return fn(*args, **kwargs) - - -async def anext_tick(coro: Coroutine[None, None, T]) -> T: - """A coroutine that yields control to event loop before running another coroutine.""" - await asyncio.sleep(0) - return await coro diff --git a/libs/langgraph/langgraph/pregel/runner.py b/libs/langgraph/langgraph/pregel/runner.py index 90b026919..02e97a8e0 100644 --- a/libs/langgraph/langgraph/pregel/runner.py +++ b/libs/langgraph/langgraph/pregel/runner.py @@ -213,9 +213,7 @@ def call( assert fut is not None, "writer did not return a future for call" # return a chained future to ensure commit() callback is called # before the returned future is resolved, to ensure stream order etc - sfut: concurrent.futures.Future[Any] = concurrent.futures.Future() - chain_future(fut, sfut) - return sfut + return chain_future(fut, concurrent.futures.Future()) tasks = tuple(tasks) futures = FuturesDict( @@ -400,10 +398,6 @@ def call( retry: Optional[RetryPolicy] = None, callbacks: Callbacks = None, ) -> Union[asyncio.Future[Any], concurrent.futures.Future[Any]]: - if not asyncio.iscoroutinefunction(func): - raise RuntimeError( - "In an async context use func.to_thread(...) to invoke tasks" - ) (fut,) = writer( task, [(PUSH, None)], @@ -412,9 +406,25 @@ def call( assert fut is not None, "writer did not return a future for call" # return a chained future to ensure commit() callback is called # before the returned future is resolved, to ensure stream order etc - sfut: asyncio.Future[Any] = asyncio.Future(loop=loop) - chain_future(fut, sfut) - return sfut + try: + asyncio.current_task() + in_async = True + except RuntimeError: + in_async = False + # if in async context return an async future + # otherwise return a chained sync future + if in_async: + if isinstance(fut, asyncio.Task): + sfut = asyncio.Future(loop=loop) + loop.call_soon_threadsafe(chain_future, fut, sfut) + return sfut + else: + # already wrapped in a future + return fut + else: + sfut = concurrent.futures.Future() + loop.call_soon_threadsafe(chain_future, fut, sfut) + return sfut loop = asyncio.get_event_loop() tasks = tuple(tasks) diff --git a/libs/langgraph/langgraph/utils/future.py b/libs/langgraph/langgraph/utils/future.py index 03ce31a50..50fc1d0eb 100644 --- a/libs/langgraph/langgraph/utils/future.py +++ b/libs/langgraph/langgraph/utils/future.py @@ -1,9 +1,16 @@ import asyncio import concurrent.futures -from typing import Union +import contextvars +import inspect +import sys +import types +from typing import Coroutine, Optional, TypeVar, Union +T = TypeVar("T") AnyFuture = Union[asyncio.Future, concurrent.futures.Future] +CONTEXT_NOT_SUPPORTED = sys.version_info < (3, 11) + def _get_loop(fut: asyncio.Future) -> asyncio.AbstractEventLoop: # Tries to call Future.get_loop() if it's available. @@ -52,10 +59,11 @@ def _copy_future_state(source: AnyFuture, dest: asyncio.Future) -> None: The other Future may be a concurrent.futures.Future. """ + if dest.done(): + return assert source.done() if dest.cancelled(): return - assert not dest.done() if source.cancelled(): dest.cancel() else: @@ -112,10 +120,11 @@ def _call_set_state(source: AnyFuture) -> None: source.add_done_callback(_call_set_state) -def chain_future(source: AnyFuture, destination: AnyFuture) -> None: +def chain_future(source: AnyFuture, destination: AnyFuture) -> AnyFuture: # adapted from asyncio.run_coroutine_threadsafe try: _chain_future(source, destination) + return destination except (SystemExit, KeyboardInterrupt): raise except BaseException as exc: @@ -125,3 +134,68 @@ def chain_future(source: AnyFuture, destination: AnyFuture) -> None: else: destination.set_exception(exc) raise + + +def _ensure_future( + coro_or_future: Coroutine[None, None, T], + *, + loop: asyncio.AbstractEventLoop, + name: Optional[str] = None, + context: Optional[contextvars.Context] = None, +) -> asyncio.Task[T]: + called_wrap_awaitable = False + if not asyncio.iscoroutine(coro_or_future): + if inspect.isawaitable(coro_or_future): + coro_or_future = _wrap_awaitable(coro_or_future) + called_wrap_awaitable = True + else: + raise TypeError( + "An asyncio.Future, a coroutine or an awaitable is required" + ) + + try: + if CONTEXT_NOT_SUPPORTED: + return loop.create_task(coro_or_future, name=name) + else: + return loop.create_task(coro_or_future, name=name, context=context) + except RuntimeError: + if not called_wrap_awaitable: + coro_or_future.close() + raise + + +@types.coroutine +def _wrap_awaitable(awaitable): + """Helper for asyncio.ensure_future(). + + Wraps awaitable (an object with __await__) into a coroutine + that will later be wrapped in a Task by ensure_future(). + """ + return (yield from awaitable.__await__()) + + +def run_coroutine_threadsafe( + coro: Coroutine[None, None, T], + loop: asyncio.AbstractEventLoop, + name: Optional[str] = None, + context: Optional[contextvars.Context] = None, +) -> asyncio.Future[T]: + """Submit a coroutine object to a given event loop. + + Return a asyncio.Future to access the result. + """ + future = asyncio.Future(loop=loop) + + def callback(): + try: + chain_future( + _ensure_future(coro, loop=loop, name=name, context=context), future + ) + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + future.set_exception(exc) + raise + + loop.call_soon_threadsafe(callback, context=context) + return future diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 17847f4b8..67d31fc38 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -1158,7 +1158,7 @@ async def alittlewhile(input: Any) -> None: with pytest.raises(asyncio.TimeoutError): async for chunk in graph.astream(1, stream_mode="updates"): assert chunk == {"alittlewhile": {"alittlewhile": "1"}} - await asyncio.sleep(0.6) + await asyncio.sleep(stream_hang_s) assert inner_task_cancelled @@ -2484,6 +2484,71 @@ async def graph(input: list[int]) -> list[str]: assert mapper_calls == 2 +@NEEDS_CONTEXTVARS +@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) +async def test_imp_nested(checkpointer_name: str) -> None: + async def mynode(input: list[str]) -> list[str]: + return [it + "a" for it in input] + + builder = StateGraph(list[str]) + builder.add_node(mynode) + builder.add_edge(START, "mynode") + add_a = builder.compile() + + @task + def submapper(input: int) -> str: + return str(input) + + @task + async def mapper(input: int) -> str: + await asyncio.sleep(input / 100) + return await submapper(input) * 2 + + async with awith_checkpointer(checkpointer_name) as checkpointer: + + @entrypoint(checkpointer=checkpointer) + async def graph(input: list[int]) -> list[str]: + futures = [mapper(i) for i in input] + mapped = await asyncio.gather(*futures) + answer = interrupt("question") + final = [m + answer for m in mapped] + return await add_a.ainvoke(final) + + assert graph.get_input_jsonschema() == { + "type": "array", + "items": {"type": "integer"}, + "title": "LangGraphInput", + } + assert graph.get_output_jsonschema() == { + "type": "array", + "items": {"type": "string"}, + "title": "LangGraphOutput", + } + + thread1 = {"configurable": {"thread_id": "1"}} + assert [c async for c in graph.astream([0, 1], thread1)] == [ + {"submapper": "0"}, + {"mapper": "00"}, + {"submapper": "1"}, + {"mapper": "11"}, + { + "__interrupt__": ( + Interrupt( + value="question", + resumable=True, + ns=[AnyStr("graph:")], + when="during", + ), + ) + }, + ] + + assert await graph.ainvoke(Command(resume="answer"), thread1) == [ + "00answera", + "11answera", + ] + + @NEEDS_CONTEXTVARS @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_imp_task_cancel(checkpointer_name: str) -> None: @@ -2535,7 +2600,6 @@ async def graph(input: list[int]) -> list[str]: assert mapper_cancels == 2 -@pytest.mark.skip("TODO: re-enable") @NEEDS_CONTEXTVARS @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_imp_sync_from_async(checkpointer_name: str) -> None: From 12be3fac33a95e71ece4b68b8a4816929c3a3da5 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 20 Jan 2025 11:13:37 -0800 Subject: [PATCH 2/5] Lint --- libs/langgraph/langgraph/pregel/runner.py | 4 +++- libs/langgraph/langgraph/utils/future.py | 14 ++++++++------ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/libs/langgraph/langgraph/pregel/runner.py b/libs/langgraph/langgraph/pregel/runner.py index 02e97a8e0..af80ed715 100644 --- a/libs/langgraph/langgraph/pregel/runner.py +++ b/libs/langgraph/langgraph/pregel/runner.py @@ -415,7 +415,9 @@ def call( # otherwise return a chained sync future if in_async: if isinstance(fut, asyncio.Task): - sfut = asyncio.Future(loop=loop) + sfut: Union[asyncio.Future[Any], concurrent.futures.Future[Any]] = ( + asyncio.Future(loop=loop) + ) loop.call_soon_threadsafe(chain_future, fut, sfut) return sfut else: diff --git a/libs/langgraph/langgraph/utils/future.py b/libs/langgraph/langgraph/utils/future.py index 50fc1d0eb..8445f5396 100644 --- a/libs/langgraph/langgraph/utils/future.py +++ b/libs/langgraph/langgraph/utils/future.py @@ -4,7 +4,7 @@ import inspect import sys import types -from typing import Coroutine, Optional, TypeVar, Union +from typing import Awaitable, Coroutine, Generator, Optional, TypeVar, Union, cast T = TypeVar("T") AnyFuture = Union[asyncio.Future, concurrent.futures.Future] @@ -137,7 +137,7 @@ def chain_future(source: AnyFuture, destination: AnyFuture) -> AnyFuture: def _ensure_future( - coro_or_future: Coroutine[None, None, T], + coro_or_future: Union[Coroutine[None, None, T], Awaitable[T]], *, loop: asyncio.AbstractEventLoop, name: Optional[str] = None, @@ -146,7 +146,9 @@ def _ensure_future( called_wrap_awaitable = False if not asyncio.iscoroutine(coro_or_future): if inspect.isawaitable(coro_or_future): - coro_or_future = _wrap_awaitable(coro_or_future) + coro_or_future = cast( + Coroutine[None, None, T], _wrap_awaitable(coro_or_future) + ) called_wrap_awaitable = True else: raise TypeError( @@ -165,7 +167,7 @@ def _ensure_future( @types.coroutine -def _wrap_awaitable(awaitable): +def _wrap_awaitable(awaitable: Awaitable[T]) -> Generator[None, None, T]: """Helper for asyncio.ensure_future(). Wraps awaitable (an object with __await__) into a coroutine @@ -184,9 +186,9 @@ def run_coroutine_threadsafe( Return a asyncio.Future to access the result. """ - future = asyncio.Future(loop=loop) + future: asyncio.Future[T] = asyncio.Future(loop=loop) - def callback(): + def callback() -> None: try: chain_future( _ensure_future(coro, loop=loop, name=name, context=context), future From c7c62f558759ce608062e40a099ea217b4db98ed Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 21 Jan 2025 09:42:25 -0800 Subject: [PATCH 3/5] Fix --- libs/langgraph/langgraph/pregel/runner.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/libs/langgraph/langgraph/pregel/runner.py b/libs/langgraph/langgraph/pregel/runner.py index af80ed715..1651b0f4f 100644 --- a/libs/langgraph/langgraph/pregel/runner.py +++ b/libs/langgraph/langgraph/pregel/runner.py @@ -407,8 +407,7 @@ def call( # return a chained future to ensure commit() callback is called # before the returned future is resolved, to ensure stream order etc try: - asyncio.current_task() - in_async = True + in_async = asyncio.current_task() is not None except RuntimeError: in_async = False # if in async context return an async future From fca0d2d5bb9a5f3d7f255015b083662c8c5411fd Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 21 Jan 2025 09:43:12 -0800 Subject: [PATCH 4/5] Update libs/langgraph/langgraph/utils/future.py Co-authored-by: William FH <13333726+hinthornw@users.noreply.github.com> --- libs/langgraph/langgraph/utils/future.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/libs/langgraph/langgraph/utils/future.py b/libs/langgraph/langgraph/utils/future.py index 8445f5396..b7790f48c 100644 --- a/libs/langgraph/langgraph/utils/future.py +++ b/libs/langgraph/langgraph/utils/future.py @@ -152,7 +152,8 @@ def _ensure_future( called_wrap_awaitable = True else: raise TypeError( - "An asyncio.Future, a coroutine or an awaitable is required" + "An asyncio.Future, a coroutine or an awaitable is required." + f" Got {type(future).__name__} instead." ) try: From 2cafb4905b6401d247cbef0eb337863bcf74d129 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 21 Jan 2025 10:06:02 -0800 Subject: [PATCH 5/5] Lint --- libs/langgraph/langgraph/utils/future.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langgraph/langgraph/utils/future.py b/libs/langgraph/langgraph/utils/future.py index b7790f48c..e61afec91 100644 --- a/libs/langgraph/langgraph/utils/future.py +++ b/libs/langgraph/langgraph/utils/future.py @@ -153,7 +153,7 @@ def _ensure_future( else: raise TypeError( "An asyncio.Future, a coroutine or an awaitable is required." - f" Got {type(future).__name__} instead." + f" Got {type(coro_or_future).__name__} instead." ) try: