-
Notifications
You must be signed in to change notification settings - Fork 198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add experimental INT8 quantized training #644
Merged
Merged
Changes from all commits
Commits
Show all changes
45 commits
Select commit
Hold shift + click to select a range
3d42329
initial commit
gau-nernst eca170a
add tests
gau-nernst dd162a8
add training
gau-nernst b286f5d
support py3.9
gau-nernst 8a84aca
skip test for torch<2.3
gau-nernst ea47c7d
fix pytest
gau-nernst f20486b
fix adamw
gau-nernst 3415244
add some FSDP ops
gau-nernst 5d0e658
add more fsdp ops
gau-nernst d753476
more ops
gau-nernst 9c77800
add benchmark script
gau-nernst 158eb61
some organisation
gau-nernst db0290f
add FSDP test
gau-nernst 1c32b78
clean up
gau-nernst ff69121
update FSDP test
gau-nernst 45342ba
add compile test (things are crashing)
gau-nernst f1587a2
fix bias
gau-nernst 7f9102a
substantial update to tests
gau-nernst 0428330
fix compile for FSDP
gau-nernst 001422c
update readme. rename file
gau-nernst 2eb2787
speed up CI
gau-nernst d39caba
fix typo
gau-nernst de6aa25
fix typo
gau-nernst adbe47d
typos. unset some dynamo flags
gau-nernst 3fdf776
update readme
gau-nernst ea0ee4f
remove requires_grad, since it is unnecessary
gau-nernst 36d0e1a
remove note
gau-nernst 2360a97
Merge branch 'pytorch:main' into qt_int8
gau-nernst 9e19104
Merge branch 'main' into qt_int8
gau-nernst 6bc7621
don't set inductor flags
gau-nernst 6646c0b
rename
gau-nernst 00e25cf
update README
gau-nernst 927a6d1
rename optimizer
gau-nernst 8377707
Merge branch 'main' into qt_int8
gau-nernst de49e8b
update benchmark script
gau-nernst f80ac97
make compile explicit
gau-nernst e375c3d
update docs
gau-nernst 6396a95
Merge branch 'main' into qt_int8
gau-nernst 662c61f
use torch.optim.Adam to avoid FSDP optim compile bug
gau-nernst cc90298
update docs
gau-nernst f1c588b
update doc
gau-nernst f444fa6
update docs
gau-nernst 640ec2d
fix CI test
gau-nernst dad6560
skip test
gau-nernst 4924e8d
fix compiled test
gau-nernst File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
# pre-train a mini Llama2 on TinyStories with INT8 quantized training | ||
# pip install transformers sentencepiece wandb | ||
# | ||
# BF16 baseline: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile | ||
# INT8 QT: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_weight_only | ||
|
||
import os | ||
|
||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | ||
|
||
import argparse | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import torch | ||
import wandb | ||
from tqdm import tqdm | ||
from transformers import LlamaConfig, LlamaForCausalLM | ||
|
||
from torchao.prototype import low_bit_optim | ||
from torchao.prototype.quantized_training import int8_weight_only_quantized_training | ||
from torchao.quantization.quant_api import quantize_ | ||
|
||
|
||
def get_loss(model: LlamaForCausalLM, batch: torch.Tensor): | ||
return model(batch, labels=batch).loss | ||
|
||
|
||
def get_tinystories(): | ||
save_path = Path("tinystories.bin") | ||
|
||
if not save_path.exists(): | ||
import sentencepiece as spm | ||
from huggingface_hub import hf_hub_download | ||
|
||
tokenizer_path = hf_hub_download("meta-llama/Llama-2-7b", "tokenizer.model") | ||
tokenizer = spm.SentencePieceProcessor(tokenizer_path) | ||
assert tokenizer.vocab_size() < (1 << 16) # make sure we can use uint16 | ||
|
||
# do everything in memory. we have enough RAM | ||
filepath = hf_hub_download( | ||
"roneneldan/TinyStories", | ||
"TinyStoriesV2-GPT4-train.txt", | ||
repo_type="dataset", | ||
) | ||
stories = open(filepath).read().split("\n<|endoftext|>\n") | ||
|
||
tokens_list = [] | ||
chunk_size = 10_000 | ||
for i in tqdm(range(0, len(stories), chunk_size), desc="Tokenizing TinyStories"): | ||
chunk = stories[i : min(i + chunk_size, len(stories))] | ||
tokens_list.extend(tokenizer.Encode(chunk, add_bos=True, add_eos=True, num_threads=4)) | ||
|
||
total_size = sum(len(x) for x in tokens_list) | ||
mmap_tokens = np.memmap(save_path, dtype=np.uint16, mode="w+", shape=total_size) | ||
i = 0 | ||
for tokens in tokens_list: | ||
mmap_tokens[i : i + len(tokens)] = tokens | ||
i += len(tokens) | ||
mmap_tokens.flush() | ||
|
||
tokens = np.memmap(save_path, dtype=np.uint16, mode="r") | ||
return torch.from_numpy(tokens) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
# default config is 470M | ||
parser.add_argument("--d_model", type=int, default=1024) | ||
parser.add_argument("--depth", type=int, default=24) | ||
parser.add_argument("--ffn_size", type=int, default=4096) | ||
parser.add_argument("--head_dim", type=int, default=64) | ||
|
||
parser.add_argument("--quantize") | ||
parser.add_argument("--activation_checkpointing", action="store_true") | ||
parser.add_argument("--compile", action="store_true") | ||
|
||
parser.add_argument("--n_steps", type=int, default=1000) | ||
parser.add_argument("--batch_size", type=int, default=4) | ||
parser.add_argument("--seq_len", type=int, default=2048) | ||
|
||
parser.add_argument("--optim", default="AdamW") | ||
parser.add_argument("--lr", type=float, default=3e-4) | ||
parser.add_argument("--weight_decay", type=float, default=1e-2) | ||
|
||
parser.add_argument("--project", default="int8_quantized_training") | ||
parser.add_argument("--run_name") | ||
parser.add_argument("--seed", type=int) | ||
args = parser.parse_args() | ||
|
||
if args.seed is not None: | ||
torch.manual_seed(args.seed) | ||
|
||
config = LlamaConfig( | ||
hidden_size=args.d_model, | ||
intermediate_size=args.ffn_size, | ||
num_hidden_layers=args.depth, | ||
num_attention_heads=args.d_model // args.head_dim, | ||
max_position_embeddings=args.seq_len, | ||
use_cache=False, | ||
) | ||
model = LlamaForCausalLM(config).bfloat16().cuda() | ||
if args.activation_checkpointing: | ||
model.gradient_checkpointing_enable() | ||
if args.quantize == "int8_weight_only": | ||
quantize_(model, int8_weight_only_quantized_training(), set_inductor_config=False) | ||
elif args.quantize is not None: | ||
raise ValueError(f"Unsupported quantize={args.quantize}") | ||
print(f"No. of params: {sum(p.numel() for p in model.parameters()):,}") | ||
print(f"No. of buffers: {sum(p.numel() for p in model.buffers()):,}") | ||
|
||
# only use optimizers from torchao.prototype.low_bit_optim to support quantized training | ||
if args.optim == "AdamW": | ||
args.optim = "_AdamW" | ||
optim = getattr(low_bit_optim, args.optim)(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) | ||
|
||
data = get_tinystories().cuda() | ||
run = wandb.init(dir="/tmp", config=args, project=args.project, name=args.run_name) | ||
|
||
step = 0 | ||
log_interval = 50 | ||
pbar = tqdm(total=args.n_steps, dynamic_ncols=True) | ||
model.train() | ||
_get_loss = torch.compile(get_loss) if args.compile else get_loss | ||
|
||
while step < args.n_steps: | ||
# randomly select a continuous chunk, then reshape it | ||
idx = torch.randint(0, data.shape[0] - args.batch_size * args.seq_len, (1,)).item() | ||
batch = data[idx : idx + args.batch_size * args.seq_len].view(args.batch_size, args.seq_len).long() | ||
|
||
loss = _get_loss(model, batch) | ||
loss.backward() | ||
|
||
if step % log_interval == 0: | ||
log_dict = dict( | ||
loss=loss.item(), | ||
lr=optim.param_groups[0]["lr"], | ||
max_memory_allocated=torch.cuda.max_memory_allocated() / 1e9, | ||
max_memory_active=torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / 1e9, | ||
) | ||
run.log(log_dict, step=step) | ||
pbar.set_postfix(loss=log_dict["loss"]) | ||
|
||
optim.step() | ||
optim.zero_grad() | ||
|
||
step += 1 | ||
pbar.update() | ||
|
||
run.finish() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
import copy | ||
|
||
import pytest | ||
import torch | ||
import torch.nn.functional as F | ||
from torch import nn | ||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu | ||
from torch.testing._internal.common_fsdp import FSDPTest | ||
from torch.testing._internal.common_utils import TestCase, instantiate_parametrized_tests, parametrize, run_tests | ||
|
||
from torchao.prototype.low_bit_optim import _AdamW | ||
from torchao.prototype.quantized_training import Int8QTLinearWeight, int8_weight_only_quantized_training | ||
from torchao.quantization.quant_api import quantize_ | ||
from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4 | ||
|
||
if not TORCH_VERSION_AFTER_2_3: | ||
pytest.skip("Requires torch>=2.4", allow_module_level=True) | ||
|
||
|
||
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) | ||
|
||
|
||
def _reset(): | ||
# using TF32 will cause mixed mm to segfault with triton backend | ||
# fixed in nightly by https://github.com/pytorch/pytorch/pull/133173 | ||
# also required for correctness check | ||
torch.set_float32_matmul_precision("highest") | ||
torch._dynamo.reset() | ||
|
||
|
||
# we always use `quantize_(set_inductor_config=False)` to reduce compile time in CI. | ||
class TestQuantizedTraining(TestCase): | ||
@parametrize("device", _DEVICES) | ||
def test_int8_stochastic_rounding(self, device): | ||
x = torch.randn(32, device=device) | ||
x_samples = x.view(1, -1).repeat(100_000, 1) | ||
|
||
x_int8, x_scale = Int8QTLinearWeight.quantize(x_samples, stochastic_rounding=True) | ||
x_dequant_samples = x_int8 * x_scale.view(-1, 1) | ||
x_dequant_mean = x_dequant_samples.mean(0) | ||
|
||
# a more rigorous test would be to do a hypothesis testing. | ||
# due to the statistical nature, this assertion may still fail, though very rarely. | ||
torch.testing.assert_close(x_dequant_mean, x, atol=1e-4, rtol=1e-4) | ||
|
||
@parametrize("leading_dims", [(), (2,), (2, 4)]) | ||
@parametrize("bias", [False, True]) | ||
@parametrize("device", _DEVICES) | ||
def test_int8_linear(self, leading_dims, bias, device): | ||
_reset() | ||
embed_dim = 32 | ||
|
||
linear_fp32 = nn.Linear(embed_dim, embed_dim, bias=bias, device=device) | ||
linear_int8 = copy.deepcopy(linear_fp32) | ||
quantize_(linear_int8, int8_weight_only_quantized_training(), set_inductor_config=False) | ||
linear_fp32.weight.data = linear_int8.weight.data.dequantize() | ||
|
||
input_fp32 = torch.randn(leading_dims + (embed_dim,), device=device) | ||
input_int8 = input_fp32.clone() | ||
input_fp32.requires_grad_(True) | ||
input_int8.requires_grad_(True) | ||
|
||
# test forward | ||
out_fp32 = linear_fp32(input_fp32) | ||
out_int8 = linear_int8(input_int8) | ||
torch.testing.assert_close(out_fp32, out_int8) | ||
|
||
# test backward | ||
grad = torch.randn(leading_dims + (embed_dim,), device=device) | ||
out_fp32.backward(grad) | ||
out_int8.backward(grad) | ||
torch.testing.assert_close(input_fp32.grad, input_int8.grad) | ||
torch.testing.assert_close(linear_fp32.weight.grad, linear_int8.weight.grad) | ||
if bias: | ||
torch.testing.assert_close(linear_fp32.bias.grad, linear_int8.bias.grad) | ||
|
||
@parametrize("leading_dims", [(), (2,), (2, 4)]) | ||
@parametrize("bias", [False, True]) | ||
@parametrize("device", _DEVICES) | ||
def test_int8_linear_compile(self, leading_dims, bias, device): | ||
_reset() | ||
embed_dim = 128 | ||
|
||
linear_eager = nn.Linear(embed_dim, embed_dim, bias=bias, device=device) | ||
quantize_(linear_eager, int8_weight_only_quantized_training(), set_inductor_config=False) | ||
linear_compiled = copy.deepcopy(linear_eager) | ||
linear_compiled.compile() | ||
|
||
input_eager = torch.randn(leading_dims + (embed_dim,), device=device) * 10 | ||
input_compiled = input_eager.clone() | ||
input_eager.requires_grad_(True) | ||
input_compiled.requires_grad_(True) | ||
|
||
out_eager = linear_eager(input_eager) | ||
out_compiled = linear_compiled(input_compiled) | ||
torch.testing.assert_close(out_eager, out_compiled) | ||
|
||
grad = torch.randn(leading_dims + (embed_dim,), device=device) | ||
out_eager.backward(grad) | ||
out_compiled.backward(grad) | ||
torch.testing.assert_close(input_eager.grad, input_compiled.grad) | ||
torch.testing.assert_close(linear_eager.weight.grad, linear_compiled.weight.grad) | ||
if bias: | ||
torch.testing.assert_close(linear_eager.bias.grad, linear_compiled.bias.grad) | ||
|
||
@parametrize("compile", [False, True]) | ||
@parametrize("device", _DEVICES) | ||
def test_int8_linear_training(self, compile, device): | ||
_reset() | ||
bsize = 4 | ||
embed_dim = 32 | ||
n_classes = 10 | ||
|
||
model_fp32 = nn.Sequential( | ||
nn.Linear(embed_dim, embed_dim * 2, bias=False), | ||
nn.GELU(), | ||
nn.Linear(embed_dim * 2, n_classes), | ||
).to(device) | ||
model_int8 = copy.deepcopy(model_fp32) | ||
# don't set inductor flags to speed up CI time | ||
quantize_(model_int8, int8_weight_only_quantized_training(), set_inductor_config=False) | ||
|
||
if compile: | ||
model_fp32.compile() | ||
model_int8.compile() | ||
|
||
optim_fp32 = _AdamW(model_fp32.parameters()) | ||
optim_int8 = _AdamW(model_int8.parameters()) | ||
|
||
for _ in range(5): | ||
inputs = torch.randn(bsize, embed_dim, device=device) | ||
labels = torch.randint(n_classes, size=(bsize,), device=device) | ||
loss_fp32 = F.cross_entropy(model_fp32(inputs), labels) | ||
loss_int8 = F.cross_entropy(model_int8(inputs), labels) | ||
|
||
rel_error = abs(loss_int8.item() - loss_fp32.item()) / abs(loss_fp32.item()) | ||
assert rel_error < 2e-3, rel_error | ||
|
||
loss_fp32.backward() | ||
optim_fp32.step() | ||
optim_fp32.zero_grad() | ||
|
||
loss_int8.backward() | ||
optim_int8.step() | ||
optim_int8.zero_grad() | ||
|
||
|
||
class TestFSDP2(FSDPTest): | ||
@property | ||
def world_size(self) -> int: | ||
return 2 | ||
|
||
@skip_if_lt_x_gpu(2) | ||
def test_fsdp2(self): | ||
# FSDP2 + compiled quantized training fails with PyTorch 2.4 | ||
compile_layer_choices = [False] | ||
if TORCH_VERSION_AFTER_2_4: | ||
compile_layer_choices.append(True) | ||
|
||
self.run_subtests( | ||
{"compile_layer": compile_layer_choices}, | ||
self._test_fsdp2, | ||
) | ||
|
||
def _test_fsdp2(self, compile_layer): | ||
import torch.distributed as dist | ||
from torch.distributed._composable.fsdp import fully_shard | ||
from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer | ||
|
||
_reset() | ||
batch_size = 3 | ||
vocab_size = 32 | ||
seq_len = 64 | ||
model_args = ModelArgs( | ||
n_layers=2, | ||
n_heads=2, | ||
dim=128, | ||
vocab_size=vocab_size, | ||
max_seq_len=seq_len, | ||
dropout_p=0, | ||
) | ||
torch.manual_seed(42) | ||
base_model = Transformer(model_args).cuda() | ||
quantize_(base_model, int8_weight_only_quantized_training(), set_inductor_config=False) | ||
fsdp_model = copy.deepcopy(base_model) | ||
|
||
if compile_layer: | ||
for layer in base_model.layers: | ||
layer.compile() | ||
|
||
for layer in fsdp_model.layers: | ||
if compile_layer: | ||
layer.compile() | ||
fully_shard(layer) | ||
fully_shard(fsdp_model) | ||
|
||
base_optim = torch.optim.Adam(base_model.parameters(), lr=1e-2, foreach=False, fused=False) | ||
fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2, foreach=False, fused=False) | ||
|
||
torch.manual_seed(42 + self.rank + 1) | ||
for iter_idx in range(5): | ||
inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") | ||
fsdp_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) | ||
fsdp_loss = fsdp_model(inp).sum() | ||
fsdp_loss.backward() | ||
fsdp_optim.step() | ||
|
||
base_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) | ||
base_loss = base_model(inp).sum() | ||
base_loss.backward() | ||
for param in base_model.parameters(): | ||
if param.grad is not None: | ||
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG) | ||
base_optim.step() | ||
|
||
# due to stochastic rounding, use a pretty large tolerance here | ||
rel_error = (fsdp_loss - base_loss).abs() / base_loss.abs() | ||
assert rel_error < 0.05, rel_error | ||
|
||
|
||
instantiate_parametrized_tests(TestQuantizedTraining) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from .adam import Adam8bit, Adam4bit, AdamFp8 | ||
from .adamw import AdamW8bit, AdamW4bit, AdamWFp8 | ||
from .adamw import _AdamW, AdamW8bit, AdamW4bit, AdamWFp8 | ||
from .cpu_offload import CPUOffloadOptimizer |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it feels like we could probably add this as part of standard test suite, that we can use to sanity check if FSDP is supported for any dtype/tensor subclasses