diff --git a/tests/unit/vertexai/tuning/test_tuning.py b/tests/unit/vertexai/tuning/test_tuning.py index 5faf590bc1..0d6e74bb59 100644 --- a/tests/unit/vertexai/tuning/test_tuning.py +++ b/tests/unit/vertexai/tuning/test_tuning.py @@ -250,6 +250,30 @@ def test_genai_tuning_service_encryption_spec( ) assert sft_tuning_job.encryption_spec.kms_key_name == "test-key" + @mock.patch.object( + target=tuning.TuningJob, + attribute="client_class", + new=MockTuningJobClientWithOverride, + ) + @pytest.mark.parametrize( + "supervised_tuning", + [supervised_tuning, preview_supervised_tuning], + ) + def test_genai_tuning_service_service_account( + self, supervised_tuning: supervised_tuning + ): + """Test that the service account propagates to the tuning job.""" + vertexai.init(service_account="test-sa@test-project.iam.gserviceaccount.com") + + sft_tuning_job = supervised_tuning.train( + source_model="gemini-1.0-pro-002", + train_dataset="gs://some-bucket/some_dataset.jsonl", + ) + assert ( + sft_tuning_job.service_account + == "test-sa@test-project.iam.gserviceaccount.com" + ) + @mock.patch.object( target=tuning.TuningJob, attribute="client_class", diff --git a/vertexai/tuning/_tuning.py b/vertexai/tuning/_tuning.py index ced3fac9b1..f080608dd9 100644 --- a/vertexai/tuning/_tuning.py +++ b/vertexai/tuning/_tuning.py @@ -107,6 +107,11 @@ def experiment(self) -> Optional[aiplatform.Experiment]: def state(self) -> gca_types.JobState: return self._gca_resource.state + @property + def service_account(self) -> Optional[str]: + self._assert_gca_resource_is_available() + return self._gca_resource.service_account + @property def has_ended(self): return self.state in jobs._JOB_COMPLETE_STATES @@ -204,6 +209,9 @@ def _create( gca_tuning_job.encryption_spec.kms_key_name = ( aiplatform_initializer.global_config.encryption_spec_key_name ) + gca_tuning_job.service_account = ( + aiplatform_initializer.global_config.service_account + ) tuning_job: TuningJob = cls._construct_sdk_resource_from_gapic( gapic_resource=gca_tuning_job,