Skip to content

Commit

Permalink
langgraph: handle node return annotations with unions (#3170)
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarda authored Jan 23, 2025
1 parent 3955225 commit 1059ef5
Showing 1 changed file with 21 additions and 8 deletions.
29 changes: 21 additions & 8 deletions libs/langgraph/langgraph/graph/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,14 +379,27 @@ def add_node(
if input_hint := hints.get(first_parameter_name):
if isinstance(input_hint, type) and get_type_hints(input_hint):
input = input_hint
if (
(rtn := hints.get("return"))
and get_origin(rtn) is Command
and (rargs := get_args(rtn))
and get_origin(rargs[0]) is Literal
and (vals := get_args(rargs[0]))
):
ends = vals
if rtn := hints.get("return"):
# Handle Union types
rtn_origin = get_origin(rtn)
if rtn_origin is Union:
rtn_args = get_args(rtn)
# Look for Command in the union
for arg in rtn_args:
arg_origin = get_origin(arg)
if arg_origin is Command:
rtn = arg
rtn_origin = arg_origin
break

# Check if it's a Command type
if (
rtn_origin is Command
and (rargs := get_args(rtn))
and get_origin(rargs[0]) is Literal
and (vals := get_args(rargs[0]))
):
ends = vals
except (TypeError, StopIteration):
pass
if input is not None:
Expand Down

0 comments on commit 1059ef5

Please sign in to comment.