From c1fab914d3ecd55461287419d73940edac17c2ba Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Fri, 1 Nov 2024 18:58:05 -0700 Subject: [PATCH] Flip default on weights_only (#137602) Pull Request resolved: https://github.com/pytorch/pytorch/pull/137602 Approved by: https://github.com/malfet, https://github.com/albanD ghstack dependencies: #138936, #139221, #139433, #139541 --- .github/ci_commit_pins/torchbench.txt | 2 +- .github/ci_commit_pins/xla.txt | 2 +- .../sharded_tensor/test_sharded_tensor.py | 15 +++++--- test/distributed/fsdp/test_fsdp_state_dict.py | 23 +++++++++++- test/dynamo/test_compile.py | 8 +++-- test/dynamo/test_modules.py | 10 ++++-- test/load_torchscript_model.py | 3 +- test/test_mps.py | 6 ++-- test/test_nestedtensor.py | 19 ++++++++-- test/test_serialization.py | 2 +- torch/_weights_only_unpickler.py | 5 +-- torch/serialization.py | 35 +++++++------------ 12 files changed, 88 insertions(+), 42 deletions(-) diff --git a/.github/ci_commit_pins/torchbench.txt b/.github/ci_commit_pins/torchbench.txt index 21b3c3481f3988..4f922a0676eb2c 100644 --- a/.github/ci_commit_pins/torchbench.txt +++ b/.github/ci_commit_pins/torchbench.txt @@ -1 +1 @@ -e522b45cd4535b9dfe067aa68d7315755df38f48 +766a5e3a189384659fd35a68c3b17b88c761aaac diff --git a/.github/ci_commit_pins/xla.txt b/.github/ci_commit_pins/xla.txt index 9d412df07f46c2..03db6224c4139e 100644 --- a/.github/ci_commit_pins/xla.txt +++ b/.github/ci_commit_pins/xla.txt @@ -1 +1 @@ -2eb4a60ed14a38260b85b0c765161f0ce45be6d1 +f71c02d1f457d58371e013632efb016c01bd1866 diff --git a/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py b/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py index 76d06a972bdf8c..730b2c2c0ac27b 100644 --- a/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py +++ b/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py @@ -1245,7 +1245,8 @@ def test_state_dict(self): module_load._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True) buffer.seek(0) - state_dict_deser = torch.load(buffer) + # weights_only=False as ShardedTensor weights_only is already tested in TestFSDPStateDict.test_torch_save_load + state_dict_deser = torch.load(buffer, weights_only=False) module_load.load_state_dict(state_dict_deser, strict=False) module_load._register_state_dict_hook(state_dict_hook) @@ -1289,7 +1290,8 @@ def test_state_dict_new_group(self): buffer.seek(0) with load_with_process_group(pg): - state_dict_deser = torch.load(buffer) + # ShardedTensor weights_only is already tested in TestFSDPStateDict.test_torch_save_load + state_dict_deser = torch.load(buffer, weights_only=False) module_load.load_state_dict(state_dict_deser, strict=False) # Verify after load. @@ -1361,20 +1363,23 @@ def test_load_state_dict_errors(self): if self.rank != 0: with self.assertRaisesRegex(RuntimeError, "Local rank at save time was"): with load_with_process_group(pg): - state_dict_deser = torch.load(buffer) + # ShardedTensor weights_only is already tested in TestFSDPStateDict.test_torch_save_load + state_dict_deser = torch.load(buffer, weights_only=False) else: with self.assertRaisesRegex( RuntimeError, "Local world size at save time was" ): with load_with_process_group(pg): - state_dict_deser = torch.load(buffer) + # ShardedTensor weights_only is already tested in TestFSDPStateDict.test_torch_save_load + state_dict_deser = torch.load(buffer, weights_only=False) dist.destroy_process_group() buffer.seek(0) with self.assertRaisesRegex( RuntimeError, "Need to initialize default process group" ): - state_dict_deser = torch.load(buffer) + # ShardedTensor weights_only is already tested in TestFSDPStateDict.test_torch_save_load + state_dict_deser = torch.load(buffer, weights_only=False) rpc.shutdown() @with_comms diff --git a/test/distributed/fsdp/test_fsdp_state_dict.py b/test/distributed/fsdp/test_fsdp_state_dict.py index 0fa1b38eef42bc..a246375caba8ff 100644 --- a/test/distributed/fsdp/test_fsdp_state_dict.py +++ b/test/distributed/fsdp/test_fsdp_state_dict.py @@ -16,6 +16,12 @@ Shard, ShardedTensor, ) +from torch.distributed._shard.sharded_tensor.metadata import ( + MEM_FORMAT_ENCODING, + ShardedTensorMetadata, + TensorProperties, +) +from torch.distributed._shard.sharding_spec import ChunkShardingSpec, ShardMetadata from torch.distributed._state_dict_utils import ( _all_gather_sharded_tensor, _gather_state_dict, @@ -37,6 +43,7 @@ from torch.distributed.fsdp._common_utils import FSDP_PREFIX from torch.distributed.fsdp._unshard_param_utils import FLAT_PARAM from torch.distributed.fsdp.wrap import enable_wrap, ModuleWrapPolicy, wrap +from torch.distributed.remote_device import _remote_device from torch.nn import Linear, Module, TransformerDecoderLayer, TransformerEncoderLayer from torch.nn.parallel import DistributedDataParallel from torch.optim import SGD @@ -1160,7 +1167,21 @@ def test_torch_save_load(self): checkpoint = io.BytesIO() torch.save(state_dict, checkpoint) checkpoint.seek(0) - state_dict_saved = torch.load(checkpoint) + with torch.serialization.safe_globals( + [ + Shard, + ShardMetadata, + ShardedTensor, + ShardedTensorMetadata, + TensorProperties, + MEM_FORMAT_ENCODING, + _remote_device, + getattr, + ShardedTensor.ProcessGroupState, + ChunkShardingSpec, + ] + ): + state_dict_saved = torch.load(checkpoint) for k, v in state_dict_saved.items(): if isinstance(v, ShardedTensor): self.assertEqual( diff --git a/test/dynamo/test_compile.py b/test/dynamo/test_compile.py index f28855c1ae2548..791ff7a67ffde3 100644 --- a/test/dynamo/test_compile.py +++ b/test/dynamo/test_compile.py @@ -46,7 +46,10 @@ def test_save(self): with tempfile.TemporaryDirectory() as tmpdirname: torch.save(model, os.path.join(tmpdirname, "model.pt")) - loaded_model = torch.load(os.path.join(tmpdirname, "model.pt")) + # weights_only=False as this is a legacy use case that loads a module + loaded_model = torch.load( + os.path.join(tmpdirname, "model.pt"), weights_only=False + ) loaded_model(torch.randn(1, 10)) def test_state_dict_save(self): @@ -58,7 +61,8 @@ def test_state_dict_save(self): torch.save(model.state_dict(), os.path.join(tmpdirname, "model.pt")) loaded_model = ToyModel() loaded_model.load_state_dict( - torch.load(os.path.join(tmpdirname, "model.pt")) + # weights_only=False as this is a legacy use case that loads a module + torch.load(os.path.join(tmpdirname, "model.pt"), weights_only=False) ) loaded_model(torch.randn(1, 10)) diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 3a81380af3ea88..acdd687b6c7b1c 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -3002,7 +3002,10 @@ def test_save_and_load_inductor(self): with tempfile.TemporaryDirectory() as tmpdirname: torch.save(opt_mod, os.path.join(tmpdirname, "model.pt")) - loaded_model = torch.load(os.path.join(tmpdirname, "model.pt")) + # weights_only=False as this is a legacy use case that loads a module + loaded_model = torch.load( + os.path.join(tmpdirname, "model.pt"), weights_only=False + ) loaded_model(inp) self.assertTrue(same_two_models(loaded_model, mod, [inp])) self.assertTrue(same_two_models(loaded_model, opt_mod, [inp])) @@ -3020,7 +3023,10 @@ def test_save_and_load_all_backends(self): opt_mod = torch.compile(mod, backend=backend) with tempfile.TemporaryDirectory() as tmpdirname: torch.save(opt_mod, os.path.join(tmpdirname, "model.pt")) - loaded_model = torch.load(os.path.join(tmpdirname, "model.pt")) + # weights_only=False as this is a legacy use case that loads a module + loaded_model = torch.load( + os.path.join(tmpdirname, "model.pt"), weights_only=False + ) torch._dynamo.reset() # force recompiles torch._inductor.metrics.generated_kernel_count = 0 opt_mod(inp) diff --git a/test/load_torchscript_model.py b/test/load_torchscript_model.py index 807f27ffe76050..d362ae5dd93a00 100644 --- a/test/load_torchscript_model.py +++ b/test/load_torchscript_model.py @@ -5,7 +5,8 @@ if __name__ == "__main__": script_mod = torch.jit.load(sys.argv[1]) - mod = torch.load(sys.argv[1] + ".orig") + # weights_only=False as this is loading a sharded model + mod = torch.load(sys.argv[1] + ".orig", weights_only=False) print(script_mod) inp = torch.rand(2, 28 * 28) _ = mod(inp) diff --git a/test/test_mps.py b/test/test_mps.py index 8962ece03dc18d..5a5f7944d486e6 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -8825,7 +8825,8 @@ def test_module_backcompat(self): path = download_file('https://download.pytorch.org/test_data/linear.pt') with warnings.catch_warnings(): warnings.simplefilter('ignore', SourceChangeWarning) - m = torch.load(path) + # weights_only=False as this is a legacy use case that loads a module + m = torch.load(path, weights_only=False) input = torch.randn(2, 3, dtype=torch.float) self.assertEqual(m(input).size(), (2, 5)) @@ -8842,7 +8843,8 @@ def test_conv_backcompat(self): path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt') with warnings.catch_warnings(): warnings.simplefilter('ignore', SourceChangeWarning) - m = torch.load(path, encoding='utf-8') + # weights_only=False as this is a legacy use case that loads a module + m = torch.load(path, encoding='utf-8', weights_only=False) input = torch.randn((1, 1, 1, 1), dtype=torch.float) self.assertEqual(m(input).size(), (1, 1, 1, 1)) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 1feca105d60f48..ba2af0927c8e12 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -1,6 +1,7 @@ # Owner(s): ["module: nestedtensor"] import ast +import contextlib import io import itertools import math @@ -3657,7 +3658,8 @@ def _make_tensor( ["contig", "noncontig_transposed", "noncontig_with_holes"], name_fn=lambda c: c, ) - def test_serialization(self, device, dtype, contiguity): + @parametrize("weights_only", [True, False]) + def test_serialization(self, device, dtype, contiguity, weights_only): # Test with 3 cases: # 1. contiguous # 2. non-contiguous transposed @@ -3693,8 +3695,21 @@ def test_serialization(self, device, dtype, contiguity): with tempfile.TemporaryFile() as f: torch.save(nt, f) + safe_globals = [ + torch.nested._internal.nested_tensor.NestedTensor, + torch.nested._internal.nested_tensor._rebuild_njt, + set, + torch._dynamo.decorators._DimRange, + ] f.seek(0) - nt_loaded = torch.load(f) + ctx = ( + torch.serialization.safe_globals(safe_globals) + if weights_only + else contextlib.nullcontext() + ) + + with ctx: + nt_loaded = torch.load(f, weights_only=weights_only) self.assertIsNot(nt, nt_loaded) # we expect a new offsets tensor -> different nested int upon load diff --git a/test/test_serialization.py b/test/test_serialization.py index 331b8c85f9c783..f24886ac4cd251 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -1196,7 +1196,7 @@ def test_weights_only_error(self, unsafe_global): f.seek(0) if unsafe_global: with self.assertRaisesRegex(pickle.UnpicklingError, - r"use `torch.serialization.add_safe_globals\(\[TwoTensor\]\)` to allowlist"): + r"use `torch.serialization.add_safe_globals\(\[TwoTensor\]\)` or .* to allowlist"): torch.load(f, weights_only=True) else: with self.assertRaisesRegex(pickle.UnpicklingError, diff --git a/torch/_weights_only_unpickler.py b/torch/_weights_only_unpickler.py index c624ad914e8fe5..a2d83425d2be60 100644 --- a/torch/_weights_only_unpickler.py +++ b/torch/_weights_only_unpickler.py @@ -322,8 +322,9 @@ def load(self): else: raise UnpicklingError( f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. " - f"Please use `torch.serialization.add_safe_globals([{name}])` to allowlist " - "this global if you trust this class/function." + f"Please use `torch.serialization.add_safe_globals([{name}])` or the " + f"`torch.serialization.safe_globals([{name}])` context manager to allowlist this global " + "if you trust this class/function." ) elif key[0] == NEWOBJ[0]: args = self.stack.pop() diff --git a/torch/serialization.py b/torch/serialization.py index 857e70c23a1a96..352514d541505a 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -67,6 +67,7 @@ "skip_data", ] +IS_FBCODE = not hasattr(torch.version, "git_version") DEFAULT_PROTOCOL = 2 @@ -92,6 +93,10 @@ MAP_SHARED, MAP_PRIVATE = None, None # type: ignore[assignment] +def _default_to_weights_only(pickle_module): + return pickle_module is None and not IS_FBCODE + + # _serialization_tls is used to store thread local state specific to serialization # that needs to be propagated to other files, in particular we use this for # (1) map_location (needed for wrapper subclasses/third party devices to torch._utils) @@ -1205,7 +1210,7 @@ def load( # documentation. We need it so that Sphinx doesn't leak `pickle`s path from # the build environment (e.g. ` str: "is not supported yet. Please call torch.load outside the skip_data context manager." ) + weights_only_not_set = weights_only is None + + if weights_only_not_set: + weights_only = _default_to_weights_only(pickle_module) + true_values = ["1", "y", "yes", "true"] # Add ability to force safe only or non-safe weight loads via environment variables force_weights_only_load = ( @@ -1364,7 +1374,8 @@ def _get_wo_message(message: str) -> str: elif force_weights_only_load: weights_only = True elif force_no_weights_only_load: - if weights_only is None: + # TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD can only override if callsite did not explicitly set weights_only + if weights_only_not_set: warnings.warn( "Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the" "`weights_only` argument was not explicitly passed to `torch.load`, forcing weights_only=False.", @@ -1373,11 +1384,6 @@ def _get_wo_message(message: str) -> str: ) weights_only = False - if weights_only is None: - weights_only, warn_weights_only = False, True - else: - warn_weights_only = False - if weights_only: if pickle_module is not None: raise RuntimeError( @@ -1385,21 +1391,6 @@ def _get_wo_message(message: str) -> str: ) else: if pickle_module is None: - if warn_weights_only: - warnings.warn( - "You are using `torch.load` with `weights_only=False` (the current default value), which uses " - "the default pickle module implicitly. It is possible to construct malicious pickle data " - "which will execute arbitrary code during unpickling (See " - "https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). " - "In a future release, the default value for `weights_only` will be flipped to `True`. This " - "limits the functions that could be executed during unpickling. Arbitrary objects will no " - "longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the " - "user via `torch.serialization.add_safe_globals`. We recommend you start setting " - "`weights_only=True` for any use case where you don't have full control of the loaded file. " - "Please open an issue on GitHub for any issues related to this experimental feature.", - FutureWarning, - stacklevel=2, - ) pickle_module = pickle # make flipping default BC-compatible