Skip to content

Commit

Permalink
Remove send v2 (#3033)
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos authored Jan 15, 2025
2 parents b18d266 + 0b74e25 commit c7e43f8
Show file tree
Hide file tree
Showing 13 changed files with 552 additions and 1,289 deletions.
3 changes: 0 additions & 3 deletions libs/langgraph/langgraph/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import sys
from os import getenv
from types import MappingProxyType
from typing import Any, Literal, Mapping, cast

Expand Down Expand Up @@ -93,8 +92,6 @@
# for checkpoint_ns, for each level, separates the namespace from the task_id
CONF = cast(Literal["configurable"], sys.intern("configurable"))
# key for the configurable dict in RunnableConfig
FF_SEND_V2 = getenv("LANGGRAPH_FF_SEND_V2", "false").lower() == "true"
# temporary flag to enable new Send semantics
NULL_TASK_ID = sys.intern("00000000-0000-0000-0000-000000000000")
# the task_id to use for writes that are not associated with a task

Expand Down
6 changes: 4 additions & 2 deletions libs/langgraph/langgraph/pregel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,7 +966,7 @@ def update_state(
return patch_checkpoint_map(
next_config, saved.metadata if saved else None
)
# no values, copy checkpoint
# no values, empty checkpoint
if values is None and as_node is None:
next_checkpoint = create_checkpoint(checkpoint, None, step)
# copy checkpoint
Expand All @@ -985,6 +985,7 @@ def update_state(
return patch_checkpoint_map(
next_config, saved.metadata if saved else None
)
# no values, copy checkpoint
if values is None and as_node == "__copy__":
next_checkpoint = create_checkpoint(checkpoint, None, step)
# copy checkpoint
Expand Down Expand Up @@ -1248,7 +1249,7 @@ async def aupdate_state(
return patch_checkpoint_map(
next_config, saved.metadata if saved else None
)
# no values, copy checkpoint
# no values, empty checkpoint
if values is None and as_node is None:
next_checkpoint = create_checkpoint(checkpoint, None, step)
# copy checkpoint
Expand All @@ -1267,6 +1268,7 @@ async def aupdate_state(
return patch_checkpoint_map(
next_config, saved.metadata if saved else None
)
# no values, copy checkpoint
if values is None and as_node == "__copy__":
next_checkpoint = create_checkpoint(checkpoint, None, step)
# copy checkpoint
Expand Down
113 changes: 10 additions & 103 deletions libs/langgraph/langgraph/pregel/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def apply_writes(
# sort tasks on path, to ensure deterministic order for update application
# any path parts after the 3rd are ignored for sorting
# (we use them for eg. task ids which aren't good for sorting)
tasks = sorted(tasks, key=lambda t: t.path[:3])
tasks = sorted(tasks, key=lambda t: _tuple_str(t.path[:3]))
# if no task has triggers this is applying writes from the null task only
# so we don't do anything other than update the channels written to
bump_step = any(t.triggers for t in tasks)
Expand Down Expand Up @@ -273,7 +273,7 @@ def apply_writes(
for chan, val in task.writes:
if chan in (NO_WRITES, PUSH, RESUME, INTERRUPT, RETURN, ERROR):
pass
elif chan == TASKS: # TODO: remove branch in 1.0
elif chan == TASKS:
checkpoint["pending_sends"].append(val)
elif chan in channels:
pending_writes_by_channel[chan].append(val)
Expand Down Expand Up @@ -363,8 +363,8 @@ def prepare_next_tasks(
This is the union of all PUSH tasks (Sends) and PULL tasks (nodes triggered
by edges)."""
tasks: list[Union[PregelTask, PregelExecutableTask]] = []
# Consume pending_sends from previous step (legacy version of Send)
for idx, _ in enumerate(checkpoint["pending_sends"]): # TODO: remove branch in 1.0
# Consume pending_sends from previous step
for idx, _ in enumerate(checkpoint["pending_sends"]):
if task := prepare_single_task(
(PUSH, idx),
None,
Expand Down Expand Up @@ -400,65 +400,7 @@ def prepare_next_tasks(
manager=manager,
):
tasks.append(task)
# Consume pending Sends from this step (new version of Send)
if any(c == PUSH for _, c, _ in pending_writes):
# group writes by task id
grouped_by_task = defaultdict(list)
for tid, c, _ in pending_writes:
grouped_by_task[tid].append(c)
# prepare send tasks from grouped writes
# 1. start from sends originating from existing tasks
tidx = 0
while tidx < len(tasks):
task = tasks[tidx]
if twrites := grouped_by_task.pop(task.id, None):
for idx, c in enumerate(twrites):
if c != PUSH:
continue
if next_task := prepare_single_task(
(PUSH, task.path, idx, task.id),
None,
checkpoint=checkpoint,
pending_writes=pending_writes,
processes=processes,
channels=channels,
managed=managed,
config=config,
step=step,
for_execution=for_execution,
store=store,
checkpointer=checkpointer,
manager=manager,
):
tasks.append(next_task)
tidx += 1
# key tasks by id
task_map = {t.id: t for t in tasks}
# 2. create new tasks for remaining sends (eg. from update_state)
for tid, writes in grouped_by_task.items():
task = task_map.get(tid)
for idx, c in enumerate(writes):
if c != PUSH:
continue
if next_task := prepare_single_task(
(PUSH, task.path if task else (), idx, tid),
None,
checkpoint=checkpoint,
pending_writes=pending_writes,
processes=processes,
channels=channels,
managed=managed,
config=config,
step=step,
for_execution=for_execution,
store=store,
checkpointer=checkpointer,
manager=manager,
):
task_map[next_task.id] = next_task
else:
task_map = {t.id: t for t in tasks}
return task_map
return {t.id: t for t in tasks}


def prepare_single_task(
Expand Down Expand Up @@ -571,8 +513,8 @@ def prepare_single_task(
else:
return PregelTask(task_id, name, task_path[:3])
elif task_path[0] == PUSH:
if len(task_path) == 2: # TODO: remove branch in 1.0
# legacy SEND tasks, executed in superstep n+1
if len(task_path) == 2:
# SEND tasks, executed in superstep n+1
# (PUSH, idx of pending send)
idx = cast(int, task_path[1])
if idx >= len(checkpoint["pending_sends"]):
Expand Down Expand Up @@ -601,43 +543,6 @@ def prepare_single_task(
PUSH,
str(idx),
)
elif len(task_path) >= 4:
# new PUSH tasks, executed in superstep n
# (PUSH, parent task path, idx of PUSH write, id of parent task)
task_path_tt = cast(tuple[str, tuple, int, str], task_path)
writes_for_path = [w for w in pending_writes if w[0] == task_path_tt[3]]
if task_path_tt[2] >= len(writes_for_path):
logger.warning(
f"Ignoring invalid write index {task_path[2]} in pending writes"
)
return
packet = writes_for_path[task_path_tt[2]][2]
if packet is None:
return
if not isinstance(packet, Send):
logger.warning(
f"Ignoring invalid packet type {type(packet)} in pending writes"
)
return
if packet.node not in processes:
logger.warning(
f"Ignoring unknown node name {packet.node} in pending writes"
)
return
# create task id
triggers = [PUSH]
checkpoint_ns = (
f"{parent_ns}{NS_SEP}{packet.node}" if parent_ns else packet.node
)
task_id = _uuid5_str(
checkpoint_id,
checkpoint_ns,
str(step),
packet.node,
PUSH,
_tuple_str(task_path[1]),
str(task_path[2]),
)
else:
logger.warning(f"Ignoring invalid PUSH task path {task_path}")
return
Expand Down Expand Up @@ -904,7 +809,9 @@ def _uuid5_str(namespace: bytes, *parts: str) -> str:
def _tuple_str(tup: Union[str, int, tuple]) -> str:
"""Generate a string representation of a tuple."""
return (
f"({', '.join(_tuple_str(x) for x in tup)})"
f"~{', '.join(_tuple_str(x) for x in tup)}"
if isinstance(tup, (tuple, list))
else f"{tup:010d}"
if isinstance(tup, int)
else str(tup)
)
4 changes: 1 addition & 3 deletions libs/langgraph/langgraph/pregel/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
from langgraph.constants import (
EMPTY_SEQ,
ERROR,
FF_SEND_V2,
INTERRUPT,
NULL_TASK_ID,
PUSH,
RESUME,
RETURN,
SELF,
Expand Down Expand Up @@ -83,7 +81,7 @@ def map_command(
sends = [cmd.goto]
for send in sends:
if isinstance(send, Send):
yield (NULL_TASK_ID, PUSH if FF_SEND_V2 else TASKS, send)
yield (NULL_TASK_ID, TASKS, send)
elif isinstance(send, str):
yield (NULL_TASK_ID, f"branch:{START}:{SELF}:{send}", START)
else:
Expand Down
8 changes: 5 additions & 3 deletions libs/langgraph/langgraph/pregel/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,9 +699,6 @@ def _suppress_interrupt(
traceback: Optional[TracebackType],
) -> Optional[bool]:
suppress = isinstance(exc_value, GraphInterrupt) and not self.is_nested
if suppress or exc_type is None:
# save final output
self.output = read_channels(self.channels, self.output_keys)
if suppress:
# emit one last "values" event, with pending writes applied
if (
Expand Down Expand Up @@ -729,8 +726,13 @@ def _suppress_interrupt(
"updates",
lambda: iter([{INTERRUPT: cast(GraphInterrupt, exc_value).args[0]}]),
)
# save final output
self.output = read_channels(self.channels, self.output_keys)
# suppress interrupt
return True
elif exc_type is None:
# save final output
self.output = read_channels(self.channels, self.output_keys)

def _emit(
self,
Expand Down
6 changes: 3 additions & 3 deletions libs/langgraph/langgraph/pregel/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.runnables.utils import ConfigurableFieldSpec

from langgraph.constants import CONF, CONFIG_KEY_SEND, FF_SEND_V2, PUSH, TASKS, Send
from langgraph.constants import CONF, CONFIG_KEY_SEND, TASKS, Send
from langgraph.errors import InvalidUpdateError
from langgraph.utils.runnable import RunnableCallable

Expand Down Expand Up @@ -125,7 +125,7 @@ def do_write(
# validate
for w in writes:
if isinstance(w, ChannelWriteEntry):
if w.channel in (TASKS, PUSH):
if w.channel == TASKS:
raise InvalidUpdateError(
"Cannot write to the reserved channel TASKS"
)
Expand All @@ -138,7 +138,7 @@ def do_write(
tuples: list[tuple[str, Any]] = []
for w in writes:
if isinstance(w, Send):
tuples.append((PUSH if FF_SEND_V2 else TASKS, w))
tuples.append((TASKS, w))
elif isinstance(w, ChannelWriteTupleEntry):
if ww := w.mapper(w.value):
tuples.extend(ww)
Expand Down
26 changes: 25 additions & 1 deletion libs/langgraph/tests/test_algo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from langgraph.checkpoint.base import empty_checkpoint
from langgraph.pregel.algo import prepare_next_tasks
from langgraph.constants import PULL, PUSH
from langgraph.pregel.algo import _tuple_str, prepare_next_tasks
from langgraph.pregel.manager import ChannelsManager


Expand Down Expand Up @@ -40,3 +41,26 @@ def test_prepare_next_tasks() -> None:
)

# TODO: add more tests


def test_tuple_str() -> None:
push_path_a = (PUSH, 2)
pull_path_a = (PULL, "abc")
push_path_b = (PUSH, push_path_a, 1)
push_path_c = (PUSH, push_path_b, 3)

assert _tuple_str(push_path_a) == f"~{PUSH}, 0000000002"
assert _tuple_str(push_path_b) == f"~{PUSH}, ~{PUSH}, 0000000002, 0000000001"
assert (
_tuple_str(push_path_c)
== f"~{PUSH}, ~{PUSH}, ~{PUSH}, 0000000002, 0000000001, 0000000003"
)
assert _tuple_str(pull_path_a) == f"~{PULL}, abc"

path_list = [push_path_b, push_path_a, pull_path_a, push_path_c]
assert sorted(map(_tuple_str, path_list)) == [
f"~{PULL}, abc",
f"~{PUSH}, 0000000002",
f"~{PUSH}, ~{PUSH}, 0000000002, 0000000001",
f"~{PUSH}, ~{PUSH}, ~{PUSH}, 0000000002, 0000000001, 0000000003",
]
Loading

0 comments on commit c7e43f8

Please sign in to comment.