From 9a749bdcb2be578c387f00c067bade56e8ae7539 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 22 Jan 2025 14:37:17 +0100 Subject: [PATCH] DOC impact of stratification on the target class in cross-validation splitters (#30576) Co-authored-by: Christian Lorentzen Co-authored-by: antoinebaker --- doc/modules/cross_validation.rst | 33 ++++++++++++++++++++++----- sklearn/model_selection/_split.py | 37 +++++++++++++++++++++++++------ 2 files changed, 57 insertions(+), 13 deletions(-) diff --git a/doc/modules/cross_validation.rst b/doc/modules/cross_validation.rst index ee6d7180728a7..bffa1f2727650 100644 --- a/doc/modules/cross_validation.rst +++ b/doc/modules/cross_validation.rst @@ -523,12 +523,33 @@ the proportion of samples on each side of the train / test split. Cross-validation iterators with stratification based on class labels -------------------------------------------------------------------- -Some classification problems can exhibit a large imbalance in the distribution -of the target classes: for instance there could be several times more negative -samples than positive samples. In such cases it is recommended to use -stratified sampling as implemented in :class:`StratifiedKFold` and -:class:`StratifiedShuffleSplit` to ensure that relative class frequencies is -approximately preserved in each train and validation fold. +Some classification tasks can naturally exhibit rare classes: for instance, +there could be orders of magnitude more negative observations than positive +observations (e.g. medical screening, fraud detection, etc). As a result, +cross-validation splitting can generate train or validation folds without any +occurence of a particular class. This typically leads to undefined +classification metrics (e.g. ROC AUC), exceptions raised when attempting to +call :term:`fit` or missing columns in the output of the `predict_proba` or +`decision_function` methods of multiclass classifiers trained on different +folds. + +To mitigate such problems, splitters such as :class:`StratifiedKFold` and +:class:`StratifiedShuffleSplit` implement stratified sampling to ensure that +relative class frequencies are approximately preserved in each fold. + +.. note:: + + Stratified sampling was introduced in scikit-learn to workaround the + aforementioned engineering problems rather than solve a statistical one. + + Stratification makes cross-validation folds more homogeneous, and as a result + hides some of the variability inherent to fitting models with a limited + number of observations. + + As a result, stratification can artificially shrink the spread of the metric + measured across cross-validation iterations: the inter-fold variability does + no longer reflect the uncertainty in the performance of classifiers in the + presence of rare classes. .. _stratified_k_fold: diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index 04520c059159c..5501513d114e1 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -684,13 +684,14 @@ def split(self, X, y=None, groups=None): class StratifiedKFold(_BaseKFold): - """Stratified K-Fold cross-validator. + """Class-wise stratified K-Fold cross-validator. Provides train/test indices to split data in train/test sets. This cross-validation object is a variation of KFold that returns stratified folds. The folds are made by preserving the percentage of - samples for each class. + samples for each class in `y` in a binary or multiclass classification + setting. Read more in the :ref:`User Guide `. @@ -698,6 +699,11 @@ class StratifiedKFold(_BaseKFold): comparison between common scikit-learn split methods refer to :ref:`sphx_glr_auto_examples_model_selection_plot_cv_indices.py` + .. note:: + + Stratification on the class label solves an engineering problem rather + than a statistical one. See :ref:`stratification` for more details. + Parameters ---------- n_splits : int, default=5 @@ -883,11 +889,12 @@ def split(self, X, y, groups=None): class StratifiedGroupKFold(GroupsConsumerMixin, _BaseKFold): - """Stratified K-Fold iterator variant with non-overlapping groups. + """Class-wise stratified K-Fold iterator variant with non-overlapping groups. This cross-validation object is a variation of StratifiedKFold attempts to return stratified folds with non-overlapping groups. The folds are made by - preserving the percentage of samples for each class. + preserving the percentage of samples for each class in `y` in a binary or + multiclass classification setting. Each group will appear exactly once in the test set across all folds (the number of distinct groups has to be at least equal to the number of folds). @@ -906,6 +913,11 @@ class StratifiedGroupKFold(GroupsConsumerMixin, _BaseKFold): comparison between common scikit-learn split methods refer to :ref:`sphx_glr_auto_examples_model_selection_plot_cv_indices.py` + .. note:: + + Stratification on the class label solves an engineering problem rather + than a statistical one. See :ref:`stratification` for more details. + Parameters ---------- n_splits : int, default=5 @@ -1726,13 +1738,18 @@ def __init__(self, *, n_splits=5, n_repeats=10, random_state=None): class RepeatedStratifiedKFold(_UnsupportedGroupCVMixin, _RepeatedSplits): - """Repeated Stratified K-Fold cross validator. + """Repeated class-wise stratified K-Fold cross validator. Repeats Stratified K-Fold n times with different randomization in each repetition. Read more in the :ref:`User Guide `. + .. note:: + + Stratification on the class label solves an engineering problem rather + than a statistical one. See :ref:`stratification` for more details. + Parameters ---------- n_splits : int, default=5 @@ -2204,13 +2221,14 @@ def split(self, X, y=None, groups=None): class StratifiedShuffleSplit(BaseShuffleSplit): - """Stratified ShuffleSplit cross-validator. + """Class-wise stratified ShuffleSplit cross-validator. Provides train/test indices to split data in train/test sets. This cross-validation object is a merge of :class:`StratifiedKFold` and :class:`ShuffleSplit`, which returns stratified randomized folds. The folds - are made by preserving the percentage of samples for each class. + are made by preserving the percentage of samples for each class in `y` in a + binary or multiclass classification setting. Note: like the :class:`ShuffleSplit` strategy, stratified random splits do not guarantee that test sets across all folds will be mutually exclusive, @@ -2223,6 +2241,11 @@ class StratifiedShuffleSplit(BaseShuffleSplit): comparison between common scikit-learn split methods refer to :ref:`sphx_glr_auto_examples_model_selection_plot_cv_indices.py` + .. note:: + + Stratification on the class label solves an engineering problem rather + than a statistical one. See :ref:`stratification` for more details. + Parameters ---------- n_splits : int, default=10