From a47d164c82c347cb4f4b115b12cc79f1571c25b6 Mon Sep 17 00:00:00 2001 From: Ghislain Piot Date: Fri, 3 May 2024 10:50:43 +0200 Subject: [PATCH] Correct some of the Scikit-learn stubs --- stubs/sklearn/metrics/_classification.pyi | 2 +- stubs/sklearn/model_selection/_split.pyi | 2 +- stubs/sklearn/utils/__init__.pyi | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/stubs/sklearn/metrics/_classification.pyi b/stubs/sklearn/metrics/_classification.pyi index 80d06354..94b95746 100644 --- a/stubs/sklearn/metrics/_classification.pyi +++ b/stubs/sklearn/metrics/_classification.pyi @@ -112,7 +112,7 @@ def precision_recall_fscore_support( labels: None | ArrayLike = None, pos_label: str | int = 1, average: None | Literal["binary", "micro", "macro", "samples", "weighted"] = None, - warn_for: set | tuple = ..., + warn_for: list | set | tuple = ..., sample_weight: None | ArrayLike = None, zero_division: Literal["warn", "warn"] | int = "warn", ) -> tuple[float | ndarray, float | ndarray, float | ndarray, None | ndarray]: ... diff --git a/stubs/sklearn/model_selection/_split.pyi b/stubs/sklearn/model_selection/_split.pyi index f0841a13..885d8fa8 100644 --- a/stubs/sklearn/model_selection/_split.pyi +++ b/stubs/sklearn/model_selection/_split.pyi @@ -204,7 +204,7 @@ class _CVIterableWrapper(BaseCrossValidator): def split(self, X: Any = None, y: Any = None, groups: Any = None): ... def check_cv( - cv: Iterable | int | BaseShuffleSplit | BaseCrossValidator = 5, + cv: Iterable | int | BaseShuffleSplit | BaseCrossValidator | None = 5, y: None | ArrayLike = None, *, classifier: bool = False, diff --git a/stubs/sklearn/utils/__init__.pyi b/stubs/sklearn/utils/__init__.pyi index 5d2d6ed9..61432283 100644 --- a/stubs/sklearn/utils/__init__.pyi +++ b/stubs/sklearn/utils/__init__.pyi @@ -83,8 +83,8 @@ def resample( n_samples: None | Int = None, random_state: RandomState | None | Int = None, stratify: None | MatrixLike | ArrayLike = None, -) -> list[ndarray]: ... -def shuffle(*arrays, random_state: RandomState | None | Int = None, n_samples: None | Int = None) -> list[SupportsIndex]: ... +) -> list[ndarray] | None: ... +def shuffle(*arrays, random_state: RandomState | None | Int = None, n_samples: None | Int = None) -> list[SupportsIndex] | None: ... def safe_sqr(X: MatrixLike | ArrayLike, *, copy: bool = True) -> ndarray: ... def gen_batches(n: Int, batch_size: Int, *, min_batch_size: Int = 0) -> Iterator[slice]: ... def gen_even_slices(n: Int, n_packs: Int, *, n_samples: None | Int = None) -> Iterator[slice]: ...