Skip to content

Commit

Permalink
Flip default on weights_only (pytorch#137602)
Browse files Browse the repository at this point in the history
  • Loading branch information
mikaylagawarecki authored and Ryo-not-rio committed Dec 2, 2024
1 parent a2793c8 commit c1fab91
Show file tree
Hide file tree
Showing 12 changed files with 88 additions and 42 deletions.
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/torchbench.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
e522b45cd4535b9dfe067aa68d7315755df38f48
766a5e3a189384659fd35a68c3b17b88c761aaac
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/xla.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2eb4a60ed14a38260b85b0c765161f0ce45be6d1
f71c02d1f457d58371e013632efb016c01bd1866
15 changes: 10 additions & 5 deletions test/distributed/_shard/sharded_tensor/test_sharded_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,7 +1245,8 @@ def test_state_dict(self):
module_load._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True)

buffer.seek(0)
state_dict_deser = torch.load(buffer)
# weights_only=False as ShardedTensor weights_only is already tested in TestFSDPStateDict.test_torch_save_load
state_dict_deser = torch.load(buffer, weights_only=False)
module_load.load_state_dict(state_dict_deser, strict=False)

module_load._register_state_dict_hook(state_dict_hook)
Expand Down Expand Up @@ -1289,7 +1290,8 @@ def test_state_dict_new_group(self):

buffer.seek(0)
with load_with_process_group(pg):
state_dict_deser = torch.load(buffer)
# ShardedTensor weights_only is already tested in TestFSDPStateDict.test_torch_save_load
state_dict_deser = torch.load(buffer, weights_only=False)
module_load.load_state_dict(state_dict_deser, strict=False)

# Verify after load.
Expand Down Expand Up @@ -1361,20 +1363,23 @@ def test_load_state_dict_errors(self):
if self.rank != 0:
with self.assertRaisesRegex(RuntimeError, "Local rank at save time was"):
with load_with_process_group(pg):
state_dict_deser = torch.load(buffer)
# ShardedTensor weights_only is already tested in TestFSDPStateDict.test_torch_save_load
state_dict_deser = torch.load(buffer, weights_only=False)
else:
with self.assertRaisesRegex(
RuntimeError, "Local world size at save time was"
):
with load_with_process_group(pg):
state_dict_deser = torch.load(buffer)
# ShardedTensor weights_only is already tested in TestFSDPStateDict.test_torch_save_load
state_dict_deser = torch.load(buffer, weights_only=False)

dist.destroy_process_group()
buffer.seek(0)
with self.assertRaisesRegex(
RuntimeError, "Need to initialize default process group"
):
state_dict_deser = torch.load(buffer)
# ShardedTensor weights_only is already tested in TestFSDPStateDict.test_torch_save_load
state_dict_deser = torch.load(buffer, weights_only=False)
rpc.shutdown()

@with_comms
Expand Down
23 changes: 22 additions & 1 deletion test/distributed/fsdp/test_fsdp_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
Shard,
ShardedTensor,
)
from torch.distributed._shard.sharded_tensor.metadata import (
MEM_FORMAT_ENCODING,
ShardedTensorMetadata,
TensorProperties,
)
from torch.distributed._shard.sharding_spec import ChunkShardingSpec, ShardMetadata
from torch.distributed._state_dict_utils import (
_all_gather_sharded_tensor,
_gather_state_dict,
Expand All @@ -37,6 +43,7 @@
from torch.distributed.fsdp._common_utils import FSDP_PREFIX
from torch.distributed.fsdp._unshard_param_utils import FLAT_PARAM
from torch.distributed.fsdp.wrap import enable_wrap, ModuleWrapPolicy, wrap
from torch.distributed.remote_device import _remote_device
from torch.nn import Linear, Module, TransformerDecoderLayer, TransformerEncoderLayer
from torch.nn.parallel import DistributedDataParallel
from torch.optim import SGD
Expand Down Expand Up @@ -1160,7 +1167,21 @@ def test_torch_save_load(self):
checkpoint = io.BytesIO()
torch.save(state_dict, checkpoint)
checkpoint.seek(0)
state_dict_saved = torch.load(checkpoint)
with torch.serialization.safe_globals(
[
Shard,
ShardMetadata,
ShardedTensor,
ShardedTensorMetadata,
TensorProperties,
MEM_FORMAT_ENCODING,
_remote_device,
getattr,
ShardedTensor.ProcessGroupState,
ChunkShardingSpec,
]
):
state_dict_saved = torch.load(checkpoint)
for k, v in state_dict_saved.items():
if isinstance(v, ShardedTensor):
self.assertEqual(
Expand Down
8 changes: 6 additions & 2 deletions test/dynamo/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ def test_save(self):

with tempfile.TemporaryDirectory() as tmpdirname:
torch.save(model, os.path.join(tmpdirname, "model.pt"))
loaded_model = torch.load(os.path.join(tmpdirname, "model.pt"))
# weights_only=False as this is a legacy use case that loads a module
loaded_model = torch.load(
os.path.join(tmpdirname, "model.pt"), weights_only=False
)
loaded_model(torch.randn(1, 10))

def test_state_dict_save(self):
Expand All @@ -58,7 +61,8 @@ def test_state_dict_save(self):
torch.save(model.state_dict(), os.path.join(tmpdirname, "model.pt"))
loaded_model = ToyModel()
loaded_model.load_state_dict(
torch.load(os.path.join(tmpdirname, "model.pt"))
# weights_only=False as this is a legacy use case that loads a module
torch.load(os.path.join(tmpdirname, "model.pt"), weights_only=False)
)
loaded_model(torch.randn(1, 10))

Expand Down
10 changes: 8 additions & 2 deletions test/dynamo/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3002,7 +3002,10 @@ def test_save_and_load_inductor(self):

with tempfile.TemporaryDirectory() as tmpdirname:
torch.save(opt_mod, os.path.join(tmpdirname, "model.pt"))
loaded_model = torch.load(os.path.join(tmpdirname, "model.pt"))
# weights_only=False as this is a legacy use case that loads a module
loaded_model = torch.load(
os.path.join(tmpdirname, "model.pt"), weights_only=False
)
loaded_model(inp)
self.assertTrue(same_two_models(loaded_model, mod, [inp]))
self.assertTrue(same_two_models(loaded_model, opt_mod, [inp]))
Expand All @@ -3020,7 +3023,10 @@ def test_save_and_load_all_backends(self):
opt_mod = torch.compile(mod, backend=backend)
with tempfile.TemporaryDirectory() as tmpdirname:
torch.save(opt_mod, os.path.join(tmpdirname, "model.pt"))
loaded_model = torch.load(os.path.join(tmpdirname, "model.pt"))
# weights_only=False as this is a legacy use case that loads a module
loaded_model = torch.load(
os.path.join(tmpdirname, "model.pt"), weights_only=False
)
torch._dynamo.reset() # force recompiles
torch._inductor.metrics.generated_kernel_count = 0
opt_mod(inp)
Expand Down
3 changes: 2 additions & 1 deletion test/load_torchscript_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

if __name__ == "__main__":
script_mod = torch.jit.load(sys.argv[1])
mod = torch.load(sys.argv[1] + ".orig")
# weights_only=False as this is loading a sharded model
mod = torch.load(sys.argv[1] + ".orig", weights_only=False)
print(script_mod)
inp = torch.rand(2, 28 * 28)
_ = mod(inp)
Expand Down
6 changes: 4 additions & 2 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -8825,7 +8825,8 @@ def test_module_backcompat(self):
path = download_file('https://download.pytorch.org/test_data/linear.pt')
with warnings.catch_warnings():
warnings.simplefilter('ignore', SourceChangeWarning)
m = torch.load(path)
# weights_only=False as this is a legacy use case that loads a module
m = torch.load(path, weights_only=False)
input = torch.randn(2, 3, dtype=torch.float)
self.assertEqual(m(input).size(), (2, 5))

Expand All @@ -8842,7 +8843,8 @@ def test_conv_backcompat(self):
path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt')
with warnings.catch_warnings():
warnings.simplefilter('ignore', SourceChangeWarning)
m = torch.load(path, encoding='utf-8')
# weights_only=False as this is a legacy use case that loads a module
m = torch.load(path, encoding='utf-8', weights_only=False)
input = torch.randn((1, 1, 1, 1), dtype=torch.float)
self.assertEqual(m(input).size(), (1, 1, 1, 1))

Expand Down
19 changes: 17 additions & 2 deletions test/test_nestedtensor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Owner(s): ["module: nestedtensor"]

import ast
import contextlib
import io
import itertools
import math
Expand Down Expand Up @@ -3657,7 +3658,8 @@ def _make_tensor(
["contig", "noncontig_transposed", "noncontig_with_holes"],
name_fn=lambda c: c,
)
def test_serialization(self, device, dtype, contiguity):
@parametrize("weights_only", [True, False])
def test_serialization(self, device, dtype, contiguity, weights_only):
# Test with 3 cases:
# 1. contiguous
# 2. non-contiguous transposed
Expand Down Expand Up @@ -3693,8 +3695,21 @@ def test_serialization(self, device, dtype, contiguity):

with tempfile.TemporaryFile() as f:
torch.save(nt, f)
safe_globals = [
torch.nested._internal.nested_tensor.NestedTensor,
torch.nested._internal.nested_tensor._rebuild_njt,
set,
torch._dynamo.decorators._DimRange,
]
f.seek(0)
nt_loaded = torch.load(f)
ctx = (
torch.serialization.safe_globals(safe_globals)
if weights_only
else contextlib.nullcontext()
)

with ctx:
nt_loaded = torch.load(f, weights_only=weights_only)

self.assertIsNot(nt, nt_loaded)
# we expect a new offsets tensor -> different nested int upon load
Expand Down
2 changes: 1 addition & 1 deletion test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,7 +1196,7 @@ def test_weights_only_error(self, unsafe_global):
f.seek(0)
if unsafe_global:
with self.assertRaisesRegex(pickle.UnpicklingError,
r"use `torch.serialization.add_safe_globals\(\[TwoTensor\]\)` to allowlist"):
r"use `torch.serialization.add_safe_globals\(\[TwoTensor\]\)` or .* to allowlist"):
torch.load(f, weights_only=True)
else:
with self.assertRaisesRegex(pickle.UnpicklingError,
Expand Down
5 changes: 3 additions & 2 deletions torch/_weights_only_unpickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,9 @@ def load(self):
else:
raise UnpicklingError(
f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. "
f"Please use `torch.serialization.add_safe_globals([{name}])` to allowlist "
"this global if you trust this class/function."
f"Please use `torch.serialization.add_safe_globals([{name}])` or the "
f"`torch.serialization.safe_globals([{name}])` context manager to allowlist this global "
"if you trust this class/function."
)
elif key[0] == NEWOBJ[0]:
args = self.stack.pop()
Expand Down
35 changes: 13 additions & 22 deletions torch/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
"skip_data",
]

IS_FBCODE = not hasattr(torch.version, "git_version")

DEFAULT_PROTOCOL = 2

Expand All @@ -92,6 +93,10 @@
MAP_SHARED, MAP_PRIVATE = None, None # type: ignore[assignment]


def _default_to_weights_only(pickle_module):
return pickle_module is None and not IS_FBCODE


# _serialization_tls is used to store thread local state specific to serialization
# that needs to be propagated to other files, in particular we use this for
# (1) map_location (needed for wrapper subclasses/third party devices to torch._utils)
Expand Down Expand Up @@ -1205,7 +1210,7 @@ def load(
# documentation. We need it so that Sphinx doesn't leak `pickle`s path from
# the build environment (e.g. `<module 'pickle' from '/leaked/path').

"""load(f, map_location=None, pickle_module=pickle, *, weights_only=False, mmap=None, **pickle_load_args)
"""load(f, map_location=None, pickle_module=pickle, *, weights_only=True, mmap=None, **pickle_load_args)
Loads an object saved with :func:`torch.save` from a file.
Expand Down Expand Up @@ -1347,6 +1352,11 @@ def _get_wo_message(message: str) -> str:
"is not supported yet. Please call torch.load outside the skip_data context manager."
)

weights_only_not_set = weights_only is None

if weights_only_not_set:
weights_only = _default_to_weights_only(pickle_module)

true_values = ["1", "y", "yes", "true"]
# Add ability to force safe only or non-safe weight loads via environment variables
force_weights_only_load = (
Expand All @@ -1364,7 +1374,8 @@ def _get_wo_message(message: str) -> str:
elif force_weights_only_load:
weights_only = True
elif force_no_weights_only_load:
if weights_only is None:
# TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD can only override if callsite did not explicitly set weights_only
if weights_only_not_set:
warnings.warn(
"Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the"
"`weights_only` argument was not explicitly passed to `torch.load`, forcing weights_only=False.",
Expand All @@ -1373,33 +1384,13 @@ def _get_wo_message(message: str) -> str:
)
weights_only = False

if weights_only is None:
weights_only, warn_weights_only = False, True
else:
warn_weights_only = False

if weights_only:
if pickle_module is not None:
raise RuntimeError(
"Can not safely load weights when explicit pickle_module is specified"
)
else:
if pickle_module is None:
if warn_weights_only:
warnings.warn(
"You are using `torch.load` with `weights_only=False` (the current default value), which uses "
"the default pickle module implicitly. It is possible to construct malicious pickle data "
"which will execute arbitrary code during unpickling (See "
"https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). "
"In a future release, the default value for `weights_only` will be flipped to `True`. This "
"limits the functions that could be executed during unpickling. Arbitrary objects will no "
"longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the "
"user via `torch.serialization.add_safe_globals`. We recommend you start setting "
"`weights_only=True` for any use case where you don't have full control of the loaded file. "
"Please open an issue on GitHub for any issues related to this experimental feature.",
FutureWarning,
stacklevel=2,
)
pickle_module = pickle

# make flipping default BC-compatible
Expand Down

0 comments on commit c1fab91

Please sign in to comment.