Skip to content

Commit

Permalink
feat(train_test_split): Add "shuffle is True" warning (#791)
Browse files Browse the repository at this point in the history
Addresses #684
  • Loading branch information
augustebaum authored Nov 21, 2024
1 parent 4335f13 commit 09de4b6
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 17 deletions.
5 changes: 4 additions & 1 deletion skore/src/skore/sklearn/train_test_split/warning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
HighClassImbalanceWarning,
)
from .random_state_unset_warning import RandomStateUnsetWarning
from .stratify_is_set import StratifyWarning
from .shuffle_true_warning import ShuffleTrueWarning
from .stratify_is_set_warning import StratifyWarning

TRAIN_TEST_SPLIT_WARNINGS = [
HighClassImbalanceTooFewExamplesWarning,
HighClassImbalanceWarning,
StratifyWarning,
RandomStateUnsetWarning,
ShuffleTrueWarning,
]

__all__ = [
Expand All @@ -26,4 +28,5 @@
"HighClassImbalanceWarning",
"StratifyWarning",
"RandomStateUnsetWarning",
"ShuffleTrueWarning",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""'Shuffle is true' warning.
This warning is shown when ``shuffle`` is set to True.
"""

from __future__ import annotations

from skore.sklearn.train_test_split.warning.train_test_split_warning import (
TrainTestSplitWarning,
)


class ShuffleTrueWarning(TrainTestSplitWarning):
"""Check whether ``shuffle`` is set to ``True``."""

MSG = (
"We detected that the `shuffle` parameter is set to `True` either explicitly "
"or from its default value. In case of time-ordered events (even if they are "
"independent), this will result in inflated model performance evaluation "
"because natural drift will not be taken into account. We recommend setting "
"the shuffle parameter to `False` in order to ensure the evaluation process is "
"really representative of your production release process."
)

@staticmethod
def check(
shuffle: bool,
**kwargs,
) -> bool:
"""Check whether ``shuffle`` is set to ``True``.
Parameters
----------
shuffle : bool
Whether to shuffle the data before splitting.
Returns
-------
bool
True if the check passed, False otherwise.
"""
return shuffle is False
31 changes: 15 additions & 16 deletions skore/tests/unit/sklearn/train_test_split/test_train_test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
HighClassImbalanceTooFewExamplesWarning,
HighClassImbalanceWarning,
RandomStateUnsetWarning,
ShuffleTrueWarning,
StratifyWarning,
)

Expand Down Expand Up @@ -49,6 +50,18 @@ def case_random_state_unset():
return args, kwargs, RandomStateUnsetWarning


def case_shuffle_true():
args = ([[1]] * 4, [0, 1, 1, 1])
kwargs = dict(shuffle=True)
return args, kwargs, ShuffleTrueWarning


def case_shuffle_none():
args = ([[1]] * 4, [0, 1, 1, 1])
kwargs = {}
return args, kwargs, ShuffleTrueWarning


@pytest.mark.parametrize(
"params",
[
Expand All @@ -58,6 +71,8 @@ def case_random_state_unset():
case_high_class_imbalance_too_few_examples_kwargs_mixed,
case_stratify,
case_random_state_unset,
case_shuffle_true,
case_shuffle_none,
],
)
def test_train_test_split_warns(params):
Expand All @@ -70,22 +85,6 @@ def test_train_test_split_warns(params):
train_test_split(*args, **kwargs)


def test_train_test_split_no_y():
"""When calling `train_test_split` with one array argument,
this array is assumed to be `X` and not `y`."""
warnings.simplefilter("error")

# Since the array is `X` and we do no checks on it, this should produce no
# warning
train_test_split([[1]] * 4, random_state=0)


def test_train_test_split_no_warn():
warnings.simplefilter("error")

train_test_split([[1]] * 2000, [0] * 1000 + [1] * 1000, random_state=0)


def test_train_test_split_kwargs():
"""Passing data by keyword arguments should produce the same results as passing
them by position."""
Expand Down

0 comments on commit 09de4b6

Please sign in to comment.