Skip to content

Commit

Permalink
feat: GenAI - Grounding - Added grounding dynamic_retrieval config to…
Browse files Browse the repository at this point in the history
… Vertex SDK

PiperOrigin-RevId: 696776459
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Nov 15, 2024
1 parent 44587ec commit c39334a
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 5 deletions.
29 changes: 26 additions & 3 deletions tests/system/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,10 +427,33 @@ def test_generate_content_from_text_and_remote_audio(
assert api_transport in get_client_api_transport(pro_model._prediction_client)

def test_grounding_google_search_retriever(self, api_endpoint_env_name):
model = preview_generative_models.GenerativeModel(GEMINI_MODEL_NAME)
model = generative_models.GenerativeModel(GEMINI_MODEL_NAME)
google_search_retriever_tool = (
preview_generative_models.Tool.from_google_search_retrieval(
preview_generative_models.grounding.GoogleSearchRetrieval()
generative_models.Tool.from_google_search_retrieval(
generative_models.grounding.GoogleSearchRetrieval()
)
)
response = model.generate_content(
"Why is sky blue?",
tools=[google_search_retriever_tool],
generation_config=generative_models.GenerationConfig(temperature=0),
)
assert (
response.candidates[0].finish_reason
is generative_models.FinishReason.RECITATION
or response.text
)

def test_grounding_google_search_retriever_with_dynamic_retrieval(
self, api_endpoint_env_name
):
model = generative_models.GenerativeModel(GEMINI_MODEL_NAME)
google_search_retriever_tool = generative_models.Tool.from_google_search_retrieval(
generative_models.grounding.GoogleSearchRetrieval(
generative_models.grounding.DynamicRetrievalConfig(
mode=generative_models.grounding.DynamicRetrievalConfig.Mode.MODE_DYNAMIC,
dynamic_threshold=0.05,
)
)
)
response = model.generate_content(
Expand Down
31 changes: 29 additions & 2 deletions vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2745,14 +2745,41 @@ class grounding: # pylint: disable=invalid-name
def __init__(self):
raise RuntimeError("This class must not be instantiated.")

class DynamicRetrievalConfig:
"""Config for dynamic retrieval."""

Mode = gapic_tool_types.DynamicRetrievalConfig.Mode

def __init__(
self,
mode: Mode = Mode.MODE_UNSPECIFIED,
dynamic_threshold: Optional[float] = None,
):
"""Initializes a DynamicRetrievalConfig."""
self._raw_dynamic_retrieval_config = (
gapic_tool_types.DynamicRetrievalConfig(
mode=mode,
dynamic_threshold=dynamic_threshold,
)
)

class GoogleSearchRetrieval:
r"""Tool to retrieve public web data for grounding, powered by
Google Search.
"""

def __init__(self):
def __init__(
self,
dynamic_retrieval_config: Optional[
"grounding.DynamicRetrievalConfig"
] = None,
):
"""Initializes a Google Search Retrieval tool."""
self._raw_google_search_retrieval = gapic_tool_types.GoogleSearchRetrieval()
self._raw_google_search_retrieval = gapic_tool_types.GoogleSearchRetrieval(
dynamic_retrieval_config=dynamic_retrieval_config._raw_dynamic_retrieval_config
if dynamic_retrieval_config
else None
)


class preview_grounding: # pylint: disable=invalid-name
Expand Down

0 comments on commit c39334a

Please sign in to comment.