Skip to content

Commit

Permalink
MNT Sklearn1.6 compatibility (#447)
Browse files Browse the repository at this point in the history
* Fix _sgd imports

* Fix _safe_tags import issue

* Change _construct_instance import

* Change get_tags syntax

* Ignore FutureWarning in sklearn

* Update skops/io/_sklearn.py

Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>

* Update skops/io/_sklearn.py

Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>

* fix typo

* Fix variable name inconsitency

* Add clearer message about warning supression

* WIP

* Add explicit typing

* Remove stray WIP with prints

* Fix tags issues

* Update skops/io/_sklearn.py

Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>

* Make the use of SGD models conditional on sklearn version

* Add relative paths to fix import errors

* Add construct_instances for both versions

* Move imports for construct_instances

* Partially make tags work between the two versions

* Tags working with both versions

* Remove typing import

* Attepmt to fix catboost issues

* Skip quantile-forest futurewarning sklearn 1.7

* Supress quantile-foreset warning

* Update spaces/skops_model_card_creator/requirements.txt

Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>

* Update skops/_min_dependencies.py

Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>

* Add error for SGD class and incompatible sklearn version

* Copy code for scikit-learn for est tags

* Fix loss issues

* minor fix

* reduce diff

* annotations import

* work with all instances from _construct_instances

* Refactor get_input()

* trigger CI

* debug CI

* ...

* ...

* ...

* ...

* ...

* ...

---------

Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
  • Loading branch information
TamaraAtanasoska and adrinjalali authored Dec 2, 2024
1 parent fb35674 commit 00f5f07
Show file tree
Hide file tree
Showing 14 changed files with 728 additions and 150 deletions.
24 changes: 13 additions & 11 deletions .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ jobs:
# this is to make the CI run on different sklearn versions
include:
- python: "3.9"
sklearn_version: "1.1"
sklearn_version: ">=1.1,<1.2"
numpy_version: "numpy<2"
- python: "3.10"
sklearn_version: "1.2"
sklearn_version: ">=1.2,<1.3"
numpy_version: "numpy"
- python: "3.11"
sklearn_version: "1.4"
sklearn_version: ">=1.4,<1.5"
numpy_version: "numpy"
- python: "3.12"
sklearn_version: "1.5"
sklearn_version: ">=1.5,<1.6"
numpy_version: "numpy"
- python: "3.13"
sklearn_version: "nightly"
Expand Down Expand Up @@ -59,20 +59,22 @@ jobs:

- name: Install dependencies
run: |
set -x
python -m pip install -U pip
if [ "${{ matrix.os }}" == "macos-latest" ]; then
brew install libomp
fi
pip install "pytest<8"
pip install "${{ matrix.numpy_version }}"
if [ ${{ matrix.sklearn_version }} == "nightly" ];
then pip install --pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple scikit-learn;
else pip install "scikit-learn~=${{ matrix.sklearn_version }}";
if [ ${{ matrix.sklearn_version }} == "nightly" ]; then
pip install --pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple scikit-learn
pip install .[docs,tests]
else
pip install .[docs,tests] "scikit-learn${{ matrix.sklearn_version }}"
fi
pip install .[docs,tests]
pip install black=="23.9.1" ruff=="0.0.292" mypy=="1.6.0"
if [ ${{ matrix.os }} == "ubuntu-latest" ];
then sudo apt install pandoc && pandoc --version;
if [ ${{ matrix.os }} == "ubuntu-latest" ]; then
sudo apt install pandoc && pandoc --version;
fi
python --version
pip --version
Expand All @@ -98,7 +100,7 @@ jobs:
- name: Inference tests (conditional)
if: contains(env.PR_COMMIT_MESSAGE, '[CI inference]')
run: |
python -m pytest -s -v -m "inference" skops/
python -l -m pytest -s -v -m "inference" skops/
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
Expand Down
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ filterwarnings = [
"ignore:DataFrameGroupBy.apply operated on the grouping columns.:DeprecationWarning",
# Ignore Pandas 2.2 warning on PyArrow. It might be reverted in a later release.
"ignore:\\s*Pyarrow will become a required dependency of pandas.*:DeprecationWarning",
# LightGBM sklearn 1.6 deprecation warning, fixed in the next release
"ignore:'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.:FutureWarning",
# RandomForestQuantileRegressor tags deprecation warning in sklearn 1.7
"ignore:The RandomForestQuantileRegressor or classes from which it inherits use `_get_tags` and `_more_tags`:FutureWarning",
# ExtraTreesQuantileRegressor tags deprecation warning in sklearn 1.7
"ignore:The ExtraTreesQuantileRegressor or classes from which it inherits use `_get_tags` and `_more_tags`:FutureWarning",
# BaseEstimator._validate_data deprecation warning in sklearn 1.6 #TODO can be removed when a new release of quantile-forest is out
"ignore:`BaseEstimator._validate_data` is deprecated in 1.6 and will be removed in 1.7:FutureWarning",
]
markers = [
"network: marks tests as requiring internet (deselect with '-m \"not network\"')",
Expand Down
5 changes: 2 additions & 3 deletions scripts/check_file_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from zipfile import ZIP_DEFLATED, ZipFile

import pandas as pd
from sklearn.utils._tags import _safe_tags
from sklearn.utils._testing import set_random_state

import skops.io as sio
Expand All @@ -29,6 +28,7 @@
_tested_estimators,
get_input,
)
from skops.utils._fixes import get_tags

TOPK = 10 # number of largest estimators reported
MAX_ALLOWED_SIZE = 1024 # maximum allowed file size in kb
Expand All @@ -46,8 +46,7 @@ def check_file_size() -> None:
set_random_state(estimator, random_state=0)

X, y = get_input(estimator)
tags = _safe_tags(estimator)
if tags.get("requires_fit", True):
if get_tags(estimator).requires_fit:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", module="sklearn")
if y is not None:
Expand Down
5 changes: 2 additions & 3 deletions scripts/check_persistence_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from typing import Any

import pandas as pd
from sklearn.utils._tags import _safe_tags
from sklearn.utils._testing import set_random_state

import skops.io as sio
Expand All @@ -24,6 +23,7 @@
_tested_estimators,
get_input,
)
from skops.utils._fixes import get_tags

ATOL = 1 # seconds absolute difference allowed at max
NUM_REPS = 10 # number of times the check is repeated
Expand All @@ -43,8 +43,7 @@ def check_persist_performance() -> None:
set_random_state(estimator, random_state=0)

X, y = get_input(estimator)
tags = _safe_tags(estimator)
if tags.get("requires_fit", True):
if get_tags(estimator).requires_fit:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", module="sklearn")
if y is not None:
Expand Down
4 changes: 3 additions & 1 deletion skops/_min_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
# required for persistence tests of external libraries
"lightgbm": ("3", "tests", None),
"xgboost": ("1.6", "tests", None),
"catboost": ("1.0", "tests", None),
# remove python constraint when catboost supports 3.13
# https://github.com/catboost/catboost/issues/2748
"catboost": ("1.0", "tests", 'python_version < "3.13"'),
"fairlearn": ("0.7.0", "docs, tests", None),
"rich": ("12", "tests, rich", None),
}
Expand Down
101 changes: 76 additions & 25 deletions skops/io/_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,94 @@
from typing import Any, Optional, Sequence, Type

from sklearn.cluster import Birch
from sklearn.tree._tree import Tree

from ._general import TypeNode
from ._audit import Node, get_tree
from ._general import TypeNode, unsupported_get_state
from ._protocol import PROTOCOL
from ._utils import LoadContext, SaveContext, get_module, get_state, gettype
from .exceptions import UnsupportedTypeException

try:
# TODO: remove once support for sklearn<1.2 is dropped. See #187
from sklearn.covariance._graph_lasso import _DictWithDeprecatedKeys
except ImportError:
_DictWithDeprecatedKeys = None

from sklearn.linear_model._sgd_fast import (
EpsilonInsensitive,
Hinge,
Huber,
Log,
LossFunction,
ModifiedHuber,
SquaredEpsilonInsensitive,
SquaredHinge,
SquaredLoss,
)
from sklearn.tree._tree import Tree

from ._audit import Node, get_tree
from ._general import unsupported_get_state
from ._utils import LoadContext, SaveContext, get_module, get_state, gettype
from .exceptions import UnsupportedTypeException

ALLOWED_SGD_LOSSES = {
ModifiedHuber,
Hinge,
SquaredHinge,
Log,
SquaredLoss,
Huber,
ALLOWED_LOSSES = {
EpsilonInsensitive,
Hinge,
ModifiedHuber,
SquaredEpsilonInsensitive,
SquaredHinge,
}

try:
# TODO: remove once support for sklearn<1.6 is dropped.
from sklearn.linear_model._sgd_fast import (
Huber,
Log,
SquaredLoss,
)

ALLOWED_LOSSES |= {
Huber,
Log,
SquaredLoss,
}
except ImportError:
pass

try:
# sklearn>=1.6
from sklearn._loss._loss import (
CyAbsoluteError,
CyExponentialLoss,
CyHalfBinomialLoss,
CyHalfGammaLoss,
CyHalfMultinomialLoss,
CyHalfPoissonLoss,
CyHalfSquaredError,
CyHalfTweedieLoss,
CyHalfTweedieLossIdentity,
CyHuberLoss,
CyPinballLoss,
)

ALLOWED_LOSSES |= {
CyAbsoluteError,
CyExponentialLoss,
CyHalfBinomialLoss,
CyHalfGammaLoss,
CyHalfMultinomialLoss,
CyHalfPoissonLoss,
CyHalfSquaredError,
CyHalfTweedieLoss,
CyHalfTweedieLossIdentity,
CyHuberLoss,
CyPinballLoss,
}
except ImportError:
pass

# This import is for the parent class of all loss functions, which is used to
# set the dispatch function for all loss functions.
try:
# From sklearn>=1.6
from sklearn._loss._loss import CyLossFunction as ParentLossClass
except ImportError:
# sklearn<1.6
from sklearn.linear_model._sgd_fast import LossFunction as ParentLossClass


UNSUPPORTED_TYPES = {Birch}


Expand Down Expand Up @@ -163,13 +213,13 @@ def __init__(
super().__init__(state, load_context, constructor=Tree, trusted=self.trusted)


def sgd_loss_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
def loss_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
state = reduce_get_state(obj, save_context)
state["__loader__"] = "SGDNode"
state["__loader__"] = "LossNode"
return state


class SGDNode(ReduceNode):
class LossNode(ReduceNode):
def __init__(
self,
state: dict[str, Any],
Expand All @@ -178,7 +228,7 @@ def __init__(
) -> None:
# TODO: make sure trusted here makes sense and used.
self.trusted = self._get_trusted(
trusted, [get_module(x) + "." + x.__name__ for x in ALLOWED_SGD_LOSSES]
trusted, [get_module(x) + "." + x.__name__ for x in ALLOWED_LOSSES]
)
super().__init__(
state,
Expand Down Expand Up @@ -240,15 +290,16 @@ def _construct(self):

# tuples of type and function that gets the state of that type
GET_STATE_DISPATCH_FUNCTIONS = [
(LossFunction, sgd_loss_get_state),
(ParentLossClass, loss_get_state),
(Tree, tree_get_state),
]

for type_ in UNSUPPORTED_TYPES:
GET_STATE_DISPATCH_FUNCTIONS.append((type_, unsupported_get_state))

# tuples of type and function that creates the instance of that type
NODE_TYPE_MAPPING = {
("SGDNode", PROTOCOL): SGDNode,
NODE_TYPE_MAPPING: dict[tuple[str, int], Any] = {
("LossNode", PROTOCOL): LossNode,
("TreeNode", PROTOCOL): TreeNode,
}

Expand Down
Loading

0 comments on commit 00f5f07

Please sign in to comment.