-
Notifications
You must be signed in to change notification settings - Fork 207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BF16 support for Quant-LLM kernel #1147
Conversation
This approach is (much) slower than multiplying by 2^bias after the fact, so that's why it's not usable
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1147
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit a6de35a with merge base eb1fb3a (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work! Leave some comments for discussion.
Here are some benchmark results on Llama-2-7b-chat-hf. Tested on:
Wikitext perplexity (using
FP16 perf. python generate.py --compile --quantization fp6 --precision float16 Average tokens/sec: 111.55 BF16 perf. python generate.py --compile --quantization fp6 --precision bfloat16 Average tokens/sec: 111.19 |
|
||
# follow https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/kernel_test.py | ||
# doesn't seem to be the right way to check for correctness | ||
correct = (fp6_output - fp16_output).abs().mean() / fp16_output.abs().mean() < 1e-3 | ||
correct_fp16 = (fp6_output_fp16 - fp16_output).abs().mean() / fp16_output.abs().mean() < 1e-3 | ||
correct_bf16 = (fp6_output_bf16 - bf16_output).abs().mean() / bf16_output.abs().mean() < 1e-2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious. I saw that generally when BF16 is used, tolerance is quite higher than FP16. From your experience working on this, you do suspect any part of the code might result in this loss of precision? e.g. perhaps some parts are computed in BF16 instead of FP32. Or maybe it's just the way it is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All I know is that BF16 has fewer bits for the fraction (mantissa) than FP16 (10 bits vs. 7 bits), so that leads to lower precision for BF16 compared to FP16. I can't think of any part of the FP6 kernel that would inherently lead to more loss of precision for BF16.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the great work! I will request Mark and Charles to take another scan at the PR.
Edit: almost forgot. @tobiasvanderwerff Can you update the docs also? Mainly this one I think https://github.com/pytorch/ao/tree/main/torchao/dtypes/floatx. You can mention that you extended the original kernel to work with BF16 😄. I think there might be other places where we mentioned FPx kernel only works with FP16.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to adding docs, left a few small comments - this is close to getting landed
const size_t M_Global, | ||
const size_t N_Global, | ||
const size_t K_Global, | ||
float *Reduction_Workspace, // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32) | ||
int Split_K) | ||
{ | ||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 | ||
static_assert(std::is_same<InputDataType, half>::value || std::is_same<InputDataType, __nv_bfloat16>::value, "Type must be float or __nv_bfloat16"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in warning did you mean float16 instead of float
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes!
"r"(b[0]), "r"(b[1]), | ||
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); | ||
if constexpr (USE_BF16) { | ||
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO for myself is to add some comments explain this asm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to adding docs, left a few small comments - this is close to getting landed
Any check for `__CUDA_ARCH__` in `fp6_linear.cu` will always fail because `__CUDA_ARCH__` is undefined since all of the functions in `fp6_linear.cu` are host functions
@msaroufim @gau-nernst I updated the docs at https://github.com/pytorch/ao/tree/main/torchao/dtypes/floatx. I didn't update the benchmark results with BF16 however, because I don't have access to the same machine that the existing benchmarks results used. I'm a bit short on time today so I may have missed some additional docs that should be updated -- I'll double check this later. I also intend to do a sm75 test because that would be good to check I think. I still need to address some of your comments; I'll also do this a bit later. |
If this is not done, the kernel may still run but fail silently, leading to unexpected behavior
There are currently several ways of using `__CUDA_ARCH__` that lead to undefined behavior. See https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#cuda-arch for details of how `__CUDA_ARCH__` should not be used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new changes look good! Figuring out how to handle CUDA_ARCH took more efforts than expected, but now we all understand it better 😄.
I think there are some problems with CUDA CI right now. Will need to wait for that to be fixed.
cool work! |
Closes #998
This PR adds BF16 support for the existing FPx CUDA kernel (here), which was originally written for FP16 (see the paper for details). Since all recent models are trained and released with BF16, having BF16 support potentially improves accuracy for FPx models.
Most important changes
More details on modified exponent bias calculations
Adapting the FP6 kernel for BF16 introduces a complication regarding the exponent bias in the dequantization step compared to the original implementation for FP16. This required me to come up with a custom solution. For context: in section 4.2.2 and 5.3.1 of the FP6 paper, they mention that the exponent for FP16 (and also equivalent for BF16) is
Adding the constant bias terms is computationally expensive to do during dequantization, so instead they set$E^{\text{fp16}} = E^{\text{fpx}}$ during dequantization, followed by multiplication later on with $2^{\text{bias}^{\text{fp16}} − \text{bias}^{\text{fpx}}}$ to get the equivalent result in a more efficient way. This works fine for FP16, since $\text{bias}^{\text{fp16}} = 15$ and so $2^{\text{bias}^{\text{fp16}} − \text{bias}^{\text{fpx}}} = 2^{15-3} = 2^{12}$ for FP6, which fits into a single 32-bit integer value.
Translating this to BF16 however, is a little more tricky. Since$\text{bias}^{\text{bf16}} = 127$ (see Wikipedia), this would amount to multiplication by $2^{\text{bias}^{\text{bf16}} − \text{bias}^{\text{fpx}}} = 2^{127-3} = 2^{124}$ , which is too large to fit into a 32-bit or even 64-bit number. To address this, I experimented with a few approaches, which I highlight below.
Approach 1: Type punning (current choice)
EDIT: The solution below is now simplified by using the CUDA ldexpf function. Thanks to @gau-nernst for pointing me in the right direction!
Approach 2: Decompose the exponent bias into smaller values
A second approach is based on the insight that instead of multiplying by$2^{124}$ , we can notice that $2^{124} = 2^{31} \cdot 2^{31} \cdot 2^{31} \cdot 2^{31}$ . Since $2^{31}$ does fit into an ordinary 32-bit unsigned int, we can multiply 4 times by this value, like so:
This approach works, but I think it is less preferrable to approach 1 since it introduces 3 additional multiplications.
Benchmark
Tested on:
2.6.0.dev20241022
Final remarks
CC @gau-nernst