Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for offloading when using TorchAO >= 0.7.0 #3332

Merged
merged 5 commits into from
Jan 13, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import contextlib
import gc
import importlib
import inspect
import json
import logging
Expand Down Expand Up @@ -43,7 +44,7 @@
from .memory import clear_device_cache, get_xpu_available_memory
from .offload import load_offloaded_weight, offload_weight, save_offload_index
from .tqdm import is_tqdm_available, tqdm
from .versions import is_torch_version
from .versions import compare_versions, is_torch_version


if is_npu_available(check_device=False):
Expand Down Expand Up @@ -350,17 +351,19 @@ def set_module_tensor_to_device(
elif param_cls.__name__ in ["QTensor", "QBitsTensor"]:
new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad).to(device)
elif param_cls.__name__ in ["AffineQuantizedTensor"]:
new_value = torch.nn.Parameter(
param_cls(
new_value.layout_tensor,
new_value.block_size,
new_value.shape,
new_value.quant_min,
new_value.quant_max,
new_value.zero_point_domain,
),
requires_grad=old_value.requires_grad,
).to(device)
if importlib.util.find_spec("torchao") is not None and compare_versions("torchao", ">=", "0.7.0"):
# TorchAO v0.7.0 made layout_tensor an internal private variable and exposed tensor_impl
args = (new_value.tensor_impl,)
else:
args = (new_value.layout_tensor,)
args += (
new_value.block_size,
new_value.shape,
new_value.quant_min,
new_value.quant_max,
new_value.zero_point_domain,
)
new_value = torch.nn.Parameter(param_cls(*args), requires_grad=old_value.requires_grad).to(device)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the interest of making this easier to maintain - it seems like the most painful bit is that you are running one of AO's tensor subclass constructors, which does not have strong BC guarantees.

Is there a reason you need to call the constructor (to construct a new instance of this subclass) in the first place?

It sounds like the purpose of this logic is to take any parameters (which may be subclasses) that live on the wrong device, and in-place move them to the right device. So stepping back a bit:

(1) if you can, I would just call module.to(device) directly, which will save you from having to loop over and re-assign the params individually (hopefully this just works for you - if not, maybe we should discuss any problems with other core folks)

(2) if you need to reassign some params individually, can you do it without needing to reconstruct the subclass directly? I would probably do something like this:

# option 1:
new_value = torch.nn.Parameter(new_value.to(device=device), requires_grad=old_value.requires_grad)
# option 2:
new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad).to(device=device)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't really need it as you said. This was done for better compatibility with bitsandbytes as you can see in this PR https://github.com/huggingface/accelerate/pull/539/files. Would you like to check @a-r-r-o-w if the changes proposed by @bdhirsh works ?

else:
new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(device)

Expand Down
Loading