-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[python-package] make scikit-learn estimator tags compatible with scikit-learn>=1.6.0dev
#6651
base: master
Are you sure you want to change the base?
[python-package] make scikit-learn estimator tags compatible with scikit-learn>=1.6.0dev
#6651
Conversation
Update: The change introduced in scikit-learn/scikit-learn#29677 makes it hard to subclass a sklearn estimator in a codebase while being compatible with sklearn < 1.6.0 and sklearn >= 1.6.0. Essentially the former looks up The issue is discussed here: and it looks like a relaxation of the impossibility of having both |
@vnherdeiro note that it's possible already to support both with this method (scikit-learn/scikit-learn#29677 (comment)), however, the version check and |
Correct I am waiting for that PR to go in to bring back _more_tags
Using @available_if would require another sklearn import and make the code
less readable I reckon
…On Thu, 12 Sept 2024, 11:34 am Adrin Jalali, ***@***.***> wrote:
@vnherdeiro <https://github.com/vnherdeiro> note that it's possible
already to support both with this method (scikit-learn/scikit-learn#29677
(comment)
<scikit-learn/scikit-learn#29677 (comment)>),
however, the version check and @available_if are going to be unnecessary
once we merge scikit-learn/scikit-learn#29801
<scikit-learn/scikit-learn#29801>
—
Reply to this email directly, view it on GitHub
<#6651 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AE4CNVUURU6AMLDYUXKPFTTZWFU2TAVCNFSM6AAAAABOAVNTLSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGNBVHA4DCOJWGI>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Thanks for starting on this @vnherdeiro . I've documented it in an issue: #6653 (and added that to the PR description). Note there that I intentionally put the exact errors messages in plain text instead of just referring to Note also that the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for starting on this! Please see scikit-learn/scikit-learn#29801 (comment):
The story becomes "If you want to support multiple scikit-learn versions, define both."
I think we should leave _more_tags()
untouched and add __sklearn_tags__()
. And have self.__sklearn_tags__()
call self._more_tags()
to get its data, so we don't define things like _xfail_checks
twice.
Do you have time to do that in the next few days? We need to fix this to unblock CI here, so if you don't have time to fix it this week please let me know and I will work on this.
scikit-learn>=1.16
scikit-learn>=1.16
scikit-learn>=1.16
@jameslamb Have just pushe a sklearn_tags trying a conversion from _more_tags. I added a out of current argument scope warning to catch a change from the arguments in _more_tags (they don't seem to change much though). |
scikit-learn>=1.16
scikit-learn>=1.6.0dev
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a maintainer here, but coming from sklearn side. Leaving thoughts hoping it'd help.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this.
I've reviewed the dataclasses at https://github.com/scikit-learn/scikit-learn/blob/e2ee93156bd3692722a39130c011eea313628690/sklearn/utils/_tags.py and agree with the choices you've made about how to map the dictionary-formatted values from _more_tags()
to the dataclass attributes scikit-learn
now prefers.
Please see the other comments about simplifying this.
Co-authored-by: James Lamb <jaylamb20@gmail.com>
@jameslamb have adressed your comments! thanks for the review! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd probably include a test to make sure X_types
is exactly as is here, so that when somebody changes it in the future in _more_tags
, the corresponding tags in __sklearn_tags__
is also changed (and the test itself)
I started looking into this and realized that I'll push commits here adding this test and fixing that. |
This is proving to be very challenging to get right, because python -c "import lightgbm; print(lightgbm.LGBMRegressor.__mro__)"
# (<class 'lightgbm.sklearn.LGBMRegressor'>,
# <class 'sklearn.base.RegressorMixin'>,
# <class 'lightgbm.sklearn.LGBMModel'>,
# <class 'sklearn.base.BaseEstimator'>,
# <class 'sklearn.utils._estimator_html_repr._HTMLDocumentationLinkMixin'>,
# <class 'sklearn.utils._metadata_requests._MetadataRequester'>,
# <class 'object'> (we do that intentionally, following the advice from "BaseEstimator and mixins" at https://scikit-learn.org/stable/developers/develop.html#rolling-your-own-estimator) I'm finding it difficult to preserve the LightGBM-specific changes that we want (that @vnherdeiro has implemented here) without them being overwritten by the Will come back to this tomorrow, when I can, and will try to put together a clear reproducible example. The amount of indirection here means that'll take a bit more time than I have today. |
.ci/test.sh
Outdated
@@ -103,6 +103,7 @@ if [[ $TASK == "lint" ]]; then | |||
'mypy>=1.11.1' \ | |||
'pre-commit>=3.8.0' \ | |||
'pyarrow-core>=17.0' \ | |||
'scikit-learn>=1.15.0' \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is to ensure that mypy
checks scikit-learn
imports. Extra important now that I'm proposing adding an optional type hint on this new sklearn.utils.Tags
.
"check_n_features_in_after_fitting": ( | ||
"validate_data() was first added in scikit-learn 1.6 and lightgbm" | ||
"supports much older versions than that" | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On the 1.6.dev
nightlies, scikit-learn
is raising this error:
E AssertionError:
LGBMRegressor.predict()
does not check for consistency between input number
E of features with LGBMRegressor.fit(), via then_features_in_
attribute.
E You might want to usesklearn.utils.validation.validate_data
instead
E ofcheck_array
inLGBMRegressor.fit()
and LGBMRegressor.predict()`. This can be done
E like the following:
E from sklearn.utils.validation import validate_data
We should ignore this check here in LightGBM... validate_data()
will be added for the first time in scikit-learn
1.6:
- https://github.com/scikit-learn/scikit-learn/blob/74a33757c8a8df84d227f28bbc9ec7ae2fb51dea/sklearn/utils/validation.py#L2790
- API move BaseEstimator._validate_data to utils.validation.validate_data scikit-learn/scikit-learn#29696
We have other mechanisms further down in LightGBM to check shape mismatches between training data and the data provided at scoring time. I'd rather rely on those than take on the complexity of try-catching a call to this new-in-v1.6 validate_data()
function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Understand you don't want to use validate_data
here, but you can still conform to the API with your own tools.
You probably also want to make sure you store n_feature_in_
as well, to better imitate sklearn's behavior.
I would personally go down the fixes.py
path though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Understand you don't want to use validate_data here, but you can still conform to the API with your own tools.
How could we avoid the check_n_features_in_after_fitting
check failing without calling validate_data()
? Could you point to a doc I could reference?
You probably also want to make sure you store
n_feature_in_
as well, to better imitate sklearn's behavior.
We do.
LightGBM/python-package/lightgbm/sklearn.py
Lines 1063 to 1068 in 41ba9e8
@property | |
def n_features_in_(self) -> int: | |
""":obj:`int`: The number of features of fitted model.""" | |
if not self.__sklearn_is_fitted__(): | |
raise LGBMNotFittedError("No n_features_in found. Need to call fit beforehand.") | |
return self._n_features_in |
python-package/lightgbm/sklearn.py
Outdated
def _more_tags(self) -> Dict[str, Any]: | ||
# handle the case where ClassifierMixin possibly provides _more_tags() | ||
if callable(getattr(_LGBMClassifierBase, "_more_tags", None)): | ||
tags = _LGBMClassifierBase._more_tags(self) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Proposing all these uses of {some_class}.{some_method}
instead of super().{some_method}
because we follow this advice from scikit-learn
's docs (https://scikit-learn.org/stable/developers/develop.html#rolling-your-own-estimator):
...mixins should be “on the left” while the
BaseEstimator
should be “on the right” in the inheritance list for proper MRO.
Using super()
would get the _more_tags()
/ __sklearn_tags__()
from e.g. sklearn.base.RegressorMixin
, but we want to use LightGBM's specific tags.
I've pushed commits here adding testing and ensuring that Since I've added so much code to this, my review should not count towards a merge. @StrikerRUS or @jmoralez could you please review whenever you have time? And of course @adrinjalali we'd welcome your feedback if you have time/interest. It's been great having you here helping us adapt so far! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something that's happening here, is that you're adding complexity in a few places, to handle different dependency versions. This is quite a common pattern, and we have it whenever we have dependencies and support multiple versions.
What we tend to do instead, is to have a utils/fixes.py
kinda thing, where we put all version dependent code, and we only call those methods / import from there. That means we mostly have only one file to look at, when we upgrade minimum dependency versions.
These are two examples:
# _LGBMModelBase.__sklearn_tags__() cannot be called unconditionally, | ||
# because that method isn't defined for scikit-learn<1.6 | ||
if not callable(getattr(_LGBMModelBase, "__sklearn_tags__", None)): | ||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would personally prefer using an available_if
here since now this logic is not less complicated as having that one. But this works too. However, maybe raising an AttributeError
would be better? This method doesn't need to exist in older sklearn versions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean sklearn.utils.available_if
?
I prefer this method with getattr()
that doesn't take on another sklearn
import that could possibly be moved or changed in future versions.
return self._update_sklearn_tags_from_dict( | ||
tags=_LGBMModelBase.__sklearn_tags__(self), | ||
tags_dict=self._more_tags(), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wondering why you're not getting it through super()
to let the MRO decide?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I explained this here: #6651 (comment)
python-package/lightgbm/sklearn.py
Outdated
def _more_tags(self) -> Dict[str, Any]: | ||
# handle the case where ClassifierMixin possibly provides _more_tags() | ||
if callable(getattr(_LGBMClassifierBase, "_more_tags", None)): | ||
tags = _LGBMClassifierBase._more_tags(self) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_more_tags
shouldn't care about other classes and the MRO, it should only return what it wants to add, so I'm not sure why this complexity here is needed.
Also, interesting that Classifier tags are needed in the Regressor class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_more_tags shouldn't care about other classes and the MRO, i
This is explained here: #6651 (comment)
I'm not confidence that RegressorMixin
/ ClassifierMixin
won't add a _more_tags()
, and I don't want those to silently override LightGBM's preferred tags because we follow scikit-learn
's advice to put mixins first in the MRO.
Also, interesting that Classifier tags are needed in the Regressor class
Thank you! This was a copy-paste mistake. Fixed in d1915c0.
It wasn't caught by tests because the tags for LGBMClassifier
, LGBMRegressor
, and LGBMRanker
happen to be the same today.
"check_n_features_in_after_fitting": ( | ||
"validate_data() was first added in scikit-learn 1.6 and lightgbm" | ||
"supports much older versions than that" | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Understand you don't want to use validate_data
here, but you can still conform to the API with your own tools.
You probably also want to make sure you store n_feature_in_
as well, to better imitate sklearn's behavior.
I would personally go down the fixes.py
path though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you all guys for working on this PR!
Generally LGTM, except some quite minor comments below:
# sklearn.utils.Tags can be imported unconditionally once | ||
# lightgbm's minimum scikit-learn version is 1.6 or higher | ||
try: | ||
from sklearn.utils import Tags as _sklearn_Tags | ||
except ImportError: | ||
_sklearn_Tags = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this piece of code should go to compat.py
.
"check_n_features_in_after_fitting": ( | ||
"validate_data() was first added in scikit-learn 1.6 and lightgbm" | ||
"supports much older versions than that" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should add in this comment that LightGBM supports predict_disable_shape_check=True
and we won't call validate_data()
even after minimum sklearn version bump.
def __sklearn_tags__(self) -> Optional["_sklearn_Tags"]: | ||
return LGBMModel.__sklearn_tags__(self) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need __sklearn_tags__()
in LGBMRegressor
and LGBMClassifier
due to MRO again, right?
Fixes #6653
Tring to fix latest CI job. Sklearn 1.6.0 dev deprecates
BaseEstimator._more_tags_()
for__sklearn_tags__
see https://scikit-learn.org/dev/whats_new/v1.6.html and scikit-learn/scikit-learn#29677