Skip to content

Commit

Permalink
feat: GenAI - Added support for Grounding
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 607165697
  • Loading branch information
Ark-kun authored and copybara-github committed Feb 15, 2024
1 parent dd80b69 commit 0c3e294
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 5 deletions.
15 changes: 15 additions & 0 deletions tests/system/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from google.cloud import aiplatform
from tests.system.aiplatform import e2e_base
from vertexai import generative_models
from vertexai.preview import generative_models as preview_generative_models


class TestGenerativeModels(e2e_base.TestEndToEnd):
Expand Down Expand Up @@ -134,6 +135,20 @@ def test_generate_content_from_text_and_remote_video(self):
assert response.text
assert "Zootopia" in response.text

def test_grounding_google_search_retriever(self):
model = preview_generative_models.GenerativeModel("gemini-pro")
google_search_retriever_tool = (
preview_generative_models.Tool.from_google_search_retrieval(
preview_generative_models.grounding.GoogleSearchRetrieval(
disable_attribution=False
)
)
)
response = model.generate_content(
"Why is sky blue?", tools=[google_search_retriever_tool]
)
assert response.text

# Chat

def test_send_message_from_text(self):
Expand Down
81 changes: 76 additions & 5 deletions tests/unit/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,27 @@ def mock_generate_content(
contents: Optional[MutableSequence[gapic_content_types.Content]] = None,
) -> Iterable[gapic_prediction_service_types.GenerateContentResponse]:
is_continued_chat = len(request.contents) > 1
has_tools = bool(request.tools)
has_retrieval = any(
tool.retrieval or tool.google_search_retrieval for tool in request.tools
)
has_function_declarations = any(
tool.function_declarations for tool in request.tools
)
has_function_request = any(
content.parts[0].function_call for content in request.contents
)
has_function_response = any(
content.parts[0].function_response for content in request.contents
)

if has_tools:
has_function_response = any(
"function_response" in content.parts[0] for content in request.contents
)
if has_function_request:
assert has_function_response

if has_function_response:
assert has_function_request
assert has_function_declarations

if has_function_declarations:
needs_function_call = not has_function_response
if needs_function_call:
response_part_struct = _RESPONSE_FUNCTION_CALL_PART_STRUCT
Expand Down Expand Up @@ -158,6 +173,24 @@ def mock_generate_content(
gapic_content_types.Citation(_RESPONSE_CITATION_STRUCT),
]
),
grounding_metadata=gapic_content_types.GroundingMetadata(
web_search_queries=[request.contents[0].parts[0].text],
grounding_attributions=[
gapic_content_types.GroundingAttribution(
segment=gapic_content_types.Segment(
start_index=0,
end_index=67,
),
confidence_score=0.69857746,
web=gapic_content_types.GroundingAttribution.Web(
uri="https://math.ucr.edu/home/baez/physics/General/BlueSky/blue_sky.html",
title="Why is the sky blue? - UCR Math",
),
),
],
)
if has_retrieval and request.contents[0].parts[0].text
else None,
),
],
)
Expand Down Expand Up @@ -288,3 +321,41 @@ def test_chat_function_calling(self, generative_models: generative_models):
),
)
assert response2.text == "The weather in Boston is super nice!"

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
new=mock_generate_content,
)
def test_generate_content_grounding_google_search_retriever(self):
model = preview_generative_models.GenerativeModel("gemini-pro")
google_search_retriever_tool = (
preview_generative_models.Tool.from_google_search_retrieval(
preview_generative_models.grounding.GoogleSearchRetrieval(
disable_attribution=False
)
)
)
response = model.generate_content(
"Why is sky blue?", tools=[google_search_retriever_tool]
)
assert response.text

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
new=mock_generate_content,
)
def test_generate_content_grounding_vertex_ai_search_retriever(self):
model = preview_generative_models.GenerativeModel("gemini-pro")
google_search_retriever_tool = preview_generative_models.Tool.from_retrieval(
retrieval=preview_generative_models.grounding.Retrieval(
source=preview_generative_models.grounding.VertexAISearch(
datastore=f"projects/{_TEST_PROJECT}/locations/global/collections/default_collection/dataStores/test-datastore",
)
)
)
response = model.generate_content(
"Why is sky blue?", tools=[google_search_retriever_tool]
)
assert response.text
115 changes: 115 additions & 0 deletions vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,40 @@ def __init__(
function_declarations=gapic_function_declarations
)

@classmethod
def from_function_declarations(
cls,
function_declarations: List["FunctionDeclaration"],
):
gapic_function_declarations = [
function_declaration._raw_function_declaration
for function_declaration in function_declarations
]
raw_tool = gapic_tool_types.Tool(
function_declarations=gapic_function_declarations
)
return cls._from_gapic(raw_tool=raw_tool)

@classmethod
def from_retrieval(
cls,
retrieval: "Retrieval",
):
raw_tool = gapic_tool_types.Tool(
retrieval=retrieval._raw_retrieval
)
return cls._from_gapic(raw_tool=raw_tool)

@classmethod
def from_google_search_retrieval(
cls,
google_search_retrieval: "GoogleSearchRetrieval",
):
raw_tool = gapic_tool_types.Tool(
google_search_retrieval=google_search_retrieval._raw_google_search_retrieval
)
return cls._from_gapic(raw_tool=raw_tool)

@classmethod
def _from_gapic(
cls,
Expand Down Expand Up @@ -1520,6 +1554,87 @@ def _image(self) -> "Image":
return Image.from_bytes(data=self._raw_part.inline_data.data)


class grounding: # pylint: disable=invalid-name
"""Grounding namespace."""

def __init__(self):
raise RuntimeError("This class must not be instantiated.")

class Retrieval:
"""Defines a retrieval tool that model can call to access external knowledge."""

def __init__(
self,
source: Union["grounding.VertexAISearch"],
disable_attribution: Optional[bool] = None,
):
"""Initializes a Retrieval tool.
Args:
source (VertexAISearch):
Set to use data source powered by Vertex AI Search.
disable_attribution (bool):
Optional. Disable using the result from this
tool in detecting grounding attribution. This
does not affect how the result is given to the
model for generation.
"""
self._raw_retrieval = gapic_tool_types.Retrieval(
vertex_ai_search=source._raw_vertex_ai_search,
disable_attribution=disable_attribution,
)

class VertexAISearch:
r"""Retrieve from Vertex AI Search datastore for grounding.
See https://cloud.google.com/vertex-ai-search-and-conversation
"""

def __init__(
self,
datastore: str,
):
"""Initializes a Vertex AI Search tool.
Args:
datastore (str):
Required. Fully-qualified Vertex AI Search's
datastore resource ID.
projects/<>/locations/<>/collections/<>/dataStores/<>
"""
self._raw_vertex_ai_search = gapic_tool_types.VertexAISearch(
datastore=datastore,
)

class GoogleSearchRetrieval:
r"""Tool to retrieve public web data for grounding, powered by
Google Search.
Attributes:
disable_attribution (bool):
Optional. Disable using the result from this
tool in detecting grounding attribution. This
does not affect how the result is given to the
model for generation.
"""

def __init__(
self,
disable_attribution: Optional[bool] = None,
):
"""Initializes a Google Search Retrieval tool.
Args:
disable_attribution (bool):
Optional. Disable using the result from this
tool in detecting grounding attribution. This
does not affect how the result is given to the
model for generation.
"""
self._raw_google_search_retrieval = gapic_tool_types.GoogleSearchRetrieval(
disable_attribution=disable_attribution,
)


def _to_content(
value: Union[
gapic_content_types.Content,
Expand Down
2 changes: 2 additions & 0 deletions vertexai/preview/generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# We just want to re-export certain classes
# pylint: disable=g-multiple-import,g-importing-member
from vertexai.generative_models._generative_models import (
grounding,
_PreviewGenerativeModel,
GenerationConfig,
GenerationResponse,
Expand All @@ -39,6 +40,7 @@ class GenerativeModel(_PreviewGenerativeModel):


__all__ = [
"grounding",
"GenerationConfig",
"GenerativeModel",
"GenerationResponse",
Expand Down

0 comments on commit 0c3e294

Please sign in to comment.