Skip to content

Commit

Permalink
feat: add the recent api method for ML component (#225)
Browse files Browse the repository at this point in the history
* feat: add the recent api method for ML component
  • Loading branch information
ashleyxuu authored Nov 21, 2023
1 parent 1d14771 commit ed8876d
Show file tree
Hide file tree
Showing 11 changed files with 34 additions and 1 deletion.
2 changes: 2 additions & 0 deletions bigframes/ml/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
from google.cloud import bigquery

import bigframes
from bigframes.core import log_adapter
from bigframes.ml import base, core, globals, utils
import bigframes.pandas as bpd
import third_party.bigframes_vendored.sklearn.cluster._kmeans


@log_adapter.class_logger
class KMeans(
base.UnsupervisedTrainablePredictor,
third_party.bigframes_vendored.sklearn.cluster._kmeans.KMeans,
Expand Down
2 changes: 2 additions & 0 deletions bigframes/ml/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import List, Optional, Tuple, Union

from bigframes import constants
from bigframes.core import log_adapter
from bigframes.ml import base, core, globals, preprocessing, utils
import bigframes.pandas as bpd
import third_party.bigframes_vendored.sklearn.compose._column_transformer
Expand All @@ -36,6 +37,7 @@
]


@log_adapter.class_logger
class ColumnTransformer(
base.Transformer,
third_party.bigframes_vendored.sklearn.compose._column_transformer.ColumnTransformer,
Expand Down
2 changes: 2 additions & 0 deletions bigframes/ml/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
from google.cloud import bigquery

import bigframes
from bigframes.core import log_adapter
from bigframes.ml import base, core, globals, utils
import bigframes.pandas as bpd
import third_party.bigframes_vendored.sklearn.decomposition._pca


@log_adapter.class_logger
class PCA(
base.UnsupervisedTrainablePredictor,
third_party.bigframes_vendored.sklearn.decomposition._pca.PCA,
Expand Down
5 changes: 5 additions & 0 deletions bigframes/ml/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from google.cloud import bigquery

import bigframes
from bigframes.core import log_adapter
from bigframes.ml import base, core, globals, utils
import bigframes.pandas as bpd
import third_party.bigframes_vendored.sklearn.ensemble._forest
Expand All @@ -47,6 +48,7 @@
}


@log_adapter.class_logger
class XGBRegressor(
base.SupervisedTrainablePredictor,
third_party.bigframes_vendored.xgboost.sklearn.XGBRegressor,
Expand Down Expand Up @@ -202,6 +204,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> XGBRegressor:
return new_model.session.read_gbq_model(model_name)


@log_adapter.class_logger
class XGBClassifier(
base.SupervisedTrainablePredictor,
third_party.bigframes_vendored.xgboost.sklearn.XGBClassifier,
Expand Down Expand Up @@ -356,6 +359,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> XGBClassifier:
return new_model.session.read_gbq_model(model_name)


@log_adapter.class_logger
class RandomForestRegressor(
base.SupervisedTrainablePredictor,
third_party.bigframes_vendored.sklearn.ensemble._forest.RandomForestRegressor,
Expand Down Expand Up @@ -521,6 +525,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> RandomForestRegresso
return new_model.session.read_gbq_model(model_name)


@log_adapter.class_logger
class RandomForestClassifier(
base.SupervisedTrainablePredictor,
third_party.bigframes_vendored.sklearn.ensemble._forest.RandomForestClassifier,
Expand Down
2 changes: 2 additions & 0 deletions bigframes/ml/forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
from google.cloud import bigquery

import bigframes
from bigframes.core import log_adapter
from bigframes.ml import base, core, globals, utils
import bigframes.pandas as bpd


@log_adapter.class_logger
class ARIMAPlus(base.SupervisedTrainablePredictor):
"""Time Series ARIMA Plus model."""

Expand Down
3 changes: 3 additions & 0 deletions bigframes/ml/imported.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
from google.cloud import bigquery

import bigframes
from bigframes.core import log_adapter
from bigframes.ml import base, core, globals, utils
import bigframes.pandas as bpd


@log_adapter.class_logger
class TensorFlowModel(base.Predictor):
"""Imported TensorFlow model.
Expand Down Expand Up @@ -101,6 +103,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> TensorFlowModel:
return new_model.session.read_gbq_model(model_name)


@log_adapter.class_logger
class ONNXModel(base.Predictor):
"""Imported Open Neural Network Exchange (ONNX) model.
Expand Down
3 changes: 3 additions & 0 deletions bigframes/ml/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import bigframes
import bigframes.constants as constants
from bigframes.core import log_adapter
from bigframes.ml import base, core, globals, utils
import bigframes.pandas as bpd
import third_party.bigframes_vendored.sklearn.linear_model._base
Expand All @@ -46,6 +47,7 @@
}


@log_adapter.class_logger
class LinearRegression(
base.SupervisedTrainablePredictor,
third_party.bigframes_vendored.sklearn.linear_model._base.LinearRegression,
Expand Down Expand Up @@ -178,6 +180,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> LinearRegression:
return new_model.session.read_gbq_model(model_name)


@log_adapter.class_logger
class LogisticRegression(
base.SupervisedTrainablePredictor,
third_party.bigframes_vendored.sklearn.linear_model._logistic.LogisticRegression,
Expand Down
4 changes: 3 additions & 1 deletion bigframes/ml/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import bigframes
from bigframes import clients, constants
from bigframes.core import blocks
from bigframes.core import blocks, log_adapter
from bigframes.ml import base, core, globals, utils
import bigframes.pandas as bpd

Expand All @@ -43,6 +43,7 @@
_ML_EMBED_TEXT_STATUS = "ml_embed_text_status"


@log_adapter.class_logger
class PaLM2TextGenerator(base.Predictor):
"""PaLM2 text generator LLM model.
Expand Down Expand Up @@ -200,6 +201,7 @@ def predict(
return df


@log_adapter.class_logger
class PaLM2TextEmbeddingGenerator(base.Predictor):
"""PaLM2 text embedding generator LLM model.
Expand Down
2 changes: 2 additions & 0 deletions bigframes/ml/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@

import bigframes
import bigframes.constants as constants
from bigframes.core import log_adapter
from bigframes.ml import base, compose, forecasting, loader, preprocessing, utils
import bigframes.pandas as bpd
import third_party.bigframes_vendored.sklearn.pipeline


@log_adapter.class_logger
class Pipeline(
base.BaseEstimator,
third_party.bigframes_vendored.sklearn.pipeline.Pipeline,
Expand Down
7 changes: 7 additions & 0 deletions bigframes/ml/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import typing
from typing import Any, cast, List, Literal, Optional, Tuple, Union

from bigframes.core import log_adapter
from bigframes.ml import base, core, globals, utils
import bigframes.pandas as bpd
import third_party.bigframes_vendored.sklearn.preprocessing._data
Expand All @@ -28,6 +29,7 @@
import third_party.bigframes_vendored.sklearn.preprocessing._label


@log_adapter.class_logger
class StandardScaler(
base.Transformer,
third_party.bigframes_vendored.sklearn.preprocessing._data.StandardScaler,
Expand Down Expand Up @@ -111,6 +113,7 @@ def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
)


@log_adapter.class_logger
class MaxAbsScaler(
base.Transformer,
third_party.bigframes_vendored.sklearn.preprocessing._data.MaxAbsScaler,
Expand Down Expand Up @@ -194,6 +197,7 @@ def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
)


@log_adapter.class_logger
class MinMaxScaler(
base.Transformer,
third_party.bigframes_vendored.sklearn.preprocessing._data.MinMaxScaler,
Expand Down Expand Up @@ -277,6 +281,7 @@ def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
)


@log_adapter.class_logger
class KBinsDiscretizer(
base.Transformer,
third_party.bigframes_vendored.sklearn.preprocessing._discretization.KBinsDiscretizer,
Expand Down Expand Up @@ -395,6 +400,7 @@ def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
)


@log_adapter.class_logger
class OneHotEncoder(
base.Transformer,
third_party.bigframes_vendored.sklearn.preprocessing._encoder.OneHotEncoder,
Expand Down Expand Up @@ -524,6 +530,7 @@ def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
)


@log_adapter.class_logger
class LabelEncoder(
base.LabelTransformer,
third_party.bigframes_vendored.sklearn.preprocessing._label.LabelEncoder,
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/session/test_io_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def test_create_job_configs_labels_length_limit_not_met():


def test_create_job_configs_labels_log_adaptor_call_method_under_length_limit():
log_adapter.get_and_reset_api_methods()
cur_labels = {
"bigframes-api": "read_pandas",
"source": "bigquery-dataframes-temp",
Expand Down Expand Up @@ -87,6 +88,7 @@ def test_create_job_configs_labels_log_adaptor_call_method_under_length_limit():


def test_create_job_configs_labels_length_limit_met_and_labels_is_none():
log_adapter.get_and_reset_api_methods()
df = bpd.DataFrame({"col1": [1, 2], "col2": [3, 4]})
# Test running methods more than the labels' length limit
for i in range(66):
Expand All @@ -102,6 +104,7 @@ def test_create_job_configs_labels_length_limit_met_and_labels_is_none():


def test_create_job_configs_labels_length_limit_met():
log_adapter.get_and_reset_api_methods()
cur_labels = {
"bigframes-api": "read_pandas",
"source": "bigquery-dataframes-temp",
Expand Down

0 comments on commit ed8876d

Please sign in to comment.