Skip to content

Commit

Permalink
feat: GenAI - Added the model Distillation feature (private preview)
Browse files Browse the repository at this point in the history
```
from google.cloud.aiplatform.private_preview import distillation

job = distillation.train(
    student_model="gemma-1.1-2b-it",
    teacher_model="gemini-1.5-flash-001",
    training_dataset="gs://some-bucket/some_dataset.jsonl",
    # Optional:
    validation_dataset="gs://some-bucket/some_dataset.jsonl",
    epoch_count=5,
    learning_rate_multiplier=1.0,
)
```
PiperOrigin-RevId: 666992707
  • Loading branch information
Ark-kun authored and copybara-github committed Aug 24, 2024
1 parent d59a052 commit a0d4ff2
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 6 deletions.
57 changes: 53 additions & 4 deletions tests/unit/vertexai/test_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
sft as preview_supervised_tuning,
)
from vertexai.tuning import sft as supervised_tuning
from vertexai.tuning import _distillation
from google.cloud import storage

import pytest

Expand Down Expand Up @@ -79,11 +81,12 @@ def create_tuning_job(
def _progress_tuning_job(self, name: str):
tuning_job: gca_tuning_job.TuningJob = self._tuning_jobs[name]
current_time = datetime.datetime.now(datetime.timezone.utc)
training_dataset_uri = (
tuning_job.supervised_tuning_spec.training_dataset_uri
or tuning_job.distillation_spec.training_dataset_uri
)
if tuning_job.state == job_state.JobState.JOB_STATE_PENDING:
if (
"invalid_dataset"
in tuning_job.supervised_tuning_spec.training_dataset_uri
):
if "invalid_dataset" in training_dataset_uri:
tuning_job.state = job_state.JobState.JOB_STATE_FAILED
tuning_job.error = status_pb2.Status(
code=400, message="Invalid dataset."
Expand Down Expand Up @@ -162,6 +165,7 @@ def setup_method(self):
vertexai.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
staging_bucket="gs://test-bucket",
)

def teardown_method(self):
Expand Down Expand Up @@ -233,3 +237,48 @@ def test_genai_tuning_service_encryption_spec(
train_dataset="gs://some-bucket/some_dataset.jsonl",
)
assert sft_tuning_job.encryption_spec.kms_key_name == "test-key"

@mock.patch.object(
target=tuning.TuningJob,
attribute="client_class",
new=MockTuningJobClientWithOverride,
)
@mock.patch.object(
target=storage.Bucket,
attribute="exists",
new=lambda _: True,
)
def test_genai_tuning_service_distillation_distill_model(self):
distillation_train = _distillation.distill_model

tuning_job = distillation_train(
student_model="gemma",
teacher_model="gemini-1.0-pro-001",
training_dataset="gs://some-bucket/some_dataset.jsonl",
# Optional:
validation_dataset="gs://some-bucket/some_dataset.jsonl",
epoch_count=300,
learning_rate_multiplier=1.0,
)
assert tuning_job.state == job_state.JobState.JOB_STATE_PENDING
assert not tuning_job.has_ended
assert not tuning_job.has_succeeded

# Refreshing the job
tuning_job.refresh()
assert tuning_job.state == job_state.JobState.JOB_STATE_PENDING
assert not tuning_job.has_ended
assert not tuning_job.has_succeeded

# Refreshing the job
tuning_job.refresh()
assert tuning_job.state == job_state.JobState.JOB_STATE_RUNNING
assert not tuning_job.has_ended
assert not tuning_job.has_succeeded

# Refreshing the job
tuning_job.refresh()
assert tuning_job.state == job_state.JobState.JOB_STATE_SUCCEEDED
assert tuning_job.has_ended
assert tuning_job.has_succeeded
assert tuning_job.tuned_model_name
88 changes: 88 additions & 0 deletions vertexai/tuning/_distillation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pylint: disable=protected-access
"""Classes for model tuning based on distillation."""

from typing import Optional

from google.cloud.aiplatform.utils import gcs_utils
from google.cloud.aiplatform_v1beta1.types import tuning_job as gca_tuning_job_types

from vertexai import generative_models
from vertexai.tuning import _tuning


def distill_model(
*,
student_model: str,
teacher_model: str,
training_dataset: str,
validation_dataset: Optional[str] = None,
epoch_count: Optional[int] = None,
learning_rate_multiplier: Optional[float] = None,
tuned_model_display_name: Optional[str] = None,
) -> "DistillationJob":
"""Tunes a model using distillation.
Args:
student_model:
Student model name for distillation, e.g., "gemma-1.1-2b-it".
teacher_model:
Teacher model name for distillation, e.g., "gemini-1.5-flash-001".
training_dataset: Cloud Storage path to file containing training dataset for distillation.
The dataset should be in JSONL format.
validation_dataset: Cloud Storage path to file containing validation dataset for distillation.
The dataset should be in JSONL format.
epoch_count: Number of training epoches for this tuning job.
learning_rate_multiplier: Learning rate multiplier for tuning.
tuned_model_display_name: The display name of the
[TunedModel][google.cloud.aiplatform.v1.Model]. The name can
be up to 128 characters long and can consist of any UTF-8 characters.
Returns:
A `TuningJob` object.
"""

if isinstance(student_model, generative_models.GenerativeModel):
student_model = student_model._prediction_resource_name

student_model = student_model.rpartition("/")[-1]
teacher_model = teacher_model.rpartition("/")[-1]

pipeline_root = (
gcs_utils.create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist()
)

distillation_spec = gca_tuning_job_types.DistillationSpec(
student_model=student_model,
base_teacher_model=teacher_model,
training_dataset_uri=training_dataset,
validation_dataset_uri=validation_dataset,
hyper_parameters=gca_tuning_job_types.DistillationHyperParameters(
epoch_count=epoch_count,
learning_rate_multiplier=learning_rate_multiplier,
),
pipeline_root_directory=pipeline_root,
)

return DistillationJob._create( # pylint: disable=protected-access
base_model=None,
tuning_spec=distillation_spec,
tuned_model_display_name=tuned_model_display_name,
)


class DistillationJob(_tuning.TuningJob):
pass
9 changes: 7 additions & 2 deletions vertexai/tuning/_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,10 @@ def _create(
cls,
*,
base_model: str,
tuning_spec: Union[gca_tuning_job_types.SupervisedTuningSpec],
tuning_spec: Union[
gca_tuning_job_types.SupervisedTuningSpec,
gca_tuning_job_types.DistillationSpec,
],
tuned_model_display_name: Optional[str] = None,
description: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
Expand All @@ -145,7 +148,7 @@ def _create(
This field is a member of `oneof`_ ``source_model``.
tuning_spec: Tuning Spec for Fine Tuning.
Supported types: SupervisedTuningSpec.
Supported types: SupervisedTuningSpec, DistillationSpec.
tuned_model_display_name: The display name of the
[TunedModel][google.cloud.aiplatform.v1.Model]. The name can
be up to 128 characters long and can consist of any UTF-8
Expand Down Expand Up @@ -192,6 +195,8 @@ def _create(

if isinstance(tuning_spec, gca_tuning_job_types.SupervisedTuningSpec):
gca_tuning_job.supervised_tuning_spec = tuning_spec
elif isinstance(tuning_spec, gca_tuning_job_types.DistillationSpec):
gca_tuning_job.distillation_spec = tuning_spec
else:
raise RuntimeError(f"Unsupported tuning_spec kind: {tuning_spec}")

Expand Down

0 comments on commit a0d4ff2

Please sign in to comment.