Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add experimental INT8 quantized training #644

Merged
merged 45 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
3d42329
initial commit
gau-nernst Aug 9, 2024
eca170a
add tests
gau-nernst Aug 9, 2024
dd162a8
add training
gau-nernst Aug 9, 2024
b286f5d
support py3.9
gau-nernst Aug 9, 2024
8a84aca
skip test for torch<2.3
gau-nernst Aug 9, 2024
ea47c7d
fix pytest
gau-nernst Aug 9, 2024
f20486b
fix adamw
gau-nernst Aug 9, 2024
3415244
add some FSDP ops
gau-nernst Aug 9, 2024
5d0e658
add more fsdp ops
gau-nernst Aug 10, 2024
d753476
more ops
gau-nernst Aug 10, 2024
9c77800
add benchmark script
gau-nernst Aug 10, 2024
158eb61
some organisation
gau-nernst Aug 10, 2024
db0290f
add FSDP test
gau-nernst Aug 10, 2024
1c32b78
clean up
gau-nernst Aug 10, 2024
ff69121
update FSDP test
gau-nernst Aug 10, 2024
45342ba
add compile test (things are crashing)
gau-nernst Aug 10, 2024
f1587a2
fix bias
gau-nernst Aug 10, 2024
7f9102a
substantial update to tests
gau-nernst Aug 10, 2024
0428330
fix compile for FSDP
gau-nernst Aug 10, 2024
001422c
update readme. rename file
gau-nernst Aug 10, 2024
2eb2787
speed up CI
gau-nernst Aug 10, 2024
d39caba
fix typo
gau-nernst Aug 10, 2024
de6aa25
fix typo
gau-nernst Aug 10, 2024
adbe47d
typos. unset some dynamo flags
gau-nernst Aug 10, 2024
3fdf776
update readme
gau-nernst Aug 10, 2024
ea0ee4f
remove requires_grad, since it is unnecessary
gau-nernst Aug 11, 2024
36d0e1a
remove note
gau-nernst Aug 11, 2024
2360a97
Merge branch 'pytorch:main' into qt_int8
gau-nernst Aug 11, 2024
9e19104
Merge branch 'main' into qt_int8
gau-nernst Aug 13, 2024
6bc7621
don't set inductor flags
gau-nernst Aug 13, 2024
6646c0b
rename
gau-nernst Aug 13, 2024
00e25cf
update README
gau-nernst Aug 13, 2024
927a6d1
rename optimizer
gau-nernst Aug 13, 2024
8377707
Merge branch 'main' into qt_int8
gau-nernst Aug 14, 2024
de49e8b
update benchmark script
gau-nernst Aug 14, 2024
f80ac97
make compile explicit
gau-nernst Aug 14, 2024
e375c3d
update docs
gau-nernst Aug 14, 2024
6396a95
Merge branch 'main' into qt_int8
gau-nernst Aug 16, 2024
662c61f
use torch.optim.Adam to avoid FSDP optim compile bug
gau-nernst Aug 16, 2024
cc90298
update docs
gau-nernst Aug 16, 2024
f1c588b
update doc
gau-nernst Aug 16, 2024
f444fa6
update docs
gau-nernst Aug 16, 2024
640ec2d
fix CI test
gau-nernst Aug 16, 2024
dad6560
skip test
gau-nernst Aug 16, 2024
4924e8d
fix compiled test
gau-nernst Aug 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions benchmarks/quantized_training/pretrain_llama2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# pre-train a mini Llama2 on TinyStories with INT8 quantized training
# pip install transformers sentencepiece wandb
#
# BF16 baseline: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile
# INT8 QT: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_weight_only

import os

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import argparse
from pathlib import Path

import numpy as np
import torch
import wandb
from tqdm import tqdm
from transformers import LlamaConfig, LlamaForCausalLM

from torchao.prototype import low_bit_optim
from torchao.prototype.quantized_training import int8_weight_only_quantized_training
from torchao.quantization.quant_api import quantize_


def get_loss(model: LlamaForCausalLM, batch: torch.Tensor):
return model(batch, labels=batch).loss


def get_tinystories():
save_path = Path("tinystories.bin")

if not save_path.exists():
import sentencepiece as spm
from huggingface_hub import hf_hub_download

tokenizer_path = hf_hub_download("meta-llama/Llama-2-7b", "tokenizer.model")
tokenizer = spm.SentencePieceProcessor(tokenizer_path)
assert tokenizer.vocab_size() < (1 << 16) # make sure we can use uint16

# do everything in memory. we have enough RAM
filepath = hf_hub_download(
"roneneldan/TinyStories",
"TinyStoriesV2-GPT4-train.txt",
repo_type="dataset",
)
stories = open(filepath).read().split("\n<|endoftext|>\n")

tokens_list = []
chunk_size = 10_000
for i in tqdm(range(0, len(stories), chunk_size), desc="Tokenizing TinyStories"):
chunk = stories[i : min(i + chunk_size, len(stories))]
tokens_list.extend(tokenizer.Encode(chunk, add_bos=True, add_eos=True, num_threads=4))

total_size = sum(len(x) for x in tokens_list)
mmap_tokens = np.memmap(save_path, dtype=np.uint16, mode="w+", shape=total_size)
i = 0
for tokens in tokens_list:
mmap_tokens[i : i + len(tokens)] = tokens
i += len(tokens)
mmap_tokens.flush()

tokens = np.memmap(save_path, dtype=np.uint16, mode="r")
return torch.from_numpy(tokens)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
# default config is 470M
parser.add_argument("--d_model", type=int, default=1024)
parser.add_argument("--depth", type=int, default=24)
parser.add_argument("--ffn_size", type=int, default=4096)
parser.add_argument("--head_dim", type=int, default=64)

parser.add_argument("--quantize")
parser.add_argument("--activation_checkpointing", action="store_true")
parser.add_argument("--compile", action="store_true")

parser.add_argument("--n_steps", type=int, default=1000)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--seq_len", type=int, default=2048)

parser.add_argument("--optim", default="AdamW")
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--weight_decay", type=float, default=1e-2)

parser.add_argument("--project", default="int8_quantized_training")
parser.add_argument("--run_name")
parser.add_argument("--seed", type=int)
args = parser.parse_args()

if args.seed is not None:
torch.manual_seed(args.seed)

config = LlamaConfig(
hidden_size=args.d_model,
intermediate_size=args.ffn_size,
num_hidden_layers=args.depth,
num_attention_heads=args.d_model // args.head_dim,
max_position_embeddings=args.seq_len,
use_cache=False,
)
model = LlamaForCausalLM(config).bfloat16().cuda()
if args.activation_checkpointing:
model.gradient_checkpointing_enable()
if args.quantize == "int8_weight_only":
quantize_(model, int8_weight_only_quantized_training(), set_inductor_config=False)
elif args.quantize is not None:
raise ValueError(f"Unsupported quantize={args.quantize}")
print(f"No. of params: {sum(p.numel() for p in model.parameters()):,}")
print(f"No. of buffers: {sum(p.numel() for p in model.buffers()):,}")

# only use optimizers from torchao.prototype.low_bit_optim to support quantized training
if args.optim == "AdamW":
args.optim = "_AdamW"
optim = getattr(low_bit_optim, args.optim)(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

data = get_tinystories().cuda()
run = wandb.init(dir="/tmp", config=args, project=args.project, name=args.run_name)

step = 0
log_interval = 50
pbar = tqdm(total=args.n_steps, dynamic_ncols=True)
model.train()
_get_loss = torch.compile(get_loss) if args.compile else get_loss

while step < args.n_steps:
# randomly select a continuous chunk, then reshape it
idx = torch.randint(0, data.shape[0] - args.batch_size * args.seq_len, (1,)).item()
batch = data[idx : idx + args.batch_size * args.seq_len].view(args.batch_size, args.seq_len).long()

loss = _get_loss(model, batch)
loss.backward()

if step % log_interval == 0:
log_dict = dict(
loss=loss.item(),
lr=optim.param_groups[0]["lr"],
max_memory_allocated=torch.cuda.max_memory_allocated() / 1e9,
max_memory_active=torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / 1e9,
)
run.log(log_dict, step=step)
pbar.set_postfix(loss=log_dict["loss"])

optim.step()
optim.zero_grad()

step += 1
pbar.update()

run.finish()
225 changes: 225 additions & 0 deletions test/prototype/test_quantized_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
import copy

import pytest
import torch
import torch.nn.functional as F
from torch import nn
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import TestCase, instantiate_parametrized_tests, parametrize, run_tests

from torchao.prototype.low_bit_optim import _AdamW
from torchao.prototype.quantized_training import Int8QTLinearWeight, int8_weight_only_quantized_training
from torchao.quantization.quant_api import quantize_
from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4

if not TORCH_VERSION_AFTER_2_3:
pytest.skip("Requires torch>=2.4", allow_module_level=True)


_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])


def _reset():
# using TF32 will cause mixed mm to segfault with triton backend
# fixed in nightly by https://github.com/pytorch/pytorch/pull/133173
# also required for correctness check
torch.set_float32_matmul_precision("highest")
torch._dynamo.reset()


# we always use `quantize_(set_inductor_config=False)` to reduce compile time in CI.
class TestQuantizedTraining(TestCase):
@parametrize("device", _DEVICES)
def test_int8_stochastic_rounding(self, device):
x = torch.randn(32, device=device)
x_samples = x.view(1, -1).repeat(100_000, 1)

x_int8, x_scale = Int8QTLinearWeight.quantize(x_samples, stochastic_rounding=True)
x_dequant_samples = x_int8 * x_scale.view(-1, 1)
x_dequant_mean = x_dequant_samples.mean(0)

# a more rigorous test would be to do a hypothesis testing.
# due to the statistical nature, this assertion may still fail, though very rarely.
torch.testing.assert_close(x_dequant_mean, x, atol=1e-4, rtol=1e-4)

@parametrize("leading_dims", [(), (2,), (2, 4)])
@parametrize("bias", [False, True])
@parametrize("device", _DEVICES)
def test_int8_linear(self, leading_dims, bias, device):
_reset()
embed_dim = 32

linear_fp32 = nn.Linear(embed_dim, embed_dim, bias=bias, device=device)
linear_int8 = copy.deepcopy(linear_fp32)
quantize_(linear_int8, int8_weight_only_quantized_training(), set_inductor_config=False)
linear_fp32.weight.data = linear_int8.weight.data.dequantize()

input_fp32 = torch.randn(leading_dims + (embed_dim,), device=device)
input_int8 = input_fp32.clone()
input_fp32.requires_grad_(True)
input_int8.requires_grad_(True)

# test forward
out_fp32 = linear_fp32(input_fp32)
out_int8 = linear_int8(input_int8)
torch.testing.assert_close(out_fp32, out_int8)

# test backward
grad = torch.randn(leading_dims + (embed_dim,), device=device)
out_fp32.backward(grad)
out_int8.backward(grad)
torch.testing.assert_close(input_fp32.grad, input_int8.grad)
torch.testing.assert_close(linear_fp32.weight.grad, linear_int8.weight.grad)
if bias:
torch.testing.assert_close(linear_fp32.bias.grad, linear_int8.bias.grad)

@parametrize("leading_dims", [(), (2,), (2, 4)])
@parametrize("bias", [False, True])
@parametrize("device", _DEVICES)
def test_int8_linear_compile(self, leading_dims, bias, device):
_reset()
embed_dim = 128

linear_eager = nn.Linear(embed_dim, embed_dim, bias=bias, device=device)
quantize_(linear_eager, int8_weight_only_quantized_training(), set_inductor_config=False)
linear_compiled = copy.deepcopy(linear_eager)
linear_compiled.compile()

input_eager = torch.randn(leading_dims + (embed_dim,), device=device) * 10
input_compiled = input_eager.clone()
input_eager.requires_grad_(True)
input_compiled.requires_grad_(True)

out_eager = linear_eager(input_eager)
out_compiled = linear_compiled(input_compiled)
torch.testing.assert_close(out_eager, out_compiled)

grad = torch.randn(leading_dims + (embed_dim,), device=device)
out_eager.backward(grad)
out_compiled.backward(grad)
torch.testing.assert_close(input_eager.grad, input_compiled.grad)
torch.testing.assert_close(linear_eager.weight.grad, linear_compiled.weight.grad)
if bias:
torch.testing.assert_close(linear_eager.bias.grad, linear_compiled.bias.grad)

@parametrize("compile", [False, True])
@parametrize("device", _DEVICES)
def test_int8_linear_training(self, compile, device):
_reset()
bsize = 4
embed_dim = 32
n_classes = 10

model_fp32 = nn.Sequential(
nn.Linear(embed_dim, embed_dim * 2, bias=False),
nn.GELU(),
nn.Linear(embed_dim * 2, n_classes),
).to(device)
model_int8 = copy.deepcopy(model_fp32)
# don't set inductor flags to speed up CI time
quantize_(model_int8, int8_weight_only_quantized_training(), set_inductor_config=False)

if compile:
model_fp32.compile()
model_int8.compile()

optim_fp32 = _AdamW(model_fp32.parameters())
optim_int8 = _AdamW(model_int8.parameters())

for _ in range(5):
inputs = torch.randn(bsize, embed_dim, device=device)
labels = torch.randint(n_classes, size=(bsize,), device=device)
loss_fp32 = F.cross_entropy(model_fp32(inputs), labels)
loss_int8 = F.cross_entropy(model_int8(inputs), labels)

rel_error = abs(loss_int8.item() - loss_fp32.item()) / abs(loss_fp32.item())
assert rel_error < 2e-3, rel_error

loss_fp32.backward()
optim_fp32.step()
optim_fp32.zero_grad()

loss_int8.backward()
optim_int8.step()
optim_int8.zero_grad()


class TestFSDP2(FSDPTest):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it feels like we could probably add this as part of standard test suite, that we can use to sanity check if FSDP is supported for any dtype/tensor subclasses

@property
def world_size(self) -> int:
return 2

@skip_if_lt_x_gpu(2)
def test_fsdp2(self):
# FSDP2 + compiled quantized training fails with PyTorch 2.4
compile_layer_choices = [False]
if TORCH_VERSION_AFTER_2_4:
compile_layer_choices.append(True)

self.run_subtests(
{"compile_layer": compile_layer_choices},
self._test_fsdp2,
)

def _test_fsdp2(self, compile_layer):
import torch.distributed as dist
from torch.distributed._composable.fsdp import fully_shard
from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer

_reset()
batch_size = 3
vocab_size = 32
seq_len = 64
model_args = ModelArgs(
n_layers=2,
n_heads=2,
dim=128,
vocab_size=vocab_size,
max_seq_len=seq_len,
dropout_p=0,
)
torch.manual_seed(42)
base_model = Transformer(model_args).cuda()
quantize_(base_model, int8_weight_only_quantized_training(), set_inductor_config=False)
fsdp_model = copy.deepcopy(base_model)

if compile_layer:
for layer in base_model.layers:
layer.compile()

for layer in fsdp_model.layers:
if compile_layer:
layer.compile()
fully_shard(layer)
fully_shard(fsdp_model)

base_optim = torch.optim.Adam(base_model.parameters(), lr=1e-2, foreach=False, fused=False)
fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2, foreach=False, fused=False)

torch.manual_seed(42 + self.rank + 1)
for iter_idx in range(5):
inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda")
fsdp_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
fsdp_loss = fsdp_model(inp).sum()
fsdp_loss.backward()
fsdp_optim.step()

base_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
base_loss = base_model(inp).sum()
base_loss.backward()
for param in base_model.parameters():
if param.grad is not None:
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
base_optim.step()

# due to stochastic rounding, use a pretty large tolerance here
rel_error = (fsdp_loss - base_loss).abs() / base_loss.abs()
assert rel_error < 0.05, rel_error


instantiate_parametrized_tests(TestQuantizedTraining)


if __name__ == "__main__":
run_tests()
2 changes: 1 addition & 1 deletion torchao/prototype/low_bit_optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .adam import Adam8bit, Adam4bit, AdamFp8
from .adamw import AdamW8bit, AdamW4bit, AdamWFp8
from .adamw import _AdamW, AdamW8bit, AdamW4bit, AdamWFp8
from .cpu_offload import CPUOffloadOptimizer
Loading
Loading