From 6995ef3e64032aefa353612fc43a77f2b6b5dd53 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 24 Sep 2024 22:01:13 +0800 Subject: [PATCH 01/15] first upstream of BitNet --- test/prototype/test_quantized_training.py | 66 +++- .../prototype/quantized_training/README.md | 24 +- .../prototype/quantized_training/__init__.py | 1 + .../prototype/quantized_training/bitnet.py | 306 ++++++++++++++++++ torchao/prototype/quantized_training/int8.py | 4 +- .../int8_mixed_precision.py | 10 +- .../prototype/quantized_training/int8_mm.py | 81 +++-- 7 files changed, 454 insertions(+), 38 deletions(-) create mode 100644 torchao/prototype/quantized_training/bitnet.py diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index bffff16fc1..0d0176aad2 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -11,7 +11,7 @@ import torch.distributed as dist import torch.nn.functional as F from torch import nn -from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy +from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import FSDPTest from torch.testing._internal.common_utils import TestCase, instantiate_parametrized_tests, parametrize, run_tests @@ -20,6 +20,7 @@ from torchao.prototype.low_bit_optim import _AdamW from torchao.prototype.quantized_training import ( Int8MixedPrecisionTrainingConfig, + bitnet_training, int8_mixed_precision_training, int8_weight_only_quantized_training, quantize_int8_rowwise, @@ -200,6 +201,69 @@ def test_int8_mixed_precision_training(self, compile, config): optim_int8mp.step() optim_int8mp.zero_grad() + @parametrize("compile", [False, True]) + def test_bitnet_training(self, compile): + # reference implementation + # https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf + # Figure 3 + class BitLinear(nn.Linear): + def activation_quant(self, x): + scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5) + return (x * scale).round().clamp_(-128, 127) / scale + + def weight_quant(self, x): + scale = 1.0 / x.abs().mean().clamp_(min=1e-5) + return (x * scale).round().clamp_(-1, 1) / scale + + def forward(self, x): + w = self.weight + x = x + (self.activation_quant(x) - x).detach() + w = w + (self.weight_quant(w) - w).detach() + return F.linear(x, w, self.bias) + + _reset() + bsize = 4 + embed_dim = 32 + device = "cuda" + + # only use 1 matmul shape to reduce triton autotune time + model_ref = nn.Sequential( + nn.Linear(embed_dim, embed_dim, bias=False), + nn.GELU(), + nn.Linear(embed_dim, embed_dim), + ).to(device) + model = copy.deepcopy(model_ref) + quantize_(model, bitnet_training(), set_inductor_config=False) + + # change model_ref to use BitLinear + model_ref[0].__class__ = BitLinear + model_ref[2].__class__ = BitLinear + + if compile: + model_ref.compile() + model.compile() + + optim_ref = torch.optim.AdamW(model_ref.parameters()) + optim = torch.optim.AdamW(model.parameters()) + + for i in range(5): + inputs = torch.randn(bsize, embed_dim, device=device) + labels = torch.randint(embed_dim, size=(bsize,), device=device) + loss_ref = F.cross_entropy(model_ref(inputs), labels) + loss = F.cross_entropy(model(inputs), labels) + + torch.testing.assert_close(loss, loss_ref) + + loss_ref.backward() + optim_ref.step() + optim_ref.zero_grad() + + loss.backward() + for p in model.parameters(): + assert p.grad is not None + optim.step() + optim.zero_grad() + _FSDP_WORLD_SIZE = 2 diff --git a/torchao/prototype/quantized_training/README.md b/torchao/prototype/quantized_training/README.md index 1dde72598c..52bde67406 100644 --- a/torchao/prototype/quantized_training/README.md +++ b/torchao/prototype/quantized_training/README.md @@ -35,7 +35,7 @@ Usage ```python from torchao.prototype.quantized_training import int8_weight_only_quantized_training from torchao.prototype.low_bit_optim import _AdamW -from torchao.quantization import quantize_ +from torchao import quantize_ model = ... quantize_(model, int8_weight_only_quantized_training()) @@ -64,7 +64,7 @@ On NVIDIA GPUs, INT8 Tensor Cores is approximately 2x faster than their BF16/FP1 ```python from torchao.prototype.quantized_training import int8_mixed_precision_training, Int8MixedPrecisionTrainingConfig -from torchao.quantization import quantize_ +from torchao import quantize_ model = ... @@ -111,7 +111,7 @@ Out of the box, this INT8 mixed-precision training is not compatible with FSDP2 ```python from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy from torchao.prototype.quantized_training import int8_mixed_precision_training, Int8MixedPrecisionTrainingConfig -from torchao.quantization import quantize_ +from torchao import quantize_ model = ... # FP32 model @@ -130,6 +130,24 @@ fully_shard(model, mp_policy=mp_policy) # train model as usual ``` +## BitNet b1.58 + +[BitNet b1.58](https://arxiv.org/abs/2402.17764) uses ternary weights: each parameter can only take on 3 distinct values {-1, 0, +1}, thus making a BitNet model very compact. BitNet uses tensor-wise abs-mean scaling for weights (quantize to ternary) and row-wise abs-max scaling for activations (quantize to INT8). + +BitNet is originally trained with QAT: the weights and activations are fake-quantized, and straight-through estimator (STE) is used to calculate gradients with respect to floating point weights. This process adds extra overhead over standard straining. Our implementation utilizes INT8 Tensor Cores to make up for this loss in speed. In fact, our implementation is faster than BF16 training in most cases. + +Usage + +```python +from torchao.prototype.quantized_training import bitnet_training +from torchao import quantize_ + +model = ... +quantize_(model, bitnet_training()) +``` + +Note: following the [BitNet Training Tips, Code and FAQ](https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf), user should insert extra RMSNorm before each `nn.Linear` layers and also remove the original RMSNorm before attention and MLP modules. Calling `quantize_(model, bitnet_training())` will NOT perform this for you. + ## Future ideas - Tile-wise INT8 quantization to keep quantized weight for both forward and backward pass (similar to JetFire). diff --git a/torchao/prototype/quantized_training/__init__.py b/torchao/prototype/quantized_training/__init__.py index ccf2f5375d..cdaf39ab35 100644 --- a/torchao/prototype/quantized_training/__init__.py +++ b/torchao/prototype/quantized_training/__init__.py @@ -1,3 +1,4 @@ +from .bitnet import BitNetTrainingLinearWeight, bitnet_training from .int8 import ( Int8QuantizedTrainingLinearWeight, int8_weight_only_quantized_training, diff --git a/torchao/prototype/quantized_training/bitnet.py b/torchao/prototype/quantized_training/bitnet.py new file mode 100644 index 0000000000..26337f1f45 --- /dev/null +++ b/torchao/prototype/quantized_training/bitnet.py @@ -0,0 +1,306 @@ +# this file implements BitNet b1.58 https://arxiv.org/abs/2402.17764 +# a reference implementation is available at +# https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf + +from typing import Any, Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torch.utils._pytree as pytree +from torch import Tensor +from torch.utils._triton import has_triton + +from torchao.quantization.quant_api import _get_linear_subclass_inserter +from torchao.utils import TorchAOBaseTensor + +from .int8 import quantize_int8_rowwise + +if has_triton(): + from .int8_mm import scaled_int8_mm + +else: + + # This is less performant than the explicit hand-written Triton kernel, though things might + # change in the future. + # Multiplying col_scale first is faster than the other way round. + def scaled_int8_mm(A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tensor) -> Tensor: + return torch._int_mm(A, B) * col_scale * row_scale.view(-1, 1) + + +aten = torch.ops.aten + + +class BitNetTrainingLinearWeight(TorchAOBaseTensor): + @staticmethod + @torch._dynamo.disable + def __new__(cls, data: Tensor): + return Tensor._make_wrapper_subclass( + cls, + data.shape, + dtype=data.dtype, + device=data.device, + ) + + @torch._dynamo.disable + def __init__(self, data: Tensor): + self._data = data + + def __tensor_flatten__(self): + return ["_data"], [] + + @classmethod + def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): + return cls(tensor_data_dict["_data"], *tensor_attributes) + + def __repr__(self): + return f"{self.__class__.__name__}(data={self._data})" + + # adapated from FP8 implementation of WeightWithDynamicFloat8CastTensor + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + out = func( + *pytree.tree_map_only(cls, lambda x: x._data, args), + **pytree.tree_map_only(cls, lambda x: x._data, kwargs), + ) + + if func is aten.copy_.default: + # return original object + return args[0] + elif func in { + aten.t.default, + aten.detach.default, + aten.empty_like.default, + aten.new_zeros.default, + aten.slice.Tensor, + aten.view.default, + aten.as_strided.default, + aten._to_copy.default, + aten._pin_memory.default, + aten.split.Tensor, + aten.clone.default, + }: + # return new wrapped object + return pytree.tree_map_only(Tensor, lambda x: cls(x), out) + else: + # return new unwrapped object + return out + + def fsdp_pre_all_gather(self, mesh): + # quantize and pack into 2-bit to save comm bandwidth + # TODO: precompute absmean similar to float8 + data = BitNetPacked2bitLinearWeight.from_float(self._data, all_reduce=True) + return (data.int_data,), (data.scale,) + + def fsdp_post_all_gather( + self, + all_gather_outputs: Tuple[Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: Optional[Tensor] = None, + ): + (int_data,) = all_gather_outputs + (scale,) = metadata + if out is not None: + assert isinstance(out, BitNetPacked2bitLinearWeight) + out.scale = scale + return + return BitNetPacked2bitLinearWeight(int_data, scale), all_gather_outputs + + +@BitNetTrainingLinearWeight.implements(F.linear) +def _(func, types, args, kwargs): + if torch.is_autocast_enabled("cuda"): + dtype = torch.get_autocast_gpu_dtype() + args = tuple(x.to(dtype) if x is not None else x for x in args) + return _BitNetTrainingLinear.apply(*args, **kwargs) + + +def quantize_bitnet_weight(w: Tensor, eps: float = 1e-5, all_reduce: bool = False) -> Tensor: + dtype = w.dtype + w = w.float() + scale = w.abs().mean() # tensor-wise abs-mean. FP32 + + if all_reduce and dist.is_initialized(): + dist.all_reduce(scale, op=dist.ReduceOp.AVG) + + w = w / scale.clip(eps) + w = w.round().clip(-1, 1).to(torch.int8) + return w, scale.to(dtype) + + +class _BitNetTrainingLinear(torch.autograd.Function): + @staticmethod + def forward(ctx, input: Tensor, weight: BitNetTrainingLinearWeight, bias: Tensor | None = None): + batch_dims = input.shape[:-1] + input = input.view(-1, weight.shape[1]) + + # https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf + # Figure 3 + input_i8, row_scale = quantize_int8_rowwise(input, eps=1e-5) + weight_i8, tensor_scale = quantize_bitnet_weight(weight._data) + ctx.save_for_backward(input, weight_i8, tensor_scale) + + # use int8 tensor cores + out = scaled_int8_mm(input_i8.contiguous(), weight_i8.contiguous().T, row_scale, tensor_scale) + out = out.view(*batch_dims, weight.shape[0]) + + out = out + bias if bias is not None else out + return out + + @staticmethod + def backward(ctx, grad_output): + input, weight_i8, tensor_scale = ctx.saved_tensors + grad_input = grad_weight = grad_bias = None + + batch_dims = grad_output.shape[:-1] + grad_output = grad_output.view(-1, weight_i8.shape[0]) + input = input.view(-1, weight_i8.shape[1]) + + # NOTE: we can potentially speedup training by also quantizing the backward pass + # to use INT8 tensor cores + if ctx.needs_input_grad[0]: + # mixed mm + grad_input = (grad_output @ weight_i8.to(grad_output.dtype)) * tensor_scale + grad_input = grad_input.view(*batch_dims, weight_i8.shape[1]) + + if ctx.needs_input_grad[1]: + grad_weight = grad_output.T @ input + + if ctx.needs_input_grad[2]: + grad_bias = grad_output.sum(0) + + return grad_input, grad_weight, grad_bias + + +def bitnet_training(): + return _get_linear_subclass_inserter(BitNetTrainingLinearWeight, allow_requires_grad=True) + + +def _pack_i2_to_i8(x: Tensor): + # NOTE: this is signed integer, so we have to mask before bit-shift + return (x[:, ::4] << 6) | ((x[:, 1::4] & 0b11) << 4) | ((x[:, 2::4] & 0b11) << 2) | (x[:, 3::4] & 0b11) + + +def _unpack_i8_to_i2(x: Tensor): + # NOTE: this is signed integer, so left-shift then right-shift will perform sign extension correctly + # e.g. aa10bbcc -> 10bbcc00 -> 11111110 + return torch.stack([x >> 6, x << 2 >> 6, x << 4 >> 6, x << 6 >> 6], dim=-1).view(x.shape[0], -1) + + +# currently this class mainly serves as a container for quantized FSDP2 all-gather, +# so only a minimal set of ops are implemented. this can be extended for inference. +class BitNetPacked2bitLinearWeight(TorchAOBaseTensor): + @staticmethod + @torch._dynamo.disable + def __new__(cls, int_data: Tensor, scale: Tensor): + M, N = int_data.shape + shape = (M, N * 4) + return Tensor._make_wrapper_subclass( + cls, + shape, + dtype=scale.dtype, + device=scale.device, + ) + + @torch._dynamo.disable + def __init__(self, int_data: Tensor, scale: Tensor): + assert int_data.dtype is torch.int8 + assert scale.shape == () + self.int_data = int_data + self.scale = scale + + def __tensor_flatten__(self): + return ["int_data", "scale"], [] + + @classmethod + def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): + return cls(tensor_data_dict["int_data"], tensor_data_dict["scale"], *tensor_attributes) + + def __repr__(self): + return f"{self.__class__.__name__}(data={self.dequantize()})" + + @classmethod + def from_float(cls, tensor: Tensor, *, eps: float = 1e-5, all_reduce: bool = False): + int_data, scale = quantize_bitnet_weight(tensor, eps=eps, all_reduce=all_reduce) + int_data = _pack_i2_to_i8(int_data) + return BitNetPacked2bitLinearWeight(int_data, scale) + + def dequantize(self, out_dtype=None): + out = _unpack_i8_to_i2(self.int_data) * self.scale + if out_dtype is not None: + out = out.to(out_dtype) + return out + + +@BitNetPacked2bitLinearWeight.implements(F.linear) +def _(func, types, args, kwargs): + return _BitNetPacked2bitLinear.apply(*args, **kwargs) + + +@BitNetPacked2bitLinearWeight.implements( + [ + aten.detach.default, + aten.clone.default, + ] +) +def _(func, types, args, kwargs): + return BitNetPacked2bitLinearWeight( + func(args[0].int_data, *args[1:], **kwargs), + func(args[0].scale, *args[1:], **kwargs), + ) + + +# this is a workaround to make it work with FSDP2. +# end-users should not call this op directly. +@BitNetPacked2bitLinearWeight.implements(aten.as_strided.default) +def _(func, types, args, kwargs): + return BitNetPacked2bitLinearWeight(args[0].int_data, args[0].scale) + + +class _BitNetPacked2bitLinear(torch.autograd.Function): + @staticmethod + def forward(ctx, input: Tensor, weight: BitNetPacked2bitLinearWeight, bias: Tensor | None = None): + batch_dims = input.shape[:-1] + input = input.view(-1, weight.shape[1]) + + # https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf + # Figure 3 + input_i8, row_scale = quantize_int8_rowwise(input, eps=1e-5) + weight_i2, tensor_scale = weight.int_data, weight.scale + ctx.save_for_backward(input, weight_i2, tensor_scale) + + # use int8 tensor cores + # NOTE: is doing dequant inside matmul faster when M is large? + weight_i8 = _unpack_i8_to_i2(weight_i2) + out = scaled_int8_mm(input_i8.contiguous(), weight_i8.contiguous().T, row_scale, tensor_scale) + out = out.view(*batch_dims, weight.shape[0]) + + out = out + bias if bias is not None else out + return out + + @staticmethod + def backward(ctx, grad_output): + input, weight_i2, tensor_scale = ctx.saved_tensors + weight_i8 = _unpack_i8_to_i2(weight_i2) + grad_input = grad_weight = grad_bias = None + + batch_dims = grad_output.shape[:-1] + grad_output = grad_output.view(-1, weight_i8.shape[0]) + input = input.view(-1, weight_i8.shape[1]) + + # NOTE: we can potentially speedup training by also quantizing the backward pass + # to use INT8 tensor cores + if ctx.needs_input_grad[0]: + # mixed mm + grad_input = (grad_output @ weight_i8.to(grad_output.dtype)) * tensor_scale + grad_input = grad_input.view(*batch_dims, weight_i8.shape[1]) + + if ctx.needs_input_grad[1]: + grad_weight = grad_output.T @ input + + if ctx.needs_input_grad[2]: + grad_bias = grad_output.sum(0) + + return grad_input, grad_weight, grad_bias diff --git a/torchao/prototype/quantized_training/int8.py b/torchao/prototype/quantized_training/int8.py index 828655f04c..672c3a223e 100644 --- a/torchao/prototype/quantized_training/int8.py +++ b/torchao/prototype/quantized_training/int8.py @@ -14,7 +14,7 @@ @torch.no_grad() -def quantize_int8_rowwise(tensor: Tensor, stochastic_rounding: bool = False): +def quantize_int8_rowwise(tensor: Tensor, stochastic_rounding: bool = False, eps: float = 1e-12): """Normal rounding will always round down small changes in weight update. To tackle this problem, stochastic rounding can be used, which has a low chance, but not zero, of rounding up. The probability of rounding up is equal to x - ⌊x⌋, which indicates how close the value is to the next @@ -29,7 +29,7 @@ def quantize_int8_rowwise(tensor: Tensor, stochastic_rounding: bool = False): """ # absmax symmetric quantization scale = tensor.abs().amax(1) / 127 # same dtype as tensor - inv_scale = 1.0 / scale.float().clip(1e-12) + inv_scale = 1.0 / scale.float().clip(eps) tensor = tensor.float() * inv_scale.view(-1, 1) # slightly faster than divide directly if stochastic_rounding: diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index 0f96e348ba..9cf73fc2a5 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -11,15 +11,15 @@ from .int8 import quantize_int8_rowwise if has_triton(): - from .int8_mm import int8_mm_dequant + from .int8_mm import scaled_int8_mm else: # This is less performant than the explicit hand-written Triton kernel, though things might # change in the future. - # Multiplying B_scale first is faster than the other way round. - def int8_mm_dequant(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwise: Tensor) -> Tensor: - return torch._int_mm(A, B) * B_scale_colwise * A_scale_rowwise.view(-1, 1) + # Multiplying col_scale first is faster than the other way round. + def scaled_int8_mm(A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tensor) -> Tensor: + return torch._int_mm(A, B) * col_scale * row_scale.view(-1, 1) class Int8MixedPrecisionTrainingConfig(NamedTuple): @@ -171,7 +171,7 @@ def _dynamic_int8_mm(A: Tensor, B: Tensor) -> Tensor: # A may have more than 2 dims, while B must be exactly 2-dim A_i8, A_scale_rowwise = quantize_int8_rowwise(A.view(-1, A.shape[-1])) B_t_i8, B_scale_colwise = quantize_int8_rowwise(B.T) - out = int8_mm_dequant( + out = scaled_int8_mm( A_i8.contiguous(), B_t_i8.contiguous().T, A_scale_rowwise.contiguous(), diff --git a/torchao/prototype/quantized_training/int8_mm.py b/torchao/prototype/quantized_training/int8_mm.py index b316e82208..74d3027daa 100644 --- a/torchao/prototype/quantized_training/int8_mm.py +++ b/torchao/prototype/quantized_training/int8_mm.py @@ -51,19 +51,27 @@ @triton.autotune(configs=configs, key=["M", "N", "K", "stride_ak", "stride_bk"]) @triton.jit -def _int8_mm_dequant_kernel( - A_ptr, B_ptr, C_ptr, - A_scale_rowwise_ptr, - B_scale_colwise_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, +def _scaled_int8_mm_kernel( + A_ptr, + B_ptr, + C_ptr, + row_scale_ptr, + col_scale_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr = 8, EVEN_K: tl.constexpr = True, + COL_SCALE_SCALAR: tl.constexpr = False, ): # based on triton.ops.matmul pid = tl.program_id(0) @@ -104,41 +112,60 @@ def _int8_mm_dequant_kernel( idx_n = rn[None, :] mask = (idx_m < M) & (idx_n < N) - a_scale = tl.load(A_scale_rowwise_ptr + idx_m, mask=idx_m < M).to(tl.float32) - b_scale = tl.load(B_scale_colwise_ptr + idx_n, mask=idx_n < N).to(tl.float32) - acc = acc.to(tl.float32) * a_scale * b_scale + row_scale = tl.load(row_scale_ptr + idx_m, mask=idx_m < M).to(tl.float32) + if COL_SCALE_SCALAR: + # hack to support BitNet. col_scale is now a scalar + col_scale = tl.load(col_scale_ptr).to(tl.float32) + else: + col_scale = tl.load(col_scale_ptr + idx_n, mask=idx_n < N).to(tl.float32) + acc = acc.to(tl.float32) * row_scale * col_scale # inductor generates a suffix xindex = idx_m * stride_cm + idx_n * stride_cn tl.store(C_ptr + tl.broadcast_to(xindex, mask.shape), acc, mask) -lib.define("int8_mm_dequant(Tensor A, Tensor B, Tensor A_scale, Tensor B_scale) -> Tensor") +lib.define("scaled_int8_mm(Tensor A, Tensor B, Tensor A_scale, Tensor B_scale) -> Tensor") -def int8_mm_dequant(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwise: Tensor) -> Tensor: +def scaled_int8_mm(A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tensor) -> Tensor: + """Compute `(A @ B) * row_scale * col_scale`, where `A` and `B` are INT8 to utilize + INT8 tensor cores. `col_scale` can be a scalar. + """ assert A.dtype is torch.int8 and B.dtype is torch.int8 - assert A_scale_rowwise.dtype is B_scale_colwise.dtype + assert row_scale.dtype is col_scale.dtype assert A.shape[1] == B.shape[0] - assert A_scale_rowwise.squeeze().shape == (A.shape[0],) - assert B_scale_colwise.squeeze().shape == (B.shape[1],) - assert A_scale_rowwise.is_contiguous() - assert B_scale_colwise.is_contiguous() - return torch.ops.torchao.int8_mm_dequant(A, B, A_scale_rowwise, B_scale_colwise) + assert row_scale.squeeze().shape == (A.shape[0],) + assert col_scale.squeeze().shape in ((B.shape[1],), ()) + assert row_scale.is_contiguous() + assert col_scale.is_contiguous() + return torch.ops.torchao.scaled_int8_mm(A, B, row_scale, col_scale) -@torch.library.impl(lib, "int8_mm_dequant", "Meta") -def _(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwise: Tensor): - return torch.empty((A.shape[0], B.shape[1]), device=A.device, dtype=A_scale_rowwise.dtype) +@torch.library.impl(lib, "scaled_int8_mm", "Meta") +def _(A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tensor): + return torch.empty((A.shape[0], B.shape[1]), device=A.device, dtype=row_scale.dtype) -@torch.library.impl(lib, "int8_mm_dequant", "CUDA") -def int8_mm_dequant_cuda(A: Tensor, B: Tensor, A_scale_rowwise: Tensor, B_scale_colwise: Tensor): +@torch.library.impl(lib, "scaled_int8_mm", "CUDA") +def scaled_int8_mm_cuda(A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tensor): M, K = A.shape _, N = B.shape - C = torch.empty(M, N, device=A.device, dtype=A_scale_rowwise.dtype) + C = torch.empty(M, N, device=A.device, dtype=row_scale.dtype) grid = lambda meta: (triton.cdiv(meta["M"], meta["BLOCK_M"]) * triton.cdiv(meta["N"], meta["BLOCK_N"]),) - _int8_mm_dequant_kernel[grid]( - A, B, C, A_scale_rowwise, B_scale_colwise, M, N, K, *A.stride(), *B.stride(), *C.stride(), EVEN_K=K % 2 == 0 + _scaled_int8_mm_kernel[grid]( + A, + B, + C, + row_scale, + col_scale, + M, + N, + K, + *A.stride(), + *B.stride(), + *C.stride(), + EVEN_K=K % 2 == 0, + COL_SCALE_SCALAR=col_scale.numel() == 1, ) return C From 46668bf9a723ff3d466b8656583f3d28477b42c9 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 24 Sep 2024 22:20:45 +0800 Subject: [PATCH 02/15] fix type annotation --- torchao/prototype/quantized_training/bitnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/quantized_training/bitnet.py b/torchao/prototype/quantized_training/bitnet.py index 26337f1f45..45093332ac 100644 --- a/torchao/prototype/quantized_training/bitnet.py +++ b/torchao/prototype/quantized_training/bitnet.py @@ -132,7 +132,7 @@ def quantize_bitnet_weight(w: Tensor, eps: float = 1e-5, all_reduce: bool = Fals class _BitNetTrainingLinear(torch.autograd.Function): @staticmethod - def forward(ctx, input: Tensor, weight: BitNetTrainingLinearWeight, bias: Tensor | None = None): + def forward(ctx, input: Tensor, weight: BitNetTrainingLinearWeight, bias: Optional[Tensor] = None): batch_dims = input.shape[:-1] input = input.view(-1, weight.shape[1]) @@ -261,7 +261,7 @@ def _(func, types, args, kwargs): class _BitNetPacked2bitLinear(torch.autograd.Function): @staticmethod - def forward(ctx, input: Tensor, weight: BitNetPacked2bitLinearWeight, bias: Tensor | None = None): + def forward(ctx, input: Tensor, weight: BitNetPacked2bitLinearWeight, bias: Optional[Tensor] = None): batch_dims = input.shape[:-1] input = input.view(-1, weight.shape[1]) From c5e382d01f65bb1daf8bd272553a8c0423033ec4 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 24 Sep 2024 22:51:42 +0800 Subject: [PATCH 03/15] skip bitnet test on cpu. add bitnet to benchmark script --- .../quantized_training/pretrain_llama2.py | 31 ++++++++++++++++--- test/prototype/test_quantized_training.py | 1 + 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/benchmarks/quantized_training/pretrain_llama2.py b/benchmarks/quantized_training/pretrain_llama2.py index fc87c2cd6e..cbd580ee4e 100644 --- a/benchmarks/quantized_training/pretrain_llama2.py +++ b/benchmarks/quantized_training/pretrain_llama2.py @@ -4,6 +4,7 @@ # BF16 baseline: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile # INT8 QT: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_weight_only # INT8 MP: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_mixed_precision +# BitNet: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize bitnet import os @@ -20,14 +21,14 @@ from torch.utils.checkpoint import checkpoint from tqdm import tqdm -from torchao._models.llama.model import ModelArgs, Transformer, transformer_configs +from torchao import quantize_ +from torchao._models.llama.model import ModelArgs, Transformer, transformer_configs, RMSNorm from torchao.prototype import low_bit_optim from torchao.prototype.quantized_training import ( + bitnet_training, int8_mixed_precision_training, int8_weight_only_quantized_training, ) -from torchao.quantization.quant_api import quantize_ - # not official models transformer_configs.update( @@ -104,7 +105,7 @@ def get_tinystories(): parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--weight_decay", type=float, default=1e-2) - parser.add_argument("--project", default="int8_quantized_training") + parser.add_argument("--project", default="quantized_training") parser.add_argument("--run_name") parser.add_argument("--seed", type=int) parser.add_argument("--log_interval", type=int, default=10) @@ -126,8 +127,30 @@ def get_tinystories(): # TODO: might want to do the same for int8_weight_only to standardize. if args.quantize == "int8_weight_only": quantize_(model, int8_weight_only_quantized_training(), set_inductor_config=False) + elif args.quantize == "int8_mixed_precision": quantize_(model.layers, int8_mixed_precision_training(), set_inductor_config=False) + + elif args.quantize == "bitnet": + quantize_(model.layers, bitnet_training(), set_inductor_config=False) + + # remove old RMSNorm + for layer in model.layers: + layer.attention_norm = torch.nn.Identity() + layer.ffn_norm = torch.nn.Identity() + + # insert new RMSNorm + def insert_rmsnorm(module: torch.nn.Module): + for name, child in module.named_children(): + if isinstance(child, torch.nn.Linear): + w = child.weight + norm = RMSNorm(child.in_features).to(device=w.device, dtype=w.dtype) + setattr(module, name, torch.nn.Sequential(norm, child)) + else: + insert_rmsnorm(child) + + insert_rmsnorm(model.layers) + elif args.quantize is not None: raise ValueError(f"Unsupported quantize={args.quantize}") diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 0d0176aad2..65391531c9 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -202,6 +202,7 @@ def test_int8_mixed_precision_training(self, compile, config): optim_int8mp.zero_grad() @parametrize("compile", [False, True]) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_bitnet_training(self, compile): # reference implementation # https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf From e3660eb3ce0e89129a5f03cbd16afef7c1fb4b4a Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 25 Sep 2024 09:59:09 +0800 Subject: [PATCH 04/15] add bitnet option to example training script. update backward --- .../quantized_training/pretrain_llama2.py | 20 +++++++++++++------ .../prototype/quantized_training/README.md | 5 ++++- .../prototype/quantized_training/bitnet.py | 18 +++++++++-------- 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/benchmarks/quantized_training/pretrain_llama2.py b/benchmarks/quantized_training/pretrain_llama2.py index cbd580ee4e..6bb181959d 100644 --- a/benchmarks/quantized_training/pretrain_llama2.py +++ b/benchmarks/quantized_training/pretrain_llama2.py @@ -93,6 +93,8 @@ def get_tinystories(): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", default="470M", choices=transformer_configs.keys()) + parser.add_argument("--bf16_model", action="store_true") + parser.add_argument("--bf16_amp", action="store_true") parser.add_argument("--quantize") parser.add_argument("--activation_checkpointing", action="store_true") parser.add_argument("--compile", action="store_true") @@ -116,7 +118,10 @@ def get_tinystories(): config = ModelArgs.from_name(args.model) config.block_size = args.seq_len - model = Transformer(config).bfloat16().cuda() + model = Transformer(config) + if args.bf16_model: + model.bfloat16() + model.cuda() with torch.device("cuda"): model.setup_caches(args.batch_size, args.seq_len, training=True) if args.activation_checkpointing: @@ -178,7 +183,8 @@ def insert_rmsnorm(module: torch.nn.Module): idx = torch.randint(0, data.shape[0] - args.batch_size * args.seq_len, (1,)).item() batch = data[idx : idx + args.batch_size * args.seq_len].view(args.batch_size, args.seq_len).long() - loss = _get_loss(model, batch) + with torch.autocast("cuda", torch.bfloat16, enabled=args.bf16_amp): + loss = _get_loss(model, batch) loss.backward() if step % args.log_interval == 0: @@ -188,10 +194,6 @@ def insert_rmsnorm(module: torch.nn.Module): max_memory_allocated=torch.cuda.max_memory_allocated() / 1e9, max_memory_reserved=torch.cuda.max_memory_reserved() / 1e9, ) - if step > 0: - time1 = time.time() - log_dict["tokens_per_second"] = (args.log_interval * args.batch_size * args.seq_len) / (time1 - time0) - time0 = time1 run.log(log_dict, step=step) pbar.set_postfix(loss=log_dict["loss"]) @@ -201,4 +203,10 @@ def insert_rmsnorm(module: torch.nn.Module): step += 1 pbar.update() + if step % args.log_interval == 0: + time1 = time.time() + log_dict = dict(tokens_per_second=(args.log_interval * args.batch_size * args.seq_len) / (time1 - time0)) + time0 = time1 + run.log(log_dict, step=step) + run.finish() diff --git a/torchao/prototype/quantized_training/README.md b/torchao/prototype/quantized_training/README.md index 52bde67406..b5d502571a 100644 --- a/torchao/prototype/quantized_training/README.md +++ b/torchao/prototype/quantized_training/README.md @@ -146,10 +146,13 @@ model = ... quantize_(model, bitnet_training()) ``` -Note: following the [BitNet Training Tips, Code and FAQ](https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf), user should insert extra RMSNorm before each `nn.Linear` layers and also remove the original RMSNorm before attention and MLP modules. Calling `quantize_(model, bitnet_training())` will NOT perform this for you. +Note: following the [BitNet Training Tips, Code and FAQ](https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf), user should insert extra RMSNorm before each `nn.Linear` layers and also remove the original RMSNorm before attention and MLP modules. Calling `quantize_(model, bitnet_training())` will NOT perform this for you. You can take a look at our example training script [`benchmarks/quantized_training/pretrain_llama2.py`](../../../benchmarks/quantized_training/pretrain_llama2.py) on how to do this for our Llama model. + +See [#930](https://github.com/pytorch/ao/pull/930) for some benchmark results. ## Future ideas +- Extend INT8 weight-only to support tensor-wise scaling, as well as other INTx dtypes. - Tile-wise INT8 quantization to keep quantized weight for both forward and backward pass (similar to JetFire). - INT4 weight only (with group-wise quantization). This can be used with INT4 tinygemm deployment in mind (or other optimized INT4 kernels). - FP8 activation x FP8 weight. The current FP8 training recipe can be seen as a form of QAT, which maintains a high-precision copy of model weights. We can eliminate the high-precision copy. diff --git a/torchao/prototype/quantized_training/bitnet.py b/torchao/prototype/quantized_training/bitnet.py index 45093332ac..092c1d5e40 100644 --- a/torchao/prototype/quantized_training/bitnet.py +++ b/torchao/prototype/quantized_training/bitnet.py @@ -140,7 +140,8 @@ def forward(ctx, input: Tensor, weight: BitNetTrainingLinearWeight, bias: Option # Figure 3 input_i8, row_scale = quantize_int8_rowwise(input, eps=1e-5) weight_i8, tensor_scale = quantize_bitnet_weight(weight._data) - ctx.save_for_backward(input, weight_i8, tensor_scale) + + ctx.save_for_backward(input_i8, row_scale, weight_i8, tensor_scale) # use int8 tensor cores out = scaled_int8_mm(input_i8.contiguous(), weight_i8.contiguous().T, row_scale, tensor_scale) @@ -151,12 +152,11 @@ def forward(ctx, input: Tensor, weight: BitNetTrainingLinearWeight, bias: Option @staticmethod def backward(ctx, grad_output): - input, weight_i8, tensor_scale = ctx.saved_tensors + input_i8, row_scale, weight_i8, tensor_scale = ctx.saved_tensors grad_input = grad_weight = grad_bias = None batch_dims = grad_output.shape[:-1] grad_output = grad_output.view(-1, weight_i8.shape[0]) - input = input.view(-1, weight_i8.shape[1]) # NOTE: we can potentially speedup training by also quantizing the backward pass # to use INT8 tensor cores @@ -166,7 +166,8 @@ def backward(ctx, grad_output): grad_input = grad_input.view(*batch_dims, weight_i8.shape[1]) if ctx.needs_input_grad[1]: - grad_weight = grad_output.T @ input + # NOTE: we use quantized activation for this calculation + grad_weight = grad_output.T @ (input_i8 * row_scale.view(-1, 1)) if ctx.needs_input_grad[2]: grad_bias = grad_output.sum(0) @@ -269,7 +270,8 @@ def forward(ctx, input: Tensor, weight: BitNetPacked2bitLinearWeight, bias: Opti # Figure 3 input_i8, row_scale = quantize_int8_rowwise(input, eps=1e-5) weight_i2, tensor_scale = weight.int_data, weight.scale - ctx.save_for_backward(input, weight_i2, tensor_scale) + + ctx.save_for_backward(input_i8, row_scale, weight_i8, tensor_scale) # use int8 tensor cores # NOTE: is doing dequant inside matmul faster when M is large? @@ -282,13 +284,12 @@ def forward(ctx, input: Tensor, weight: BitNetPacked2bitLinearWeight, bias: Opti @staticmethod def backward(ctx, grad_output): - input, weight_i2, tensor_scale = ctx.saved_tensors + input_i8, row_scale, weight_i2, tensor_scale = ctx.saved_tensors weight_i8 = _unpack_i8_to_i2(weight_i2) grad_input = grad_weight = grad_bias = None batch_dims = grad_output.shape[:-1] grad_output = grad_output.view(-1, weight_i8.shape[0]) - input = input.view(-1, weight_i8.shape[1]) # NOTE: we can potentially speedup training by also quantizing the backward pass # to use INT8 tensor cores @@ -298,7 +299,8 @@ def backward(ctx, grad_output): grad_input = grad_input.view(*batch_dims, weight_i8.shape[1]) if ctx.needs_input_grad[1]: - grad_weight = grad_output.T @ input + # NOTE: we use quantized activation for this calculation + grad_weight = grad_output.T @ (input_i8 * row_scale.view(-1, 1)) if ctx.needs_input_grad[2]: grad_bias = grad_output.sum(0) From 974cd544f7c822b0350170beebf58eaa9ed1c2e0 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 25 Sep 2024 10:29:48 +0800 Subject: [PATCH 05/15] add FSDP2 test --- test/prototype/test_quantized_training.py | 25 +++++++++++++++++-- .../prototype/quantized_training/README.md | 4 +++ .../prototype/quantized_training/bitnet.py | 14 ++++++++--- torchao/utils.py | 1 + 4 files changed, 38 insertions(+), 6 deletions(-) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 65391531c9..237448791b 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -1,6 +1,6 @@ import pytest -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_6 if not TORCH_VERSION_AT_LEAST_2_4: pytest.skip("Requires torch>=2.4", allow_module_level=True) @@ -300,7 +300,28 @@ def test_fsdp2_correctness(self): MixedPrecisionPolicy(param_dtype=torch.bfloat16), 1e-2, ), + ( + bitnet_training(), + bitnet_training(), + MixedPrecisionPolicy(), + 1e-6, + ), ] + + # FSDP2 mixed-precision requires this commit + # https://github.com/pytorch/pytorch/pull/136129 + # TODO: add FSDP2 mixed-precision test for int8_weight_only + if TORCH_VERSION_AT_LEAST_2_6: + extra_args = [ + ( + bitnet_training(), + bitnet_training(), + MixedPrecisionPolicy(param_dtype=torch.bfloat16), + 1e-2, + ), + ] + test_args.extend(extra_args) + self.run_subtests({"args": test_args}, self._run_subtest) def _run_subtest(self, args): @@ -353,7 +374,7 @@ def _run_subtest(self, args): base_optim.step() rel_error = (fsdp_loss - base_loss).abs() / base_loss.abs() - assert rel_error < tolerance, (iter_idx, rel_error) + assert rel_error < tolerance, (args, iter_idx, rel_error) instantiate_parametrized_tests(TestQuantizedTraining) diff --git a/torchao/prototype/quantized_training/README.md b/torchao/prototype/quantized_training/README.md index b5d502571a..d7498b24a0 100644 --- a/torchao/prototype/quantized_training/README.md +++ b/torchao/prototype/quantized_training/README.md @@ -20,6 +20,8 @@ There are 3 main benefits of using low-precision dtype for training (the extent [`benchmarks/quantized_training/pretrain_llama2.py`](../../../benchmarks/quantized_training/pretrain_llama2.py) demonstrates an end-to-end Llama2 pre-training on single GPU for strategies implemented in this folder. +All features in this folder are tested to work with PyTorch 2.4+ unless otherwise stated. + ## INT8 quantized training Typically, quantized weights cannot be trained directly due to quantization error: a small change in the quantized weight will be round down to zero. To tackle this problem, we use **stochastic rounding** for weight update. In simple terms, stochastic rounding will round up or down randomly, but with a higher chance if it is closer to that direction. For example, 0.8 will have 80% chance of rounding up and 20% of rounding down. It also follows that on average, stochastic rounding will estimate the floating point value exactly. @@ -146,6 +148,8 @@ model = ... quantize_(model, bitnet_training()) ``` +Training with FSDP2 is also supported. To use FDSP2 mixed-precision with `param_dtype` != model dtype, PyTorch 2.6+ is required. + Note: following the [BitNet Training Tips, Code and FAQ](https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf), user should insert extra RMSNorm before each `nn.Linear` layers and also remove the original RMSNorm before attention and MLP modules. Calling `quantize_(model, bitnet_training())` will NOT perform this for you. You can take a look at our example training script [`benchmarks/quantized_training/pretrain_llama2.py`](../../../benchmarks/quantized_training/pretrain_llama2.py) on how to do this for our Llama model. See [#930](https://github.com/pytorch/ao/pull/930) for some benchmark results. diff --git a/torchao/prototype/quantized_training/bitnet.py b/torchao/prototype/quantized_training/bitnet.py index 092c1d5e40..37c4bbb772 100644 --- a/torchao/prototype/quantized_training/bitnet.py +++ b/torchao/prototype/quantized_training/bitnet.py @@ -86,11 +86,17 @@ def __torch_dispatch__(cls, func, types, args, kwargs): # return new unwrapped object return out - def fsdp_pre_all_gather(self, mesh): + # require https://github.com/pytorch/pytorch/pull/136129 for mixed-precision param_dtype + # we need default None for module and mp_policy so this method still works with PyTorch 2.4 and 2.5 + def fsdp_pre_all_gather(self, mesh, module=None, mp_policy=None): + data = self._data + if mp_policy is not None: + data = data.to(mp_policy.param_dtype) + # quantize and pack into 2-bit to save comm bandwidth # TODO: precompute absmean similar to float8 - data = BitNetPacked2bitLinearWeight.from_float(self._data, all_reduce=True) - return (data.int_data,), (data.scale,) + packed_data = BitNetPacked2bitLinearWeight.from_float(data, all_reduce=True) + return (packed_data.int_data,), (packed_data.scale,) def fsdp_post_all_gather( self, @@ -271,7 +277,7 @@ def forward(ctx, input: Tensor, weight: BitNetPacked2bitLinearWeight, bias: Opti input_i8, row_scale = quantize_int8_rowwise(input, eps=1e-5) weight_i2, tensor_scale = weight.int_data, weight.scale - ctx.save_for_backward(input_i8, row_scale, weight_i8, tensor_scale) + ctx.save_for_backward(input_i8, row_scale, weight_i2, tensor_scale) # use int8 tensor cores # NOTE: is doing dequant inside matmul faster when M is large? diff --git a/torchao/utils.py b/torchao/utils.py index 1f4f66e1f4..7598c048a5 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -317,6 +317,7 @@ def is_fbcode(): def torch_version_at_least(min_version): return is_fbcode() or compare_versions(torch.__version__, min_version) >= 0 +TORCH_VERSION_AT_LEAST_2_6 = torch_version_at_least("2.6.0") TORCH_VERSION_AT_LEAST_2_5 = torch_version_at_least("2.5.0") TORCH_VERSION_AT_LEAST_2_4 = torch_version_at_least("2.4.0") TORCH_VERSION_AT_LEAST_2_3 = torch_version_at_least("2.3.0") From 36e60c42ea3353bed25f35132283d8213bba97ed Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 25 Sep 2024 10:44:35 +0800 Subject: [PATCH 06/15] remove FSDP2 mixed-precision workaround. cleanup test --- test/prototype/test_quantized_training.py | 59 +++++++------------ .../prototype/quantized_training/README.md | 32 +--------- torchao/prototype/quantized_training/int8.py | 10 +++- .../int8_mixed_precision.py | 16 ++--- 4 files changed, 36 insertions(+), 81 deletions(-) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 237448791b..377264da2b 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -276,56 +276,37 @@ def world_size(self) -> int: @skip_if_lt_x_gpu(_FSDP_WORLD_SIZE) def test_fsdp2_correctness(self): + mp_policy = MixedPrecisionPolicy() + + # quantize_fn, mp_policy, tolerance test_args = [ - ( - int8_weight_only_quantized_training(), # quantize_fn for base model - int8_weight_only_quantized_training(), # quantize_fn for FSDP model - MixedPrecisionPolicy(), - 0.05, # tolerance. due to stochastic rounding, use a pretty large tolerance here - ), - ( - int8_mixed_precision_training(), - int8_mixed_precision_training(), - MixedPrecisionPolicy(), - 1e-6, - ), - ( - # It's complicated (though possible) to simulate FSDP BF16 mixed-precision for base_model. - # We would need to cast all params to BF16 in forward and backward pass, while keeping - # the params in FP32 for optim step. - # torch.autocast() will only do this for F.linear() layer (and its backward). - # To keep it simple, we just use a larger tolerance here. - int8_mixed_precision_training(), - int8_mixed_precision_training(Int8MixedPrecisionTrainingConfig(fsdp_param_dtype=torch.bfloat16)), - MixedPrecisionPolicy(param_dtype=torch.bfloat16), - 1e-2, - ), - ( - bitnet_training(), - bitnet_training(), - MixedPrecisionPolicy(), - 1e-6, - ), + # high tolerance due to stochastic rounding + (int8_weight_only_quantized_training(), mp_policy, 0.05), + (int8_mixed_precision_training(), mp_policy, 1e-6), + (bitnet_training(), mp_policy, 1e-6), ] # FSDP2 mixed-precision requires this commit # https://github.com/pytorch/pytorch/pull/136129 - # TODO: add FSDP2 mixed-precision test for int8_weight_only if TORCH_VERSION_AT_LEAST_2_6: + # It's complicated (though possible) to simulate FSDP BF16 mixed-precision for base_model. + # We would need to cast all params to BF16 in forward and backward pass, while keeping + # the params in FP32 for optim step. + # torch.autocast() will only do this for F.linear() layer (and its backward). + # To keep it simple, we just use a larger tolerance here. + bf16_mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16) + extra_args = [ - ( - bitnet_training(), - bitnet_training(), - MixedPrecisionPolicy(param_dtype=torch.bfloat16), - 1e-2, - ), + (int8_weight_only_quantized_training(), bf16_mp_policy, 1e-2), + (int8_mixed_precision_training(), bf16_mp_policy, 1e-2), + (bitnet_training(), bf16_mp_policy, 1e-2), ] test_args.extend(extra_args) self.run_subtests({"args": test_args}, self._run_subtest) def _run_subtest(self, args): - base_quantize_fn, fsdp_quantize_fn, mp_policy, tolerance = args + quantize_fn, mp_policy, tolerance = args batch_size = 3 vocab_size = 32 @@ -344,8 +325,8 @@ def _run_subtest(self, args): base_model = Transformer(model_args).cuda() fsdp_model = copy.deepcopy(base_model) - quantize_(base_model.layers, base_quantize_fn, set_inductor_config=False) - quantize_(fsdp_model.layers, fsdp_quantize_fn, set_inductor_config=False) + quantize_(base_model.layers, quantize_fn, set_inductor_config=False) + quantize_(fsdp_model.layers, quantize_fn, set_inductor_config=False) for layer in fsdp_model.layers: fully_shard(layer, mp_policy=mp_policy) diff --git a/torchao/prototype/quantized_training/README.md b/torchao/prototype/quantized_training/README.md index d7498b24a0..3338594406 100644 --- a/torchao/prototype/quantized_training/README.md +++ b/torchao/prototype/quantized_training/README.md @@ -20,7 +20,7 @@ There are 3 main benefits of using low-precision dtype for training (the extent [`benchmarks/quantized_training/pretrain_llama2.py`](../../../benchmarks/quantized_training/pretrain_llama2.py) demonstrates an end-to-end Llama2 pre-training on single GPU for strategies implemented in this folder. -All features in this folder are tested to work with PyTorch 2.4+ unless otherwise stated. +All features in this folder are tested to work with PyTorch 2.4+ unless otherwise stated. Training with FSDP2 is also supported, but if you use FDSP2 mixed-precision with `param_dtype` != model dtype, PyTorch 2.6+ is required. ## INT8 quantized training @@ -58,7 +58,7 @@ BF16 compile | 10.25 | 9000 INT8 QT eager | 10.12 | 5600 INT8 QT compile | 9.84 | 8700 -## INT8 mixed-precision +## INT8 mixed-precision training On NVIDIA GPUs, INT8 Tensor Cores is approximately 2x faster than their BF16/FP16 counterparts. In mixed-precision training, we can down-cast activations and weights dynamically to INT8 to leverage faster matmuls. However, since INT8 has very limited range [-128,127], we perform row-wise quantization, similar to how INT8 post-training quantization (PTQ) is done. Weight is still in original precision. @@ -106,32 +106,6 @@ INT8 mixed-precision | ~29k | 19.47 | 2.90 See [#748](https://github.com/pytorch/ao/pull/748) for more results. -### FSDP support - -Out of the box, this INT8 mixed-precision training is not compatible with FSDP2 `MixedPrecisionPolicy(param_dtype=param_dtype)`, where `param_dtype` != model dtype. As a workaround, you will need to manually specify the FSDP2's `param_dtype` in `Int8MixedPrecisionTrainingConfig` - -```python -from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy -from torchao.prototype.quantized_training import int8_mixed_precision_training, Int8MixedPrecisionTrainingConfig -from torchao import quantize_ - -model = ... # FP32 model - -# setup configs -mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16) -int8mp_config = Int8MixedPrecisionTrainingConfig(fsdp_param_dtype=mp_policy.param_dtype) - -# exclude LM head -quantize_(model.layers, int8_mixed_precision_training(int8mp_config)) - -# shard the model w/ FSDP2 -for layer in model.layers: - fully_shard(layer, mp_policy=mp_policy) -fully_shard(model, mp_policy=mp_policy) - -# train model as usual -``` - ## BitNet b1.58 [BitNet b1.58](https://arxiv.org/abs/2402.17764) uses ternary weights: each parameter can only take on 3 distinct values {-1, 0, +1}, thus making a BitNet model very compact. BitNet uses tensor-wise abs-mean scaling for weights (quantize to ternary) and row-wise abs-max scaling for activations (quantize to INT8). @@ -148,8 +122,6 @@ model = ... quantize_(model, bitnet_training()) ``` -Training with FSDP2 is also supported. To use FDSP2 mixed-precision with `param_dtype` != model dtype, PyTorch 2.6+ is required. - Note: following the [BitNet Training Tips, Code and FAQ](https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf), user should insert extra RMSNorm before each `nn.Linear` layers and also remove the original RMSNorm before attention and MLP modules. Calling `quantize_(model, bitnet_training())` will NOT perform this for you. You can take a look at our example training script [`benchmarks/quantized_training/pretrain_llama2.py`](../../../benchmarks/quantized_training/pretrain_llama2.py) on how to do this for our Llama model. See [#930](https://github.com/pytorch/ao/pull/930) for some benchmark results. diff --git a/torchao/prototype/quantized_training/int8.py b/torchao/prototype/quantized_training/int8.py index 672c3a223e..1273eda83f 100644 --- a/torchao/prototype/quantized_training/int8.py +++ b/torchao/prototype/quantized_training/int8.py @@ -99,8 +99,14 @@ def __repr__(self): f"requires_grad={self.requires_grad})" ) - def fsdp_pre_all_gather(self, mesh): - return (self.int_data, self.scale), None + # require https://github.com/pytorch/pytorch/pull/136129 for mixed-precision param_dtype + # we need default None for module and mp_policy so this method still works with PyTorch 2.4 and 2.5 + def fsdp_pre_all_gather(self, mesh, module=None, mp_policy=None): + scale = self.scale + if mp_policy is not None: + scale = scale.to(mp_policy.param_dtype) + + return (self.int_data, scale), None def fsdp_post_all_gather( self, diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index 9cf73fc2a5..ec3dd0cb81 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -27,10 +27,6 @@ class Int8MixedPrecisionTrainingConfig(NamedTuple): grad_input: bool = True grad_weight: bool = True - # workaround for FSDP2 with `MixedPrecisionPolicy(param_dtype)` - # see `Int8MixedPrecisionTrainingLinearWeight.fsdp_pre_all_gather()` for more details. - fsdp_param_dtype: Optional[torch.dtype] = None - _DEFAULT_CONFIG = Int8MixedPrecisionTrainingConfig() @@ -114,15 +110,15 @@ def unwrap(x: cls): # return new unwrapped object return out - def fsdp_pre_all_gather(self, mesh): + # require https://github.com/pytorch/pytorch/pull/136129 for mixed-precision param_dtype + # we need default None for module and mp_policy so this method still works with PyTorch 2.4 and 2.5 + def fsdp_pre_all_gather(self, mesh, module=None, mp_policy=None): # TODO: pre-quantize weight here -> reduce comm bandwidth. # we will need another tensor subclass to hold the quantized weight. + data = self._data + if mp_policy is not None: + data = data.to(mp_policy.param_dtype) - # doing dtype casting to `param_dtype` in `fsdp_post_all_gather()` will give wrong results. - # as a workaround, we do it in `fsdp_pre_all_gather()` instead. since `param_dtype` is not - # exposed to `fsdp_pre_all_gather()`, we need to specify it in the config. - # this workaround can be removed once we implement INT8 communication. - data = self._data.to(dtype=self.config.fsdp_param_dtype) return (data,), (self.config,) def fsdp_post_all_gather( From 5d1244804c3bc83eda6acb17c69936632b04eeca Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 25 Sep 2024 10:46:42 +0800 Subject: [PATCH 07/15] fix typo --- test/prototype/test_quantized_training.py | 3 +-- torchao/prototype/quantized_training/README.md | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 377264da2b..7263b9ae18 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -286,8 +286,7 @@ def test_fsdp2_correctness(self): (bitnet_training(), mp_policy, 1e-6), ] - # FSDP2 mixed-precision requires this commit - # https://github.com/pytorch/pytorch/pull/136129 + # FSDP2 mixed-precision requires https://github.com/pytorch/pytorch/pull/136129 if TORCH_VERSION_AT_LEAST_2_6: # It's complicated (though possible) to simulate FSDP BF16 mixed-precision for base_model. # We would need to cast all params to BF16 in forward and backward pass, while keeping diff --git a/torchao/prototype/quantized_training/README.md b/torchao/prototype/quantized_training/README.md index 3338594406..10a3d57ccc 100644 --- a/torchao/prototype/quantized_training/README.md +++ b/torchao/prototype/quantized_training/README.md @@ -110,7 +110,7 @@ See [#748](https://github.com/pytorch/ao/pull/748) for more results. [BitNet b1.58](https://arxiv.org/abs/2402.17764) uses ternary weights: each parameter can only take on 3 distinct values {-1, 0, +1}, thus making a BitNet model very compact. BitNet uses tensor-wise abs-mean scaling for weights (quantize to ternary) and row-wise abs-max scaling for activations (quantize to INT8). -BitNet is originally trained with QAT: the weights and activations are fake-quantized, and straight-through estimator (STE) is used to calculate gradients with respect to floating point weights. This process adds extra overhead over standard straining. Our implementation utilizes INT8 Tensor Cores to make up for this loss in speed. In fact, our implementation is faster than BF16 training in most cases. +BitNet is originally trained with QAT: the weights and activations are fake-quantized, and straight-through estimator (STE) is used to calculate gradients with respect to floating point weights. This process adds extra overhead over standard training. Our implementation utilizes INT8 Tensor Cores to make up for this loss in speed. In fact, our implementation is faster than BF16 training in most cases. Usage From 62218eff72661ceb0e5068f1901c3fd7fddd704c Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 25 Sep 2024 11:43:36 +0800 Subject: [PATCH 08/15] adjust tolerance --- test/prototype/test_quantized_training.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 7263b9ae18..d62d547991 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -281,9 +281,9 @@ def test_fsdp2_correctness(self): # quantize_fn, mp_policy, tolerance test_args = [ # high tolerance due to stochastic rounding - (int8_weight_only_quantized_training(), mp_policy, 0.05), - (int8_mixed_precision_training(), mp_policy, 1e-6), - (bitnet_training(), mp_policy, 1e-6), + (int8_weight_only_quantized_training, mp_policy, 0.05), + (int8_mixed_precision_training, mp_policy, 1e-6), + (bitnet_training, mp_policy, 1e-5), ] # FSDP2 mixed-precision requires https://github.com/pytorch/pytorch/pull/136129 @@ -296,9 +296,9 @@ def test_fsdp2_correctness(self): bf16_mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16) extra_args = [ - (int8_weight_only_quantized_training(), bf16_mp_policy, 1e-2), - (int8_mixed_precision_training(), bf16_mp_policy, 1e-2), - (bitnet_training(), bf16_mp_policy, 1e-2), + (int8_weight_only_quantized_training, bf16_mp_policy, 1e-2), + (int8_mixed_precision_training, bf16_mp_policy, 1e-2), + (bitnet_training, bf16_mp_policy, 1e-2), ] test_args.extend(extra_args) @@ -324,8 +324,8 @@ def _run_subtest(self, args): base_model = Transformer(model_args).cuda() fsdp_model = copy.deepcopy(base_model) - quantize_(base_model.layers, quantize_fn, set_inductor_config=False) - quantize_(fsdp_model.layers, quantize_fn, set_inductor_config=False) + quantize_(base_model.layers, quantize_fn(), set_inductor_config=False) + quantize_(fsdp_model.layers, quantize_fn(), set_inductor_config=False) for layer in fsdp_model.layers: fully_shard(layer, mp_policy=mp_policy) @@ -354,7 +354,7 @@ def _run_subtest(self, args): base_optim.step() rel_error = (fsdp_loss - base_loss).abs() / base_loss.abs() - assert rel_error < tolerance, (args, iter_idx, rel_error) + assert rel_error < tolerance, (quantize_fn.__name__, mp_policy, iter_idx, rel_error) instantiate_parametrized_tests(TestQuantizedTraining) From 6990b61f93d6ce8bd04a0f78e2ce31579bf8253b Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 25 Sep 2024 13:33:24 +0800 Subject: [PATCH 09/15] update command --- benchmarks/quantized_training/pretrain_llama2.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/benchmarks/quantized_training/pretrain_llama2.py b/benchmarks/quantized_training/pretrain_llama2.py index 6bb181959d..b7e111a44f 100644 --- a/benchmarks/quantized_training/pretrain_llama2.py +++ b/benchmarks/quantized_training/pretrain_llama2.py @@ -1,10 +1,10 @@ # pre-train a mini Llama2 on TinyStories with INT8 quantized training # pip install huggingface_hub sentencepiece wandb # -# BF16 baseline: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile -# INT8 QT: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_weight_only -# INT8 MP: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_mixed_precision -# BitNet: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize bitnet +# BF16 baseline: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --bf16_model --compile +# INT8 QT: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --bf16_model --compile --quantize int8_weight_only +# INT8 MP: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --bf16_model --compile --quantize int8_mixed_precision +# BitNet: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --bf16_model --compile --quantize bitnet import os From 0e8f898eb63b44c514c5aba35e9c7be86ffba1f4 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 25 Sep 2024 14:30:01 +0800 Subject: [PATCH 10/15] add precompute scale for FSDP2 --- .../prototype/quantized_training/README.md | 11 +++ .../prototype/quantized_training/__init__.py | 2 +- .../prototype/quantized_training/bitnet.py | 91 +++++++++++++------ 3 files changed, 73 insertions(+), 31 deletions(-) diff --git a/torchao/prototype/quantized_training/README.md b/torchao/prototype/quantized_training/README.md index 10a3d57ccc..b1b32a8be6 100644 --- a/torchao/prototype/quantized_training/README.md +++ b/torchao/prototype/quantized_training/README.md @@ -124,6 +124,17 @@ quantize_(model, bitnet_training()) Note: following the [BitNet Training Tips, Code and FAQ](https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf), user should insert extra RMSNorm before each `nn.Linear` layers and also remove the original RMSNorm before attention and MLP modules. Calling `quantize_(model, bitnet_training())` will NOT perform this for you. You can take a look at our example training script [`benchmarks/quantized_training/pretrain_llama2.py`](../../../benchmarks/quantized_training/pretrain_llama2.py) on how to do this for our Llama model. +When used with FSDP2 training, you can pre-compute BitNet weight scales for the next iteration to synchronize all scales with a single all-reduce operation. This should be done after optimizer step. + +```python +from torchao.prototype.quantized_training import precompute_bitnet_scale_for_fsdp + +for _ in range(n_steps): + model(inputs).sum().backward() + optim.step() + precompute_bitnet_scale_for_fsdp(model) +``` + See [#930](https://github.com/pytorch/ao/pull/930) for some benchmark results. ## Future ideas diff --git a/torchao/prototype/quantized_training/__init__.py b/torchao/prototype/quantized_training/__init__.py index cdaf39ab35..c3c9b7cfaf 100644 --- a/torchao/prototype/quantized_training/__init__.py +++ b/torchao/prototype/quantized_training/__init__.py @@ -1,4 +1,4 @@ -from .bitnet import BitNetTrainingLinearWeight, bitnet_training +from .bitnet import BitNetTrainingLinearWeight, bitnet_training, precompute_bitnet_scale_for_fsdp from .int8 import ( Int8QuantizedTrainingLinearWeight, int8_weight_only_quantized_training, diff --git a/torchao/prototype/quantized_training/bitnet.py b/torchao/prototype/quantized_training/bitnet.py index 37c4bbb772..0a47c8db0d 100644 --- a/torchao/prototype/quantized_training/bitnet.py +++ b/torchao/prototype/quantized_training/bitnet.py @@ -8,8 +8,9 @@ import torch.distributed as dist import torch.nn.functional as F import torch.utils._pytree as pytree -from torch import Tensor +from torch import Tensor, nn from torch.utils._triton import has_triton +from torch.distributed._tensor import DTensor from torchao.quantization.quant_api import _get_linear_subclass_inserter from torchao.utils import TorchAOBaseTensor @@ -34,7 +35,7 @@ def scaled_int8_mm(A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tensor) - class BitNetTrainingLinearWeight(TorchAOBaseTensor): @staticmethod @torch._dynamo.disable - def __new__(cls, data: Tensor): + def __new__(cls, data: Tensor, precomputed_scale: Tensor | None = None): return Tensor._make_wrapper_subclass( cls, data.shape, @@ -43,15 +44,19 @@ def __new__(cls, data: Tensor): ) @torch._dynamo.disable - def __init__(self, data: Tensor): + def __init__(self, data: Tensor, precomputed_scale: Tensor | None = None): self._data = data + self._precomputed_scale = precomputed_scale def __tensor_flatten__(self): - return ["_data"], [] + if self._precomputed_scale is not None: + return ["_data", "_precomputed_scale"], [] + else: + return ["_data"], [] @classmethod def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): - return cls(tensor_data_dict["_data"], *tensor_attributes) + return cls(tensor_data_dict["_data"], tensor_data_dict.get("_precomputed_scale", None), *tensor_attributes) def __repr__(self): return f"{self.__class__.__name__}(data={self._data})" @@ -64,6 +69,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): **pytree.tree_map_only(cls, lambda x: x._data, kwargs), ) + # NOTE: _precomputed_scale does not propagate through any ops if func is aten.copy_.default: # return original object return args[0] @@ -86,17 +92,21 @@ def __torch_dispatch__(cls, func, types, args, kwargs): # return new unwrapped object return out - # require https://github.com/pytorch/pytorch/pull/136129 for mixed-precision param_dtype + # new signature https://github.com/pytorch/pytorch/pull/136129 # we need default None for module and mp_policy so this method still works with PyTorch 2.4 and 2.5 def fsdp_pre_all_gather(self, mesh, module=None, mp_policy=None): - data = self._data - if mp_policy is not None: - data = data.to(mp_policy.param_dtype) - # quantize and pack into 2-bit to save comm bandwidth - # TODO: precompute absmean similar to float8 - packed_data = BitNetPacked2bitLinearWeight.from_float(data, all_reduce=True) - return (packed_data.int_data,), (packed_data.scale,) + if self._precomputed_scale is not None: + scale = self._precomputed_scale + + else: + scale = get_bitnet_scale(self._data) + dist.all_reduce(scale, op=dist.ReduceOp.AVG) + + # NOTE: scale is in FP32 + data_i8 = quantize_bitnet_weight(self._data, scale) + data_i2 = _pack_i2_to_i8(data_i8) + return (data_i2,), (scale,) def fsdp_post_all_gather( self, @@ -106,13 +116,14 @@ def fsdp_post_all_gather( *, out: Optional[Tensor] = None, ): - (int_data,) = all_gather_outputs + (data_i2,) = all_gather_outputs (scale,) = metadata + scale = scale.to(param_dtype) if out is not None: assert isinstance(out, BitNetPacked2bitLinearWeight) out.scale = scale return - return BitNetPacked2bitLinearWeight(int_data, scale), all_gather_outputs + return BitNetPacked2bitLinearWeight(data_i2, scale), all_gather_outputs @BitNetTrainingLinearWeight.implements(F.linear) @@ -123,17 +134,38 @@ def _(func, types, args, kwargs): return _BitNetTrainingLinear.apply(*args, **kwargs) -def quantize_bitnet_weight(w: Tensor, eps: float = 1e-5, all_reduce: bool = False) -> Tensor: - dtype = w.dtype - w = w.float() - scale = w.abs().mean() # tensor-wise abs-mean. FP32 +def get_bitnet_scale(x: Tensor): + "Tensor-wise abs-mean. Always return FP32." + return x.float().abs().mean() - if all_reduce and dist.is_initialized(): - dist.all_reduce(scale, op=dist.ReduceOp.AVG) - w = w / scale.clip(eps) +def quantize_bitnet_weight(w: Tensor, scale: Tensor, eps: float = 1e-5) -> Tensor: + w = w.float() / scale.clip(eps) w = w.round().clip(-1, 1).to(torch.int8) - return w, scale.to(dtype) + return w + + +@torch.no_grad() +def precompute_bitnet_scale_for_fsdp(module: nn.Module): + """Calculate scale for all BitNetTrainingLinearWeight parameters. + This should be run after the optimizer step. It performs a single all-reduce for all + parameters to reduce overhead. + """ + bitnet_params = [ + p + for p in module.parameters() + if isinstance(p, DTensor) and isinstance(p.to_local(), BitNetTrainingLinearWeight) + ] + if len(bitnet_params) == 0: + return + + # NOTE: use torch.compile to save memory and increase speed? + bitnet_scales = [get_bitnet_scale(x) for x in bitnet_params] # local absmean + bitnet_scales = torch.stack(bitnet_scales) + bitnet_scales = bitnet_scales.full_tensor() # global absmean + + for i, p in enumerate(bitnet_params): + p._local_tensor._precomputed_scale = bitnet_scales[i] class _BitNetTrainingLinear(torch.autograd.Function): @@ -145,7 +177,12 @@ def forward(ctx, input: Tensor, weight: BitNetTrainingLinearWeight, bias: Option # https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf # Figure 3 input_i8, row_scale = quantize_int8_rowwise(input, eps=1e-5) - weight_i8, tensor_scale = quantize_bitnet_weight(weight._data) + + # NOTE: use FP32 scale for weight quantization, but cast scale to possibly lower precision + # for matmul and backward + tensor_scale = get_bitnet_scale(weight._data) + weight_i8 = quantize_bitnet_weight(weight._data, tensor_scale) + tensor_scale = tensor_scale.to(weight.dtype) ctx.save_for_backward(input_i8, row_scale, weight_i8, tensor_scale) @@ -228,12 +265,6 @@ def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=No def __repr__(self): return f"{self.__class__.__name__}(data={self.dequantize()})" - @classmethod - def from_float(cls, tensor: Tensor, *, eps: float = 1e-5, all_reduce: bool = False): - int_data, scale = quantize_bitnet_weight(tensor, eps=eps, all_reduce=all_reduce) - int_data = _pack_i2_to_i8(int_data) - return BitNetPacked2bitLinearWeight(int_data, scale) - def dequantize(self, out_dtype=None): out = _unpack_i8_to_i2(self.int_data) * self.scale if out_dtype is not None: From d375fb256291fe5dd168bb3c70935b69a8592676 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 25 Sep 2024 15:04:30 +0800 Subject: [PATCH 11/15] fix typing --- torchao/prototype/quantized_training/bitnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/quantized_training/bitnet.py b/torchao/prototype/quantized_training/bitnet.py index 0a47c8db0d..a7320b5bba 100644 --- a/torchao/prototype/quantized_training/bitnet.py +++ b/torchao/prototype/quantized_training/bitnet.py @@ -35,7 +35,7 @@ def scaled_int8_mm(A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tensor) - class BitNetTrainingLinearWeight(TorchAOBaseTensor): @staticmethod @torch._dynamo.disable - def __new__(cls, data: Tensor, precomputed_scale: Tensor | None = None): + def __new__(cls, data: Tensor, precomputed_scale: Optional[Tensor] = None): return Tensor._make_wrapper_subclass( cls, data.shape, @@ -44,7 +44,7 @@ def __new__(cls, data: Tensor, precomputed_scale: Tensor | None = None): ) @torch._dynamo.disable - def __init__(self, data: Tensor, precomputed_scale: Tensor | None = None): + def __init__(self, data: Tensor, precomputed_scale: Optional[Tensor] = None): self._data = data self._precomputed_scale = precomputed_scale From c966fcec0c5517ccb1f6524bc48e096cc2a44abc Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 25 Sep 2024 19:23:39 +0800 Subject: [PATCH 12/15] add test for precompute scale --- test/prototype/test_quantized_training.py | 18 ++++++++++++++++++ torchao/prototype/quantized_training/bitnet.py | 2 +- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index d62d547991..2e8103bb00 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -356,6 +356,24 @@ def _run_subtest(self, args): rel_error = (fsdp_loss - base_loss).abs() / base_loss.abs() assert rel_error < tolerance, (quantize_fn.__name__, mp_policy, iter_idx, rel_error) + @skip_if_lt_x_gpu(_FSDP_WORLD_SIZE) + def test_precompute_bitnet_scale(self): + from torchao.prototype.quantized_training.bitnet import get_bitnet_scale, precompute_bitnet_scale_for_fsdp + + model = nn.Sequential(nn.Linear(32, 64), nn.GELU(), nn.Linear(64, 32)).cuda() + model_fsdp = copy.deepcopy(model) + quantize_(model_fsdp, bitnet_training()) + fully_shard(model_fsdp) + + precompute_bitnet_scale_for_fsdp(model_fsdp) + + torch.testing.assert_close( + get_bitnet_scale(model[0].weight), model_fsdp[0].weight._local_tensor._precomputed_scale + ) + torch.testing.assert_close( + get_bitnet_scale(model[2].weight), model_fsdp[2].weight._local_tensor._precomputed_scale + ) + instantiate_parametrized_tests(TestQuantizedTraining) diff --git a/torchao/prototype/quantized_training/bitnet.py b/torchao/prototype/quantized_training/bitnet.py index a7320b5bba..ef6774d400 100644 --- a/torchao/prototype/quantized_training/bitnet.py +++ b/torchao/prototype/quantized_training/bitnet.py @@ -154,7 +154,7 @@ def precompute_bitnet_scale_for_fsdp(module: nn.Module): bitnet_params = [ p for p in module.parameters() - if isinstance(p, DTensor) and isinstance(p.to_local(), BitNetTrainingLinearWeight) + if isinstance(p, DTensor) and isinstance(p._local_tensor, BitNetTrainingLinearWeight) ] if len(bitnet_params) == 0: return From 521b04eb4e937eff7f7c1e92ece681a02904d2e0 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 25 Sep 2024 19:59:57 +0800 Subject: [PATCH 13/15] rename --- torchao/prototype/quantized_training/bitnet.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchao/prototype/quantized_training/bitnet.py b/torchao/prototype/quantized_training/bitnet.py index ef6774d400..023f32cac0 100644 --- a/torchao/prototype/quantized_training/bitnet.py +++ b/torchao/prototype/quantized_training/bitnet.py @@ -105,7 +105,7 @@ def fsdp_pre_all_gather(self, mesh, module=None, mp_policy=None): # NOTE: scale is in FP32 data_i8 = quantize_bitnet_weight(self._data, scale) - data_i2 = _pack_i2_to_i8(data_i8) + data_i2 = _pack_i2_in_i8(data_i8) return (data_i2,), (scale,) def fsdp_post_all_gather( @@ -222,12 +222,12 @@ def bitnet_training(): return _get_linear_subclass_inserter(BitNetTrainingLinearWeight, allow_requires_grad=True) -def _pack_i2_to_i8(x: Tensor): +def _pack_i2_in_i8(x: Tensor): # NOTE: this is signed integer, so we have to mask before bit-shift return (x[:, ::4] << 6) | ((x[:, 1::4] & 0b11) << 4) | ((x[:, 2::4] & 0b11) << 2) | (x[:, 3::4] & 0b11) -def _unpack_i8_to_i2(x: Tensor): +def _unpack_i2_in_i8(x: Tensor): # NOTE: this is signed integer, so left-shift then right-shift will perform sign extension correctly # e.g. aa10bbcc -> 10bbcc00 -> 11111110 return torch.stack([x >> 6, x << 2 >> 6, x << 4 >> 6, x << 6 >> 6], dim=-1).view(x.shape[0], -1) @@ -266,7 +266,7 @@ def __repr__(self): return f"{self.__class__.__name__}(data={self.dequantize()})" def dequantize(self, out_dtype=None): - out = _unpack_i8_to_i2(self.int_data) * self.scale + out = _unpack_i2_in_i8(self.int_data) * self.scale if out_dtype is not None: out = out.to(out_dtype) return out @@ -312,7 +312,7 @@ def forward(ctx, input: Tensor, weight: BitNetPacked2bitLinearWeight, bias: Opti # use int8 tensor cores # NOTE: is doing dequant inside matmul faster when M is large? - weight_i8 = _unpack_i8_to_i2(weight_i2) + weight_i8 = _unpack_i2_in_i8(weight_i2) out = scaled_int8_mm(input_i8.contiguous(), weight_i8.contiguous().T, row_scale, tensor_scale) out = out.view(*batch_dims, weight.shape[0]) @@ -322,7 +322,7 @@ def forward(ctx, input: Tensor, weight: BitNetPacked2bitLinearWeight, bias: Opti @staticmethod def backward(ctx, grad_output): input_i8, row_scale, weight_i2, tensor_scale = ctx.saved_tensors - weight_i8 = _unpack_i8_to_i2(weight_i2) + weight_i8 = _unpack_i2_in_i8(weight_i2) grad_input = grad_weight = grad_bias = None batch_dims = grad_output.shape[:-1] From 2ca74600fcdfff3a1e2c848a7a6a3956ec3b577e Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 29 Sep 2024 08:18:51 +0800 Subject: [PATCH 14/15] separate BitNet model surgery --- .../quantized_training/pretrain_llama2.py | 33 +++++++++++-------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/benchmarks/quantized_training/pretrain_llama2.py b/benchmarks/quantized_training/pretrain_llama2.py index b7e111a44f..0085f24264 100644 --- a/benchmarks/quantized_training/pretrain_llama2.py +++ b/benchmarks/quantized_training/pretrain_llama2.py @@ -4,7 +4,7 @@ # BF16 baseline: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --bf16_model --compile # INT8 QT: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --bf16_model --compile --quantize int8_weight_only # INT8 MP: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --bf16_model --compile --quantize int8_mixed_precision -# BitNet: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --bf16_model --compile --quantize bitnet +# BitNet: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --bf16_model --compile --quantize bitnet --modify_rmsnorm_for_bitnet import os @@ -99,6 +99,8 @@ def get_tinystories(): parser.add_argument("--activation_checkpointing", action="store_true") parser.add_argument("--compile", action="store_true") + parser.add_argument("--modify_rmsnorm_for_bitnet", action="store_true") + parser.add_argument("--n_steps", type=int, default=1000) parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--seq_len", type=int, default=2048) @@ -128,22 +130,14 @@ def get_tinystories(): for layer in model.layers: enable_activation_checkpointing(layer) - # don't apply int8_mixed_precision to LM head, since it can cause convergence issue. - # TODO: might want to do the same for int8_weight_only to standardize. - if args.quantize == "int8_weight_only": - quantize_(model, int8_weight_only_quantized_training(), set_inductor_config=False) - - elif args.quantize == "int8_mixed_precision": - quantize_(model.layers, int8_mixed_precision_training(), set_inductor_config=False) - - elif args.quantize == "bitnet": - quantize_(model.layers, bitnet_training(), set_inductor_config=False) - + # as recommended by https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf + # section 3 + if args.modify_rmsnorm_for_bitnet: # remove old RMSNorm for layer in model.layers: layer.attention_norm = torch.nn.Identity() layer.ffn_norm = torch.nn.Identity() - + # insert new RMSNorm def insert_rmsnorm(module: torch.nn.Module): for name, child in module.named_children(): @@ -153,9 +147,20 @@ def insert_rmsnorm(module: torch.nn.Module): setattr(module, name, torch.nn.Sequential(norm, child)) else: insert_rmsnorm(child) - + insert_rmsnorm(model.layers) + # don't apply int8_mixed_precision to LM head, since it can cause convergence issue. + # TODO: might want to do the same for int8_weight_only to standardize. + if args.quantize == "int8_weight_only": + quantize_(model, int8_weight_only_quantized_training(), set_inductor_config=False) + + elif args.quantize == "int8_mixed_precision": + quantize_(model.layers, int8_mixed_precision_training(), set_inductor_config=False) + + elif args.quantize == "bitnet": + quantize_(model.layers, bitnet_training(), set_inductor_config=False) + elif args.quantize is not None: raise ValueError(f"Unsupported quantize={args.quantize}") From 1b5f6161e9bfedf5eb253fd35719c32d2af691b2 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 29 Sep 2024 09:07:20 +0800 Subject: [PATCH 15/15] minor fixes. add note on packing --- test/prototype/test_quantized_training.py | 2 +- torchao/prototype/quantized_training/bitnet.py | 14 +++++++++++--- .../quantized_training/int8_mixed_precision.py | 2 +- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index b26be28e54..faecb6b2d2 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -166,7 +166,7 @@ def test_int8_mixed_precision_training(self, compile, config): embed_dim = 64 device = "cuda" - linear = nn.Linear(embed_dim, embed_dim).cuda() + linear = nn.Linear(embed_dim, embed_dim, device=device) linear_int8mp = copy.deepcopy(linear) quantize_(linear_int8mp, int8_mixed_precision_training(config), set_inductor_config=False) diff --git a/torchao/prototype/quantized_training/bitnet.py b/torchao/prototype/quantized_training/bitnet.py index 023f32cac0..ffba7f252e 100644 --- a/torchao/prototype/quantized_training/bitnet.py +++ b/torchao/prototype/quantized_training/bitnet.py @@ -26,7 +26,7 @@ # change in the future. # Multiplying col_scale first is faster than the other way round. def scaled_int8_mm(A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tensor) -> Tensor: - return torch._int_mm(A, B) * col_scale * row_scale.view(-1, 1) + return torch._int_mm(A, B) * col_scale.view(-1) * row_scale.view(-1, 1) aten = torch.ops.aten @@ -223,8 +223,16 @@ def bitnet_training(): def _pack_i2_in_i8(x: Tensor): - # NOTE: this is signed integer, so we have to mask before bit-shift - return (x[:, ::4] << 6) | ((x[:, 1::4] & 0b11) << 4) | ((x[:, 2::4] & 0b11) << 2) | (x[:, 3::4] & 0b11) + # perform packing: [xxxx xxaa, xxxx xxxbb, xxxx xxcc, xxxx xxdd] -> [aabb ccdd] + # for each value, xxxx can be either all 0s or all 1s because these are signed numbers. + # thus, we have to mask out the 2 least significant bits (right-most) before bit-shift. + # e.g. 1111 1111 (value=-1) -> 0000 0011 -> 0011 0000 + + x0 = x[:, ::4] << 6 # don't need to mask this number because we shift it to the left-most + x1 = (x[:, 1::4] & 0b11) << 4 + x2 = (x[:, 2::4] & 0b11) << 2 + x3 = x[:, 3::4] & 0b11 + return x0 | x1 | x2 | x3 def _unpack_i2_in_i8(x: Tensor): diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index ec3dd0cb81..8cc02b53c0 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -19,7 +19,7 @@ # change in the future. # Multiplying col_scale first is faster than the other way round. def scaled_int8_mm(A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tensor) -> Tensor: - return torch._int_mm(A, B) * col_scale * row_scale.view(-1, 1) + return torch._int_mm(A, B) * col_scale.view(-1) * row_scale.view(-1, 1) class Int8MixedPrecisionTrainingConfig(NamedTuple):