From 97a81eaae163073e99e14106271288fe7b427987 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 22 Aug 2024 16:29:59 +0200 Subject: [PATCH 1/4] Fix warning from torch.load starting in torch 2.4 See discussion in #1063 Starting from PyTorch 2.4, there is a warning when torch.load is called without setting the weights_only argument. This is because in the future, the default will switch from False to True, which can result in a lot of errors when trying to load torch files (which are pickle files and thus insecure). In this PR, we add a possibility for the user to influence the kwargs passed to torch.load so that they can control that behavior. If not further indicated by the user, we will use the same defaults as the installed torch version. Therefore, users will only encounter this issue via skorch if they would have encountered it via torch anyway. Since it's not 100% certain if the default will switch in torch 2.6.0, we may have to adjust the version check in the future. Besides directly testing the kwargs being passed on, a test was also added that net.load_params does not give any warnings. This is already indirectly tested through some accelerate tests that are currently failing with torch 2.4, but it's better to have an explicit test. After this is merged, the CI should pass when using torch 2.4.0. --- skorch/net.py | 39 ++++++++++++++++++- skorch/tests/test_net.py | 81 ++++++++++++++++++++++++++++++++++++++++ skorch/utils.py | 14 +++++++ 3 files changed, 133 insertions(+), 1 deletion(-) diff --git a/skorch/net.py b/skorch/net.py index 32dfec7b2..e0a4a632e 100644 --- a/skorch/net.py +++ b/skorch/net.py @@ -46,6 +46,7 @@ from skorch.utils import to_device from skorch.utils import to_numpy from skorch.utils import to_tensor +from skorch.utils import check_torch_weights_only_default_true # pylint: disable=too-many-instance-attributes @@ -235,6 +236,33 @@ class NeuralNet: callbacks. Implementation note: It is the job of the callbacks to honor this setting. + torch_load_kwargs : dict or None (default=None) + Additional arguments that will be passed to torch.load when load pickled + parameters. + + In particular, this is important to because PyTorch will switch (probably + in version 2.6.0) to only allow weights to be loaded for security reasons + (i.e weights_only switches from False to True). As a consequence, loading + pickled parameters may raise an error after upgrading torch because some + types are used that are considered insecure. In skorch, we will also make + that switch at the same time. To resolve the error, follow the + instructions in the torch error message to designate the offending types + as secure. Only do this if you trust the source of the file. + + If you want to keep loading non-weight types the same way as before, + please pass: + + torch_load_kwargs={'weights_only': False} + + You should be aware that this is considered insecure and should only be + used if you trust the source of the file. However, this does not introduce + new insecurities, it rather corresponds to the status quo from before + torch made the switch. + + Another way to avoid this issue is to pass use_safetensors=True when + calling save_params and load_params. This avoid using pickle in favor of + the safetensors format, which is secure by design. + Attributes ---------- prefixes_ : list of str @@ -311,6 +339,7 @@ def __init__( device='cpu', compile=False, use_caching='auto', + torch_load_kwargs=None, **kwargs ): self.module = module @@ -330,6 +359,7 @@ def __init__( self.device = device self.compile = compile self.use_caching = use_caching + self.torch_load_kwargs = torch_load_kwargs self._check_deprecated_params(**kwargs) history = kwargs.pop('history', None) @@ -2620,10 +2650,17 @@ def _get_state_dict(f_name): return state_dict else: + torch_load_kwargs = self.torch_load_kwargs + if torch_load_kwargs is None: + if check_torch_weights_only_default_true(): + torch_load_kwargs = {"weights_only": True} + else: + torch_load_kwargs = {"weights_only": False} + def _get_state_dict(f_name): map_location = get_map_location(self.device) self.device = self._check_device(self.device, map_location) - return torch.load(f_name, map_location=map_location) + return torch.load(f_name, map_location=map_location, **torch_load_kwargs) kwargs_full = {} if checkpoint is not None: diff --git a/skorch/tests/test_net.py b/skorch/tests/test_net.py index 17704e1bd..38e99632a 100644 --- a/skorch/tests/test_net.py +++ b/skorch/tests/test_net.py @@ -16,6 +16,7 @@ from unittest.mock import patch import sys import time +import warnings from contextlib import ExitStack from flaky import flaky @@ -30,6 +31,7 @@ import torch from torch import nn +import skorch from skorch.tests.conftest import INFERENCE_METHODS from skorch.utils import flatten from skorch.utils import to_numpy @@ -561,6 +563,17 @@ def test_load_params_unknown_attribute_raises(self, net_fit): with pytest.raises(AttributeError, match=msg): net_fit.load_params(f_unknown='some-file.pt') + def test_load_params_no_warning(self, net_fit, tmp_path, recwarn): + # See discussion in 1063 + # Ensure that there is no FutureWarning (and DeprecationWarning for good + # measure) caused by torch.load. + net_fit.save_params(f_params=tmp_path / 'weights.pt') + net_fit.load_params(f_params=tmp_path / 'weights.pt') + assert not any( + isinstance(warning.message, (DeprecationWarning, FutureWarning)) + for warning in recwarn.list + ) + @pytest.mark.parametrize('use_safetensors', [False, True]) def test_save_load_state_dict_file( self, net_cls, module_cls, net_fit, data, tmpdir, use_safetensors): @@ -2983,6 +2996,74 @@ def test_save_load_state_dict_custom_module( weights_loaded = net_new.custom_.state_dict()['sequential.3.weight'] assert (weights_before == weights_loaded).all() + def test_torch_load_kwargs_auto_weights_only_false_when_load_params( + self, net_cls, module_cls, monkeypatch, tmp_path + ): + # Here we assume that the torch version is low enough that weights_only + # defaults to False. Check that when no argument is set in skorch, the + # right default is used. + # See discussion in 1063 + net = net_cls(module_cls).initialize() + net.save_params(f_params=tmp_path / 'params.pkl') + state_dict = net.module_.state_dict() + + mock_torch_load = Mock(return_value=state_dict) + monkeypatch.setattr(torch, "load", mock_torch_load) + monkeypatch.setattr( + skorch.net, "check_torch_weights_only_default_true", lambda: False + ) + + net.load_params(f_params=tmp_path / 'params.pkl') + + call_kwargs = mock_torch_load.call_args_list[0].kwargs + del call_kwargs['map_location'] # we're not interested in that + expected_kwargs = {"weights_only": False} + assert call_kwargs == expected_kwargs + + def test_torch_load_kwargs_auto_weights_only_true_when_load_params( + self, net_cls, module_cls, monkeypatch, tmp_path + ): + # Here we assume that the torch version is high enough that weights_only + # defaults to True. Check that when no argument is set in skorch, the + # right default is used. + # See discussion in 1063 + net = net_cls(module_cls).initialize() + net.save_params(f_params=tmp_path / 'params.pkl') + state_dict = net.module_.state_dict() + + mock_torch_load = Mock(return_value=state_dict) + monkeypatch.setattr(torch, "load", mock_torch_load) + monkeypatch.setattr( + skorch.net, "check_torch_weights_only_default_true", lambda: True + ) + + net.load_params(f_params=tmp_path / 'params.pkl') + + call_kwargs = mock_torch_load.call_args_list[0].kwargs + del call_kwargs['map_location'] # we're not interested in that + expected_kwargs = {"weights_only": True} + assert call_kwargs == expected_kwargs + + def test_torch_load_kwargs_forwarded_to_torch_load( + self, net_cls, module_cls, monkeypatch, tmp_path + ): + # Here we check that custom set torch load args are forwarded to + # torch.load. + # See discussion in 1063 + torch_load_kwargs = {'weights_only': 123, 'foo': 'bar'} + net = net_cls(module_cls, torch_load_kwargs=torch_load_kwargs).initialize() + net.save_params(f_params=tmp_path / 'params.pkl') + state_dict = net.module_.state_dict() + + mock_torch_load = Mock(return_value=state_dict) + monkeypatch.setattr(torch, "load", mock_torch_load) + + net.load_params(f_params=tmp_path / 'params.pkl') + + call_kwargs = mock_torch_load.call_args_list[0].kwargs + del call_kwargs['map_location'] # we're not interested in that + assert call_kwargs == torch_load_kwargs + def test_custom_module_params_passed_to_optimizer( self, net_custom_module_cls, module_cls): # custom module parameters should automatically be passed to the optimizer diff --git a/skorch/utils.py b/skorch/utils.py index de679ec35..6fa61934c 100644 --- a/skorch/utils.py +++ b/skorch/utils.py @@ -28,6 +28,7 @@ from skorch.exceptions import DeviceWarning from skorch.exceptions import NotInitializedError +from ._version import Version try: import torch_geometric @@ -768,3 +769,16 @@ def _check_f_arguments(caller_name, **kwargs): key = 'module_' if key == 'f_params' else key[2:] + '_' kwargs_module[key] = val return kwargs_module, kwargs_other + + +def check_torch_weights_only_default_true(): + """Check if the version of torch is one that made the switch to + only_weights=True in torch.load + + The planned switch is PyTorch version 2.6.0, but depending on what happens, + this may require updating + + """ + version_torch = Version(torch.__version__) + version_default_switch = Version('2.6.0') + return version_torch >= version_default_switch From 9acfb8479db5546199adf99d3dbc3942844aa37d Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Fri, 23 Aug 2024 17:46:54 +0200 Subject: [PATCH 2/4] Reviewer feedback: return kwargs directly --- skorch/net.py | 7 ++----- skorch/tests/test_net.py | 14 +++++++------- skorch/utils.py | 14 ++++++++------ 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/skorch/net.py b/skorch/net.py index e0a4a632e..05e5ec88d 100644 --- a/skorch/net.py +++ b/skorch/net.py @@ -46,7 +46,7 @@ from skorch.utils import to_device from skorch.utils import to_numpy from skorch.utils import to_tensor -from skorch.utils import check_torch_weights_only_default_true +from skorch.utils import get_torch_load_kwargs # pylint: disable=too-many-instance-attributes @@ -2652,10 +2652,7 @@ def _get_state_dict(f_name): else: torch_load_kwargs = self.torch_load_kwargs if torch_load_kwargs is None: - if check_torch_weights_only_default_true(): - torch_load_kwargs = {"weights_only": True} - else: - torch_load_kwargs = {"weights_only": False} + torch_load_kwargs = get_torch_load_kwargs() def _get_state_dict(f_name): map_location = get_map_location(self.device) diff --git a/skorch/tests/test_net.py b/skorch/tests/test_net.py index 38e99632a..d54d3f9e1 100644 --- a/skorch/tests/test_net.py +++ b/skorch/tests/test_net.py @@ -3006,18 +3006,18 @@ def test_torch_load_kwargs_auto_weights_only_false_when_load_params( net = net_cls(module_cls).initialize() net.save_params(f_params=tmp_path / 'params.pkl') state_dict = net.module_.state_dict() + expected_kwargs = {"weights_only": False} mock_torch_load = Mock(return_value=state_dict) monkeypatch.setattr(torch, "load", mock_torch_load) monkeypatch.setattr( - skorch.net, "check_torch_weights_only_default_true", lambda: False + skorch.net, "get_torch_load_kwargs", lambda: expected_kwargs ) net.load_params(f_params=tmp_path / 'params.pkl') call_kwargs = mock_torch_load.call_args_list[0].kwargs del call_kwargs['map_location'] # we're not interested in that - expected_kwargs = {"weights_only": False} assert call_kwargs == expected_kwargs def test_torch_load_kwargs_auto_weights_only_true_when_load_params( @@ -3030,18 +3030,18 @@ def test_torch_load_kwargs_auto_weights_only_true_when_load_params( net = net_cls(module_cls).initialize() net.save_params(f_params=tmp_path / 'params.pkl') state_dict = net.module_.state_dict() + expected_kwargs = {"weights_only": True} mock_torch_load = Mock(return_value=state_dict) monkeypatch.setattr(torch, "load", mock_torch_load) monkeypatch.setattr( - skorch.net, "check_torch_weights_only_default_true", lambda: True + skorch.net, "get_torch_load_kwargs", lambda: expected_kwargs ) net.load_params(f_params=tmp_path / 'params.pkl') call_kwargs = mock_torch_load.call_args_list[0].kwargs del call_kwargs['map_location'] # we're not interested in that - expected_kwargs = {"weights_only": True} assert call_kwargs == expected_kwargs def test_torch_load_kwargs_forwarded_to_torch_load( @@ -3050,8 +3050,8 @@ def test_torch_load_kwargs_forwarded_to_torch_load( # Here we check that custom set torch load args are forwarded to # torch.load. # See discussion in 1063 - torch_load_kwargs = {'weights_only': 123, 'foo': 'bar'} - net = net_cls(module_cls, torch_load_kwargs=torch_load_kwargs).initialize() + expected_kwargs = {'weights_only': 123, 'foo': 'bar'} + net = net_cls(module_cls, torch_load_kwargs=expected_kwargs).initialize() net.save_params(f_params=tmp_path / 'params.pkl') state_dict = net.module_.state_dict() @@ -3062,7 +3062,7 @@ def test_torch_load_kwargs_forwarded_to_torch_load( call_kwargs = mock_torch_load.call_args_list[0].kwargs del call_kwargs['map_location'] # we're not interested in that - assert call_kwargs == torch_load_kwargs + assert call_kwargs == expected_kwargs def test_custom_module_params_passed_to_optimizer( self, net_custom_module_cls, module_cls): diff --git a/skorch/utils.py b/skorch/utils.py index 6fa61934c..b1320c2a2 100644 --- a/skorch/utils.py +++ b/skorch/utils.py @@ -771,14 +771,16 @@ def _check_f_arguments(caller_name, **kwargs): return kwargs_module, kwargs_other -def check_torch_weights_only_default_true(): - """Check if the version of torch is one that made the switch to - only_weights=True in torch.load +def get_torch_load_kwargs(): + """Returns the kwargs passed to torch.load the correspond to the current + torch version. - The planned switch is PyTorch version 2.6.0, but depending on what happens, - this may require updating + The plan is to switch from weights_only=False to True in PyTorch version + 2.6.0, but depending on what happens, this may require updating. """ version_torch = Version(torch.__version__) version_default_switch = Version('2.6.0') - return version_torch >= version_default_switch + if version_torch >= version_default_switch: + return {"weights_only": True} + return {"weights_only": False} From ab9c536516466f2fdb029caf71fd43f2d6e5d375 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 2 Sep 2024 15:42:54 +0200 Subject: [PATCH 3/4] Reviewer feedback: One more test w/o monkeypatch Instead, rely on the installed torch version and skip if it doesn't fit. --- skorch/tests/test_net.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/skorch/tests/test_net.py b/skorch/tests/test_net.py index d54d3f9e1..b4f056c73 100644 --- a/skorch/tests/test_net.py +++ b/skorch/tests/test_net.py @@ -3064,6 +3064,32 @@ def test_torch_load_kwargs_forwarded_to_torch_load( del call_kwargs['map_location'] # we're not interested in that assert call_kwargs == expected_kwargs + def test_torch_load_kwargs_auto_weights_false_pytorch_lt_2_6( + self, net_cls, module_cls, monkeypatch, tmp_path + ): + # Same test as test_torch_load_kwargs_auto_weights_only_false_when_load_params + # but without monkeypatching get_torch_load_kwargs. There is no corresponding + # test for >= 2.6.0 since it's not clear yet if the switch will be made in that + # version. + # See discussion in 1063. + from skorch._version import Version + + if Version(torch.__version__) >= Version('2.6.0'): + pytest.skip("Test only for torch < v2.6.0") + + net = net_cls(module_cls).initialize() + net.save_params(f_params=tmp_path / 'params.pkl') + state_dict = net.module_.state_dict() + expected_kwargs = {"weights_only": False} + + mock_torch_load = Mock(return_value=state_dict) + monkeypatch.setattr(torch, "load", mock_torch_load) + net.load_params(f_params=tmp_path / 'params.pkl') + + call_kwargs = mock_torch_load.call_args_list[0].kwargs + del call_kwargs['map_location'] # we're not interested in that + assert call_kwargs == expected_kwargs + def test_custom_module_params_passed_to_optimizer( self, net_custom_module_cls, module_cls): # custom module parameters should automatically be passed to the optimizer From f4162acb55a3ec6bdf45aef3ec08811f383a3992 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 3 Sep 2024 14:07:15 +0200 Subject: [PATCH 4/4] Reviewer feedback: rename function, fix typo --- skorch/net.py | 4 ++-- skorch/tests/test_net.py | 13 +++++++------ skorch/utils.py | 4 ++-- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/skorch/net.py b/skorch/net.py index 05e5ec88d..6b7748be9 100644 --- a/skorch/net.py +++ b/skorch/net.py @@ -46,7 +46,7 @@ from skorch.utils import to_device from skorch.utils import to_numpy from skorch.utils import to_tensor -from skorch.utils import get_torch_load_kwargs +from skorch.utils import get_default_torch_load_kwargs # pylint: disable=too-many-instance-attributes @@ -2652,7 +2652,7 @@ def _get_state_dict(f_name): else: torch_load_kwargs = self.torch_load_kwargs if torch_load_kwargs is None: - torch_load_kwargs = get_torch_load_kwargs() + torch_load_kwargs = get_default_torch_load_kwargs() def _get_state_dict(f_name): map_location = get_map_location(self.device) diff --git a/skorch/tests/test_net.py b/skorch/tests/test_net.py index b4f056c73..768ef00b9 100644 --- a/skorch/tests/test_net.py +++ b/skorch/tests/test_net.py @@ -3011,7 +3011,7 @@ def test_torch_load_kwargs_auto_weights_only_false_when_load_params( mock_torch_load = Mock(return_value=state_dict) monkeypatch.setattr(torch, "load", mock_torch_load) monkeypatch.setattr( - skorch.net, "get_torch_load_kwargs", lambda: expected_kwargs + skorch.net, "get_default_torch_load_kwargs", lambda: expected_kwargs ) net.load_params(f_params=tmp_path / 'params.pkl') @@ -3035,7 +3035,7 @@ def test_torch_load_kwargs_auto_weights_only_true_when_load_params( mock_torch_load = Mock(return_value=state_dict) monkeypatch.setattr(torch, "load", mock_torch_load) monkeypatch.setattr( - skorch.net, "get_torch_load_kwargs", lambda: expected_kwargs + skorch.net, "get_default_torch_load_kwargs", lambda: expected_kwargs ) net.load_params(f_params=tmp_path / 'params.pkl') @@ -3067,10 +3067,11 @@ def test_torch_load_kwargs_forwarded_to_torch_load( def test_torch_load_kwargs_auto_weights_false_pytorch_lt_2_6( self, net_cls, module_cls, monkeypatch, tmp_path ): - # Same test as test_torch_load_kwargs_auto_weights_only_false_when_load_params - # but without monkeypatching get_torch_load_kwargs. There is no corresponding - # test for >= 2.6.0 since it's not clear yet if the switch will be made in that - # version. + # Same test as + # test_torch_load_kwargs_auto_weights_only_false_when_load_params but + # without monkeypatching get_default_torch_load_kwargs. There is no + # corresponding test for >= 2.6.0 since it's not clear yet if the switch + # will be made in that version. # See discussion in 1063. from skorch._version import Version diff --git a/skorch/utils.py b/skorch/utils.py index b1320c2a2..851936dbe 100644 --- a/skorch/utils.py +++ b/skorch/utils.py @@ -771,8 +771,8 @@ def _check_f_arguments(caller_name, **kwargs): return kwargs_module, kwargs_other -def get_torch_load_kwargs(): - """Returns the kwargs passed to torch.load the correspond to the current +def get_default_torch_load_kwargs(): + """Returns the kwargs passed to torch.load that correspond to the current torch version. The plan is to switch from weights_only=False to True in PyTorch version