Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.load planned default flip for weights_only #7799

Closed
mikaylagawarecki opened this issue Aug 1, 2024 · 7 comments · Fixed by #8259
Closed

torch.load planned default flip for weights_only #7799

mikaylagawarecki opened this issue Aug 1, 2024 · 7 comments · Fixed by #8259

Comments

@mikaylagawarecki
Copy link

mikaylagawarecki commented Aug 1, 2024

TL;DR

PyTorch is planning a BC-breaking change in torch.load to flip the default for weights_only from None (i.e. False) to True (and have added a warning to this effect in torch 2.4 :) ) that will break loading of tensor serialized when on XLA.

Context

Instead of using the default Unpickler provided by pickle, torch.load(weights_only=True) uses a custom Unpickler that restricts the allowed GLOBALs in the checkpoint (classes and functions) to those here (that required to build state_dicts).

The purpose of this is towards addressing the issue of remote code execution when using torch.load.

Another feature of this is that users can allowlist certain globals using add_safe_globals (in torch 2.4) or the safe_globals context manager (in torch nightly), a simple example being

import torch
from torch.serialization import safe_globals

class MyTensor(...):
     pass
 
 t = MyTensor(torch.randn(2, 3))    
 torch.save(t, "ckpt.pt")
 
 # This fails saying that MyTensor is not an allowed GLOBAL
 # t1 = torch.load("ckpt.pt", weights_only=True)
 
 # This succeeds
 with safe_globals([MyTensor]):
     torch.load("ckpt.pt", weights_only=True)

How this affects XLA

Notably, XLA uses a special path that uses numpy for serialization/deserialization see here. However, we have made a decision not to include the numpy GLOBALS required for unpickling in the defaut list as we do not control the codepaths numpy implements for pickling (see relevant GLOBALs here)

Ask

Opening this issue to figure out the best way to move forward re above to make the flip as smooth as possible!

Ideally, it would be good if the path for serializing XLA tensors could be refactored to not use numpy and we would definitely accept a PR that implements this!

Separately, for existing checkpoints I imagine there will be something that needs to be done there.

cc @JackCaoG

@JackCaoG
Copy link
Collaborator

JackCaoG commented Aug 2, 2024

@will-cromar any thoughts?

mikaylagawarecki added a commit to pytorch/pytorch that referenced this issue Aug 13, 2024
…eights_only"

Tests on XLA shard not fixed yet but there is an issue here pytorch/xla#7799




cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
mikaylagawarecki added a commit to pytorch/pytorch that referenced this issue Aug 13, 2024
Tests on XLA shard not fixed yet but there is an issue here pytorch/xla#7799




cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
mikaylagawarecki added a commit to pytorch/pytorch that referenced this issue Aug 13, 2024
…eights_only"

Tests on XLA shard not fixed yet but there is an issue here pytorch/xla#7799




cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
mikaylagawarecki added a commit to pytorch/pytorch that referenced this issue Aug 13, 2024
Tests on XLA shard not fixed yet but there is an issue here pytorch/xla#7799




cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this issue Aug 16, 2024
Tests on XLA shard not fixed yet but there is an issue here pytorch/xla#7799

Pull Request resolved: #127627
Approved by: https://github.com/albanD
ghstack dependencies: #132349
@mikaylagawarecki
Copy link
Author

Gentle bump, @will-cromar what would be the best path forward here?

malfet pushed a commit to aditew01/pytorch that referenced this issue Sep 13, 2024
Tests on XLA shard not fixed yet but there is an issue here pytorch/xla#7799

Pull Request resolved: pytorch#127627
Approved by: https://github.com/albanD
ghstack dependencies: pytorch#132349
@JackCaoG
Copy link
Collaborator

Sorry @mikaylagawarecki for the delay, I will take a look next Monday.

@will-cromar
Copy link
Collaborator

Discussed this with @JackCaoG. We actually recommend against saving xla device tensors directly. We implement a wrapper that actually moves all of the data to CPU first:

cpu_data = _maybe_convert_to_cpu(data, convert=should_write_data)
if should_write_data:
torch.save(cpu_data, file_or_path)

Is it possible to do the same thing in torch? That is, if you see an xla tensor, move it to the CPU and then it always loads as a CPU tensor.

I saw there are some other backends that use the same code path. Is there a plan to handle those? If you are planning add some interface we can implement for custom serialization/deserialization logic, it should also be possible for us to save/load data in an XLA-friendly format (via xla::Literal). I want to avoid upstreaming/inlining a bunch of XLA-specific logic, though.

@mikaylagawarecki
Copy link
Author

mikaylagawarecki commented Sep 26, 2024

Hey @will-cromar, thanks for the response here! First of all, I am very glad to hear that XLA recommends saving CPU tensors 😄 That means that the weights_only flip will be very minimally BC-breaking for XLA.

Is it possible to do the same thing in torch? That is, if you see an xla tensor, move it to the CPU and then it always loads as a CPU tensor.

It is definitely possible, but ~technically is also a BC-break within torch (separate from the weights_only BC break)

I chatted with @JackCaoG a bit offline, sharing the summary of what we came to here

If users always do tensor.cpu() before passing to torch.save, that is great and yea seems like the blast radius of the BC-break will be small for XLA! I will nevertheless be trying to remove the numpy dependency for 2.6 🙂
So after the torch.load weights_only flip from False->True`
(1) checkpoints with tensors on XLA saved in <2.6 will have errors like "numpy._reconstruct" is not a safe GLOBAL, please add it to safe_globals if you trust it
(2) checkpoints with tensors on XLA saved in >=2.6 will work seamlessly with torch.load(weights_only=True) (I will have a fix to remove numpy dependency)
(3) users who always move their XLA tensors to CPU before saving will not be affected at all

For the (perhaps tiny) portion of users who save their tensors without moving to CPU, I am personally inclined to keep the "if save on XLA, load on XLA" behavior (with the numpy dependency removed)

But if you feel strongly that always forcing CPU within torch.save is the better path that sounds plausible too

mikaylagawarecki added a commit to pytorch/pytorch that referenced this issue Oct 9, 2024
…registration devices without numpy"

Related: pytorch/xla#7799 (comment) 

Follow ups: Do the same for maia and mtia

## Motivation

With the move to `weights_only` by default, we are making an explicit decision not to allowlist GLOBALs required to deserialize `numpy` tensors  by default. The implication is that backends relying on numpy for serialization will fail loudly when `torch.load` flips `weights_only`.

However, we make the observation that this dependency on numpy was legacy and is not actually needed anymore. So we can remove it, which aligns with our weights_only strategy.

## Why is this ok?

The following comment on why numpy is necessary for serialization is legacy

https://github.com/pytorch/pytorch/blob/c87c9f0a01f4840bd19ac5058960c9766dd15ef8/torch/_tensor.py#L303-L312

We no longer do the following, though it was the case 5 years ago in the PR that added this 
> CPU storage is reconstructed with randomly initialized data, moved onto backend device, and then storage is updated to the serialized content

**Instead what now happens is that CPU storage is constructed with data from the file **and then** moved onto backend device.**


Old behavior (`legacy_load`): https://github.com/ailzhang/pytorch/blob/67adda891a839691790a0dcd99062430050eff3b/torch/serialization.py#L620






[ghstack-poisoned]
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this issue Oct 9, 2024
…n devices without numpy (#137444)

Related: pytorch/xla#7799 (comment)

Follow ups: Do the same for maia and mtia

## Motivation

With the move to `weights_only` by default, we are making an explicit decision not to allowlist GLOBALs required to deserialize `numpy` tensors  by default. The implication is that backends relying on numpy for serialization will fail loudly when `torch.load` flips `weights_only`.

However, we make the observation that this dependency on numpy was legacy and is not actually needed anymore. So we can remove it, which aligns with our weights_only strategy.

## Why is this ok?

The following comment on why numpy is necessary for serialization is legacy

https://github.com/pytorch/pytorch/blob/c87c9f0a01f4840bd19ac5058960c9766dd15ef8/torch/_tensor.py#L303-L312

We no longer do the following, though it was the case 5 years ago in the PR that added this
> CPU storage is reconstructed with randomly initialized data, moved onto backend device, and then storage is updated to the serialized content

**Instead what now happens is that CPU storage is constructed with data from the file **and then** moved onto backend device.**

Old behavior (`legacy_load`): https://github.com/ailzhang/pytorch/blob/67adda891a839691790a0dcd99062430050eff3b/torch/serialization.py#L620

Pull Request resolved: #137444
Approved by: https://github.com/albanD
@mikaylagawarecki
Copy link
Author

mikaylagawarecki commented Oct 9, 2024

Hey @will-cromar fyi it seems like there is a test here that tests saving of tensors on xla device https://github.com/pytorch/xla/blob/master/test/test_operations.py#L1460

I removed the numpy dependency, but when I flip weights_only it seems that there are other GLOBALs specific to XLA that need to be allowlisted e.g. test currently fails with

_pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, do those steps only if you trust the source of the checkpoint. 
2024-10-09T17:00:12.0214875Z 	(1) Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
2024-10-09T17:00:12.0216437Z 	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
2024-10-09T17:00:12.0218501Z 	WeightsUnpickler error: Unsupported global: GLOBAL torch_xla.utils.serialization.TensorReference was not an allowed global by default. Please use `torch.serialization.add_safe_globals([TensorReference])` to allowlist this global if you trust this class/function

Could you look into this/allowlist the necessary globals in the xla library please, you can do this via torch.serialization.add_safe_globals([foo]) perhaps on import

@JackCaoG
Copy link
Collaborator

ok I can repo, looking..

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants