From 41625408c41b29d69754f9f22441719c24bb043b Mon Sep 17 00:00:00 2001 From: Huy Do Date: Wed, 6 Nov 2024 03:08:29 +0000 Subject: [PATCH] Fix torch.load (torch.utils.benchmark) after #137602 (#139810) After #137602, the default `weights_only` has been set to True. This test is failing in trunk slow jobs atm benchmark_utils/test_benchmark_utils.py::TestBenchmarkUtils::test_collect_callgrind [GH job link](https://github.com/pytorch/pytorch/actions/runs/11672436111/job/32502454946) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/1aa71be56c39908893273bd9558b127159e1ef3a) Pull Request resolved: https://github.com/pytorch/pytorch/pull/139810 Approved by: https://github.com/kit1980 --- .../benchmark/utils/valgrind_wrapper/timer_interface.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py b/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py index 199a49bde20ff2..9525fd54aa8e12 100644 --- a/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py +++ b/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py @@ -457,7 +457,10 @@ def construct(self) -> str: elif wrapped_value.serialization == Serialization.TORCH: path = os.path.join(self._data_dir, f"{name}.pt") - load_lines.append(f"{name} = torch.load({repr(path)})") + # TODO: Figure out if we can use torch.serialization.add_safe_globals here + # Using weights_only=False after the change in + # https://dev-discuss.pytorch.org/t/bc-breaking-change-torch-load-is-being-flipped-to-use-weights-only-true-by-default-in-the-nightlies-after-137602/2573 + load_lines.append(f"{name} = torch.load({repr(path)}, weights_only=False)") torch.save(wrapped_value.value, path) elif wrapped_value.serialization == Serialization.TORCH_JIT: