Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add ComparisonReport to compare instances of EstimatorReport #1286

Draft
wants to merge 47 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
1596cc1
Squashed commit of the following:
thomass-dev Feb 4, 2025
f5296be
Improve warning messages
thomass-dev Feb 5, 2025
21d2bd5
mob session
auguste-probabl Feb 5, 2025
4a2bdd4
add docstring
auguste-probabl Feb 5, 2025
345cdde
docs: add example
MarieS-WiMLDS Jan 27, 2025
491f6f8
rename example file
auguste-probabl Feb 5, 2025
a142cb2
fix circular import
auguste-probabl Feb 5, 2025
beb6cbf
update example
auguste-probabl Feb 5, 2025
b6774ba
attempt to fix docs
auguste-probabl Feb 5, 2025
5fe9e2c
attempt to fix docs
auguste-probabl Feb 5, 2025
b53924c
Squashed commit of the following:
thomass-dev Feb 6, 2025
638e871
resolve merge
sylvaincom Feb 6, 2025
5835ca5
Merge branch 'comparator' of https://github.com/probabl-ai/skore into…
sylvaincom Feb 6, 2025
35ffef2
iter doc example
sylvaincom Feb 6, 2025
a37380d
iter doc example
sylvaincom Feb 6, 2025
6952abe
Squashed commit of the following:
thomass-dev Feb 6, 2025
5478865
Disallow plot from comparison reports on multi-class
thomass-dev Feb 6, 2025
bf89597
Merge branch 'main' into comparator
thomass-dev Feb 7, 2025
d0a4853
Clean ComparisonReport tests
thomass-dev Feb 7, 2025
a202bf7
Fix ComparisonReport tests for python 3.9
thomass-dev Feb 7, 2025
cca35c9
Fix minor typing mistake
thomass-dev Feb 7, 2025
c15a3fd
Merge branch 'main' into comparator
thomass-dev Feb 7, 2025
f82a32a
adding more stuff in the example
sylvaincom Feb 7, 2025
f54078d
docs: Add ComparisonReport to API docs
auguste-probabl Feb 7, 2025
6a499f8
fix link in docs
auguste-probabl Feb 7, 2025
869b8e9
iter on the doc example
sylvaincom Feb 7, 2025
f44c49b
Merge branch 'main' into comparator
thomass-dev Feb 10, 2025
689820e
Allow different training datasets
thomass-dev Feb 10, 2025
3f04b57
Disallow comparator of estimators without testing data
thomass-dev Feb 10, 2025
784c022
Rebase with last changes on plot API introduced by Guillaume
thomass-dev Feb 10, 2025
b91cdd0
Update sphinx and examples
thomass-dev Feb 10, 2025
33b050b
Update last failing example
thomass-dev Feb 10, 2025
cdef817
Merge branch 'main' into comparator
thomass-dev Feb 10, 2025
4f59748
replace `usecase` by fixtures
auguste-probabl Feb 11, 2025
2ed2770
catch type error explicitly
auguste-probabl Feb 11, 2025
37971f7
remove test
auguste-probabl Feb 11, 2025
757a37a
collect all ml tasks in error message
auguste-probabl Feb 11, 2025
f076274
refactor
auguste-probabl Feb 11, 2025
a1ae2b7
move type checking out of loop
auguste-probabl Feb 11, 2025
c77bb64
collect hashes in error message
auguste-probabl Feb 11, 2025
4ec4c69
refine error message about length of `report_names`
auguste-probabl Feb 11, 2025
0c13905
add report_names_ to list of attributes
auguste-probabl Feb 11, 2025
0f2504e
fix "see also" reference
auguste-probabl Feb 11, 2025
433120e
add cross-validation report to "see also"
auguste-probabl Feb 11, 2025
8a3d974
remove dead code
auguste-probabl Feb 11, 2025
9e24be2
use short imports in "see also"
auguste-probabl Feb 11, 2025
a630d27
use short imports in "see also"
auguste-probabl Feb 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 20 additions & 24 deletions skore/src/skore/sklearn/_comparison/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,30 @@ def __init__(
We check that the estimator reports can be compared:
- all reports are estimator reports,
- all estimators are in the same ML use case,
- all X_test, y_test have the same hash.
- all estimators have non-empty X_test and y_test,
- all estimators have the same X_test and y_test.
"""
if len(reports) < 2:
raise ValueError("At least 2 instances of EstimatorReport are needed")

if not all(isinstance(report, EstimatorReport) for report in reports):
raise TypeError("Only instances of EstimatorReport are allowed")
ml_tasks = set()
test_dataset_hashes = set()

for report in reports:
if not isinstance(report, EstimatorReport):
raise TypeError("Only instances of EstimatorReport are allowed")
auguste-probabl marked this conversation as resolved.
Show resolved Hide resolved

if (report.X_test is None) or (report.y_test is None):
raise ValueError("Cannot compare reports without testing data")
Comment on lines +113 to +114
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, you should be able using the external data_source. I would expect something like this to work on the side of the end-user.

comparator.metrics.report_metrics(data_source="X_y", X=X, y=y)


ml_tasks.add(report._ml_task)
test_dataset_hashes.add(joblib.hash((report.X_test, report.y_test)))

if len(ml_tasks) > 1:
raise ValueError("Not all estimators are in the same ML usecase")
auguste-probabl marked this conversation as resolved.
Show resolved Hide resolved

if len(test_dataset_hashes) > 1:
raise ValueError("Not all estimators have the same testing data")
thomass-dev marked this conversation as resolved.
Show resolved Hide resolved

if report_names is None:
self.report_names_ = [report.estimator_name_ for report in reports]
Expand All @@ -124,34 +141,13 @@ def __init__(

self.estimator_reports_ = deepcopy(reports)

first_report = self.estimator_reports_[0]
first_report_ml_task = first_report._ml_task
first_report_test_hash = joblib.hash((first_report.X_test, first_report.y_test))

for report in self.estimator_reports_[1:]:
if report._ml_task != first_report_ml_task:
raise ValueError("Not all estimators are in the same ML usecase")

if joblib.hash((report.X_test, report.y_test)) != first_report_test_hash:
raise ValueError("Not all estimators have the same testing data")

if (first_report.X_test is None) or (first_report.y_test is None):
warn(
"MissingTestDataWarning",
(
"We cannot ensure that all estimators have been tested "
"with the same dataset. This could lead to incoherent comparisons."
),
)

# NEEDED FOR METRICS ACCESSOR
self.n_jobs = n_jobs
self._rng = np.random.default_rng(time.time_ns())
self._hash = self._rng.integers(
low=np.iinfo(np.int64).min, high=np.iinfo(np.int64).max
)
self._cache = {}

self._ml_task = self.estimator_reports_[0]._ml_task

####################################################################################
Expand Down
89 changes: 60 additions & 29 deletions skore/tests/unit/sklearn/test_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,13 @@ def usecase(
def test_comparison_report_init_wrong_parameters():
"""If the input is not valid, raise."""

estimator, _, _, _, _ = usecase("binary-logistic-regression")
estimator_report = EstimatorReport(estimator, fit=False)
estimator, _, X_test, _, y_test = usecase("binary-logistic-regression")
estimator_report = EstimatorReport(
estimator,
fit=False,
X_test=X_test,
y_test=y_test,
)

with pytest.raises(
TypeError, match="object of type 'EstimatorReport' has no len()"
auguste-probabl marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -59,8 +64,14 @@ def test_comparison_report_init_wrong_parameters():
def test_comparison_report_init_deepcopy():
"""If an estimator report is modified outside of the comparator, it is not modified
inside the comparator."""
estimator, _, _, _, _ = usecase("binary-logistic-regression")
estimator_report = EstimatorReport(estimator, fit=False)
estimator, _, X_test, _, y_test = usecase("binary-logistic-regression")
estimator_report = EstimatorReport(
estimator,
fit=False,
X_test=X_test,
y_test=y_test,
)

comp = ComparisonReport([estimator_report, estimator_report])

# check if the deepcopy work well
Expand All @@ -74,33 +85,33 @@ def test_comparison_report_init_deepcopy():
assert comp.estimator_reports_[0]._hash != 0


def test_comparison_report_init_MissingTestDataWarning(capsys):
def test_comparison_report_init_without_testing_data():
"""Raise a warning if there is no test data (`None`) for any estimator
report."""
estimator, _, _, _, _ = usecase("binary-logistic-regression")
estimator_report = EstimatorReport(estimator, fit=False)

estimator, X_train, _, y_train, _ = usecase("binary-logistic-regression")
estimator_report = EstimatorReport(
estimator,
fit=False,
X_train=X_train,
y_train=y_train,
)

ComparisonReport([estimator_report, estimator_report])

captured = capsys.readouterr()

assert "MissingTestDataWarning" in captured.out
with pytest.raises(ValueError, match="Cannot compare reports without testing data"):
ComparisonReport([estimator_report, estimator_report])


def test_comparison_report_init_different_ml_usecases():
linear_regression_estimator, _, _, _, _ = usecase("linear-regression")
linear_regression_report = EstimatorReport(linear_regression_estimator, fit=False)
linear_regression_estimator, _, X_test, _, y_test = usecase("linear-regression")
linear_regression_report = EstimatorReport(
linear_regression_estimator,
fit=False,
X_test=X_test,
y_test=y_test,
)

logistic_regression_estimator, _, _, _, _ = usecase("binary-logistic-regression")
logistic_regression_estimator, _, X_test, _, y_test = usecase(
"binary-logistic-regression"
)
logistic_regression_report = EstimatorReport(
logistic_regression_estimator,
fit=False,
X_test=X_test,
y_test=y_test,
)

with pytest.raises(
Expand Down Expand Up @@ -175,8 +186,13 @@ def test_comparison_report_init_without_report_names():


def test_comparison_report_init_with_invalid_report_names():
estimator, _, _, _, _ = usecase("binary-logistic-regression")
estimator_report = EstimatorReport(estimator, fit=False)
estimator, _, X_test, _, y_test = usecase("binary-logistic-regression")
estimator_report = EstimatorReport(
estimator,
fit=False,
X_test=X_test,
y_test=y_test,
)

with pytest.raises(
ValueError, match="There should be as many report names as there are reports"
Expand All @@ -185,17 +201,27 @@ def test_comparison_report_init_with_invalid_report_names():


def test_comparison_report_help(capsys):
estimator, _, _, _, _ = usecase("binary-logistic-regression")
estimator_report = EstimatorReport(estimator, fit=False)
estimator, _, X_test, _, y_test = usecase("binary-logistic-regression")
estimator_report = EstimatorReport(
estimator,
fit=False,
X_test=X_test,
y_test=y_test,
)

ComparisonReport([estimator_report, estimator_report]).help()

assert "Tools to compare estimators" in capsys.readouterr().out


def test_comparison_report_repr():
estimator, _, _, _, _ = usecase("binary-logistic-regression")
estimator_report = EstimatorReport(estimator, fit=False)
estimator, _, X_test, _, y_test = usecase("binary-logistic-regression")
estimator_report = EstimatorReport(
estimator,
fit=False,
X_test=X_test,
y_test=y_test,
)

repr_str = repr(ComparisonReport([estimator_report, estimator_report]))

Expand All @@ -205,8 +231,13 @@ def test_comparison_report_repr():

def test_comparison_report_pickle(tmp_path):
"""Check that we can pickle a comparison report."""
estimator, _, _, _, _ = usecase("binary-logistic-regression")
estimator_report = EstimatorReport(estimator, fit=False)
estimator, _, X_test, _, y_test = usecase("binary-logistic-regression")
estimator_report = EstimatorReport(
estimator,
fit=False,
X_test=X_test,
y_test=y_test,
)

with BytesIO() as stream:
joblib.dump(ComparisonReport([estimator_report, estimator_report]), stream)
Expand Down
Loading