Skip to content

Commit

Permalink
MAINT rename base_estimator in _BaseChain subclasses (scikit-lear…
Browse files Browse the repository at this point in the history
…n#30152)

Co-authored-by: Stefanie Senger <91849487+StefanieSenger@users.noreply.github.com>
Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
Co-authored-by: Guillaume Lemaitre <guillaume@probabl.ai>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
  • Loading branch information
5 people authored Jan 2, 2025
1 parent 99d5cd0 commit 6c163c6
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 19 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- The parameter `base_estimator` has been deprecated in favour of `estimator` for
:class:`multioutput.RegressorChain` and :class:`multioutput.ClassifierChain`.
By :user:`Success Moses <SuccessMoses>` and :user:`dikraMasrour <dikra_masrour>`
93 changes: 78 additions & 15 deletions sklearn/multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# SPDX-License-Identifier: BSD-3-Clause


import warnings
from abc import ABCMeta, abstractmethod
from numbers import Integral

Expand All @@ -26,7 +27,11 @@
)
from .model_selection import cross_val_predict
from .utils import Bunch, check_random_state, get_tags
from .utils._param_validation import HasMethods, StrOptions
from .utils._param_validation import (
HasMethods,
Hidden,
StrOptions,
)
from .utils._response import _get_response_values
from .utils._user_interface import _print_elapsed_time
from .utils.metadata_routing import (
Expand Down Expand Up @@ -628,7 +633,7 @@ def _available_if_base_estimator_has(attr):
"""

def _check(self):
return hasattr(self.base_estimator, attr) or all(
return hasattr(self._get_estimator(), attr) or all(
hasattr(est, attr) for est in self.estimators_
)

Expand All @@ -637,22 +642,61 @@ def _check(self):

class _BaseChain(BaseEstimator, metaclass=ABCMeta):
_parameter_constraints: dict = {
"base_estimator": [HasMethods(["fit", "predict"])],
"base_estimator": [
HasMethods(["fit", "predict"]),
StrOptions({"deprecated"}),
],
"estimator": [
HasMethods(["fit", "predict"]),
Hidden(None),
],
"order": ["array-like", StrOptions({"random"}), None],
"cv": ["cv_object", StrOptions({"prefit"})],
"random_state": ["random_state"],
"verbose": ["boolean"],
}

# TODO(1.9): Remove base_estimator
def __init__(
self, base_estimator, *, order=None, cv=None, random_state=None, verbose=False
self,
estimator=None,
*,
order=None,
cv=None,
random_state=None,
verbose=False,
base_estimator="deprecated",
):
self.estimator = estimator
self.base_estimator = base_estimator
self.order = order
self.cv = cv
self.random_state = random_state
self.verbose = verbose

# TODO(1.8): This is a temporary getter method to validate input wrt deprecation.
# It was only included to avoid relying on the presence of self.estimator_
def _get_estimator(self):
"""Get and validate estimator."""

if self.estimator is not None and (self.base_estimator != "deprecated"):
raise ValueError(
"Both `estimator` and `base_estimator` are provided. You should only"
" pass `estimator`. `base_estimator` as a parameter is deprecated in"
" version 1.7, and will be removed in version 1.9."
)

if self.base_estimator != "deprecated":

warning_msg = (
"`base_estimator` as an argument was deprecated in 1.7 and will be"
" removed in 1.9. Use `estimator` instead."
)
warnings.warn(warning_msg, FutureWarning)
return self.base_estimator
else:
return self.estimator

def _log_message(self, *, estimator_idx, n_estimators, processing_msg):
if not self.verbose:
return None
Expand Down Expand Up @@ -735,7 +779,7 @@ def fit(self, X, Y, **fit_params):
elif sorted(self.order_) != list(range(Y.shape[1])):
raise ValueError("invalid order")

self.estimators_ = [clone(self.base_estimator) for _ in range(Y.shape[1])]
self.estimators_ = [clone(self._get_estimator()) for _ in range(Y.shape[1])]

if self.cv is None:
Y_pred_chain = Y[:, self.order_]
Expand Down Expand Up @@ -774,7 +818,7 @@ def fit(self, X, Y, **fit_params):

if hasattr(self, "chain_method"):
chain_method = _check_response_method(
self.base_estimator,
self._get_estimator(),
self.chain_method,
).__name__
self.chain_method_ = chain_method
Expand All @@ -799,7 +843,7 @@ def fit(self, X, Y, **fit_params):
if self.cv is not None and chain_idx < len(self.estimators_) - 1:
col_idx = X.shape[1] + chain_idx
cv_result = cross_val_predict(
self.base_estimator,
self._get_estimator(),
X_aug[:, :col_idx],
y=y,
cv=self.cv,
Expand Down Expand Up @@ -832,7 +876,7 @@ def predict(self, X):

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.sparse = get_tags(self.base_estimator).input_tags.sparse
tags.input_tags.sparse = get_tags(self._get_estimator()).input_tags.sparse
return tags


Expand All @@ -854,7 +898,7 @@ class ClassifierChain(MetaEstimatorMixin, ClassifierMixin, _BaseChain):
Parameters
----------
base_estimator : estimator
estimator : estimator
The base estimator from which the classifier chain is built.
order : array-like of shape (n_outputs,) or 'random', default=None
Expand Down Expand Up @@ -911,6 +955,13 @@ class ClassifierChain(MetaEstimatorMixin, ClassifierMixin, _BaseChain):
.. versionadded:: 1.2
base_estimator : estimator, default="deprecated"
Use `estimator` instead.
.. deprecated:: 1.7
`base_estimator` is deprecated and will be removed in 1.9.
Use `estimator` instead.
Attributes
----------
classes_ : list
Expand Down Expand Up @@ -985,22 +1036,25 @@ class labels for each estimator in the chain.
],
}

# TODO(1.9): Remove base_estimator from __init__
def __init__(
self,
base_estimator,
estimator=None,
*,
order=None,
cv=None,
chain_method="predict",
random_state=None,
verbose=False,
base_estimator="deprecated",
):
super().__init__(
base_estimator,
estimator,
order=order,
cv=cv,
random_state=random_state,
verbose=verbose,
base_estimator=base_estimator,
)
self.chain_method = chain_method

Expand Down Expand Up @@ -1100,8 +1154,9 @@ def get_metadata_routing(self):
A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating
routing information.
"""

router = MetadataRouter(owner=self.__class__.__name__).add(
estimator=self.base_estimator,
estimator=self._get_estimator(),
method_mapping=MethodMapping().add(caller="fit", callee="fit"),
)
return router
Expand All @@ -1128,7 +1183,7 @@ class RegressorChain(MetaEstimatorMixin, RegressorMixin, _BaseChain):
Parameters
----------
base_estimator : estimator
estimator : estimator
The base estimator from which the regressor chain is built.
order : array-like of shape (n_outputs,) or 'random', default=None
Expand Down Expand Up @@ -1172,6 +1227,13 @@ class RegressorChain(MetaEstimatorMixin, RegressorMixin, _BaseChain):
.. versionadded:: 1.2
base_estimator : estimator, default="deprecated"
Use `estimator` instead.
.. deprecated:: 1.7
`base_estimator` is deprecated and will be removed in 1.9.
Use `estimator` instead.
Attributes
----------
estimators_ : list
Expand Down Expand Up @@ -1204,7 +1266,7 @@ class RegressorChain(MetaEstimatorMixin, RegressorMixin, _BaseChain):
>>> from sklearn.linear_model import LogisticRegression
>>> logreg = LogisticRegression(solver='lbfgs')
>>> X, Y = [[1, 0], [0, 1], [1, 1]], [[0, 2], [1, 1], [2, 0]]
>>> chain = RegressorChain(base_estimator=logreg, order=[0, 1]).fit(X, Y)
>>> chain = RegressorChain(logreg, order=[0, 1]).fit(X, Y)
>>> chain.predict(X)
array([[0., 2.],
[1., 1.],
Expand Down Expand Up @@ -1254,8 +1316,9 @@ def get_metadata_routing(self):
A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating
routing information.
"""

router = MetadataRouter(owner=self.__class__.__name__).add(
estimator=self.base_estimator,
estimator=self._get_estimator(),
method_mapping=MethodMapping().add(caller="fit", callee="fit"),
)
return router
Expand Down
4 changes: 2 additions & 2 deletions sklearn/tests/test_metaestimators_metadata_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,15 @@
},
{
"metaestimator": ClassifierChain,
"estimator_name": "base_estimator",
"estimator_name": "estimator",
"estimator": "classifier",
"X": X,
"y": y_multi,
"estimator_routing_methods": ["fit"],
},
{
"metaestimator": RegressorChain,
"estimator_name": "base_estimator",
"estimator_name": "estimator",
"estimator": "regressor",
"X": X,
"y": y_multi,
Expand Down
16 changes: 16 additions & 0 deletions sklearn/tests/test_multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,3 +864,19 @@ def test_multioutput_regressor_has_partial_fit():
msg = "This 'MultiOutputRegressor' has no attribute 'partial_fit'"
with pytest.raises(AttributeError, match=msg):
getattr(est, "partial_fit")


# TODO(1.9): remove when deprecated `base_estimator` is removed
@pytest.mark.parametrize("Estimator", [ClassifierChain, RegressorChain])
def test_base_estimator_deprecation(Estimator):
"""Check that we warn about the deprecation of `base_estimator`."""
X = np.array([[1, 2], [3, 4]])
y = np.array([[1, 0], [0, 1]])

estimator = LogisticRegression()

with pytest.warns(FutureWarning):
Estimator(base_estimator=estimator).fit(X, y)

with pytest.raises(ValueError):
Estimator(base_estimator=estimator, estimator=estimator).fit(X, y)
4 changes: 2 additions & 2 deletions sklearn/utils/_test_common/instance_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@
BisectingKMeans: dict(n_init=2, n_clusters=2, max_iter=5),
CalibratedClassifierCV: dict(estimator=LogisticRegression(C=1), cv=3),
CCA: dict(n_components=1, max_iter=5),
ClassifierChain: dict(base_estimator=LogisticRegression(C=1), cv=3),
ClassifierChain: dict(estimator=LogisticRegression(C=1), cv=3),
ColumnTransformer: dict(transformers=[("trans1", StandardScaler(), [0, 1])]),
DictionaryLearning: dict(max_iter=20, transform_algorithm="lasso_lars"),
# the default strategy prior would output constant predictions and fail
Expand Down Expand Up @@ -429,7 +429,7 @@
# For common tests, we can enforce using `LinearRegression` that
# is the default estimator in `RANSACRegressor` instead of `Ridge`.
RANSACRegressor: dict(estimator=LinearRegression(), max_trials=10),
RegressorChain: dict(base_estimator=Ridge(), cv=3),
RegressorChain: dict(estimator=Ridge(), cv=3),
RFECV: dict(estimator=LogisticRegression(C=1), cv=3),
RFE: dict(estimator=LogisticRegression(C=1)),
# be tolerant of noisy datasets (not actually speed)
Expand Down

0 comments on commit 6c163c6

Please sign in to comment.