From a90ee8da161f95aa489aa4f09309a3fa34320a4c Mon Sep 17 00:00:00 2001 From: Jaycee Li Date: Mon, 24 Jun 2024 10:10:44 -0700 Subject: [PATCH] feat: GenAI - Batch Prediction - Added support for tuned GenAI models PiperOrigin-RevId: 646136098 --- tests/unit/vertexai/test_batch_prediction.py | 126 +++++++++++++++++- .../batch_prediction/_batch_prediction.py | 54 +++++--- 2 files changed, 163 insertions(+), 17 deletions(-) diff --git a/tests/unit/vertexai/test_batch_prediction.py b/tests/unit/vertexai/test_batch_prediction.py index 59fb9ddaa7..485da92ae6 100644 --- a/tests/unit/vertexai/test_batch_prediction.py +++ b/tests/unit/vertexai/test_batch_prediction.py @@ -25,11 +25,15 @@ import vertexai from google.cloud.aiplatform import base as aiplatform_base from google.cloud.aiplatform import initializer as aiplatform_initializer -from google.cloud.aiplatform.compat.services import job_service_client +from google.cloud.aiplatform.compat.services import ( + job_service_client, + model_service_client, +) from google.cloud.aiplatform.compat.types import ( batch_prediction_job as gca_batch_prediction_job_compat, io as gca_io_compat, job_state as gca_job_state_compat, + model as gca_model, ) from vertexai.preview import batch_prediction from vertexai.generative_models import GenerativeModel @@ -43,6 +47,7 @@ _TEST_GEMINI_MODEL_NAME = "gemini-1.0-pro" _TEST_GEMINI_MODEL_RESOURCE_NAME = f"publishers/google/models/{_TEST_GEMINI_MODEL_NAME}" +_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME = "projects/123/locations/us-central1/models/456" _TEST_PALM_MODEL_NAME = "text-bison" _TEST_PALM_MODEL_RESOURCE_NAME = f"publishers/google/models/{_TEST_PALM_MODEL_NAME}" @@ -122,6 +127,48 @@ def get_batch_prediction_job_with_gcs_output_mock(): yield get_job_mock +@pytest.fixture +def get_batch_prediction_job_with_tuned_gemini_model_mock(): + with mock.patch.object( + job_service_client.JobServiceClient, "get_batch_prediction_job" + ) as get_job_mock: + get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + model=_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME, + state=_TEST_JOB_STATE_SUCCESS, + output_info=gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo( + gcs_output_directory=_TEST_GCS_OUTPUT_PREFIX + ), + ) + yield get_job_mock + + +@pytest.fixture +def get_gemini_model_mock(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as get_model_mock: + get_model_mock.return_value = gca_model.Model( + name=_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME, + model_source_info=gca_model.ModelSourceInfo( + source_type=gca_model.ModelSourceInfo.ModelSourceType.GENIE + ), + ) + yield get_model_mock + + +@pytest.fixture +def get_non_gemini_model_mock(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as get_model_mock: + get_model_mock.return_value = gca_model.Model( + name=_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME, + ) + yield get_model_mock + + @pytest.fixture def get_batch_prediction_job_invalid_model_mock(): with mock.patch.object( @@ -205,6 +252,21 @@ def test_init_batch_prediction_job( name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY ) + def test_init_batch_prediction_job_with_tuned_gemini_model( + self, + get_batch_prediction_job_with_tuned_gemini_model_mock, + get_gemini_model_mock, + ): + batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID) + + get_batch_prediction_job_with_tuned_gemini_model_mock.assert_called_once_with( + name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY + ) + get_gemini_model_mock.assert_called_once_with( + name=_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME, + retry=aiplatform_base._DEFAULT_RETRY, + ) + @pytest.mark.usefixtures("get_batch_prediction_job_invalid_model_mock") def test_init_batch_prediction_job_invalid_model(self): with pytest.raises( @@ -217,6 +279,23 @@ def test_init_batch_prediction_job_invalid_model(self): ): batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID) + @pytest.mark.usefixtures( + "get_batch_prediction_job_with_tuned_gemini_model_mock", + "get_non_gemini_model_mock", + ) + def test_init_batch_prediction_job_with_invalid_tuned_model( + self, + ): + with pytest.raises( + ValueError, + match=( + f"BatchPredictionJob '{_TEST_BATCH_PREDICTION_JOB_ID}' " + f"runs with the model '{_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME}', " + "which is not a GenAI model." + ), + ): + batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID) + @pytest.mark.usefixtures("get_batch_prediction_job_with_gcs_output_mock") def test_submit_batch_prediction_job_with_gcs_input( self, create_batch_prediction_job_mock @@ -368,16 +447,59 @@ def test_submit_batch_prediction_job_with_bq_input_without_output_uri_prefix( timeout=None, ) + @pytest.mark.usefixtures("create_batch_prediction_job_mock") + def test_submit_batch_prediction_job_with_tuned_model( + self, + get_gemini_model_mock, + ): + job = batch_prediction.BatchPredictionJob.submit( + source_model=_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME, + input_dataset=_TEST_BQ_INPUT_URI, + ) + + assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB + get_gemini_model_mock.assert_called_once_with( + name=_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME, + retry=aiplatform_base._DEFAULT_RETRY, + ) + def test_submit_batch_prediction_job_with_invalid_source_model(self): with pytest.raises( ValueError, - match=(f"Model '{_TEST_PALM_MODEL_RESOURCE_NAME}' is not a GenAI model."), + match=( + f"Model '{_TEST_PALM_MODEL_RESOURCE_NAME}' is not a Generative AI model." + ), ): batch_prediction.BatchPredictionJob.submit( source_model=_TEST_PALM_MODEL_NAME, input_dataset=_TEST_GCS_INPUT_URI, ) + @pytest.mark.usefixtures("get_non_gemini_model_mock") + def test_submit_batch_prediction_job_with_non_gemini_tuned_model(self): + with pytest.raises( + ValueError, + match=( + f"Model '{_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME}' " + "is not a Generative AI model." + ), + ): + batch_prediction.BatchPredictionJob.submit( + source_model=_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME, + input_dataset=_TEST_GCS_INPUT_URI, + ) + + def test_submit_batch_prediction_job_with_invalid_model_name(self): + invalid_model_name = "invalid/model/name" + with pytest.raises( + ValueError, + match=(f"Invalid format for model name: {invalid_model_name}."), + ): + batch_prediction.BatchPredictionJob.submit( + source_model=invalid_model_name, + input_dataset=_TEST_GCS_INPUT_URI, + ) + def test_submit_batch_prediction_job_with_invalid_input_dataset(self): with pytest.raises( ValueError, diff --git a/vertexai/batch_prediction/_batch_prediction.py b/vertexai/batch_prediction/_batch_prediction.py index 7900579cf4..c7cafc7543 100644 --- a/vertexai/batch_prediction/_batch_prediction.py +++ b/vertexai/batch_prediction/_batch_prediction.py @@ -22,6 +22,7 @@ from google.cloud.aiplatform import base as aiplatform_base from google.cloud.aiplatform import initializer as aiplatform_initializer from google.cloud.aiplatform import jobs +from google.cloud.aiplatform import models from google.cloud.aiplatform import utils as aiplatform_utils from google.cloud.aiplatform_v1 import types as gca_types from vertexai import generative_models @@ -32,6 +33,7 @@ _LOGGER = aiplatform_base.Logger(__name__) _GEMINI_MODEL_PATTERN = r"publishers/google/models/gemini" +_GEMINI_TUNED_MODEL_PATTERN = r"^projects/[0-9]+?/locations/[0-9a-z-]+?/models/[0-9]+?$" class BatchPredictionJob(aiplatform_base._VertexAiResourceNounPlus): @@ -64,8 +66,7 @@ def __init__(self, batch_prediction_job_name: str): self._gca_resource = self._get_gca_resource( resource_name=batch_prediction_job_name ) - # TODO(b/338452508) Support tuned GenAI models. - if not re.search(_GEMINI_MODEL_PATTERN, self.model_name): + if not self._is_genai_model(self.model_name): raise ValueError( f"BatchPredictionJob '{batch_prediction_job_name}' " f"runs with the model '{self.model_name}', " @@ -117,9 +118,12 @@ def submit( Args: source_model (Union[str, generative_models.GenerativeModel]): - Model name or a GenerativeModel instance for batch prediction. - Supported formats: "gemini-1.0-pro", "models/gemini-1.0-pro", - and "publishers/google/models/gemini-1.0-pro" + A GenAI model name or a tuned model name or a GenerativeModel instance + for batch prediction. + Supported formats for model name: "gemini-1.0-pro", + "models/gemini-1.0-pro", and "publishers/google/models/gemini-1.0-pro" + Supported formats for tuned model name: "789" and + "projects/123/locations/456/models/789" input_dataset (Union[str,List[str]]): GCS URI(-s) or Bigquery URI to your input data to run batch prediction on. Example: "gs://path/to/input/data.jsonl" or @@ -142,12 +146,13 @@ def submit( set in vertexai.init(). """ # Handle model name - # TODO(b/338452508) Support tuned GenAI models. model_name = cls._reconcile_model_name( source_model._model_name if isinstance(source_model, generative_models.GenerativeModel) else source_model ) + if not cls._is_genai_model(model_name): + raise ValueError(f"Model '{model_name}' is not a Generative AI model.") # Handle input URI gcs_source = None @@ -244,9 +249,7 @@ def delete(self): def list(cls, filter=None) -> List["BatchPredictionJob"]: """Lists all BatchPredictionJob instances that run with GenAI models.""" return cls._list( - cls_filter=lambda gca_resource: re.search( - _GEMINI_MODEL_PATTERN, gca_resource.model - ), + cls_filter=lambda gca_resource: cls._is_genai_model(gca_resource.model), filter=filter, ) @@ -263,23 +266,44 @@ def _dashboard_uri(self) -> Optional[str]: @classmethod def _reconcile_model_name(cls, model_name: str) -> str: - """Reconciles model name to a publisher model resource name.""" + """Reconciles model name to a publisher model resource name or a tuned model resource name.""" if not model_name: raise ValueError("model_name must not be empty") + if "/" not in model_name: + # model name (e.g., gemini-1.0-pro) model_name = "publishers/google/models/" + model_name elif model_name.startswith("models/"): + # publisher model name (e.g., models/gemini-1.0-pro) model_name = "publishers/google/" + model_name - elif not model_name.startswith("publishers/google/models/") and not re.search( - r"^projects/.*?/locations/.*?/publishers/google/models/.*$", model_name + elif ( + # publisher model full name + not model_name.startswith("publishers/google/models/") + # tuned model full resource name + and not re.search(_GEMINI_TUNED_MODEL_PATTERN, model_name) ): raise ValueError(f"Invalid format for model name: {model_name}.") - if not re.search(_GEMINI_MODEL_PATTERN, model_name): - raise ValueError(f"Model '{model_name}' is not a GenAI model.") - return model_name + @classmethod + def _is_genai_model(cls, model_name: str) -> bool: + """Validates if a given model_name represents a GenAI model.""" + if re.search(_GEMINI_MODEL_PATTERN, model_name): + # Model is a Gemini model. + return True + + if re.search(_GEMINI_TUNED_MODEL_PATTERN, model_name): + model = models.Model(model_name) + if ( + model.gca_resource.model_source_info.source_type + == gca_types.model.ModelSourceInfo.ModelSourceType.GENIE + ): + # Model is a tuned Gemini model. + return True + + return False + @classmethod def _complete_bq_uri(cls, uri: Optional[str] = None): """Completes a BigQuery uri to a BigQuery table uri."""