diff --git a/libs/genai/langchain_google_genai/_function_utils.py b/libs/genai/langchain_google_genai/_function_utils.py index 5ac6c52c..6ce80b06 100644 --- a/libs/genai/langchain_google_genai/_function_utils.py +++ b/libs/genai/langchain_google_genai/_function_utils.py @@ -269,13 +269,12 @@ def _convert_pydantic_to_genai_function( name=tool_name if tool_name else schema.get("title"), description=tool_description if tool_description else schema.get("description"), parameters={ - "properties": { - k: { - "type_": _get_type_from_schema(v), - "description": v.get("description"), - } - for k, v in schema["properties"].items() - }, + "properties": _get_properties_from_schema_any( + schema.get("properties") + ), # TODO: use _dict_to_gapic_schema() if possible + # "items": _get_items_from_schema_any( + # schema + # ), # TODO: fix it https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/function-calling?hl#schema "required": schema.get("required", []), "type_": TYPE_ENUM[schema["type"]], }, @@ -283,6 +282,84 @@ def _convert_pydantic_to_genai_function( return function_declaration +def _get_properties_from_schema_any(schema: Any) -> Dict[str, Any]: + if isinstance(schema, Dict): + return _get_properties_from_schema(schema) + return {} + + +def _get_properties_from_schema(schema: Dict) -> Dict[str, Any]: + properties = {} + for k, v in schema.items(): + if not isinstance(k, str): + logger.warning(f"Key '{k}' is not supported in schema, type={type(k)}") + continue + if not isinstance(v, Dict): + logger.warning(f"Value '{v}' is not supported in schema, ignoring v={v}") + continue + properties_item: Dict[str, Union[str, int, Dict, List]] = {} + if v.get("type"): + properties_item["type_"] = _get_type_from_schema(v) + + if v.get("enum"): + properties_item["enum"] = v["enum"] + + description = v.get("description") + if description and isinstance(description, str): + properties_item["description"] = description + + if v.get("type") == "array" and v.get("items"): + properties_item["items"] = _get_items_from_schema_any(v.get("items")) + + if v.get("type") == "object" and v.get("properties"): + properties_item["properties"] = _get_properties_from_schema_any( + v.get("properties") + ) + if k == "title" and "description" not in properties_item: + properties_item["description"] = k + " is " + str(v) + + properties[k] = properties_item + + return properties + + +def _get_items_from_schema_any(schema: Any) -> Dict[str, Any]: + if isinstance(schema, Dict): + return _get_items_from_schema(schema) + if isinstance(schema, List): + return _get_items_from_schema(schema) + if isinstance(schema, str): + return _get_items_from_schema(schema) + return {} + + +def _get_items_from_schema(schema: Union[Dict, List, str]) -> Dict[str, Any]: + items: Dict = {} + if isinstance(schema, List): + for i, v in enumerate(schema): + items[f"item{i}"] = _get_properties_from_schema_any(v) + elif isinstance(schema, Dict): + item: Dict = {} + for k, v in schema.items(): + item["type_"] = _get_type_from_schema(v) + if not isinstance(v, Dict): + logger.warning( + f"Value '{v}' is not supported in schema, ignoring v={v}" + ) + continue + if v.get("type") == "object" and v.get("properties"): + item["properties"] = _get_properties_from_schema_any( + v.get("properties") + ) + if k == "title" and "description" not in item: + item["description"] = v + items = item + else: + # str + items["type_"] = _get_type_from_str(str(schema)) + return items + + def _get_type_from_schema(schema: Dict[str, Any]) -> int: if "anyOf" in schema: types = [_get_type_from_schema(sub_schema) for sub_schema in schema["anyOf"]] @@ -293,15 +370,18 @@ def _get_type_from_schema(schema: Dict[str, Any]) -> int: pass elif "type" in schema: stype = str(schema["type"]) - if stype in TYPE_ENUM: - return TYPE_ENUM[stype] - else: - pass + return _get_type_from_str(stype) else: pass return TYPE_ENUM["string"] # Default to string if no valid types found +def _get_type_from_str(stype: str) -> int: + if stype in TYPE_ENUM: + return TYPE_ENUM[stype] + return TYPE_ENUM["string"] # Default to string if no valid types found + + _ToolChoiceType = Union[ dict, List[str], str, Literal["auto", "none", "any"], Literal[True] ] diff --git a/libs/genai/tests/integration_tests/test_chat_models.py b/libs/genai/tests/integration_tests/test_chat_models.py index dc62d09a..056b71ef 100644 --- a/libs/genai/tests/integration_tests/test_chat_models.py +++ b/libs/genai/tests/integration_tests/test_chat_models.py @@ -335,21 +335,30 @@ def _check_tool_calls(response: BaseMessage, expected_name: str) -> None: assert isinstance(response, AIMessage) assert isinstance(response.content, str) assert response.content == "" + + # function_call function_call = response.additional_kwargs.get("function_call") assert function_call assert function_call["name"] == expected_name arguments_str = function_call.get("arguments") assert arguments_str arguments = json.loads(arguments_str) - assert arguments == { - "name": "Erick", - "age": 27.0, - } + _check_tool_call_args(arguments) + + # tool_calls tool_calls = response.tool_calls assert len(tool_calls) == 1 tool_call = tool_calls[0] assert tool_call["name"] == expected_name - assert tool_call["args"] == {"age": 27.0, "name": "Erick"} + _check_tool_call_args(tool_call["args"]) + + +def _check_tool_call_args(tool_call_args: dict) -> None: + assert tool_call_args == { + "age": 27.0, + "name": "Erick", + "likes": ["apple", "banana"], + } @pytest.mark.extended @@ -357,21 +366,25 @@ def test_chat_vertexai_gemini_function_calling() -> None: class MyModel(BaseModel): name: str age: int + likes: list[str] safety: Dict[HarmCategory, HarmBlockThreshold] = { HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH # type: ignore[dict-item] } # Test .bind_tools with BaseModel - message = HumanMessage(content="My name is Erick and I am 27 years old") + message = HumanMessage( + content="My name is Erick and I am 27 years old. I like apple and banana." + ) model = ChatGoogleGenerativeAI(model=_MODEL, safety_settings=safety).bind_tools( [MyModel] ) response = model.invoke([message]) + print("response=", response) _check_tool_calls(response, "MyModel") # Test .bind_tools with function - def my_model(name: str, age: int) -> None: - """Invoke this with names and ages.""" + def my_model(name: str, age: int, likes: list[str]) -> None: + """Invoke this with names and age and likes.""" pass model = ChatGoogleGenerativeAI(model=_MODEL, safety_settings=safety).bind_tools( @@ -382,8 +395,8 @@ def my_model(name: str, age: int) -> None: # Test .bind_tools with tool @tool - def my_tool(name: str, age: int) -> None: - """Invoke this with names and ages.""" + def my_tool(name: str, age: int, likes: list[str]) -> None: + """Invoke this with names and age and likes.""" pass model = ChatGoogleGenerativeAI(model=_MODEL, safety_settings=safety).bind_tools( @@ -405,7 +418,9 @@ def my_tool(name: str, age: int) -> None: assert len(gathered.tool_call_chunks) == 1 tool_call_chunk = gathered.tool_call_chunks[0] assert tool_call_chunk["name"] == "my_tool" - assert tool_call_chunk["args"] == '{"age": 27.0, "name": "Erick"}' + arguments_str = tool_call_chunk["args"] + arguments = json.loads(str(arguments_str)) + _check_tool_call_args(arguments) # Test with model that supports tool choice (gemini 1.5) and one that doesn't diff --git a/libs/genai/tests/unit_tests/test_function_utils.py b/libs/genai/tests/unit_tests/test_function_utils.py index 3e8a5e88..e1508d46 100644 --- a/libs/genai/tests/unit_tests/test_function_utils.py +++ b/libs/genai/tests/unit_tests/test_function_utils.py @@ -136,6 +136,7 @@ def test_tool_to_dict_pydantic() -> None: class MyModel(BaseModel): name: str age: int + likes: list[str] tool = convert_to_genai_function_declarations([MyModel]) tool_dict = tool_to_dict(tool)