Skip to content

Commit

Permalink
feat: Add support for strategy in custom training jobs.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 661428427
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Aug 9, 2024
1 parent 7404f67 commit a076191
Show file tree
Hide file tree
Showing 6 changed files with 405 additions and 3 deletions.
34 changes: 32 additions & 2 deletions google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2214,6 +2214,7 @@ def run(
create_request_timeout: Optional[float] = None,
disable_retries: bool = False,
persistent_resource_id: Optional[str] = None,
scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None,
) -> None:
"""Run this configured CustomJob.
Expand Down Expand Up @@ -2282,6 +2283,8 @@ def run(
on-demand short-live machines. The network, CMEK, and node pool
configs on the job should be consistent with those on the
PersistentResource, otherwise, the job will be rejected.
scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy):
Optional. Indicates the job scheduling strategy.
"""
network = network or initializer.global_config.network
service_account = service_account or initializer.global_config.service_account
Expand All @@ -2299,6 +2302,7 @@ def run(
create_request_timeout=create_request_timeout,
disable_retries=disable_retries,
persistent_resource_id=persistent_resource_id,
scheduling_strategy=scheduling_strategy,
)

@base.optional_sync()
Expand All @@ -2316,6 +2320,7 @@ def _run(
create_request_timeout: Optional[float] = None,
disable_retries: bool = False,
persistent_resource_id: Optional[str] = None,
scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None,
) -> None:
"""Helper method to ensure network synchronization and to run the configured CustomJob.
Expand Down Expand Up @@ -2382,6 +2387,8 @@ def _run(
on-demand short-live machines. The network, CMEK, and node pool
configs on the job should be consistent with those on the
PersistentResource, otherwise, the job will be rejected.
scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy):
Optional. Indicates the job scheduling strategy.
"""
self.submit(
service_account=service_account,
Expand All @@ -2395,6 +2402,7 @@ def _run(
create_request_timeout=create_request_timeout,
disable_retries=disable_retries,
persistent_resource_id=persistent_resource_id,
scheduling_strategy=scheduling_strategy,
)

self._block_until_complete()
Expand All @@ -2413,6 +2421,7 @@ def submit(
create_request_timeout: Optional[float] = None,
disable_retries: bool = False,
persistent_resource_id: Optional[str] = None,
scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None,
) -> None:
"""Submit the configured CustomJob.
Expand Down Expand Up @@ -2476,6 +2485,8 @@ def submit(
on-demand short-live machines. The network, CMEK, and node pool
configs on the job should be consistent with those on the
PersistentResource, otherwise, the job will be rejected.
scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy):
Optional. Indicates the job scheduling strategy.
Raises:
ValueError:
Expand All @@ -2498,12 +2509,18 @@ def submit(
if network:
self._gca_resource.job_spec.network = network

if timeout or restart_job_on_worker_restart or disable_retries:
if (
timeout
or restart_job_on_worker_restart
or disable_retries
or scheduling_strategy
):
timeout = duration_pb2.Duration(seconds=timeout) if timeout else None
self._gca_resource.job_spec.scheduling = gca_custom_job_compat.Scheduling(
timeout=timeout,
restart_job_on_worker_restart=restart_job_on_worker_restart,
disable_retries=disable_retries,
strategy=scheduling_strategy,
)

if enable_web_access:
Expand Down Expand Up @@ -2868,6 +2885,7 @@ def run(
sync: bool = True,
create_request_timeout: Optional[float] = None,
disable_retries: bool = False,
scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None,
) -> None:
"""Run this configured CustomJob.
Expand Down Expand Up @@ -2916,6 +2934,8 @@ def run(
Indicates if the job should retry for internal errors after the
job starts running. If True, overrides
`restart_job_on_worker_restart` to False.
scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy):
Optional. Indicates the job scheduling strategy.
"""
network = network or initializer.global_config.network
service_account = service_account or initializer.global_config.service_account
Expand All @@ -2930,6 +2950,7 @@ def run(
sync=sync,
create_request_timeout=create_request_timeout,
disable_retries=disable_retries,
scheduling_strategy=scheduling_strategy,
)

@base.optional_sync()
Expand All @@ -2944,6 +2965,7 @@ def _run(
sync: bool = True,
create_request_timeout: Optional[float] = None,
disable_retries: bool = False,
scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None,
) -> None:
"""Helper method to ensure network synchronization and to run the configured CustomJob.
Expand Down Expand Up @@ -2990,20 +3012,28 @@ def _run(
Indicates if the job should retry for internal errors after the
job starts running. If True, overrides
`restart_job_on_worker_restart` to False.
scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy):
Optional. Indicates the job scheduling strategy.
"""
if service_account:
self._gca_resource.trial_job_spec.service_account = service_account

if network:
self._gca_resource.trial_job_spec.network = network

if timeout or restart_job_on_worker_restart or disable_retries:
if (
timeout
or restart_job_on_worker_restart
or disable_retries
or scheduling_strategy
):
duration = duration_pb2.Duration(seconds=timeout) if timeout else None
self._gca_resource.trial_job_spec.scheduling = (
gca_custom_job_compat.Scheduling(
timeout=duration,
restart_job_on_worker_restart=restart_job_on_worker_restart,
disable_retries=disable_retries,
strategy=scheduling_strategy,
)
)

Expand Down
44 changes: 43 additions & 1 deletion google/cloud/aiplatform/training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from google.cloud.aiplatform.compat.types import (
training_pipeline as gca_training_pipeline,
study as gca_study_compat,
custom_job as gca_custom_job_compat,
)

from google.cloud.aiplatform.utils import _timestamped_gcs_dir
Expand Down Expand Up @@ -1525,6 +1526,7 @@ def _prepare_training_task_inputs_and_output_dir(
tensorboard: Optional[str] = None,
disable_retries: bool = False,
persistent_resource_id: Optional[str] = None,
scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None,
) -> Tuple[Dict, str]:
"""Prepares training task inputs and output directory for custom job.
Expand Down Expand Up @@ -1582,6 +1584,8 @@ def _prepare_training_task_inputs_and_output_dir(
on-demand short-live machines. The network, CMEK, and node pool
configs on the job should be consistent with those on the
PersistentResource, otherwise, the job will be rejected.
scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy):
Optional. Indicates the job scheduling strategy.
Returns:
Training task inputs and Output directory for custom job.
Expand Down Expand Up @@ -1612,12 +1616,18 @@ def _prepare_training_task_inputs_and_output_dir(
if persistent_resource_id:
training_task_inputs["persistent_resource_id"] = persistent_resource_id

if timeout or restart_job_on_worker_restart or disable_retries:
if (
timeout
or restart_job_on_worker_restart
or disable_retries
or scheduling_strategy
):
timeout = f"{timeout}s" if timeout else None
scheduling = {
"timeout": timeout,
"restart_job_on_worker_restart": restart_job_on_worker_restart,
"disable_retries": disable_retries,
"strategy": scheduling_strategy,
}
training_task_inputs["scheduling"] = scheduling

Expand Down Expand Up @@ -3005,6 +3015,7 @@ def run(
disable_retries: bool = False,
persistent_resource_id: Optional[str] = None,
tpu_topology: Optional[str] = None,
scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None,
) -> Optional[models.Model]:
"""Runs the custom training job.
Expand Down Expand Up @@ -3360,6 +3371,8 @@ def run(
details on the TPU topology, refer to
https://cloud.google.com/tpu/docs/v5e#tpu-v5e-config. The topology must
be a supported value for the TPU machine type.
scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy):
Optional. Indicates the job scheduling strategy.
Returns:
The trained Vertex AI model resource or None if the training
Expand Down Expand Up @@ -3424,6 +3437,7 @@ def run(
create_request_timeout=create_request_timeout,
disable_retries=disable_retries,
persistent_resource_id=persistent_resource_id,
scheduling_strategy=scheduling_strategy,
)

def submit(
Expand Down Expand Up @@ -3477,6 +3491,7 @@ def submit(
disable_retries: bool = False,
persistent_resource_id: Optional[str] = None,
tpu_topology: Optional[str] = None,
scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None,
) -> Optional[models.Model]:
"""Submits the custom training job without blocking until completion.
Expand Down Expand Up @@ -3777,6 +3792,8 @@ def submit(
details on the TPU topology, refer to
https://cloud.google.com/tpu/docs/v5e#tpu-v5e-config. The topology must
be a supported value for the TPU machine type.
scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy):
Optional. Indicates the job scheduling strategy.
Returns:
model: The trained Vertex AI Model resource or None if training did not
Expand Down Expand Up @@ -3841,6 +3858,7 @@ def submit(
block=False,
disable_retries=disable_retries,
persistent_resource_id=persistent_resource_id,
scheduling_strategy=scheduling_strategy,
)

@base.optional_sync(construct_object_on_arg="managed_model")
Expand Down Expand Up @@ -3888,6 +3906,7 @@ def _run(
block: Optional[bool] = True,
disable_retries: bool = False,
persistent_resource_id: Optional[str] = None,
scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None,
) -> Optional[models.Model]:
"""Packages local script and launches training_job.
Expand Down Expand Up @@ -4084,6 +4103,8 @@ def _run(
on-demand short-live machines. The network, CMEK, and node pool
configs on the job should be consistent with those on the
PersistentResource, otherwise, the job will be rejected.
scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy):
Optional. Indicates the job scheduling strategy.
Returns:
model: The trained Vertex AI Model resource or None if training did not
Expand Down Expand Up @@ -4138,6 +4159,7 @@ def _run(
tensorboard=tensorboard,
disable_retries=disable_retries,
persistent_resource_id=persistent_resource_id,
scheduling_strategy=scheduling_strategy,
)

model = self._run_job(
Expand Down Expand Up @@ -4462,6 +4484,7 @@ def run(
disable_retries: bool = False,
persistent_resource_id: Optional[str] = None,
tpu_topology: Optional[str] = None,
scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None,
) -> Optional[models.Model]:
"""Runs the custom training job.
Expand Down Expand Up @@ -4755,6 +4778,8 @@ def run(
details on the TPU topology, refer to
https://cloud.google.com/tpu/docs/v5e#tpu-v5e-config. The topology
must be a supported value for the TPU machine type.
scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy):
Optional. Indicates the job scheduling strategy.
Returns:
model: The trained Vertex AI Model resource or None if training did not
Expand Down Expand Up @@ -4818,6 +4843,7 @@ def run(
create_request_timeout=create_request_timeout,
disable_retries=disable_retries,
persistent_resource_id=persistent_resource_id,
scheduling_strategy=scheduling_strategy,
)

def submit(
Expand Down Expand Up @@ -4871,6 +4897,7 @@ def submit(
disable_retries: bool = False,
persistent_resource_id: Optional[str] = None,
tpu_topology: Optional[str] = None,
scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None,
) -> Optional[models.Model]:
"""Submits the custom training job without blocking until completion.
Expand Down Expand Up @@ -5164,6 +5191,8 @@ def submit(
details on the TPU topology, refer to
https://cloud.google.com/tpu/docs/v5e#tpu-v5e-config. The topology
must be a supported value for the TPU machine type.
scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy):
Optional. Indicates the job scheduling strategy.
Returns:
model: The trained Vertex AI Model resource or None if training did not
Expand Down Expand Up @@ -5227,6 +5256,7 @@ def submit(
block=False,
disable_retries=disable_retries,
persistent_resource_id=persistent_resource_id,
scheduling_strategy=scheduling_strategy,
)

@base.optional_sync(construct_object_on_arg="managed_model")
Expand Down Expand Up @@ -5273,6 +5303,7 @@ def _run(
block: Optional[bool] = True,
disable_retries: bool = False,
persistent_resource_id: Optional[str] = None,
scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None,
) -> Optional[models.Model]:
"""Packages local script and launches training_job.
Args:
Expand Down Expand Up @@ -5465,6 +5496,8 @@ def _run(
on-demand short-live machines. The network, CMEK, and node pool
configs on the job should be consistent with those on the
PersistentResource, otherwise, the job will be rejected.
scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy):
Optional. Indicates the job scheduling strategy.
Returns:
model: The trained Vertex AI Model resource or None if training did not
Expand Down Expand Up @@ -5513,6 +5546,7 @@ def _run(
tensorboard=tensorboard,
disable_retries=disable_retries,
persistent_resource_id=persistent_resource_id,
scheduling_strategy=scheduling_strategy,
)

model = self._run_job(
Expand Down Expand Up @@ -7537,6 +7571,7 @@ def run(
disable_retries: bool = False,
persistent_resource_id: Optional[str] = None,
tpu_topology: Optional[str] = None,
scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None,
) -> Optional[models.Model]:
"""Runs the custom training job.
Expand Down Expand Up @@ -7831,6 +7866,8 @@ def run(
details on the TPU topology, refer to
https://cloud.google.com/tpu/docs/v5e#tpu-v5e-config. The topology
must be a supported value for the TPU machine type.
scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy):
Optional. Indicates the job scheduling strategy.
Returns:
model: The trained Vertex AI Model resource or None if training did not
Expand Down Expand Up @@ -7889,6 +7926,7 @@ def run(
create_request_timeout=create_request_timeout,
disable_retries=disable_retries,
persistent_resource_id=persistent_resource_id,
scheduling_strategy=scheduling_strategy,
)

@base.optional_sync(construct_object_on_arg="managed_model")
Expand Down Expand Up @@ -7934,6 +7972,7 @@ def _run(
create_request_timeout: Optional[float] = None,
disable_retries: bool = False,
persistent_resource_id: Optional[str] = None,
scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None,
) -> Optional[models.Model]:
"""Packages local script and launches training_job.
Expand Down Expand Up @@ -8111,6 +8150,8 @@ def _run(
on-demand short-live machines. The network, CMEK, and node pool
configs on the job should be consistent with those on the
PersistentResource, otherwise, the job will be rejected.
scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy):
Optional. Indicates the job scheduling strategy.
Returns:
model: The trained Vertex AI Model resource or None if training did not
Expand Down Expand Up @@ -8159,6 +8200,7 @@ def _run(
tensorboard=tensorboard,
disable_retries=disable_retries,
persistent_resource_id=persistent_resource_id,
scheduling_strategy=scheduling_strategy,
)

model = self._run_job(
Expand Down
Loading

0 comments on commit a076191

Please sign in to comment.