Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FloatQuantization subclass #228

Open
msaroufim opened this issue May 8, 2024 · 4 comments
Open

FloatQuantization subclass #228

msaroufim opened this issue May 8, 2024 · 4 comments
Labels
enhancement New feature or request

Comments

@msaroufim
Copy link
Member

As I was reviewing #223

I was reminded of this PR #214

And I'd be curious what range of floating point numbers we can just express using sublcasses

@msaroufim msaroufim added the enhancement New feature or request label May 8, 2024
@gau-nernst
Copy link
Collaborator

gau-nernst commented May 12, 2024

At some point we probably need to port float_quantize() from https://github.com/Tiiiger/QPyTorch (FP6-LLM use that to do FP16->FP6). The main logic is to handle correct rounding (we cannot just "erase" unwanted bits).

The code for that is actually quite simple (https://github.com/Tiiiger/QPyTorch/blob/f58bba72113e696099ef3e15e06cf421a06ff289/qtorch/quant/quant_cpu/quant_cpu.cpp#L267-L300). I tried out locally and it seems we can implement that in pure PyTorch, and potentially torch.compile() will make efficient CPU/GPU kernels for it (everything is elementwise op).

  • update: pure PyTorch impl is probably not good, since PyTorch doesn't have uint32 -> bit-wise ops on int32 can be problematic (>> is implementation-dependent, likely with sign-extension; << apparently is undefined behavior).
  • CUDA has some bit-wise math to perform fp32->fp16 on CPU (see float2half() and __internal_float2half() in cuda_fp16.hpp). Probably can be an inspiration to write our own FP32/16 -> FPx (together with qtorch).

(NOTE: float_quantize() from qtorch does not handle bit-packing. The output is in original dtype, so it's like fake quantization)

@vadimkantorov
Copy link

(PyTorch now has uint32 I think, but still no ops, but maybe needed bit ops can just be enabled on uint32 (and other unsigned dtypes))

@gau-nernst
Copy link
Collaborator

@vkuzo has an excellent FP32<->FPx conversion functions for MX dtypes (only for FP4_E2M1, FP6_E3M2, FP6_E2M3, but should be straight-forward to extend to other custom FPx). For quantization-aware training/fine-tuning, this should be enough as we probably don't need bit-packing. For inference, @vayuda is working on generic bit-packing, which should be useful (although personally I would prefer explicit bit-packing for each FPx dtype). Not sure if the current state of torch compiler can fuse FPx dequant into triton matmul kernel. If not, we probably need to write custom fused dequant+matmul triton kernel (which can be fun to do 😃)

@gau-nernst
Copy link
Collaborator

Prototype FP8 quant using native PyTorch fp8 dtypes: https://github.com/gau-nernst/ao/blob/fp8wo/torchao/prototype/fp8wo/__init__.py

Llama2-7B-chat on 4070Ti SUPER. tokens/s is measured with torchao/_models/llama/generate.py. PyTorch 2.4.0.dev20240610+cu124

Quant type (weight only) token/s Bandwidth GB/s
BF16 46.74 617.64
INT8 88.22 584.05
FP8 E4M3 FN 79.33 1048.31
FP8 E4M3 FNUZ 78.73 (output is gibberish) 1040.44
FP8 E5M2 82.53 1090.65
FP8 E5M2 FNUZ Error (see below)

Error with FP8 E5M2 FNUZ: Unsupported conversion from f16 to f8E5M2FNUZ with rounding mode rtne

The speed degradation compared to INT8 seems like because torch.compile cannot fuse act_bf16 @ weight_fp8.to(torch.bfloat16) * scales into a single kernel (thus memory bandwidth is also very high).

import torch
import torch._inductor.config

# core dump if either of these flags are enabled
# torch._inductor.config.force_mixed_mm = True
# torch._inductor.config.use_mixed_mm = True

def f(a, b, s):
    return torch.mm(a, b.to(a.dtype)) * s

fp16_act = torch.randn(1, 32).to(torch.bfloat16).cuda()
fp8_weight = torch.randn(32, 32).to(torch.float8_e5m2).cuda()
scales = torch.randn(32).to(torch.bfloat16).cuda()
torch.compile(f, mode="max-autotune", fullgraph=True)(fp16_act, fp8_weight, scales)

Codegen output: https://gist.github.com/gau-nernst/4afd0f5b97368ecf26d54b5f3415b004

When either use_mixed_mm or force_mixed_mm flag is set, I got core dump. (already opened an issue at PyTorch core pytorch/pytorch#128381)

loc("/tmp/torchinductor_thien/h3/ch3inellku5joxa2jz4iwhfnrcquf7fbmuq53uw4vr6kuoctvtzo.py":76:21): error:  size mismatch when packing elements for LLVM struct expected 4 but got 8
python: /root/.triton/llvm/llvm-6f44bb77-centos-x64/include/llvm/ADT/ArrayRef.h:257: const T& llvm::ArrayRef<T>::operator[](size_t) const [with T = mlir::Type; size_t = long unsigned int]: Assertion `Index < Length && "Invalid index!"' failed.

One reason why FP8 quant is not so performant is probably because there is no optimized FP8->BF16 dtype conversion that can be fused with triton (either at triton level or torch.compile level, need to investigate further...)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants