Skip to content

Commit

Permalink
feat: GenAI - Tuning - Added support for BYOSA
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 700223388
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Nov 26, 2024
1 parent 598c931 commit 7cbda03
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
24 changes: 24 additions & 0 deletions tests/unit/vertexai/tuning/test_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,30 @@ def test_genai_tuning_service_encryption_spec(
)
assert sft_tuning_job.encryption_spec.kms_key_name == "test-key"

@mock.patch.object(
target=tuning.TuningJob,
attribute="client_class",
new=MockTuningJobClientWithOverride,
)
@pytest.mark.parametrize(
"supervised_tuning",
[supervised_tuning, preview_supervised_tuning],
)
def test_genai_tuning_service_service_account(
self, supervised_tuning: supervised_tuning
):
"""Test that the service account propagates to the tuning job."""
vertexai.init(service_account="test-sa@test-project.iam.gserviceaccount.com")

sft_tuning_job = supervised_tuning.train(
source_model="gemini-1.0-pro-002",
train_dataset="gs://some-bucket/some_dataset.jsonl",
)
assert (
sft_tuning_job.service_account
== "test-sa@test-project.iam.gserviceaccount.com"
)

@mock.patch.object(
target=tuning.TuningJob,
attribute="client_class",
Expand Down
8 changes: 8 additions & 0 deletions vertexai/tuning/_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ def experiment(self) -> Optional[aiplatform.Experiment]:
def state(self) -> gca_types.JobState:
return self._gca_resource.state

@property
def service_account(self) -> Optional[str]:
self._assert_gca_resource_is_available()
return self._gca_resource.service_account

@property
def has_ended(self):
return self.state in jobs._JOB_COMPLETE_STATES
Expand Down Expand Up @@ -204,6 +209,9 @@ def _create(
gca_tuning_job.encryption_spec.kms_key_name = (
aiplatform_initializer.global_config.encryption_spec_key_name
)
gca_tuning_job.service_account = (
aiplatform_initializer.global_config.service_account
)

tuning_job: TuningJob = cls._construct_sdk_resource_from_gapic(
gapic_resource=gca_tuning_job,
Expand Down

0 comments on commit 7cbda03

Please sign in to comment.