-
Notifications
You must be signed in to change notification settings - Fork 198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BitNet b1.58 training #930
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/930
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 77d5c0d with merge base 68e1886 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
elif args.quantize == "int8_mixed_precision": | ||
quantize_(model.layers, int8_mixed_precision_training(), set_inductor_config=False) | ||
|
||
elif args.quantize == "bitnet": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
optional: this is "change model architecture and then do quantization" which is pretty different from just "quantization". For code clarity, maybe we can either have an explicit preprocessing step to be called separately, or call the arg something like rmsnorm_model_surgery_then_quantize_bitnet
?
|
||
def _pack_i2_in_i8(x: Tensor): | ||
# NOTE: this is signed integer, so we have to mask before bit-shift | ||
return (x[:, ::4] << 6) | ((x[:, 1::4] & 0b11) << 4) | ((x[:, 2::4] & 0b11) << 2) | (x[:, 3::4] & 0b11) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
readability nit: write it out line by line with comments to make easier to understand for code readers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
LGTM for prototype but feel free to wait for other reviews if needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent work as usual! Feel free to merge this whenever you're ready @gau-nernst
* first upstream of BitNet * fix type annotation * skip bitnet test on cpu. add bitnet to benchmark script * add bitnet option to example training script. update backward * add FSDP2 test * remove FSDP2 mixed-precision workaround. cleanup test * fix typo * adjust tolerance * update command * add precompute scale for FSDP2 * fix typing * add test for precompute scale * rename * separate BitNet model surgery * minor fixes. add note on packing
This PR adds training code for BitNet b1.58 (ternary weights - 1.58 bit. The first version of BitNet is binary weights). This is implemented as tensor subclass and integrate nicely with the
quantize_()
API. I also added 2 extra optimizations:Not optimized for inference (yet). A good baseline for inference would be something like A8W2 kernel from GemLite
BitNet b1.58
BitNet b1.58 uses ternary weights: each parameter can only take on 3 distinct values {-1, 0, +1}, thus making a BitNet model very compact. BitNet uses tensor-wise abs-mean scaling for weights (quantize to ternary) and row-wise abs-max scaling for activations (quantize to INT8).
BitNet is originally trained with QAT: the weights and activations are fake-quantized, and straight-through estimator (STE) is used to calculate gradients with respect to floating point weights. This process adds extra overhead over standard training. Our implementation utilizes INT8 Tensor Cores to make up for this loss in speed. In fact, our implementation is faster than BF16 training in most cases.
Usage
Note: following the BitNet Training Tips, Code and FAQ, user should insert extra RMSNorm before each
nn.Linear
layers and also remove the original RMSNorm before attention and MLP modules. Callingquantize_(model, bitnet_training())
will NOT perform this for you. You can take a look at our example training scriptbenchmarks/quantized_training/pretrain_llama2.py
on how to do this for our Llama model.When used with FSDP2 training, you can pre-compute BitNet weight scales for the next iteration to synchronize all scales with a single all-reduce operation. This should be done after optimizer step.
Results
Convergence check Ran with my experimental repo https://github.com/gau-nernst/quantized-training. Llama2-1.1B (based on TinyLlama) for 1B tokens on FineWeb-Edu using 1x 4090. Baseline is full BF16 training. Each step is 4x8192 tokens.
Note: at this scale, we don't expect the loss curves to have a gap. According to Figure 2 of FAQ, the gap only appears at around 10B tokens.
Sanity benchmark with built-in training script Using
benchmarks/quantized_training/pretrain_llama2.py
. Llama2-1B on TinyStories, w/ 4090, 1k steps. Each step is 16x2048 tokens. PyTorch 2.4.0. Full BF16 trainingThe train loss is a bit strange, but I think training on TinyStories is not so reliable. Perhaps it's just numerics. Side note that the speedup is impressive is because INT8 tensor cores is very fast on 4090 (up to 3.5x faster than BF16 tensor cores).
FSDP2 benchmark w/ torchtitan Using https://github.com/gau-nernst/torchtitan/tree/bitnet. Llama3-8B on C4, default config, 4x A100.
torch==2.6.0.dev20240924
Note: due to the way torchtitan initialize weights, it's a bit troublesome to add extra RMSNorm layers as recommended by the paper. Thus, for benchmarks in torchtitan, I don't add extra RMSNorm.