Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add weight tensor-wise scaling for INT8 quantized and mixed-precision training #1010

Open
gau-nernst opened this issue Oct 4, 2024 · 1 comment
Labels
enhancement New feature or request good first issue Good for newcomers

Comments

@gau-nernst
Copy link
Collaborator

https://github.com/pytorch/ao/tree/main/torchao/prototype/quantized_training

Currently INT8 training recipes only support row-wise scaling for weight. This should be strictly better than (or at least the same as) tensor-wise scaling for weight in terms of accuracy. However, this causes some issues in the backward pass, especially in FSDP2 if we want to support INT8 all-gather (cc pytorch/torchtitan#578). Some pointers

  • For pre-training, INT8 tensor-wise scaling for weight "should" be ok. This is basically SwitchBack. BitNet uses 1.58-bit tensor-wise scaling and demonstrates good results.
  • For fine-tuning, it will be bad out-of-the-box (imagine INT8 tensor-wise scaling for PTQ). "Might" be ok after fine-tuning. Will need some experiments on this.

Opening this issue to welcome new contributors. Shouldn't be too difficult I think.

For context, to highlight the key difference between quantized training and mixed-precision training

  • INT8 quantized training: Only keeps INT8 weight, don't keep high precision weight. Don't quantize activations. Use stochastic rounding for weight update
  • INT8 mixed-precision training: Keep high precision weight. Dynamically quantize weights (and activations) to INT8 to use INT8 tensor cores.
    • For this new feature (INT8 tensor-wise scaling for weight), I think activations should still be row-wise scaling, since there doesn't seem to be any benefits to use tensor-wise scaling for activations.
@gau-nernst gau-nernst added enhancement New feature or request good first issue Good for newcomers labels Oct 4, 2024
@vayuda
Copy link
Collaborator

vayuda commented Oct 9, 2024

So basically for quantize_int8_rowwise we would pass in a quantization granularity that could either be set to row-wise or tensor-wise. In the case of tensor-wise, even though the scale is just one float, by making it a tensor it would be able to be broadcasted and the rest of the functions wouldn't really need to be changed (besides also adding the granularity param to from_float())

Seems easy to do, but was wondering if the change was more involved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

2 participants