Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: project/location parsing for nested resources #1700

Merged
merged 15 commits into from
Sep 30, 2022
7 changes: 7 additions & 0 deletions google/cloud/aiplatform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,13 @@ def _list(
Returns:
List[VertexAiResourceNoun] - A list of SDK resource objects
"""
if parent:
parent_resources = utils.extract_project_and_location_from_parent(parent)
if parent_resources:
project, location = (
parent_resources["project"],
parent_resources["location"],
)

resource = cls._empty_constructor(
project=project, location=location, credentials=credentials
Expand Down
28 changes: 28 additions & 0 deletions google/cloud/aiplatform/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,34 @@ def extract_bucket_and_prefix_from_gcs_path(gcs_path: str) -> Tuple[str, Optiona
return (gcs_bucket, gcs_blob_prefix)


def extract_project_and_location_from_parent(
parent: str,
) -> Dict[str, str]:
"""Given a complete parent resource name, return the project and location as a dict.

Example Usage:

parent_resources = extract_project_and_location_from_parent(
"projects/123/locations/us-central1/datasets/456"
)

parent_resources["project"] = "123"
parent_resources["location"] = "us-central1"

Args:
parent (str):
Required. A complete parent resource name.

Returns:
Dict[str, str]
A project, location dict from provided parent resource name.
"""
parent_resources = re.match(
nayaknishant marked this conversation as resolved.
Show resolved Hide resolved
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)(/|$)", parent
)
return parent_resources.groupdict() if parent_resources else {}

nayaknishant marked this conversation as resolved.
Show resolved Hide resolved

class ClientWithOverride:
class WrappedClient:
"""Wrapper class for client that creates client at API invocation
Expand Down
37 changes: 35 additions & 2 deletions tests/unit/aiplatform/test_featurestores.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,7 +1023,23 @@ def test_list_entity_types(self, list_entity_types_mock):
aiplatform.init(project=_TEST_PROJECT)
nayaknishant marked this conversation as resolved.
Show resolved Hide resolved

my_featurestore = aiplatform.Featurestore(
featurestore_name=_TEST_FEATURESTORE_ID
featurestore_name=_TEST_FEATURESTORE_ID,
)
my_entity_type_list = my_featurestore.list_entity_types()

list_entity_types_mock.assert_called_once_with(
request={"parent": _TEST_FEATURESTORE_NAME}
)
assert len(my_entity_type_list) == len(_TEST_ENTITY_TYPE_LIST)
for my_entity_type in my_entity_type_list:
assert type(my_entity_type) == aiplatform.EntityType

@pytest.mark.usefixtures("get_featurestore_mock")
def test_list_entity_types_with_no_init(self, list_entity_types_mock):
my_featurestore = aiplatform.Featurestore(
featurestore_name=_TEST_FEATURESTORE_ID,
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
my_entity_type_list = my_featurestore.list_entity_types()

Expand Down Expand Up @@ -1762,7 +1778,7 @@ def test_update_entity_type(self, update_entity_type_mock):
@pytest.mark.parametrize(
"featurestore_name", [_TEST_FEATURESTORE_NAME, _TEST_FEATURESTORE_ID]
)
def test_list_entity_types(self, featurestore_name, list_entity_types_mock):
def test_list_entity_type(self, featurestore_name, list_entity_types_mock):
aiplatform.init(project=_TEST_PROJECT)

my_entity_type_list = aiplatform.EntityType.list(
Expand Down Expand Up @@ -1790,6 +1806,23 @@ def test_list_features(self, list_features_mock):
for my_feature in my_feature_list:
assert type(my_feature) == aiplatform.Feature

@pytest.mark.usefixtures("get_entity_type_mock")
def test_list_features_with_no_init(self, list_features_mock):
my_entity_type = aiplatform.EntityType(
entity_type_name=_TEST_ENTITY_TYPE_ID,
featurestore_id=_TEST_FEATURESTORE_ID,
project=_TEST_PROJECT,
location=_TEST_LOCATION
)
my_feature_list = my_entity_type.list_features()

list_features_mock.assert_called_once_with(
request={"parent": _TEST_ENTITY_TYPE_NAME}
)
assert len(my_feature_list) == len(_TEST_FEATURE_LIST)
for my_feature in my_feature_list:
assert type(my_feature) == aiplatform.Feature

@pytest.mark.parametrize("sync", [True, False])
@pytest.mark.usefixtures("get_entity_type_mock", "get_feature_mock")
def test_delete_features(self, delete_feature_mock, sync):
Expand Down
24 changes: 24 additions & 0 deletions tests/unit/aiplatform/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,30 @@ def test_extract_bucket_and_prefix_from_gcs_path(gcs_path: str, expected: tuple)
assert expected == utils.extract_bucket_and_prefix_from_gcs_path(gcs_path)


@pytest.mark.parametrize(
"parent, expected",
[
(
"projects/123/locations/us-central1/datasets/456",
{"project": "123", "location": "us-central1"},
),
(
"projects/123/locations/us-central1/",
{"project": "123", "location": "us-central1"},
),
(
"projects/123/locations/us-central1",
{"project": "123", "location": "us-central1"},
),
("projects/123/locations/", {}),
("projects/123", {}),
],
)
def test_extract_project_and_location_from_parent(parent: str, expected: tuple):
# Given a parent resource name, ensure correct project and location are extracted
assert expected == utils.extract_project_and_location_from_parent(parent)


@pytest.mark.usefixtures("google_auth_mock")
def test_wrapped_client():
test_client_info = gapic_v1.client_info.ClientInfo()
Expand Down