Skip to content

Commit

Permalink
genai: fix pydantic structured_output with array (#469)
Browse files Browse the repository at this point in the history
  • Loading branch information
nobu007 authored Oct 7, 2024
1 parent 693a1de commit 5594bc1
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 22 deletions.
96 changes: 85 additions & 11 deletions libs/genai/langchain_google_genai/_function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,20 +273,97 @@ def _convert_pydantic_to_genai_function(
name=tool_name if tool_name else schema.get("title"),
description=tool_description if tool_description else schema.get("description"),
parameters={
"properties": {
k: {
"type_": _get_type_from_schema(v),
"description": v.get("description"),
}
for k, v in schema["properties"].items()
},
"properties": _get_properties_from_schema_any(
schema.get("properties")
), # TODO: use _dict_to_gapic_schema() if possible
# "items": _get_items_from_schema_any(
# schema
# ), # TODO: fix it https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/function-calling?hl#schema
"required": schema.get("required", []),
"type_": TYPE_ENUM[schema["type"]],
},
)
return function_declaration


def _get_properties_from_schema_any(schema: Any) -> Dict[str, Any]:
if isinstance(schema, Dict):
return _get_properties_from_schema(schema)
return {}


def _get_properties_from_schema(schema: Dict) -> Dict[str, Any]:
properties = {}
for k, v in schema.items():
if not isinstance(k, str):
logger.warning(f"Key '{k}' is not supported in schema, type={type(k)}")
continue
if not isinstance(v, Dict):
logger.warning(f"Value '{v}' is not supported in schema, ignoring v={v}")
continue
properties_item: Dict[str, Union[str, int, Dict, List]] = {}
if v.get("type") or v.get("anyOf"):
properties_item["type_"] = _get_type_from_schema(v)

if v.get("enum"):
properties_item["enum"] = v["enum"]

description = v.get("description")
if description and isinstance(description, str):
properties_item["description"] = description

if v.get("type") == "array" and v.get("items"):
properties_item["items"] = _get_items_from_schema_any(v.get("items"))

if v.get("type") == "object" and v.get("properties"):
properties_item["properties"] = _get_properties_from_schema_any(
v.get("properties")
)
if k == "title" and "description" not in properties_item:
properties_item["description"] = k + " is " + str(v)

properties[k] = properties_item

return properties


def _get_items_from_schema_any(schema: Any) -> Dict[str, Any]:
if isinstance(schema, Dict):
return _get_items_from_schema(schema)
if isinstance(schema, List):
return _get_items_from_schema(schema)
if isinstance(schema, str):
return _get_items_from_schema(schema)
return {}


def _get_items_from_schema(schema: Union[Dict, List, str]) -> Dict[str, Any]:
items: Dict = {}
if isinstance(schema, List):
for i, v in enumerate(schema):
items[f"item{i}"] = _get_properties_from_schema_any(v)
elif isinstance(schema, Dict):
item: Dict = {}
for k, v in schema.items():
item["type_"] = _get_type_from_schema(v)
if not isinstance(v, Dict):
logger.warning(
f"Value '{v}' is not supported in schema, ignoring v={v}"
)
continue
if v.get("type") == "object" and v.get("properties"):
item["properties"] = _get_properties_from_schema_any(
v.get("properties")
)
if k == "title" and "description" not in item:
item["description"] = v
items = item
else:
# str
items["type_"] = TYPE_ENUM.get(str(schema), glm.Type.STRING)
return items


def _get_type_from_schema(schema: Dict[str, Any]) -> int:
if "anyOf" in schema:
types = [_get_type_from_schema(sub_schema) for sub_schema in schema["anyOf"]]
Expand All @@ -297,10 +374,7 @@ def _get_type_from_schema(schema: Dict[str, Any]) -> int:
pass
elif "type" in schema:
stype = str(schema["type"])
if stype in TYPE_ENUM:
return TYPE_ENUM[stype]
else:
pass
return TYPE_ENUM.get(stype, glm.Type.STRING)
else:
pass
return TYPE_ENUM["string"] # Default to string if no valid types found
Expand Down
36 changes: 25 additions & 11 deletions libs/genai/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,43 +335,55 @@ def _check_tool_calls(response: BaseMessage, expected_name: str) -> None:
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert response.content == ""

# function_call
function_call = response.additional_kwargs.get("function_call")
assert function_call
assert function_call["name"] == expected_name
arguments_str = function_call.get("arguments")
assert arguments_str
arguments = json.loads(arguments_str)
assert arguments == {
"name": "Erick",
"age": 27.0,
}
_check_tool_call_args(arguments)

# tool_calls
tool_calls = response.tool_calls
assert len(tool_calls) == 1
tool_call = tool_calls[0]
assert tool_call["name"] == expected_name
assert tool_call["args"] == {"age": 27.0, "name": "Erick"}
_check_tool_call_args(tool_call["args"])


def _check_tool_call_args(tool_call_args: dict) -> None:
assert tool_call_args == {
"age": 27.0,
"name": "Erick",
"likes": ["apple", "banana"],
}


@pytest.mark.extended
def test_chat_vertexai_gemini_function_calling() -> None:
class MyModel(BaseModel):
name: str
age: int
likes: list[str]

safety: Dict[HarmCategory, HarmBlockThreshold] = {
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH # type: ignore[dict-item]
}
# Test .bind_tools with BaseModel
message = HumanMessage(content="My name is Erick and I am 27 years old")
message = HumanMessage(
content="My name is Erick and I am 27 years old. I like apple and banana."
)
model = ChatGoogleGenerativeAI(model=_MODEL, safety_settings=safety).bind_tools(
[MyModel]
)
response = model.invoke([message])
_check_tool_calls(response, "MyModel")

# Test .bind_tools with function
def my_model(name: str, age: int) -> None:
"""Invoke this with names and ages."""
def my_model(name: str, age: int, likes: list[str]) -> None:
"""Invoke this with names and age and likes."""
pass

model = ChatGoogleGenerativeAI(model=_MODEL, safety_settings=safety).bind_tools(
Expand All @@ -382,8 +394,8 @@ def my_model(name: str, age: int) -> None:

# Test .bind_tools with tool
@tool
def my_tool(name: str, age: int) -> None:
"""Invoke this with names and ages."""
def my_tool(name: str, age: int, likes: list[str]) -> None:
"""Invoke this with names and age and likes."""
pass

model = ChatGoogleGenerativeAI(model=_MODEL, safety_settings=safety).bind_tools(
Expand All @@ -405,7 +417,9 @@ def my_tool(name: str, age: int) -> None:
assert len(gathered.tool_call_chunks) == 1
tool_call_chunk = gathered.tool_call_chunks[0]
assert tool_call_chunk["name"] == "my_tool"
assert tool_call_chunk["args"] == '{"age": 27.0, "name": "Erick"}'
arguments_str = tool_call_chunk["args"]
arguments = json.loads(str(arguments_str))
_check_tool_call_args(arguments)


# Test with model that supports tool choice (gemini 1.5) and one that doesn't
Expand Down
1 change: 1 addition & 0 deletions libs/genai/tests/unit_tests/test_function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def test_tool_to_dict_pydantic() -> None:
class MyModel(BaseModel):
name: str
age: int
likes: list[str]

gapic_tool = convert_to_genai_function_declarations([MyModel])
tool_dict = tool_to_dict(gapic_tool)
Expand Down

0 comments on commit 5594bc1

Please sign in to comment.