Skip to content

Commit

Permalink
feat: support PSC-Interface in Ray on Vertex
Browse files Browse the repository at this point in the history
feat: support disable Cloud logging in Ray on Vertex

PiperOrigin-RevId: 661019434
  • Loading branch information
yinghsienwu authored and copybara-github committed Aug 8, 2024
1 parent a521ba6 commit accaa97
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 30 deletions.
2 changes: 2 additions & 0 deletions google/cloud/aiplatform/vertex_ray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from google.cloud.aiplatform.vertex_ray.util.resources import (
Resources,
NodeImages,
PscIConfig,
)

from google.cloud.aiplatform.vertex_ray.dashboard_sdk import (
Expand All @@ -61,4 +62,5 @@
"update_ray_cluster",
"Resources",
"NodeImages",
"PscIConfig",
)
33 changes: 27 additions & 6 deletions google/cloud/aiplatform/vertex_ray/cluster_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,20 @@
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.utils import resource_manager_utils
from google.cloud.aiplatform_v1.types import persistent_resource_service
from google.cloud.aiplatform_v1beta1.types import persistent_resource_service

from google.cloud.aiplatform_v1.types.persistent_resource import (
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
PersistentResource,
RayLogsSpec,
RaySpec,
RayMetricSpec,
ResourcePool,
ResourceRuntimeSpec,
ServiceAccountSpec,
)

from google.cloud.aiplatform_v1beta1.types.service_networking import (
PscInterfaceConfig,
)
from google.cloud.aiplatform.vertex_ray.util import (
_gapic_utils,
_validation_utils,
Expand All @@ -56,6 +59,8 @@ def create_ray_cluster(
worker_node_types: Optional[List[resources.Resources]] = [resources.Resources()],
custom_images: Optional[resources.NodeImages] = None,
enable_metrics_collection: Optional[bool] = True,
enable_logging: Optional[bool] = True,
psc_interface_config: Optional[resources.PscIConfig] = None,
labels: Optional[Dict[str, str]] = None,
) -> str:
"""Create a ray cluster on the Vertex AI.
Expand Down Expand Up @@ -119,6 +124,8 @@ def create_ray_cluster(
head/worker_node_type(s). Note that configuring `Resources.custom_image`
will override `custom_images` here. Allowlist only.
enable_metrics_collection: Enable Ray metrics collection for visualization.
enable_logging: Enable exporting Ray logs to Cloud Logging.
psc_interface_config: PSC-I config.
labels:
The labels with user-defined metadata to organize Ray cluster.
Expand Down Expand Up @@ -258,10 +265,17 @@ def create_ray_cluster(
i += 1

resource_pools = [resource_pool_0] + worker_pools
disabled = not enable_metrics_collection
ray_metric_spec = RayMetricSpec(disabled=disabled)

metrics_collection_disabled = not enable_metrics_collection
ray_metric_spec = RayMetricSpec(disabled=metrics_collection_disabled)

logging_disabled = not enable_logging
ray_logs_spec = RayLogsSpec(disabled=logging_disabled)

ray_spec = RaySpec(
resource_pool_images=resource_pool_images, ray_metric_spec=ray_metric_spec
resource_pool_images=resource_pool_images,
ray_metric_spec=ray_metric_spec,
ray_logs_spec=ray_logs_spec,
)
if service_account:
service_account_spec = ServiceAccountSpec(
Expand All @@ -274,11 +288,18 @@ def create_ray_cluster(
)
else:
resource_runtime_spec = ResourceRuntimeSpec(ray_spec=ray_spec)
if psc_interface_config:
gapic_psc_interface_config = PscInterfaceConfig(
network_attachment=psc_interface_config.network_attachment,
)
else:
gapic_psc_interface_config = None
persistent_resource = PersistentResource(
resource_pools=resource_pools,
network=network,
labels=labels,
resource_runtime_spec=resource_runtime_spec,
psc_interface_config=gapic_psc_interface_config,
)

location = initializer.global_config.location
Expand Down
17 changes: 14 additions & 3 deletions google/cloud/aiplatform/vertex_ray/util/_gapic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@
from google.cloud.aiplatform.vertex_ray.util import _validation_utils
from google.cloud.aiplatform.vertex_ray.util.resources import (
Cluster,
PscIConfig,
Resources,
)
from google.cloud.aiplatform_v1.types.persistent_resource import (
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
PersistentResource,
)
from google.cloud.aiplatform_v1.types.persistent_resource_service import (
from google.cloud.aiplatform_v1beta1.types.persistent_resource_service import (
GetPersistentResourceRequest,
)

Expand All @@ -47,7 +48,7 @@ def create_persistent_resource_client():
return initializer.global_config.create_client(
client_class=PersistentResourceClientWithOverride,
appended_gapic_version="vertex_ray",
)
).select_version("v1beta1")


def polling_delay(num_attempts: int, time_scale: float) -> datetime.timedelta:
Expand Down Expand Up @@ -159,6 +160,10 @@ def persistent_resource_to_cluster(
% persistent_resource.name,
)
return
if persistent_resource.psc_interface_config:
cluster.psc_interface_config = PscIConfig(
network_attachment=persistent_resource.psc_interface_config.network_attachment
)
resource_pools = persistent_resource.resource_pools

head_resource_pool = resource_pools[0]
Expand Down Expand Up @@ -192,6 +197,12 @@ def persistent_resource_to_cluster(
ray_version = None
cluster.python_version = python_version
cluster.ray_version = ray_version
cluster.ray_metric_enabled = not (
persistent_resource.resource_runtime_spec.ray_spec.ray_metric_spec.disabled
)
cluster.ray_logs_enabled = not (
persistent_resource.resource_runtime_spec.ray_spec.ray_logs_spec.disabled
)

accelerator_type = head_resource_pool.machine_spec.accelerator_type
if accelerator_type.value != 0:
Expand Down
26 changes: 25 additions & 1 deletion google/cloud/aiplatform/vertex_ray/util/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#
import dataclasses
from typing import Dict, List, Optional
from google.cloud.aiplatform_v1.types import PersistentResource
from google.cloud.aiplatform_v1beta1.types import PersistentResource


@dataclasses.dataclass
Expand Down Expand Up @@ -68,6 +68,27 @@ class NodeImages:
worker: str = None


@dataclasses.dataclass
class PscIConfig:
"""PSC-I config.
Attributes:
network_attachment: Optional. The name or full name of the Compute Engine
`network attachment <https://cloud.google.com/vpc/docs/about-network-attachments>`
to attach to the resource. It has a format:
``projects/{project}/regions/{region}/networkAttachments/{networkAttachment}``.
Where {project} is a project number, as in ``12345``, and
{networkAttachment} is a network attachment name. To specify
this field, you must have already [created a network
attachment]
(https://cloud.google.com/vpc/docs/create-manage-network-attachments#create-network-attachments).
This field is only used for resources using PSC-I. Make sure you do not
specify the network here for VPC peering.
"""

network_attachment: str = None


@dataclasses.dataclass
class Cluster:
"""Ray cluster (output only).
Expand Down Expand Up @@ -111,6 +132,9 @@ class Cluster:
head_node_type: Resources = None
worker_node_types: List[Resources] = None
dashboard_address: str = None
ray_metric_enabled: bool = True
ray_logs_enabled: bool = True
psc_interface_config: PscIConfig = None
labels: Dict[str, str] = None


Expand Down
8 changes: 4 additions & 4 deletions tests/unit/vertex_ray/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@
from google.auth import credentials as auth_credentials
from google.cloud import resourcemanager
from google.cloud.aiplatform import vertex_ray
from google.cloud.aiplatform_v1.services.persistent_resource_service import (
from google.cloud.aiplatform_v1beta1.services.persistent_resource_service import (
PersistentResourceServiceClient,
)
from google.cloud.aiplatform_v1.types.persistent_resource import (
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
PersistentResource,
)
from google.cloud.aiplatform_v1.types.persistent_resource import (
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
ResourceRuntime,
)
from google.cloud.aiplatform_v1.types.persistent_resource_service import (
from google.cloud.aiplatform_v1beta1.types.persistent_resource_service import (
DeletePersistentResourceRequest,
)
import test_constants as tc
Expand Down
12 changes: 9 additions & 3 deletions tests/unit/vertex_ray/test_cluster_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
Resources,
NodeImages,
)
from google.cloud.aiplatform_v1.services.persistent_resource_service import (
from google.cloud.aiplatform_v1beta1.services.persistent_resource_service import (
PersistentResourceServiceClient,
)
from google.cloud.aiplatform_v1.types import persistent_resource_service
from google.cloud.aiplatform_v1beta1.types import persistent_resource_service
import test_constants as tc
import mock
import pytest
Expand Down Expand Up @@ -352,13 +352,15 @@ def test_create_ray_cluster_1_pool_gpu_with_labels_success(
self, create_persistent_resource_1_pool_mock
):
"""If head and worker nodes are duplicate, merge to head pool."""
# Also test disable logging and metrics collection.
cluster_name = vertex_ray.create_ray_cluster(
head_node_type=tc.ClusterConstants.TEST_HEAD_NODE_TYPE_1_POOL,
worker_node_types=tc.ClusterConstants.TEST_WORKER_NODE_TYPES_1_POOL,
network=tc.ProjectConstants.TEST_VPC_NETWORK,
cluster_name=tc.ClusterConstants.TEST_VERTEX_RAY_PR_ID,
labels=tc.ClusterConstants.TEST_LABELS,
enable_metrics_collection=False,
enable_logging=False,
)

assert tc.ClusterConstants.TEST_VERTEX_RAY_PR_ADDRESS == cluster_name
Expand Down Expand Up @@ -401,11 +403,15 @@ def test_create_ray_cluster_2_pools_success(
self, create_persistent_resource_2_pools_mock
):
"""If head and worker nodes are not duplicate, create separate resource_pools."""
# Also test PSC-I.
psc_interface_config = vertex_ray.PscIConfig(
network_attachment=tc.ClusterConstants.TEST_PSC_NETWORK_ATTACHMENT
)
cluster_name = vertex_ray.create_ray_cluster(
head_node_type=tc.ClusterConstants.TEST_HEAD_NODE_TYPE_2_POOLS,
worker_node_types=tc.ClusterConstants.TEST_WORKER_NODE_TYPES_2_POOLS,
network=tc.ProjectConstants.TEST_VPC_NETWORK,
cluster_name=tc.ClusterConstants.TEST_VERTEX_RAY_PR_ID,
psc_interface_config=psc_interface_config,
)

assert tc.ClusterConstants.TEST_VERTEX_RAY_PR_ADDRESS == cluster_name
Expand Down
Loading

0 comments on commit accaa97

Please sign in to comment.