Skip to content

Commit

Permalink
chore: Dynamic set query method
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 698923437
  • Loading branch information
shawn-yang-google authored and copybara-github committed Nov 21, 2024
1 parent 58ba55e commit 653ba88
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 46 deletions.
46 changes: 33 additions & 13 deletions tests/unit/vertex_langchain/test_reasoning_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def register_operations(self) -> Dict[str, List[str]]:
_TEST_STANDARD_API_MODE = _reasoning_engines._STANDARD_API_MODE
_TEST_MODE_KEY_IN_SCHEMA = _reasoning_engines._MODE_KEY_IN_SCHEMA
_TEST_DEFAULT_METHOD_NAME = _reasoning_engines._DEFAULT_METHOD_NAME
_TEST_DEFAULT_METHOD_DOCSTRING = _reasoning_engines._DEFAULT_METHOD_DOCSTRING
_TEST_CUSTOM_METHOD_NAME = "custom_method"
_TEST_QUERY_PROMPT = "Find the first fibonacci number greater than 999"
_TEST_REASONING_ENGINE_GCS_URI = "{}/{}/{}".format(
Expand Down Expand Up @@ -413,13 +414,6 @@ def query_reasoning_engine_mock():
yield query_reasoning_engine_mock


@pytest.fixture(scope="module")
def to_dict_mock():
with mock.patch.object(_utils, "to_dict") as to_dict_mock:
to_dict_mock.return_value = {}
yield to_dict_mock


# Function scope is required for the pytest parameterized tests.
@pytest.fixture(scope="function")
def types_reasoning_engine_mock():
Expand Down Expand Up @@ -853,23 +847,49 @@ def test_delete_after_get_reasoning_engine(
name=test_reasoning_engine.resource_name,
)

def test_query_after_create_reasoning_engine(
self,
get_reasoning_engine_mock,
query_reasoning_engine_mock,
get_gca_resource_mock,
):
test_reasoning_engine = reasoning_engines.ReasoningEngine.create(
self.test_app,
display_name=_TEST_REASONING_ENGINE_DISPLAY_NAME,
requirements=_TEST_REASONING_ENGINE_REQUIREMENTS,
extra_packages=[_TEST_REASONING_ENGINE_EXTRA_PACKAGE_PATH],
)
get_reasoning_engine_mock.assert_called_with(
name=_TEST_REASONING_ENGINE_RESOURCE_NAME,
retry=_TEST_RETRY,
)
with mock.patch.object(_utils, "to_dict") as to_dict_mock:
to_dict_mock.return_value = {}
test_reasoning_engine.query(query=_TEST_QUERY_PROMPT)
assert test_reasoning_engine.query.__doc__ == _TEST_DEFAULT_METHOD_DOCSTRING
query_reasoning_engine_mock.assert_called_with(
request=_TEST_REASONING_ENGINE_QUERY_REQUEST
)
to_dict_mock.assert_called_once()

def test_query_reasoning_engine(
self,
get_reasoning_engine_mock,
query_reasoning_engine_mock,
to_dict_mock,
get_gca_resource_mock,
):
test_reasoning_engine = reasoning_engines.ReasoningEngine(_TEST_RESOURCE_ID)
get_reasoning_engine_mock.assert_called_with(
name=_TEST_REASONING_ENGINE_RESOURCE_NAME,
retry=_TEST_RETRY,
)
test_reasoning_engine.query(query=_TEST_QUERY_PROMPT)
query_reasoning_engine_mock.assert_called_with(
request=_TEST_REASONING_ENGINE_QUERY_REQUEST
)
to_dict_mock.assert_called_once()
with mock.patch.object(_utils, "to_dict") as to_dict_mock:
to_dict_mock.return_value = {}
test_reasoning_engine.query(query=_TEST_QUERY_PROMPT)
query_reasoning_engine_mock.assert_called_with(
request=_TEST_REASONING_ENGINE_QUERY_REQUEST
)
to_dict_mock.assert_called_once()

def test_operation_schemas(
self,
Expand Down
106 changes: 73 additions & 33 deletions vertexai/reasoning_engines/_reasoning_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
import os
import sys
import tarfile
import types
import typing
from typing import Any, Dict, List, Optional, Protocol, Sequence, Union
from typing import Any, Callable, Dict, List, Optional, Protocol, Sequence, Union

import proto

Expand All @@ -29,7 +30,7 @@
from google.cloud.aiplatform import base
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import utils as aip_utils
from google.cloud.aiplatform_v1beta1 import types
from google.cloud.aiplatform_v1beta1 import types as aip_types
from google.cloud.aiplatform_v1beta1.types import reasoning_engine_service
from vertexai.reasoning_engines import _utils
from google.protobuf import field_mask_pb2
Expand All @@ -44,6 +45,19 @@
_STANDARD_API_MODE = ""
_MODE_KEY_IN_SCHEMA = "api_mode"
_DEFAULT_METHOD_NAME = "query"
_DEFAULT_METHOD_DOCSTRING = """
Runs the Reasoning Engine to serve the user query.
This will be based on the `.query(...)` method of the python object that
was passed in when creating the Reasoning Engine.
Args:
**kwargs:
Optional. The arguments of the `.query(...)` method.
Returns:
dict[str, Any]: The response from serving the user query.
"""


@typing.runtime_checkable
Expand Down Expand Up @@ -73,7 +87,7 @@ def register_operations(self, **kwargs):
"""Register the user provided operations (modes and methods)."""


class ReasoningEngine(base.VertexAiResourceNounWithFutureManager, Queryable):
class ReasoningEngine(base.VertexAiResourceNounWithFutureManager):
"""Represents a Vertex AI Reasoning Engine resource."""

client_class = aip_utils.ReasoningEngineClientWithOverride
Expand All @@ -98,6 +112,7 @@ def __init__(self, reasoning_engine_name: str):
client_class=aip_utils.ReasoningEngineExecutionClientWithOverride,
)
self._gca_resource = self._get_gca_resource(resource_name=reasoning_engine_name)
_register_api_method(self)
self._operation_schemas = None

@property
Expand Down Expand Up @@ -233,7 +248,7 @@ def create(
extra_packages=extra_packages,
)
# Update the package spec.
package_spec = types.ReasoningEngineSpec.PackageSpec(
package_spec = aip_types.ReasoningEngineSpec.PackageSpec(
python_version=sys_version,
pickle_object_gcs_uri="{}/{}/{}".format(
staging_bucket,
Expand All @@ -253,7 +268,7 @@ def create(
gcs_dir_name,
_REQUIREMENTS_FILE,
)
reasoning_engine_spec = types.ReasoningEngineSpec(
reasoning_engine_spec = aip_types.ReasoningEngineSpec(
package_spec=package_spec,
)
class_methods_spec = _generate_class_methods_spec_or_raise(
Expand All @@ -264,7 +279,7 @@ def create(
parent=initializer.global_config.common_location_path(
project=sdk_resource.project, location=sdk_resource.location
),
reasoning_engine=types.ReasoningEngine(
reasoning_engine=aip_types.ReasoningEngine(
name=reasoning_engine_name,
display_name=display_name,
description=description,
Expand All @@ -289,6 +304,7 @@ def create(
credentials=sdk_resource.credentials,
location_override=sdk_resource.location,
)
_register_api_method(sdk_resource)
sdk_resource._operation_schemas = None
return sdk_resource

Expand Down Expand Up @@ -431,30 +447,6 @@ def operation_schemas(self) -> Sequence[_utils.JsonDict]:
self._operation_schemas = spec.get("class_methods", [])
return self._operation_schemas

def query(self, **kwargs) -> _utils.JsonDict:
"""Runs the Reasoning Engine to serve the user query.
This will be based on the `.query(...)` method of the python object that
was passed in when creating the Reasoning Engine.
Args:
**kwargs:
Optional. The arguments of the `.query(...)` method.
Returns:
dict[str, Any]: The response from serving the user query.
"""
response = self.execution_api_client.query_reasoning_engine(
request=types.QueryReasoningEngineRequest(
name=self.resource_name,
input=kwargs,
),
)
output = _utils.to_dict(response)
if "output" in output:
return output.get("output")
return output


def _validate_sys_version_or_raise(sys_version: str) -> None:
"""Tries to validate the python system version."""
Expand Down Expand Up @@ -630,8 +622,8 @@ def _generate_update_request_or_raise(
"""Tries to generates the update request for the reasoning engine."""
is_spec_update = False
update_masks: List[str] = []
reasoning_engine_spec = types.ReasoningEngineSpec()
package_spec = types.ReasoningEngineSpec.PackageSpec()
reasoning_engine_spec = aip_types.ReasoningEngineSpec()
package_spec = aip_types.ReasoningEngineSpec.PackageSpec()
if requirements is not None:
is_spec_update = True
update_masks.append("spec.package_spec.requirements_gcs_uri")
Expand Down Expand Up @@ -662,7 +654,7 @@ def _generate_update_request_or_raise(
reasoning_engine_spec.class_methods.extend(class_methods_spec)
update_masks.append("spec.class_methods")

reasoning_engine_message = types.ReasoningEngine(name=resource_name)
reasoning_engine_message = aip_types.ReasoningEngine(name=resource_name)
if is_spec_update:
reasoning_engine_spec.package_spec = package_spec
reasoning_engine_message.spec = reasoning_engine_spec
Expand All @@ -684,6 +676,54 @@ def _generate_update_request_or_raise(
)


def _wrap_query_operation(method_name: str, doc: str) -> Callable[..., _utils.JsonDict]:
"""Wraps a Reasoning Engine method, creating a callable for `query` API.
This function creates a callable object that executes the specified
Reasoning Engine method using the `query` API. It handles the creation of
the API request and the processing of the API response.
Args:
method_name: The name of the Reasoning Engine method to call.
doc: Documentation string for the method.
Returns:
A callable object that executes the method on the Reasoning Engine via
the `query` API.
"""

def _method(self, **kwargs) -> _utils.JsonDict:
response = self.execution_api_client.query_reasoning_engine(
request=aip_types.QueryReasoningEngineRequest(
name=self.resource_name,
input=kwargs,
),
)
output = _utils.to_dict(response)
return output.get("output", output)

_method.__name__ = method_name
_method.__doc__ = doc

return _method


def _register_api_method(obj: "ReasoningEngine"):
"""Registers Reasoning Engine API methods based on operation schemas.
This function registers `query` method on the ReasoningEngine object
to handle API calls based on the specified API mode.
Args:
obj: The ReasoningEngine object to augment with API methods.
"""
query_method = _wrap_query_operation(
method_name=_DEFAULT_METHOD_NAME, doc=_DEFAULT_METHOD_DOCSTRING
)
# Binds the method to the object.
setattr(obj, _DEFAULT_METHOD_NAME, types.MethodType(query_method, obj))


def _get_registered_operations(reasoning_engine: Any) -> Dict[str, List[str]]:
"""Retrieves registered operations for a ReasoningEngine."""
if isinstance(reasoning_engine, OperationRegistrable):
Expand Down

0 comments on commit 653ba88

Please sign in to comment.