diff --git a/docs/quantized_ops.md b/docs/quantized_ops.md index cffe87d298a6..e9b9ed41831a 100644 --- a/docs/quantized_ops.md +++ b/docs/quantized_ops.md @@ -103,8 +103,8 @@ orig_model.linear = q_linear | per-channel | N/A | W4A16 | Yes | | per-channel | per-token | W8A8 | No | | per-channel | per-token | W4A8 | No | -| blockwise | N/A | W8A16 | No | -| blockwise | N/A | W4A16 | No | +| blockwise | N/A | W8A16 | Yes | +| blockwise | N/A | W4A16 | Yes | | blockwise | per-token | W8A8 | No | | blockwise | per-token | W4A8 | No | diff --git a/test/quantized_ops/test_quantized_matmul.py b/test/quantized_ops/test_quantized_matmul.py index a41eafa90781..af886196c333 100644 --- a/test/quantized_ops/test_quantized_matmul.py +++ b/test/quantized_ops/test_quantized_matmul.py @@ -24,35 +24,66 @@ def __init__(self, input_dim, output_dim): def weight_quantization_rtn(self, linear, n_bits=8, + block_size=-1, quant_method=torch.per_channel_symmetric): ''' Quantize linear weight using Round-To-Nearest(RTN) algorithm. ''' assert isinstance(self.linear, torch.nn.Linear) w_fp = linear.weight.data - min_val, max_val = torch.aminmax(w_fp, dim=1) # min_val, max_val [out_dim] - int_min = -2**(n_bits - 1) - int_max = 2**(n_bits - 1) - 1 - scaler, zero_point = determine_qparams( - min_val, - max_val, - int_min, - int_max, - dtype=torch.int8, - eps=torch.Tensor([1e-5]), - has_customized_qrange=False, - qscheme=quant_method) - w_int = torch.ops.quantized_decomposed.quantize_per_channel( - w_fp, scaler, zero_point, 0, int_min, int_max, torch.int8) - return w_int, scaler.to(w_fp.dtype), zero_point - - def replace_with_xla_quantized_matmul(self, n_bit=8): + if block_size == -1: + min_val, max_val = torch.aminmax( + w_fp, dim=1) # min_val, max_val [out_dim] + int_min = -2**(n_bits - 1) + int_max = 2**(n_bits - 1) - 1 + scaler, zero_point = determine_qparams( + min_val, + max_val, + int_min, + int_max, + dtype=torch.int8, + eps=torch.Tensor([1e-5]), + has_customized_qrange=False, + qscheme=quant_method) + w_int = torch.ops.quantized_decomposed.quantize_per_channel( + w_fp, scaler, zero_point, 0, int_min, int_max, torch.int8) + return w_int, scaler.to(w_fp.dtype), zero_point + else: + assert w_fp.shape[1] % block_size == 0 + output_dim = w_fp.shape[0] + input_dim = w_fp.shape[1] + w_fp = w_fp.reshape(output_dim * input_dim // block_size, block_size) + min_val, max_val = torch.aminmax( + w_fp, dim=1) # min_val, max_val [out_dim] + int_min = -2**(n_bits - 1) + int_max = 2**(n_bits - 1) - 1 + scaler, zero_point = determine_qparams( + min_val, + max_val, + int_min, + int_max, + dtype=torch.int8, + eps=torch.Tensor([1e-5]), + has_customized_qrange=False, + qscheme=quant_method) + w_int = torch.ops.quantized_decomposed.quantize_per_channel( + w_fp, scaler, zero_point, 0, int_min, int_max, torch.int8) + w_int = w_int.reshape(output_dim, input_dim // block_size, + block_size).permute(1, 2, 0) + scaler = scaler.to(w_fp.dtype).reshape(output_dim, + input_dim // block_size).permute( + 1, 0) + return w_int, scaler, zero_point + + def replace_with_xla_quantized_matmul(self, n_bit=8, block_size=-1): assert isinstance(self.linear, torch.nn.Linear) - w_int, scaler, _ = self.weight_quantization_rtn(self.linear, n_bit) + w_int, scaler, _ = self.weight_quantization_rtn( + self.linear, n_bits=n_bit, block_size=block_size) use_int4_weight = n_bit == 4 q_linear = XlaQuantizedLinear( self.linear.in_features, self.linear.out_features, + block_size=block_size, int4_weight=use_int4_weight) q_linear.load_quantized_weight(w_int, scaler) self.linear = q_linear @@ -95,7 +126,7 @@ def test_q_linear_module_dynamo(self): m = m.to(device) m_dynamo = torch.compile(m, backend="openxla") out_quant_dynamo = m_dynamo(x.to(device)) - self.assertTrue(torch.allclose(out_fp, out_quant, atol=0.01)) + self.assertTrue(torch.allclose(out_fp, out_quant, atol=0.02)) self.assertTrue(torch.allclose(out_quant_dynamo.cpu(), out_quant)) def test_q_linear_hlo(self): @@ -146,6 +177,54 @@ def test_int4_per_channel_linear_module(self): self.assertGreater( self._calc_cosine_dist(out_quant_xla.cpu(), out_quant), 0.999999) + def test_blockwise_matmul_op(self): + input_features = 6 + out_features = 8 + block_size = 2 + batch_size = 3 + for n_bit in [4]: + with self.subTest(n_bit=n_bit): + weight = torch.randint(-8, 7, (input_features // block_size, block_size, + out_features)).to(torch.int8) + weight_scaler = torch.ones(input_features // block_size, out_features) + x = torch.rand(batch_size, input_features) + + # Fake quantize output. + w_dq = (weight * weight_scaler.unsqueeze(1)).reshape( + input_features, out_features) + fake_quant_out = torch.matmul(x, w_dq) + # Eager output. + torch_out = torch.ops.xla.quantized_matmul( + x, weight, weight_scaler, block_size=block_size) + self.assertGreater( + self._calc_cosine_dist(fake_quant_out, torch_out), 0.99999) + # XLA Output. + if not (n_bit == 4 and xr.device_type() != 'TPU'): + x = x.to(device) + weight = weight.to(device) + weight_scaler = weight_scaler.to(device) + xla_out = torch.ops.xla.quantized_matmul( + x, weight, weight_scaler, block_size=block_size) + self.assertTrue(torch.allclose(torch_out, xla_out.cpu(), atol=0.03)) + + def test_blockwise_linear_module(self): + for n_bit in [4, 8]: + with self.subTest(n_bit=n_bit): + m = M(6, 8) + x = torch.randn(3, 6) + out_fp = m(x) + m.replace_with_xla_quantized_matmul(n_bit=8, block_size=2) + out_quant = m(x) + self.assertGreater(self._calc_cosine_dist(out_fp, out_quant), 0.99) + + # Dot with int4 weight is only supported on TPU + if not (n_bit == 4 and xr.device_type() != 'TPU'): + m = m.to(device) + x = x.to(device) + out_quant_xla = m(x) + self.assertGreater( + self._calc_cosine_dist(out_quant_xla.cpu(), out_quant), 0.999999) + if __name__ == '__main__': unittest.main() diff --git a/torch_xla/experimental/xla_quantized_matmul.py b/torch_xla/experimental/xla_quantized_matmul.py index 7b0b41c8d741..5de4eb644cef 100644 --- a/torch_xla/experimental/xla_quantized_matmul.py +++ b/torch_xla/experimental/xla_quantized_matmul.py @@ -5,7 +5,7 @@ from torch_xla.core.xla_model import XLA_LIB XLA_LIB.define( - "quantized_matmul(Tensor x, Tensor w, Tensor scale, int? blocksize=-1, bool? int4_weight=False, bool? quantize_activation=False) -> Tensor" + "quantized_matmul(Tensor x, Tensor w, Tensor scale, int? block_size=-1, bool? int4_weight=False, bool? quantize_activation=False) -> Tensor" ) @@ -21,11 +21,32 @@ def _check_per_channel_quant_weight_dtype_shapes(input_dim, output_dim, w, 0], f"weight scaler shape is expect to be [out_channel,], got {w_scaler.shape}, weight shape {w_shape}." +def _check_blockwise_quant_weight_dtype_shapes(input_dim, output_dim, + block_size, w, w_scaler): + assert w.dtype == torch.int8, ( + f"Weight dtype is expected to be torch.int8, got {w.dtype}.") + assert w.dim() == 3, ( + f"Weight tensor is expected to be 3D, got {w.dim()}D Tensor.") + w_shape = list(w.shape) + assert input_dim % block_size == 0, ( + f"input_dim should be divisible by block_size, " + f"got input_dim: {input_dim}, block_size: {block_size}.") + assert w_shape[0] == input_dim / block_size and w_shape[1] == block_size, ( + f"Weight shape is expected to be [input_dim / block_size, block_size, output_dim], " + f"input_dim: {input_dim}, block_size: {block_size}, output_dim: {output_dim}, " + f"but got {w_shape}.") + assert w_scaler.dim() == 2, ( + f"weight scaler is expected to be 2D, got {w_scaler.dim()}D Tensor.") + assert w_scaler.shape[0] == w_shape[0] and w_scaler.shape[1] == w_shape[-1], ( + f"weight scaler shape is expect to be [in_channel / block_size, out_channel], " + f"got {w_scaler.shape}, weight shape {w_shape}.") + + @impl(XLA_LIB, "quantized_matmul", "XLA") def quantized_matmul_xla(x: torch.Tensor, w: torch.Tensor, scaler: torch.Tensor, - blocksize: int = -1, + block_size: int = -1, int4_weight: bool = False): """Quantized Matrix Multiply op on XLA devices. @@ -33,34 +54,53 @@ def quantized_matmul_xla(x: torch.Tensor, x: torch.Tensor - Activation of Matmul [..., in_channel]. w: torch.Tensor - Weight Tensor. per-channel quant: torch.int8 x [out_channel, in_channel]. + block_wise quant: torch.int8 x [in_channel / block_size, block_size, out_channel]. scaler: torch.Tensor - Weight scaler. per-channel quant: [out_channel,]. - blocksize: blocksize for blockwise quantization, -1 for per-channel quantization. + blockwise quant: [in_channel / block_size, out_channel]. + block_size: The blocksize for blockwise quantization, -1 for per-channel quantization. int4_weight: if the weights are int4, the int4 weights need to be stored in a int8 container (unpacked). """ - assert blocksize == -1, "blockwise quantization is not supported yet." if int4_weight: # Reinterpret cast the weight to s4 dtype in XLA. w = torch_xla._XLAC._xla_cast_int4(w, w.cpu().flatten().numpy().tolist()) - # Per-channel quant. - _check_per_channel_quant_weight_dtype_shapes(x.shape[-1], scaler.shape[0], w, - scaler) - return F.linear(x, w) * scaler + if block_size == -1: + # Per-channel quant. + _check_per_channel_quant_weight_dtype_shapes(x.shape[-1], scaler.shape[0], + w, scaler) + return F.linear(x, w) * scaler + else: + # Blockwise quant. + _check_blockwise_quant_weight_dtype_shapes(x.shape[-1], w.shape[-1], + block_size, w, scaler) + x = x.reshape(*x.shape[:-1], x.shape[-1] // block_size, block_size) + out = torch.einsum('scn,...sc->...sn', w, x) + out = torch.einsum('sn,...sn->...n', scaler, out) + return out @impl(XLA_LIB, "quantized_matmul", "CompositeExplicitAutograd") def quantized_matmul(x: torch.Tensor, w: torch.Tensor, scaler: torch.Tensor, - blocksize: int = -1, + block_size: int = -1, int4_weight: bool = False): - assert blocksize == -1, "blockwise quantization is not supported yet." - # Per-channel quant. - _check_per_channel_quant_weight_dtype_shapes(x.shape[-1], scaler.shape[0], w, - scaler) - w = w.to(x.dtype) - return torch.mul(F.linear(x, w), scaler) + if block_size == -1: + # Per-channel quant. + _check_per_channel_quant_weight_dtype_shapes(x.shape[-1], scaler.shape[0], + w, scaler) + w = w.to(x.dtype) + return torch.mul(F.linear(x, w), scaler) + else: + # Blockwise quant. + _check_blockwise_quant_weight_dtype_shapes(x.shape[-1], w.shape[-1], + block_size, w, scaler) + x = x.reshape(*x.shape[:-1], x.shape[-1] // block_size, block_size) + w = w.to(x.dtype) + out = torch.einsum('scn,...sc->...sn', w, x) + out = torch.einsum('sn,...sn->...n', scaler, out) + return out class XlaQuantizedLinear(torch.nn.Module): @@ -68,13 +108,12 @@ class XlaQuantizedLinear(torch.nn.Module): def __init__(self, input_dim, output_dim, - blocksize=-1, + block_size=-1, int4_weight: bool = False): super().__init__() - assert blocksize == -1, "Only per-channel quantization is supported." self.input_dim = input_dim self.output_dim = output_dim - self.blocksize = blocksize + self.block_size = block_size self.int4_weight = int4_weight self.register_buffer('weight', torch.zeros(output_dim, input_dim).to(torch.int8)) @@ -82,22 +121,32 @@ def __init__(self, def load_quantized_weight(self, weight, weight_scaler): ''' - Weight shape: [output_channel, input_channel] - Weight scaler shape: [output_channel] + weight (Tensor): + per-channel quant: [out_channel, in_channel]. + block_wise quant: [in_channel / block_size, block_size, out_channel]. + + weight_scaler (Tensor): + per-channel quant: [out_channel,]. + blockwise quant: [in_channel / block_size, out_channel]. ''' - if self.blocksize == -1: + if self.block_size == -1: # Per-channel quant. _check_per_channel_quant_weight_dtype_shapes(self.input_dim, self.output_dim, weight, weight_scaler) - self.weight = weight - self.weight_scaler = weight_scaler else: - assert False, "Only per-channel quantization is supported." + # Blockwise quant. + _check_blockwise_quant_weight_dtype_shapes(self.input_dim, + self.output_dim, + self.block_size, weight, + weight_scaler) + self.weight = weight + self.weight_scaler = weight_scaler def forward(self, x): - if self.blocksize == -1: - return torch.ops.xla.quantized_matmul( - x, self.weight, self.weight_scaler, int4_weight=self.int4_weight) - else: - assert False, "Only per-channel quantization is supported." + return torch.ops.xla.quantized_matmul( + x, + self.weight, + self.weight_scaler, + block_size=self.block_size, + int4_weight=self.int4_weight)