From 0fd7957a077b14f124a2e34e57268da69a396d94 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Mon, 9 Sep 2024 13:05:57 -0700 Subject: [PATCH] Add 2, 3, 4, 5 bit custom ops (#828) Summary: Pull Request resolved: https://github.com/pytorch/ao/pull/828 Refactors some of the custom op example code to add support for more kernel variants. Reviewed By: digantdesai Differential Revision: D62248716 --- .../examples/torch_custom_op/run_custom_op.py | 84 ++++--- .../torch_custom_op/torch_custom_op.cpp | 183 +++++++++++++-- .../torch_custom_op/torch_custom_op.py | 220 ++++++++++++------ 3 files changed, 360 insertions(+), 127 deletions(-) diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py index e933dd3aac..77cc35f0db 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py +++ b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py @@ -4,16 +4,21 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from torch_custom_op import quantize, replace_linear_with_quantized_linear -import torch import copy -group_size = 16 +import torch +from torch_custom_op import ( + linear_a8sz_w_lowbit_reference_impl, + replace_linear_with_quantized_linear, +) + +group_size = 256 m = 1 n = 4096 k = 4096 -nbit = 4 -n_layers = 10 +nbit = 5 +has_weight_zeros = True +n_layers = 5 print("Creating random model") layers = [torch.nn.Linear(k, n, bias=False) for _ in range(n_layers)] @@ -22,8 +27,15 @@ print("Quantizing random model") quantized_model = copy.deepcopy(model) -quantized_model = quantized_model.eval() -replace_linear_with_quantized_linear(quantized_model, kwargs={"group_size": group_size, "nbit": nbit}) +quantized_model = quantized_model.eval() +replace_linear_with_quantized_linear( + quantized_model, + kwargs={ + "group_size": group_size, + "nbit": nbit, + "has_weight_zeros": has_weight_zeros, + }, +) print("Creating random activations") activations = torch.randn(m, k, dtype=torch.float32) @@ -48,36 +60,42 @@ fn(activations) -print("Checking correctness on layer 0") - -rtol=1e-05 - -# default is 1e-8, but PyTorch and C++ (and ARM neon) have different rounding -# conventions for ties (PyTorch rounds half to even and C++ rounds half to odd) -# TODO(T200109708): address this -atol=1e-05 - +print("\nChecking correctness on layer 0") linear = model[0] quantized_linear = quantized_model[0] -weight_qvals, weight_scales = quantize(linear.weight, group_size, quantized_linear.nbit, scale_only=True) - -activation_qvals, activations_scales, activations_zeros = quantize(activations, k, 8, False) -activations_dequantized = activations_scales * (activation_qvals - activations_zeros) -weights_dequantized = (weight_qvals.reshape(-1, group_size) * weight_scales.reshape(-1, 1)).reshape(n, k) with torch.no_grad(): result = quantized_linear(activations) - expected_result = torch.matmul(activations_dequantized, weights_dequantized.transpose(1, 0)) + expected_result = linear_a8sz_w_lowbit_reference_impl( + linear.weight, activations, group_size, nbit, has_weight_zeros + ) non_quantized_result = linear(activations) -if not (torch.allclose(result, expected_result, rtol=rtol, atol=atol)): - rand_idxs = torch.randint(0, result.shape[1], (5,)) - print("rand_idxs: ", rand_idxs) - print("kernel_result[rand_idxs]: ", result[0][rand_idxs]) - print("expected_result[rand_idxs]: ", expected_result[0][rand_idxs]) - assert False -else: - print("Correctness check passed") - -print("kernel_result[0:5]: ", result[0][0:5]) -print("non_quantized_result[0:5]: ", non_quantized_result[0][0:5]) + +# Check that entries in result match entries in expected_result +num_mismatch_at_low_tol = 0 +num_total = result.reshape(-1).shape[0] +for i in range(num_total): + actual_val = result.reshape(-1)[i] + expected_val = expected_result.reshape(-1)[i] + if not torch.allclose(actual_val, expected_val): + num_mismatch_at_low_tol += 1 + + # If results are not close at a relaxed tolerance, exit with failure + if not torch.allclose(actual_val, expected_val, atol=1e-6): + assert False, "Correctness check failed" + +# Assert at most 5% of entries are not close at a low tolerance +assert num_mismatch_at_low_tol / num_total <= 0.05, "Correctness check failed" +print( + "Correctness check passed. All results are close, and ", + (num_total - num_mismatch_at_low_tol), + "/", + num_total, + " entries are close at a low tolerance.", +) +print("Quantization errors:") +print("\tL1 error: ", torch.mean(torch.abs(result - non_quantized_result)).item()) +print("\tL2 error: ", torch.mean((result - non_quantized_result) ** 2).item()) +print("\tquantized_result[0:5]: ", result[0][0:5]) +print("\tnon_quantized_result[0:5]: ", non_quantized_result[0][0:5]) diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.cpp b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.cpp index 0ac19ec9f4..3f7488d915 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.cpp +++ b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.cpp @@ -11,7 +11,7 @@ #include template -at::Tensor pack_weights_cpu( +at::Tensor pack_weights_without_zeros_cpu( const at::Tensor& weight_qvals, const at::Tensor& weight_scales, // TODO(T200095131): convert to int64_t when supported by AOTI @@ -54,9 +54,8 @@ at::Tensor pack_weights_cpu( auto packed_weight_data_size = get_packed_weight_data_size(ukernel_config, n, k, group_size); - auto options = torch::TensorOptions().dtype(torch::kInt8); - - at::Tensor packed_weights = torch::empty({packed_weight_data_size}, options); + at::Tensor packed_weights = + torch::empty({packed_weight_data_size}, torch::kInt8); pack_weight_data_operator( ukernel_config, pack_weight_tiling_params, @@ -72,7 +71,74 @@ at::Tensor pack_weights_cpu( } template -at::Tensor pack_weights_meta( +at::Tensor pack_weights_with_zeros_cpu( + const at::Tensor& weight_qvals, + const at::Tensor& weight_scales, + const at::Tensor& weight_zeros, + // TODO(T200095131): convert to int64_t when supported by AOTI + // group_size is a meta tensor with size (group_size) + const at::Tensor& group_size_tensor) { + int64_t group_size = group_size_tensor.size(0); + + TORCH_CHECK( + weight_qvals.dtype() == torch::kInt8, "weight_qvals must be int8"); + TORCH_CHECK(weight_qvals.dim() == 2, "weight_qvals must be 2D"); + + // In PyTorch, weights are nxk in row-major format (with activations being + // right-multiplied). + // In kernel, activations are left-multiplied by kxn transposed + // weights in column-major format. + // Note the underlying data is the same in both cases + int n = weight_qvals.size(0); + int k = weight_qvals.size(1); + + TORCH_CHECK( + weight_scales.dtype() == torch::kFloat32, + "weight_scales must be float32"); + TORCH_CHECK(weight_scales.dim() == 1, "weight_scales must be 1D"); + TORCH_CHECK( + weight_scales.size(0) == ((n * k) / group_size), + "expected 1 scale per group"); + TORCH_CHECK( + weight_zeros.dtype() == torch::kInt8, "weight_zeros must be int8"); + TORCH_CHECK(weight_zeros.dim() == 1, "weight_zeros must be 1D"); + TORCH_CHECK( + weight_zeros.size(0) == ((n * k) / group_size), + "expected 1 zero per group"); + + using namespace torchao::operators::cpu::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight; + + auto ukernel_config = get_ukernel_config< + weight_nbit, + true /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/>(); + auto pack_weight_tiling_params = get_default_pack_weight_data_tiling_params( + ukernel_config, n, /*target_panels_per_thread=*/1); + + torchao::set_num_threads(torch::get_num_threads()); + + auto packed_weight_data_size = + get_packed_weight_data_size(ukernel_config, n, k, group_size); + at::Tensor packed_weights = + torch::empty({packed_weight_data_size}, torch::kInt8); + pack_weight_data_operator( + ukernel_config, + pack_weight_tiling_params, + packed_weights.data_ptr(), + n, + k, + group_size, + weight_qvals.const_data_ptr(), + weight_scales.const_data_ptr(), + weight_zeros.const_data_ptr()); + + return packed_weights; +} + +template +at::Tensor pack_weights_without_zeros_meta( const at::Tensor& weight_qvals, const at::Tensor& weight_scales, // TODO(T200095131): convert to int64_t when supported by AOTI @@ -98,6 +164,33 @@ at::Tensor pack_weights_meta( } template +at::Tensor pack_weights_with_zeros_meta( + const at::Tensor& weight_qvals, + const at::Tensor& weight_scales, + const at::Tensor& weight_zeros, + // TODO(T200095131): convert to int64_t when supported by AOTI + // group_size is a meta tensor with size (group_size) + const at::Tensor& group_size_tensor) { + int64_t group_size = group_size_tensor.size(0); + + int n = weight_qvals.size(0); + int k = weight_qvals.size(1); + + using namespace torchao::operators::cpu::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight; + + auto ukernel_config = get_ukernel_config< + weight_nbit, + true /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/>(); + + auto packed_weight_data_size = + get_packed_weight_data_size(ukernel_config, n, k, group_size); + return torch::empty({packed_weight_data_size}).to("meta"); +} + +template at::Tensor linear_cpu( const at::Tensor& packed_weights, // TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to @@ -123,7 +216,7 @@ at::Tensor linear_cpu( auto ukernel_config = get_ukernel_config< weight_nbit, - false /*has_weight_zeros*/, + has_weight_zeros /*has_weight_zeros*/, false /*has_bias*/, false /*has_clamp*/>(); auto linear_tiling_params = get_default_linear_tiling_params( @@ -167,7 +260,7 @@ at::Tensor linear_cpu( return output_tensor; } -template +template at::Tensor linear_meta( const at::Tensor& packed_weights, // TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to @@ -187,26 +280,78 @@ at::Tensor linear_meta( } TORCH_LIBRARY(torchao, m) { + // Pack weights without zeros + m.def( + "_pack_weights_a8sz_w2s(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor"); + m.def( + "_pack_weights_a8sz_w3s(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor"); + m.def( + "_pack_weights_a8sz_w4s(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor"); + m.def( + "_pack_weights_a8sz_w5s(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor"); + // Pack weights with zeros + m.def( + "_pack_weights_a8sz_w2sz(Tensor weight_qvals, Tensor weight_scales, Tensor weight_zeros, Tensor group_size) -> Tensor"); + m.def( + "_pack_weights_a8sz_w3sz(Tensor weight_qvals, Tensor weight_scales, Tensor weight_zeros, Tensor group_size) -> Tensor"); + m.def( + "_pack_weights_a8sz_w4sz(Tensor weight_qvals, Tensor weight_scales, Tensor weight_zeros, Tensor group_size) -> Tensor"); + m.def( + "_pack_weights_a8sz_w5sz(Tensor weight_qvals, Tensor weight_scales, Tensor weight_zeros, Tensor group_size) -> Tensor"); + // Linear weights without zeros + m.def( + "_linear_a8sz_w2s(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor"); + m.def( + "_linear_a8sz_w3s(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor"); + m.def( + "_linear_a8sz_w4s(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor"); + m.def( + "_linear_a8sz_w5s(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor"); + // Linear weights with zeros m.def( - "_pack_weights_3bit(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor"); + "_linear_a8sz_w2sz(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor"); m.def( - "_linear_3bit(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor"); + "_linear_a8sz_w3sz(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor"); m.def( - "_pack_weights_4bit(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor"); + "_linear_a8sz_w4sz(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor"); m.def( - "_linear_4bit(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor"); + "_linear_a8sz_w5sz(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor"); } TORCH_LIBRARY_IMPL(torchao, CPU, m) { - m.impl("_pack_weights_3bit", &pack_weights_cpu<3>); - m.impl("_linear_3bit", &linear_cpu<3>); - m.impl("_pack_weights_4bit", &pack_weights_cpu<4>); - m.impl("_linear_4bit", &linear_cpu<4>); + m.impl("_pack_weights_a8sz_w2s", &pack_weights_without_zeros_cpu<2>); + m.impl("_pack_weights_a8sz_w3s", &pack_weights_without_zeros_cpu<3>); + m.impl("_pack_weights_a8sz_w4s", &pack_weights_without_zeros_cpu<4>); + m.impl("_pack_weights_a8sz_w5s", &pack_weights_without_zeros_cpu<5>); + m.impl("_pack_weights_a8sz_w2sz", &pack_weights_with_zeros_cpu<2>); + m.impl("_pack_weights_a8sz_w3sz", &pack_weights_with_zeros_cpu<3>); + m.impl("_pack_weights_a8sz_w4sz", &pack_weights_with_zeros_cpu<4>); + m.impl("_pack_weights_a8sz_w5sz", &pack_weights_with_zeros_cpu<5>); + m.impl("_linear_a8sz_w2s", &linear_cpu<2, false>); + m.impl("_linear_a8sz_w3s", &linear_cpu<3, false>); + m.impl("_linear_a8sz_w4s", &linear_cpu<4, false>); + m.impl("_linear_a8sz_w5s", &linear_cpu<5, false>); + m.impl("_linear_a8sz_w2sz", &linear_cpu<2, true>); + m.impl("_linear_a8sz_w3sz", &linear_cpu<3, true>); + m.impl("_linear_a8sz_w4sz", &linear_cpu<4, true>); + m.impl("_linear_a8sz_w5sz", &linear_cpu<5, true>); } TORCH_LIBRARY_IMPL(torchao, Meta, m) { - m.impl("_pack_weights_3bit", &pack_weights_meta<3>); - m.impl("_linear_3bit", &linear_meta<3>); - m.impl("_pack_weights_4bit", &pack_weights_meta<4>); - m.impl("_linear_4bit", &linear_meta<4>); + m.impl("_pack_weights_a8sz_w2s", &pack_weights_without_zeros_meta<2>); + m.impl("_pack_weights_a8sz_w3s", &pack_weights_without_zeros_meta<3>); + m.impl("_pack_weights_a8sz_w4s", &pack_weights_without_zeros_meta<4>); + m.impl("_pack_weights_a8sz_w5s", &pack_weights_without_zeros_meta<5>); + m.impl("_pack_weights_a8sz_w2sz", &pack_weights_with_zeros_meta<2>); + m.impl("_pack_weights_a8sz_w3sz", &pack_weights_with_zeros_meta<3>); + m.impl("_pack_weights_a8sz_w4sz", &pack_weights_with_zeros_meta<4>); + m.impl("_pack_weights_a8sz_w5sz", &pack_weights_with_zeros_meta<5>); + m.impl("_linear_a8sz_w2s", &linear_meta<2, false>); + m.impl("_linear_a8sz_w3s", &linear_meta<3, false>); + m.impl("_linear_a8sz_w4s", &linear_meta<4, false>); + m.impl("_linear_a8sz_w5s", &linear_meta<5, false>); + m.impl("_linear_a8sz_w2sz", &linear_meta<2, true>); + m.impl("_linear_a8sz_w3sz", &linear_meta<3, true>); + m.impl("_linear_a8sz_w4sz", &linear_meta<4, true>); + m.impl("_linear_a8sz_w5sz", &linear_meta<5, true>); } diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.py b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.py index 6a4ff1a731..7e25eed261 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.py +++ b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.py @@ -40,54 +40,57 @@ def quantize(vals: torch.Tensor, group_size: int, nbit: int, scale_only: bool): return group_qvals, group_scales, group_zeros -class Chn8ActGrp3WgtQuantizedLinear(nn.Module): - nbit = 3 - - def __init__(self, squeeze_unsqueeze_dim0=False): - super().__init__() - self.squeeze_unsqueeze_dim0 = squeeze_unsqueeze_dim0 - - def initialize_from_unpacked_weights(self, weight_qvals, weight_scales, group_size): - n, k = weight_qvals.shape - - # TODO(T200095131): convert self.n, self.k, self.group_size to - # int when supported by AOTI - self.n = torch.empty(n) - self.k = torch.empty(k) - self.group_size = torch.empty(group_size) - self.packed_weights = torch.ops.torchao._pack_weights_3bit( - weight_qvals, weight_scales, self.group_size +def linear_a8sz_w_lowbit_reference_impl( + weights, activations, group_size, nbit, has_weight_zeros +): + n, k = weights.shape + m, k = activations.shape + assert m == 1 + assert k % group_size == 0 + + if has_weight_zeros: + weight_qvals, weight_scales, weight_zeros = quantize( + weights, group_size, nbit, scale_only=False ) - - def initialize_from_packed_weights(self, packed_weights, n, k, group_size): - # TODO(T200095131): convert self.n, self.k, self.group_size to - # int when supported by AOTI - self.n = torch.empty(n) - self.k = torch.empty(k) - self.group_size = torch.empty(group_size) - self.packed_weights = packed_weights - - def forward(self, x): - if self.squeeze_unsqueeze_dim0: - x = x.squeeze(0) - - res = torch.ops.torchao._linear_3bit( - self.packed_weights, self.n, self.k, self.group_size, x + weights_dequantized = ( + weight_scales.reshape(-1, 1) + * (weight_qvals.reshape(-1, group_size) - weight_zeros.reshape(-1, 1)) + ).reshape(n, k) + else: + weight_qvals, weight_scales = quantize( + weights, group_size, nbit, scale_only=True ) + weights_dequantized = ( + weight_scales.reshape(-1, 1) * (weight_qvals.reshape(-1, group_size)) + ).reshape(n, k) - if self.squeeze_unsqueeze_dim0: - res = res.unsqueeze(0) - return res - - -class Chn8ActGrp4WgtQuantizedLinear(nn.Module): - nbit = 4 - - def __init__(self, squeeze_unsqueeze_dim0=False): + activation_qvals, activations_scales, activations_zeros = quantize( + activations, k, 8, False + ) + activations_dequantized = activations_scales * ( + activation_qvals - activations_zeros + ).reshape(m, k) + return torch.matmul(activations_dequantized, weights_dequantized.transpose(1, 0)) + + +class _quantized_linear(nn.Module): + def __init__( + self, + nbit, + has_weight_zeros, + pack_weight_op, + linear_op, + squeeze_unsqueeze_dim0=False, + ): super().__init__() self.squeeze_unsqueeze_dim0 = squeeze_unsqueeze_dim0 + self.nbit = nbit - def initialize_from_unpacked_weights(self, weight_qvals, weight_scales, group_size): + self._has_weight_zeros = has_weight_zeros + self._pack_weights_op = pack_weight_op + self._linear_op = linear_op + + def pack_weights(self, weight_qvals, weight_scales_and_zeros, group_size): n, k = weight_qvals.shape # TODO(T200095131): convert self.n, self.k, self.group_size to @@ -95,25 +98,23 @@ def initialize_from_unpacked_weights(self, weight_qvals, weight_scales, group_si self.n = torch.empty(n) self.k = torch.empty(k) self.group_size = torch.empty(group_size) - self.packed_weights = torch.ops.torchao._pack_weights_4bit( - weight_qvals, weight_scales, self.group_size - ) - def initialize_from_packed_weights(self, packed_weights, n, k, group_size): - # TODO(T200095131): convert self.n, self.k, self.group_size to - # int when supported by AOTI - self.n = torch.empty(n) - self.k = torch.empty(k) - self.group_size = torch.empty(group_size) - self.packed_weights = packed_weights + if self._has_weight_zeros: + weight_scales, weight_zeros = weight_scales_and_zeros + self.packed_weights = self._pack_weights_op( + weight_qvals, weight_scales, weight_zeros, self.group_size + ) + else: + weight_scales = weight_scales_and_zeros + self.packed_weights = self._pack_weights_op( + weight_qvals, weight_scales, self.group_size + ) def forward(self, x): if self.squeeze_unsqueeze_dim0: x = x.squeeze(0) - res = torch.ops.torchao._linear_4bit( - self.packed_weights, self.n, self.k, self.group_size, x - ) + res = self._linear_op(self.packed_weights, self.n, self.k, self.group_size, x) if self.squeeze_unsqueeze_dim0: res = res.unsqueeze(0) @@ -123,6 +124,7 @@ def forward(self, x): def replace_linear_with_quantized_linear(module: nn.Module, kwargs={}): group_size = kwargs["group_size"] nbit = kwargs["nbit"] + has_weight_zeros = kwargs["has_weight_zeros"] squeeze_unsqueeze_dim0 = ( kwargs["squeeze_unsqueeze_dim0"] if "squeeze_unsqueeze_dim0" in kwargs @@ -132,30 +134,98 @@ def replace_linear_with_quantized_linear(module: nn.Module, kwargs={}): for name, child in module.named_children(): if isinstance(child, nn.Linear): assert child.bias is None - n, k = child.weight.shape - weight_qvals, weight_scales = quantize( - child.weight, group_size=group_size, nbit=nbit, scale_only=True - ) - if nbit == 3: - setattr( - module, name, Chn8ActGrp3WgtQuantizedLinear(squeeze_unsqueeze_dim0) + if not has_weight_zeros: + weight_qvals, weight_scales = quantize( + child.weight, group_size=group_size, nbit=nbit, scale_only=True ) - getattr(module, name).initialize_from_unpacked_weights( - weight_qvals, - weight_scales, - group_size, + weight_scales_and_zeros = weight_scales + else: + weight_qvals, weight_scales, weight_zeros = quantize( + child.weight, group_size=group_size, nbit=nbit, scale_only=False ) + weight_scales_and_zeros = (weight_scales, weight_zeros.to(torch.int8)) + + qlinear = None + if nbit == 2: + if has_weight_zeros: + qlinear = _quantized_linear( + nbit=nbit, + has_weight_zeros=has_weight_zeros, + pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w2sz, + linear_op=torch.ops.torchao._linear_a8sz_w2sz, + squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, + ) + else: + qlinear = _quantized_linear( + nbit=nbit, + has_weight_zeros=has_weight_zeros, + pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w2s, + linear_op=torch.ops.torchao._linear_a8sz_w2s, + squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, + ) + elif nbit == 3: + if has_weight_zeros: + qlinear = _quantized_linear( + nbit=nbit, + has_weight_zeros=has_weight_zeros, + pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w3sz, + linear_op=torch.ops.torchao._linear_a8sz_w3sz, + squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, + ) + else: + qlinear = _quantized_linear( + nbit=nbit, + has_weight_zeros=has_weight_zeros, + pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w3s, + linear_op=torch.ops.torchao._linear_a8sz_w3s, + squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, + ) elif nbit == 4: - setattr( - module, name, Chn8ActGrp4WgtQuantizedLinear(squeeze_unsqueeze_dim0) - ) - getattr(module, name).initialize_from_unpacked_weights( - weight_qvals, - weight_scales, - group_size, - ) + if has_weight_zeros: + qlinear = _quantized_linear( + nbit=nbit, + has_weight_zeros=has_weight_zeros, + pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w4sz, + linear_op=torch.ops.torchao._linear_a8sz_w4sz, + squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, + ) + else: + qlinear = _quantized_linear( + nbit=nbit, + has_weight_zeros=has_weight_zeros, + pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w4s, + linear_op=torch.ops.torchao._linear_a8sz_w4s, + squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, + ) + elif nbit == 5: + if has_weight_zeros: + qlinear = _quantized_linear( + nbit=nbit, + has_weight_zeros=has_weight_zeros, + pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w5sz, + linear_op=torch.ops.torchao._linear_a8sz_w5sz, + squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, + ) + else: + qlinear = _quantized_linear( + nbit=nbit, + has_weight_zeros=has_weight_zeros, + pack_weight_op=torch.ops.torchao._pack_weights_a8sz_w5s, + linear_op=torch.ops.torchao._linear_a8sz_w5s, + squeeze_unsqueeze_dim0=squeeze_unsqueeze_dim0, + ) else: - raise ValueError(f"Unsupported nbit: {nbit}") + raise ValueError( + f"Unsupported nbit ({nbit}) and has_weight_zeros ({has_weight_zeros}) combination" + ) + + assert qlinear is not None + setattr(module, name, qlinear) + getattr(module, name).pack_weights( + weight_qvals, + weight_scales_and_zeros, + group_size, + ) else: replace_linear_with_quantized_linear(child, kwargs)