Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BitNet b1.58 training #930

Merged
merged 19 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 49 additions & 13 deletions benchmarks/quantized_training/pretrain_llama2.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# pre-train a mini Llama2 on TinyStories with INT8 quantized training
# pip install huggingface_hub sentencepiece wandb
#
# BF16 baseline: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile
# INT8 QT: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_weight_only
# INT8 MP: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_mixed_precision
# BF16 baseline: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --bf16_model --compile
# INT8 QT: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --bf16_model --compile --quantize int8_weight_only
# INT8 MP: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --bf16_model --compile --quantize int8_mixed_precision
# BitNet: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --bf16_model --compile --quantize bitnet --modify_rmsnorm_for_bitnet

import os

Expand All @@ -20,14 +21,14 @@
from torch.utils.checkpoint import checkpoint
from tqdm import tqdm

from torchao._models.llama.model import ModelArgs, Transformer, transformer_configs
from torchao import quantize_
from torchao._models.llama.model import ModelArgs, Transformer, transformer_configs, RMSNorm
from torchao.prototype import low_bit_optim
from torchao.prototype.quantized_training import (
bitnet_training,
int8_mixed_precision_training,
int8_weight_only_quantized_training,
)
from torchao.quantization.quant_api import quantize_


# not official models
transformer_configs.update(
Expand Down Expand Up @@ -92,10 +93,14 @@ def get_tinystories():
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="470M", choices=transformer_configs.keys())
parser.add_argument("--bf16_model", action="store_true")
parser.add_argument("--bf16_amp", action="store_true")
parser.add_argument("--quantize")
parser.add_argument("--activation_checkpointing", action="store_true")
parser.add_argument("--compile", action="store_true")

parser.add_argument("--modify_rmsnorm_for_bitnet", action="store_true")

parser.add_argument("--n_steps", type=int, default=1000)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--seq_len", type=int, default=2048)
Expand All @@ -104,7 +109,7 @@ def get_tinystories():
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--weight_decay", type=float, default=1e-2)

parser.add_argument("--project", default="int8_quantized_training")
parser.add_argument("--project", default="quantized_training")
parser.add_argument("--run_name")
parser.add_argument("--seed", type=int)
parser.add_argument("--log_interval", type=int, default=10)
Expand All @@ -115,19 +120,47 @@ def get_tinystories():

config = ModelArgs.from_name(args.model)
config.block_size = args.seq_len
model = Transformer(config).bfloat16().cuda()
model = Transformer(config)
if args.bf16_model:
model.bfloat16()
model.cuda()
with torch.device("cuda"):
model.setup_caches(args.batch_size, args.seq_len, training=True)
if args.activation_checkpointing:
for layer in model.layers:
enable_activation_checkpointing(layer)

# as recommended by https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf
# section 3
if args.modify_rmsnorm_for_bitnet:
# remove old RMSNorm
for layer in model.layers:
layer.attention_norm = torch.nn.Identity()
layer.ffn_norm = torch.nn.Identity()

# insert new RMSNorm
def insert_rmsnorm(module: torch.nn.Module):
for name, child in module.named_children():
if isinstance(child, torch.nn.Linear):
w = child.weight
norm = RMSNorm(child.in_features).to(device=w.device, dtype=w.dtype)
setattr(module, name, torch.nn.Sequential(norm, child))
else:
insert_rmsnorm(child)

insert_rmsnorm(model.layers)

# don't apply int8_mixed_precision to LM head, since it can cause convergence issue.
# TODO: might want to do the same for int8_weight_only to standardize.
if args.quantize == "int8_weight_only":
quantize_(model, int8_weight_only_quantized_training(), set_inductor_config=False)

elif args.quantize == "int8_mixed_precision":
quantize_(model.layers, int8_mixed_precision_training(), set_inductor_config=False)

elif args.quantize == "bitnet":
quantize_(model.layers, bitnet_training(), set_inductor_config=False)

elif args.quantize is not None:
raise ValueError(f"Unsupported quantize={args.quantize}")

Expand Down Expand Up @@ -155,7 +188,8 @@ def get_tinystories():
idx = torch.randint(0, data.shape[0] - args.batch_size * args.seq_len, (1,)).item()
batch = data[idx : idx + args.batch_size * args.seq_len].view(args.batch_size, args.seq_len).long()

loss = _get_loss(model, batch)
with torch.autocast("cuda", torch.bfloat16, enabled=args.bf16_amp):
loss = _get_loss(model, batch)
loss.backward()

if step % args.log_interval == 0:
Expand All @@ -165,10 +199,6 @@ def get_tinystories():
max_memory_allocated=torch.cuda.max_memory_allocated() / 1e9,
max_memory_reserved=torch.cuda.max_memory_reserved() / 1e9,
)
if step > 0:
time1 = time.time()
log_dict["tokens_per_second"] = (args.log_interval * args.batch_size * args.seq_len) / (time1 - time0)
time0 = time1
run.log(log_dict, step=step)
pbar.set_postfix(loss=log_dict["loss"])

Expand All @@ -178,4 +208,10 @@ def get_tinystories():
step += 1
pbar.update()

if step % args.log_interval == 0:
time1 = time.time()
log_dict = dict(tokens_per_second=(args.log_interval * args.batch_size * args.seq_len) / (time1 - time0))
time0 = time1
run.log(log_dict, step=step)

run.finish()
144 changes: 114 additions & 30 deletions test/prototype/test_quantized_training.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_6

if not TORCH_VERSION_AT_LEAST_2_4:
pytest.skip("Requires torch>=2.4", allow_module_level=True)
Expand All @@ -11,7 +11,7 @@
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import TestCase, instantiate_parametrized_tests, parametrize, run_tests
Expand All @@ -20,6 +20,7 @@
from torchao.prototype.low_bit_optim import _AdamW
from torchao.prototype.quantized_training import (
Int8MixedPrecisionTrainingConfig,
bitnet_training,
int8_mixed_precision_training,
int8_weight_only_quantized_training,
quantize_int8_rowwise,
Expand Down Expand Up @@ -165,7 +166,7 @@ def test_int8_mixed_precision_training(self, compile, config):
embed_dim = 64
device = "cuda"

linear = nn.Linear(embed_dim, embed_dim).cuda()
linear = nn.Linear(embed_dim, embed_dim, device=device)
linear_int8mp = copy.deepcopy(linear)
quantize_(linear_int8mp, int8_mixed_precision_training(config), set_inductor_config=False)

Expand All @@ -187,6 +188,70 @@ def snr(ref, actual):
assert snr(inputs_ref.grad, inputs_int8mp.grad) > 20
assert snr(linear.weight.grad, linear_int8mp.weight.grad) > 20

@parametrize("compile", [False, True])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_bitnet_training(self, compile):
# reference implementation
# https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf
# Figure 3
class BitLinear(nn.Linear):
def activation_quant(self, x):
scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
return (x * scale).round().clamp_(-128, 127) / scale

def weight_quant(self, x):
scale = 1.0 / x.abs().mean().clamp_(min=1e-5)
return (x * scale).round().clamp_(-1, 1) / scale

def forward(self, x):
w = self.weight
x = x + (self.activation_quant(x) - x).detach()
w = w + (self.weight_quant(w) - w).detach()
return F.linear(x, w, self.bias)

_reset()
bsize = 4
embed_dim = 32
device = "cuda"

# only use 1 matmul shape to reduce triton autotune time
model_ref = nn.Sequential(
nn.Linear(embed_dim, embed_dim, bias=False),
nn.GELU(),
nn.Linear(embed_dim, embed_dim),
).to(device)
model = copy.deepcopy(model_ref)
quantize_(model, bitnet_training(), set_inductor_config=False)

# change model_ref to use BitLinear
model_ref[0].__class__ = BitLinear
model_ref[2].__class__ = BitLinear

if compile:
model_ref.compile()
model.compile()

optim_ref = torch.optim.AdamW(model_ref.parameters())
optim = torch.optim.AdamW(model.parameters())

for i in range(5):
inputs = torch.randn(bsize, embed_dim, device=device)
labels = torch.randint(embed_dim, size=(bsize,), device=device)
loss_ref = F.cross_entropy(model_ref(inputs), labels)
loss = F.cross_entropy(model(inputs), labels)

torch.testing.assert_close(loss, loss_ref)

loss_ref.backward()
optim_ref.step()
optim_ref.zero_grad()

loss.backward()
for p in model.parameters():
assert p.grad is not None
optim.step()
optim.zero_grad()


_FSDP_WORLD_SIZE = 2

Expand All @@ -198,35 +263,36 @@ def world_size(self) -> int:

@skip_if_lt_x_gpu(_FSDP_WORLD_SIZE)
def test_fsdp2_correctness(self):
mp_policy = MixedPrecisionPolicy()

# quantize_fn, mp_policy, tolerance
test_args = [
(
int8_weight_only_quantized_training(), # quantize_fn for base model
int8_weight_only_quantized_training(), # quantize_fn for FSDP model
MixedPrecisionPolicy(),
0.05, # tolerance. due to stochastic rounding, use a pretty large tolerance here
),
(
int8_mixed_precision_training(),
int8_mixed_precision_training(),
MixedPrecisionPolicy(),
1e-6,
),
(
# It's complicated (though possible) to simulate FSDP BF16 mixed-precision for base_model.
# We would need to cast all params to BF16 in forward and backward pass, while keeping
# the params in FP32 for optim step.
# torch.autocast() will only do this for F.linear() layer (and its backward).
# To keep it simple, we just use a larger tolerance here.
int8_mixed_precision_training(),
int8_mixed_precision_training(Int8MixedPrecisionTrainingConfig(fsdp_param_dtype=torch.bfloat16)),
MixedPrecisionPolicy(param_dtype=torch.bfloat16),
1e-2,
),
# high tolerance due to stochastic rounding
(int8_weight_only_quantized_training, mp_policy, 0.05),
(int8_mixed_precision_training, mp_policy, 1e-6),
(bitnet_training, mp_policy, 1e-5),
]

# FSDP2 mixed-precision requires https://github.com/pytorch/pytorch/pull/136129
if TORCH_VERSION_AT_LEAST_2_6:
# It's complicated (though possible) to simulate FSDP BF16 mixed-precision for base_model.
# We would need to cast all params to BF16 in forward and backward pass, while keeping
# the params in FP32 for optim step.
# torch.autocast() will only do this for F.linear() layer (and its backward).
# To keep it simple, we just use a larger tolerance here.
bf16_mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)

extra_args = [
(int8_weight_only_quantized_training, bf16_mp_policy, 1e-2),
(int8_mixed_precision_training, bf16_mp_policy, 1e-2),
(bitnet_training, bf16_mp_policy, 1e-2),
]
test_args.extend(extra_args)

self.run_subtests({"args": test_args}, self._run_subtest)

def _run_subtest(self, args):
base_quantize_fn, fsdp_quantize_fn, mp_policy, tolerance = args
quantize_fn, mp_policy, tolerance = args

batch_size = 3
vocab_size = 32
Expand All @@ -245,8 +311,8 @@ def _run_subtest(self, args):
base_model = Transformer(model_args).cuda()
fsdp_model = copy.deepcopy(base_model)

quantize_(base_model.layers, base_quantize_fn, set_inductor_config=False)
quantize_(fsdp_model.layers, fsdp_quantize_fn, set_inductor_config=False)
quantize_(base_model.layers, quantize_fn(), set_inductor_config=False)
quantize_(fsdp_model.layers, quantize_fn(), set_inductor_config=False)

for layer in fsdp_model.layers:
fully_shard(layer, mp_policy=mp_policy)
Expand Down Expand Up @@ -275,7 +341,25 @@ def _run_subtest(self, args):
base_optim.step()

rel_error = (fsdp_loss - base_loss).abs() / base_loss.abs()
assert rel_error < tolerance, (iter_idx, rel_error)
assert rel_error < tolerance, (quantize_fn.__name__, mp_policy, iter_idx, rel_error)

@skip_if_lt_x_gpu(_FSDP_WORLD_SIZE)
def test_precompute_bitnet_scale(self):
from torchao.prototype.quantized_training.bitnet import get_bitnet_scale, precompute_bitnet_scale_for_fsdp

model = nn.Sequential(nn.Linear(32, 64), nn.GELU(), nn.Linear(64, 32)).cuda()
model_fsdp = copy.deepcopy(model)
quantize_(model_fsdp, bitnet_training())
fully_shard(model_fsdp)

precompute_bitnet_scale_for_fsdp(model_fsdp)

torch.testing.assert_close(
get_bitnet_scale(model[0].weight), model_fsdp[0].weight._local_tensor._precomputed_scale
)
torch.testing.assert_close(
get_bitnet_scale(model[2].weight), model_fsdp[2].weight._local_tensor._precomputed_scale
)


instantiate_parametrized_tests(TestQuantizedTraining)
Expand Down
Loading
Loading