Skip to content

Commit

Permalink
Fix tracing hierarchy for imperative api (#3036)
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos authored Jan 15, 2025
2 parents e2554c9 + 4426552 commit d12830e
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 36 deletions.
17 changes: 1 addition & 16 deletions .github/workflows/_test_langgraph.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,11 @@ jobs:
- "3.11"
- "3.12"
- "3.13"
core-version:
- "latest"
ff-send-v2:
- "false"
include:
- python-version: "3.11"
core-version: ">=0.2.42,<0.3.0"
- python-version: "3.11"
core-version: "latest"
ff-send-v2: "true"

defaults:
run:
working-directory: libs/langgraph
name: "test #${{ matrix.python-version }} (langchain-core: ${{ matrix.core-version }}, ff-send-v2: ${{ matrix.ff-send-v2 }})"
name: "test #${{ matrix.python-version }}"
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }}
Expand All @@ -51,14 +41,9 @@ jobs:
shell: bash
run: |
poetry install --with dev
if [ "${{ matrix.core-version }}" != "latest" ]; then
poetry run pip install "langchain-core${{ matrix.core-version }}"
fi
- name: Run tests
shell: bash
env:
LANGGRAPH_FF_SEND_V2: ${{ matrix.ff-send-v2 }}
run: |
make test_parallel
Expand Down
10 changes: 5 additions & 5 deletions libs/langgraph/langgraph/func/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from langgraph.channels.ephemeral_value import EphemeralValue
from langgraph.channels.last_value import LastValue
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.constants import END, START, TAG_HIDDEN
from langgraph.constants import CONF, END, START, TAG_HIDDEN
from langgraph.pregel import Pregel
from langgraph.pregel.call import get_runnable_for_func
from langgraph.pregel.read import PregelNode
Expand All @@ -39,11 +39,11 @@ def call(
**kwargs: Any,
) -> concurrent.futures.Future[T]:
from langgraph.constants import CONFIG_KEY_CALL
from langgraph.utils.config import get_configurable
from langgraph.utils.config import get_config

conf = get_configurable()
impl = conf[CONFIG_KEY_CALL]
fut = impl(func, (args, kwargs), retry=retry)
config = get_config()
impl = config[CONF][CONFIG_KEY_CALL]
fut = impl(func, (args, kwargs), retry=retry, callbacks=config["callbacks"])
return fut


Expand Down
17 changes: 12 additions & 5 deletions libs/langgraph/langgraph/pregel/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from uuid import UUID

from langchain_core.callbacks import Callbacks
from langchain_core.callbacks.manager import AsyncParentRunManager, ParentRunManager
from langchain_core.runnables.config import RunnableConfig

Expand Down Expand Up @@ -107,18 +108,25 @@ class PregelTaskWrites(NamedTuple):


class Call:
__slots__ = ("func", "input", "retry")
__slots__ = ("func", "input", "retry", "callbacks")

func: Callable
input: Any
retry: Optional[RetryPolicy]
callbacks: Callbacks

def __init__(
self, func: Callable, input: Any, *, retry: Optional[RetryPolicy]
self,
func: Callable,
input: Any,
*,
retry: Optional[RetryPolicy],
callbacks: Callbacks,
) -> None:
self.func = func
self.input = input
self.retry = retry
self.callbacks = callbacks


def should_interrupt(
Expand Down Expand Up @@ -465,9 +473,8 @@ def prepare_single_task(
patch_config(
merge_configs(config, {"metadata": metadata}),
run_name=name,
callbacks=(
manager.get_child(f"graph:step:{step}") if manager else None
),
callbacks=call.callbacks
or (manager.get_child(f"graph:step:{step}") if manager else None),
configurable={
CONFIG_KEY_TASK_ID: task_id,
# deque.extend is thread-safe
Expand Down
8 changes: 6 additions & 2 deletions libs/langgraph/langgraph/pregel/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,14 @@ def submit( # type: ignore[valid-type]
__next_tick__: bool = False,
**kwargs: P.kwargs,
) -> concurrent.futures.Future[T]:
ctx = copy_context()
if __next_tick__:
task = self.executor.submit(next_tick, fn, *args, **kwargs)
task = cast(
concurrent.futures.Future[T],
self.executor.submit(next_tick, ctx.run, fn, *args, **kwargs), # type: ignore[arg-type]
)
else:
task = self.executor.submit(fn, *args, **kwargs)
task = self.executor.submit(ctx.run, fn, *args, **kwargs)
self.tasks[task] = (__cancel_on_exit__, __reraise_on_exit__)
# add a callback to remove the task from the tasks dict when it's done
task.add_done_callback(self.done)
Expand Down
12 changes: 10 additions & 2 deletions libs/langgraph/langgraph/pregel/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
cast,
)

from langchain_core.callbacks import Callbacks

from langgraph.constants import (
CONF,
CONFIG_KEY_CALL,
Expand Down Expand Up @@ -148,9 +150,12 @@ def call(
input: Any,
*,
retry: Optional[RetryPolicy] = None,
callbacks: Callbacks = None,
) -> concurrent.futures.Future[Any]:
(fut,) = writer(
task, [(PUSH, None)], calls=[Call(func, input, retry=retry)]
task,
[(PUSH, None)],
calls=[Call(func, input, retry=retry, callbacks=callbacks)],
)
assert fut is not None, "writer did not return a future for call"
return fut
Expand Down Expand Up @@ -337,9 +342,12 @@ def call(
input: Any,
*,
retry: Optional[RetryPolicy] = None,
callbacks: Callbacks = None,
) -> Union[asyncio.Future[Any], concurrent.futures.Future[Any]]:
(fut,) = writer(
task, [(PUSH, None)], calls=[Call(func, input, retry=retry)]
task,
[(PUSH, None)],
calls=[Call(func, input, retry=retry, callbacks=callbacks)],
)
assert fut is not None, "writer did not return a future for call"
if asyncio.iscoroutinefunction(func):
Expand Down
4 changes: 2 additions & 2 deletions libs/langgraph/langgraph/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,9 +453,9 @@ def node(state: State):
RESUME,
)
from langgraph.errors import GraphInterrupt
from langgraph.utils.config import get_configurable
from langgraph.utils.config import get_config

conf = get_configurable()
conf = get_config()["configurable"]
# track interrupt index
scratchpad: PregelScratchpad = conf[CONFIG_KEY_SCRATCHPAD]
if "interrupt_counter" not in scratchpad:
Expand Down
6 changes: 3 additions & 3 deletions libs/langgraph/langgraph/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
def patch_config(
config: Optional[RunnableConfig],
*,
callbacks: Optional[Callbacks] = None,
callbacks: Callbacks = None,
recursion_limit: Optional[int] = None,
max_concurrency: Optional[int] = None,
run_name: Optional[str] = None,
Expand Down Expand Up @@ -304,7 +304,7 @@ def ensure_config(*configs: Optional[RunnableConfig]) -> RunnableConfig:
return empty


def get_configurable() -> dict[str, Any]:
def get_config() -> RunnableConfig:
if sys.version_info < (3, 11):
try:
if asyncio.current_task():
Expand All @@ -314,6 +314,6 @@ def get_configurable() -> dict[str, Any]:
except RuntimeError:
pass
if var_config := var_child_runnable_config.get():
return var_config[CONF]
return var_config
else:
raise RuntimeError("Called get_configurable outside of a runnable context")
7 changes: 6 additions & 1 deletion libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -2456,6 +2456,7 @@ async def test_imp_task(checkpointer_name: str) -> None:
async def mapper(input: int) -> str:
nonlocal mapper_calls
mapper_calls += 1
await asyncio.sleep(0.1 * input)
return str(input) * 2

@entrypoint(checkpointer=checkpointer)
Expand All @@ -2465,7 +2466,8 @@ async def graph(input: list[int]) -> list[str]:
answer = interrupt("question")
return [m + answer for m in mapped]

thread1 = {"configurable": {"thread_id": "1"}}
tracer = FakeTracer()
thread1 = {"configurable": {"thread_id": "1"}, "callbacks": [tracer]}
assert [c async for c in graph.astream([0, 1], thread1)] == [
{"mapper": "00"},
{"mapper": "11"},
Expand All @@ -2481,6 +2483,9 @@ async def graph(input: list[int]) -> list[str]:
},
]
assert mapper_calls == 2
assert len(tracer.runs) == 1
assert len(tracer.runs[0].child_runs) == 1
assert tracer.runs[0].child_runs[0].name == "graph"

assert await graph.ainvoke(Command(resume="answer"), thread1) == [
"00answer",
Expand Down

0 comments on commit d12830e

Please sign in to comment.