Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support INT8 mixed-precision training from torchao? #578

Open
gau-nernst opened this issue Sep 14, 2024 · 8 comments
Open

Support INT8 mixed-precision training from torchao? #578

gau-nernst opened this issue Sep 14, 2024 · 8 comments
Labels
enhancement New feature or request

Comments

@gau-nernst
Copy link

Recently I worked on INT8 mixed-precision training in torchao. The relevant PR is here pytorch/ao#748

Preliminary results show that with torchtitan, it improves speed by 20% on 8x A100 with no noticeable difference in loss curve. See the PR for more details.

Would you be open to add an experimental flag for this in torchtitan? Similar to Float8 training. This can also help to profile and improve INT8 training performance directly in torchtitan for future perf optimization.

cc @msaroufim

@tianyu-l
Copy link
Contributor

cc: @weifengpy

@tianyu-l tianyu-l added the enhancement New feature or request label Sep 30, 2024
@weifengpy
Copy link
Contributor

@gau-nernst nice work! I took a look at the original torchao PR

  • did you land it as a torchtune full finetuning recipe? I am mostly curious about gaps for checkpointing if any
  • if performing rowwise scaling, how did you deal with backward in FSDP2? asking becuase rowwise became column wise in the backward.
  • is perf the main reason to adopt int8? the PR mentioned memeory savings. Is it because of int8 casting in fsdp hooks? otherwise it's bfloat16 casting

@gau-nernst
Copy link
Author

@weifengpy

  1. There is an on-going PR Integrate INT8 mixed-precision from torchao 0.7 torchtune#1552

I am mostly curious about gaps for checkpointing if any

What do you mean by this?

  1. Right now all-gather is still in BF16, so quantization will be done after all-gather. https://github.com/pytorch/ao/blob/ae3e7c68eae7085e13241cb3d6b39481868dd162/torchao/prototype/quantized_training/int8_mixed_precision.py#L117-L126.

Yea this ("rowwise became column wise in the backward") is the main problem preventing me from implementing INT8 all-gather.

  • If FSDP2 provides a way to customize all-gather to behave differently in forward and backward pass, at least we can do INT8 all-gather in forward, and fallback to BF16 all-gather for backward. But this is not possible for now (not a priority).
  • The way I'm doing right now can be seen as approximate matmul with INT8 matmul + row-wise + column-wise scaling. Notice that in backward pass, I use the original weight for column-wise scaling. But actually, I think it also makes sense to use row-wise quantized weight (from forward) for column-wise scaling in backward (i.e. dequant and re-quant). Will need some experiments to validate this.
  1. Yes, speed is the main reason. The original PR description includes the revamped README for training with INT8 (which includes INT8 weight-only quantized training i.e. only keep INT8 weight), hence memory is mentioned there.

Some extra thoughts.

  • For pre-training, it might be possible to do INT8 tensor-wise scaling. I was working on BitNet BitNet b1.58 training ao#930, which performs 1.58-bit tensor-wise scaling, and initial results look quite promising. SwitchBack also trains CLIP with INT8 tensor-wise scaling.
  • However, for fine-tuning, I don't think INT8 tensor-wise scaling is suitable. Tried it a bit with PTQ and the accuracy drop is quite significant (perhaps can be recovered after some finetuning). The current INT8 mixed-precision training (i.e. INT8 matmul + row-wise + column-wise scaling) works well for fine-tuning (see torchtune PR), likely because the approximate matmul is pretty good.

@weifengpy
Copy link
Contributor

weifengpy commented Oct 2, 2024

thanks for explaining everything in detail

What do you mean by this?

I thought model.parameters() might be some INT8 related tensor subclass

  • when saving state dict, whether we convert the sharded tensor subclass to higher precision
  • when loading state dict, whether we qunatize on full tensor first, and chunk/shard after

Right now all-gather is still in BF16, so quantization will be done after all-gather

yeah. INT8 all-gather might be the main justfication to land into torchtitan, since this repo is used to demonstrate composability with distributed api

for rowwise, if backward is too hard, are you comfortable with supporting INT8 all-gather with fully_shard(reshard_after_forward=False) ? In that case, we do not have all-gather in the backward

For pre-training, it might be possible to do INT8 tensor-wise scaling

if the numerics does not become too bad witht tensor-wise scaling, it's a great demonstration for INT8 all-gather

@gau-nernst
Copy link
Author

I thought model.parameters() might be some INT8 related tensor subclass

Oh yea right now I don't have any special logic with it. So that state_dict will be a tensor subclass wrapper holding the original high precision weight (NOT int8). For INT8 mixed-precision training, I only inject custom matmul logic, weights stay the same (same as FP8 training).

supporting INT8 all-gather with fully_shard(reshard_after_forward=False) ? In that case, we do not have all-gather in the backward

Does that mean INT8 post-all-gathered weights remain in memory starting from forward until backward? If that's the case, we can just do what I suggested earlier?

use row-wise quantized weight (from forward) for column-wise scaling in backward (i.e. dequant and re-quant)

More concretely:

Pass Original Suggested change
Forward FP32 weight -> all-gather -> row-wise quantize to INT8 FP32 weight -> row-wise quantize to INT8 -> all-gather
Backward FP32 weight -> all-gather -> column-wise quantize to INT8 FP32 weight -> row-wise quantize to INT8 -> all-gather -> dequant -> column-wise quantize to INT8

In other words, it differs in which version of the weight will be used for column-wise quantization in backward: whether to use the original weight, or use the row-wise quantized weight used in forward.

Otherwise, to just demonstrate INT8 all-gather, I think it is easier (and save efforts) to do INT8 tensor-wise scaling 🤣.

@weifengpy
Copy link
Contributor

Otherwise, to just demonstrate INT8 all-gather, I think it is easier (and save efforts) to do INT8 tensor-wise scaling

agree, having tensor-wise scaling is already a good thing. I will bring this topic for discussion with the team

@vkuzo
Copy link
Contributor

vkuzo commented Oct 3, 2024

I think long term it's great to unify training APIs in torchao, to enable torchtitan to work with float8/int8/mx/etc training in the same way. I'm working on this, no ETA yet.

Short term if someone wants to add int8 training to torchtitan as an experimental feature - SGTM personally, but I'll also defer to torchtitan folks on that.

@weifengpy
Copy link
Contributor

I will bring this topic for discussion with the team

we would love to have this feature after discussion. we can start with tensor-wise scaling. it's also consistent with our float8 offering

@mori360

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants