diff --git a/benchmarks/quantized_training/pretrain_llama2.py b/benchmarks/quantized_training/pretrain_llama2.py index fc87c2cd6e..0085f24264 100644 --- a/benchmarks/quantized_training/pretrain_llama2.py +++ b/benchmarks/quantized_training/pretrain_llama2.py @@ -1,9 +1,10 @@ # pre-train a mini Llama2 on TinyStories with INT8 quantized training # pip install huggingface_hub sentencepiece wandb # -# BF16 baseline: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile -# INT8 QT: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_weight_only -# INT8 MP: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_mixed_precision +# BF16 baseline: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --bf16_model --compile +# INT8 QT: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --bf16_model --compile --quantize int8_weight_only +# INT8 MP: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --bf16_model --compile --quantize int8_mixed_precision +# BitNet: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --bf16_model --compile --quantize bitnet --modify_rmsnorm_for_bitnet import os @@ -20,14 +21,14 @@ from torch.utils.checkpoint import checkpoint from tqdm import tqdm -from torchao._models.llama.model import ModelArgs, Transformer, transformer_configs +from torchao import quantize_ +from torchao._models.llama.model import ModelArgs, Transformer, transformer_configs, RMSNorm from torchao.prototype import low_bit_optim from torchao.prototype.quantized_training import ( + bitnet_training, int8_mixed_precision_training, int8_weight_only_quantized_training, ) -from torchao.quantization.quant_api import quantize_ - # not official models transformer_configs.update( @@ -92,10 +93,14 @@ def get_tinystories(): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", default="470M", choices=transformer_configs.keys()) + parser.add_argument("--bf16_model", action="store_true") + parser.add_argument("--bf16_amp", action="store_true") parser.add_argument("--quantize") parser.add_argument("--activation_checkpointing", action="store_true") parser.add_argument("--compile", action="store_true") + parser.add_argument("--modify_rmsnorm_for_bitnet", action="store_true") + parser.add_argument("--n_steps", type=int, default=1000) parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--seq_len", type=int, default=2048) @@ -104,7 +109,7 @@ def get_tinystories(): parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--weight_decay", type=float, default=1e-2) - parser.add_argument("--project", default="int8_quantized_training") + parser.add_argument("--project", default="quantized_training") parser.add_argument("--run_name") parser.add_argument("--seed", type=int) parser.add_argument("--log_interval", type=int, default=10) @@ -115,19 +120,47 @@ def get_tinystories(): config = ModelArgs.from_name(args.model) config.block_size = args.seq_len - model = Transformer(config).bfloat16().cuda() + model = Transformer(config) + if args.bf16_model: + model.bfloat16() + model.cuda() with torch.device("cuda"): model.setup_caches(args.batch_size, args.seq_len, training=True) if args.activation_checkpointing: for layer in model.layers: enable_activation_checkpointing(layer) + # as recommended by https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf + # section 3 + if args.modify_rmsnorm_for_bitnet: + # remove old RMSNorm + for layer in model.layers: + layer.attention_norm = torch.nn.Identity() + layer.ffn_norm = torch.nn.Identity() + + # insert new RMSNorm + def insert_rmsnorm(module: torch.nn.Module): + for name, child in module.named_children(): + if isinstance(child, torch.nn.Linear): + w = child.weight + norm = RMSNorm(child.in_features).to(device=w.device, dtype=w.dtype) + setattr(module, name, torch.nn.Sequential(norm, child)) + else: + insert_rmsnorm(child) + + insert_rmsnorm(model.layers) + # don't apply int8_mixed_precision to LM head, since it can cause convergence issue. # TODO: might want to do the same for int8_weight_only to standardize. if args.quantize == "int8_weight_only": quantize_(model, int8_weight_only_quantized_training(), set_inductor_config=False) + elif args.quantize == "int8_mixed_precision": quantize_(model.layers, int8_mixed_precision_training(), set_inductor_config=False) + + elif args.quantize == "bitnet": + quantize_(model.layers, bitnet_training(), set_inductor_config=False) + elif args.quantize is not None: raise ValueError(f"Unsupported quantize={args.quantize}") @@ -155,7 +188,8 @@ def get_tinystories(): idx = torch.randint(0, data.shape[0] - args.batch_size * args.seq_len, (1,)).item() batch = data[idx : idx + args.batch_size * args.seq_len].view(args.batch_size, args.seq_len).long() - loss = _get_loss(model, batch) + with torch.autocast("cuda", torch.bfloat16, enabled=args.bf16_amp): + loss = _get_loss(model, batch) loss.backward() if step % args.log_interval == 0: @@ -165,10 +199,6 @@ def get_tinystories(): max_memory_allocated=torch.cuda.max_memory_allocated() / 1e9, max_memory_reserved=torch.cuda.max_memory_reserved() / 1e9, ) - if step > 0: - time1 = time.time() - log_dict["tokens_per_second"] = (args.log_interval * args.batch_size * args.seq_len) / (time1 - time0) - time0 = time1 run.log(log_dict, step=step) pbar.set_postfix(loss=log_dict["loss"]) @@ -178,4 +208,10 @@ def get_tinystories(): step += 1 pbar.update() + if step % args.log_interval == 0: + time1 = time.time() + log_dict = dict(tokens_per_second=(args.log_interval * args.batch_size * args.seq_len) / (time1 - time0)) + time0 = time1 + run.log(log_dict, step=step) + run.finish() diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index b07ade0b54..faecb6b2d2 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -1,6 +1,6 @@ import pytest -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_6 if not TORCH_VERSION_AT_LEAST_2_4: pytest.skip("Requires torch>=2.4", allow_module_level=True) @@ -11,7 +11,7 @@ import torch.distributed as dist import torch.nn.functional as F from torch import nn -from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy +from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import FSDPTest from torch.testing._internal.common_utils import TestCase, instantiate_parametrized_tests, parametrize, run_tests @@ -20,6 +20,7 @@ from torchao.prototype.low_bit_optim import _AdamW from torchao.prototype.quantized_training import ( Int8MixedPrecisionTrainingConfig, + bitnet_training, int8_mixed_precision_training, int8_weight_only_quantized_training, quantize_int8_rowwise, @@ -165,7 +166,7 @@ def test_int8_mixed_precision_training(self, compile, config): embed_dim = 64 device = "cuda" - linear = nn.Linear(embed_dim, embed_dim).cuda() + linear = nn.Linear(embed_dim, embed_dim, device=device) linear_int8mp = copy.deepcopy(linear) quantize_(linear_int8mp, int8_mixed_precision_training(config), set_inductor_config=False) @@ -187,6 +188,70 @@ def snr(ref, actual): assert snr(inputs_ref.grad, inputs_int8mp.grad) > 20 assert snr(linear.weight.grad, linear_int8mp.weight.grad) > 20 + @parametrize("compile", [False, True]) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_bitnet_training(self, compile): + # reference implementation + # https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf + # Figure 3 + class BitLinear(nn.Linear): + def activation_quant(self, x): + scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5) + return (x * scale).round().clamp_(-128, 127) / scale + + def weight_quant(self, x): + scale = 1.0 / x.abs().mean().clamp_(min=1e-5) + return (x * scale).round().clamp_(-1, 1) / scale + + def forward(self, x): + w = self.weight + x = x + (self.activation_quant(x) - x).detach() + w = w + (self.weight_quant(w) - w).detach() + return F.linear(x, w, self.bias) + + _reset() + bsize = 4 + embed_dim = 32 + device = "cuda" + + # only use 1 matmul shape to reduce triton autotune time + model_ref = nn.Sequential( + nn.Linear(embed_dim, embed_dim, bias=False), + nn.GELU(), + nn.Linear(embed_dim, embed_dim), + ).to(device) + model = copy.deepcopy(model_ref) + quantize_(model, bitnet_training(), set_inductor_config=False) + + # change model_ref to use BitLinear + model_ref[0].__class__ = BitLinear + model_ref[2].__class__ = BitLinear + + if compile: + model_ref.compile() + model.compile() + + optim_ref = torch.optim.AdamW(model_ref.parameters()) + optim = torch.optim.AdamW(model.parameters()) + + for i in range(5): + inputs = torch.randn(bsize, embed_dim, device=device) + labels = torch.randint(embed_dim, size=(bsize,), device=device) + loss_ref = F.cross_entropy(model_ref(inputs), labels) + loss = F.cross_entropy(model(inputs), labels) + + torch.testing.assert_close(loss, loss_ref) + + loss_ref.backward() + optim_ref.step() + optim_ref.zero_grad() + + loss.backward() + for p in model.parameters(): + assert p.grad is not None + optim.step() + optim.zero_grad() + _FSDP_WORLD_SIZE = 2 @@ -198,35 +263,36 @@ def world_size(self) -> int: @skip_if_lt_x_gpu(_FSDP_WORLD_SIZE) def test_fsdp2_correctness(self): + mp_policy = MixedPrecisionPolicy() + + # quantize_fn, mp_policy, tolerance test_args = [ - ( - int8_weight_only_quantized_training(), # quantize_fn for base model - int8_weight_only_quantized_training(), # quantize_fn for FSDP model - MixedPrecisionPolicy(), - 0.05, # tolerance. due to stochastic rounding, use a pretty large tolerance here - ), - ( - int8_mixed_precision_training(), - int8_mixed_precision_training(), - MixedPrecisionPolicy(), - 1e-6, - ), - ( - # It's complicated (though possible) to simulate FSDP BF16 mixed-precision for base_model. - # We would need to cast all params to BF16 in forward and backward pass, while keeping - # the params in FP32 for optim step. - # torch.autocast() will only do this for F.linear() layer (and its backward). - # To keep it simple, we just use a larger tolerance here. - int8_mixed_precision_training(), - int8_mixed_precision_training(Int8MixedPrecisionTrainingConfig(fsdp_param_dtype=torch.bfloat16)), - MixedPrecisionPolicy(param_dtype=torch.bfloat16), - 1e-2, - ), + # high tolerance due to stochastic rounding + (int8_weight_only_quantized_training, mp_policy, 0.05), + (int8_mixed_precision_training, mp_policy, 1e-6), + (bitnet_training, mp_policy, 1e-5), ] + + # FSDP2 mixed-precision requires https://github.com/pytorch/pytorch/pull/136129 + if TORCH_VERSION_AT_LEAST_2_6: + # It's complicated (though possible) to simulate FSDP BF16 mixed-precision for base_model. + # We would need to cast all params to BF16 in forward and backward pass, while keeping + # the params in FP32 for optim step. + # torch.autocast() will only do this for F.linear() layer (and its backward). + # To keep it simple, we just use a larger tolerance here. + bf16_mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16) + + extra_args = [ + (int8_weight_only_quantized_training, bf16_mp_policy, 1e-2), + (int8_mixed_precision_training, bf16_mp_policy, 1e-2), + (bitnet_training, bf16_mp_policy, 1e-2), + ] + test_args.extend(extra_args) + self.run_subtests({"args": test_args}, self._run_subtest) def _run_subtest(self, args): - base_quantize_fn, fsdp_quantize_fn, mp_policy, tolerance = args + quantize_fn, mp_policy, tolerance = args batch_size = 3 vocab_size = 32 @@ -245,8 +311,8 @@ def _run_subtest(self, args): base_model = Transformer(model_args).cuda() fsdp_model = copy.deepcopy(base_model) - quantize_(base_model.layers, base_quantize_fn, set_inductor_config=False) - quantize_(fsdp_model.layers, fsdp_quantize_fn, set_inductor_config=False) + quantize_(base_model.layers, quantize_fn(), set_inductor_config=False) + quantize_(fsdp_model.layers, quantize_fn(), set_inductor_config=False) for layer in fsdp_model.layers: fully_shard(layer, mp_policy=mp_policy) @@ -275,7 +341,25 @@ def _run_subtest(self, args): base_optim.step() rel_error = (fsdp_loss - base_loss).abs() / base_loss.abs() - assert rel_error < tolerance, (iter_idx, rel_error) + assert rel_error < tolerance, (quantize_fn.__name__, mp_policy, iter_idx, rel_error) + + @skip_if_lt_x_gpu(_FSDP_WORLD_SIZE) + def test_precompute_bitnet_scale(self): + from torchao.prototype.quantized_training.bitnet import get_bitnet_scale, precompute_bitnet_scale_for_fsdp + + model = nn.Sequential(nn.Linear(32, 64), nn.GELU(), nn.Linear(64, 32)).cuda() + model_fsdp = copy.deepcopy(model) + quantize_(model_fsdp, bitnet_training()) + fully_shard(model_fsdp) + + precompute_bitnet_scale_for_fsdp(model_fsdp) + + torch.testing.assert_close( + get_bitnet_scale(model[0].weight), model_fsdp[0].weight._local_tensor._precomputed_scale + ) + torch.testing.assert_close( + get_bitnet_scale(model[2].weight), model_fsdp[2].weight._local_tensor._precomputed_scale + ) instantiate_parametrized_tests(TestQuantizedTraining) diff --git a/torchao/prototype/quantized_training/README.md b/torchao/prototype/quantized_training/README.md index 1dde72598c..b1b32a8be6 100644 --- a/torchao/prototype/quantized_training/README.md +++ b/torchao/prototype/quantized_training/README.md @@ -20,6 +20,8 @@ There are 3 main benefits of using low-precision dtype for training (the extent [`benchmarks/quantized_training/pretrain_llama2.py`](../../../benchmarks/quantized_training/pretrain_llama2.py) demonstrates an end-to-end Llama2 pre-training on single GPU for strategies implemented in this folder. +All features in this folder are tested to work with PyTorch 2.4+ unless otherwise stated. Training with FSDP2 is also supported, but if you use FDSP2 mixed-precision with `param_dtype` != model dtype, PyTorch 2.6+ is required. + ## INT8 quantized training Typically, quantized weights cannot be trained directly due to quantization error: a small change in the quantized weight will be round down to zero. To tackle this problem, we use **stochastic rounding** for weight update. In simple terms, stochastic rounding will round up or down randomly, but with a higher chance if it is closer to that direction. For example, 0.8 will have 80% chance of rounding up and 20% of rounding down. It also follows that on average, stochastic rounding will estimate the floating point value exactly. @@ -35,7 +37,7 @@ Usage ```python from torchao.prototype.quantized_training import int8_weight_only_quantized_training from torchao.prototype.low_bit_optim import _AdamW -from torchao.quantization import quantize_ +from torchao import quantize_ model = ... quantize_(model, int8_weight_only_quantized_training()) @@ -56,7 +58,7 @@ BF16 compile | 10.25 | 9000 INT8 QT eager | 10.12 | 5600 INT8 QT compile | 9.84 | 8700 -## INT8 mixed-precision +## INT8 mixed-precision training On NVIDIA GPUs, INT8 Tensor Cores is approximately 2x faster than their BF16/FP16 counterparts. In mixed-precision training, we can down-cast activations and weights dynamically to INT8 to leverage faster matmuls. However, since INT8 has very limited range [-128,127], we perform row-wise quantization, similar to how INT8 post-training quantization (PTQ) is done. Weight is still in original precision. @@ -64,7 +66,7 @@ On NVIDIA GPUs, INT8 Tensor Cores is approximately 2x faster than their BF16/FP1 ```python from torchao.prototype.quantized_training import int8_mixed_precision_training, Int8MixedPrecisionTrainingConfig -from torchao.quantization import quantize_ +from torchao import quantize_ model = ... @@ -104,34 +106,40 @@ INT8 mixed-precision | ~29k | 19.47 | 2.90 See [#748](https://github.com/pytorch/ao/pull/748) for more results. -### FSDP support +## BitNet b1.58 + +[BitNet b1.58](https://arxiv.org/abs/2402.17764) uses ternary weights: each parameter can only take on 3 distinct values {-1, 0, +1}, thus making a BitNet model very compact. BitNet uses tensor-wise abs-mean scaling for weights (quantize to ternary) and row-wise abs-max scaling for activations (quantize to INT8). -Out of the box, this INT8 mixed-precision training is not compatible with FSDP2 `MixedPrecisionPolicy(param_dtype=param_dtype)`, where `param_dtype` != model dtype. As a workaround, you will need to manually specify the FSDP2's `param_dtype` in `Int8MixedPrecisionTrainingConfig` +BitNet is originally trained with QAT: the weights and activations are fake-quantized, and straight-through estimator (STE) is used to calculate gradients with respect to floating point weights. This process adds extra overhead over standard training. Our implementation utilizes INT8 Tensor Cores to make up for this loss in speed. In fact, our implementation is faster than BF16 training in most cases. + +Usage ```python -from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy -from torchao.prototype.quantized_training import int8_mixed_precision_training, Int8MixedPrecisionTrainingConfig -from torchao.quantization import quantize_ +from torchao.prototype.quantized_training import bitnet_training +from torchao import quantize_ -model = ... # FP32 model +model = ... +quantize_(model, bitnet_training()) +``` -# setup configs -mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16) -int8mp_config = Int8MixedPrecisionTrainingConfig(fsdp_param_dtype=mp_policy.param_dtype) +Note: following the [BitNet Training Tips, Code and FAQ](https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf), user should insert extra RMSNorm before each `nn.Linear` layers and also remove the original RMSNorm before attention and MLP modules. Calling `quantize_(model, bitnet_training())` will NOT perform this for you. You can take a look at our example training script [`benchmarks/quantized_training/pretrain_llama2.py`](../../../benchmarks/quantized_training/pretrain_llama2.py) on how to do this for our Llama model. -# exclude LM head -quantize_(model.layers, int8_mixed_precision_training(int8mp_config)) +When used with FSDP2 training, you can pre-compute BitNet weight scales for the next iteration to synchronize all scales with a single all-reduce operation. This should be done after optimizer step. -# shard the model w/ FSDP2 -for layer in model.layers: - fully_shard(layer, mp_policy=mp_policy) -fully_shard(model, mp_policy=mp_policy) +```python +from torchao.prototype.quantized_training import precompute_bitnet_scale_for_fsdp -# train model as usual +for _ in range(n_steps): + model(inputs).sum().backward() + optim.step() + precompute_bitnet_scale_for_fsdp(model) ``` +See [#930](https://github.com/pytorch/ao/pull/930) for some benchmark results. + ## Future ideas +- Extend INT8 weight-only to support tensor-wise scaling, as well as other INTx dtypes. - Tile-wise INT8 quantization to keep quantized weight for both forward and backward pass (similar to JetFire). - INT4 weight only (with group-wise quantization). This can be used with INT4 tinygemm deployment in mind (or other optimized INT4 kernels). - FP8 activation x FP8 weight. The current FP8 training recipe can be seen as a form of QAT, which maintains a high-precision copy of model weights. We can eliminate the high-precision copy. diff --git a/torchao/prototype/quantized_training/__init__.py b/torchao/prototype/quantized_training/__init__.py index ccf2f5375d..c3c9b7cfaf 100644 --- a/torchao/prototype/quantized_training/__init__.py +++ b/torchao/prototype/quantized_training/__init__.py @@ -1,3 +1,4 @@ +from .bitnet import BitNetTrainingLinearWeight, bitnet_training, precompute_bitnet_scale_for_fsdp from .int8 import ( Int8QuantizedTrainingLinearWeight, int8_weight_only_quantized_training, diff --git a/torchao/prototype/quantized_training/bitnet.py b/torchao/prototype/quantized_training/bitnet.py new file mode 100644 index 0000000000..ffba7f252e --- /dev/null +++ b/torchao/prototype/quantized_training/bitnet.py @@ -0,0 +1,353 @@ +# this file implements BitNet b1.58 https://arxiv.org/abs/2402.17764 +# a reference implementation is available at +# https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf + +from typing import Any, Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torch.utils._pytree as pytree +from torch import Tensor, nn +from torch.utils._triton import has_triton +from torch.distributed._tensor import DTensor + +from torchao.quantization.quant_api import _get_linear_subclass_inserter +from torchao.utils import TorchAOBaseTensor + +from .int8 import quantize_int8_rowwise + +if has_triton(): + from .int8_mm import scaled_int8_mm + +else: + + # This is less performant than the explicit hand-written Triton kernel, though things might + # change in the future. + # Multiplying col_scale first is faster than the other way round. + def scaled_int8_mm(A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tensor) -> Tensor: + return torch._int_mm(A, B) * col_scale.view(-1) * row_scale.view(-1, 1) + + +aten = torch.ops.aten + + +class BitNetTrainingLinearWeight(TorchAOBaseTensor): + @staticmethod + @torch._dynamo.disable + def __new__(cls, data: Tensor, precomputed_scale: Optional[Tensor] = None): + return Tensor._make_wrapper_subclass( + cls, + data.shape, + dtype=data.dtype, + device=data.device, + ) + + @torch._dynamo.disable + def __init__(self, data: Tensor, precomputed_scale: Optional[Tensor] = None): + self._data = data + self._precomputed_scale = precomputed_scale + + def __tensor_flatten__(self): + if self._precomputed_scale is not None: + return ["_data", "_precomputed_scale"], [] + else: + return ["_data"], [] + + @classmethod + def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): + return cls(tensor_data_dict["_data"], tensor_data_dict.get("_precomputed_scale", None), *tensor_attributes) + + def __repr__(self): + return f"{self.__class__.__name__}(data={self._data})" + + # adapated from FP8 implementation of WeightWithDynamicFloat8CastTensor + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + out = func( + *pytree.tree_map_only(cls, lambda x: x._data, args), + **pytree.tree_map_only(cls, lambda x: x._data, kwargs), + ) + + # NOTE: _precomputed_scale does not propagate through any ops + if func is aten.copy_.default: + # return original object + return args[0] + elif func in { + aten.t.default, + aten.detach.default, + aten.empty_like.default, + aten.new_zeros.default, + aten.slice.Tensor, + aten.view.default, + aten.as_strided.default, + aten._to_copy.default, + aten._pin_memory.default, + aten.split.Tensor, + aten.clone.default, + }: + # return new wrapped object + return pytree.tree_map_only(Tensor, lambda x: cls(x), out) + else: + # return new unwrapped object + return out + + # new signature https://github.com/pytorch/pytorch/pull/136129 + # we need default None for module and mp_policy so this method still works with PyTorch 2.4 and 2.5 + def fsdp_pre_all_gather(self, mesh, module=None, mp_policy=None): + # quantize and pack into 2-bit to save comm bandwidth + if self._precomputed_scale is not None: + scale = self._precomputed_scale + + else: + scale = get_bitnet_scale(self._data) + dist.all_reduce(scale, op=dist.ReduceOp.AVG) + + # NOTE: scale is in FP32 + data_i8 = quantize_bitnet_weight(self._data, scale) + data_i2 = _pack_i2_in_i8(data_i8) + return (data_i2,), (scale,) + + def fsdp_post_all_gather( + self, + all_gather_outputs: Tuple[Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: Optional[Tensor] = None, + ): + (data_i2,) = all_gather_outputs + (scale,) = metadata + scale = scale.to(param_dtype) + if out is not None: + assert isinstance(out, BitNetPacked2bitLinearWeight) + out.scale = scale + return + return BitNetPacked2bitLinearWeight(data_i2, scale), all_gather_outputs + + +@BitNetTrainingLinearWeight.implements(F.linear) +def _(func, types, args, kwargs): + if torch.is_autocast_enabled("cuda"): + dtype = torch.get_autocast_gpu_dtype() + args = tuple(x.to(dtype) if x is not None else x for x in args) + return _BitNetTrainingLinear.apply(*args, **kwargs) + + +def get_bitnet_scale(x: Tensor): + "Tensor-wise abs-mean. Always return FP32." + return x.float().abs().mean() + + +def quantize_bitnet_weight(w: Tensor, scale: Tensor, eps: float = 1e-5) -> Tensor: + w = w.float() / scale.clip(eps) + w = w.round().clip(-1, 1).to(torch.int8) + return w + + +@torch.no_grad() +def precompute_bitnet_scale_for_fsdp(module: nn.Module): + """Calculate scale for all BitNetTrainingLinearWeight parameters. + This should be run after the optimizer step. It performs a single all-reduce for all + parameters to reduce overhead. + """ + bitnet_params = [ + p + for p in module.parameters() + if isinstance(p, DTensor) and isinstance(p._local_tensor, BitNetTrainingLinearWeight) + ] + if len(bitnet_params) == 0: + return + + # NOTE: use torch.compile to save memory and increase speed? + bitnet_scales = [get_bitnet_scale(x) for x in bitnet_params] # local absmean + bitnet_scales = torch.stack(bitnet_scales) + bitnet_scales = bitnet_scales.full_tensor() # global absmean + + for i, p in enumerate(bitnet_params): + p._local_tensor._precomputed_scale = bitnet_scales[i] + + +class _BitNetTrainingLinear(torch.autograd.Function): + @staticmethod + def forward(ctx, input: Tensor, weight: BitNetTrainingLinearWeight, bias: Optional[Tensor] = None): + batch_dims = input.shape[:-1] + input = input.view(-1, weight.shape[1]) + + # https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf + # Figure 3 + input_i8, row_scale = quantize_int8_rowwise(input, eps=1e-5) + + # NOTE: use FP32 scale for weight quantization, but cast scale to possibly lower precision + # for matmul and backward + tensor_scale = get_bitnet_scale(weight._data) + weight_i8 = quantize_bitnet_weight(weight._data, tensor_scale) + tensor_scale = tensor_scale.to(weight.dtype) + + ctx.save_for_backward(input_i8, row_scale, weight_i8, tensor_scale) + + # use int8 tensor cores + out = scaled_int8_mm(input_i8.contiguous(), weight_i8.contiguous().T, row_scale, tensor_scale) + out = out.view(*batch_dims, weight.shape[0]) + + out = out + bias if bias is not None else out + return out + + @staticmethod + def backward(ctx, grad_output): + input_i8, row_scale, weight_i8, tensor_scale = ctx.saved_tensors + grad_input = grad_weight = grad_bias = None + + batch_dims = grad_output.shape[:-1] + grad_output = grad_output.view(-1, weight_i8.shape[0]) + + # NOTE: we can potentially speedup training by also quantizing the backward pass + # to use INT8 tensor cores + if ctx.needs_input_grad[0]: + # mixed mm + grad_input = (grad_output @ weight_i8.to(grad_output.dtype)) * tensor_scale + grad_input = grad_input.view(*batch_dims, weight_i8.shape[1]) + + if ctx.needs_input_grad[1]: + # NOTE: we use quantized activation for this calculation + grad_weight = grad_output.T @ (input_i8 * row_scale.view(-1, 1)) + + if ctx.needs_input_grad[2]: + grad_bias = grad_output.sum(0) + + return grad_input, grad_weight, grad_bias + + +def bitnet_training(): + return _get_linear_subclass_inserter(BitNetTrainingLinearWeight, allow_requires_grad=True) + + +def _pack_i2_in_i8(x: Tensor): + # perform packing: [xxxx xxaa, xxxx xxxbb, xxxx xxcc, xxxx xxdd] -> [aabb ccdd] + # for each value, xxxx can be either all 0s or all 1s because these are signed numbers. + # thus, we have to mask out the 2 least significant bits (right-most) before bit-shift. + # e.g. 1111 1111 (value=-1) -> 0000 0011 -> 0011 0000 + + x0 = x[:, ::4] << 6 # don't need to mask this number because we shift it to the left-most + x1 = (x[:, 1::4] & 0b11) << 4 + x2 = (x[:, 2::4] & 0b11) << 2 + x3 = x[:, 3::4] & 0b11 + return x0 | x1 | x2 | x3 + + +def _unpack_i2_in_i8(x: Tensor): + # NOTE: this is signed integer, so left-shift then right-shift will perform sign extension correctly + # e.g. aa10bbcc -> 10bbcc00 -> 11111110 + return torch.stack([x >> 6, x << 2 >> 6, x << 4 >> 6, x << 6 >> 6], dim=-1).view(x.shape[0], -1) + + +# currently this class mainly serves as a container for quantized FSDP2 all-gather, +# so only a minimal set of ops are implemented. this can be extended for inference. +class BitNetPacked2bitLinearWeight(TorchAOBaseTensor): + @staticmethod + @torch._dynamo.disable + def __new__(cls, int_data: Tensor, scale: Tensor): + M, N = int_data.shape + shape = (M, N * 4) + return Tensor._make_wrapper_subclass( + cls, + shape, + dtype=scale.dtype, + device=scale.device, + ) + + @torch._dynamo.disable + def __init__(self, int_data: Tensor, scale: Tensor): + assert int_data.dtype is torch.int8 + assert scale.shape == () + self.int_data = int_data + self.scale = scale + + def __tensor_flatten__(self): + return ["int_data", "scale"], [] + + @classmethod + def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): + return cls(tensor_data_dict["int_data"], tensor_data_dict["scale"], *tensor_attributes) + + def __repr__(self): + return f"{self.__class__.__name__}(data={self.dequantize()})" + + def dequantize(self, out_dtype=None): + out = _unpack_i2_in_i8(self.int_data) * self.scale + if out_dtype is not None: + out = out.to(out_dtype) + return out + + +@BitNetPacked2bitLinearWeight.implements(F.linear) +def _(func, types, args, kwargs): + return _BitNetPacked2bitLinear.apply(*args, **kwargs) + + +@BitNetPacked2bitLinearWeight.implements( + [ + aten.detach.default, + aten.clone.default, + ] +) +def _(func, types, args, kwargs): + return BitNetPacked2bitLinearWeight( + func(args[0].int_data, *args[1:], **kwargs), + func(args[0].scale, *args[1:], **kwargs), + ) + + +# this is a workaround to make it work with FSDP2. +# end-users should not call this op directly. +@BitNetPacked2bitLinearWeight.implements(aten.as_strided.default) +def _(func, types, args, kwargs): + return BitNetPacked2bitLinearWeight(args[0].int_data, args[0].scale) + + +class _BitNetPacked2bitLinear(torch.autograd.Function): + @staticmethod + def forward(ctx, input: Tensor, weight: BitNetPacked2bitLinearWeight, bias: Optional[Tensor] = None): + batch_dims = input.shape[:-1] + input = input.view(-1, weight.shape[1]) + + # https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf + # Figure 3 + input_i8, row_scale = quantize_int8_rowwise(input, eps=1e-5) + weight_i2, tensor_scale = weight.int_data, weight.scale + + ctx.save_for_backward(input_i8, row_scale, weight_i2, tensor_scale) + + # use int8 tensor cores + # NOTE: is doing dequant inside matmul faster when M is large? + weight_i8 = _unpack_i2_in_i8(weight_i2) + out = scaled_int8_mm(input_i8.contiguous(), weight_i8.contiguous().T, row_scale, tensor_scale) + out = out.view(*batch_dims, weight.shape[0]) + + out = out + bias if bias is not None else out + return out + + @staticmethod + def backward(ctx, grad_output): + input_i8, row_scale, weight_i2, tensor_scale = ctx.saved_tensors + weight_i8 = _unpack_i2_in_i8(weight_i2) + grad_input = grad_weight = grad_bias = None + + batch_dims = grad_output.shape[:-1] + grad_output = grad_output.view(-1, weight_i8.shape[0]) + + # NOTE: we can potentially speedup training by also quantizing the backward pass + # to use INT8 tensor cores + if ctx.needs_input_grad[0]: + # mixed mm + grad_input = (grad_output @ weight_i8.to(grad_output.dtype)) * tensor_scale + grad_input = grad_input.view(*batch_dims, weight_i8.shape[1]) + + if ctx.needs_input_grad[1]: + # NOTE: we use quantized activation for this calculation + grad_weight = grad_output.T @ (input_i8 * row_scale.view(-1, 1)) + + if ctx.needs_input_grad[2]: + grad_bias = grad_output.sum(0) + + return grad_input, grad_weight, grad_bias diff --git a/torchao/prototype/quantized_training/int8.py b/torchao/prototype/quantized_training/int8.py index 828655f04c..1273eda83f 100644 --- a/torchao/prototype/quantized_training/int8.py +++ b/torchao/prototype/quantized_training/int8.py @@ -14,7 +14,7 @@ @torch.no_grad() -def quantize_int8_rowwise(tensor: Tensor, stochastic_rounding: bool = False): +def quantize_int8_rowwise(tensor: Tensor, stochastic_rounding: bool = False, eps: float = 1e-12): """Normal rounding will always round down small changes in weight update. To tackle this problem, stochastic rounding can be used, which has a low chance, but not zero, of rounding up. The probability of rounding up is equal to x - ⌊x⌋, which indicates how close the value is to the next @@ -29,7 +29,7 @@ def quantize_int8_rowwise(tensor: Tensor, stochastic_rounding: bool = False): """ # absmax symmetric quantization scale = tensor.abs().amax(1) / 127 # same dtype as tensor - inv_scale = 1.0 / scale.float().clip(1e-12) + inv_scale = 1.0 / scale.float().clip(eps) tensor = tensor.float() * inv_scale.view(-1, 1) # slightly faster than divide directly if stochastic_rounding: @@ -99,8 +99,14 @@ def __repr__(self): f"requires_grad={self.requires_grad})" ) - def fsdp_pre_all_gather(self, mesh): - return (self.int_data, self.scale), None + # require https://github.com/pytorch/pytorch/pull/136129 for mixed-precision param_dtype + # we need default None for module and mp_policy so this method still works with PyTorch 2.4 and 2.5 + def fsdp_pre_all_gather(self, mesh, module=None, mp_policy=None): + scale = self.scale + if mp_policy is not None: + scale = scale.to(mp_policy.param_dtype) + + return (self.int_data, scale), None def fsdp_post_all_gather( self, diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index 0f96e348ba..8cc02b53c0 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -11,15 +11,15 @@ from .int8 import quantize_int8_rowwise if has_triton(): - from .int8_mm import int8_mm_dequant + from .int8_mm import scaled_int8_mm else: # This is less performant than the explicit hand-written Triton kernel, though things might # change in the future. - # Multiplying B_scale first is faster than the other way round. - def int8_mm_dequant(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwise: Tensor) -> Tensor: - return torch._int_mm(A, B) * B_scale_colwise * A_scale_rowwise.view(-1, 1) + # Multiplying col_scale first is faster than the other way round. + def scaled_int8_mm(A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tensor) -> Tensor: + return torch._int_mm(A, B) * col_scale.view(-1) * row_scale.view(-1, 1) class Int8MixedPrecisionTrainingConfig(NamedTuple): @@ -27,10 +27,6 @@ class Int8MixedPrecisionTrainingConfig(NamedTuple): grad_input: bool = True grad_weight: bool = True - # workaround for FSDP2 with `MixedPrecisionPolicy(param_dtype)` - # see `Int8MixedPrecisionTrainingLinearWeight.fsdp_pre_all_gather()` for more details. - fsdp_param_dtype: Optional[torch.dtype] = None - _DEFAULT_CONFIG = Int8MixedPrecisionTrainingConfig() @@ -114,15 +110,15 @@ def unwrap(x: cls): # return new unwrapped object return out - def fsdp_pre_all_gather(self, mesh): + # require https://github.com/pytorch/pytorch/pull/136129 for mixed-precision param_dtype + # we need default None for module and mp_policy so this method still works with PyTorch 2.4 and 2.5 + def fsdp_pre_all_gather(self, mesh, module=None, mp_policy=None): # TODO: pre-quantize weight here -> reduce comm bandwidth. # we will need another tensor subclass to hold the quantized weight. + data = self._data + if mp_policy is not None: + data = data.to(mp_policy.param_dtype) - # doing dtype casting to `param_dtype` in `fsdp_post_all_gather()` will give wrong results. - # as a workaround, we do it in `fsdp_pre_all_gather()` instead. since `param_dtype` is not - # exposed to `fsdp_pre_all_gather()`, we need to specify it in the config. - # this workaround can be removed once we implement INT8 communication. - data = self._data.to(dtype=self.config.fsdp_param_dtype) return (data,), (self.config,) def fsdp_post_all_gather( @@ -171,7 +167,7 @@ def _dynamic_int8_mm(A: Tensor, B: Tensor) -> Tensor: # A may have more than 2 dims, while B must be exactly 2-dim A_i8, A_scale_rowwise = quantize_int8_rowwise(A.view(-1, A.shape[-1])) B_t_i8, B_scale_colwise = quantize_int8_rowwise(B.T) - out = int8_mm_dequant( + out = scaled_int8_mm( A_i8.contiguous(), B_t_i8.contiguous().T, A_scale_rowwise.contiguous(), diff --git a/torchao/prototype/quantized_training/int8_mm.py b/torchao/prototype/quantized_training/int8_mm.py index b316e82208..74d3027daa 100644 --- a/torchao/prototype/quantized_training/int8_mm.py +++ b/torchao/prototype/quantized_training/int8_mm.py @@ -51,19 +51,27 @@ @triton.autotune(configs=configs, key=["M", "N", "K", "stride_ak", "stride_bk"]) @triton.jit -def _int8_mm_dequant_kernel( - A_ptr, B_ptr, C_ptr, - A_scale_rowwise_ptr, - B_scale_colwise_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, +def _scaled_int8_mm_kernel( + A_ptr, + B_ptr, + C_ptr, + row_scale_ptr, + col_scale_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr = 8, EVEN_K: tl.constexpr = True, + COL_SCALE_SCALAR: tl.constexpr = False, ): # based on triton.ops.matmul pid = tl.program_id(0) @@ -104,41 +112,60 @@ def _int8_mm_dequant_kernel( idx_n = rn[None, :] mask = (idx_m < M) & (idx_n < N) - a_scale = tl.load(A_scale_rowwise_ptr + idx_m, mask=idx_m < M).to(tl.float32) - b_scale = tl.load(B_scale_colwise_ptr + idx_n, mask=idx_n < N).to(tl.float32) - acc = acc.to(tl.float32) * a_scale * b_scale + row_scale = tl.load(row_scale_ptr + idx_m, mask=idx_m < M).to(tl.float32) + if COL_SCALE_SCALAR: + # hack to support BitNet. col_scale is now a scalar + col_scale = tl.load(col_scale_ptr).to(tl.float32) + else: + col_scale = tl.load(col_scale_ptr + idx_n, mask=idx_n < N).to(tl.float32) + acc = acc.to(tl.float32) * row_scale * col_scale # inductor generates a suffix xindex = idx_m * stride_cm + idx_n * stride_cn tl.store(C_ptr + tl.broadcast_to(xindex, mask.shape), acc, mask) -lib.define("int8_mm_dequant(Tensor A, Tensor B, Tensor A_scale, Tensor B_scale) -> Tensor") +lib.define("scaled_int8_mm(Tensor A, Tensor B, Tensor A_scale, Tensor B_scale) -> Tensor") -def int8_mm_dequant(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwise: Tensor) -> Tensor: +def scaled_int8_mm(A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tensor) -> Tensor: + """Compute `(A @ B) * row_scale * col_scale`, where `A` and `B` are INT8 to utilize + INT8 tensor cores. `col_scale` can be a scalar. + """ assert A.dtype is torch.int8 and B.dtype is torch.int8 - assert A_scale_rowwise.dtype is B_scale_colwise.dtype + assert row_scale.dtype is col_scale.dtype assert A.shape[1] == B.shape[0] - assert A_scale_rowwise.squeeze().shape == (A.shape[0],) - assert B_scale_colwise.squeeze().shape == (B.shape[1],) - assert A_scale_rowwise.is_contiguous() - assert B_scale_colwise.is_contiguous() - return torch.ops.torchao.int8_mm_dequant(A, B, A_scale_rowwise, B_scale_colwise) + assert row_scale.squeeze().shape == (A.shape[0],) + assert col_scale.squeeze().shape in ((B.shape[1],), ()) + assert row_scale.is_contiguous() + assert col_scale.is_contiguous() + return torch.ops.torchao.scaled_int8_mm(A, B, row_scale, col_scale) -@torch.library.impl(lib, "int8_mm_dequant", "Meta") -def _(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwise: Tensor): - return torch.empty((A.shape[0], B.shape[1]), device=A.device, dtype=A_scale_rowwise.dtype) +@torch.library.impl(lib, "scaled_int8_mm", "Meta") +def _(A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tensor): + return torch.empty((A.shape[0], B.shape[1]), device=A.device, dtype=row_scale.dtype) -@torch.library.impl(lib, "int8_mm_dequant", "CUDA") -def int8_mm_dequant_cuda(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwise: Tensor): +@torch.library.impl(lib, "scaled_int8_mm", "CUDA") +def scaled_int8_mm_cuda(A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tensor): M, K = A.shape _, N = B.shape - C = torch.empty(M, N, device=A.device, dtype=A_scale_rowwise.dtype) + C = torch.empty(M, N, device=A.device, dtype=row_scale.dtype) grid = lambda meta: (triton.cdiv(meta["M"], meta["BLOCK_M"]) * triton.cdiv(meta["N"], meta["BLOCK_N"]),) - _int8_mm_dequant_kernel[grid]( - A, B, C, A_scale_rowwise, B_scale_colwise, M, N, K, *A.stride(), *B.stride(), *C.stride(), EVEN_K=K % 2 == 0 + _scaled_int8_mm_kernel[grid]( + A, + B, + C, + row_scale, + col_scale, + M, + N, + K, + *A.stride(), + *B.stride(), + *C.stride(), + EVEN_K=K % 2 == 0, + COL_SCALE_SCALAR=col_scale.numel() == 1, ) return C