Skip to content

Commit

Permalink
DOC impact of stratification on the target class in cross-validation …
Browse files Browse the repository at this point in the history
…splitters (scikit-learn#30576)

Co-authored-by: Christian Lorentzen <lorentzen.ch@gmail.com>
Co-authored-by: antoinebaker <antoinebaker@users.noreply.github.com>
  • Loading branch information
3 people authored Jan 22, 2025
1 parent fef4701 commit 9a749bd
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 13 deletions.
33 changes: 27 additions & 6 deletions doc/modules/cross_validation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -523,12 +523,33 @@ the proportion of samples on each side of the train / test split.
Cross-validation iterators with stratification based on class labels
--------------------------------------------------------------------

Some classification problems can exhibit a large imbalance in the distribution
of the target classes: for instance there could be several times more negative
samples than positive samples. In such cases it is recommended to use
stratified sampling as implemented in :class:`StratifiedKFold` and
:class:`StratifiedShuffleSplit` to ensure that relative class frequencies is
approximately preserved in each train and validation fold.
Some classification tasks can naturally exhibit rare classes: for instance,
there could be orders of magnitude more negative observations than positive
observations (e.g. medical screening, fraud detection, etc). As a result,
cross-validation splitting can generate train or validation folds without any
occurence of a particular class. This typically leads to undefined
classification metrics (e.g. ROC AUC), exceptions raised when attempting to
call :term:`fit` or missing columns in the output of the `predict_proba` or
`decision_function` methods of multiclass classifiers trained on different
folds.

To mitigate such problems, splitters such as :class:`StratifiedKFold` and
:class:`StratifiedShuffleSplit` implement stratified sampling to ensure that
relative class frequencies are approximately preserved in each fold.

.. note::

Stratified sampling was introduced in scikit-learn to workaround the
aforementioned engineering problems rather than solve a statistical one.

Stratification makes cross-validation folds more homogeneous, and as a result
hides some of the variability inherent to fitting models with a limited
number of observations.

As a result, stratification can artificially shrink the spread of the metric
measured across cross-validation iterations: the inter-fold variability does
no longer reflect the uncertainty in the performance of classifiers in the
presence of rare classes.

.. _stratified_k_fold:

Expand Down
37 changes: 30 additions & 7 deletions sklearn/model_selection/_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,20 +684,26 @@ def split(self, X, y=None, groups=None):


class StratifiedKFold(_BaseKFold):
"""Stratified K-Fold cross-validator.
"""Class-wise stratified K-Fold cross-validator.
Provides train/test indices to split data in train/test sets.
This cross-validation object is a variation of KFold that returns
stratified folds. The folds are made by preserving the percentage of
samples for each class.
samples for each class in `y` in a binary or multiclass classification
setting.
Read more in the :ref:`User Guide <stratified_k_fold>`.
For visualisation of cross-validation behaviour and
comparison between common scikit-learn split methods
refer to :ref:`sphx_glr_auto_examples_model_selection_plot_cv_indices.py`
.. note::
Stratification on the class label solves an engineering problem rather
than a statistical one. See :ref:`stratification` for more details.
Parameters
----------
n_splits : int, default=5
Expand Down Expand Up @@ -883,11 +889,12 @@ def split(self, X, y, groups=None):


class StratifiedGroupKFold(GroupsConsumerMixin, _BaseKFold):
"""Stratified K-Fold iterator variant with non-overlapping groups.
"""Class-wise stratified K-Fold iterator variant with non-overlapping groups.
This cross-validation object is a variation of StratifiedKFold attempts to
return stratified folds with non-overlapping groups. The folds are made by
preserving the percentage of samples for each class.
preserving the percentage of samples for each class in `y` in a binary or
multiclass classification setting.
Each group will appear exactly once in the test set across all folds (the
number of distinct groups has to be at least equal to the number of folds).
Expand All @@ -906,6 +913,11 @@ class StratifiedGroupKFold(GroupsConsumerMixin, _BaseKFold):
comparison between common scikit-learn split methods
refer to :ref:`sphx_glr_auto_examples_model_selection_plot_cv_indices.py`
.. note::
Stratification on the class label solves an engineering problem rather
than a statistical one. See :ref:`stratification` for more details.
Parameters
----------
n_splits : int, default=5
Expand Down Expand Up @@ -1726,13 +1738,18 @@ def __init__(self, *, n_splits=5, n_repeats=10, random_state=None):


class RepeatedStratifiedKFold(_UnsupportedGroupCVMixin, _RepeatedSplits):
"""Repeated Stratified K-Fold cross validator.
"""Repeated class-wise stratified K-Fold cross validator.
Repeats Stratified K-Fold n times with different randomization in each
repetition.
Read more in the :ref:`User Guide <repeated_k_fold>`.
.. note::
Stratification on the class label solves an engineering problem rather
than a statistical one. See :ref:`stratification` for more details.
Parameters
----------
n_splits : int, default=5
Expand Down Expand Up @@ -2204,13 +2221,14 @@ def split(self, X, y=None, groups=None):


class StratifiedShuffleSplit(BaseShuffleSplit):
"""Stratified ShuffleSplit cross-validator.
"""Class-wise stratified ShuffleSplit cross-validator.
Provides train/test indices to split data in train/test sets.
This cross-validation object is a merge of :class:`StratifiedKFold` and
:class:`ShuffleSplit`, which returns stratified randomized folds. The folds
are made by preserving the percentage of samples for each class.
are made by preserving the percentage of samples for each class in `y` in a
binary or multiclass classification setting.
Note: like the :class:`ShuffleSplit` strategy, stratified random splits
do not guarantee that test sets across all folds will be mutually exclusive,
Expand All @@ -2223,6 +2241,11 @@ class StratifiedShuffleSplit(BaseShuffleSplit):
comparison between common scikit-learn split methods
refer to :ref:`sphx_glr_auto_examples_model_selection_plot_cv_indices.py`
.. note::
Stratification on the class label solves an engineering problem rather
than a statistical one. See :ref:`stratification` for more details.
Parameters
----------
n_splits : int, default=10
Expand Down

0 comments on commit 9a749bd

Please sign in to comment.