Skip to content

Commit

Permalink
feat: add max_retries to TextEmbeddingGenerator and Claude3TextGenera…
Browse files Browse the repository at this point in the history
…tor (#1259)

* chore: fix wordings of Gemini max_retries

* feat: add max_retries to TextEmbeddingGenerator and Claude3TextGenerator

---------

Co-authored-by: Shuowei Li <shuowei.l@outlook.com>
  • Loading branch information
GarrettWu and shuoweil authored Jan 6, 2025
1 parent 796fc3e commit 8077ff4
Show file tree
Hide file tree
Showing 3 changed files with 369 additions and 92 deletions.
64 changes: 57 additions & 7 deletions bigframes/ml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
"""

import abc
from typing import cast, Optional, TypeVar
from typing import Callable, cast, Mapping, Optional, TypeVar
import warnings

import bigframes_vendored.sklearn.base

Expand Down Expand Up @@ -77,6 +78,9 @@ def fit_transform(self, x_train: Union[DataFrame, Series], y_train: Union[DataFr
...
"""

def __init__(self):
self._bqml_model: Optional[core.BqmlModel] = None

def __repr__(self):
"""Print the estimator's constructor with all non-default parameter values."""

Expand All @@ -95,9 +99,6 @@ def __repr__(self):
class Predictor(BaseEstimator):
"""A BigQuery DataFrames ML Model base class that can be used to predict outputs."""

def __init__(self):
self._bqml_model: Optional[core.BqmlModel] = None

@abc.abstractmethod
def predict(self, X):
pass
Expand Down Expand Up @@ -213,12 +214,61 @@ def fit(
return self._fit(X, y)


class RetriableRemotePredictor(BaseEstimator):
@property
@abc.abstractmethod
def _predict_func(self) -> Callable[[bpd.DataFrame, Mapping], bpd.DataFrame]:
pass

@property
@abc.abstractmethod
def _status_col(self) -> str:
pass

def _predict_and_retry(
self, X: bpd.DataFrame, options: Mapping, max_retries: int
) -> bpd.DataFrame:
assert self._bqml_model is not None

df_result = bpd.DataFrame(session=self._bqml_model.session) # placeholder
df_fail = X
for _ in range(max_retries + 1):
df = self._predict_func(df_fail, options)

success = df[self._status_col].str.len() == 0
df_succ = df[success]
df_fail = df[~success]

if df_succ.empty:
if max_retries > 0:
warnings.warn(
"Can't make any progress, stop retrying.", RuntimeWarning
)
break

df_result = (
bpd.concat([df_result, df_succ]) if not df_result.empty else df_succ
)

if df_fail.empty:
break

if not df_fail.empty:
warnings.warn(
f"Some predictions failed. Check column {self._status_col} for detailed status. You may want to filter the failed rows and retry.",
RuntimeWarning,
)

df_result = cast(
bpd.DataFrame,
bpd.concat([df_result, df_fail]) if not df_result.empty else df_fail,
)
return df_result


class BaseTransformer(BaseEstimator):
"""Transformer base class."""

def __init__(self):
self._bqml_model: Optional[core.BqmlModel] = None

@abc.abstractmethod
def _keys(self):
pass
Expand Down
109 changes: 50 additions & 59 deletions bigframes/ml/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from __future__ import annotations

from typing import cast, Literal, Optional
from typing import Callable, cast, Literal, Mapping, Optional
import warnings

import bigframes_vendored.constants as constants
Expand Down Expand Up @@ -616,7 +616,7 @@ def to_gbq(


@log_adapter.class_logger
class TextEmbeddingGenerator(base.BaseEstimator):
class TextEmbeddingGenerator(base.RetriableRemotePredictor):
"""Text embedding generator LLM model.
Args:
Expand Down Expand Up @@ -715,18 +715,33 @@ def _from_bq(
model._bqml_model = core.BqmlModel(session, bq_model)
return model

def predict(self, X: utils.ArrayType) -> bpd.DataFrame:
@property
def _predict_func(self) -> Callable[[bpd.DataFrame, Mapping], bpd.DataFrame]:
return self._bqml_model.generate_embedding

@property
def _status_col(self) -> str:
return _ML_GENERATE_EMBEDDING_STATUS

def predict(self, X: utils.ArrayType, *, max_retries: int = 0) -> bpd.DataFrame:
"""Predict the result from input DataFrame.
Args:
X (bigframes.dataframe.DataFrame or bigframes.series.Series or pandas.core.frame.DataFrame or pandas.core.series.Series):
Input DataFrame or Series, can contain one or more columns. If multiple columns are in the DataFrame, it must contain a "content" column for prediction.
max_retries (int, default 0):
Max number of retries if the prediction for any rows failed. Each try needs to make progress (i.e. has successfully predicted rows) to continue the retry.
Each retry will append newly succeeded rows. When the max retries are reached, the remaining rows (the ones without successful predictions) will be appended to the end of the result.
Returns:
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
"""
if max_retries < 0:
raise ValueError(
f"max_retries must be larger than or equal to 0, but is {max_retries}."
)

# Params reference: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models
(X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session)

if len(X.columns) == 1:
Expand All @@ -738,15 +753,7 @@ def predict(self, X: utils.ArrayType) -> bpd.DataFrame:
"flatten_json_output": True,
}

df = self._bqml_model.generate_embedding(X, options)

if (df[_ML_GENERATE_EMBEDDING_STATUS] != "").any():
warnings.warn(
f"Some predictions failed. Check column {_ML_GENERATE_EMBEDDING_STATUS} for detailed status. You may want to filter the failed rows and retry.",
RuntimeWarning,
)

return df
return self._predict_and_retry(X, options=options, max_retries=max_retries)

def to_gbq(self, model_name: str, replace: bool = False) -> TextEmbeddingGenerator:
"""Save the model to BigQuery.
Expand All @@ -765,7 +772,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> TextEmbeddingGenerat


@log_adapter.class_logger
class GeminiTextGenerator(base.BaseEstimator):
class GeminiTextGenerator(base.RetriableRemotePredictor):
"""Gemini text generator LLM model.
Args:
Expand Down Expand Up @@ -891,6 +898,14 @@ def _bqml_options(self) -> dict:
}
return options

@property
def _predict_func(self) -> Callable[[bpd.DataFrame, Mapping], bpd.DataFrame]:
return self._bqml_model.generate_text

@property
def _status_col(self) -> str:
return _ML_GENERATE_TEXT_STATUS

def fit(
self,
X: utils.ArrayType,
Expand Down Expand Up @@ -1028,41 +1043,7 @@ def predict(
"ground_with_google_search": ground_with_google_search,
}

df_result = bpd.DataFrame(session=self._bqml_model.session) # placeholder
df_fail = X
for _ in range(max_retries + 1):
df = self._bqml_model.generate_text(df_fail, options)

success = df[_ML_GENERATE_TEXT_STATUS].str.len() == 0
df_succ = df[success]
df_fail = df[~success]

if df_succ.empty:
if max_retries > 0:
warnings.warn(
"Can't make any progress, stop retrying.", RuntimeWarning
)
break

df_result = (
bpd.concat([df_result, df_succ]) if not df_result.empty else df_succ
)

if df_fail.empty:
break

if not df_fail.empty:
warnings.warn(
f"Some predictions failed. Check column {_ML_GENERATE_TEXT_STATUS} for detailed status. You may want to filter the failed rows and retry.",
RuntimeWarning,
)

df_result = cast(
bpd.DataFrame,
bpd.concat([df_result, df_fail]) if not df_result.empty else df_fail,
)

return df_result
return self._predict_and_retry(X, options=options, max_retries=max_retries)

def score(
self,
Expand Down Expand Up @@ -1144,7 +1125,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> GeminiTextGenerator:


@log_adapter.class_logger
class Claude3TextGenerator(base.BaseEstimator):
class Claude3TextGenerator(base.RetriableRemotePredictor):
"""Claude3 text generator LLM model.
Go to Google Cloud Console -> Vertex AI -> Model Garden page to enabe the models before use. Must have the Consumer Procurement Entitlement Manager Identity and Access Management (IAM) role to enable the models.
Expand Down Expand Up @@ -1273,13 +1254,22 @@ def _bqml_options(self) -> dict:
}
return options

@property
def _predict_func(self) -> Callable[[bpd.DataFrame, Mapping], bpd.DataFrame]:
return self._bqml_model.generate_text

@property
def _status_col(self) -> str:
return _ML_GENERATE_TEXT_STATUS

def predict(
self,
X: utils.ArrayType,
*,
max_output_tokens: int = 128,
top_k: int = 40,
top_p: float = 0.95,
max_retries: int = 0,
) -> bpd.DataFrame:
"""Predict the result from input DataFrame.
Expand Down Expand Up @@ -1307,6 +1297,10 @@ def predict(
Specify a lower value for less random responses and a higher value for more random responses.
Default 0.95. Possible values [0.0, 1.0].
max_retries (int, default 0):
Max number of retries if the prediction for any rows failed. Each try needs to make progress (i.e. has successfully predicted rows) to continue the retry.
Each retry will append newly succeeded rows. When the max retries are reached, the remaining rows (the ones without successful predictions) will be appended to the end of the result.
Returns:
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
Expand All @@ -1324,6 +1318,11 @@ def predict(
if top_p < 0.0 or top_p > 1.0:
raise ValueError(f"top_p must be [0.0, 1.0], but is {top_p}.")

if max_retries < 0:
raise ValueError(
f"max_retries must be larger than or equal to 0, but is {max_retries}."
)

(X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session)

if len(X.columns) == 1:
Expand All @@ -1338,15 +1337,7 @@ def predict(
"flatten_json_output": True,
}

df = self._bqml_model.generate_text(X, options)

if (df[_ML_GENERATE_TEXT_STATUS] != "").any():
warnings.warn(
f"Some predictions failed. Check column {_ML_GENERATE_TEXT_STATUS} for detailed status. You may want to filter the failed rows and retry.",
RuntimeWarning,
)

return df
return self._predict_and_retry(X, options=options, max_retries=max_retries)

def to_gbq(self, model_name: str, replace: bool = False) -> Claude3TextGenerator:
"""Save the model to BigQuery.
Expand Down
Loading

0 comments on commit 8077ff4

Please sign in to comment.