Skip to content

Commit

Permalink
ENH Propagate main process warning filters to joblib workers (scikit-…
Browse files Browse the repository at this point in the history
…learn#30380)

Co-authored-by: Loïc Estève <loic.esteve@ymail.com>
  • Loading branch information
thomasjpfan and lesteve authored Jan 15, 2025
1 parent 2707099 commit 10253eb
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- Warning filters from the main process are propagated to joblib workers.
By `Thomas Fan`_
30 changes: 19 additions & 11 deletions sklearn/utils/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
_threadpool_controller = None


def _with_config(delayed_func, config):
def _with_config_and_warning_filters(delayed_func, config, warning_filters):
"""Helper function that intends to attach a config to a delayed function."""
if hasattr(delayed_func, "with_config"):
return delayed_func.with_config(config)
if hasattr(delayed_func, "with_config_and_warning_filters"):
return delayed_func.with_config_and_warning_filters(config, warning_filters)
else:
warnings.warn(
(
Expand Down Expand Up @@ -70,11 +70,16 @@ def __call__(self, iterable):
# in a different thread depending on the backend and on the value of
# pre_dispatch and n_jobs.
config = get_config()
iterable_with_config = (
(_with_config(delayed_func, config), args, kwargs)
warning_filters = warnings.filters
iterable_with_config_and_warning_filters = (
(
_with_config_and_warning_filters(delayed_func, config, warning_filters),
args,
kwargs,
)
for delayed_func, args, kwargs in iterable
)
return super().__call__(iterable_with_config)
return super().__call__(iterable_with_config_and_warning_filters)


# remove when https://github.com/joblib/joblib/issues/1071 is fixed
Expand Down Expand Up @@ -118,13 +123,15 @@ def __init__(self, function):
self.function = function
update_wrapper(self, self.function)

def with_config(self, config):
def with_config_and_warning_filters(self, config, warning_filters):
self.config = config
self.warning_filters = warning_filters
return self

def __call__(self, *args, **kwargs):
config = getattr(self, "config", None)
if config is None:
config = getattr(self, "config", {})
warning_filters = getattr(self, "warning_filters", [])
if not config or not warning_filters:
warnings.warn(
(
"`sklearn.utils.parallel.delayed` should be used with"
Expand All @@ -134,8 +141,9 @@ def __call__(self, *args, **kwargs):
),
UserWarning,
)
config = {}
with config_context(**config):

with config_context(**config), warnings.catch_warnings():
warnings.filters = warning_filters
return self.function(*args, **kwargs)


Expand Down
53 changes: 53 additions & 0 deletions sklearn/utils/tests/test_parallel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
import warnings

import joblib
import numpy as np
Expand All @@ -9,6 +10,7 @@
from sklearn.compose import make_column_transformer
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.exceptions import ConvergenceWarning
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
Expand Down Expand Up @@ -98,3 +100,54 @@ def transform(self, X, y=None):
search_cv.fit(iris.data, iris.target)

assert not np.isnan(search_cv.cv_results_["mean_test_score"]).any()


def raise_warning():
warnings.warn("Convergence warning", ConvergenceWarning)


@pytest.mark.parametrize("n_jobs", [1, 2])
@pytest.mark.parametrize("backend", ["loky", "threading", "multiprocessing"])
def test_filter_warning_propagates(n_jobs, backend):
"""Check warning propagates to the job."""
with warnings.catch_warnings():
warnings.simplefilter("error", category=ConvergenceWarning)

with pytest.raises(ConvergenceWarning):
Parallel(n_jobs=n_jobs, backend=backend)(
delayed(raise_warning)() for _ in range(2)
)


def get_warnings():
return warnings.filters


def test_check_warnings_threading():
"""Check that warnings filters are set correctly in the threading backend."""
with warnings.catch_warnings():
warnings.simplefilter("error", category=ConvergenceWarning)

filters = warnings.filters
assert ("error", None, ConvergenceWarning, None, 0) in filters

all_warnings = Parallel(n_jobs=2, backend="threading")(
delayed(get_warnings)() for _ in range(2)
)

assert all(w == filters for w in all_warnings)


def test_filter_warning_propagates_no_side_effect_with_loky_backend():
with warnings.catch_warnings():
warnings.simplefilter("error", category=ConvergenceWarning)

Parallel(n_jobs=2, backend="loky")(delayed(time.sleep)(0) for _ in range(10))

# Since loky workers are reused, make sure that inside the loky workers,
# warnings filters have been reset to their original value. Using joblib
# directly should not turn ConvergenceWarning into an error.
joblib.Parallel(n_jobs=2, backend="loky")(
joblib.delayed(warnings.warn)("Convergence warning", ConvergenceWarning)
for _ in range(10)
)

0 comments on commit 10253eb

Please sign in to comment.