Skip to content

Commit

Permalink
fix: add support of display_name to create_cached_content in python SDK
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 663329746
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Aug 15, 2024
1 parent e7b239a commit ecc2d54
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
8 changes: 8 additions & 0 deletions tests/unit/vertexai/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
_TEST_PROJECT = "test-project"
_TEST_LOCATION = "us-central1"
_CREATED_CONTENT_ID = "contents-id-mocked"
_TEST_DISPLAY_NAME = "test-display-name"


@pytest.fixture
Expand All @@ -53,6 +54,7 @@ def create_cached_content(self, request):
response = GapicCachedContent(
name=f"{request.parent}/cachedContents/{_CREATED_CONTENT_ID}",
model=f"{request.cached_content.model}",
display_name=f"{request.cached_content.display_name}",
create_time=datetime.datetime(
year=2024,
month=2,
Expand Down Expand Up @@ -199,6 +201,7 @@ def test_create_with_real_payload(
tool_config=GapicToolConfig(),
contents=[GapicContent(role="user")],
ttl=datetime.timedelta(days=1),
display_name=_TEST_DISPLAY_NAME,
)

# parent is automantically set to align with the current project and location.
Expand All @@ -211,6 +214,7 @@ def test_create_with_real_payload(
cache.model_name
== f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/model-name"
)
assert cache.display_name == _TEST_DISPLAY_NAME

def test_create_with_real_payload_and_wrapped_type(
self, mock_create_cached_content, mock_get_cached_content
Expand All @@ -222,6 +226,7 @@ def test_create_with_real_payload_and_wrapped_type(
tool_config=GapicToolConfig(),
contents=["user content"],
ttl=datetime.timedelta(days=1),
display_name=_TEST_DISPLAY_NAME,
)

# parent is automantically set to align with the current project and location.
Expand All @@ -231,6 +236,7 @@ def test_create_with_real_payload_and_wrapped_type(
== f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/model-name"
)
assert cache.name == _CREATED_CONTENT_ID
assert cache.display_name == _TEST_DISPLAY_NAME

def test_list(self, mock_list_cached_contents):
cached_contents = caching.CachedContent.list()
Expand All @@ -248,6 +254,7 @@ def test_print_a_cached_content(
tool_config=GapicToolConfig(),
contents=["user content"],
ttl=datetime.timedelta(days=1),
display_name=_TEST_DISPLAY_NAME,
)
f = io.StringIO()
with redirect_stdout(f):
Expand All @@ -261,6 +268,7 @@ def test_print_a_cached_content(
"createTime": "2024-02-01T01:01:01Z",
"updateTime": "2024-02-01T01:01:01Z",
"expireTime": "2024-02-01T02:01:01Z",
"displayName": "test-display-name",
},
indent=2,
)
Expand Down
10 changes: 7 additions & 3 deletions vertexai/caching/_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def _prepare_create_request(
contents: Optional[ContentsType] = None,
expire_time: Optional[datetime.datetime] = None,
ttl: Optional[datetime.timedelta] = None,
display_name: Optional[str] = None,
) -> CreateCachedContentRequest:
"""Prepares the request create_cached_content RPC."""
(
Expand Down Expand Up @@ -97,6 +98,7 @@ def _prepare_create_request(
contents=contents,
expire_time=expire_time,
ttl=ttl,
display_name=display_name,
),
)
return request
Expand Down Expand Up @@ -158,6 +160,7 @@ def create(
contents: Optional[List[Content]] = None,
expire_time: Optional[datetime.datetime] = None,
ttl: Optional[datetime.timedelta] = None,
display_name: Optional[str] = None,
) -> "CachedContent":
"""Creates a new cached content through the gen ai cache service.
Expand Down Expand Up @@ -194,6 +197,8 @@ def create(
At most one of expire_time and ttl can be set. If neither is set,
default TTL on the API side will be used (currently 1 hour).
display_name:
The user-generated meaningful display name of the cached content.
Returns:
A CachedContent object with only name and model_name specified.
Raises:
Expand All @@ -217,6 +222,7 @@ def create(
contents=contents,
expire_time=expire_time,
ttl=ttl,
display_name=display_name,
)
client = cls._instantiate_client(location=location)
cached_content_resource = client.create_cached_content(request)
Expand Down Expand Up @@ -292,6 +298,4 @@ def get(cls, cached_content_name: str) -> "CachedContent":
@property
def display_name(self) -> str:
"""Display name of this resource."""
# TODO(b/345335749): remove this override when the feature is available
# in the API.
raise NotImplementedError("Display name is not available.")
return self._gca_resource.display_name

0 comments on commit ecc2d54

Please sign in to comment.