From 1059ef55d15ba8f73b40f34bafdea0630f3382c3 Mon Sep 17 00:00:00 2001 From: Vadym Barda Date: Thu, 23 Jan 2025 14:41:22 -0500 Subject: [PATCH] langgraph: handle node return annotations with unions (#3170) --- libs/langgraph/langgraph/graph/state.py | 29 ++++++++++++++++++------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/libs/langgraph/langgraph/graph/state.py b/libs/langgraph/langgraph/graph/state.py index 16bbf7d48..51b12c08e 100644 --- a/libs/langgraph/langgraph/graph/state.py +++ b/libs/langgraph/langgraph/graph/state.py @@ -379,14 +379,27 @@ def add_node( if input_hint := hints.get(first_parameter_name): if isinstance(input_hint, type) and get_type_hints(input_hint): input = input_hint - if ( - (rtn := hints.get("return")) - and get_origin(rtn) is Command - and (rargs := get_args(rtn)) - and get_origin(rargs[0]) is Literal - and (vals := get_args(rargs[0])) - ): - ends = vals + if rtn := hints.get("return"): + # Handle Union types + rtn_origin = get_origin(rtn) + if rtn_origin is Union: + rtn_args = get_args(rtn) + # Look for Command in the union + for arg in rtn_args: + arg_origin = get_origin(arg) + if arg_origin is Command: + rtn = arg + rtn_origin = arg_origin + break + + # Check if it's a Command type + if ( + rtn_origin is Command + and (rargs := get_args(rtn)) + and get_origin(rargs[0]) is Literal + and (vals := get_args(rargs[0])) + ): + ends = vals except (TypeError, StopIteration): pass if input is not None: