Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fix: Update create_dynamic_map to always return a float32 tensor #1521

Conversation

mitchellgoffpc
Copy link
Contributor

@mitchellgoffpc mitchellgoffpc commented Feb 17, 2025

bitsandbytes.functional.create_dynamic_map doesn't specify a dtype for the result tensor it creates, so it will use the torch default dtype. However, all of the cquantize_blockwise_*/cdequantize_blockwise_* functions in pythoninterface.cpp expect a float * for code. This discrepancy causes quantize_blockwise and dequantize_blockwise to crash or give incorrect outputs whenever the default torch dtype is modified before name2qmap["dynamic"] is created on the first call. (Obviously using torch.set_default_dtype isn't necessarily the best practice, but at the time of writing it's still in fairly common use, e.g. https://github.com/meta-llama/llama-models/blob/main/models/llama3/reference_impl/generation.py#L157.)

To reproduce:

import torch
from bitsandbytes.functional import quantize_blockwise, name2qmap

torch.manual_seed(42)
X = torch.randn(65536).cuda()
Q1, _ = quantize_blockwise(X)
Q2, _ = quantize_blockwise(X)
torch.testing.assert_close(Q1, Q2)

del name2qmap['dynamic']
torch.set_default_dtype(torch.float16)
Q3, _ = quantize_blockwise(X)
torch.testing.assert_close(Q1, Q3)

@mitchellgoffpc mitchellgoffpc changed the title Update create_dynamic_map to always return a float32 tensor Bug fix: Update create_dynamic_map to always return a float32 tensor Feb 17, 2025
Copy link

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@matthewdouglas matthewdouglas self-assigned this Feb 19, 2025
@matthewdouglas matthewdouglas self-requested a review February 19, 2025 00:35
@matthewdouglas
Copy link
Member

Makes sense, thanks!

@matthewdouglas matthewdouglas merged commit 8ed7d97 into bitsandbytes-foundation:main Feb 19, 2025
33 checks passed
@matthewdouglas matthewdouglas added the bug Something isn't working label Feb 19, 2025
@mitchellgoffpc mitchellgoffpc deleted the fix-dynamic-map-dtype branch February 19, 2025 00:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants