From 7568862013524dad001cfd3a931b8a83a91e9cc5 Mon Sep 17 00:00:00 2001 From: Vadym Barda Date: Mon, 27 Jan 2025 19:41:15 -0500 Subject: [PATCH] langgraph: more unittests for functional API (#3221) --- libs/langgraph/tests/test_pregel.py | 79 ++++++++++++++++++++-- libs/langgraph/tests/test_pregel_async.py | 81 ++++++++++++++++++++++- 2 files changed, 154 insertions(+), 6 deletions(-) diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index ed1b649562..4ce2410863 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -5950,7 +5950,7 @@ def test_multiple_subgraphs_functional( # Define addition subgraph @entrypoint() - def add(inputs): + def add(inputs: tuple[int, int]): a, b = inputs return a + b @@ -5960,7 +5960,7 @@ def multiply_task(a, b): return a * b @entrypoint() - def multiply(inputs): + def multiply(inputs: tuple[int, int]): return multiply_task(*inputs).result() # Test calling the same subgraph multiple times @@ -5993,9 +5993,10 @@ def parent_call_multiple_subgraphs(inputs): @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) -def test_multiple_subgraphs_mixed( +def test_multiple_subgraphs_mixed_entrypoint( request: pytest.FixtureRequest, checkpointer_name: str ) -> None: + """Test calling multiple StateGraph subgraphs from an entrypoint.""" checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") class State(TypedDict): @@ -6053,7 +6054,77 @@ def parent_call_multiple_subgraphs(inputs): @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) -def test_multiple_subgraphs_mixed_checkpointer( +def test_multiple_subgraphs_mixed_state_graph( + request: pytest.FixtureRequest, checkpointer_name: str +) -> None: + """Test calling multiple entrypoint "subgraphs" from a StateGraph.""" + checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") + + class State(TypedDict): + a: int + b: int + + class Output(TypedDict): + result: int + + # Define addition subgraph + @entrypoint() + def add(inputs: tuple[int, int]): + a, b = inputs + return a + b + + # Define multiplication subgraph using tasks + @task + def multiply_task(a, b): + return a * b + + @entrypoint() + def multiply(inputs: tuple[int, int]): + return multiply_task(*inputs).result() + + # Test calling the same subgraph multiple times + def call_same_subgraph(state): + result = add.invoke([state["a"], state["b"]]) + another_result = add.invoke([result, 10]) + return {"result": another_result} + + parent_call_same_subgraph = ( + StateGraph(State, output=Output) + .add_node(call_same_subgraph) + .add_edge(START, "call_same_subgraph") + .compile(checkpointer=checkpointer) + ) + config = {"configurable": {"thread_id": "1"}} + assert parent_call_same_subgraph.invoke({"a": 2, "b": 3}, config) == {"result": 15} + + # Test calling multiple subgraphs + class Output(TypedDict): + add_result: int + multiply_result: int + + def call_multiple_subgraphs(state): + add_result = add.invoke([state["a"], state["b"]]) + multiply_result = multiply.invoke([state["a"], state["b"]]) + return { + "add_result": add_result, + "multiply_result": multiply_result, + } + + parent_call_multiple_subgraphs = ( + StateGraph(State, output=Output) + .add_node(call_multiple_subgraphs) + .add_edge(START, "call_multiple_subgraphs") + .compile(checkpointer=checkpointer) + ) + config = {"configurable": {"thread_id": "2"}} + assert parent_call_multiple_subgraphs.invoke({"a": 2, "b": 3}, config) == { + "add_result": 5, + "multiply_result": 6, + } + + +@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) +def test_multiple_subgraphs_checkpointer( request: pytest.FixtureRequest, checkpointer_name: str ) -> None: checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 24e53345f3..85c2945b12 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -7131,7 +7131,9 @@ async def parent_call_multiple_subgraphs(inputs): @NEEDS_CONTEXTVARS @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) -async def test_multiple_subgraphs_mixed(checkpointer_name: str) -> None: +async def test_multiple_subgraphs_mixed_entrypoint(checkpointer_name: str) -> None: + """Test calling multiple StateGraph subgraphs from an entrypoint.""" + class State(TypedDict): a: int b: int @@ -7196,7 +7198,82 @@ async def parent_call_multiple_subgraphs(inputs): @NEEDS_CONTEXTVARS @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) -async def test_multiple_subgraphs_mixed_checkpointer( +async def test_multiple_subgraphs_mixed_state_graph( + request: pytest.FixtureRequest, checkpointer_name: str +) -> None: + """Test calling multiple entrypoint "subgraphs" from a StateGraph.""" + async with awith_checkpointer(checkpointer_name) as checkpointer: + + class State(TypedDict): + a: int + b: int + + class Output(TypedDict): + result: int + + # Define addition subgraph + @entrypoint() + async def add(inputs): + a, b = inputs + return a + b + + # Define multiplication subgraph using tasks + @task + async def multiply_task(a, b): + return a * b + + @entrypoint() + async def multiply(inputs): + return await multiply_task(*inputs) + + # Test calling the same subgraph multiple times + async def call_same_subgraph(state): + result = await add.ainvoke([state["a"], state["b"]]) + another_result = await add.ainvoke([result, 10]) + return {"result": another_result} + + parent_call_same_subgraph = ( + StateGraph(State, output=Output) + .add_node(call_same_subgraph) + .add_edge(START, "call_same_subgraph") + .compile(checkpointer=checkpointer) + ) + config = {"configurable": {"thread_id": "1"}} + assert await parent_call_same_subgraph.ainvoke({"a": 2, "b": 3}, config) == { + "result": 15 + } + + # Test calling multiple subgraphs + class Output(TypedDict): + add_result: int + multiply_result: int + + async def call_multiple_subgraphs(state): + add_result = await add.ainvoke([state["a"], state["b"]]) + multiply_result = await multiply.ainvoke([state["a"], state["b"]]) + return { + "add_result": add_result, + "multiply_result": multiply_result, + } + + parent_call_multiple_subgraphs = ( + StateGraph(State, output=Output) + .add_node(call_multiple_subgraphs) + .add_edge(START, "call_multiple_subgraphs") + .compile(checkpointer=checkpointer) + ) + config = {"configurable": {"thread_id": "2"}} + assert await parent_call_multiple_subgraphs.ainvoke( + {"a": 2, "b": 3}, config + ) == { + "add_result": 5, + "multiply_result": 6, + } + + +@NEEDS_CONTEXTVARS +@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) +async def test_multiple_subgraphs_checkpointer( request: pytest.FixtureRequest, checkpointer_name: str ) -> None: async with awith_checkpointer(checkpointer_name) as checkpointer: