From b260263c9ce81f8c8fac91daef7a6225528a5688 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Tue, 2 Jul 2024 03:57:01 +0000 Subject: [PATCH 1/5] add blockwise quant --- test/quantized_ops/test_quantized_matmul.py | 113 +++++++++++++++--- .../experimental/xla_quantized_matmul.py | 107 ++++++++++++----- 2 files changed, 173 insertions(+), 47 deletions(-) diff --git a/test/quantized_ops/test_quantized_matmul.py b/test/quantized_ops/test_quantized_matmul.py index a41eafa90781..021e83e9fe10 100644 --- a/test/quantized_ops/test_quantized_matmul.py +++ b/test/quantized_ops/test_quantized_matmul.py @@ -24,35 +24,66 @@ def __init__(self, input_dim, output_dim): def weight_quantization_rtn(self, linear, n_bits=8, + block_size=-1, quant_method=torch.per_channel_symmetric): ''' Quantize linear weight using Round-To-Nearest(RTN) algorithm. ''' assert isinstance(self.linear, torch.nn.Linear) w_fp = linear.weight.data - min_val, max_val = torch.aminmax(w_fp, dim=1) # min_val, max_val [out_dim] - int_min = -2**(n_bits - 1) - int_max = 2**(n_bits - 1) - 1 - scaler, zero_point = determine_qparams( - min_val, - max_val, - int_min, - int_max, - dtype=torch.int8, - eps=torch.Tensor([1e-5]), - has_customized_qrange=False, - qscheme=quant_method) - w_int = torch.ops.quantized_decomposed.quantize_per_channel( - w_fp, scaler, zero_point, 0, int_min, int_max, torch.int8) - return w_int, scaler.to(w_fp.dtype), zero_point - - def replace_with_xla_quantized_matmul(self, n_bit=8): + if block_size == -1: + min_val, max_val = torch.aminmax( + w_fp, dim=1) # min_val, max_val [out_dim] + int_min = -2**(n_bits - 1) + int_max = 2**(n_bits - 1) - 1 + scaler, zero_point = determine_qparams( + min_val, + max_val, + int_min, + int_max, + dtype=torch.int8, + eps=torch.Tensor([1e-5]), + has_customized_qrange=False, + qscheme=quant_method) + w_int = torch.ops.quantized_decomposed.quantize_per_channel( + w_fp, scaler, zero_point, 0, int_min, int_max, torch.int8) + return w_int, scaler.to(w_fp.dtype), zero_point + else: + assert w_fp.shape[1] % block_size == 0 + output_dim = w_fp.shape[0] + input_dim = w_fp.shape[1] + w_fp = w_fp.reshape(output_dim * input_dim // block_size, block_size) + min_val, max_val = torch.aminmax( + w_fp, dim=1) # min_val, max_val [out_dim] + int_min = -2**(n_bits - 1) + int_max = 2**(n_bits - 1) - 1 + scaler, zero_point = determine_qparams( + min_val, + max_val, + int_min, + int_max, + dtype=torch.int8, + eps=torch.Tensor([1e-5]), + has_customized_qrange=False, + qscheme=quant_method) + w_int = torch.ops.quantized_decomposed.quantize_per_channel( + w_fp, scaler, zero_point, 0, int_min, int_max, torch.int8) + w_int = w_int.reshape(output_dim, input_dim // block_size, + block_size).permute(1, 2, 0) + scaler = scaler.to(w_fp.dtype).reshape(output_dim, + input_dim // block_size).permute( + 1, 0) + return w_int, scaler, zero_point + + def replace_with_xla_quantized_matmul(self, n_bit=8, block_size=-1): assert isinstance(self.linear, torch.nn.Linear) - w_int, scaler, _ = self.weight_quantization_rtn(self.linear, n_bit) + w_int, scaler, _ = self.weight_quantization_rtn( + self.linear, n_bits=n_bit, block_size=block_size) use_int4_weight = n_bit == 4 q_linear = XlaQuantizedLinear( self.linear.in_features, self.linear.out_features, + block_size=block_size, int4_weight=use_int4_weight) q_linear.load_quantized_weight(w_int, scaler) self.linear = q_linear @@ -146,6 +177,52 @@ def test_int4_per_channel_linear_module(self): self.assertGreater( self._calc_cosine_dist(out_quant_xla.cpu(), out_quant), 0.999999) + def test_blockwise_matmul_op(self): + input_features = 6 + out_features = 8 + block_size = 2 + batch_size = 3 + weight = torch.randint( + -8, 7, + (input_features // block_size, block_size, out_features)).to(torch.int8) + weight_scaler = torch.ones(input_features // block_size, out_features) + x = torch.rand(batch_size, input_features) + + # Fake quantize output. + w_dq = (weight * weight_scaler.unsqueeze(1)).reshape( + input_features, out_features) + fake_quant_out = torch.matmul(x, w_dq) + # Eager output. + torch_out = torch.ops.xla.quantized_matmul( + x, weight, weight_scaler, block_size=block_size) + # XLA Output. + x = x.to(device) + weight = weight.to(device) + weight_scaler = weight_scaler.to(device) + xla_out = torch.ops.xla.quantized_matmul( + x, weight, weight_scaler, block_size=block_size) + self.assertGreater( + self._calc_cosine_dist(fake_quant_out, torch_out), 0.99999) + self.assertTrue(torch.allclose(torch_out, xla_out.cpu())) + + def test_blockwise_linear_module(self): + for n_bit in [4, 8]: + with self.subTest(n_bit=n_bit): + m = M(6, 8) + x = torch.randn(3, 6) + out_fp = m(x) + m.replace_with_xla_quantized_matmul(n_bit=8, block_size=2) + out_quant = m(x) + self.assertGreater(self._calc_cosine_dist(out_fp, out_quant), 0.99) + + # Dot with int4 weight is only supported on TPU + if not (n_bit == 4 and xr.device_type() != 'TPU'): + m = m.to(device) + x = x.to(device) + out_quant_xla = m(x) + self.assertGreater( + self._calc_cosine_dist(out_quant_xla.cpu(), out_quant), 0.999999) + if __name__ == '__main__': unittest.main() diff --git a/torch_xla/experimental/xla_quantized_matmul.py b/torch_xla/experimental/xla_quantized_matmul.py index 7b0b41c8d741..746d84be1f1f 100644 --- a/torch_xla/experimental/xla_quantized_matmul.py +++ b/torch_xla/experimental/xla_quantized_matmul.py @@ -5,7 +5,7 @@ from torch_xla.core.xla_model import XLA_LIB XLA_LIB.define( - "quantized_matmul(Tensor x, Tensor w, Tensor scale, int? blocksize=-1, bool? int4_weight=False, bool? quantize_activation=False) -> Tensor" + "quantized_matmul(Tensor x, Tensor w, Tensor scale, int? block_size=-1, bool? int4_weight=False, bool? quantize_activation=False) -> Tensor" ) @@ -21,11 +21,32 @@ def _check_per_channel_quant_weight_dtype_shapes(input_dim, output_dim, w, 0], f"weight scaler shape is expect to be [out_channel,], got {w_scaler.shape}, weight shape {w_shape}." +def _check_blockwise_quant_weight_dtype_shapes(input_dim, output_dim, + block_size, w, w_scaler): + assert w.dtype == torch.int8, ( + f"Weight dtype is expected to be torch.int8, got {w.dtype}.") + assert w.dim() == 3, ( + f"Weight tensor is expected to be 3D, got {w.dim()}D Tensor.") + w_shape = list(w.shape) + assert input_dim % block_size == 0, ( + f"input_dim should be divisible by block_size, " + f"got input_dim: {input_dim}, block_size: {block_size}.") + assert w_shape[0] == input_dim / block_size and w_shape[1] == block_size, ( + f"Weight shape is expected to be [input_dim / block_size, block_size, output_dim], " + f"input_dim: {input_dim}, block_size: {block_size}, output_dim: {output_dim}, " + f"but got {w_shape}.") + assert w_scaler.dim() == 2, ( + f"weight scaler is expected to be 2D, got {w_scaler.dim()}D Tensor.") + assert w_scaler.shape[0] == w_shape[0] and w_scaler.shape[1] == w_shape[-1], ( + f"weight scaler shape is expect to be [in_channel / block_size, out_channel], " + f"got {w_scaler.shape}, weight shape {w_shape}.") + + @impl(XLA_LIB, "quantized_matmul", "XLA") def quantized_matmul_xla(x: torch.Tensor, w: torch.Tensor, scaler: torch.Tensor, - blocksize: int = -1, + block_size: int = -1, int4_weight: bool = False): """Quantized Matrix Multiply op on XLA devices. @@ -33,34 +54,53 @@ def quantized_matmul_xla(x: torch.Tensor, x: torch.Tensor - Activation of Matmul [..., in_channel]. w: torch.Tensor - Weight Tensor. per-channel quant: torch.int8 x [out_channel, in_channel]. + block_wise quant: torch.int8 x [in_channel / block_size, block_size, out_channel]. scaler: torch.Tensor - Weight scaler. per-channel quant: [out_channel,]. - blocksize: blocksize for blockwise quantization, -1 for per-channel quantization. + blockwise quant: [in_channel / block_size, out_channel]. + block_size: The blocksize for blockwise quantization, -1 for per-channel quantization. int4_weight: if the weights are int4, the int4 weights need to be stored in a int8 container (unpacked). """ - assert blocksize == -1, "blockwise quantization is not supported yet." if int4_weight: # Reinterpret cast the weight to s4 dtype in XLA. w = torch_xla._XLAC._xla_cast_int4(w, w.cpu().flatten().numpy().tolist()) - # Per-channel quant. - _check_per_channel_quant_weight_dtype_shapes(x.shape[-1], scaler.shape[0], w, - scaler) - return F.linear(x, w) * scaler + if block_size == -1: + # Per-channel quant. + _check_per_channel_quant_weight_dtype_shapes(x.shape[-1], scaler.shape[0], + w, scaler) + return F.linear(x, w) * scaler + else: + # Blockwise quant. + _check_blockwise_quant_weight_dtype_shapes(x.shape[-1], w.shape[-1], + block_size, w, scaler) + x = x.reshape(*x.shape[:-1], x.shape[-1] // block_size, block_size) + out = torch.einsum('scn,...sc->...sn', w, x) + out = torch.einsum('sn,...sn->...n', scaler, out) + return out @impl(XLA_LIB, "quantized_matmul", "CompositeExplicitAutograd") def quantized_matmul(x: torch.Tensor, w: torch.Tensor, scaler: torch.Tensor, - blocksize: int = -1, + block_size: int = -1, int4_weight: bool = False): - assert blocksize == -1, "blockwise quantization is not supported yet." - # Per-channel quant. - _check_per_channel_quant_weight_dtype_shapes(x.shape[-1], scaler.shape[0], w, - scaler) - w = w.to(x.dtype) - return torch.mul(F.linear(x, w), scaler) + if block_size == -1: + # Per-channel quant. + _check_per_channel_quant_weight_dtype_shapes(x.shape[-1], scaler.shape[0], + w, scaler) + w = w.to(x.dtype) + return torch.mul(F.linear(x, w), scaler) + else: + # Blockwise quant. + _check_blockwise_quant_weight_dtype_shapes(x.shape[-1], w.shape[-1], + block_size, w, scaler) + x = x.reshape(*x.shape[:-1], x.shape[-1] // block_size, block_size) + w = w.to(x.dtype) + out = torch.einsum('scn,...sc->...sn', w, x) + out = torch.einsum('sn,...sn->...n', scaler, out) + return out class XlaQuantizedLinear(torch.nn.Module): @@ -68,13 +108,12 @@ class XlaQuantizedLinear(torch.nn.Module): def __init__(self, input_dim, output_dim, - blocksize=-1, + block_size=-1, int4_weight: bool = False): super().__init__() - assert blocksize == -1, "Only per-channel quantization is supported." self.input_dim = input_dim self.output_dim = output_dim - self.blocksize = blocksize + self.block_size = block_size self.int4_weight = int4_weight self.register_buffer('weight', torch.zeros(output_dim, input_dim).to(torch.int8)) @@ -82,22 +121,32 @@ def __init__(self, def load_quantized_weight(self, weight, weight_scaler): ''' - Weight shape: [output_channel, input_channel] - Weight scaler shape: [output_channel] + Weight shape: + per-channel quant: [out_channel, in_channel]. + block_wise quant: [in_channel / block_size, block_size, out_channel]. + + Weight scaler shape: + per-channel quant: [out_channel,]. + blockwise quant: [in_channel / block_size, out_channel]. ''' - if self.blocksize == -1: + if self.block_size == -1: # Per-channel quant. _check_per_channel_quant_weight_dtype_shapes(self.input_dim, self.output_dim, weight, weight_scaler) - self.weight = weight - self.weight_scaler = weight_scaler else: - assert False, "Only per-channel quantization is supported." + # Blockwise quant. + _check_blockwise_quant_weight_dtype_shapes(self.input_dim, + self.output_dim, + self.block_size, weight, + weight_scaler) + self.weight = weight + self.weight_scaler = weight_scaler def forward(self, x): - if self.blocksize == -1: - return torch.ops.xla.quantized_matmul( - x, self.weight, self.weight_scaler, int4_weight=self.int4_weight) - else: - assert False, "Only per-channel quantization is supported." + return torch.ops.xla.quantized_matmul( + x, + self.weight, + self.weight_scaler, + block_size=self.block_size, + int4_weight=self.int4_weight) From 9febaa88d89b90276494702c1cb778fdb0a535e0 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Tue, 2 Jul 2024 06:08:43 +0000 Subject: [PATCH 2/5] update test --- test/quantized_ops/test_quantized_matmul.py | 47 +++++++++++---------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/test/quantized_ops/test_quantized_matmul.py b/test/quantized_ops/test_quantized_matmul.py index 021e83e9fe10..509afeccff30 100644 --- a/test/quantized_ops/test_quantized_matmul.py +++ b/test/quantized_ops/test_quantized_matmul.py @@ -182,28 +182,31 @@ def test_blockwise_matmul_op(self): out_features = 8 block_size = 2 batch_size = 3 - weight = torch.randint( - -8, 7, - (input_features // block_size, block_size, out_features)).to(torch.int8) - weight_scaler = torch.ones(input_features // block_size, out_features) - x = torch.rand(batch_size, input_features) - - # Fake quantize output. - w_dq = (weight * weight_scaler.unsqueeze(1)).reshape( - input_features, out_features) - fake_quant_out = torch.matmul(x, w_dq) - # Eager output. - torch_out = torch.ops.xla.quantized_matmul( - x, weight, weight_scaler, block_size=block_size) - # XLA Output. - x = x.to(device) - weight = weight.to(device) - weight_scaler = weight_scaler.to(device) - xla_out = torch.ops.xla.quantized_matmul( - x, weight, weight_scaler, block_size=block_size) - self.assertGreater( - self._calc_cosine_dist(fake_quant_out, torch_out), 0.99999) - self.assertTrue(torch.allclose(torch_out, xla_out.cpu())) + for n_bit in [4]: + with self.subTest(n_bit=n_bit): + weight = torch.randint( + -8, 7, + (input_features // block_size, block_size, out_features)).to(torch.int8) + weight_scaler = torch.ones(input_features // block_size, out_features) + x = torch.rand(batch_size, input_features) + + # Fake quantize output. + w_dq = (weight * weight_scaler.unsqueeze(1)).reshape( + input_features, out_features) + fake_quant_out = torch.matmul(x, w_dq) + # Eager output. + torch_out = torch.ops.xla.quantized_matmul( + x, weight, weight_scaler, block_size=block_size) + self.assertGreater( + self._calc_cosine_dist(fake_quant_out, torch_out), 0.99999) + # XLA Output. + if not (n_bit == 4 and xr.device_type() != 'TPU'): + x = x.to(device) + weight = weight.to(device) + weight_scaler = weight_scaler.to(device) + xla_out = torch.ops.xla.quantized_matmul( + x, weight, weight_scaler, block_size=block_size) + self.assertTrue(torch.allclose(torch_out, xla_out.cpu(), atol=0.03)) def test_blockwise_linear_module(self): for n_bit in [4, 8]: From 683c8a861d083b438a89fca3c0ea02bab3288d22 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Tue, 2 Jul 2024 06:09:49 +0000 Subject: [PATCH 3/5] update readme --- docs/quantized_ops.md | 4 ++-- test/quantized_ops/test_quantized_matmul.py | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/docs/quantized_ops.md b/docs/quantized_ops.md index cffe87d298a6..e9b9ed41831a 100644 --- a/docs/quantized_ops.md +++ b/docs/quantized_ops.md @@ -103,8 +103,8 @@ orig_model.linear = q_linear | per-channel | N/A | W4A16 | Yes | | per-channel | per-token | W8A8 | No | | per-channel | per-token | W4A8 | No | -| blockwise | N/A | W8A16 | No | -| blockwise | N/A | W4A16 | No | +| blockwise | N/A | W8A16 | Yes | +| blockwise | N/A | W4A16 | Yes | | blockwise | per-token | W8A8 | No | | blockwise | per-token | W4A8 | No | diff --git a/test/quantized_ops/test_quantized_matmul.py b/test/quantized_ops/test_quantized_matmul.py index 509afeccff30..b7f4f63bb7fa 100644 --- a/test/quantized_ops/test_quantized_matmul.py +++ b/test/quantized_ops/test_quantized_matmul.py @@ -184,9 +184,8 @@ def test_blockwise_matmul_op(self): batch_size = 3 for n_bit in [4]: with self.subTest(n_bit=n_bit): - weight = torch.randint( - -8, 7, - (input_features // block_size, block_size, out_features)).to(torch.int8) + weight = torch.randint(-8, 7, (input_features // block_size, block_size, + out_features)).to(torch.int8) weight_scaler = torch.ones(input_features // block_size, out_features) x = torch.rand(batch_size, input_features) @@ -198,7 +197,7 @@ def test_blockwise_matmul_op(self): torch_out = torch.ops.xla.quantized_matmul( x, weight, weight_scaler, block_size=block_size) self.assertGreater( - self._calc_cosine_dist(fake_quant_out, torch_out), 0.99999) + self._calc_cosine_dist(fake_quant_out, torch_out), 0.99999) # XLA Output. if not (n_bit == 4 and xr.device_type() != 'TPU'): x = x.to(device) From 6bae4a0c100cfaeea528e2c5ba8355354f416e26 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Tue, 2 Jul 2024 16:18:19 +0000 Subject: [PATCH 4/5] fix test --- test/quantized_ops/test_quantized_matmul.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/quantized_ops/test_quantized_matmul.py b/test/quantized_ops/test_quantized_matmul.py index b7f4f63bb7fa..af886196c333 100644 --- a/test/quantized_ops/test_quantized_matmul.py +++ b/test/quantized_ops/test_quantized_matmul.py @@ -126,7 +126,7 @@ def test_q_linear_module_dynamo(self): m = m.to(device) m_dynamo = torch.compile(m, backend="openxla") out_quant_dynamo = m_dynamo(x.to(device)) - self.assertTrue(torch.allclose(out_fp, out_quant, atol=0.01)) + self.assertTrue(torch.allclose(out_fp, out_quant, atol=0.02)) self.assertTrue(torch.allclose(out_quant_dynamo.cpu(), out_quant)) def test_q_linear_hlo(self): From 5b7e48e7ee21f9d993251393877db613c747e46b Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Mon, 8 Jul 2024 20:57:45 +0000 Subject: [PATCH 5/5] update doc str --- torch_xla/experimental/xla_quantized_matmul.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/xla_quantized_matmul.py b/torch_xla/experimental/xla_quantized_matmul.py index 746d84be1f1f..5de4eb644cef 100644 --- a/torch_xla/experimental/xla_quantized_matmul.py +++ b/torch_xla/experimental/xla_quantized_matmul.py @@ -121,11 +121,11 @@ def __init__(self, def load_quantized_weight(self, weight, weight_scaler): ''' - Weight shape: + weight (Tensor): per-channel quant: [out_channel, in_channel]. block_wise quant: [in_channel / block_size, block_size, out_channel]. - Weight scaler shape: + weight_scaler (Tensor): per-channel quant: [out_channel,]. blockwise quant: [in_channel / block_size, out_channel]. '''