Skip to content

Commit

Permalink
feat: add labels parameter to the supervised tuning train method
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 636381156
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed May 23, 2024
1 parent 0936f35 commit f7c5567
Showing 1 changed file with 28 additions and 27 deletions.
55 changes: 28 additions & 27 deletions vertexai/tuning/_supervised_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
# limitations under the License.
#

from typing import Literal, Optional, Union
from typing import Dict, Literal, Optional, Union

from google.cloud.aiplatform_v1.types import tuning_job as gca_tuning_job_types

from vertexai import generative_models
from vertexai.tuning import _tuning

Expand All @@ -30,27 +29,28 @@ def train(
epochs: Optional[int] = None,
learning_rate_multiplier: Optional[float] = None,
adapter_size: Optional[Literal[1, 4, 8, 16]] = None,
labels: Optional[Dict[str, str]] = None,
) -> "SupervisedTuningJob":
"""Tunes a model using supervised training.
"""Tunes a model using supervised training.
Args:
source_model (str):
Model name for tuning, e.g., "gemini-1.0-pro-002".
train_dataset: Cloud Storage path to file containing training dataset for tuning.
The dataset should be in JSONL format.
validation_dataset: Cloud Storage path to file containing validation dataset for tuning.
The dataset should be in JSONL format.
tuned_model_display_name: The display name of the
[TunedModel][google.cloud.aiplatform.v1.Model]. The name can
be up to 128 characters long and can consist of any UTF-8 characters.
epochs: Number of training epoches for this tuning job.
learning_rate_multiplier: Learning rate multiplier for tuning.
adapter_size: Adapter size for tuning.
Args:
source_model (str): Model name for tuning, e.g., "gemini-1.0-pro-002".
train_dataset: Cloud Storage path to file containing training dataset for
tuning. The dataset should be in JSONL format.
validation_dataset: Cloud Storage path to file containing validation
dataset for tuning. The dataset should be in JSONL format.
tuned_model_display_name: The display name of the
[TunedModel][google.cloud.aiplatform.v1.Model]. The name can be up to
128 characters long and can consist of any UTF-8 characters.
epochs: Number of training epoches for this tuning job.
learning_rate_multiplier: Learning rate multiplier for tuning.
adapter_size: Adapter size for tuning.
labels: User-defined metadata to be associated with trained models
Returns:
A `TuningJob` object.
"""
supervised_tuning_spec = gca_tuning_job_types.SupervisedTuningSpec(
Returns:
A `TuningJob` object.
"""
supervised_tuning_spec = gca_tuning_job_types.SupervisedTuningSpec(
training_dataset_uri=train_dataset,
validation_dataset_uri=validation_dataset,
hyper_parameters=gca_tuning_job_types.SupervisedHyperParameters(
Expand All @@ -60,14 +60,15 @@ def train(
),
)

if isinstance(source_model, generative_models.GenerativeModel):
source_model = source_model._prediction_resource_name.rpartition('/')[-1]
if isinstance(source_model, generative_models.GenerativeModel):
source_model = source_model._prediction_resource_name.rpartition('/')[-1]

return SupervisedTuningJob._create( # pylint: disable=protected-access
base_model=source_model,
tuning_spec=supervised_tuning_spec,
tuned_model_display_name=tuned_model_display_name,
)
return SupervisedTuningJob._create( # pylint: disable=protected-access
base_model=source_model,
tuning_spec=supervised_tuning_spec,
tuned_model_display_name=tuned_model_display_name,
labels=labels,
)


class SupervisedTuningJob(_tuning.TuningJob):
Expand Down

0 comments on commit f7c5567

Please sign in to comment.