-
Notifications
You must be signed in to change notification settings - Fork 121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FloatQuantization subclass #228
Comments
At some point we probably need to port The code for that is actually quite simple (https://github.com/Tiiiger/QPyTorch/blob/f58bba72113e696099ef3e15e06cf421a06ff289/qtorch/quant/quant_cpu/quant_cpu.cpp#L267-L300). I tried out locally and it seems we can implement that in pure PyTorch, and potentially
(NOTE: |
(PyTorch now has uint32 I think, but still no ops, but maybe needed bit ops can just be enabled on uint32 (and other unsigned dtypes)) |
@vkuzo has an excellent FP32<->FPx conversion functions for MX dtypes (only for FP4_E2M1, FP6_E3M2, FP6_E2M3, but should be straight-forward to extend to other custom FPx). For quantization-aware training/fine-tuning, this should be enough as we probably don't need bit-packing. For inference, @vayuda is working on generic bit-packing, which should be useful (although personally I would prefer explicit bit-packing for each FPx dtype). Not sure if the current state of torch compiler can fuse FPx dequant into triton matmul kernel. If not, we probably need to write custom fused dequant+matmul triton kernel (which can be fun to do 😃) |
Prototype FP8 quant using native PyTorch fp8 dtypes: https://github.com/gau-nernst/ao/blob/fp8wo/torchao/prototype/fp8wo/__init__.py Llama2-7B-chat on 4070Ti SUPER. tokens/s is measured with
Error with FP8 E5M2 FNUZ: The speed degradation compared to INT8 seems like because torch.compile cannot fuse import torch
import torch._inductor.config
# core dump if either of these flags are enabled
# torch._inductor.config.force_mixed_mm = True
# torch._inductor.config.use_mixed_mm = True
def f(a, b, s):
return torch.mm(a, b.to(a.dtype)) * s
fp16_act = torch.randn(1, 32).to(torch.bfloat16).cuda()
fp8_weight = torch.randn(32, 32).to(torch.float8_e5m2).cuda()
scales = torch.randn(32).to(torch.bfloat16).cuda()
torch.compile(f, mode="max-autotune", fullgraph=True)(fp16_act, fp8_weight, scales) Codegen output: https://gist.github.com/gau-nernst/4afd0f5b97368ecf26d54b5f3415b004 When either
One reason why FP8 quant is not so performant is probably because there is no optimized FP8->BF16 dtype conversion that can be fused with triton (either at triton level or torch.compile level, need to investigate further...) |
As I was reviewing #223
I was reminded of this PR #214
And I'd be curious what range of floating point numbers we can just express using sublcasses
The text was updated successfully, but these errors were encountered: