Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantized Training #554

Open
msaroufim opened this issue Jul 29, 2024 · 4 comments
Open

Quantized Training #554

msaroufim opened this issue Jul 29, 2024 · 4 comments
Labels
good first issue Good for newcomers

Comments

@msaroufim
Copy link
Member

msaroufim commented Jul 29, 2024

Inspired by a recent back and forth with @gau-nernst we should add some quantized training recipes in AO for small models (600M param range)

Character.ai recently shared that they're working on quantized training https://research.character.ai/optimizing-inference/ where per @stephenroller they train models from scratch in int8 https://x.com/stephenroller/status/1816636257717436779

Historically we've invested more in QAT which @andrewor14 has led which is more of a technique to reduce perplexity when we do an eventual post training quantization.

Quantized training on the other hand actually quantizes the model at training time and so memory savings are observed both for training and inference

So when discussing quantized training there's a few aspects

  1. Weights they can be in one: fp16, fp8, int8, int4 and below
  2. Activations most likely limited to fp8, fp16
  3. Optimizer can be in one of: fp32, fp16, bf16, fp8, int8 and below

And if one were to ship this work, a bad combination can be validated at small scale (~600M parameter range) but a good idea needs to continuously be tested from (8b to 405b range) so each of these will need loss curves

When choosing the starting point, we could either pretrain a model using quantized training or just finetune it and as long as the loss curves match the fp16 baselines then we are good. We'd also need to of course validate that memory savings are there and what the speedups/slowdowns are.

And while we can merge a lot of the dtype conversion in AO and have some toy training loop in AO what I'm more optimistic about is having some end to end trainig recipe in https://github.com/pytorch/torchtitan @awgu and an end to end finetuning recipe https://github.com/pytorch/torchtune @ebsmothers @joecummings

@msaroufim msaroufim added the good first issue Good for newcomers label Jul 29, 2024
@gau-nernst
Copy link
Collaborator

gau-nernst commented Jul 29, 2024

Just want to add, there is also activation/computation dtype and gradient dtype. In my exploration, I still use activation/computation in BF16 and gradient in BF16 to match weight-only quant inference. Activation/computation can be in lower precision dtype also, such as INT8 act - INT8 weight to match dynamic quant inference, or FP8 act - FP8 weight to match current FP8 training recipe.

Lower precision gradient might not be possible? Will need to check existing works on this.

@gau-nernst
Copy link
Collaborator

gau-nernst commented Aug 6, 2024

Some extra info for future reference

For evaluating the effectiveness of quantized training

  • For INT8 quantized training, it's better to compare against BF16 training from scratch / pre-training, since INT8 PTQ is already very good.
  • For fine-tuning workflow, compare INT4 quantized training against the current INT4 QAT recipe.

Digging into AQT INT8 (update as I read more). Many things can be customized, but the basic config is:

  • Quantize both weight and activation to INT8 to use INT8 matmul. Channel-wise symmetric quantization so that the inner matmul can be INT8, and the scale is absorbed post-matmul.
  • Signed number scale to [-127,127]. Unsigned numbers (e.g. after ReLU) scale to [0, 255]. Q: how does UINT8 x INT8 matmul work?
  • Calculating scale (https://arxiv.org/abs/2105.03536):
    • For weights: absmax scaling.
    • For activations : train for 20% without activation quantization, keep track of absmax EMA. For the remaining 80%, apply activation quantization with that absmax EMA, and don't update activation scale anymore. (Note: the paper was for ResNet-50. Might not be needed for transformers using LayerNorm/RMSNorm)
  • Seems like master weight is still in BF16 -> similar to QAT. Backward is quantized too -> similar to current FP8 recipe.

@gau-nernst
Copy link
Collaborator

gau-nernst commented Aug 16, 2024

Found this interesting paper - Jetfire. ICML 2024 poster spotlight. With code release and custom CUDA kernels 😮

INT8 for everything, including activations and gradients. Tile-wise quantization. Also use 127 for scaling. Master weight in FP32.

Which also led to me an earlier paper - SwitchBack. Timm Dettmers is one of the authors 😆.

Dynamic quantization for everything (weight is still in high precision). Row-wise quant (i.e. batch dim - per token) for activation (forward) and grad output (backward). Tensor-wise quant for weight. INT8 matmul for forward (Y = X @ W.T) and input grad backward (X_grad = Y_grad @ W), while weight grad is FP16 matmul (W_grad = Y_grad.T @ X)

@jerryzh168
Copy link
Contributor

@gau-nernst thanks for the pointers, feels like these are good motivations to enable training with AffineQuantizedTensor since it will be general to support all kinds of quantization (per block, row-wise, per token) and both for dynamic quant and weight only quant. cc @andrewor14

yanbing-j pushed a commit to yanbing-j/ao that referenced this issue Dec 9, 2024
* split cpu eval CI by dtype

* fix

* differentiate names with checks

* keep one name the same as old

* fix
yanbing-j pushed a commit to yanbing-j/ao that referenced this issue Dec 9, 2024
* make --device fast the default

* Update iOS.md (pytorch#517)

* Update iOS.md

* Update iOS.md

* Pip to pip3 (pytorch#504)

* remove macos-12 test

* pip to pip3

* break aoti CI jobs separately (pytorch#500)

* init

* fixes

* more fixes

* fixes

* fix

* fix

* bug fix

* add objcopy update

* suppress int8

* undefined variable

---------

Co-authored-by: Michael Gschwind <mikekg@meta.com>

* Support llama3 in chat in run.cpp  (pytorch#486)

* refactor chat runner in preparation for llama3

* add sketch for llama3 prompt template and move to returning tokens

* fix tiktoken

* fixes to chat

* add default llama_ver

* Add tests for quantize json, add cuda device specification and precision to cuda.json (pytorch#519)

* remove code for no KV Cache path (pytorch#527)

* Update ADVANCED-USERS.md (pytorch#529)

Update Advanced Users description to reflect changes in the repo since the description was initially created.

* runner-aoti on cuda (pytorch#531)

* runner-aoti on cuda

* transfer results back to CPU

* transfer results back to CPU

* runner-aoti on cuda

* Update runner_build.md (pytorch#530)

Update description of runner and build process in runner_build.md

* clean up runner code a little (pytorch#532)

* clean up runner code a little

* update

* update

* pull out generate loop in chat

* updates

* edit docs

* typo

* move int8 linear class and function into qops.py (pytorch#534)

* add dtype tests for runner-aoti + runner-et (pytorch#539)

* add dtype tests for runner-aoti + runner-et

* typo

* Quantized embedding (pytorch#536)

* move int8 linear class and function into qops.py

* move Quantized Embedding to qops.py

* Move Linear int4 to qops (pytorch#537)

* move int8 linear class and function into qops.py

* move Quantized Embedding to qops.py

* move int4 linear to qops

* Revert "add dtype tests for runner-aoti + runner-et (pytorch#539)" (pytorch#548)

This reverts commit a7a24577a65be67ac9ae4dc05452f35d9c49e5d1.

* fix generate for llama3 (pytorch#538)

* fix generate for llama3

* switch more things to C

* remove C++ header

* add delegation visualization instructions (pytorch#551)

* Add dtype runner aoti (pytorch#552)

* add dtype tests for runner-aoti + runner-et

* typo

* add dtype test runner-aoti

* test sdpa with fp16 (pytorch#553)

* test sdpa with fp16

* kv cache fp32

* typo

* update (pytorch#560)

* Only support newest versions of lm-eval (pytorch#556)

Summary:
remove support for lm-eval 0.3 to reduce the options we have

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:

* split cpu eval CI by dtype (pytorch#554)

* split cpu eval CI by dtype

* fix

* differentiate names with checks

* keep one name the same as old

* fix

* Removing duplicate HF issue message from README (pytorch#559)

Co-authored-by: Michael Gschwind <61328285+mikekgfb@users.noreply.github.com>

* doc updates (pytorch#567)

* Add VM-safe MPS check

---------

Co-authored-by: Anthony Shoumikhin <anthony@shoumikh.in>
Co-authored-by: metascroy <161522778+metascroy@users.noreply.github.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Co-authored-by: lucylq <lfq@meta.com>
Co-authored-by: Jerry Zhang <jerryzh168@gmail.com>
Co-authored-by: Jack-Khuu <jack.khuu.7@gmail.com>
yanbing-j pushed a commit to yanbing-j/ao that referenced this issue Dec 9, 2024
* code beautification

* code beautification, move functions together

* make --device fast the default (pytorch#515)

* make --device fast the default

* Update iOS.md (pytorch#517)

* Update iOS.md

* Update iOS.md

* Pip to pip3 (pytorch#504)

* remove macos-12 test

* pip to pip3

* break aoti CI jobs separately (pytorch#500)

* init

* fixes

* more fixes

* fixes

* fix

* fix

* bug fix

* add objcopy update

* suppress int8

* undefined variable

---------

Co-authored-by: Michael Gschwind <mikekg@meta.com>

* Support llama3 in chat in run.cpp  (pytorch#486)

* refactor chat runner in preparation for llama3

* add sketch for llama3 prompt template and move to returning tokens

* fix tiktoken

* fixes to chat

* add default llama_ver

* Add tests for quantize json, add cuda device specification and precision to cuda.json (pytorch#519)

* remove code for no KV Cache path (pytorch#527)

* Update ADVANCED-USERS.md (pytorch#529)

Update Advanced Users description to reflect changes in the repo since the description was initially created.

* runner-aoti on cuda (pytorch#531)

* runner-aoti on cuda

* transfer results back to CPU

* transfer results back to CPU

* runner-aoti on cuda

* Update runner_build.md (pytorch#530)

Update description of runner and build process in runner_build.md

* clean up runner code a little (pytorch#532)

* clean up runner code a little

* update

* update

* pull out generate loop in chat

* updates

* edit docs

* typo

* move int8 linear class and function into qops.py (pytorch#534)

* add dtype tests for runner-aoti + runner-et (pytorch#539)

* add dtype tests for runner-aoti + runner-et

* typo

* Quantized embedding (pytorch#536)

* move int8 linear class and function into qops.py

* move Quantized Embedding to qops.py

* Move Linear int4 to qops (pytorch#537)

* move int8 linear class and function into qops.py

* move Quantized Embedding to qops.py

* move int4 linear to qops

* Revert "add dtype tests for runner-aoti + runner-et (pytorch#539)" (pytorch#548)

This reverts commit a7a24577a65be67ac9ae4dc05452f35d9c49e5d1.

* fix generate for llama3 (pytorch#538)

* fix generate for llama3

* switch more things to C

* remove C++ header

* add delegation visualization instructions (pytorch#551)

* Add dtype runner aoti (pytorch#552)

* add dtype tests for runner-aoti + runner-et

* typo

* add dtype test runner-aoti

* test sdpa with fp16 (pytorch#553)

* test sdpa with fp16

* kv cache fp32

* typo

* update (pytorch#560)

* Only support newest versions of lm-eval (pytorch#556)

Summary:
remove support for lm-eval 0.3 to reduce the options we have

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:

* split cpu eval CI by dtype (pytorch#554)

* split cpu eval CI by dtype

* fix

* differentiate names with checks

* keep one name the same as old

* fix

* Removing duplicate HF issue message from README (pytorch#559)

Co-authored-by: Michael Gschwind <61328285+mikekgfb@users.noreply.github.com>

* doc updates (pytorch#567)

* Add VM-safe MPS check

---------

Co-authored-by: Anthony Shoumikhin <anthony@shoumikh.in>
Co-authored-by: metascroy <161522778+metascroy@users.noreply.github.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Co-authored-by: lucylq <lfq@meta.com>
Co-authored-by: Jerry Zhang <jerryzh168@gmail.com>
Co-authored-by: Jack-Khuu <jack.khuu.7@gmail.com>

* add unpacking support (pytorch#525)

* add unpacking support

* fix typos and linter

* perform parallel prefill when possible (pytorch#568)

* perform parallel prefill when possible

* typo

* disable hack

* remove print

* remove debug messages which prevent export

* fixes

* stream results in generate.py (#571)

* remove logging interfering with export

---------

Co-authored-by: Anthony Shoumikhin <anthony@shoumikh.in>
Co-authored-by: metascroy <161522778+metascroy@users.noreply.github.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Co-authored-by: lucylq <lfq@meta.com>
Co-authored-by: Jerry Zhang <jerryzh168@gmail.com>
Co-authored-by: Jack-Khuu <jack.khuu.7@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

3 participants