Skip to content

Commit

Permalink
fix: Include DeploymentResourcePool class in aiplatform top-level sdk…
Browse files Browse the repository at this point in the history
… module

PiperOrigin-RevId: 652541626
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Jul 15, 2024
1 parent cfe0cc6 commit ecc4f09
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
2 changes: 2 additions & 0 deletions google/cloud/aiplatform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
)
from google.cloud.aiplatform import metadata
from google.cloud.aiplatform.tensorboard import uploader_tracker
from google.cloud.aiplatform.models import DeploymentResourcePool
from google.cloud.aiplatform.models import Endpoint
from google.cloud.aiplatform.models import PrivateEndpoint
from google.cloud.aiplatform.models import Model
Expand Down Expand Up @@ -153,6 +154,7 @@
"CustomTrainingJob",
"CustomContainerTrainingJob",
"CustomPythonPackageTrainingJob",
"DeploymentResourcePool",
"Endpoint",
"EntityType",
"Execution",
Expand Down
17 changes: 8 additions & 9 deletions tests/unit/aiplatform/test_deployment_resource_pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from google.cloud import aiplatform
from google.cloud.aiplatform import base
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import models

from google.cloud.aiplatform.compat.services import (
deployment_resource_pool_service_client,
Expand Down Expand Up @@ -232,7 +231,7 @@ def test_constructor_gets_drp(self, get_drp_mock):
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
)
models.DeploymentResourcePool(_TEST_DRP_NAME)
aiplatform.DeploymentResourcePool(_TEST_DRP_NAME)
get_drp_mock.assert_called_once_with(
name=_TEST_DRP_NAME, retry=base._DEFAULT_RETRY
)
Expand All @@ -242,7 +241,7 @@ def test_constructor_with_conflicting_location_fails(self):
"""Passing a full resource name with `_TEST_LOCATION` and providing `_TEST_LOCATION_2` as location"""

with pytest.raises(RuntimeError) as err:
models.DeploymentResourcePool(_TEST_DRP_NAME, location=_TEST_LOCATION_2)
aiplatform.DeploymentResourcePool(_TEST_DRP_NAME, location=_TEST_LOCATION_2)

assert err.match(
regexp=r"is provided, but different from the resource location"
Expand All @@ -251,7 +250,7 @@ def test_constructor_with_conflicting_location_fails(self):
@pytest.mark.usefixtures("create_drp_mock")
@pytest.mark.parametrize("sync", [True, False])
def test_create(self, create_drp_mock, sync):
test_drp = models.DeploymentResourcePool.create(
test_drp = aiplatform.DeploymentResourcePool.create(
deployment_resource_pool_id=_TEST_ID,
machine_type=_TEST_MACHINE_TYPE,
min_replica_count=10,
Expand Down Expand Up @@ -285,7 +284,7 @@ def test_create(self, create_drp_mock, sync):
@pytest.mark.usefixtures("create_drp_mock")
@pytest.mark.parametrize("sync", [True, False])
def test_create_with_timeout(self, create_drp_mock, sync):
test_drp = models.DeploymentResourcePool.create(
test_drp = aiplatform.DeploymentResourcePool.create(
deployment_resource_pool_id=_TEST_ID,
machine_type=_TEST_MACHINE_TYPE,
min_replica_count=10,
Expand Down Expand Up @@ -319,17 +318,17 @@ def test_create_with_timeout(self, create_drp_mock, sync):

@pytest.mark.usefixtures("list_drp_mock")
def test_list(self, list_drp_mock):
drp_list = models.DeploymentResourcePool.list()
drp_list = aiplatform.DeploymentResourcePool.list()

list_drp_mock.assert_called_once()

for drp in drp_list:
assert isinstance(drp, models.DeploymentResourcePool)
assert isinstance(drp, aiplatform.DeploymentResourcePool)

@pytest.mark.usefixtures("delete_drp_mock", "get_drp_mock")
@pytest.mark.parametrize("sync", [True, False])
def test_delete(self, delete_drp_mock, get_drp_mock, sync):
test_drp = models.DeploymentResourcePool(
test_drp = aiplatform.DeploymentResourcePool(
deployment_resource_pool_name=_TEST_DRP_NAME
)
test_drp.delete(sync=sync)
Expand All @@ -341,7 +340,7 @@ def test_delete(self, delete_drp_mock, get_drp_mock, sync):

@pytest.mark.usefixtures("query_deployed_models_mock", "get_drp_mock")
def test_query_deployed_models(self, query_deployed_models_mock, get_drp_mock):
test_drp = models.DeploymentResourcePool(
test_drp = aiplatform.DeploymentResourcePool(
deployment_resource_pool_name=_TEST_DRP_NAME
)
dm_refs = test_drp.query_deployed_models()
Expand Down
2 changes: 2 additions & 0 deletions vertexai/resources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from google.cloud.aiplatform import metadata
from google.cloud.aiplatform.tensorboard import uploader_tracker
from google.cloud.aiplatform.models import DeploymentResourcePool
from google.cloud.aiplatform.models import Endpoint
from google.cloud.aiplatform.models import PrivateEndpoint
from google.cloud.aiplatform.models import Model
Expand Down Expand Up @@ -148,6 +149,7 @@
"CustomTrainingJob",
"CustomContainerTrainingJob",
"CustomPythonPackageTrainingJob",
"DeploymentResourcePool",
"Endpoint",
"EntityType",
"Execution",
Expand Down

0 comments on commit ecc4f09

Please sign in to comment.