From a076191b8726363e1f7c47ef8343eb86cebf9918 Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Fri, 9 Aug 2024 15:39:27 -0700 Subject: [PATCH] feat: Add support for strategy in custom training jobs. PiperOrigin-RevId: 661428427 --- google/cloud/aiplatform/jobs.py | 34 ++++- google/cloud/aiplatform/training_jobs.py | 44 ++++++- tests/unit/aiplatform/constants.py | 1 + tests/unit/aiplatform/test_custom_job.py | 90 +++++++++++++ .../test_hyperparameter_tuning_job.py | 120 ++++++++++++++++++ tests/unit/aiplatform/test_training_jobs.py | 119 +++++++++++++++++ 6 files changed, 405 insertions(+), 3 deletions(-) diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index 70d5029013..6f6a8380d8 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -2214,6 +2214,7 @@ def run( create_request_timeout: Optional[float] = None, disable_retries: bool = False, persistent_resource_id: Optional[str] = None, + scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, ) -> None: """Run this configured CustomJob. @@ -2282,6 +2283,8 @@ def run( on-demand short-live machines. The network, CMEK, and node pool configs on the job should be consistent with those on the PersistentResource, otherwise, the job will be rejected. + scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy): + Optional. Indicates the job scheduling strategy. """ network = network or initializer.global_config.network service_account = service_account or initializer.global_config.service_account @@ -2299,6 +2302,7 @@ def run( create_request_timeout=create_request_timeout, disable_retries=disable_retries, persistent_resource_id=persistent_resource_id, + scheduling_strategy=scheduling_strategy, ) @base.optional_sync() @@ -2316,6 +2320,7 @@ def _run( create_request_timeout: Optional[float] = None, disable_retries: bool = False, persistent_resource_id: Optional[str] = None, + scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, ) -> None: """Helper method to ensure network synchronization and to run the configured CustomJob. @@ -2382,6 +2387,8 @@ def _run( on-demand short-live machines. The network, CMEK, and node pool configs on the job should be consistent with those on the PersistentResource, otherwise, the job will be rejected. + scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy): + Optional. Indicates the job scheduling strategy. """ self.submit( service_account=service_account, @@ -2395,6 +2402,7 @@ def _run( create_request_timeout=create_request_timeout, disable_retries=disable_retries, persistent_resource_id=persistent_resource_id, + scheduling_strategy=scheduling_strategy, ) self._block_until_complete() @@ -2413,6 +2421,7 @@ def submit( create_request_timeout: Optional[float] = None, disable_retries: bool = False, persistent_resource_id: Optional[str] = None, + scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, ) -> None: """Submit the configured CustomJob. @@ -2476,6 +2485,8 @@ def submit( on-demand short-live machines. The network, CMEK, and node pool configs on the job should be consistent with those on the PersistentResource, otherwise, the job will be rejected. + scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy): + Optional. Indicates the job scheduling strategy. Raises: ValueError: @@ -2498,12 +2509,18 @@ def submit( if network: self._gca_resource.job_spec.network = network - if timeout or restart_job_on_worker_restart or disable_retries: + if ( + timeout + or restart_job_on_worker_restart + or disable_retries + or scheduling_strategy + ): timeout = duration_pb2.Duration(seconds=timeout) if timeout else None self._gca_resource.job_spec.scheduling = gca_custom_job_compat.Scheduling( timeout=timeout, restart_job_on_worker_restart=restart_job_on_worker_restart, disable_retries=disable_retries, + strategy=scheduling_strategy, ) if enable_web_access: @@ -2868,6 +2885,7 @@ def run( sync: bool = True, create_request_timeout: Optional[float] = None, disable_retries: bool = False, + scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, ) -> None: """Run this configured CustomJob. @@ -2916,6 +2934,8 @@ def run( Indicates if the job should retry for internal errors after the job starts running. If True, overrides `restart_job_on_worker_restart` to False. + scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy): + Optional. Indicates the job scheduling strategy. """ network = network or initializer.global_config.network service_account = service_account or initializer.global_config.service_account @@ -2930,6 +2950,7 @@ def run( sync=sync, create_request_timeout=create_request_timeout, disable_retries=disable_retries, + scheduling_strategy=scheduling_strategy, ) @base.optional_sync() @@ -2944,6 +2965,7 @@ def _run( sync: bool = True, create_request_timeout: Optional[float] = None, disable_retries: bool = False, + scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, ) -> None: """Helper method to ensure network synchronization and to run the configured CustomJob. @@ -2990,6 +3012,8 @@ def _run( Indicates if the job should retry for internal errors after the job starts running. If True, overrides `restart_job_on_worker_restart` to False. + scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy): + Optional. Indicates the job scheduling strategy. """ if service_account: self._gca_resource.trial_job_spec.service_account = service_account @@ -2997,13 +3021,19 @@ def _run( if network: self._gca_resource.trial_job_spec.network = network - if timeout or restart_job_on_worker_restart or disable_retries: + if ( + timeout + or restart_job_on_worker_restart + or disable_retries + or scheduling_strategy + ): duration = duration_pb2.Duration(seconds=timeout) if timeout else None self._gca_resource.trial_job_spec.scheduling = ( gca_custom_job_compat.Scheduling( timeout=duration, restart_job_on_worker_restart=restart_job_on_worker_restart, disable_retries=disable_retries, + strategy=scheduling_strategy, ) ) diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index bc415d0ee7..7e52060b05 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -44,6 +44,7 @@ from google.cloud.aiplatform.compat.types import ( training_pipeline as gca_training_pipeline, study as gca_study_compat, + custom_job as gca_custom_job_compat, ) from google.cloud.aiplatform.utils import _timestamped_gcs_dir @@ -1525,6 +1526,7 @@ def _prepare_training_task_inputs_and_output_dir( tensorboard: Optional[str] = None, disable_retries: bool = False, persistent_resource_id: Optional[str] = None, + scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, ) -> Tuple[Dict, str]: """Prepares training task inputs and output directory for custom job. @@ -1582,6 +1584,8 @@ def _prepare_training_task_inputs_and_output_dir( on-demand short-live machines. The network, CMEK, and node pool configs on the job should be consistent with those on the PersistentResource, otherwise, the job will be rejected. + scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy): + Optional. Indicates the job scheduling strategy. Returns: Training task inputs and Output directory for custom job. @@ -1612,12 +1616,18 @@ def _prepare_training_task_inputs_and_output_dir( if persistent_resource_id: training_task_inputs["persistent_resource_id"] = persistent_resource_id - if timeout or restart_job_on_worker_restart or disable_retries: + if ( + timeout + or restart_job_on_worker_restart + or disable_retries + or scheduling_strategy + ): timeout = f"{timeout}s" if timeout else None scheduling = { "timeout": timeout, "restart_job_on_worker_restart": restart_job_on_worker_restart, "disable_retries": disable_retries, + "strategy": scheduling_strategy, } training_task_inputs["scheduling"] = scheduling @@ -3005,6 +3015,7 @@ def run( disable_retries: bool = False, persistent_resource_id: Optional[str] = None, tpu_topology: Optional[str] = None, + scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, ) -> Optional[models.Model]: """Runs the custom training job. @@ -3360,6 +3371,8 @@ def run( details on the TPU topology, refer to https://cloud.google.com/tpu/docs/v5e#tpu-v5e-config. The topology must be a supported value for the TPU machine type. + scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy): + Optional. Indicates the job scheduling strategy. Returns: The trained Vertex AI model resource or None if the training @@ -3424,6 +3437,7 @@ def run( create_request_timeout=create_request_timeout, disable_retries=disable_retries, persistent_resource_id=persistent_resource_id, + scheduling_strategy=scheduling_strategy, ) def submit( @@ -3477,6 +3491,7 @@ def submit( disable_retries: bool = False, persistent_resource_id: Optional[str] = None, tpu_topology: Optional[str] = None, + scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, ) -> Optional[models.Model]: """Submits the custom training job without blocking until completion. @@ -3777,6 +3792,8 @@ def submit( details on the TPU topology, refer to https://cloud.google.com/tpu/docs/v5e#tpu-v5e-config. The topology must be a supported value for the TPU machine type. + scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy): + Optional. Indicates the job scheduling strategy. Returns: model: The trained Vertex AI Model resource or None if training did not @@ -3841,6 +3858,7 @@ def submit( block=False, disable_retries=disable_retries, persistent_resource_id=persistent_resource_id, + scheduling_strategy=scheduling_strategy, ) @base.optional_sync(construct_object_on_arg="managed_model") @@ -3888,6 +3906,7 @@ def _run( block: Optional[bool] = True, disable_retries: bool = False, persistent_resource_id: Optional[str] = None, + scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, ) -> Optional[models.Model]: """Packages local script and launches training_job. @@ -4084,6 +4103,8 @@ def _run( on-demand short-live machines. The network, CMEK, and node pool configs on the job should be consistent with those on the PersistentResource, otherwise, the job will be rejected. + scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy): + Optional. Indicates the job scheduling strategy. Returns: model: The trained Vertex AI Model resource or None if training did not @@ -4138,6 +4159,7 @@ def _run( tensorboard=tensorboard, disable_retries=disable_retries, persistent_resource_id=persistent_resource_id, + scheduling_strategy=scheduling_strategy, ) model = self._run_job( @@ -4462,6 +4484,7 @@ def run( disable_retries: bool = False, persistent_resource_id: Optional[str] = None, tpu_topology: Optional[str] = None, + scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, ) -> Optional[models.Model]: """Runs the custom training job. @@ -4755,6 +4778,8 @@ def run( details on the TPU topology, refer to https://cloud.google.com/tpu/docs/v5e#tpu-v5e-config. The topology must be a supported value for the TPU machine type. + scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy): + Optional. Indicates the job scheduling strategy. Returns: model: The trained Vertex AI Model resource or None if training did not @@ -4818,6 +4843,7 @@ def run( create_request_timeout=create_request_timeout, disable_retries=disable_retries, persistent_resource_id=persistent_resource_id, + scheduling_strategy=scheduling_strategy, ) def submit( @@ -4871,6 +4897,7 @@ def submit( disable_retries: bool = False, persistent_resource_id: Optional[str] = None, tpu_topology: Optional[str] = None, + scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, ) -> Optional[models.Model]: """Submits the custom training job without blocking until completion. @@ -5164,6 +5191,8 @@ def submit( details on the TPU topology, refer to https://cloud.google.com/tpu/docs/v5e#tpu-v5e-config. The topology must be a supported value for the TPU machine type. + scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy): + Optional. Indicates the job scheduling strategy. Returns: model: The trained Vertex AI Model resource or None if training did not @@ -5227,6 +5256,7 @@ def submit( block=False, disable_retries=disable_retries, persistent_resource_id=persistent_resource_id, + scheduling_strategy=scheduling_strategy, ) @base.optional_sync(construct_object_on_arg="managed_model") @@ -5273,6 +5303,7 @@ def _run( block: Optional[bool] = True, disable_retries: bool = False, persistent_resource_id: Optional[str] = None, + scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, ) -> Optional[models.Model]: """Packages local script and launches training_job. Args: @@ -5465,6 +5496,8 @@ def _run( on-demand short-live machines. The network, CMEK, and node pool configs on the job should be consistent with those on the PersistentResource, otherwise, the job will be rejected. + scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy): + Optional. Indicates the job scheduling strategy. Returns: model: The trained Vertex AI Model resource or None if training did not @@ -5513,6 +5546,7 @@ def _run( tensorboard=tensorboard, disable_retries=disable_retries, persistent_resource_id=persistent_resource_id, + scheduling_strategy=scheduling_strategy, ) model = self._run_job( @@ -7537,6 +7571,7 @@ def run( disable_retries: bool = False, persistent_resource_id: Optional[str] = None, tpu_topology: Optional[str] = None, + scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, ) -> Optional[models.Model]: """Runs the custom training job. @@ -7831,6 +7866,8 @@ def run( details on the TPU topology, refer to https://cloud.google.com/tpu/docs/v5e#tpu-v5e-config. The topology must be a supported value for the TPU machine type. + scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy): + Optional. Indicates the job scheduling strategy. Returns: model: The trained Vertex AI Model resource or None if training did not @@ -7889,6 +7926,7 @@ def run( create_request_timeout=create_request_timeout, disable_retries=disable_retries, persistent_resource_id=persistent_resource_id, + scheduling_strategy=scheduling_strategy, ) @base.optional_sync(construct_object_on_arg="managed_model") @@ -7934,6 +7972,7 @@ def _run( create_request_timeout: Optional[float] = None, disable_retries: bool = False, persistent_resource_id: Optional[str] = None, + scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, ) -> Optional[models.Model]: """Packages local script and launches training_job. @@ -8111,6 +8150,8 @@ def _run( on-demand short-live machines. The network, CMEK, and node pool configs on the job should be consistent with those on the PersistentResource, otherwise, the job will be rejected. + scheduling_strategy (gca_custom_job_compat.Scheduling.Strategy): + Optional. Indicates the job scheduling strategy. Returns: model: The trained Vertex AI Model resource or None if training did not @@ -8159,6 +8200,7 @@ def _run( tensorboard=tensorboard, disable_retries=disable_retries, persistent_resource_id=persistent_resource_id, + scheduling_strategy=scheduling_strategy, ) model = self._run_job( diff --git a/tests/unit/aiplatform/constants.py b/tests/unit/aiplatform/constants.py index 9c9758c2df..365cb71422 100644 --- a/tests/unit/aiplatform/constants.py +++ b/tests/unit/aiplatform/constants.py @@ -197,6 +197,7 @@ class TrainingJobConstants: "projects/my-project/locations/us-central1/trainingPipelines/12345" ) _TEST_DEFAULT_ENCRYPTION_KEY_NAME = "key_default" + _TEST_SPOT_STRATEGY = custom_job.Scheduling.Strategy.SPOT def create_tpu_job_proto(tpu_version): worker_pool_spec = ( diff --git a/tests/unit/aiplatform/test_custom_job.py b/tests/unit/aiplatform/test_custom_job.py index 46b9ca3fa0..19762a5059 100644 --- a/tests/unit/aiplatform/test_custom_job.py +++ b/tests/unit/aiplatform/test_custom_job.py @@ -62,6 +62,7 @@ test_constants.TrainingJobConstants._TEST_TRAINING_CONTAINER_IMAGE ) _TEST_PREBUILT_CONTAINER_IMAGE = "gcr.io/cloud-aiplatform/container:image" +_TEST_SPOT_STRATEGY = test_constants.TrainingJobConstants._TEST_SPOT_STRATEGY _TEST_RUN_ARGS = test_constants.TrainingJobConstants._TEST_RUN_ARGS _TEST_EXPERIMENT = "test-experiment" @@ -226,6 +227,12 @@ def _get_custom_tpu_job_proto(state=None, name=None, error=None, tpu_version=Non return custom_job_proto +def _get_custom_job_proto_with_spot_strategy(state=None, name=None, error=None): + custom_job_proto = _get_custom_job_proto(state=state, name=name, error=error) + custom_job_proto.job_spec.scheduling.strategy = _TEST_SPOT_STRATEGY + return custom_job_proto + + @pytest.fixture def mock_builtin_open(): with patch("builtins.open", mock_open(read_data="data")) as mock_file: @@ -396,6 +403,28 @@ def get_custom_job_mock_with_enable_web_access_succeeded(): yield get_custom_job_mock +@pytest.fixture +def get_custom_job_mock_with_spot_strategy(): + with patch.object( + job_service_client.JobServiceClient, "get_custom_job" + ) as get_custom_job_mock: + get_custom_job_mock.side_effect = [ + _get_custom_job_proto_with_spot_strategy( + name=_TEST_CUSTOM_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_PENDING, + ), + _get_custom_job_proto_with_spot_strategy( + name=_TEST_CUSTOM_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_RUNNING, + ), + _get_custom_job_proto_with_spot_strategy( + name=_TEST_CUSTOM_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED, + ), + ] + yield get_custom_job_mock + + @pytest.fixture def create_custom_job_mock(): with mock.patch.object( @@ -445,6 +474,18 @@ def create_custom_job_mock_fail(): yield create_custom_job_mock +@pytest.fixture +def create_custom_job_mock_with_spot_strategy(): + with mock.patch.object( + job_service_client.JobServiceClient, "create_custom_job" + ) as create_custom_job_mock: + create_custom_job_mock.return_value = _get_custom_job_proto_with_spot_strategy( + name=_TEST_CUSTOM_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_PENDING, + ) + yield create_custom_job_mock + + _EXPERIMENT_MOCK = copy.deepcopy(_EXPERIMENT_MOCK) _EXPERIMENT_MOCK.metadata[ constants._BACKING_TENSORBOARD_RESOURCE_KEY @@ -1433,3 +1474,52 @@ def test_create_custom_job_tpu_v3( assert ( job._gca_resource.state == gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED ) + + def test_create_custom_job_with_spot_strategy( + self, + create_custom_job_mock_with_spot_strategy, + get_custom_job_mock_with_spot_strategy, + ): + + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_STAGING_BUCKET, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = aiplatform.CustomJob( + display_name=_TEST_DISPLAY_NAME, + worker_pool_specs=_TEST_WORKER_POOL_SPEC, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + labels=_TEST_LABELS, + ) + + job.run( + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + timeout=_TEST_TIMEOUT, + restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, + create_request_timeout=None, + disable_retries=_TEST_DISABLE_RETRIES, + scheduling_strategy=_TEST_SPOT_STRATEGY, + ) + + job.wait_for_resource_creation() + + job.wait() + + assert job.resource_name == _TEST_CUSTOM_JOB_NAME + + expected_custom_job = _get_custom_job_proto_with_spot_strategy() + + create_custom_job_mock_with_spot_strategy.assert_called_once_with( + parent=_TEST_PARENT, + custom_job=expected_custom_job, + timeout=None, + ) + + assert job.job_spec == expected_custom_job.job_spec + assert ( + job._gca_resource.state == gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED + ) diff --git a/tests/unit/aiplatform/test_hyperparameter_tuning_job.py b/tests/unit/aiplatform/test_hyperparameter_tuning_job.py index f49af1868f..e78977e8e4 100644 --- a/tests/unit/aiplatform/test_hyperparameter_tuning_job.py +++ b/tests/unit/aiplatform/test_hyperparameter_tuning_job.py @@ -202,6 +202,22 @@ def _get_hyperparameter_tuning_job_proto_with_enable_web_access( return hyperparameter_tuning_job_proto +def _get_hyperparameter_tuning_job_proto_with_spot_strategy( + state=None, name=None, error=None, trials=[] +): + hyperparameter_tuning_job_proto = _get_hyperparameter_tuning_job_proto( + state=state, + name=name, + error=error, + ) + hyperparameter_tuning_job_proto.trial_job_spec.scheduling.strategy = ( + test_constants.TrainingJobConstants._TEST_SPOT_STRATEGY + ) + if state == gca_job_state_compat.JobState.JOB_STATE_RUNNING: + hyperparameter_tuning_job_proto.trials = trials + return hyperparameter_tuning_job_proto + + @pytest.fixture def get_hyperparameter_tuning_job_mock(): with patch.object( @@ -331,6 +347,28 @@ def get_hyperparameter_tuning_job_mock_with_fail(): yield get_hyperparameter_tuning_job_mock +@pytest.fixture +def get_hyperparameter_tuning_job_mock_with_spot_strategy(): + with patch.object( + job_service_client.JobServiceClient, "get_hyperparameter_tuning_job" + ) as get_hyperparameter_tuning_job_mock: + get_hyperparameter_tuning_job_mock.side_effect = [ + _get_hyperparameter_tuning_job_proto_with_spot_strategy( + name=_TEST_HYPERPARAMETERTUNING_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_PENDING, + ), + _get_hyperparameter_tuning_job_proto_with_spot_strategy( + name=_TEST_HYPERPARAMETERTUNING_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_RUNNING, + ), + _get_hyperparameter_tuning_job_proto_with_spot_strategy( + name=_TEST_HYPERPARAMETERTUNING_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED, + ), + ] + yield get_hyperparameter_tuning_job_mock + + @pytest.fixture def create_hyperparameter_tuning_job_mock(): with mock.patch.object( @@ -386,6 +424,20 @@ def create_hyperparameter_tuning_job_mock_with_tensorboard(): yield create_hyperparameter_tuning_job_mock +@pytest.fixture +def create_hyperparameter_tuning_job_mock_with_spot_strategy(): + with mock.patch.object( + job_service_client.JobServiceClient, "create_hyperparameter_tuning_job" + ) as create_hyperparameter_tuning_job_mock: + create_hyperparameter_tuning_job_mock.return_value = ( + _get_hyperparameter_tuning_job_proto_with_spot_strategy( + name=_TEST_HYPERPARAMETERTUNING_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_PENDING, + ) + ) + yield create_hyperparameter_tuning_job_mock + + @pytest.mark.usefixtures("google_auth_mock") class TestHyperparameterTuningJob: def setup_method(self): @@ -908,3 +960,71 @@ def test_log_enable_web_access_after_get_hyperparameter_tuning_job( assert hp_job._logged_web_access_uris == set( test_constants.TrainingJobConstants._TEST_WEB_ACCESS_URIS.values() ) + + def test_create_hyperparameter_tuning_job_with_spot_strategy( + self, + create_hyperparameter_tuning_job_mock_with_spot_strategy, + get_hyperparameter_tuning_job_mock_with_spot_strategy, + ): + + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_STAGING_BUCKET, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + custom_job = aiplatform.CustomJob( + display_name=test_constants.TrainingJobConstants._TEST_DISPLAY_NAME, + worker_pool_specs=test_constants.TrainingJobConstants._TEST_WORKER_POOL_SPEC, + base_output_dir=test_constants.TrainingJobConstants._TEST_BASE_OUTPUT_DIR, + ) + + job = aiplatform.HyperparameterTuningJob( + display_name=_TEST_DISPLAY_NAME, + custom_job=custom_job, + metric_spec={_TEST_METRIC_SPEC_KEY: _TEST_METRIC_SPEC_VALUE}, + parameter_spec={ + "lr": hpt.DoubleParameterSpec(min=0.001, max=0.1, scale="log"), + "units": hpt.IntegerParameterSpec(min=4, max=1028, scale="linear"), + "activation": hpt.CategoricalParameterSpec( + values=["relu", "sigmoid", "elu", "selu", "tanh"] + ), + "batch_size": hpt.DiscreteParameterSpec( + values=[4, 8, 16, 32, 64], + scale="linear", + conditional_parameter_spec={ + "decay": _TEST_CONDITIONAL_PARAMETER_DECAY, + "learning_rate": _TEST_CONDITIONAL_PARAMETER_LR, + }, + ), + }, + parallel_trial_count=_TEST_PARALLEL_TRIAL_COUNT, + max_trial_count=_TEST_MAX_TRIAL_COUNT, + max_failed_trial_count=_TEST_MAX_FAILED_TRIAL_COUNT, + search_algorithm=_TEST_SEARCH_ALGORITHM, + measurement_selection=_TEST_MEASUREMENT_SELECTION, + labels=_TEST_LABELS, + ) + + job.run( + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + timeout=_TEST_TIMEOUT, + restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, + create_request_timeout=None, + disable_retries=_TEST_DISABLE_RETRIES, + scheduling_strategy=test_constants.TrainingJobConstants._TEST_SPOT_STRATEGY, + ) + + job.wait() + + expected_hyperparameter_tuning_job = ( + _get_hyperparameter_tuning_job_proto_with_spot_strategy() + ) + + create_hyperparameter_tuning_job_mock_with_spot_strategy.assert_called_once_with( + parent=_TEST_PARENT, + hyperparameter_tuning_job=expected_hyperparameter_tuning_job, + timeout=None, + ) diff --git a/tests/unit/aiplatform/test_training_jobs.py b/tests/unit/aiplatform/test_training_jobs.py index e7f7f9d15a..681b255432 100644 --- a/tests/unit/aiplatform/test_training_jobs.py +++ b/tests/unit/aiplatform/test_training_jobs.py @@ -248,6 +248,7 @@ _TEST_PERSISTENT_RESOURCE_ID = ( test_constants.PersistentResourceConstants._TEST_PERSISTENT_RESOURCE_ID ) +_TEST_SPOT_STRATEGY = test_constants.TrainingJobConstants._TEST_SPOT_STRATEGY _TEST_BASE_CUSTOM_JOB_PROTO = gca_custom_job.CustomJob( job_spec=gca_custom_job.CustomJobSpec(), @@ -305,6 +306,15 @@ def _get_custom_job_proto_with_scheduling(state=None, name=None, version="v1"): return custom_job_proto +def _get_custom_job_proto_with_spot_strategy(state=None, name=None, version="v1"): + custom_job_proto = copy.deepcopy(_TEST_BASE_CUSTOM_JOB_PROTO) + custom_job_proto.name = name + custom_job_proto.state = state + + custom_job_proto.job_spec.scheduling.strategy = _TEST_SPOT_STRATEGY + return custom_job_proto + + def local_copy_method(path): shutil.copy(path, ".") return pathlib.Path(path).name @@ -810,6 +820,21 @@ def make_training_pipeline_with_scheduling(state): return training_pipeline +def make_training_pipeline_with_spot_strategy(state): + training_pipeline = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=state, + training_task_inputs={ + "scheduling_strategy": _TEST_SPOT_STRATEGY, + }, + ) + if state == gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING: + training_pipeline.training_task_metadata = { + "backingCustomJob": _TEST_CUSTOM_JOB_RESOURCE_NAME + } + return training_pipeline + + @pytest.fixture def mock_pipeline_service_get(make_call=make_training_pipeline): with mock.patch.object( @@ -952,6 +977,35 @@ def mock_pipeline_service_get_with_scheduling(): yield mock_get_training_pipeline +@pytest.fixture +def mock_pipeline_service_get_with_spot_strategy(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as mock_get_training_pipeline: + mock_get_training_pipeline.side_effect = [ + make_training_pipeline_with_spot_strategy( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_PENDING, + ), + make_training_pipeline_with_spot_strategy( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING, + ), + make_training_pipeline_with_spot_strategy( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING, + ), + make_training_pipeline_with_spot_strategy( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING, + ), + make_training_pipeline_with_spot_strategy( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + ), + make_training_pipeline_with_spot_strategy( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + ), + ] + + yield mock_get_training_pipeline + + @pytest.fixture def mock_pipeline_service_cancel(): with mock.patch.object( @@ -1026,6 +1080,19 @@ def mock_pipeline_service_create_with_scheduling(): yield mock_create_training_pipeline +@pytest.fixture +def mock_pipeline_service_create_with_spot_strategy(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_training_pipeline" + ) as mock_create_training_pipeline: + mock_create_training_pipeline.return_value = ( + make_training_pipeline_with_spot_strategy( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_PENDING, + ) + ) + yield mock_create_training_pipeline + + @pytest.fixture def mock_pipeline_service_get_with_no_model_to_upload(): with mock.patch.object( @@ -2388,6 +2455,58 @@ def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog): == _TEST_DISABLE_RETRIES ) + @mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1) + @mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1) + @pytest.mark.usefixtures( + "mock_pipeline_service_create_with_spot_strategy", + "mock_pipeline_service_get_with_spot_strategy", + "mock_python_package_to_gcs", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_spot_strategy(self, sync): + + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_BUCKET_NAME, + credentials=_TEST_CREDENTIALS, + ) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + job.run( + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + timeout=_TEST_TIMEOUT, + restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, + sync=sync, + create_request_timeout=None, + disable_retries=_TEST_DISABLE_RETRIES, + ) + + if not sync: + job.wait() + + assert job._gca_resource == make_training_pipeline_with_spot_strategy( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ) + + assert ( + job._gca_resource.state + == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ) + + assert ( + job._gca_resource.training_task_inputs["scheduling_strategy"] + == _TEST_SPOT_STRATEGY + ) + @mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1) @mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1) @pytest.mark.usefixtures(