-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix TorchAO related bugs; revert device_map changes #10371
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to also add torchao to nightly test quantization matrix
diffusers/.github/workflows/nightly_tests.yml
Line 358 in 6dfaec3
- backend: "bitsandbytes" |
Just curious how we make sure that the quantization-related python packages are installed? I couldn't find relevant LoC that handles this, nor does the workflow file have relevant install commands to install gguf/bitsandbytes/torchao |
I did some more testing related to serialization with Ran all the nightly tests and added a few more changes. Everything is passing 🤞 Going to run the bitsandbytes fast/slow tests now |
Just for reference in future, the error when loading a serialized
cc @jerryzh168 |
Something like below works for loading import torch
from accelerate import init_empty_weights
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
# Serialize the model
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
subfolder="transformer",
quantization_config=TorchAoConfig("uint4wo"),
torch_dtype=torch.bfloat16,
)
transformer.save_pretrained("/path/to/flux_uint4wo", safe_serialization=False, max_shard_size="50GB")
# ...
# Load the model
state_dict = torch.load("/path/to/flux_uint4wo/diffusion_pytorch_model.bin", weights_only=False, map_location="cpu")
with init_empty_weights():
transformer = FluxTransformer2DModel.from_config("/path/to/flux_uint4wo/config.json")
transformer.load_state_dict(state_dict, strict=True, assign=True) We can't load it directly in diffusers because we use a hardcoded
|
@sayakpaul I'm seeing two test failures for BnB. I think they are unrelated but could you confirm when free? _______________________________________________________________________________________________________________________ SlowBnb8bitTests.test_generate_quality_dequantize _______________________________________________________________________________________________________________________
self = <bnb.test_mixed_int8.SlowBnb8bitTests testMethod=test_generate_quality_dequantize>
def test_generate_quality_dequantize(self):
r"""
Test that loading the model and unquantize it produce correct results.
"""
> self.pipeline_8bit.transformer.dequantize()
tests/quantization/bnb/test_mixed_int8.py:415:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
src/diffusers/models/modeling_utils.py:482: in dequantize
return hf_quantizer.dequantize(self)
src/diffusers/quantizers/base.py:205: in dequantize
model = self._dequantize(model)
src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py:558: in _dequantize
model = dequantize_and_replace(
src/diffusers/quantizers/bitsandbytes/utils.py:281: in dequantize_and_replace
model, has_been_replaced = _dequantize_and_replace(
src/diffusers/quantizers/bitsandbytes/utils.py:264: in _dequantize_and_replace
_, has_been_replaced = _dequantize_and_replace(
src/diffusers/quantizers/bitsandbytes/utils.py:264: in _dequantize_and_replace
_, has_been_replaced = _dequantize_and_replace(
src/diffusers/quantizers/bitsandbytes/utils.py:247: in _dequantize_and_replace
new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state))
src/diffusers/quantizers/bitsandbytes/utils.py:185: in dequantize_bnb_weight
out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
/opt/venv/lib/python3.10/site-packages/typing_extensions.py:2853: in wrapper
return arg(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
A = tensor([[127, 0, 0, ..., 0, 0, 0],
[ 0, 0, 0, ..., 0, 0, 0],
[ 0, 0, 0, .... 0, 0, ..., 0, 0, 0],
[ 0, 0, 0, ..., 0, 0, 127]], device='cuda:0',
dtype=torch.int8)
B = tensor([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
... ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]], device='cuda:0', dtype=torch.int8), SA = (torch.Size([256, 256]), 'col32')
SB = (torch.Size([1536, 256]), 'row'), out = None, Sout = None, dtype = torch.int32
@deprecated(
"igemmlt is deprecated and will be removed in a future release. Please use int8_linear_matmul instead.",
category=FutureWarning,
)
def igemmlt(
A: torch.Tensor,
B: torch.Tensor,
SA: Tuple[torch.Size, str],
SB: Tuple[torch.Size, str],
out: Optional[torch.Tensor] = None,
Sout: Optional[Tuple[torch.Size, str]] = None,
dtype=torch.int32,
):
if SA is not None and SA[1] != "row":
> raise NotImplementedError(f"Only row-major format inputs are supported, but got format `{SA[1]}`")
E NotImplementedError: Only row-major format inputs are supported, but got format `col32`
/opt/venv/lib/python3.10/site-packages/bitsandbytes/functional.py:2268: NotImplementedError
_________________________________________________________________________________________________________________________________ SlowBnb8bitTests.test_quality _________________________________________________________________________________________________________________________________
self = <bnb.test_mixed_int8.SlowBnb8bitTests testMethod=test_quality>
def test_quality(self):
output = self.pipeline_8bit(
prompt=self.prompt,
num_inference_steps=self.num_inference_steps,
generator=torch.manual_seed(self.seed),
output_type="np",
).images
out_slice = output[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.0376, 0.0359, 0.0015, 0.0449, 0.0479, 0.0098, 0.0083, 0.0295, 0.0295])
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
> self.assertTrue(max_diff < 1e-2)
E AssertionError: False is not true
tests/quantization/bnb/test_mixed_int8.py:378: AssertionError
======================================================================================================================================= warnings summary ========================================================================================================================================
tests/quantization/bnb/test_4bit.py: 12 warnings
tests/quantization/bnb/test_mixed_int8.py: 5 warnings
/__w/diffusers/diffusers/src/diffusers/utils/testing_utils.py:547: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
arry = torch.load(BytesIO(response.content))
tests/quantization/bnb/test_mixed_int8.py::BnB8bitBasicTests::test_keep_modules_in_fp32
tests/quantization/bnb/test_mixed_int8.py::BnB8bitTrainingTests::test_training
/opt/venv/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:315: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
tests/quantization/bnb/test_mixed_int8.py::SlowBnb8bitTests::test_generate_quality_dequantize
/__w/diffusers/diffusers/src/diffusers/quantizers/bitsandbytes/utils.py:181: FutureWarning: This function is deprecated. Please use `int8_double_quant` instead.
im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im)
tests/quantization/bnb/test_mixed_int8.py::SlowBnb8bitTests::test_generate_quality_dequantize
/__w/diffusers/diffusers/src/diffusers/quantizers/bitsandbytes/utils.py:182: FutureWarning: The layout transformation operations will be removed in a future release. Please use row-major layout only.
im, Sim = bnb.functional.transform(im, "col32")
tests/quantization/bnb/test_mixed_int8.py::SlowBnb8bitTests::test_generate_quality_dequantize
/opt/venv/lib/python3.10/site-packages/bitsandbytes/functional.py:2812: FutureWarning: This function is deprecated and will be removed in a future release.
prev_device = pre_call(A.device)
tests/quantization/bnb/test_mixed_int8.py::SlowBnb8bitTests::test_generate_quality_dequantize
/opt/venv/lib/python3.10/site-packages/bitsandbytes/functional.py:2818: FutureWarning: The layout transformation operations will be removed in a future release. Please use row-major layout only.
out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose)
tests/quantization/bnb/test_mixed_int8.py::SlowBnb8bitTests::test_generate_quality_dequantize
/opt/venv/lib/python3.10/site-packages/bitsandbytes/functional.py:2854: FutureWarning: This function is deprecated and will be removed in a future release.
post_call(prev_device)
tests/quantization/bnb/test_mixed_int8.py::SlowBnb8bitTests::test_generate_quality_dequantize
/__w/diffusers/diffusers/src/diffusers/quantizers/bitsandbytes/utils.py:184: FutureWarning: The layout transformation operations will be removed in a future release. Please use row-major layout only.
state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB)
tests/quantization/bnb/test_mixed_int8.py::SlowBnb8bitTests::test_generate_quality_dequantize
/__w/diffusers/diffusers/src/diffusers/quantizers/bitsandbytes/utils.py:185: FutureWarning: igemmlt is deprecated and will be removed in a future release. Please use int8_linear_matmul instead.
out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
==================================================================================================================================== short test summary info ====================================================================================================================================
FAILED tests/quantization/bnb/test_mixed_int8.py::SlowBnb8bitTests::test_generate_quality_dequantize - NotImplementedError: Only row-major format inputs are supported, but got format `col32`
FAILED tests/quantization/bnb/test_mixed_int8.py::SlowBnb8bitTests::test_quality - AssertionError: False is not true
==================================================================================================================== 2 failed, 42 passed, 26 warnings in 1020.55s (0:17:00) ===================================================================================================================== |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fixes!
Apart from the comments I left, I think it might make sense to also test (integration tests) loading from quantized checkpoints and making sure they are working as expected.
Basically what's done in:
diffusers/tests/quantization/bnb/test_4bit.py
Line 531 in 825979d
class SlowBnb4BitFluxTests(Base4bitTests): |
image.save("output.png") | ||
``` | ||
|
||
Some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cc: @jerryzh168. Is this known?
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
@a-r-r-o-w regarding the failures,
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
Fast tests all pass ✅ Fast test logs
Slow test for pre-serialized model pass ✅ Slow Preserialized test logs
Slow test for memory footprint passes ✅ Slow memory footprint test logs
Slow test for quantization precision and layer check pass ✅ Slow quantization logs
Looks good to merge I think! Thanks for the reviews everyone, and apologies for bothering you during the vacation period! Going to start the patch release in a bit |
* Revert "Add support for sharded models when TorchAO quantization is enabled (#10256)" This reverts commit 41ba8c0. * update tests * udpate * update * update * update device map tests * apply review suggestions * update * make style * fix * update docs * update tests * update workflow * update * improve tests * allclose tolerance * Update src/diffusers/models/modeling_utils.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update tests/quantization/torchao/test_torchao.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * improve tests * fix * update correct slices --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Reverts part of #10256
Currently, we support:
This PR:
Context: https://huggingface.slack.com/archives/C065E480NN9/p1735010991364189
Running slow tests now