Skip to content

Commit

Permalink
langgraph: more unittests for functional API (#3221)
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarda authored Jan 28, 2025
1 parent 69dc5f3 commit 7568862
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 6 deletions.
79 changes: 75 additions & 4 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5950,7 +5950,7 @@ def test_multiple_subgraphs_functional(

# Define addition subgraph
@entrypoint()
def add(inputs):
def add(inputs: tuple[int, int]):
a, b = inputs
return a + b

Expand All @@ -5960,7 +5960,7 @@ def multiply_task(a, b):
return a * b

@entrypoint()
def multiply(inputs):
def multiply(inputs: tuple[int, int]):
return multiply_task(*inputs).result()

# Test calling the same subgraph multiple times
Expand Down Expand Up @@ -5993,9 +5993,10 @@ def parent_call_multiple_subgraphs(inputs):


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_multiple_subgraphs_mixed(
def test_multiple_subgraphs_mixed_entrypoint(
request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
"""Test calling multiple StateGraph subgraphs from an entrypoint."""
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")

class State(TypedDict):
Expand Down Expand Up @@ -6053,7 +6054,77 @@ def parent_call_multiple_subgraphs(inputs):


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_multiple_subgraphs_mixed_checkpointer(
def test_multiple_subgraphs_mixed_state_graph(
request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
"""Test calling multiple entrypoint "subgraphs" from a StateGraph."""
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")

class State(TypedDict):
a: int
b: int

class Output(TypedDict):
result: int

# Define addition subgraph
@entrypoint()
def add(inputs: tuple[int, int]):
a, b = inputs
return a + b

# Define multiplication subgraph using tasks
@task
def multiply_task(a, b):
return a * b

@entrypoint()
def multiply(inputs: tuple[int, int]):
return multiply_task(*inputs).result()

# Test calling the same subgraph multiple times
def call_same_subgraph(state):
result = add.invoke([state["a"], state["b"]])
another_result = add.invoke([result, 10])
return {"result": another_result}

parent_call_same_subgraph = (
StateGraph(State, output=Output)
.add_node(call_same_subgraph)
.add_edge(START, "call_same_subgraph")
.compile(checkpointer=checkpointer)
)
config = {"configurable": {"thread_id": "1"}}
assert parent_call_same_subgraph.invoke({"a": 2, "b": 3}, config) == {"result": 15}

# Test calling multiple subgraphs
class Output(TypedDict):
add_result: int
multiply_result: int

def call_multiple_subgraphs(state):
add_result = add.invoke([state["a"], state["b"]])
multiply_result = multiply.invoke([state["a"], state["b"]])
return {
"add_result": add_result,
"multiply_result": multiply_result,
}

parent_call_multiple_subgraphs = (
StateGraph(State, output=Output)
.add_node(call_multiple_subgraphs)
.add_edge(START, "call_multiple_subgraphs")
.compile(checkpointer=checkpointer)
)
config = {"configurable": {"thread_id": "2"}}
assert parent_call_multiple_subgraphs.invoke({"a": 2, "b": 3}, config) == {
"add_result": 5,
"multiply_result": 6,
}


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_multiple_subgraphs_checkpointer(
request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")
Expand Down
81 changes: 79 additions & 2 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -7131,7 +7131,9 @@ async def parent_call_multiple_subgraphs(inputs):

@NEEDS_CONTEXTVARS
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_multiple_subgraphs_mixed(checkpointer_name: str) -> None:
async def test_multiple_subgraphs_mixed_entrypoint(checkpointer_name: str) -> None:
"""Test calling multiple StateGraph subgraphs from an entrypoint."""

class State(TypedDict):
a: int
b: int
Expand Down Expand Up @@ -7196,7 +7198,82 @@ async def parent_call_multiple_subgraphs(inputs):

@NEEDS_CONTEXTVARS
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_multiple_subgraphs_mixed_checkpointer(
async def test_multiple_subgraphs_mixed_state_graph(
request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
"""Test calling multiple entrypoint "subgraphs" from a StateGraph."""
async with awith_checkpointer(checkpointer_name) as checkpointer:

class State(TypedDict):
a: int
b: int

class Output(TypedDict):
result: int

# Define addition subgraph
@entrypoint()
async def add(inputs):
a, b = inputs
return a + b

# Define multiplication subgraph using tasks
@task
async def multiply_task(a, b):
return a * b

@entrypoint()
async def multiply(inputs):
return await multiply_task(*inputs)

# Test calling the same subgraph multiple times
async def call_same_subgraph(state):
result = await add.ainvoke([state["a"], state["b"]])
another_result = await add.ainvoke([result, 10])
return {"result": another_result}

parent_call_same_subgraph = (
StateGraph(State, output=Output)
.add_node(call_same_subgraph)
.add_edge(START, "call_same_subgraph")
.compile(checkpointer=checkpointer)
)
config = {"configurable": {"thread_id": "1"}}
assert await parent_call_same_subgraph.ainvoke({"a": 2, "b": 3}, config) == {
"result": 15
}

# Test calling multiple subgraphs
class Output(TypedDict):
add_result: int
multiply_result: int

async def call_multiple_subgraphs(state):
add_result = await add.ainvoke([state["a"], state["b"]])
multiply_result = await multiply.ainvoke([state["a"], state["b"]])
return {
"add_result": add_result,
"multiply_result": multiply_result,
}

parent_call_multiple_subgraphs = (
StateGraph(State, output=Output)
.add_node(call_multiple_subgraphs)
.add_edge(START, "call_multiple_subgraphs")
.compile(checkpointer=checkpointer)
)
config = {"configurable": {"thread_id": "2"}}
assert await parent_call_multiple_subgraphs.ainvoke(
{"a": 2, "b": 3}, config
) == {
"add_result": 5,
"multiply_result": 6,
}


@NEEDS_CONTEXTVARS
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_multiple_subgraphs_checkpointer(
request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
async with awith_checkpointer(checkpointer_name) as checkpointer:
Expand Down

0 comments on commit 7568862

Please sign in to comment.