diff --git a/skorch/net.py b/skorch/net.py index 32dfec7b..6b7748be 100644 --- a/skorch/net.py +++ b/skorch/net.py @@ -46,6 +46,7 @@ from skorch.utils import to_device from skorch.utils import to_numpy from skorch.utils import to_tensor +from skorch.utils import get_default_torch_load_kwargs # pylint: disable=too-many-instance-attributes @@ -235,6 +236,33 @@ class NeuralNet: callbacks. Implementation note: It is the job of the callbacks to honor this setting. + torch_load_kwargs : dict or None (default=None) + Additional arguments that will be passed to torch.load when load pickled + parameters. + + In particular, this is important to because PyTorch will switch (probably + in version 2.6.0) to only allow weights to be loaded for security reasons + (i.e weights_only switches from False to True). As a consequence, loading + pickled parameters may raise an error after upgrading torch because some + types are used that are considered insecure. In skorch, we will also make + that switch at the same time. To resolve the error, follow the + instructions in the torch error message to designate the offending types + as secure. Only do this if you trust the source of the file. + + If you want to keep loading non-weight types the same way as before, + please pass: + + torch_load_kwargs={'weights_only': False} + + You should be aware that this is considered insecure and should only be + used if you trust the source of the file. However, this does not introduce + new insecurities, it rather corresponds to the status quo from before + torch made the switch. + + Another way to avoid this issue is to pass use_safetensors=True when + calling save_params and load_params. This avoid using pickle in favor of + the safetensors format, which is secure by design. + Attributes ---------- prefixes_ : list of str @@ -311,6 +339,7 @@ def __init__( device='cpu', compile=False, use_caching='auto', + torch_load_kwargs=None, **kwargs ): self.module = module @@ -330,6 +359,7 @@ def __init__( self.device = device self.compile = compile self.use_caching = use_caching + self.torch_load_kwargs = torch_load_kwargs self._check_deprecated_params(**kwargs) history = kwargs.pop('history', None) @@ -2620,10 +2650,14 @@ def _get_state_dict(f_name): return state_dict else: + torch_load_kwargs = self.torch_load_kwargs + if torch_load_kwargs is None: + torch_load_kwargs = get_default_torch_load_kwargs() + def _get_state_dict(f_name): map_location = get_map_location(self.device) self.device = self._check_device(self.device, map_location) - return torch.load(f_name, map_location=map_location) + return torch.load(f_name, map_location=map_location, **torch_load_kwargs) kwargs_full = {} if checkpoint is not None: diff --git a/skorch/tests/test_net.py b/skorch/tests/test_net.py index 17704e1b..768ef00b 100644 --- a/skorch/tests/test_net.py +++ b/skorch/tests/test_net.py @@ -16,6 +16,7 @@ from unittest.mock import patch import sys import time +import warnings from contextlib import ExitStack from flaky import flaky @@ -30,6 +31,7 @@ import torch from torch import nn +import skorch from skorch.tests.conftest import INFERENCE_METHODS from skorch.utils import flatten from skorch.utils import to_numpy @@ -561,6 +563,17 @@ def test_load_params_unknown_attribute_raises(self, net_fit): with pytest.raises(AttributeError, match=msg): net_fit.load_params(f_unknown='some-file.pt') + def test_load_params_no_warning(self, net_fit, tmp_path, recwarn): + # See discussion in 1063 + # Ensure that there is no FutureWarning (and DeprecationWarning for good + # measure) caused by torch.load. + net_fit.save_params(f_params=tmp_path / 'weights.pt') + net_fit.load_params(f_params=tmp_path / 'weights.pt') + assert not any( + isinstance(warning.message, (DeprecationWarning, FutureWarning)) + for warning in recwarn.list + ) + @pytest.mark.parametrize('use_safetensors', [False, True]) def test_save_load_state_dict_file( self, net_cls, module_cls, net_fit, data, tmpdir, use_safetensors): @@ -2983,6 +2996,101 @@ def test_save_load_state_dict_custom_module( weights_loaded = net_new.custom_.state_dict()['sequential.3.weight'] assert (weights_before == weights_loaded).all() + def test_torch_load_kwargs_auto_weights_only_false_when_load_params( + self, net_cls, module_cls, monkeypatch, tmp_path + ): + # Here we assume that the torch version is low enough that weights_only + # defaults to False. Check that when no argument is set in skorch, the + # right default is used. + # See discussion in 1063 + net = net_cls(module_cls).initialize() + net.save_params(f_params=tmp_path / 'params.pkl') + state_dict = net.module_.state_dict() + expected_kwargs = {"weights_only": False} + + mock_torch_load = Mock(return_value=state_dict) + monkeypatch.setattr(torch, "load", mock_torch_load) + monkeypatch.setattr( + skorch.net, "get_default_torch_load_kwargs", lambda: expected_kwargs + ) + + net.load_params(f_params=tmp_path / 'params.pkl') + + call_kwargs = mock_torch_load.call_args_list[0].kwargs + del call_kwargs['map_location'] # we're not interested in that + assert call_kwargs == expected_kwargs + + def test_torch_load_kwargs_auto_weights_only_true_when_load_params( + self, net_cls, module_cls, monkeypatch, tmp_path + ): + # Here we assume that the torch version is high enough that weights_only + # defaults to True. Check that when no argument is set in skorch, the + # right default is used. + # See discussion in 1063 + net = net_cls(module_cls).initialize() + net.save_params(f_params=tmp_path / 'params.pkl') + state_dict = net.module_.state_dict() + expected_kwargs = {"weights_only": True} + + mock_torch_load = Mock(return_value=state_dict) + monkeypatch.setattr(torch, "load", mock_torch_load) + monkeypatch.setattr( + skorch.net, "get_default_torch_load_kwargs", lambda: expected_kwargs + ) + + net.load_params(f_params=tmp_path / 'params.pkl') + + call_kwargs = mock_torch_load.call_args_list[0].kwargs + del call_kwargs['map_location'] # we're not interested in that + assert call_kwargs == expected_kwargs + + def test_torch_load_kwargs_forwarded_to_torch_load( + self, net_cls, module_cls, monkeypatch, tmp_path + ): + # Here we check that custom set torch load args are forwarded to + # torch.load. + # See discussion in 1063 + expected_kwargs = {'weights_only': 123, 'foo': 'bar'} + net = net_cls(module_cls, torch_load_kwargs=expected_kwargs).initialize() + net.save_params(f_params=tmp_path / 'params.pkl') + state_dict = net.module_.state_dict() + + mock_torch_load = Mock(return_value=state_dict) + monkeypatch.setattr(torch, "load", mock_torch_load) + + net.load_params(f_params=tmp_path / 'params.pkl') + + call_kwargs = mock_torch_load.call_args_list[0].kwargs + del call_kwargs['map_location'] # we're not interested in that + assert call_kwargs == expected_kwargs + + def test_torch_load_kwargs_auto_weights_false_pytorch_lt_2_6( + self, net_cls, module_cls, monkeypatch, tmp_path + ): + # Same test as + # test_torch_load_kwargs_auto_weights_only_false_when_load_params but + # without monkeypatching get_default_torch_load_kwargs. There is no + # corresponding test for >= 2.6.0 since it's not clear yet if the switch + # will be made in that version. + # See discussion in 1063. + from skorch._version import Version + + if Version(torch.__version__) >= Version('2.6.0'): + pytest.skip("Test only for torch < v2.6.0") + + net = net_cls(module_cls).initialize() + net.save_params(f_params=tmp_path / 'params.pkl') + state_dict = net.module_.state_dict() + expected_kwargs = {"weights_only": False} + + mock_torch_load = Mock(return_value=state_dict) + monkeypatch.setattr(torch, "load", mock_torch_load) + net.load_params(f_params=tmp_path / 'params.pkl') + + call_kwargs = mock_torch_load.call_args_list[0].kwargs + del call_kwargs['map_location'] # we're not interested in that + assert call_kwargs == expected_kwargs + def test_custom_module_params_passed_to_optimizer( self, net_custom_module_cls, module_cls): # custom module parameters should automatically be passed to the optimizer diff --git a/skorch/utils.py b/skorch/utils.py index de679ec3..851936db 100644 --- a/skorch/utils.py +++ b/skorch/utils.py @@ -28,6 +28,7 @@ from skorch.exceptions import DeviceWarning from skorch.exceptions import NotInitializedError +from ._version import Version try: import torch_geometric @@ -768,3 +769,18 @@ def _check_f_arguments(caller_name, **kwargs): key = 'module_' if key == 'f_params' else key[2:] + '_' kwargs_module[key] = val return kwargs_module, kwargs_other + + +def get_default_torch_load_kwargs(): + """Returns the kwargs passed to torch.load that correspond to the current + torch version. + + The plan is to switch from weights_only=False to True in PyTorch version + 2.6.0, but depending on what happens, this may require updating. + + """ + version_torch = Version(torch.__version__) + version_default_switch = Version('2.6.0') + if version_torch >= version_default_switch: + return {"weights_only": True} + return {"weights_only": False}