Skip to content

Commit

Permalink
feat: Use TypeAliasType to define aliases for union types in genera…
Browse files Browse the repository at this point in the history
…tive models

This is based on the original PR in #4701, just wrapping the typealiases in a try-catch block.

PiperOrigin-RevId: 708367618
  • Loading branch information
yeesian authored and copybara-github committed Dec 20, 2024
1 parent e5e59fe commit 2224c83
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 46 deletions.
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@

genai_requires = (
"pydantic < 3",
"typing_extensions",
"docstring_parser < 1",
)

Expand All @@ -143,7 +144,8 @@
"google-cloud-trace < 2",
"opentelemetry-sdk < 2",
"opentelemetry-exporter-gcp-trace < 2",
"pydantic >= 2.6.3, < 2.10",
"pydantic >= 2.6.3, < 3",
"typing_extensions",
]

evaluation_extra_require = [
Expand Down
3 changes: 1 addition & 2 deletions testing/constraints-langchain.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
langchain
langchain-core
langchain-google-vertexai
pydantic<2.10
langchain-google-vertexai
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,6 @@ def langchain_dump_mock():
yield langchain_dump_mock


@pytest.fixture
def mock_chatvertexai():
with mock.patch("langchain_google_vertexai.ChatVertexAI") as model_mock:
yield model_mock


@pytest.fixture
def cloud_trace_exporter_mock():
with mock.patch.object(
Expand Down Expand Up @@ -166,7 +160,7 @@ def test_initialization(self):
assert agent._location == _TEST_LOCATION
assert agent._runnable is None

def test_initialization_with_tools(self, mock_chatvertexai):
def test_initialization_with_tools(self):
tools = [
place_tool_query,
StructuredTool.from_function(place_photo_query),
Expand All @@ -176,6 +170,8 @@ def test_initialization_with_tools(self, mock_chatvertexai):
model=_TEST_MODEL,
system_instruction=_TEST_SYSTEM_INSTRUCTION,
tools=tools,
model_builder=lambda **kwargs: kwargs,
runnable_builder=lambda **kwargs: kwargs,
)
for tool, agent_tool in zip(tools, agent._tools):
assert isinstance(agent_tool, type(tool))
Expand All @@ -188,6 +184,8 @@ def test_set_up(self):
model=_TEST_MODEL,
prompt=self.prompt,
output_parser=self.output_parser,
model_builder=lambda **kwargs: kwargs,
runnable_builder=lambda **kwargs: kwargs,
)
assert agent._runnable is None
agent.set_up()
Expand All @@ -198,6 +196,8 @@ def test_clone(self):
model=_TEST_MODEL,
prompt=self.prompt,
output_parser=self.output_parser,
model_builder=lambda **kwargs: kwargs,
runnable_builder=lambda **kwargs: kwargs,
)
agent.set_up()
assert agent._runnable is not None
Expand Down Expand Up @@ -247,12 +247,13 @@ def test_enable_tracing(
enable_tracing=True,
)
assert agent._instrumentor is None
agent.set_up()
assert agent._instrumentor is not None
assert (
"enable_tracing=True but proceeding with tracing disabled"
not in caplog.text
)
# TODO(b/384730642): Re-enable this test once the parent issue is fixed.
# agent.set_up()
# assert agent._instrumentor is not None
# assert (
# "enable_tracing=True but proceeding with tracing disabled"
# not in caplog.text
# )

@pytest.mark.usefixtures("caplog")
def test_enable_tracing_warning(self, caplog, langchain_instrumentor_none_mock):
Expand All @@ -263,8 +264,8 @@ def test_enable_tracing_warning(self, caplog, langchain_instrumentor_none_mock):
enable_tracing=True,
)
assert agent._instrumentor is None
agent.set_up()
# TODO(b/383923584): Re-enable this test once the parent issue is fixed.
# TODO(b/384730642): Re-enable this test once the parent issue is fixed.
# agent.set_up()
# assert "enable_tracing=True but proceeding with tracing disabled" in caplog.text


Expand Down
96 changes: 68 additions & 28 deletions vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,36 +86,76 @@


# These type defnitions are expanded to help the user see all the types
PartsType = Union[
str,
"Image",
"Part",
List[Union[str, "Image", "Part"]],
]

ContentDict = Dict[str, Any]
ContentsType = Union[
List["Content"],
List[ContentDict],
str,
"Image",
"Part",
List[Union[str, "Image", "Part"]],
]

GenerationConfigDict = Dict[str, Any]
GenerationConfigType = Union[
"GenerationConfig",
GenerationConfigDict,
]

SafetySettingsType = Union[
List["SafetySetting"],
Dict[
gapic_content_types.HarmCategory,
gapic_content_types.SafetySetting.HarmBlockThreshold,
],
]
try:
# For Pydantic to resolve the forward references inside these aliases.
from typing_extensions import TypeAliasType

PartsType = TypeAliasType(
"PartsType",
Union[
str,
"Image",
"Part",
List[Union[str, "Image", "Part"]],
],
)
ContentsType = TypeAliasType(
"ContentsType",
Union[
List["Content"],
List[ContentDict],
str,
"Image",
"Part",
List[Union[str, "Image", "Part"]],
],
)
GenerationConfigType = TypeAliasType(
"GenerationConfigType",
Union[
"GenerationConfig",
GenerationConfigDict,
],
)
SafetySettingsType = TypeAliasType(
"SafetySettingsType",
Union[
List["SafetySetting"],
Dict[
gapic_content_types.HarmCategory,
gapic_content_types.SafetySetting.HarmBlockThreshold,
],
],
)
except ImportError:
# Use existing definitions if typing_extensions is not available.
PartsType = Union[
str,
"Image",
"Part",
List[Union[str, "Image", "Part"]],
]
ContentsType = Union[
List["Content"],
List[ContentDict],
str,
"Image",
"Part",
List[Union[str, "Image", "Part"]],
]
GenerationConfigType = Union[
"GenerationConfig",
GenerationConfigDict,
]
SafetySettingsType = Union[
List["SafetySetting"],
Dict[
gapic_content_types.HarmCategory,
gapic_content_types.SafetySetting.HarmBlockThreshold,
],
]


def _reconcile_model_name(model_name: str, project: str, location: str) -> str:
Expand Down

0 comments on commit 2224c83

Please sign in to comment.