Skip to content

Commit

Permalink
feat: add support for query method in Vertex AI Extension SDK
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 662504522
  • Loading branch information
Christine Betts authored and copybara-github committed Aug 13, 2024
1 parent 659ba3f commit 0008735
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 4 deletions.
79 changes: 77 additions & 2 deletions tests/unit/vertexai/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,13 @@
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import utils as aip_utils
from google.cloud.aiplatform_v1beta1 import types
from google.cloud.aiplatform_v1beta1.services import extension_execution_service
from google.cloud.aiplatform_v1beta1.services import extension_registry_service
from google.cloud.aiplatform_v1beta1.services import (
extension_execution_service,
)
from google.cloud.aiplatform_v1beta1.services import (
extension_registry_service,
)
from vertexai.generative_models import _generative_models
from vertexai.preview import extensions
from vertexai.reasoning_engines import _utils
import pytest
Expand Down Expand Up @@ -180,6 +185,33 @@ def execute_extension_mock():
yield execute_extension_mock


@pytest.fixture
def query_extension_mock():
with mock.patch.object(
extension_execution_service.ExtensionExecutionServiceClient, "query_extension"
) as query_extension_mock:
query_extension_mock.return_value.steps = [
types.Content(
role="user",
parts=[
types.Part(
text=_TEST_QUERY_PROMPT,
)
],
),
types.Content(
role="extension",
parts=[
types.Part(
text=_TEST_RESPONSE_CONTENT,
)
],
),
]
query_extension_mock.return_value.failure_message = ""
yield query_extension_mock


@pytest.fixture
def delete_extension_mock():
with mock.patch.object(
Expand Down Expand Up @@ -325,6 +357,49 @@ def test_execute_extension(
),
)

def test_query_extension(
self,
get_extension_mock,
query_extension_mock,
load_yaml_mock,
):
test_extension = extensions.Extension(_TEST_RESOURCE_ID)
get_extension_mock.assert_called_once_with(
name=_TEST_EXTENSION_RESOURCE_NAME,
retry=aiplatform.base._DEFAULT_RETRY,
)
# Manually set _gca_resource here to prevent the mocks from propagating.
test_extension._gca_resource = _TEST_EXTENSION_OBJ
response = test_extension.query(
contents=[
_generative_models.Content(
parts=[
_generative_models.Part.from_text(
_TEST_QUERY_PROMPT,
)
],
role="user",
)
],
)
assert response.steps[-1].parts[0].text == _TEST_RESPONSE_CONTENT

query_extension_mock.assert_called_once_with(
types.QueryExtensionRequest(
name=_TEST_EXTENSION_RESOURCE_NAME,
contents=[
types.Content(
role="user",
parts=[
types.Part(
text=_TEST_QUERY_PROMPT,
)
],
)
],
),
)

def test_api_spec_from_yaml(self, get_extension_mock, load_yaml_mock):
test_extension = extensions.Extension(_TEST_RESOURCE_ID)
get_extension_mock.assert_called_once_with(
Expand Down
57 changes: 55 additions & 2 deletions vertexai/extensions/_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
# limitations under the License.
#
import json
from typing import Optional, Sequence, Union
from typing import List, Optional, Sequence, Union

from google.cloud.aiplatform import base
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import utils as aip_utils
from google.cloud.aiplatform_v1beta1 import types
from vertexai.generative_models import _generative_models
from vertexai.reasoning_engines import _utils

from google.protobuf import struct_pb2

_LOGGER = base.Logger(__name__)
Expand Down Expand Up @@ -248,6 +248,36 @@ def execute(
response = self.execution_api_client.execute_extension(request)
return _try_parse_execution_response(response)

def query(
self,
contents: _generative_models.ContentsType,
) -> "QueryExtensionResponse":
"""Queries an extension with the specified contents.
Args:
contents (ContentsType):
Required. The content of the current
conversation with the model.
For single-turn queries, this is a single
instance. For multi-turn queries, this is a
repeated field that contains conversation
history + latest request.
Returns:
The result of querying the extension.
Raises:
RuntimeError: If the response contains an error.
"""
request = types.QueryExtensionRequest(
name=self.resource_name,
contents=_generative_models._content_types_to_gapic_contents(contents),
)
response = self.execution_api_client.query_extension(request)
if response.failure_message:
raise RuntimeError(response.failure_message)
return QueryExtensionResponse._from_gapic(response)

@classmethod
def from_hub(
cls,
Expand Down Expand Up @@ -317,6 +347,29 @@ def from_hub(
)


class QueryExtensionResponse:
"""A class representing the response from querying an extension."""

def __init__(self, steps: List[_generative_models.Content]):
"""Initializes the QueryExtensionResponse with the given steps."""
self.steps = steps

@classmethod
def _from_gapic(
cls, response: types.QueryExtensionResponse
) -> "QueryExtensionResponse":
"""Creates a QueryExtensionResponse from a gapic response."""
return cls(
steps=[
_generative_models.Content(
parts=[_generative_models.Part._from_gapic(p) for p in c.parts],
role=c.role,
)
for c in response.steps
]
)


def _try_parse_execution_response(
response: types.ExecuteExtensionResponse,
) -> Union[_utils.JsonDict, str]:
Expand Down

0 comments on commit 0008735

Please sign in to comment.