Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix warning from torch.load starting in torch 2.4 #1064

Merged
merged 5 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from skorch.utils import to_device
from skorch.utils import to_numpy
from skorch.utils import to_tensor
from skorch.utils import get_torch_load_kwargs


# pylint: disable=too-many-instance-attributes
Expand Down Expand Up @@ -235,6 +236,33 @@ class NeuralNet:
callbacks.
Implementation note: It is the job of the callbacks to honor this setting.

torch_load_kwargs : dict or None (default=None)
Additional arguments that will be passed to torch.load when load pickled
parameters.

In particular, this is important to because PyTorch will switch (probably
in version 2.6.0) to only allow weights to be loaded for security reasons
(i.e weights_only switches from False to True). As a consequence, loading
pickled parameters may raise an error after upgrading torch because some
types are used that are considered insecure. In skorch, we will also make
that switch at the same time. To resolve the error, follow the
instructions in the torch error message to designate the offending types
as secure. Only do this if you trust the source of the file.

If you want to keep loading non-weight types the same way as before,
please pass:

torch_load_kwargs={'weights_only': False}

You should be aware that this is considered insecure and should only be
used if you trust the source of the file. However, this does not introduce
new insecurities, it rather corresponds to the status quo from before
torch made the switch.

Another way to avoid this issue is to pass use_safetensors=True when
calling save_params and load_params. This avoid using pickle in favor of
the safetensors format, which is secure by design.

Attributes
----------
prefixes_ : list of str
Expand Down Expand Up @@ -311,6 +339,7 @@ def __init__(
device='cpu',
compile=False,
use_caching='auto',
torch_load_kwargs=None,
**kwargs
):
self.module = module
Expand All @@ -330,6 +359,7 @@ def __init__(
self.device = device
self.compile = compile
self.use_caching = use_caching
self.torch_load_kwargs = torch_load_kwargs

self._check_deprecated_params(**kwargs)
history = kwargs.pop('history', None)
Expand Down Expand Up @@ -2620,10 +2650,14 @@ def _get_state_dict(f_name):

return state_dict
else:
torch_load_kwargs = self.torch_load_kwargs
if torch_load_kwargs is None:
torch_load_kwargs = get_torch_load_kwargs()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
torch_load_kwargs = get_torch_load_kwargs()
torch_load_kwargs = get_default_torch_load_kwargs()


def _get_state_dict(f_name):
map_location = get_map_location(self.device)
self.device = self._check_device(self.device, map_location)
return torch.load(f_name, map_location=map_location)
return torch.load(f_name, map_location=map_location, **torch_load_kwargs)

kwargs_full = {}
if checkpoint is not None:
Expand Down
81 changes: 81 additions & 0 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from unittest.mock import patch
import sys
import time
import warnings
from contextlib import ExitStack

from flaky import flaky
Expand All @@ -30,6 +31,7 @@
import torch
from torch import nn

import skorch
from skorch.tests.conftest import INFERENCE_METHODS
from skorch.utils import flatten
from skorch.utils import to_numpy
Expand Down Expand Up @@ -561,6 +563,17 @@ def test_load_params_unknown_attribute_raises(self, net_fit):
with pytest.raises(AttributeError, match=msg):
net_fit.load_params(f_unknown='some-file.pt')

def test_load_params_no_warning(self, net_fit, tmp_path, recwarn):
# See discussion in 1063
# Ensure that there is no FutureWarning (and DeprecationWarning for good
# measure) caused by torch.load.
net_fit.save_params(f_params=tmp_path / 'weights.pt')
net_fit.load_params(f_params=tmp_path / 'weights.pt')
assert not any(
isinstance(warning.message, (DeprecationWarning, FutureWarning))
for warning in recwarn.list
)

@pytest.mark.parametrize('use_safetensors', [False, True])
def test_save_load_state_dict_file(
self, net_cls, module_cls, net_fit, data, tmpdir, use_safetensors):
Expand Down Expand Up @@ -2983,6 +2996,74 @@ def test_save_load_state_dict_custom_module(
weights_loaded = net_new.custom_.state_dict()['sequential.3.weight']
assert (weights_before == weights_loaded).all()

def test_torch_load_kwargs_auto_weights_only_false_when_load_params(
self, net_cls, module_cls, monkeypatch, tmp_path
):
# Here we assume that the torch version is low enough that weights_only
# defaults to False. Check that when no argument is set in skorch, the
# right default is used.
# See discussion in 1063
ottonemo marked this conversation as resolved.
Show resolved Hide resolved
net = net_cls(module_cls).initialize()
net.save_params(f_params=tmp_path / 'params.pkl')
state_dict = net.module_.state_dict()
expected_kwargs = {"weights_only": False}

mock_torch_load = Mock(return_value=state_dict)
monkeypatch.setattr(torch, "load", mock_torch_load)
monkeypatch.setattr(
skorch.net, "get_torch_load_kwargs", lambda: expected_kwargs
)

net.load_params(f_params=tmp_path / 'params.pkl')

call_kwargs = mock_torch_load.call_args_list[0].kwargs
del call_kwargs['map_location'] # we're not interested in that
assert call_kwargs == expected_kwargs

def test_torch_load_kwargs_auto_weights_only_true_when_load_params(
self, net_cls, module_cls, monkeypatch, tmp_path
):
# Here we assume that the torch version is high enough that weights_only
# defaults to True. Check that when no argument is set in skorch, the
# right default is used.
# See discussion in 1063
net = net_cls(module_cls).initialize()
net.save_params(f_params=tmp_path / 'params.pkl')
state_dict = net.module_.state_dict()
expected_kwargs = {"weights_only": True}

mock_torch_load = Mock(return_value=state_dict)
monkeypatch.setattr(torch, "load", mock_torch_load)
monkeypatch.setattr(
skorch.net, "get_torch_load_kwargs", lambda: expected_kwargs
)

net.load_params(f_params=tmp_path / 'params.pkl')

call_kwargs = mock_torch_load.call_args_list[0].kwargs
del call_kwargs['map_location'] # we're not interested in that
assert call_kwargs == expected_kwargs

def test_torch_load_kwargs_forwarded_to_torch_load(
self, net_cls, module_cls, monkeypatch, tmp_path
):
# Here we check that custom set torch load args are forwarded to
# torch.load.
# See discussion in 1063
expected_kwargs = {'weights_only': 123, 'foo': 'bar'}
net = net_cls(module_cls, torch_load_kwargs=expected_kwargs).initialize()
net.save_params(f_params=tmp_path / 'params.pkl')
state_dict = net.module_.state_dict()

mock_torch_load = Mock(return_value=state_dict)
monkeypatch.setattr(torch, "load", mock_torch_load)

net.load_params(f_params=tmp_path / 'params.pkl')

call_kwargs = mock_torch_load.call_args_list[0].kwargs
del call_kwargs['map_location'] # we're not interested in that
assert call_kwargs == expected_kwargs

def test_custom_module_params_passed_to_optimizer(
self, net_custom_module_cls, module_cls):
# custom module parameters should automatically be passed to the optimizer
Expand Down
16 changes: 16 additions & 0 deletions skorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from skorch.exceptions import DeviceWarning
from skorch.exceptions import NotInitializedError
from ._version import Version

try:
import torch_geometric
Expand Down Expand Up @@ -768,3 +769,18 @@ def _check_f_arguments(caller_name, **kwargs):
key = 'module_' if key == 'f_params' else key[2:] + '_'
kwargs_module[key] = val
return kwargs_module, kwargs_other


def get_torch_load_kwargs():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def get_torch_load_kwargs():
def get_default_torch_load_kwargs():

"""Returns the kwargs passed to torch.load the correspond to the current
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Returns the kwargs passed to torch.load the correspond to the current
"""Returns the kwargs passed to torch.load that correspond to the current

torch version.

The plan is to switch from weights_only=False to True in PyTorch version
2.6.0, but depending on what happens, this may require updating.

"""
version_torch = Version(torch.__version__)
version_default_switch = Version('2.6.0')
if version_torch >= version_default_switch:
return {"weights_only": True}
return {"weights_only": False}
Loading