Skip to content

Commit

Permalink
chore: Refactoring product mapping for environment types
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 658480094
  • Loading branch information
speedstorm1 authored and copybara-github committed Aug 1, 2024
1 parent d92e7c9 commit cff8ae0
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
14 changes: 8 additions & 6 deletions google/cloud/aiplatform/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,12 +387,14 @@ def get_resource_type(self) -> _Product:
return self._resource_type

vertex_product = os.getenv("VERTEX_PRODUCT")
if vertex_product == "COLAB_ENTERPRISE":
self._resource_type = _Product.COLAB_ENTERPRISE
if vertex_product == "WORKBENCH_CUSTOM_CONTAINER":
self._resource_type = _Product.WORKBENCH_CUSTOM_CONTAINER
if vertex_product == "WORKBENCH_INSTANCE":
self._resource_type = _Product.WORKBENCH_INSTANCE
product_mapping = {
"COLAB_ENTERPRISE": _Product.COLAB_ENTERPRISE,
"WORKBENCH_CUSTOM_CONTAINER": _Product.WORKBENCH_CUSTOM_CONTAINER,
"WORKBENCH_INSTANCE": _Product.WORKBENCH_INSTANCE,
}

if vertex_product in product_mapping:
self._resource_type = product_mapping[vertex_product]

return self._resource_type

Expand Down
19 changes: 19 additions & 0 deletions tests/unit/aiplatform/test_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,25 @@ def test_get_client_options_with_api_override(self):

assert client_options.api_endpoint == "asia-east1-override.googleapis.com"

def test_get_resource_type(self):
initializer.global_config.init()
os.environ["VERTEX_PRODUCT"] = "COLAB_ENTERPRISE"
assert initializer.global_config.get_resource_type().value == (
"COLAB_ENTERPRISE"
)

initializer.global_config.init()
os.environ["VERTEX_PRODUCT"] = "WORKBENCH_INSTANCE"
assert initializer.global_config.get_resource_type().value == (
"WORKBENCH_INSTANCE"
)

initializer.global_config.init()
os.environ["VERTEX_PRODUCT"] = "WORKBENCH_CUSTOM_CONTAINER"
assert initializer.global_config.get_resource_type().value == (
"WORKBENCH_CUSTOM_CONTAINER"
)

def test_init_with_only_creds_does_not_override_set_project(self):
assert initializer.global_config.project is not _TEST_PROJECT_2
initializer.global_config.init(project=_TEST_PROJECT_2)
Expand Down

0 comments on commit cff8ae0

Please sign in to comment.