diff --git a/csrc/cpu/aten/Linear.cpp b/csrc/cpu/aten/Linear.cpp index df9ea718e..1a9642eb6 100644 --- a/csrc/cpu/aten/Linear.cpp +++ b/csrc/cpu/aten/Linear.cpp @@ -356,38 +356,38 @@ at::Tensor ipex_linear_eltwise( input, weight, bias, eltwise, op_context, out_features); } -DEFINE_DISPATCH(woq_linear_packB_stub); DEFINE_DISPATCH(woq_tpp_gemm_packB_stub); at::Tensor woq_linear_pack_weight( const at::Tensor& weight, - const at::Tensor& scales, - const at::Tensor& zero_points, + std::vector& weight_shape, + bool is_int4, + int64_t group_size, int64_t lowp_mode) { // TPP kernel does not support edge cases // It generates packed weight in 4d (Nc, Kc, block_k, block_n) - auto N = weight.size(0), K = weight.size(1); + auto N = weight_shape[0], K = weight_shape[1]; // For TPP kernel, we only consider even K if (K % 2 == 0) { - bool is_int4 = weight.scalar_type() == c10::kQUInt4x2; - // int num_threads = at::get_num_threads(); size_t block_n = 32; - if (lowp_mode == 0) { - block_n = 16; - } - size_t block_k = 64; + size_t block_k = group_size > 0 ? std::min(group_size, (int64_t)64) : 64; while (K % block_k != 0) { block_k /= 2; } assert(block_k > 0); if (is_int4) { + if (block_k % 4 && lowp_mode == 3) { + // This case is not supported by kernel + return weight; + } // Create a new non-quantized tensor in data type uint8 (Byte) // One uint8 holds two int4 values. Compressed along K. // N is padded to the nearest multiple of block_n. + // Note that weight is already compressed int64_t K_int4_compressed = K / 2; int64_t N_int4 = N % block_n ? N / block_n * block_n + block_n : N; at::Tensor weight_int4 = at::empty( {N_int4, K_int4_compressed}, device(c10::kCPU).dtype(c10::kByte)); - int64_t weight_size_bytes = weight.numel() / 2; + int64_t weight_size_bytes = weight.numel(); int64_t weight_int4_size_bytes = weight_int4.numel(); int64_t pad_size_bytes = weight_int4_size_bytes - weight_size_bytes; std::memcpy(weight_int4.data_ptr(), weight.data_ptr(), weight_size_bytes); @@ -395,57 +395,25 @@ at::Tensor woq_linear_pack_weight( (uint8_t*)weight_int4.data_ptr() + weight_size_bytes, 0, pad_size_bytes); - auto packed_b = woq_tpp_gemm_packB_stub( + return woq_tpp_gemm_packB_stub( kCPU, weight_int4, is_int4, block_n, block_k, lowp_mode); - if (packed_b.defined()) { - return packed_b; - } } - if (!(N % block_n) && !(K % block_k)) { - auto packed_b = woq_tpp_gemm_packB_stub( + if (N % block_n) { + return weight; + } else { + return woq_tpp_gemm_packB_stub( kCPU, weight, is_int4, block_n, block_k, lowp_mode); - if (packed_b.defined()) { - return packed_b; - } } } - return woq_linear_packB_stub(kCPU, weight, scales, zero_points); + return weight; } -DEFINE_DISPATCH(woq_linear_unpackB_stub); DEFINE_DISPATCH(woq_tpp_gemm_unpackB_stub); at::Tensor woq_linear_unpack_weight( const at::Tensor& weight, bool is_int4, int64_t lowp_mode) { - if (weight.dim() > 2) { - auto unpacked_b = - woq_tpp_gemm_unpackB_stub(kCPU, weight, is_int4, lowp_mode); - if (unpacked_b.defined()) { - return unpacked_b; - } - } - return woq_linear_unpackB_stub(kCPU, weight); -} - -DEFINE_DISPATCH(woq_gemm_kernel_stub); -void woq_linear_kernel_output( - const at::Tensor& self, - const at::Tensor& weight, - const at::Tensor& scales_float, - const at::Tensor& zero_points_float, - const at::Tensor& bias, - int64_t lowp_mode, - at::Tensor& output) { - woq_gemm_kernel_stub( - kCPU, - self, - weight, - scales_float, - zero_points_float, - bias, - lowp_mode, - output); + return woq_tpp_gemm_unpackB_stub(kCPU, weight, is_int4, lowp_mode); } DEFINE_DISPATCH(woq_tpp_gemm_kernel_stub); @@ -456,48 +424,26 @@ at::Tensor woq_linear_kernel( const std::vector& zps_list, const std::vector& bias_list, bool is_int4, + int64_t group_size, int64_t lowp_mode, - int64_t num_concats) { - if (weight.dim() > 2) { - auto out = woq_tpp_gemm_kernel_stub( - kCPU, - self, - weight, - scales_list, - zps_list, - bias_list, - is_int4, - lowp_mode, - num_concats, - WOQ_FUSE_NONE, // no post op fusion - std::vector()); - if (out.defined()) { - return out; - } - } - auto input_size = self.sizes(); - std::vector output_size(input_size.begin(), input_size.end() - 1); - output_size.push_back(weight.size(0)); - auto output = at::empty(output_size, self.options()); - output.set_requires_grad(self.requires_grad()); - woq_linear_kernel_output( + int64_t num_concats, + int64_t act_quant_mode) { + int64_t quant_w_mode = group_size > 0 ? 1 : 0; + return woq_tpp_gemm_kernel_stub( + kCPU, self, weight, - scales_list[0], - zps_list[0], - bias_list[0], + scales_list, + zps_list, + bias_list, + is_int4, lowp_mode, - output); - if (num_concats > 1) { - // View as [..., num_concats, N/num_concats], transpose then make contiguous - // Finally view back as output shape - auto out_shape = output.sizes().vec(); - out_shape.insert(out_shape.end() - 1, num_concats); - out_shape.back() /= num_concats; - return output.view(out_shape).transpose(0, -2).contiguous().view( - output.sizes().vec()); - } - return output; + num_concats, + WOQ_FUSE_NONE, // no post op fusion + std::vector(), + act_quant_mode, + quant_w_mode, + group_size); } at::Tensor woq_linear_forward( @@ -510,32 +456,6 @@ at::Tensor woq_linear_forward( ->run(input); } -DEFINE_DISPATCH(woq_gemm_eltwise_kernel_stub); -void woq_linear_eltwise_kernel_output( - const at::Tensor& self, - const at::Tensor& weight, - const at::Tensor& scales_float, - const at::Tensor& zero_points_float, - const at::Tensor& bias, - const c10::string_view& post_op, - const torch::List>& scalars, - const c10::optional& algorithm, - int64_t lowp_mode, - at::Tensor& output) { - woq_gemm_eltwise_kernel_stub( - kCPU, - self, - weight, - scales_float, - zero_points_float, - bias, - post_op, - scalars, - algorithm, - lowp_mode, - output); -} - at::Tensor woq_linear_eltwise_kernel( const at::Tensor& self, const at::Tensor& weight, @@ -546,44 +466,28 @@ at::Tensor woq_linear_eltwise_kernel( const torch::List>& scalars, const c10::optional& algorithm, bool is_int4, + int64_t group_size, int64_t lowp_mode, - int64_t num_concats) { + int64_t num_concats, + int64_t act_quant_mode) { int64_t post_op_fusion_type = post_op == "gelu" ? WOQ_FUSE_GELU : WOQ_FUSE_NONE; - if (weight.dim() > 2) { - auto out = woq_tpp_gemm_kernel_stub( - kCPU, - self, - weight, - scales_list, - zps_list, - bias_list, - is_int4, - lowp_mode, - num_concats, - post_op_fusion_type, - std::vector()); - if (out.defined()) { - return out; - } - } - auto input_size = self.sizes(); - std::vector output_size(input_size.begin(), input_size.end() - 1); - output_size.push_back(weight.size(0)); - auto output = at::empty(output_size, self.options()); - output.set_requires_grad(self.requires_grad()); - woq_linear_eltwise_kernel_output( + int64_t quant_w_mode = group_size > 0 ? 1 : 0; + return woq_tpp_gemm_kernel_stub( + kCPU, self, weight, - scales_list[0], - zps_list[0], - bias_list[0], - post_op, - scalars, - algorithm, + scales_list, + zps_list, + bias_list, + is_int4, lowp_mode, - output); - return output; + num_concats, + post_op_fusion_type, + std::vector(), + act_quant_mode, + quant_w_mode, + group_size); } at::Tensor woq_linear_gelu_forward( @@ -604,87 +508,27 @@ at::Tensor woq_linear_add_kernel( const std::vector& zps_list, const std::vector& bias_list, bool is_int4, + int64_t group_size, int64_t lowp_mode, int64_t num_concats, - at::Tensor& accumu, - const c10::optional& alpha) { - c10::Scalar a = alpha.has_value() ? alpha.value() : 1.0f; - if (weight.dim() > 2) { - auto output = woq_tpp_gemm_kernel_stub( - kCPU, - self, - weight, - scales_list, - zps_list, - bias_list, - is_int4, - lowp_mode, - num_concats, - WOQ_FUSE_NONE, // no eltwise post op - std::vector()); - if (output.defined()) { - at::add_out(accumu, output, accumu, a); - return accumu; - } - } - auto input_size = self.sizes(); - std::vector output_size(input_size.begin(), input_size.end() - 1); - output_size.push_back(weight.size(0)); - auto output = at::empty(output_size, self.options()); - output.set_requires_grad(self.requires_grad()); - woq_linear_kernel_output( - self, - weight, - scales_list[0], - zps_list[0], - bias_list[0], - lowp_mode, - output); - at::add_out(accumu, output, accumu, a); - return accumu; -} - -at::Tensor woq_linear_add_kernel( - const at::Tensor& self, - const at::Tensor& weight, - const std::vector& scales_list, - const std::vector& zps_list, - const std::vector& bias_list, - bool is_int4, - int64_t lowp_mode, - int64_t num_concats, - const std::vector& others) { - if (weight.dim() > 2) { - auto out = woq_tpp_gemm_kernel_stub( - kCPU, - self, - weight, - scales_list, - zps_list, - bias_list, - is_int4, - lowp_mode, - num_concats, - WOQ_FUSE_ADD, // post op add - others); - if (out.defined()) { - return out; - } - } - auto input_size = self.sizes(); - std::vector output_size(input_size.begin(), input_size.end() - 1); - output_size.push_back(weight.size(0)); - auto output = at::empty(output_size, self.options()); - output.set_requires_grad(self.requires_grad()); - woq_linear_kernel_output( + const std::vector& others, + int64_t act_quant_mode) { + int64_t quant_w_mode = group_size > 0 ? 1 : 0; + return woq_tpp_gemm_kernel_stub( + kCPU, self, weight, - scales_list[0], - zps_list[0], - bias_list[0], + scales_list, + zps_list, + bias_list, + is_int4, lowp_mode, - output); - return at::add(output, others[0]); + num_concats, + WOQ_FUSE_ADD, // post op add + others, + act_quant_mode, + quant_w_mode, + group_size); } at::Tensor woq_linear_add_add_kernel( @@ -694,41 +538,27 @@ at::Tensor woq_linear_add_add_kernel( const std::vector& zps_list, const std::vector& bias_list, bool is_int4, + int64_t group_size, int64_t lowp_mode, int64_t num_concats, - const std::vector& others) { - if (weight.dim() > 2) { - auto out = woq_tpp_gemm_kernel_stub( - kCPU, - self, - weight, - scales_list, - zps_list, - bias_list, - is_int4, - lowp_mode, - num_concats, - WOQ_FUSE_ADD_ADD, // post op add-add - others); - if (out.defined()) { - return out; - } - } - auto input_size = self.sizes(); - std::vector output_size(input_size.begin(), input_size.end() - 1); - output_size.push_back(weight.size(0)); - auto output = at::empty(output_size, self.options()); - output.set_requires_grad(self.requires_grad()); - woq_linear_kernel_output( + const std::vector& others, + int64_t act_quant_mode) { + int64_t quant_w_mode = group_size > 0 ? 1 : 0; + return woq_tpp_gemm_kernel_stub( + kCPU, self, weight, - scales_list[0], - zps_list[0], - bias_list[0], + scales_list, + zps_list, + bias_list, + is_int4, lowp_mode, - output); - auto y = at::add(output, others[0]); - return at::add(y, others[1]); + num_concats, + WOQ_FUSE_ADD_ADD, // post op add-add + others, + act_quant_mode, + quant_w_mode, + group_size); } at::Tensor woq_linear_add_forward( @@ -752,6 +582,74 @@ at::Tensor woq_linear_add_add_forward( ->run_add_add(input, others); } +at::Tensor matmul_i8i8i32(const at::Tensor& input, const at::Tensor& weight) { + // x:s8 * w:s8 -> y:s32 + // No bias + TORCH_CHECK( + input.scalar_type() == c10::kChar, + "matmul_i8i8i32: input dtype should be signed int8 but found ", + input.scalar_type()); + TORCH_CHECK( + weight.scalar_type() == c10::kChar, + "matmul_i8i8i32: weight dtype should be signed int8 but found ", + weight.scalar_type()); + TORCH_CHECK( + input.dim() == 2 && weight.dim() == 2, + "matmul_i8i8i32: Expect Input and weight are 2d but got ", + input.dim(), + " and ", + weight.dim()); + TORCH_CHECK( + input.size(1) == weight.size(1), + "matmul_i8i8i32: Input shape and weight shape do not match, got ", + input.sizes(), + " and ", + weight.sizes()); + auto output_shape = input.sizes().vec(); + output_shape.back() = weight.size(0); + auto output = at::empty(output_shape, input.options().dtype(c10::kInt)); + auto input_contig = input.contiguous(); + auto weight_contig = weight.t().contiguous(); + // Create ideep tensors for oneDNN computation + auto src = ideep::tensor( + {input_contig.sizes().vec(), + ideep::tensor::data_type::s8, + input_contig.strides().vec()}, + input_contig.data_ptr()); + auto wei = ideep::tensor( + {weight_contig.sizes().vec(), + ideep::tensor::data_type::s8, + weight_contig.strides().vec()}, + weight_contig.data_ptr()); + auto dst = ideep::tensor( + {output.sizes().vec(), + ideep::tensor::data_type::s32, + output.strides().vec()}, + output.data_ptr()); + // Create primitive desc + auto engine = ideep::engine::cpu_engine(); + ideep::attr_t op_attr; + op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + auto src_desc = src.get_desc(); + auto wei_desc = wei.get_desc(); + auto dst_desc = dst.get_desc(); + auto prim_desc = dnnl::matmul::primitive_desc( + engine, src_desc, wei_desc, dst_desc, op_attr); + // Reorder weight + auto expected_weight = wei.reorder_if_differ_in(prim_desc.weights_desc()); + // Prepare args for primitive + ideep::tensor scratchpad(prim_desc.scratchpad_desc()); + ideep::exec_args args; + args.insert({DNNL_ARG_SRC, src}); + args.insert({DNNL_ARG_WEIGHTS, expected_weight}); + args.insert({DNNL_ARG_DST, dst}); + args.insert({DNNL_ARG_SCRATCHPAD, scratchpad}); + // Create primitve and execute + auto primitive = dnnl::matmul(prim_desc); + primitive.execute(ideep::stream::default_stream(), args); + return output; +} + } // namespace cpu } // namespace torch_ipex @@ -863,6 +761,15 @@ at::Tensor woq_linear_add_add_forward( cpu_cached_cast(target_type, others)); } +at::Tensor matmul_i8i8i32(const at::Tensor& input, const at::Tensor& weight) { + c10::impl::ExcludeDispatchKeyGuard no_autocastCPU(DispatchKey::AutocastCPU); + static auto op = torch::Dispatcher::singleton() + .findSchemaOrThrow("torch_ipex::matmul_i8i8i32", "") + .typed(); + // input is int8. Don't cast to autocast dtype + return op.call(input, weight); +} + } // namespace autocast } // namespace torch_ipex @@ -948,6 +855,14 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) { "linear_eltwise_backward", c10::DispatchKey::CPU, torch_ipex::cpu::linear_eltwise_backward); + // bnb + m.def("matmul_i8i8i32(Tensor input, Tensor weight) -> Tensor"); + m.impl( + "matmul_i8i8i32", c10::DispatchKey::CPU, torch_ipex::cpu::matmul_i8i8i32); + m.impl( + "matmul_i8i8i32", + c10::DispatchKey::AutocastCPU, + torch_ipex::autocast::matmul_i8i8i32); } } // namespace diff --git a/csrc/cpu/aten/Linear.h b/csrc/cpu/aten/Linear.h index db3900b64..3ef0febf8 100644 --- a/csrc/cpu/aten/Linear.h +++ b/csrc/cpu/aten/Linear.h @@ -81,8 +81,9 @@ at::Tensor ipex_linear_eltwise( // WOQ linear ops at::Tensor woq_linear_pack_weight( const at::Tensor& weight, - const at::Tensor& scale, - const at::Tensor& zero_points, + std::vector& weight_shape, + bool is_4bit, + int64_t group_size, int64_t lowp_mode); at::Tensor woq_linear_unpack_weight( @@ -90,15 +91,6 @@ at::Tensor woq_linear_unpack_weight( bool is_int4, int64_t lowp_mode); -void woq_linear_kernel_output( - const at::Tensor& self, - const at::Tensor& weight, - const at::Tensor& scales_float, - const at::Tensor& zero_points_float, - const at::Tensor& bias, - int64_t lowp_mode, - at::Tensor& output); - at::Tensor woq_linear_kernel( const at::Tensor& self, const at::Tensor& weight, @@ -106,20 +98,10 @@ at::Tensor woq_linear_kernel( const std::vector& zps_list, const std::vector& bias_list, bool is_int4, + int64_t group_size, int64_t lowp_mode, - int64_t num_concats); - -void woq_linear_eltwise_kernel_output( - const at::Tensor& self, - const at::Tensor& weight, - const at::Tensor& scales_float, - const at::Tensor& zero_points_float, - const at::Tensor& bias, - const c10::string_view& post_op, - const torch::List>& scalars, - const c10::optional& algorithm, - int64_t lowp_mode, - at::Tensor& output); + int64_t num_concats, + int64_t act_quant_mode); at::Tensor woq_linear_eltwise_kernel( const at::Tensor& self, @@ -131,20 +113,10 @@ at::Tensor woq_linear_eltwise_kernel( const torch::List>& scalars, const c10::optional& algorithm, bool is_int4, - int64_t lowp_mode, - int64_t num_concats); - -at::Tensor woq_linear_add_kernel( - const at::Tensor& self, - const at::Tensor& weight, - const std::vector& scales_list, - const std::vector& zps_list, - const std::vector& bias_list, - bool is_int4, + int64_t group_size, int64_t lowp_mode, int64_t num_concats, - at::Tensor& accumu, - const c10::optional& alpha); + int64_t act_quant_mode); at::Tensor woq_linear_add_kernel( const at::Tensor& self, @@ -153,9 +125,11 @@ at::Tensor woq_linear_add_kernel( const std::vector& zps_list, const std::vector& bias_list, bool is_int4, + int64_t group_size, int64_t lowp_mode, int64_t num_concats, - const std::vector& others); + const std::vector& others, + int64_t act_quant_mode); at::Tensor woq_linear_add_add_kernel( const at::Tensor& self, @@ -164,9 +138,11 @@ at::Tensor woq_linear_add_add_kernel( const std::vector& zps_list, const std::vector& bias_list, bool is_int4, + int64_t group_size, int64_t lowp_mode, int64_t num_concats, - const std::vector& others); + const std::vector& others, + int64_t act_quant_mode); namespace { void woq_gemm_kernel_impl( @@ -240,7 +216,10 @@ using woq_tpp_gemm_kernel_fn = at::Tensor (*)( int64_t, int64_t, int64_t, - const std::vector&); + const std::vector&, + int64_t, + int64_t, + int64_t); using woq_tpp_gemm_packB_fn = at::Tensor (*)(const at::Tensor&, bool, size_t, size_t, int64_t); diff --git a/csrc/cpu/aten/kernels/WoqLinearKrnl.cpp b/csrc/cpu/aten/kernels/WoqLinearKrnl.cpp index a861a336b..de9b84050 100644 --- a/csrc/cpu/aten/kernels/WoqLinearKrnl.cpp +++ b/csrc/cpu/aten/kernels/WoqLinearKrnl.cpp @@ -2658,6 +2658,16 @@ void woq_gemm_kernel_impl( zero_points_float_ptr); } } + } else { + auto qw = woq_linear_unpackB_impl(weight); + auto w = qw.dequantize().to(self_.scalar_type()).to(c10::kFloat); + auto x = self.to(c10::ScalarType::Float); + auto out = at::linear(x, w); + if (bias.defined()) { + auto b = bias.to(self_.scalar_type()).to(c10::kFloat); + out = at::add(out, b); + } + output = out.to(self.scalar_type()); } } else { // kPerChannelAffineFloatQParams @@ -2805,7 +2815,7 @@ void woq_gemm_kernel_impl( } else { at::linear_out(output, self, w); } - } else { + } else if (self_.scalar_type() == at::kBFloat16) { auto w = weight.dequantize(); auto x = self.to(c10::ScalarType::Float); // This is to align with the AVX512 kernel @@ -2818,6 +2828,15 @@ void woq_gemm_kernel_impl( out = at::add(out, bias); } output = out.to(self.scalar_type()); + } else { + auto w = weight.dequantize().to(self_.scalar_type()).to(c10::kFloat); + auto x = self.to(c10::ScalarType::Float); + auto out = at::linear(x, w); + if (bias.defined()) { + auto b = bias.to(self_.scalar_type()).to(c10::kFloat); + out = at::add(out, b); + } + output = out.to(self.scalar_type()); } #endif diff --git a/csrc/cpu/aten/kernels/WoqTppKrnl.cpp b/csrc/cpu/aten/kernels/WoqTppKrnl.cpp index f8554c715..81890dafb 100644 --- a/csrc/cpu/aten/kernels/WoqTppKrnl.cpp +++ b/csrc/cpu/aten/kernels/WoqTppKrnl.cpp @@ -2,6 +2,7 @@ // #include #include #include +#include #include #include #include "csrc/cpu/tpp/woq/tla.h" @@ -20,6 +21,15 @@ namespace { using namespace tpp; using TensorList = std::vector; +#define FUSE_GELU 1 +#define FUSE_ADD 2 +#define FUSE_ADD_ADD 3 + +#define LOWP_MODE_NONE 0 +#define LOWP_MODE_FP16 1 +#define LOWP_MODE_BF16 2 +#define LOWP_MODE_INT8 3 + // We only build optimized kernels if AVX512_FP16 is supported and gcc>=12.3 // Otherwise we just return empty results // TODO(Weiwen) Merge WoqTppKrnl.cpp and WoqLinearKrnl.cpp and put the latter in @@ -31,6 +41,15 @@ using TensorList = std::vector; constexpr long PREFETCH_K_DIST = 64; // TODO(jgong5): do not hard-code constexpr long LOOP_K_UNROLL = 4; // TODO(jgong5): do not hard-code +#define UNQUANT_A -1 +#define QUANT_A_PER_TENSOR 0 +#define QUANT_A_PER_K_BLOCK 1 +#define QUANT_A_PER_M 2 +#define QUANT_A_PER_M_K_BLOCK 3 + +#define QUANT_W_PER_CHANNEL 0 +#define QUANT_W_PER_K_BLOCK 1 + template inline VAT load_dequant_zp_only_int4(uint8_t* p, VAT vzps, LUT lut) { TLA_ASSERT(false, "not implemented"); @@ -389,6 +408,7 @@ template < long ldb, bool transA = false, bool ACC = false, + int quant_a_mode = -1, long PREFETCH_K_DIST = 0, typename Enabled = void> struct GemmMicroKernel { @@ -413,6 +433,7 @@ template < long ldb, bool transA, bool ACC, + int quant_a_mode, long PREFETCH_K_DIST> struct GemmMicroKernel< T, @@ -424,6 +445,7 @@ struct GemmMicroKernel< ldb, transA, ACC, + quant_a_mode, PREFETCH_K_DIST, typename std::enable_if_t< std::is_same::value || std::is_same::value>> { @@ -563,7 +585,14 @@ struct GemmMicroKernel< }; #ifdef __AVX512VNNI__ -template +template < + long M, + long N, + long ldb, + bool transA, + bool ACC, + int quant_a_mode, + long PREFETCH_K_DIST> struct GemmMicroKernel< /*Tin*/ uint8_t, /*Tout*/ float, @@ -574,6 +603,7 @@ struct GemmMicroKernel< ldb, transA, ACC, + quant_a_mode, PREFETCH_K_DIST> { template static inline void call( @@ -585,8 +615,9 @@ struct GemmMicroKernel< long ldc, float* scales, int8_t* zps, - float scale_a, - int32_t zp_a) { + float* scale_a, + int32_t* zp_a, + int32_t k_groups) { auto pqB = GetVLAPtr(B, {ldb, 2}); // [K/4,N,4] packed in 4-bit static_assert(N % 16 == 0, "N must be a multiple of 16"); @@ -654,10 +685,32 @@ struct GemmMicroKernel< constexpr const int col = i % COLS; // compute (qC - compensate * zp_a) * scale_a * scale_b // where compensate = sum(qB) - vc[i] = _mm512_sub_epi32( - vc[i], _mm512_mullo_epi32(vcompensate[col], _mm512_set1_epi32(zp_a))); - __m512 vc_float = _mm512_cvtepi32_ps(vc[i]); - vc_float = _mm512_mul_ps(vc_float, _mm512_set1_ps(scale_a)); + __m512 vc_float; + if constexpr ( + quant_a_mode == QUANT_A_PER_TENSOR || + quant_a_mode == QUANT_A_PER_K_BLOCK) { + vc[i] = _mm512_sub_epi32( + vc[i], + _mm512_mullo_epi32(vcompensate[col], _mm512_set1_epi32(*zp_a))); + vc_float = _mm512_cvtepi32_ps(vc[i]); + vc_float = _mm512_mul_ps(vc_float, _mm512_set1_ps(*scale_a)); + } else if constexpr (quant_a_mode == QUANT_A_PER_M) { + vc[i] = _mm512_sub_epi32( + vc[i], + _mm512_mullo_epi32( + vcompensate[col], _mm512_set1_epi32(*(zp_a + row)))); + vc_float = _mm512_cvtepi32_ps(vc[i]); + vc_float = _mm512_mul_ps(vc_float, _mm512_set1_ps(*(scale_a + row))); + } else { + vc[i] = _mm512_sub_epi32( + vc[i], + _mm512_mullo_epi32( + vcompensate[col], _mm512_set1_epi32(*(zp_a + row * k_groups)))); + vc_float = _mm512_cvtepi32_ps(vc[i]); + vc_float = _mm512_mul_ps( + vc_float, _mm512_set1_ps(*(scale_a + row * k_groups))); + } + vc_float = _mm512_mul_ps(vc_float, vscales[col]); if constexpr (ACC) { auto vc_old = _mm512_loadu_ps(C + row * ldc + col * 16); @@ -1055,6 +1108,7 @@ template < bool transA, bool ACC, bool is_int4, + int quant_a_mode, long PREFETCH_K_DIST = 0> class DequantGemmTPP { public: @@ -1069,8 +1123,9 @@ class DequantGemmTPP { TZero* zps, Tout* C, bool no_tile_cfg = true, - float scale_a = 1.0, - int32_t zp_a = 0) { + float* scale_a = nullptr, + int32_t* zp_a = nullptr, + int32_t k_groups = -1) { TLA_ASSERT(false, "not implemented"); } @@ -1092,6 +1147,7 @@ template < bool transA, bool ACC, bool is_int4, + int quant_a_mode, long PREFETCH_K_DIST> class DequantGemmTPP< Tin, @@ -1104,6 +1160,7 @@ class DequantGemmTPP< transA, ACC, is_int4, + quant_a_mode, PREFETCH_K_DIST> { public: DequantGemmTPP(long M, long K, long lda, long ldc) @@ -1133,8 +1190,9 @@ class DequantGemmTPP< Tin* zps, Tout* C, bool no_tile_cfg = true, - float scale_a = 1.0, - int32_t zp_a = 0) { + float* scale_a = nullptr, + int32_t* zp_a = nullptr, + int32_t k_groups = -1) { if (M < SMALL_BATCH_THRESHOLD && ((std::is_same() && std::is_same()) || (std::is_same() && std::is_same()))) { @@ -1153,6 +1211,7 @@ class DequantGemmTPP< ldb, transA, ACC, + quant_a_mode, PREFETCH_K_DIST>:: template call( K, @@ -1178,6 +1237,7 @@ class DequantGemmTPP< ldb, transA, ACC, + quant_a_mode, PREFETCH_K_DIST>:: template call( K, @@ -1228,6 +1288,7 @@ template < long ldb, bool transA, bool ACC, + int quant_a_mode, long PREFETCH_K_DIST> class DequantGemmTPP< /*Tin*/ uint8_t, @@ -1240,6 +1301,7 @@ class DequantGemmTPP< transA, ACC, /*is_int4*/ true, + quant_a_mode, PREFETCH_K_DIST> { using TBrgemmTPP = BrgemmTPP; @@ -1271,8 +1333,9 @@ class DequantGemmTPP< int8_t* zps, float* C, bool no_tile_cfg = true, - float scale_a = 1.0, - int32_t zp_a = 0) { + float* scale_a = nullptr, + int32_t* zp_a = nullptr, + int32_t k_groups = -1) { auto qA = GetVLAPtr(A, {lda}); #ifdef __AVX512VNNI__ if (M < SMALL_BATCH_THRESHOLD) { @@ -1280,6 +1343,17 @@ class DequantGemmTPP< BLOCK_M * N / 16 >= 16 ? BLOCK_M / 2 : BLOCK_M; for (long m = 0; m < M; m += PREFERRED_BLOCK_M) { long block_m = std::min(M - m, PREFERRED_BLOCK_M); + float* scale_a_m; + int32_t* zp_a_m; + if constexpr ( + quant_a_mode == QUANT_A_PER_M || + quant_a_mode == QUANT_A_PER_M_K_BLOCK) { + scale_a_m = scale_a + m * k_groups; + zp_a_m = zp_a + m * k_groups; + } else { + scale_a_m = scale_a; + zp_a_m = zp_a; + } enumerate_dispatcher::call( block_m, [&](auto i) { @@ -1293,6 +1367,7 @@ class DequantGemmTPP< ldb, /*transA*/ false, ACC, + quant_a_mode, PREFETCH_K_DIST>:: template call( K, @@ -1303,8 +1378,9 @@ class DequantGemmTPP< ldc, scales, zps, - scale_a, - zp_a); + scale_a_m, + zp_a_m, + k_groups); }, [&](auto i) { range_dispatcher::call( @@ -1320,6 +1396,7 @@ class DequantGemmTPP< ldb, /*transA*/ false, ACC, + quant_a_mode, PREFETCH_K_DIST>:: template call( K, @@ -1330,8 +1407,9 @@ class DequantGemmTPP< ldc, scales, zps, - scale_a, - zp_a); + scale_a_m, + zp_a_m, + k_groups); }, [&](auto j) { failing_fallback(); }); }); @@ -1351,7 +1429,19 @@ class DequantGemmTPP< for (long m = 0; m < M; ++m) { #pragma omp simd for (long n = 0; n < N; ++n) { - float c = (qC[m][n] - compensation[n] * zp_a) * scale_a * scales[n]; + float* scale_a_m; + int32_t* zp_a_m; + if constexpr ( + quant_a_mode == QUANT_A_PER_M || + quant_a_mode == QUANT_A_PER_M_K_BLOCK) { + scale_a_m = scale_a + m * k_groups; + zp_a_m = zp_a + m * k_groups; + } else { + scale_a_m = scale_a; + zp_a_m = zp_a; + } + float c = (qC[m][n] - compensation[n] * (*zp_a_m)) * (*scale_a_m) * + scales[n]; if constexpr (ACC) { C[m * ldc + n] += c; } else { @@ -1382,10 +1472,6 @@ class DequantGemmTPP< long ldc; }; -#define FUSE_GELU 1 -#define FUSE_ADD 2 -#define FUSE_ADD_ADD 3 - // If T != TComp // T -> TComp -> GEMM -> TComp -> bias/PostOp -> Tout // If T == TComp (we can save intermediate output buffer and schedule M/N/K @@ -1397,21 +1483,24 @@ template < typename TGemmOut, typename Tout, typename TScale, - typename TZero> + typename TZero, + int quant_a_mode = -1, + int quant_w_mode = 0> void qlinear_woq_affine_impl( const at::Tensor& x, const at::Tensor& qw_packed, const at::Tensor& scales, // dtype is TComp const at::Tensor& zps, // dtype is TComp const at::Tensor& b, // dtype is TComp - at::Tensor y, + at::Tensor& y, bool is_int4, int k_splits, int num_concats, int fusion_type, const TensorList& others_list, - float scale_a = 1.0f, - int32_t zp_a = 0) { + int64_t quant_block_k, + at::Tensor t_scale_a = at::empty({1}, at::kFloat), + at::Tensor t_zp_a = at::empty({1}, at::kInt)) { auto x_sizes = x.sizes(); auto w_sizes = qw_packed.sizes(); auto M = x_sizes[0]; @@ -1421,6 +1510,10 @@ void qlinear_woq_affine_impl( auto Kb = w_sizes[2]; auto N = Nc * Nb; auto K = Kc * Kb; + assert(quant_block_k % Kb == 0); + auto quant_block_multiple = quant_block_k == 0 ? 1 : quant_block_k / Kb; + auto quant_k_blocks = + quant_block_k == 0 ? 1 : (K + quant_block_k - 1) / quant_block_k; TLA_ASSERT(Nb % 16 == 0, "Nb must be a multiple of 16"); TLA_ASSERT( @@ -1458,14 +1551,18 @@ void qlinear_woq_affine_impl( auto ldy = num_concats <= 1 ? N : Nc / num_concats * Nb; auto ldc = (no_y_buf || k_splits > 1) ? ldy : Nb; + auto scales_a_ptr = t_scale_a.data_ptr(); + auto zps_a_ptr = t_zp_a.data_ptr(); auto px = GetVLAPtr(x, {Kc, Kb}); auto pw = GetVLAPtr( (uint8_t*)qw_packed.data_ptr(), {Kc, Kb * (is_int4 ? Nb / 2 : Nb)}); auto py = GetVLAPtr(y, {Nc, Nb}); /*[M, Nc, Nb]*/ auto py_concat = GetVLAPtr( y, {M, Nc / num_concats, Nb}); /*[num_concats, M, Nc/num_concats, Nb]*/ - auto pscales = GetVLAPtr(scales, {Nb}); - auto pzps = GetVLAPtr(zps, {Nb}); + int scales_kc = quant_w_mode == QUANT_W_PER_CHANNEL ? QUANT_W_PER_K_BLOCK + : quant_k_blocks; + auto pscales = GetVLAPtr(scales, {scales_kc, Nb}); + auto pzps = GetVLAPtr(zps, {scales_kc, Nb}); auto pb = GetVLAPtr(b, {Nb}); auto tin0 = others_list.size() > 0 ? others_list[0] : at::Tensor{}; auto pin0 = GetVLAPtr(tin0, {Nc, Nb}); /*[M, Nc, Nb]*/ @@ -1559,6 +1656,7 @@ void qlinear_woq_affine_impl( /*transA*/ false, /*ACC*/ true, is_int4, + quant_a_mode, PREFETCH_K_DIST>( /*M*/ BLOCK_M, /*K*/ Kb, @@ -1575,6 +1673,7 @@ void qlinear_woq_affine_impl( /*transA*/ false, /*ACC*/ true, is_int4, + quant_a_mode, 0>( /*M*/ BLOCK_M, /*K*/ Kb, @@ -1591,6 +1690,7 @@ void qlinear_woq_affine_impl( /*transA*/ false, /*ACC*/ true, is_int4, + quant_a_mode, PREFETCH_K_DIST>( /*M*/ BLOCK_M_rem, /*K*/ Kb, @@ -1607,6 +1707,7 @@ void qlinear_woq_affine_impl( /*transA*/ false, /*ACC*/ true, is_int4, + quant_a_mode, 0>( /*M*/ BLOCK_M_rem, /*K*/ Kb, @@ -1648,6 +1749,38 @@ void qlinear_woq_affine_impl( int m = idx[0]; int kc = idx[1]; int nc = idx[2]; + float* scale_a = nullptr; + int32_t* zp_a = nullptr; + int32_t k_groups = -1; + int32_t quant_offset = kc / quant_block_multiple; + if constexpr (std::is_same()) { + if constexpr (quant_a_mode == QUANT_A_PER_TENSOR) { + scale_a = scales_a_ptr; + zp_a = zps_a_ptr; + } else if constexpr ( + quant_a_mode == QUANT_A_PER_K_BLOCK) { + scale_a = scales_a_ptr + quant_offset; + zp_a = zps_a_ptr + quant_offset; + } else if constexpr (quant_a_mode == QUANT_A_PER_M) { + scale_a = scales_a_ptr + m; + zp_a = zps_a_ptr + m; + k_groups = 1; + } else { + scale_a = + scales_a_ptr + m * quant_k_blocks + quant_offset; + zp_a = zps_a_ptr + m * quant_k_blocks + quant_offset; + k_groups = quant_k_blocks; + } + } + TScale* scale_w = nullptr; + TZero* zp_w = nullptr; + if constexpr (quant_w_mode == QUANT_W_PER_CHANNEL) { + scale_w = pscales[nc][0]; + zp_w = pzps[nc][0]; + } else { + scale_w = pscales[nc][quant_offset]; + zp_w = pzps[nc][quant_offset]; + } bool is_rem = (m + BLOCK_M > M); TGemmOut* y_ptr = num_concats <= 1 ? (TGemmOut*)py[m][nc] @@ -1666,22 +1799,24 @@ void qlinear_woq_affine_impl( dequant_gemm_tpp( x_ptr, pw[nc][kc], - pscales[nc], - pzps[nc], + scale_w, + zp_w, y_ptr, true, scale_a, - zp_a); + zp_a, + k_groups); } else { dequant_gemm_no_prefetch_tpp( x_ptr, pw[nc][kc], - pscales[nc], - pzps[nc], + scale_w, + zp_w, y_ptr, true, scale_a, - zp_a); + zp_a, + k_groups); if (fusion_type > 0) { post_ops_fn(m, nc); } @@ -1699,23 +1834,25 @@ void qlinear_woq_affine_impl( dequant_gemm_rem_tpp( x_ptr, pw[nc][kc], - pscales[nc], - pzps[nc], + scale_w, + zp_w, y_ptr, false, scale_a, - zp_a); + zp_a, + k_groups); dequant_gemm_tpp.config(); } else { dequant_gemm_no_prefetch_rem_tpp( x_ptr, pw[nc][kc], - pscales[nc], - pzps[nc], + scale_w, + zp_w, y_ptr, false, scale_a, - zp_a); + zp_a, + k_groups); dequant_gemm_no_prefetch_tpp.config(); if (fusion_type > 0) { post_ops_rem_fn(m, nc); @@ -1790,6 +1927,38 @@ void qlinear_woq_affine_impl( } for (int kc = kc_start; kc < kc_end; kc++) { TComp* x_ptr = (TComp*)px[m][kc]; + float* scale_a = nullptr; + int32_t* zp_a = nullptr; + int32_t k_groups = -1; + int32_t quant_offset = kc / quant_block_multiple; + if constexpr (std::is_same()) { + if constexpr (quant_a_mode == QUANT_A_PER_TENSOR) { + scale_a = scales_a_ptr; + zp_a = zps_a_ptr; + } else if constexpr ( + quant_a_mode == QUANT_A_PER_K_BLOCK) { + scale_a = scales_a_ptr + quant_offset; + zp_a = zps_a_ptr + quant_offset; + } else if constexpr (quant_a_mode == QUANT_A_PER_M) { + scale_a = scales_a_ptr + m; + zp_a = zps_a_ptr + m; + k_groups = 1; + } else { + scale_a = + scales_a_ptr + m * quant_k_blocks + quant_offset; + zp_a = zps_a_ptr + m * quant_k_blocks + quant_offset; + k_groups = quant_k_blocks; + } + } + TScale* scale_w = nullptr; + TZero* zp_w = nullptr; + if constexpr (quant_w_mode == QUANT_W_PER_CHANNEL) { + scale_w = pscales[nc][0]; + zp_w = pzps[nc][0]; + } else { + scale_w = pscales[nc][quant_offset]; + zp_w = pzps[nc][quant_offset]; + } if (!is_rem) { alignas(64) TComp x_buf[BLOCK_M][Kb]; if (!no_x_buf) { @@ -1800,22 +1969,24 @@ void qlinear_woq_affine_impl( dequant_gemm_tpp( x_ptr, pw[nc][kc], - pscales[nc], - pzps[nc], + scale_w, + zp_w, y_ptr, true, scale_a, - zp_a); + zp_a, + k_groups); } else { dequant_gemm_no_prefetch_tpp( x_ptr, pw[nc][kc], - pscales[nc], - pzps[nc], + scale_w, + zp_w, y_ptr, true, scale_a, - zp_a); + zp_a, + k_groups); } } else { alignas(64) TComp x_buf[BLOCK_M][Kb]; @@ -1827,23 +1998,25 @@ void qlinear_woq_affine_impl( dequant_gemm_rem_tpp( x_ptr, pw[nc][kc], - pscales[nc], - pzps[nc], + scale_w, + zp_w, y_ptr, false, scale_a, - zp_a); + zp_a, + k_groups); dequant_gemm_tpp.config(); } else { dequant_gemm_no_prefetch_rem_tpp( x_ptr, pw[nc][kc], - pscales[nc], - pzps[nc], + scale_w, + zp_w, y_ptr, false, scale_a, - zp_a); + zp_a, + k_groups); dequant_gemm_no_prefetch_tpp.config(); } } @@ -1898,11 +2071,6 @@ void qlinear_woq_affine_impl( [](auto tuple) { failing_fallback(); }); } -#define LOWP_MODE_NONE 0 -#define LOWP_MODE_FP16 1 -#define LOWP_MODE_BF16 2 -#define LOWP_MODE_INT8 3 - /** * @brief pack the weight in quantized format. * @param qw quantized weight with shape [N, K] @@ -2084,6 +2252,183 @@ void compute_int8_qparams_per_tensor( *zp = (int32_t)(-std::nearbyint(min / *scale)); } +template +inline scalar_t max_propagate_nan(scalar_t a, scalar_t b) { + if (at::_isnan(a)) { + return a; + } + return a > b ? a : b; +} + +template +inline scalar_t min_propagate_nan(scalar_t a, scalar_t b) { + if (at::_isnan(a)) { + return a; + } + return a < b ? a : b; +} + +template +std::pair compute_int8_qparams_per_block( + const at::Tensor& t, + int quant_block_k, + int quant_a_mode) { + int M = t.size(0); + int K = t.size(1); + if (quant_a_mode == QUANT_A_PER_M) { + auto grouped_min = std::get<0>(t.min(-1)); + auto grouped_max = std::get<0>(t.max(-1)); + auto zeros = at::zeros_like(grouped_min); + auto min = at::minimum(grouped_min, zeros); + auto max = at::maximum(grouped_max, zeros); + auto scales = (max - min) / 255; + auto zps = -at::round(min / scales); + return std::make_pair( + std::move(scales.to(c10::kFloat)), std::move(zps.to(c10::kInt))); + } + int k_rem = K % quant_block_k; + int block_k = quant_block_k; + auto grouped = + t.index({at::indexing::Slice(), at::indexing::Slice(0, K - k_rem)}) + .view({M, K / quant_block_k, quant_block_k}); + at::Tensor grouped_min, grouped_max; + if (quant_a_mode == QUANT_A_PER_K_BLOCK) { + grouped_min = std::get<0>(std::get<0>(grouped.min(-1)).min(0)); + grouped_max = std::get<0>(std::get<0>(grouped.max(-1)).max(0)); + } else { + grouped_min = std::get<0>(grouped.min(-1)); + grouped_max = std::get<0>(grouped.max(-1)); + } + auto zeros = at::zeros_like(grouped_min); + auto min = at::minimum(grouped_min, zeros); + auto max = at::maximum(grouped_max, zeros); + auto scales = (max - min) / 255.0f; + auto zps = -at::round(min / scales); + if (k_rem) { + auto grouped_rem = + t.index({at::indexing::Slice(), at::indexing::Slice(K - k_rem, K)}) + .view({M, 1, k_rem}); + at::Tensor grouped_rem_min, grouped_rem_max; + if (quant_a_mode == QUANT_A_PER_K_BLOCK) { + grouped_rem_min = std::get<0>(std::get<0>(grouped_rem.min(-1)).min(0)); + grouped_rem_max = std::get<0>(std::get<0>(grouped_rem.max(-1)).max(0)); + } else { + grouped_rem_min = std::get<0>(grouped_rem.min(-1)); + grouped_rem_max = std::get<0>(grouped_rem.max(-1)); + } + auto min_rem = at::minimum(grouped_rem_min, at::tensor({0})); + auto max_rem = at::maximum(grouped_rem_max, at::tensor({0})); + auto scales_rem = (max_rem - min_rem) / 255; + auto zps_rem = -at::round(min_rem / scales_rem); + scales = at::cat({scales, scales_rem}, 1).contiguous(); + zps = at::cat({zps, zps_rem}, 1).contiguous(); + } + return std::make_pair( + std::move(scales.to(c10::kFloat)), std::move(zps.to(c10::kInt))); +} + +template <> +std::pair compute_int8_qparams_per_block( + const at::Tensor& t, + int quant_block_k, + int quant_a_mode) { + auto in_ptr = t.data_ptr(); + int M = t.size(0); + int K = t.size(1); + int Kc = (K + quant_block_k - 1) / quant_block_k; + auto vecsize = at::vec::Vectorized::size(); + at::Tensor scales, zps; + if (quant_a_mode == QUANT_A_PER_K_BLOCK) { + scales = at::empty({Kc}, t.options().dtype(at::kFloat)); + zps = at::empty({Kc}, t.options().dtype(at::kInt)); + } else if (quant_a_mode == QUANT_A_PER_M) { + scales = at::empty({M}, t.options().dtype(at::kFloat)); + zps = at::empty({M}, t.options().dtype(at::kInt)); + } else { + scales = at::empty({M, Kc}, t.options().dtype(at::kFloat)); + zps = at::empty({M, Kc}, t.options().dtype(at::kInt)); + } + auto scales_ptr = scales.data_ptr(); + auto zps_ptr = zps.data_ptr(); + auto compute_minmax = [vecsize, scales_ptr, zps_ptr]( + at::BFloat16* ptr, + int M, + int K, + int scale_offset, + int zp_offset, + int ld) { + float min_val = std::numeric_limits::infinity(); + float max_val = -std::numeric_limits::infinity(); + auto in_ptr_ = ptr; + auto min_vec = at::vec::Vectorized(min_val); + auto max_vec = at::vec::Vectorized(max_val); + for (int m = 0; m < M; m++) { + auto in_ptr0 = in_ptr_; + int k; + for (k = 0; k < K / vecsize * vecsize; k += vecsize) { + auto tmp0 = at::vec::Vectorized::loadu(in_ptr0, vecsize); + at::vec::Vectorized res_vec1(0); + at::vec::Vectorized res_vec2(0); + std::tie(res_vec1, res_vec2) = at::vec::convert_bfloat16_float(tmp0); + auto tmp1 = res_vec1; + min_vec = at::vec::minimum(min_vec, tmp1); + max_vec = at::vec::maximum(tmp1, max_vec); + in_ptr0 += vecsize; + } + for (; k < K; k++) { + auto tmp0 = in_ptr0[k]; + min_val = std::min(min_val, (float)tmp0); + max_val = std::max(max_val, (float)tmp0); + } + in_ptr_ += ld; + } + min_val = min_propagate_nan( + min_val, + at::vec::vec_reduce_all( + [](at::vec::Vectorized& x, at::vec::Vectorized& y) { + return at::vec::minimum(x, y); + }, + min_vec)); + max_val = max_propagate_nan( + max_val, + at::vec::vec_reduce_all( + [](at::vec::Vectorized& x, at::vec::Vectorized& y) { + return at::vec::maximum(x, y); + }, + max_vec)); + scales_ptr[scale_offset] = (max_val - min_val) / 255.0f; + zps_ptr[zp_offset] = + (int32_t)(-std::nearbyint(min_val / scales_ptr[scale_offset])); + }; + if (quant_a_mode == QUANT_A_PER_K_BLOCK) { +#pragma omp parallel for + for (int kc = 0; kc < Kc; kc++) { + int offset = kc * quant_block_k; + int block_k = std::min(quant_block_k, K - offset); + compute_minmax(in_ptr + offset, M, block_k, kc, kc, K); + } + } else if (quant_a_mode == QUANT_A_PER_M) { +#pragma omp parallel for + for (int m = 0; m < M; m++) { + int offset = m * K; + compute_minmax(in_ptr + offset, 1, K, m, m, K); + } + } else { +#pragma omp parallel for collapse(2) + for (int m = 0; m < M; m++) { + for (int kc = 0; kc < Kc; kc++) { + auto in_ptr0 = in_ptr + m * K + kc * quant_block_k; + auto scale_offset = m * Kc + kc; + auto zp_offset = m * Kc + kc; + int block_k = std::min(quant_block_k, K - kc * quant_block_k); + compute_minmax(in_ptr0, 1, block_k, scale_offset, zp_offset, K); + } + } + } + return std::make_pair( + std::move(scales), std::move(zps)); +} + template at::Tensor quantize_per_tensor(const at::Tensor& t, float scale, int32_t zp) { // TODO(jgong5): optimize me @@ -2169,9 +2514,9 @@ at::Tensor quantize_per_tensor( i0 += static_cast(1)) { auto tmp0 = in_ptr0[static_cast(i0)]; auto tmp1 = static_cast(tmp0); - auto tmp2 = static_cast(0.05); + auto tmp2 = static_cast(scale); auto tmp3 = tmp1 / tmp2; - auto tmp4 = static_cast(1.0); + auto tmp4 = static_cast(zp); auto tmp5 = tmp3 + tmp4; auto tmp6 = std::nearbyint(tmp5); auto tmp7 = static_cast(tmp6); @@ -2199,6 +2544,144 @@ at::Tensor quantize_per_tensor( #endif } +template +at::Tensor quantize_per_block( + const at::Tensor& t, + const at::Tensor& scale, + const at::Tensor& zp, + int quant_block_k, + int quant_a_mode) { + int block_k = quant_block_k; + auto grouped = t.view({-1, t.size(-1) / block_k, block_k}); + at::Tensor out; + if (quant_a_mode == QUANT_A_PER_K_BLOCK) { + out = at::clamp( + at::round(grouped / scale.unsqueeze(1)) + zp.unsqueeze(1), 0, 255); + } else if (quant_a_mode == QUANT_A_PER_M) { + out = at::clamp( + at::round(grouped / scale.unsqueeze(1).unsqueeze(2)) + + zp.unsqueeze(1).unsqueeze(2), + 0, + 255); + } else { + out = at::clamp( + at::round(grouped / scale.unsqueeze(-1)) + zp.unsqueeze(-1), 0, 255); + } + return out.to(at::kByte); +} + +template <> +at::Tensor quantize_per_block( + const at::Tensor& t, + const at::Tensor& scale, + const at::Tensor& zp, + int quant_block_k, + int quant_a_mode) { + // t is shape of [M, K] and contiguous tensor + int64_t M = t.size(0); + int64_t K = t.size(1); + at::Tensor out = at::empty_like(t, at::kByte); + int Kc = (K + quant_block_k - 1) / quant_block_k; + auto scale_ptr = scale.data_ptr(); + auto zp_ptr = zp.data_ptr(); + auto in_ptr = t.data_ptr(); + auto out_ptr = out.data_ptr(); + auto vecsize = at::vec::Vectorized::size(); + auto quantize_block = [vecsize]( + at::BFloat16* in_ptr, + uint8_t* out_ptr, + int block_k, + float scale_, + int zp_) { + int k; + for (k = 0; k < block_k / vecsize * vecsize; k += vecsize) { + auto in_ptr0 = in_ptr + k; + auto out_ptr0 = out_ptr + k; + auto tmp0 = at::vec::Vectorized::loadu(in_ptr0, vecsize); + at::vec::Vectorized res_vec1(0); + at::vec::Vectorized res_vec2(0); + std::tie(res_vec1, res_vec2) = at::vec::convert_bfloat16_float(tmp0); + auto tmp1 = res_vec1; + auto tmp2 = at::vec::Vectorized(static_cast(scale_)); + auto tmp3 = tmp1 / tmp2; + auto tmp4 = at::vec::Vectorized(static_cast(zp_)); + auto tmp5 = tmp3 + tmp4; + auto tmp6 = tmp5.round(); + auto tmp7 = (tmp6); + auto tmp8 = at::vec::Vectorized(static_cast(0.0)); + auto tmp9 = at::vec::maximum(tmp7, tmp8); + auto tmp10 = at::vec::Vectorized(static_cast(255.0)); + auto tmp11 = at::vec::minimum(tmp9, tmp10); + auto tmp12 = (tmp11); + auto tmp13 = at::vec::convert_float_to_uint8(tmp12); + tmp13.store(out_ptr0, vecsize); + } + for (; k < block_k; k++) { + auto tmp0 = in_ptr[k]; + auto tmp1 = static_cast(tmp0); + auto tmp2 = static_cast(scale_); + auto tmp3 = tmp1 / tmp2; + auto tmp4 = static_cast(zp_); + auto tmp5 = tmp3 + tmp4; + auto tmp6 = std::nearbyint(tmp5); + auto tmp7 = static_cast(tmp6); + auto tmp8 = static_cast(0.0); + auto tmp9 = 0; + if (at::_isnan(tmp7)) { + tmp9 = tmp7; + } + tmp9 = tmp7 > tmp8 ? tmp7 : tmp8; + auto tmp10 = static_cast(255.0); + auto tmp11 = 0; + if (at::_isnan(tmp9)) { + tmp11 = tmp9; + } + tmp11 = tmp9 < tmp10 ? tmp9 : tmp10; + auto tmp12 = static_cast(tmp11); + auto tmp13 = static_cast(tmp12); + out_ptr[k] = tmp13; + } + }; + if (quant_a_mode == QUANT_A_PER_K_BLOCK) { +#pragma omp parallel for collapse(2) + for (int m = 0; m < M; m++) { + for (int kc = 0; kc < Kc; kc++) { + auto in_ptr0 = in_ptr + m * K + kc * quant_block_k; + auto out_ptr0 = out_ptr + m * K + kc * quant_block_k; + auto scale_ = scale_ptr[kc]; + auto zp_ = zp_ptr[kc]; + int block_k = std::min(quant_block_k, (int)K - kc * quant_block_k); + quantize_block(in_ptr0, out_ptr0, block_k, scale_, zp_); + } + } + } else if (quant_a_mode == QUANT_A_PER_M) { +#pragma omp parallel for collapse(2) + for (int m = 0; m < M; m++) { + for (int kc = 0; kc < Kc; kc++) { + auto in_ptr0 = in_ptr + m * K + kc * quant_block_k; + auto out_ptr0 = out_ptr + m * K + kc * quant_block_k; + auto scale_ = scale_ptr[m]; + auto zp_ = zp_ptr[m]; + int block_k = std::min(quant_block_k, (int)K - kc * quant_block_k); + quantize_block(in_ptr0, out_ptr0, block_k, scale_, zp_); + } + } + } else { +#pragma omp parallel for collapse(2) + for (int m = 0; m < M; m++) { + for (int kc = 0; kc < Kc; kc++) { + auto in_ptr0 = in_ptr + m * K + kc * quant_block_k; + auto out_ptr0 = out_ptr + m * K + kc * quant_block_k; + auto scale_ = scale_ptr[m * Kc + kc]; + auto zp_ = zp_ptr[m * Kc + kc]; + int block_k = std::min(quant_block_k, (int)K - kc * quant_block_k); + quantize_block(in_ptr0, out_ptr0, block_k, scale_, zp_); + } + } + } + return out; +} + /** * @brief quantized linear with weight in affine quantized format (scale + * zero-point) but activation in floating point format. @@ -2227,7 +2710,10 @@ at::Tensor qlinear_woq_affine( int64_t lowp_mode, int64_t num_concats, int64_t fusion_type, - const TensorList& others_list) { + const TensorList& others_list, + int64_t quant_a_mode = -1, + int64_t quant_w_mode = 0, + int64_t quant_block_k = 0) { const int64_t k_splits = 0; // int8_idx is only valid with zp_list when lowp_mode == LOWP_MODE_INT8 constexpr size_t fp32_idx = 0, fp16_idx = 1, bf16_idx = 2, int8_idx = 3; @@ -2246,10 +2732,20 @@ at::Tensor qlinear_woq_affine( out_sizes.back() = N; auto y = at::empty(out_sizes, x.options()); auto x_reshape = x.reshape({M, K}); - enumerate_dispatcher:: + product_dispatcher< + std::tuple, + std::tuple< + enumerate_dispatcher< + at::ScalarType, + at::kFloat, + at::kBFloat16, + at::kHalf>, + range_dispatcher>>:: call( - x.scalar_type(), - [&](auto act_dtype) { + std::make_tuple(x.scalar_type(), quant_w_mode), + [&](auto tuple) { + auto act_dtype = std::get<0>(tuple); + auto quant_w_mode_ = std::get<1>(tuple); using act_type = typename c10::impl::ScalarTypeToCPPType::type; auto try_compute_in_half = [&]() { @@ -2260,7 +2756,9 @@ at::Tensor qlinear_woq_affine( /*TGemmOut*/ half, act_type, half, - half>( + half, + UNQUANT_A, + quant_w_mode_>( x_reshape, qw, scales_list[fp16_idx], @@ -2271,7 +2769,8 @@ at::Tensor qlinear_woq_affine( k_splits, num_concats, fusion_type, - others_list); + others_list, + quant_block_k); #else qlinear_woq_affine_impl< act_type, @@ -2279,7 +2778,9 @@ at::Tensor qlinear_woq_affine( /*TGemmOut*/ float, act_type, float, - float>( + float, + UNQUANT_A, + quant_w_mode_>( x_reshape, qw, scales_list[fp32_idx], @@ -2290,7 +2791,8 @@ at::Tensor qlinear_woq_affine( k_splits, num_concats, fusion_type, - others_list); + others_list, + quant_block_k); #endif }; if (lowp_mode == LOWP_MODE_NONE) { @@ -2303,7 +2805,9 @@ at::Tensor qlinear_woq_affine( /*TGemmOut*/ float, bfloat16, bfloat16, - bfloat16>( + bfloat16, + UNQUANT_A, + quant_w_mode_>( x_reshape, qw, scales_list[bf16_idx], @@ -2314,7 +2818,8 @@ at::Tensor qlinear_woq_affine( k_splits, num_concats, fusion_type, - others_list); + others_list, + quant_block_k); } else { qlinear_woq_affine_impl< float, @@ -2322,7 +2827,9 @@ at::Tensor qlinear_woq_affine( /*TGemmOut*/ float, float, float, - float>( + float, + UNQUANT_A, + quant_w_mode_>( x_reshape, qw, scales_list[fp32_idx], @@ -2333,7 +2840,8 @@ at::Tensor qlinear_woq_affine( k_splits, num_concats, fusion_type, - others_list); + others_list, + quant_block_k); } } else if (lowp_mode == LOWP_MODE_FP16) { try_compute_in_half(); @@ -2346,7 +2854,9 @@ at::Tensor qlinear_woq_affine( /*TGemmOut*/ float, act_type, bfloat16, - bfloat16>( + bfloat16, + UNQUANT_A, + quant_w_mode_>( x_reshape, qw, scales_list[bf16_idx], @@ -2357,74 +2867,159 @@ at::Tensor qlinear_woq_affine( k_splits, num_concats, fusion_type, - others_list); + others_list, + quant_block_k); } else { try_compute_in_half(); } } else { TLA_ASSERT(lowp_mode == LOWP_MODE_INT8, "invalid lowp_mode"); TLA_ASSERT(is_int4, "LOWP_MODE_INT8 only support is_int4=true"); - float scale_a; - int32_t zp_a; - auto x_reshape_contig = x_reshape.contiguous(); - compute_int8_qparams_per_tensor( - x_reshape_contig, &scale_a, &zp_a); - auto x_quantized = quantize_per_tensor( - x_reshape_contig, scale_a, zp_a); - qlinear_woq_affine_impl< - uint8_t, - uint8_t, - /*TGemmOut*/ float, - act_type, - float, - int8_t>( - x_quantized, - qw, - scales_list[fp32_idx], - zp_list[int8_idx], - biases[fp32_idx], - y, - is_int4, - k_splits, - num_concats, - fusion_type, - others_list, - scale_a, - zp_a); + if (quant_a_mode == QUANT_A_PER_TENSOR) { + float scale_a; + int32_t zp_a; + auto x_reshape_contig = x_reshape.contiguous(); + compute_int8_qparams_per_tensor( + x_reshape_contig, &scale_a, &zp_a); + auto x_quantized = quantize_per_tensor( + x_reshape_contig, scale_a, zp_a); + auto scale_a_t = at::full({1}, scale_a, at::kFloat); + auto zp_a_t = at::full({1}, zp_a, at::kInt); + qlinear_woq_affine_impl< + uint8_t, + uint8_t, + /*TGemmOut*/ float, + act_type, + float, + int8_t, + QUANT_A_PER_TENSOR, + quant_w_mode_>( + x_quantized, + qw, + scales_list[fp32_idx], + zp_list[int8_idx], + biases[fp32_idx], + y, + is_int4, + k_splits, + num_concats, + fusion_type, + others_list, + quant_block_k, + scale_a_t, + zp_a_t); + } else { + auto block_k = w_sizes[2]; + auto x_reshape_contig = x_reshape.contiguous(); + auto [scale_a, zp_a] = + compute_int8_qparams_per_block( + x_reshape_contig, quant_block_k, quant_a_mode); + auto x_quantized = quantize_per_block( + x_reshape_contig, + scale_a, + zp_a, + quant_block_k, + quant_a_mode); + range_dispatcher< + long, + QUANT_A_PER_K_BLOCK, + QUANT_A_PER_M_K_BLOCK>:: + call( + quant_a_mode, + [&](auto quant_a_mode_) { + qlinear_woq_affine_impl< + uint8_t, + uint8_t, + /*TGemmOut*/ float, + act_type, + float, + int8_t, + quant_a_mode_, + quant_w_mode_>( + x_quantized, + qw, + scales_list[fp32_idx], + zp_list[int8_idx], + biases[fp32_idx], + y, + is_int4, + k_splits, + num_concats, + fusion_type, + others_list, + quant_block_k, + scale_a, + zp_a); + }, + [&](auto quant_a_mode_) { failing_fallback(); }); + } } }, - failing_fallback); + [](auto tuple) { failing_fallback(); }); return y; } else { TLA_ASSERT( qw.dim() == 2, "weight must be in 4D blocked format or 2D plain format"); + auto K = x.size(-1); + auto M = x.numel() / K; + auto N = qw.size(0); auto compute_dtype = x.scalar_type(); if (lowp_mode == LOWP_MODE_FP16) { compute_dtype = at::kHalf; } else if (lowp_mode == LOWP_MODE_BF16) { - compute_dtype = at::kBFloat16; + compute_dtype = K >= SMALL_BATCH_THRESHOLD ? at::kBFloat16 : at::kHalf; } + at::Tensor scale, zp; + scale = scales_list[fp32_idx].unsqueeze(-1); + zp = zp_list[fp32_idx].unsqueeze(-1); auto w = [&]() { if (is_int4) { using namespace at::indexing; - auto w_int8 = at::empty( - {qw.size(0), qw.size(1) * 2}, qw.options().dtype(at::kByte)); + auto w_int8 = + at::empty({N, qw.size(1) * 2}, qw.options().dtype(at::kByte)); w_int8.index({Slice(), Slice(None, None, 2)}) .copy_(qw.bitwise_and(0xf)); w_int8.index({Slice(), Slice(1, None, 2)}) .copy_(qw.bitwise_right_shift(4)); - return (w_int8.to(at::kFloat) - zp_list[fp32_idx]) * - scales_list[fp32_idx]; + at::Tensor dqw; + if (quant_w_mode == 0) { + dqw = (w_int8.to(at::kFloat) - zp) * scale; + } else { + int64_t num_blocks = scale.size(-2); + auto w_int8_view = w_int8.view({N, num_blocks, -1}); + dqw = (w_int8_view.to(at::kFloat) - zp) * scale; + dqw = dqw.view({N, -1}); + } + if (K != qw.size(1) * 2) { + TORCH_CHECK( + K < qw.size(1) * 2, + 'WOQ Linear kernel: Unexpected weight shape'); + auto dqw_narrowed = dqw.narrow(1, 0, K); + return dqw_narrowed; + } + return dqw; } else { - return (qw.to(at::kFloat) - zp_list[fp32_idx]) * - scales_list[fp32_idx]; + at::Tensor dqw; + if (quant_w_mode == 0) { + dqw = (qw.to(at::kFloat) - zp) * scale; + } else { + int64_t num_blocks = scale.size(-2); + auto w_int8_view = qw.view({N, num_blocks, -1}); + dqw = (w_int8_view.to(at::kFloat) - zp) * scale; + dqw = dqw.view({N, -1}); + } + return dqw; } }() .to(compute_dtype); - auto x_fp = x.to(compute_dtype); - auto y = at::linear(x_fp, w); + auto x_reshape = x.reshape({M, K}); + auto x_fp = x_reshape.to(compute_dtype); + // PyTorch does not support computing in half yet + auto y = compute_dtype == at::kHalf + ? at::linear(x_fp.to(c10::kFloat), w.to(c10::kFloat)) + : at::linear(x_fp, w); if (biases[0].defined()) { auto b_index = compute_dtype == at::kFloat ? fp32_idx : compute_dtype == at::kHalf ? fp16_idx @@ -2435,8 +3030,11 @@ at::Tensor qlinear_woq_affine( y = at::gelu(y); } else if (fusion_type == FUSE_ADD || fusion_type == FUSE_ADD_ADD) { for (auto& tin : others_list) - y = at::add(y, tin); + y = at::add(y, tin.view(y.sizes())); } + auto out_sizes = x.sizes().vec(); + out_sizes.back() = N; + y = y.view(out_sizes); if (num_concats > 1) { y = y.view({-1, num_concats, y.size(-1) / num_concats}) .transpose(0, 1) @@ -2449,7 +3047,7 @@ at::Tensor qlinear_woq_affine( #else // defined(CPU_CAPABILITY_AVX512_FP16) && defined(COMPILER_PREREQ_MET) -static at::Tensor empty_tensor; +#define SMALL_BATCH_THRESHOLD 32 at::Tensor qlinear_woq_affine( const at::Tensor& x, @@ -2461,8 +3059,97 @@ at::Tensor qlinear_woq_affine( int64_t lowp_mode, int64_t num_concats, int64_t fusion_type, - const TensorList& others_list) { - return empty_tensor; + const TensorList& others_list, + int64_t quant_a_mode = -1, + int64_t quant_w_mode = 0, + int64_t quant_block_k = 0) { + constexpr size_t fp32_idx = 0, fp16_idx = 1, bf16_idx = 2, int8_idx = 3; + auto biases = bias_list.empty() + ? TensorList({at::Tensor(), at::Tensor(), at::Tensor()}) + : bias_list; + TLA_ASSERT( + qw.dim() == 2, "weight must be in 4D blocked format or 2D plain format"); + auto K = x.size(-1); + auto M = x.numel() / K; + auto N = qw.size(0); + auto compute_dtype = x.scalar_type(); + if (lowp_mode == LOWP_MODE_FP16) { + compute_dtype = at::kHalf; + } else if (lowp_mode == LOWP_MODE_BF16) { + compute_dtype = K >= SMALL_BATCH_THRESHOLD ? at::kBFloat16 : at::kHalf; + } + at::Tensor scale, zp; + scale = scales_list[fp32_idx].unsqueeze(-1); + zp = zp_list[fp32_idx].unsqueeze(-1); + auto w = + [&]() { + if (is_int4) { + using namespace at::indexing; + auto w_int8 = + at::empty({N, qw.size(1) * 2}, qw.options().dtype(at::kByte)); + w_int8.index({Slice(), Slice(None, None, 2)}) + .copy_(qw.bitwise_and(0xf)); + w_int8.index({Slice(), Slice(1, None, 2)}) + .copy_(qw.bitwise_right_shift(4)); + at::Tensor dqw; + if (quant_w_mode == 0) { + dqw = (w_int8.to(at::kFloat) - zp) * scale; + } else { + int64_t num_blocks = scale.size(-2); + auto w_int8_view = w_int8.view({N, num_blocks, -1}); + dqw = (w_int8_view.to(at::kFloat) - zp) * scale; + dqw = dqw.view({N, -1}); + } + if (K != qw.size(1) * 2) { + TORCH_CHECK( + K < qw.size(1) * 2, + 'WOQ Linear kernel: Unexpected weight shape'); + auto dqw_narrowed = dqw.narrow(1, 0, K); + return dqw_narrowed; + } + return dqw; + } else { + at::Tensor dqw; + if (quant_w_mode == 0) { + dqw = (qw.to(at::kFloat) - zp) * scale; + } else { + int64_t num_blocks = scale.size(-2); + auto w_int8_view = qw.view({N, num_blocks, -1}); + dqw = (w_int8_view.to(at::kFloat) - zp) * scale; + dqw = dqw.view({N, -1}); + } + return dqw; + } + }() + .to(compute_dtype); + auto x_reshape = x.reshape({M, K}); + auto x_fp = x_reshape.to(compute_dtype); + // PyTorch does not support computing in half yet + auto y = compute_dtype == at::kHalf + ? at::linear(x_fp.to(c10::kFloat), w.to(c10::kFloat)) + : at::linear(x_fp, w); + if (biases[0].defined()) { + auto b_index = compute_dtype == at::kFloat ? fp32_idx + : compute_dtype == at::kHalf ? fp16_idx + : bf16_idx; + y = at::add(y, biases[b_index]); + } + if (fusion_type == FUSE_GELU) { + y = at::gelu(y); + } else if (fusion_type == FUSE_ADD || fusion_type == FUSE_ADD_ADD) { + for (auto& tin : others_list) + y = at::add(y, tin.view(y.sizes())); + } + auto out_sizes = x.sizes().vec(); + out_sizes.back() = N; + y = y.view(out_sizes); + if (num_concats > 1) { + y = y.view({-1, num_concats, y.size(-1) / num_concats}) + .transpose(0, 1) + .contiguous() + .view({-1, y.size(-1)}); + } + return y.to(x.scalar_type()); } at::Tensor qlinear_woq_pack( @@ -2471,14 +3158,14 @@ at::Tensor qlinear_woq_pack( size_t block_n, size_t block_k, int64_t lowp_mode) { - return empty_tensor; + return qw; } at::Tensor qlinear_woq_unpack( const at::Tensor& qw_packed, bool is_int4, int64_t lowp_mode) { - return empty_tensor; + return qw_packed; } #endif // defined(CPU_CAPABILITY_AVX512_FP16) && defined(COMPILER_PREREQ_MET) diff --git a/csrc/cpu/jit/cpu/kernels/ContextLinearWoq.h b/csrc/cpu/jit/cpu/kernels/ContextLinearWoq.h index 595a14127..9fc9a5188 100644 --- a/csrc/cpu/jit/cpu/kernels/ContextLinearWoq.h +++ b/csrc/cpu/jit/cpu/kernels/ContextLinearWoq.h @@ -7,6 +7,7 @@ namespace cpu { namespace detail { struct ContextLinearWoq final { at::Tensor at_weight_; + std::vector weight_shape_; c10::optional at_bias_; // The list contains three dtype versions of bias, scale and zp // i.e., fp32, fp16, bf16 @@ -15,37 +16,75 @@ struct ContextLinearWoq final { std::vector scales_list_; std::vector zero_points_list_; bool is_int4_; + int64_t group_size_; int64_t lowp_mode_; int64_t num_concats_; - // Original weight shape. Weight may be padded after packing - c10::optional> orig_wei_shape_; + int64_t act_quant_mode_; ContextLinearWoq() = delete; ContextLinearWoq( at::Tensor&& at_weight, + std::vector&& weight_shape, at::Tensor&& scales_float, at::Tensor&& zero_point_float, c10::optional&& bias, bool is_int4 = false, + int64_t group_size = -1, int64_t lowp_mode = 0, int64_t num_concats = 1, - c10::optional>&& orig_wei_shape = c10::nullopt) + int64_t act_quant_mode = 0) : at_weight_(std::move(at_weight)), + weight_shape_(std::move(weight_shape)), at_bias_(std::move(bias)), is_int4_(is_int4), + group_size_(group_size), lowp_mode_(lowp_mode), num_concats_(num_concats), - orig_wei_shape_(std::move(orig_wei_shape)) { + act_quant_mode_(act_quant_mode) { // Make three dtype versions of scale, zp and bias // There is one more dtype for zp - auto scales_fp16 = scales_float.to(c10::kHalf); - auto scales_bf16 = scales_float.to(c10::kBFloat16); - scales_list_ = {scales_float, scales_fp16, scales_bf16}; - auto zp_fp16 = zero_point_float.to(c10::kHalf); - auto zp_bf16 = zero_point_float.to(c10::kBFloat16); - auto zp_int8 = zero_point_float.to(c10::kChar); - zero_points_list_ = {zero_point_float, zp_fp16, zp_bf16, zp_int8}; + if (group_size > 0) { + // Reshape scales/zps for data locality in kernel + // [N, #block_k] -> [N / block_n, block_n, #block_k] + // -> [#block_n, #block_k, block_n] + at::Tensor scales_perm, zp_perm; + if (at_weight_.dim() == 4) { + // packed weight in 4d (Nc, Kc, block_k, block_n) + int64_t block_n = at_weight_.size(-1); + if (is_int4) { + block_n *= 2; + } + TORCH_CHECK(scales_float.size(0) % block_n == 0); + std::vector reshape_dim = { + scales_float.size(0) / block_n, block_n, scales_float.size(1)}; + scales_perm = scales_float.view(reshape_dim) + .permute({0, 2, 1}) + .contiguous() + .to(c10::kFloat); + zp_perm = + zero_point_float.view(reshape_dim).permute({0, 2, 1}).contiguous(); + } else { + scales_perm = scales_float.to(c10::kFloat); + zp_perm = zero_point_float; + } + auto scales_fp16 = scales_perm.to(c10::kHalf); + auto scales_bf16 = scales_perm.to(c10::kBFloat16); + scales_list_ = {scales_perm, scales_fp16, scales_bf16}; + auto zp_fp16 = zp_perm.to(c10::kHalf); + auto zp_bf16 = zp_perm.to(c10::kBFloat16); + auto zp_int8 = zp_perm.to(c10::kChar); + zero_points_list_ = {zp_perm, zp_fp16, zp_bf16, zp_int8}; + } else { + auto scales_fp32 = scales_float.to(c10::kFloat); + auto scales_fp16 = scales_float.to(c10::kHalf); + auto scales_bf16 = scales_float.to(c10::kBFloat16); + scales_list_ = {scales_fp32, scales_fp16, scales_bf16}; + auto zp_fp16 = zero_point_float.to(c10::kHalf); + auto zp_bf16 = zero_point_float.to(c10::kBFloat16); + auto zp_int8 = zero_point_float.to(c10::kChar); + zero_points_list_ = {zero_point_float, zp_fp16, zp_bf16, zp_int8}; + } if (at_bias_.has_value() && at_bias_.value().defined()) { auto& orig_bias = at_bias_.value(); auto bias_fp32 = at_bias_.value().to(c10::kFloat); diff --git a/csrc/cpu/jit/cpu/kernels/LinearWoqPacked.cpp b/csrc/cpu/jit/cpu/kernels/LinearWoqPacked.cpp index 37ed51d86..67e6026a9 100644 --- a/csrc/cpu/jit/cpu/kernels/LinearWoqPacked.cpp +++ b/csrc/cpu/jit/cpu/kernels/LinearWoqPacked.cpp @@ -11,16 +11,32 @@ namespace woq_linear { c10::intrusive_ptr createWoqLinearPrePackOpContext( at::Tensor&& weight, + std::vector&& weight_shape, + at::Tensor&& scales, + at::Tensor&& zero_points, c10::optional&& bias, c10::optional batch_size, + bool is_int4, + int64_t group_size, int64_t lowp_mode, - int64_t num_concats) { + int64_t num_concats, + int64_t act_quant_mode) { RECORD_FUNCTION( "ipex_prepack::createWoqLinearPrePackOpContext", c10::ArrayRef({})); return IpexWoqLinearOpContext::create_context( - std::move(weight), std::move(bias), batch_size, lowp_mode, num_concats); + std::move(weight), + std::move(weight_shape), + std::move(scales), + std::move(zero_points), + std::move(bias), + batch_size, + is_int4, + group_size, + lowp_mode, + num_concats, + act_quant_mode); } c10::intrusive_ptr createWoqLinearPrePackOpContextInt4( @@ -29,77 +45,126 @@ c10::intrusive_ptr createWoqLinearPrePackOpContextInt4( at::Tensor&& zero_points, c10::optional&& bias, c10::optional batch_size, + int64_t group_size, // group_size along input channel int64_t lowp_mode, - int64_t num_concats) { + int64_t num_concats, + int64_t act_quant_mode) { RECORD_FUNCTION( "ipex_prepack::createWoqLinearPrePackOpContextInt4", c10::ArrayRef({})); + // clang-format off // From - // Weight dtype = int32 (uint4 * 8), scale dtype = fp16, zero points dtype = - // int32 (int4 * 8) To Weight dtype = quint4x2, scale dtype = fp32, zero - // points dtype = fp32 There might be an extra output channel in weight and - // scales bool extra_o_channel = false; // scales.numel() > - // zero_points.numel() * 8; + // Weight dtype = int32 (uint4 * 8) or uint8 (4bit * 2), scale dtype = fp16, + // zero points dtype = int32 (int4 * 8) + // To + // Weight dtype = quint4x2, scale dtype = fp32, zero points dtype = fp32 + // There might be an extra output channel in weight and scales. + // clang-format on auto scales_fp32 = scales.squeeze().to(c10::ScalarType::Float); - auto zp_fp32 = zero_points.scalar_type() == c10::kFloat - ? zero_points.squeeze() - : at::empty_like(scales_fp32); - // Convert compressed zero points to float + at::Tensor zp_fp32; + if (zero_points.scalar_type() == c10::kInt) { - if (zero_points.numel() == scales_fp32.numel() / 8 || - zero_points.numel() == scales_fp32.numel() / 8 + 1) { + // Two cases: (1) each int32 contains 8 values of zero points + // (2) each int32 is a single value of zero point + if (zero_points.numel() != scales_fp32.numel()) { + // Assume group_size > 0 and zero point data are compressed + TORCH_CHECK(scales_fp32.dim() == 2 && zero_points.dim() == 2) + TORCH_CHECK(scales_fp32.size(0) == zero_points.size(0)) + auto num_row = scales_fp32.size(0); + auto num_col = scales_fp32.size(1); + auto num_col_zp = zero_points.size(1); + // Convert compressed zero points to float + zp_fp32 = at::empty_like(scales_fp32); float* zp_fp32_ptr = reinterpret_cast(zp_fp32.data_ptr()); uint32_t* zp_int32_ptr = reinterpret_cast(zero_points.data_ptr()); - for (size_t i = 0; i < zero_points.numel(); ++i) { - uint32_t zp_uint4x8 = zp_int32_ptr[i]; - for (size_t j = 0; j < 8; ++j) { - zp_fp32_ptr[i * 8 + j] = (float)((zp_uint4x8 >> (j * 4)) & 0xf); + for (size_t i = 0; i < num_row; ++i) { + for (size_t j = 0; j < num_col; ++j) { + zp_fp32_ptr[i * num_col + j] = + (float)((zp_int32_ptr[i * num_col_zp + j / 8] >> ((j % 8) * 4)) & 0xf); } } } else if (zero_points.numel() == scales_fp32.numel()) { - zp_fp32 = zero_points.to(c10::kFloat).squeeze(); + // Not compressed + zp_fp32 = zero_points.squeeze().to(c10::kFloat); } else { TORCH_CHECK(false, "IPEX WOQ INT4: unexpected zero points size"); } + } else { + zp_fp32 = zero_points.squeeze().to(c10::kFloat); } // Support two cases here: // 1. fp32/bf16 weight after calibration - // 2. int4 weight after calibration, quantized and compressed, as int32 + // 2. int4 weight after calibration, quantized and compressed, as int32/uint8 at::Tensor weight_int4; - if (weight.scalar_type() == c10::kInt) { + std::vector weight_shape(2); + if (weight.scalar_type() == c10::kInt || weight.scalar_type() == c10::kByte) { // Create empty weight with desired options then copy data int64_t N = weight.size(0); - int64_t K_int32 = weight.size(1); - int64_t K = K_int32 * 8; // int32 = int4 * 8 - std::vector weight_size = {N, K}; - // Create an empty quint4x2 weight with scales and zero points - weight_int4 = at::_empty_per_channel_affine_quantized( - weight_size, - scales_fp32, - zp_fp32, - 0, - device(c10::kCPU).dtype(c10::kQUInt4x2)); + int64_t K_compressed = weight.size(1); + int64_t K_uint8 = + weight.scalar_type() == c10::kInt ? K_compressed * 8 / 2 : K_compressed; + weight_shape[0] = N; + weight_shape[1] = K_uint8 * 2; + std::vector weight_size = {N, K_uint8}; + // Create an empty uint8 weight to hold int4 data + weight_int4 = at::empty(weight_size, device(c10::kCPU).dtype(c10::kByte)); + auto sizeof_dtype = weight.scalar_type() == c10::kInt + ? sizeof(uint32_t) + : sizeof(unsigned char); std::memcpy( weight_int4.data_ptr(), weight.data_ptr(), - weight.numel() * sizeof(uint32_t)); - } else if (weight.scalar_type() == c10::kBFloat16) { - // Load bf16 weight and quantize + weight.numel() * sizeof_dtype); + } else if ( + weight.scalar_type() == c10::kBFloat16 || + weight.scalar_type() == c10::kFloat || + weight.scalar_type() == c10::kHalf) { + weight_shape[0] = weight.size(0); + weight_shape[1] = weight.size(1); auto weight_fp32 = weight.to(c10::kFloat); - weight_int4 = at::quantize_per_channel( - weight_fp32, scales_fp32, zp_fp32, 0, c10::kQUInt4x2); - } else if (weight.scalar_type() == c10::kFloat) { - weight_int4 = at::quantize_per_channel( - weight, scales_fp32, zp_fp32, 0, c10::kQUInt4x2); + at::Tensor weight_int4_as_uint8; + if (group_size > 0) { + auto weight_view = + weight_fp32.view({-1, weight.size(1) / group_size, group_size}); + auto scale_view = scales_fp32.unsqueeze(2); + auto zp_view = zp_fp32.unsqueeze(2); + weight_int4_as_uint8 = + at::round(weight_view / scale_view + zp_view).to(c10::kByte); + } else { + auto scale_view = scales_fp32.unsqueeze(1); + auto zp_view = zp_fp32.unsqueeze(1); + weight_int4_as_uint8 = + at::round(weight / scale_view + zp_view).to(c10::kByte); + } + weight_int4_as_uint8 = weight_int4_as_uint8.view(weight_shape); + using at::indexing::None; + using at::indexing::Slice; + at::Tensor even_columns = + weight_int4_as_uint8.index({Slice(), Slice(1, None, 2)}); + even_columns = even_columns.bitwise_left_shift(4); + at::Tensor odd_columns = + weight_int4_as_uint8.index({Slice(), Slice(None, None, 2)}); + weight_int4 = even_columns.bitwise_or(odd_columns); + } else { + TORCH_CHECK( + false, + "IPEX WOQ INT4: unexpected weight data type: ", + weight.scalar_type()); } return IpexWoqLinearOpContext::create_context( std::move(weight_int4), + std::move(weight_shape), + std::move(scales_fp32), + std::move(zp_fp32), std::move(bias), batch_size, + /*is_int4*/ true, + group_size, lowp_mode, - num_concats); + num_concats, + act_quant_mode); } at::Tensor woq_linear_run( @@ -113,73 +178,83 @@ at::Tensor woq_linear_run( ContextLinearWoq create( at::Tensor& weight, + std::vector& weight_shape, at::Tensor& scales, at::Tensor& zero_points, const c10::optional& bias, const c10::optional batch_size, + bool is_int4, + int64_t group_size, int64_t lowp_mode, - int64_t num_concats) { - auto packed_weight = - woq_linear_pack_weight(weight, scales, zero_points, lowp_mode); - bool is_int4 = weight.scalar_type() == c10::kQUInt4x2; + int64_t num_concats, + int64_t act_quant_mode) { + auto packed_weight = woq_linear_pack_weight( + weight, weight_shape, is_int4, group_size, lowp_mode); auto packed_shape = packed_weight.sizes(); int64_t N = weight.size(0); int64_t K = weight.size(1); - bool weight_is_padded = (packed_shape.size() == 4 && is_int4 && - packed_shape[0] * packed_shape[3] * 2 != N) || + // If OC is not a multiple of BLOCK_N, it may be padded. + bool oc_is_padded = (packed_shape.size() == 4 && is_int4 && + packed_shape[0] * packed_shape[3] * 2 != N) || (packed_shape.size() == 4 && !is_int4 && packed_shape[0] * packed_shape[3] != N) || (packed_shape.size() == 2 && packed_shape[0] != N); auto zero_points_float = zero_points.to(c10::kFloat); - if (weight_is_padded) { + if (oc_is_padded) { int64_t padded_N = packed_shape.size() == 4 ? (is_int4 ? packed_shape[0] * packed_shape[3] * 2 : packed_shape[0] * packed_shape[3]) : packed_shape[0]; - auto scales_padded = at::pad(scales, {0, padded_N - N}, "constant", 1.f); + std::vector pad_vec = scales.dim() == 1 + ? std::vector({0, padded_N - N}) + : std::vector({0, 0, 0, padded_N - N}); + auto scales_padded = at::pad(scales, pad_vec, "constant", 1.f); auto zero_points_padded = - at::pad(zero_points_float, {0, padded_N - N}, "constant", 0.f); + at::pad(zero_points_float, pad_vec, "constant", 0.f); if (bias.has_value()) { auto bias_padded = at::pad(bias.value(), {0, padded_N - N}, "constant", 0.f); return ContextLinearWoq( std::move(packed_weight), + std::move(weight_shape), std::move(scales_padded), std::move(zero_points_padded), c10::make_optional(bias_padded), is_int4, + group_size, lowp_mode, num_concats, - c10::make_optional(weight.sizes().vec())); + act_quant_mode); } else { return ContextLinearWoq( std::move(packed_weight), + std::move(weight_shape), std::move(scales_padded), std::move(zero_points_padded), c10::nullopt, is_int4, + group_size, lowp_mode, num_concats, - c10::make_optional(weight.sizes().vec())); + act_quant_mode); } } return ContextLinearWoq( std::move(packed_weight), + std::move(weight_shape), std::move(scales), std::move(zero_points_float), bias.has_value() ? c10::make_optional(*bias) : c10::nullopt, is_int4, + group_size, lowp_mode, num_concats, - weight_is_padded ? c10::make_optional(weight.sizes().vec()) - : c10::nullopt); + act_quant_mode); } at::Tensor run(ContextLinearWoq& context, const at::Tensor& input) { // TPP kernel packs weight to 4d (Nc, Kc, block_k, block_n) - auto w_k = context.at_weight_.dim() == 2 - ? context.at_weight_.size(1) - : context.at_weight_.size(1) * context.at_weight_.size(2); + auto w_k = context.weight_shape_[1]; TORCH_CHECK( input.size(input.dim() - 1) == w_k, "WOQ linear: input and weight shapes do not match, got k = ", @@ -188,8 +263,7 @@ at::Tensor run(ContextLinearWoq& context, const at::Tensor& input) { w_k, " respectively."); auto input_ = input.contiguous(); - // if weight is not padded, context.orig_wei_shape_ has no value - if (context.orig_wei_shape_.has_value()) { + if (context.weight_shape_[0] != context.at_weight_.size(0)) { auto res = woq_linear_kernel( input_, context.at_weight_, @@ -197,11 +271,13 @@ at::Tensor run(ContextLinearWoq& context, const at::Tensor& input) { context.zero_points_list_, context.bias_list_, context.is_int4_, + context.group_size_, context.lowp_mode_, - context.num_concats_); + context.num_concats_, + context.act_quant_mode_); // weight shape is [N by K], output shape is [M by N] or [batch by M by N] - int64_t N = context.orig_wei_shape_.value()[0]; - return at::slice(res, /*dim*/ -1, /*start*/ 0, /*end*/ N, /*step*/ 1); + int64_t N = context.weight_shape_[0]; + return at::narrow(res, /*dim*/ -1, /*start*/ 0, /*end*/ N); } return woq_linear_kernel( input_, @@ -210,8 +286,10 @@ at::Tensor run(ContextLinearWoq& context, const at::Tensor& input) { context.zero_points_list_, context.bias_list_, context.is_int4_, + context.group_size_, context.lowp_mode_, - context.num_concats_); + context.num_concats_, + context.act_quant_mode_); } // Called by IpexWoqLinearOpContext::run_eltwise @@ -222,9 +300,7 @@ at::Tensor run_eltwise( const torch::List>& scalars, const c10::optional& algorithm) { // TPP kernel packs weight to 4d (Nc, Kc, block_k, block_n) - auto w_k = context.at_weight_.dim() == 2 - ? context.at_weight_.size(1) - : context.at_weight_.size(1) * context.at_weight_.size(2); + auto w_k = context.weight_shape_[1]; TORCH_CHECK( input.size(input.dim() - 1) == w_k, "WOQ linear: input and weight shapes do not match, got k = ", @@ -243,89 +319,10 @@ at::Tensor run_eltwise( scalars, algorithm, context.is_int4_, - context.lowp_mode_, - context.num_concats_); -} - -// Registered as JIT op -at::Tensor woq_linear_eltwise_run( - const at::Tensor& input, - const at::Tensor& op_context, - const c10::string_view& post_op, - const torch::List>& scalars, - const c10::optional& algorithm) { - static std::map postop_to_record_name_map = { - {"relu", "torch_ipex::woq_linear_relu_run"}, - {"gelu", "torch_ipex::woq_linear_gelu_run"}, - }; - RECORD_FUNCTION( - postop_to_record_name_map[post_op], c10::ArrayRef({})); - return reinterpret_cast( - op_context.data_ptr()[0]) - ->run_eltwise(input, post_op, scalars, algorithm); -} - -// Called by IpexWoqLinearOpContext::run_add -at::Tensor run_add( - ContextLinearWoq& context, - const at::Tensor& input, - at::Tensor& accumu, - const c10::optional& alpha) { - // TPP kernel packs weight to 4d (Nc, Kc, block_k, block_n) - auto w_k = context.at_weight_.dim() == 2 - ? context.at_weight_.size(1) - : context.at_weight_.size(1) * context.at_weight_.size(2); - TORCH_CHECK( - input.size(input.dim() - 1) == w_k, - "WOQ linear: input and weight shapes do not match, got k = ", - input.size(input.dim() - 1), - " and ", - w_k, - " respectively."); - auto input_ = input.contiguous(); - return woq_linear_add_kernel( - input_, - context.at_weight_, - context.scales_list_, - context.zero_points_list_, - context.bias_list_, - context.is_int4_, + context.group_size_, context.lowp_mode_, context.num_concats_, - accumu, - alpha); -} - -// Called by IpexWoqLinearOpContext::run_add_relu -at::Tensor run_add_relu( - ContextLinearWoq& context, - const at::Tensor& input, - at::Tensor& accumu, - const c10::optional& alpha) { - // TPP kernel packs weight to 4d (Nc, Kc, block_k, block_n) - auto w_k = context.at_weight_.dim() == 2 - ? context.at_weight_.size(1) - : context.at_weight_.size(1) * context.at_weight_.size(2); - TORCH_CHECK( - input.size(input.dim() - 1) == w_k, - "WOQ linear: input and weight shapes do not match, got k = ", - input.size(input.dim() - 1), - " and ", - w_k, - " respectively."); - auto input_ = input.contiguous(); - auto output = woq_linear_kernel( - input_, - context.at_weight_, - context.scales_list_, - context.zero_points_list_, - context.bias_list_, - context.is_int4_, - context.lowp_mode_, - context.num_concats_); - at::add_out(accumu, output, accumu, alpha.value()); - at::relu_(accumu); - return accumu; + context.act_quant_mode_); } // Called by IpexWoqLinearOpContext::run_add @@ -334,9 +331,7 @@ at::Tensor run_add( const at::Tensor& input, const std::vector& others) { // TPP kernel packs weight to 4d (Nc, Kc, block_k, block_n) - auto w_k = context.at_weight_.dim() == 2 - ? context.at_weight_.size(1) - : context.at_weight_.size(1) * context.at_weight_.size(2); + auto w_k = context.weight_shape_[1]; TORCH_CHECK( input.size(input.dim() - 1) == w_k, "WOQ linear: input and weight shapes do not match, got k = ", @@ -352,9 +347,11 @@ at::Tensor run_add( context.zero_points_list_, context.bias_list_, context.is_int4_, + context.group_size_, context.lowp_mode_, context.num_concats_, - others); + others, + context.act_quant_mode_); } // Called by IpexWoqLinearOpContext::run_add_add @@ -363,9 +360,7 @@ at::Tensor run_add_add( const at::Tensor& input, const std::vector& others) { // TPP kernel packs weight to 4d (Nc, Kc, block_k, block_n) - auto w_k = context.at_weight_.dim() == 2 - ? context.at_weight_.size(1) - : context.at_weight_.size(1) * context.at_weight_.size(2); + auto w_k = context.weight_shape_[1]; TORCH_CHECK( input.size(input.dim() - 1) == w_k, "WOQ linear: input and weight shapes do not match, got k = ", @@ -381,35 +376,11 @@ at::Tensor run_add_add( context.zero_points_list_, context.bias_list_, context.is_int4_, + context.group_size_, context.lowp_mode_, context.num_concats_, - others); -} - -// Registered as JIT op -at::Tensor woq_linear_add_run( - const at::Tensor& input, - at::Tensor& accumu, - const c10::optional& alpha, - const at::Tensor& op_context) { - RECORD_FUNCTION( - "torch_ipex::woq_linear_add_run", c10::ArrayRef({})); - return reinterpret_cast( - op_context.data_ptr()[0]) - ->run_add(input, accumu, alpha); -} - -// Registered as JIT op -at::Tensor woq_linear_add_relu_run( - const at::Tensor& input, - at::Tensor& accumu, - const c10::optional& alpha, - const at::Tensor& op_context) { - RECORD_FUNCTION( - "torch_ipex::woq_linear_add_relu_run", c10::ArrayRef({})); - return reinterpret_cast( - op_context.data_ptr()[0]) - ->run_add_relu(input, accumu, alpha); + others, + context.act_quant_mode_); } at::Tensor pack(ContextLinearWoq& context, const at::Tensor& tensor) { @@ -428,22 +399,16 @@ at::Tensor unpack(ContextLinearWoq& context, const at::Tensor& tensor) { auto zero_points = context.zero_points_list_[0]; if (context.is_int4_) { auto unpacked_shape = unpacked_weight.sizes().vec(); // = N * K/2 - auto shape = context.orig_wei_shape_.has_value() - ? context.orig_wei_shape_.value() - : std::vector({unpacked_shape[0], unpacked_shape[1] * 2}); - at::Tensor qweight = at::_empty_per_channel_affine_quantized( - shape, - scales, - zero_points, - 0, - device(c10::kCPU).dtype(c10::kQUInt4x2)); + auto shape = context.weight_shape_; + shape.back() /= 2; + at::Tensor qweight = + at::empty(shape, device(c10::kCPU).dtype(c10::kByte)); assert(qweight.numel() % 2 == 0); std::memcpy( - qweight.data_ptr(), unpacked_weight.data_ptr(), qweight.numel() / 2); + qweight.data_ptr(), unpacked_weight.data_ptr(), qweight.numel()); return qweight; } else { // int8 - return at::_make_per_channel_quantized_tensor( - unpacked_weight.int_repr(), scales, zero_points.to(c10::kInt), 0); + return unpacked_weight; } } return unpacked_weight; diff --git a/csrc/cpu/jit/cpu/kernels/LinearWoqPacked.h b/csrc/cpu/jit/cpu/kernels/LinearWoqPacked.h index a54442aef..6f573ec9f 100644 --- a/csrc/cpu/jit/cpu/kernels/LinearWoqPacked.h +++ b/csrc/cpu/jit/cpu/kernels/LinearWoqPacked.h @@ -12,10 +12,16 @@ namespace woq_linear { // WOQ = weight-only quantization c10::intrusive_ptr createWoqLinearPrePackOpContext( at::Tensor&& weight, + std::vector&& weight_shape, + at::Tensor&& scales, + at::Tensor&& zero_points, c10::optional&& bias, c10::optional batch_size, + bool is_int4, + int64_t group_size, int64_t lowp_mode, - int64_t num_concats); + int64_t num_concats, + int64_t act_quant_mode); c10::intrusive_ptr createWoqLinearPrePackOpContextInt4( at::Tensor&& weight, @@ -23,8 +29,10 @@ c10::intrusive_ptr createWoqLinearPrePackOpContextInt4( at::Tensor&& zero_points, c10::optional&& bias, c10::optional batch_size, + int64_t group_size, int64_t lowp_mode, - int64_t num_concats); + int64_t num_concats, + int64_t act_quant_mode); at::Tensor woq_linear_run( const at::Tensor& input, @@ -32,12 +40,16 @@ at::Tensor woq_linear_run( ContextLinearWoq create( at::Tensor& weight, + std::vector& weight_shape, at::Tensor& scales, at::Tensor& zero_points, const c10::optional& bias, const c10::optional batch_size, + bool is_int4, + int64_t group_size, int64_t lowp_mode, - int64_t num_concats); + int64_t num_concats, + int64_t act_quant_mode); at::Tensor run(ContextLinearWoq& context, const at::Tensor& input); @@ -48,25 +60,6 @@ at::Tensor run_eltwise( const torch::List>& scalars, const c10::optional& algorithm); -at::Tensor woq_linear_eltwise_run( - const at::Tensor& input, - const at::Tensor& op_context, - const c10::string_view& post_op, - const torch::List>& scalars, - const c10::optional& algorithm); - -at::Tensor run_add( - ContextLinearWoq& context, - const at::Tensor& input, - at::Tensor& accumu, - const c10::optional& alpha); - -at::Tensor run_add_relu( - ContextLinearWoq& context, - const at::Tensor& input, - at::Tensor& accumu, - const c10::optional& alpha); - at::Tensor run_add( ContextLinearWoq& context, const at::Tensor& input, diff --git a/csrc/cpu/jit/cpu/kernels/OpContext.cpp b/csrc/cpu/jit/cpu/kernels/OpContext.cpp index ead9a3660..3014bd9b9 100644 --- a/csrc/cpu/jit/cpu/kernels/OpContext.cpp +++ b/csrc/cpu/jit/cpu/kernels/OpContext.cpp @@ -362,125 +362,30 @@ void IpexConvTransposeOpContext::load_from_ctx( // For weight-only quantization c10::intrusive_ptr IpexWoqLinearOpContext::create_context( at::Tensor&& weight, + std::vector&& weight_shape, + at::Tensor&& scales_fp32, + at::Tensor&& zp_fp32, c10::optional&& bias, c10::optional batch_size, + bool is_int4, + int64_t group_size, int64_t lowp_mode, - int64_t num_concats) { - auto N = weight.size(0); - const auto qtype = weight.qscheme(); - if (weight.scalar_type() == c10::ScalarType::QInt8) { - // extract scales from weight - std::vector weight_scales_float(1, 0.0); - if (qtype == c10::kPerTensorAffine) { - weight_scales_float[0] = weight.q_scale(); - } else if (qtype == c10::kPerChannelAffine) { - weight_scales_float.resize(N); - for (const auto i : c10::irange(N)) { - weight_scales_float[i] = weight.q_per_channel_scales()[i].item(); - } - } - - at::Tensor scales = at::empty( - {static_cast(weight_scales_float.size())}, - at::device(c10::kCPU).dtype(c10::kFloat)); - std::copy( - weight_scales_float.begin(), - weight_scales_float.end(), - scales.data_ptr()); - - // extract zero_points from weight - std::vector weight_zero_points_int32(1, 0); - if (qtype == c10::kPerTensorAffine) { - weight_zero_points_int32[0] = weight.q_zero_point(); - } else if (qtype == c10::kPerChannelAffine) { - weight_zero_points_int32.resize(N); - for (const auto i : c10::irange(N)) { - weight_zero_points_int32[i] = - weight.q_per_channel_zero_points()[i].item(); - } - } - at::Tensor zero_points_int32 = at::empty( - {static_cast(weight_zero_points_int32.size())}, - at::device(c10::kCPU).dtype(c10::kInt)); - std::copy( - weight_zero_points_int32.begin(), - weight_zero_points_int32.end(), - zero_points_int32.data_ptr()); - - // convert zero_points from int32_t to float - std::vector weight_zero_points_float(1, 0); - if (qtype == c10::kPerTensorAffine) { - weight_zero_points_float[0] = (float)weight.q_zero_point(); - } else if (qtype == c10::kPerChannelAffine) { - weight_zero_points_float.resize(N); - for (const auto i : c10::irange(N)) { - weight_zero_points_float[i] = - (float)weight.q_per_channel_zero_points()[i].item(); - } - } - at::Tensor zero_points_float = at::empty( - {static_cast(weight_zero_points_float.size())}, - at::device(c10::kCPU).dtype(c10::kFloat)); - std::copy( - weight_zero_points_float.begin(), - weight_zero_points_float.end(), - zero_points_float.data_ptr()); - - auto op_context = torch_ipex::cpu::detail::woq_linear::create( - weight, - scales, - zero_points_int32, - bias, - batch_size, - lowp_mode, - num_concats); - return c10::make_intrusive( - batch_size, std::move(op_context)); - } else { - // extract scales from weight - std::vector weight_scales_float(1, 0.0); - if (qtype == c10::kPerChannelAffineFloatQParams) { - weight_scales_float.resize(N); - for (const auto i : c10::irange(N)) { - weight_scales_float[i] = weight.q_per_channel_scales()[i].item(); - } - } - - at::Tensor scales = at::empty( - {static_cast(weight_scales_float.size())}, - at::device(c10::kCPU).dtype(c10::kFloat)); - std::copy( - weight_scales_float.begin(), - weight_scales_float.end(), - scales.data_ptr()); - - // extract zero_points from weight - std::vector weight_zero_points_float(1, 0); - if (qtype == c10::kPerChannelAffineFloatQParams) { - weight_zero_points_float.resize(N); - for (const auto i : c10::irange(N)) { - weight_zero_points_float[i] = - weight.q_per_channel_zero_points()[i].item(); - } - } - at::Tensor zero_points_float = at::empty( - {static_cast(weight_zero_points_float.size())}, - at::device(c10::kCPU).dtype(c10::kFloat)); - std::copy( - weight_zero_points_float.begin(), - weight_zero_points_float.end(), - zero_points_float.data_ptr()); - auto op_context = torch_ipex::cpu::detail::woq_linear::create( - weight, - scales, - zero_points_float, - bias, - batch_size, - lowp_mode, - num_concats); - return c10::make_intrusive( - batch_size, std::move(op_context)); - } + int64_t num_concats, + int64_t act_quant_mode) { + auto op_context = torch_ipex::cpu::detail::woq_linear::create( + weight, + weight_shape, + scales_fp32, + zp_fp32, + bias, + batch_size, + is_int4, + group_size, + lowp_mode, + num_concats, + act_quant_mode); + return c10::make_intrusive( + batch_size, std::move(op_context)); } at::Tensor IpexWoqLinearOpContext::get_data_handle() { @@ -502,22 +407,6 @@ at::Tensor IpexWoqLinearOpContext::run_eltwise( op_context_, input, post_op, scalars, algorithm); } -at::Tensor IpexWoqLinearOpContext::run_add( - const at::Tensor& input, - at::Tensor& accumu, - const c10::optional& alpha) { - return torch_ipex::cpu::detail::woq_linear::run_add( - op_context_, input, accumu, alpha); -} - -at::Tensor IpexWoqLinearOpContext::run_add_relu( - const at::Tensor& input, - at::Tensor& accumu, - const c10::optional& alpha) { - return torch_ipex::cpu::detail::woq_linear::run_add_relu( - op_context_, input, accumu, alpha); -} - at::Tensor IpexWoqLinearOpContext::run_add( const at::Tensor& input, const std::vector& others) { @@ -544,6 +433,46 @@ c10::optional IpexWoqLinearOpContext::get_at_bias() { return op_context_.at_bias_; } +at::Tensor IpexWoqLinearOpContext::get_scales() { + if (op_context_.group_size_ > 0 && op_context_.at_weight_.dim() == 4) { + // [#block_n, #block_k, n_block_size] -> [#block_n, n_block_size, #block_k] + // -> [N, #block_k] + auto scales = op_context_.scales_list_[0].permute({0, 2, 1}).contiguous(); + scales = scales.view({-1, scales.size(-1)}); + if (scales.size(0) > op_context_.weight_shape_[0]) { + return scales.narrow(0, 0, op_context_.weight_shape_[0]); + } + return scales; + } + if (op_context_.scales_list_[0].size(0) > op_context_.weight_shape_[0]) { + return op_context_.scales_list_[0].narrow( + 0, 0, op_context_.weight_shape_[0]); + } + return op_context_.scales_list_[0]; +} + +at::Tensor IpexWoqLinearOpContext::get_zero_points() { + if (op_context_.group_size_ > 0 && op_context_.at_weight_.dim() == 4) { + // [#block_n, #block_k, n_block_size] -> [#block_n, n_block_size, #block_k] + // -> [N, #block_k] + auto zp = op_context_.zero_points_list_[0].permute({0, 2, 1}).contiguous(); + zp = zp.view({-1, zp.size(-1)}); + if (zp.size(0) > op_context_.weight_shape_[0]) { + return zp.narrow(0, 0, op_context_.weight_shape_[0]); + } + return zp; + } + if (op_context_.zero_points_list_[0].size(0) > op_context_.weight_shape_[0]) { + return op_context_.zero_points_list_[0].narrow( + 0, 0, op_context_.weight_shape_[0]); + } + return op_context_.zero_points_list_[0]; +} + +std::vector IpexWoqLinearOpContext::get_weight_shape() { + return op_context_.weight_shape_; +} + detail::ContextLinearWoq& IpexWoqLinearOpContext::get_context() { return op_context_; } diff --git a/csrc/cpu/jit/cpu/kernels/OpContext.h b/csrc/cpu/jit/cpu/kernels/OpContext.h index e43eb5cf5..6c15601ad 100644 --- a/csrc/cpu/jit/cpu/kernels/OpContext.h +++ b/csrc/cpu/jit/cpu/kernels/OpContext.h @@ -361,11 +361,17 @@ class IpexLinearMKLOpContext final : public MKLOpContext { // Weight-only quantization using SerializationTypeWoqLinearPrePack = std::tuple< - at::Tensor, - c10::optional, - c10::optional, - int64_t, - int64_t>; + at::Tensor, // weight + std::vector, // weight shape + at::Tensor, // scales + at::Tensor, // zero points + c10::optional, // bias + c10::optional, // batch size + bool, // is_int4 + int64_t, // group size + int64_t, // lowp_mode + int64_t, // num_concats + int64_t>; // act_quant_mode class WoqLinearOpContext : public torch::jit::CustomClassHolder { protected: @@ -374,13 +380,22 @@ class WoqLinearOpContext : public torch::jit::CustomClassHolder { public: SerializationTypeWoqLinearPrePack unpack() { auto orig_weight_ = this->to_public(this->get_at_packed_weight()); + auto weight_shape_ = this->get_weight_shape(); auto orig_bias_ = this->get_context().at_bias_; + auto scales = this->get_scales(); + auto zero_points = this->get_zero_points(); return std::make_tuple( orig_weight_, + weight_shape_, + scales, + zero_points, orig_bias_, batch_size_, + this->get_context().is_int4_, + this->get_context().group_size_, this->get_context().lowp_mode_, - this->get_context().num_concats_); + this->get_context().num_concats_, + this->get_context().act_quant_mode_); } virtual at::Tensor get_data_handle() = 0; @@ -393,16 +408,6 @@ class WoqLinearOpContext : public torch::jit::CustomClassHolder { const torch::List>& scalars, const c10::optional& algorithm) = 0; - virtual at::Tensor run_add( - const at::Tensor& input, - at::Tensor& accumu, - const c10::optional& alpha) = 0; - - virtual at::Tensor run_add_relu( - const at::Tensor& input, - at::Tensor& accumu, - const c10::optional& alpha) = 0; - virtual at::Tensor run_add( const at::Tensor& input, const std::vector& others) = 0; @@ -417,6 +422,12 @@ class WoqLinearOpContext : public torch::jit::CustomClassHolder { virtual c10::optional get_at_bias() = 0; + virtual at::Tensor get_scales() = 0; + + virtual at::Tensor get_zero_points() = 0; + + virtual std::vector get_weight_shape() = 0; + virtual at::Tensor pack(const at::Tensor& tensor) = 0; virtual detail::ContextLinearWoq& get_context() = 0; @@ -453,16 +464,6 @@ class IpexWoqLinearOpContext final : public WoqLinearOpContext { const torch::List>& scalars, const c10::optional& algorithm) override; - virtual at::Tensor run_add( - const at::Tensor& input, - at::Tensor& accumu, - const c10::optional& alpha) override; - - virtual at::Tensor run_add_relu( - const at::Tensor& input, - at::Tensor& accumu, - const c10::optional& alpha) override; - virtual at::Tensor run_add( const at::Tensor& input, const std::vector& others) override; @@ -477,16 +478,28 @@ class IpexWoqLinearOpContext final : public WoqLinearOpContext { virtual c10::optional get_at_bias() override; + virtual at::Tensor get_scales() override; + + virtual at::Tensor get_zero_points() override; + + virtual std::vector get_weight_shape() override; + virtual at::Tensor pack(const at::Tensor& tensor) override; virtual detail::ContextLinearWoq& get_context() override; static c10::intrusive_ptr create_context( at::Tensor&& weight, + std::vector&& weight_shape, + at::Tensor&& scales_fp32, + at::Tensor&& zp_fp32, c10::optional&& bias, c10::optional batch_size, + bool is_int4, + int64_t group_size, int64_t lowp_mode, - int64_t num_concats); + int64_t num_concats, + int64_t act_quant_mode); virtual void load_from_ctx( c10::intrusive_ptr other) override; diff --git a/csrc/cpu/jit/cpu/kernels/RegisterOpContextClass.cpp b/csrc/cpu/jit/cpu/kernels/RegisterOpContextClass.cpp index 34229af3a..20300577e 100644 --- a/csrc/cpu/jit/cpu/kernels/RegisterOpContextClass.cpp +++ b/csrc/cpu/jit/cpu/kernels/RegisterOpContextClass.cpp @@ -128,16 +128,29 @@ TORCH_LIBRARY(ipex_prepack, m) { [](SerializationTypeWoqLinearPrePack state) -> c10::intrusive_ptr { // __setstate__ return createWoqLinearPrePackOpContext( - std::move(std::get<0>(state)), - std::move(std::get<1>(state)), - std::move(std::get<2>(state)), - std::move(std::get<3>(state)), - std::move(std::get<4>(state))); + std::move(std::get<0>(state)), // weight + std::move(std::get<1>(state)), // weight shape + std::move(std::get<2>(state)), // scales + std::move(std::get<3>(state)), // zero points + std::move(std::get<4>(state)), // bias + std::move(std::get<5>(state)), // batch size + std::move(std::get<6>(state)), // is_int4 + std::move(std::get<7>(state)), // group size + std::move(std::get<8>(state)), // lowp_mode + std::move(std::get<9>(state)), // num_concats + std::move(std::get<10>(state))); // act_quant_mode }) .def( "get_weight", &torch_ipex::cpu::WoqLinearOpContext::get_at_packed_weight) .def("get_bias", &torch_ipex::cpu::WoqLinearOpContext::get_at_bias) + .def("get_scales", &torch_ipex::cpu::WoqLinearOpContext::get_scales) + .def( + "get_zero_points", + &torch_ipex::cpu::WoqLinearOpContext::get_zero_points) + .def( + "get_weight_shape", + &torch_ipex::cpu::WoqLinearOpContext::get_weight_shape) .def("pack", &torch_ipex::cpu::WoqLinearOpContext::pack) .def("to_public", &torch_ipex::cpu::WoqLinearOpContext::to_public) .def( @@ -162,10 +175,10 @@ TORCH_LIBRARY(ipex_prepack, m) { "bool input_is_channels_last, int[] input_sizes) " "-> __torch__.torch.classes.ipex_prepack.ConvTransposeOpContext"); m.def( - "weight_only_qlinear_prepack(Tensor W, Tensor? B, int? batch_size, int lowp_mode, int num_concats) " + "weight_only_qlinear_prepack(Tensor W, int[] W_shape, Tensor scales, Tensor zero_points, Tensor? B, int? batch_size, bool is_int4, int group_size, int lowp_mode, int num_concats, int act_quant_mode) " "-> __torch__.torch.classes.ipex_prepack.WoqLinearOpContext"); m.def( - "weight_only_qlinear_prepack_int4(Tensor W, Tensor scales, Tensor zero_points, Tensor? B, int? batch_size, int lowp_mode, int num_concats) " + "weight_only_qlinear_prepack_int4(Tensor W, Tensor scales, Tensor zero_points, Tensor? B, int? batch_size, int group_size, int lowp_mode, int num_concats, int act_quant_mode) " "-> __torch__.torch.classes.ipex_prepack.WoqLinearOpContext"); } @@ -176,7 +189,7 @@ TORCH_LIBRARY_IMPL(ipex_prepack, CPU, m) { m.impl( "conv_transpose_prepack", TORCH_FN(createConvTransposePrePackOpContext)); } -TORCH_LIBRARY_IMPL(ipex_prepack, QuantizedCPU, m) { +TORCH_LIBRARY_IMPL(ipex_prepack, CPU, m) { m.impl( "weight_only_qlinear_prepack", TORCH_FN(createWoqLinearPrePackOpContext)); } diff --git a/examples/cpu/inference/python/llm/README.md b/examples/cpu/inference/python/llm/README.md index 9ba7c57cc..609ec3b54 100644 --- a/examples/cpu/inference/python/llm/README.md +++ b/examples/cpu/inference/python/llm/README.md @@ -174,7 +174,7 @@ Here is how to use it: # Step 1: Generate modified weights and quantization info python utils/run_gptq.py --model --output-dir ./saved_results ``` -It may take a few hours to finish. Modified weights and their quantization info are stored in `gptq_checkpoint.pt`. +It may take a few hours to finish. Modified weights and their quantization info are stored in `gptq_checkpoint_g128.pt`, where g128 means group size for input channel is 128 by default. Then generate model for weight only quantization with INT4 weights and run tasks. ```bash # Step 2: Generate quantized model with INT4 weights diff --git a/examples/cpu/inference/python/llm/single_instance/run_falcon_quantization.py b/examples/cpu/inference/python/llm/single_instance/run_falcon_quantization.py index c27d29baa..bcd98b769 100644 --- a/examples/cpu/inference/python/llm/single_instance/run_falcon_quantization.py +++ b/examples/cpu/inference/python/llm/single_instance/run_falcon_quantization.py @@ -66,6 +66,17 @@ " data type or lowp-mode. If `--low-precision-checkpoint` is given, weight" " data type is always INT4 and this argument is not needed.", ) +parser.add_argument( + "--group-size", + default=-1, + type=int, + help="For weight-only quantization only. Specifies the group size along" + " input channel for block-wise quantization of weight. It must be a" + " positive power of 2 or -1. If it is -1, weight is quantized per" + " output channel. Otherwise, weight is quantized per block with block size" + " = [1, group_size]. If `--low-precision-checkpoint` is given, group" + " size is determined automatically and this argument has no effect.", +) parser.add_argument( "--low-precision-checkpoint", default="", @@ -74,6 +85,20 @@ " modified weights, scales, zero points, etc. For better accuracy of weight only" " quantization with INT4 weight." ) +parser.add_argument( + "--act-quant-mode", + choices=["PER_TENSOR", "PER_IC_BLOCK", "PER_BATCH", "PER_BATCH_IC_BLOCK"], + default="PER_IC_BLOCK", + type=str, + help="Quantization mode for activation with different granularity. " + "For lowp-mode=INT8 only. For other cases, it has no effect. " + "Assume the activation tensor has shape batch_size x input_channel. " + "PER_TENSOR(0): quantize per tensor; " + "PER_IC_BLOCK(1): quantize per group along IC with group size = IC_BLOCK; " + "PER_BATCH(2): quantize per batch; " + "PER_BATCH_IC_BLOCK(3): quantize per block of size 1 x IC_BLOCK. " + "IC_BLOCK is determined by IC automatically." +) args = parser.parse_args() @@ -180,8 +205,17 @@ else: lowp_mode = ipex.quantization.WoqLowpMode.BF16 + act_quant_mode_dict = { + "PER_TENSOR": ipex.quantization.WoqActQuantMode.PER_TENSOR, + "PER_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_IC_BLOCK, + "PER_BATCH": ipex.quantization.WoqActQuantMode.PER_BATCH, + "PER_BATCH_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK, + } qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping( - weight_dtype=weight_dtype, lowp_mode=lowp_mode + weight_dtype=weight_dtype, + lowp_mode=lowp_mode, + act_quant_mode=act_quant_mode_dict[args.act_quant_mode], + group_size=args.group_size ) if args.low_precision_checkpoint != "": low_precision_checkpoint = torch.load(args.low_precision_checkpoint) diff --git a/examples/cpu/inference/python/llm/single_instance/run_gpt-j_quantization.py b/examples/cpu/inference/python/llm/single_instance/run_gpt-j_quantization.py index 9d9127786..d16ac0281 100644 --- a/examples/cpu/inference/python/llm/single_instance/run_gpt-j_quantization.py +++ b/examples/cpu/inference/python/llm/single_instance/run_gpt-j_quantization.py @@ -68,6 +68,17 @@ " data type or lowp-mode. If `--low-precision-checkpoint` is given, weight" " data type is always INT4 and this argument is not needed.", ) +parser.add_argument( + "--group-size", + default=-1, + type=int, + help="For weight-only quantization only. Specifies the group size along" + " input channel for block-wise quantization of weight. It must be a" + " positive power of 2 or -1. If it is -1, weight is quantized per" + " output channel. Otherwise, weight is quantized per block with block size" + " = [1, group_size]. If `--low-precision-checkpoint` is given, group" + " size is determined automatically and this argument has no effect.", +) parser.add_argument( "--low-precision-checkpoint", default="", @@ -76,6 +87,20 @@ " modified weights, scales, zero points, etc. For better accuracy of weight only" " quantization with INT4 weight.", ) +parser.add_argument( + "--act-quant-mode", + choices=["PER_TENSOR", "PER_IC_BLOCK", "PER_BATCH", "PER_BATCH_IC_BLOCK"], + default="PER_IC_BLOCK", + type=str, + help="Quantization mode for activation with different granularity. " + "For lowp-mode=INT8 only. For other cases, it has no effect. " + "Assume the activation tensor has shape batch_size x input_channel. " + "PER_TENSOR(0): quantize per tensor; " + "PER_IC_BLOCK(1): quantize per group along IC with group size = IC_BLOCK; " + "PER_BATCH(2): quantize per batch; " + "PER_BATCH_IC_BLOCK(3): quantize per block of size 1 x IC_BLOCK. " + "IC_BLOCK is determined by IC automatically." +) args = parser.parse_args() # amp autocast @@ -170,8 +195,17 @@ else: lowp_mode = ipex.quantization.WoqLowpMode.BF16 + act_quant_mode_dict = { + "PER_TENSOR": ipex.quantization.WoqActQuantMode.PER_TENSOR, + "PER_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_IC_BLOCK, + "PER_BATCH": ipex.quantization.WoqActQuantMode.PER_BATCH, + "PER_BATCH_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK, + } qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping( - weight_dtype=weight_dtype, lowp_mode=lowp_mode + weight_dtype=weight_dtype, + lowp_mode=lowp_mode, + act_quant_mode=act_quant_mode_dict[args.act_quant_mode], + group_size=args.group_size ) if args.low_precision_checkpoint != "": low_precision_checkpoint = torch.load(args.low_precision_checkpoint) diff --git a/examples/cpu/inference/python/llm/single_instance/run_gpt-neox_quantization.py b/examples/cpu/inference/python/llm/single_instance/run_gpt-neox_quantization.py index 0d8befc7f..1bebafc55 100644 --- a/examples/cpu/inference/python/llm/single_instance/run_gpt-neox_quantization.py +++ b/examples/cpu/inference/python/llm/single_instance/run_gpt-neox_quantization.py @@ -70,6 +70,17 @@ " data type or lowp-mode. If `--low-precision-checkpoint` is given, weight" " data type is always INT4 and this argument is not needed.", ) +parser.add_argument( + "--group-size", + default=-1, + type=int, + help="For weight-only quantization only. Specifies the group size along" + " input channel for block-wise quantization of weight. It must be a" + " positive power of 2 or -1. If it is -1, weight is quantized per" + " output channel. Otherwise, weight is quantized per block with block size" + " = [1, group_size]. If `--low-precision-checkpoint` is given, group" + " size is determined automatically and this argument has no effect.", +) parser.add_argument( "--low-precision-checkpoint", default="", @@ -78,6 +89,20 @@ " modified weights, scales, zero points, etc. For better accuracy of weight only" " quantization with INT4 weight." ) +parser.add_argument( + "--act-quant-mode", + choices=["PER_TENSOR", "PER_IC_BLOCK", "PER_BATCH", "PER_BATCH_IC_BLOCK"], + default="PER_IC_BLOCK", + type=str, + help="Quantization mode for activation with different granularity. " + "For lowp-mode=INT8 only. For other cases, it has no effect. " + "Assume the activation tensor has shape batch_size x input_channel. " + "PER_TENSOR(0): quantize per tensor; " + "PER_IC_BLOCK(1): quantize per group along IC with group size = IC_BLOCK; " + "PER_BATCH(2): quantize per batch; " + "PER_BATCH_IC_BLOCK(3): quantize per block of size 1 x IC_BLOCK. " + "IC_BLOCK is determined by IC automatically." +) args = parser.parse_args() @@ -168,8 +193,17 @@ else: lowp_mode = ipex.quantization.WoqLowpMode.BF16 + act_quant_mode_dict = { + "PER_TENSOR": ipex.quantization.WoqActQuantMode.PER_TENSOR, + "PER_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_IC_BLOCK, + "PER_BATCH": ipex.quantization.WoqActQuantMode.PER_BATCH, + "PER_BATCH_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK, + } qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping( - weight_dtype=weight_dtype, lowp_mode=lowp_mode + weight_dtype=weight_dtype, + lowp_mode=lowp_mode, + act_quant_mode=act_quant_mode_dict[args.act_quant_mode], + group_size=args.group_size ) if args.low_precision_checkpoint != "": low_precision_checkpoint = torch.load(args.low_precision_checkpoint) diff --git a/examples/cpu/inference/python/llm/single_instance/run_llama_quantization.py b/examples/cpu/inference/python/llm/single_instance/run_llama_quantization.py index 728d4f954..0ac1c1c9d 100644 --- a/examples/cpu/inference/python/llm/single_instance/run_llama_quantization.py +++ b/examples/cpu/inference/python/llm/single_instance/run_llama_quantization.py @@ -69,6 +69,17 @@ " data type or lowp-mode. If `--low-precision-checkpoint` is given, weight" " data type is always INT4 and this argument is not needed.", ) +parser.add_argument( + "--group-size", + default=-1, + type=int, + help="For weight-only quantization only. Specifies the group size along" + " input channel for block-wise quantization of weight. It must be a" + " positive power of 2 or -1. If it is -1, weight is quantized per" + " output channel. Otherwise, weight is quantized per block with block size" + " = [1, group_size]. If `--low-precision-checkpoint` is given, group" + " size is determined automatically and this argument has no effect.", +) parser.add_argument( "--low-precision-checkpoint", default="", @@ -77,6 +88,20 @@ " modified weights, scales, zero points, etc. For better accuracy of weight only" " quantization with INT4 weight.", ) +parser.add_argument( + "--act-quant-mode", + choices=["PER_TENSOR", "PER_IC_BLOCK", "PER_BATCH", "PER_BATCH_IC_BLOCK"], + default="PER_IC_BLOCK", + type=str, + help="Quantization mode for activation with different granularity. " + "For lowp-mode=INT8 only. For other cases, it has no effect. " + "Assume the activation tensor has shape batch_size x input_channel. " + "PER_TENSOR(0): quantize per tensor; " + "PER_IC_BLOCK(1): quantize per group along IC with group size = IC_BLOCK; " + "PER_BATCH(2): quantize per batch; " + "PER_BATCH_IC_BLOCK(3): quantize per block of size 1 x IC_BLOCK. " + "IC_BLOCK is determined by IC automatically." +) args = parser.parse_args() @@ -290,8 +315,17 @@ def calib_func(prepared_model): else: lowp_mode = ipex.quantization.WoqLowpMode.BF16 + act_quant_mode_dict = { + "PER_TENSOR": ipex.quantization.WoqActQuantMode.PER_TENSOR, + "PER_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_IC_BLOCK, + "PER_BATCH": ipex.quantization.WoqActQuantMode.PER_BATCH, + "PER_BATCH_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK, + } qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping( - weight_dtype=weight_dtype, lowp_mode=lowp_mode + weight_dtype=weight_dtype, + lowp_mode=lowp_mode, + act_quant_mode=act_quant_mode_dict[args.act_quant_mode], + group_size=args.group_size ) if args.low_precision_checkpoint != "": low_precision_checkpoint = torch.load(args.low_precision_checkpoint) diff --git a/examples/cpu/inference/python/llm/single_instance/run_opt_quantization.py b/examples/cpu/inference/python/llm/single_instance/run_opt_quantization.py index 71d2ec423..db860b394 100644 --- a/examples/cpu/inference/python/llm/single_instance/run_opt_quantization.py +++ b/examples/cpu/inference/python/llm/single_instance/run_opt_quantization.py @@ -63,6 +63,17 @@ " data type or lowp-mode. If `--low-precision-checkpoint` is given, weight" " data type is always INT4 and this argument is not needed.", ) +parser.add_argument( + "--group-size", + default=-1, + type=int, + help="For weight-only quantization only. Specifies the group size along" + " input channel for block-wise quantization of weight. It must be a" + " positive power of 2 or -1. If it is -1, weight is quantized per" + " output channel. Otherwise, weight is quantized per block with block size" + " = [1, group_size]. If `--low-precision-checkpoint` is given, group" + " size is determined automatically and this argument has no effect.", +) parser.add_argument( "--low-precision-checkpoint", default="", @@ -71,6 +82,20 @@ " modified weights, scales, zero points, etc. For better accuracy of weight only" " quantization with INT4 weight." ) +parser.add_argument( + "--act-quant-mode", + choices=["PER_TENSOR", "PER_IC_BLOCK", "PER_BATCH", "PER_BATCH_IC_BLOCK"], + default="PER_IC_BLOCK", + type=str, + help="Quantization mode for activation with different granularity. " + "For lowp-mode=INT8 only. For other cases, it has no effect. " + "Assume the activation tensor has shape batch_size x input_channel. " + "PER_TENSOR(0): quantize per tensor; " + "PER_IC_BLOCK(1): quantize per group along IC with group size = IC_BLOCK; " + "PER_BATCH(2): quantize per batch; " + "PER_BATCH_IC_BLOCK(3): quantize per block of size 1 x IC_BLOCK. " + "IC_BLOCK is determined by IC automatically." +) args = parser.parse_args() @@ -165,8 +190,17 @@ else: lowp_mode = ipex.quantization.WoqLowpMode.BF16 + act_quant_mode_dict = { + "PER_TENSOR": ipex.quantization.WoqActQuantMode.PER_TENSOR, + "PER_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_IC_BLOCK, + "PER_BATCH": ipex.quantization.WoqActQuantMode.PER_BATCH, + "PER_BATCH_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK, + } qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping( - weight_dtype=weight_dtype, lowp_mode=lowp_mode + weight_dtype=weight_dtype, + lowp_mode=lowp_mode, + act_quant_mode=act_quant_mode_dict[args.act_quant_mode], + group_size=args.group_size ) if args.low_precision_checkpoint != "": low_precision_checkpoint = torch.load(args.low_precision_checkpoint) diff --git a/examples/cpu/inference/python/llm/utils/run_gptq.py b/examples/cpu/inference/python/llm/utils/run_gptq.py index 5f871927f..1a9959e8b 100644 --- a/examples/cpu/inference/python/llm/utils/run_gptq.py +++ b/examples/cpu/inference/python/llm/utils/run_gptq.py @@ -21,6 +21,7 @@ ) parser.add_argument("--dataset", nargs="?", default="lambada", const="lambada") parser.add_argument("--output-dir", nargs="?", default="./saved_results") +parser.add_argument("--group-size", default=128, type=int) parser.add_argument("--calib-iters", default=512, type=int, help="calibration iters.") args = parser.parse_args() @@ -176,7 +177,7 @@ def calib_func(prepared_model): '.*': { # re.match "weight": { 'bits': 4, # only support 4-bit for now - 'group_size': -1, # only support per-channel for now + 'group_size': args.group_size, 'scheme': 'asym', # only support asym for now 'algorithm': 'GPTQ', # RTN/AWQ/TEQ }, @@ -218,5 +219,6 @@ def calib_func(prepared_model): scale_dtype=torch.float16, ) Path(args.output_dir).mkdir(parents=True, exist_ok=True) -torch.save(compressed_model.state_dict(), args.output_dir + "/gptq_checkpoint.pt") -print('\n Checkpoint saved to', args.output_dir + "/gptq_checkpoint.pt \n") +output_file_name = f"gptq_checkpoint_g{args.group_size}.pt" +torch.save(compressed_model.state_dict(), args.output_dir + "/" + output_file_name) +print('\n Checkpoint saved to', args.output_dir + "/" + output_file_name + "\n") diff --git a/intel_extension_for_pytorch/nn/modules/weight_only_quantization.py b/intel_extension_for_pytorch/nn/modules/weight_only_quantization.py index 1b1745893..899e2e34a 100644 --- a/intel_extension_for_pytorch/nn/modules/weight_only_quantization.py +++ b/intel_extension_for_pytorch/nn/modules/weight_only_quantization.py @@ -1,41 +1,15 @@ import torch from torch import nn -from torch.ao.nn.quantized.modules.utils import _clamp_weights -from ...quantization._qconfig import get_weight_only_quant_qconfig_mapping from intel_extension_for_pytorch.nn.utils._weight_prepack import ( may_import_deepspeed_modules, _all_reduce_and_bias_add, _pre_ipex_gemm, ) - - -# Port from PyTorch with a few changes -def _quantize_weight(float_wt, observer): - wt_scale, wt_zp = observer.calculate_qparams() - dtype = observer.dtype - if observer.qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]: - qweight = torch.quantize_per_tensor( - float_wt, float(wt_scale), int(wt_zp), dtype - ) - qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp) - elif observer.qscheme in [torch.per_channel_symmetric, torch.per_channel_affine]: - wt_axis = observer.ch_axis - qweight = torch.quantize_per_channel( - float_wt, wt_scale.to(torch.double), wt_zp.to(torch.int64), wt_axis, dtype - ) - qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp) - elif observer.qscheme in [torch.per_channel_affine_float_qparams]: - qweight = torch.quantize_per_channel( - float_wt, - wt_scale.to(torch.float), - wt_zp.to(torch.float), - observer.ch_axis, - dtype, - ) - qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp) - else: - raise ValueError("Unexpected qscheme " + observer.qscheme) - return qweight +from intel_extension_for_pytorch.quantization import ( + QConfigWoq, + quantize_per_channel, + quantize_per_block, +) class IpexWoqLinear(nn.Module): @@ -55,6 +29,8 @@ def __init__(self, in_features, out_features, bias_=True, dtype=torch.qint8): self._op_context = None self._lowp_mode = 0 self._num_concats = 1 + self._act_quant_mode = 0 + self._group_size = -1 def pre_ipex_gemm(self, input): return input @@ -79,15 +55,20 @@ def extra_repr(self): extra_repr_str += ", bias={}".format(self.bias) extra_repr_str += ", lowp_mode={}".format(self._lowp_mode) extra_repr_str += ", num_concats={}".format(self._num_concats) + extra_repr_str += ", act_quant_mode={}".format(self._act_quant_mode) + extra_repr_str += ", group_size={}".format(self._group_size) return extra_repr_str @classmethod - def from_float(cls, mod): + def from_float(cls, mod, scales=None, zero_points=None): r"""Create a weight-only quantized module from a float module or qparams_dict Args: - mod (Module): a float module, either produced by torch.ao.quantization - utilities or provided by the user + mod (Module): an instance of nn.Linear or its subclasses. + scales: the scales Tensor for quantizing weight. If it is None, + scales are found by min/max of the weight. + zero_points: the zero points Tensor for quantizing weight. If it is None, + zero points are found by min/max of the weight. """ float_modules = [torch.nn.Linear] deepspeed_modules = may_import_deepspeed_modules() @@ -102,47 +83,59 @@ def from_float(cls, mod): + f" or their subclasses, but found {type(mod)}" ) assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" - lowp_mode = 0 - if mod.qconfig is not None and mod.qconfig.weight is not None: - weight_observer = mod.qconfig.weight() - if hasattr(mod.qconfig, "lowp_mode"): - lowp_mode = mod.qconfig.lowp_mode - if mod.qconfig.lowp_mode == 3 and weight_observer.dtype == torch.qint8: - # lowp_mode=3 (INT8) is not yet supported for INT8 weight - # Fall back to lowp_mode=2 in this case - # TODO(Weiwen) Support lowp_mode=3 - lowp_mode = 2 - print( - "Warning: lowp_mode=3(INT8) is not supported yet in this case. " - "Falling back to 2(BF16)." - ) - else: - weight_observer = ( - get_weight_only_quant_qconfig_mapping().global_qconfig.weight() + qconfig = mod.qconfig + if qconfig is None or not isinstance(qconfig, QConfigWoq): + return mod + + lowp_mode = qconfig.lowp_mode + if qconfig.lowp_mode == 3 and qconfig.weight_dtype != torch.quint4x2: + # lowp_mode=3 (INT8) is enabled for INT4 weight only + # Fall back to lowp_mode=2 in other case + # TODO(Weiwen) Support lowp_mode=3 + lowp_mode = 2 + print( + "Warning: lowp_mode=3(INT8) is not supported yet in this case. " + "Falling back to 2(BF16)." ) + act_quant_mode = qconfig.act_quant_mode num_concats = 1 if hasattr(mod, "_num_concats"): num_concats = mod._num_concats - dtype = weight_observer.dtype - assert dtype in [torch.quint8, torch.qint8, torch.quint4x2], ( - "The only supported dtypes for " - "weight-only quantized linear are quint8, qint8 and quint4x2 got: {}".format( - dtype + dtype = qconfig.weight_dtype + is_int4 = dtype == torch.quint4x2 + group_size = qconfig.group_size + + if group_size == -1: + qweight, scales, zero_points = quantize_per_channel( + mod.weight, is_int4, scales, zero_points + ) + else: + qweight, scales, zero_points = quantize_per_block( + mod.weight, is_int4, group_size, scales, zero_points ) - ) - weight_observer(mod.weight) - qweight = _quantize_weight(mod.weight.float(), weight_observer) if not hasattr(mod, "in_features"): mod.in_features = mod.weight.size()[1] if not hasattr(mod, "out_features"): mod.out_features = mod.weight.size()[0] - qlinear = cls._init_cls(mod, dtype, qweight, lowp_mode, num_concats) + qlinear = cls._init_cls( + mod, + dtype, + qweight, + scales, + zero_points, + group_size, + lowp_mode, + num_concats, + act_quant_mode, + ) del qweight return qlinear @classmethod - def from_float_and_int4_weight(cls, mod, qweight, scales, zero_points, bias=None): + def from_float_and_int4_weight( + cls, mod, qweight, scales, zero_points, bias=None, group_size=-1 + ): r"""Create a weight-only quantized module from a float module and int4 weight Args: @@ -168,8 +161,12 @@ def from_float_and_int4_weight(cls, mod, qweight, scales, zero_points, bias=None assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" lowp_mode = 0 - if mod.qconfig is not None and hasattr(mod.qconfig, "lowp_mode"): - lowp_mode = mod.qconfig.lowp_mode + act_quant_mode = 0 + if mod.qconfig is not None: + if hasattr(mod.qconfig, "lowp_mode"): + lowp_mode = mod.qconfig.lowp_mode + if hasattr(mod.qconfig, "act_quant_mode"): + act_quant_mode = mod.qconfig.act_quant_mode num_concats = 1 if hasattr(mod, "_num_concats"): num_concats = mod._num_concats @@ -192,23 +189,57 @@ def from_float_and_int4_weight(cls, mod, qweight, scales, zero_points, bias=None if bias is None: bias = mod.bias qlinear._op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack_int4( - qweight, scales, zero_points, bias, None, int(lowp_mode), num_concats + qweight, + scales, + zero_points, + bias, + None, + group_size, + int(lowp_mode), + num_concats, + act_quant_mode, ) qlinear._lowp_mode = lowp_mode qlinear._num_concats = num_concats + qlinear._act_quant_mode = act_quant_mode + qlinear._group_size = group_size del qweight return qlinear @classmethod - def _init_cls(cls, mod, dtype, qweight, lowp_mode, num_concats): + def _init_cls( + cls, + mod, + dtype, + qweight, + scales, + zero_points, + group_size, + lowp_mode, + num_concats, + act_quant_mode, + ): qlinear = cls( mod.in_features, mod.out_features, mod.bias is not None, dtype=dtype ) + is_int4 = dtype == torch.quint4x2 qlinear._op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack( - qweight, mod.bias, None, int(lowp_mode), num_concats + qweight, + [mod.out_features, mod.in_features], + scales, + zero_points, + mod.bias, + None, + is_int4, + group_size, + int(lowp_mode), + num_concats, + act_quant_mode, ) qlinear._lowp_mode = lowp_mode qlinear._num_concats = num_concats + qlinear._act_quant_mode = act_quant_mode + qlinear._group_size = group_size return qlinear @@ -239,15 +270,33 @@ def _init_from_mod(cls, mod, dtype): ) @classmethod - def _init_cls(cls, mod, dtype, qweight, lowp_mode, num_concats): + def _init_cls( + cls, + mod, + dtype, + qweight, + scales, + zero_points, + group_size, + lowp_mode, + num_concats, + act_quant_mode, + ): qlinear = cls._init_from_mod(mod, dtype) + is_int4 = dtype == torch.quint4x2 qlinear._op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack( qweight, + [mod.out_features, mod.in_features], + scales, + zero_points, None, # Set bias to None when prepacking. Please refer to the comment in __init__ of _IPEXLinearAllreduce None, # batch_size + is_int4, + group_size, lowp_mode, num_concats, + act_quant_mode, ) return qlinear diff --git a/intel_extension_for_pytorch/quantization/__init__.py b/intel_extension_for_pytorch/quantization/__init__.py index 6b8829645..7145c392d 100644 --- a/intel_extension_for_pytorch/quantization/__init__.py +++ b/intel_extension_for_pytorch/quantization/__init__.py @@ -7,5 +7,13 @@ get_smooth_quant_qconfig_mapping, get_weight_only_quant_qconfig_mapping, WoqLowpMode, + WoqActQuantMode, + QConfigWoq, ) from ._autotune import autotune +from ._quantize_utils import ( + quantize_per_channel, + dequantize_per_channel, + quantize_per_block, + dequantize_per_block, +) diff --git a/intel_extension_for_pytorch/quantization/_qconfig.py b/intel_extension_for_pytorch/quantization/_qconfig.py index bcfb58cdc..dbca1b49d 100644 --- a/intel_extension_for_pytorch/quantization/_qconfig.py +++ b/intel_extension_for_pytorch/quantization/_qconfig.py @@ -8,7 +8,10 @@ QConfig, QConfigMapping, ) -from ._smooth_quant import SmoothQuantActivationObserver, SmoothQuantWeightObserver +from ._smooth_quant import ( + SmoothQuantActivationObserver, + SmoothQuantWeightObserver, +) _default_weight_observer = PerChannelMinMaxObserver.with_args( @@ -36,12 +39,19 @@ default_dynamic_qconfig_mapping = QConfigMapping().set_global(default_dynamic_qconfig) +# Define QConfig for SmoothQuant by extending PyTorch's QConfig +QConfigSmoothQuant = namedtuple( + "QConfigSmoothQuant", [*QConfig._fields, "share_weight_observers"] +) + + def get_smooth_quant_qconfig_mapping( alpha=0.5, act_observer=None, act_ic_observer=None, wei_observer=None, wei_ic_observer=None, + share_weight_observers=True, ): """ Configuration with SmoothQuant for static quantization of large language models (LLM) @@ -53,21 +63,27 @@ def get_smooth_quant_qconfig_mapping( HistogramObserver by default. For nn.Linear with SmoothQuant enabled, q-param is calculated based on act_ic_observer's and wei_ic_observer's min/max. It is not affected by this argument. + Example: ``torch.ao.quantization.MinMaxObserver`` act_ic_observer: Per-input-channel Observer for activation. For nn.Linear with SmoothQuant enabled only. PerChannelMinMaxObserver by default. + Example: ``torch.ao.quantization.PerChannelMinMaxObserver.with_args(ch_axis=1)`` wei_observer: Observer for weight of all weighted ops. For nn.Linear with SmoothQuant enabled, it calculates q-params after applying scaling factors. PerChannelMinMaxObserver by default. + Example: ``torch.ao.quantization.PerChannelMinMaxObserver.with_args( + dtype=torch.qint8, qscheme=torch.per_channel_symmetric + )`` wei_ic_observer: Per-input-channel Observer for weight. For nn.Linear with SmoothQuant enabled only. PerChannelMinMaxObserver by default. + Example: ``torch.ao.quantization.PerChannelMinMaxObserver.with_args(ch_axis=1)`` Returns: torch.ao.quantization.QConfig """ - qconfig = QConfig( + qconfig = QConfigSmoothQuant( activation=SmoothQuantActivationObserver.with_args( reduce_range=False, alpha=alpha, @@ -81,6 +97,7 @@ def get_smooth_quant_qconfig_mapping( wei_observer=wei_observer, wei_ic_observer=wei_ic_observer, ), + share_weight_observers=share_weight_observers, ) return QConfigMapping().set_global(qconfig) @@ -93,15 +110,69 @@ class WoqLowpMode(IntEnum): INT8 = 3 -QConfigWoq = namedtuple("QConfigWoq", [*QConfig._fields, "lowp_mode"]) +class WoqActQuantMode(IntEnum): + NONE = -1 + PER_TENSOR = 0 + PER_IC_BLOCK = 1 # IC = Input Channel + PER_BATCH = 2 + PER_BATCH_IC_BLOCK = 3 + + +QConfigWoq = namedtuple( + "QConfigWoq", + [*QConfig._fields, "lowp_mode", "act_quant_mode", "weight_dtype", "group_size"], +) def get_weight_only_quant_qconfig_mapping( - *, weight_dtype: torch.dtype = torch.qint8, lowp_mode: int = WoqLowpMode.NONE + *, + weight_dtype: torch.dtype = torch.qint8, + lowp_mode: int = WoqLowpMode.NONE, + act_quant_mode: int = WoqActQuantMode.PER_IC_BLOCK, + group_size: int = -1 ): + """ + Configuration for weight-only quantization (WOQ) for LLM. + Arguments: + weight_dtype: Data type for weight, torch.qint8 (INT8) or torch.quint4x2 (INT4) + lowp_mode: specify the lowest precision data type for computation. Data types + that has even lower precision won't be used. + Not necessarily related to activation or weight dtype. + - NONE(0): Use the activation data type for computation. + - FP16(1): Use float16 (a.k.a. half) as the lowest precision for computation. + - BF16(2): Use bfloat16 as the lowest precision for computation. + - INT8(3): Use INT8 as the lowest precision for computation. + Activation is quantized to int8 at runtime in this case. + Note that lowp_mode=INT8(3) is only available when weight_dtype=torch.quint4x2. + In other cases, it will fall back to lowp_mode=BF16(2). + act_quant_mode: Quantization granularity of activation. It only works for lowp_mode=INT8. + It has no effect in other cases. The tensor is divided into groups, and + each group is quantized with its own quantization parameters. + Suppose the activation has shape batch_size by input_channel (IC). + - PER_TENSOR(0): Use the same quantization parameters for the entire tensor. + - PER_IC_BLOCK(1): Tensor is divided along IC with group size = IC_BLOCK. + - PER_BATCH(2): Tensor is divided along batch_size with group size = 1. + - PER_BATCH_IC_BLOCK(3): Tenosr is divided into blocks of 1 x IC_BLOCK. + Note that IC_BLOCK is determined by group_size automatically. + group_size: Control quantization granularity along input channel (IC) dimension of weight. + Must be a positive power of 2 (i.e., 2^k, k > 0) or -1. + If group_size = -1: + If act_quant_mode = PER_TENSOR ro PER_BATCH: + No grouping along IC for both activation and weight + If act_quant_mode = PER_IC_BLOCK or PER_BATCH_IC_BLOCK: + No grouping along IC for weight. For activation, + IC_BLOCK is determined automatically by IC. + If group_size > 0: + act_quant_mode can be any. If act_quant_mode is PER_IC_BLOCK + or PER_BATCH_IC_BLOCK, weight is grouped along IC by group_size. + The IC_BLOCK for activation is determined by group_size automatically. + Each group has its own quantization parameters. + """ + assert group_size == -1 or ( + group_size > 0 and (group_size & (group_size - 1)) == 0 + ), "Group size must be -1 or a positive power of 2, but got {}".format(group_size) dtype_to_qscheme = { torch.qint8: torch.per_channel_affine, - torch.quint8: torch.per_channel_affine, # It is required to use per_channel_affine_float_qparams for quint4x2 by PyTorch torch.quint4x2: torch.per_channel_affine_float_qparams, } @@ -112,6 +183,9 @@ def get_weight_only_quant_qconfig_mapping( dtype=weight_dtype, qscheme=weight_qscheme ), lowp_mode=lowp_mode, + act_quant_mode=act_quant_mode, + weight_dtype=weight_dtype, + group_size=group_size, ) weight_only_quant_qconfig_mapping = QConfigMapping().set_global( _weight_only_quant_qconfig diff --git a/intel_extension_for_pytorch/quantization/_quantization_state.py b/intel_extension_for_pytorch/quantization/_quantization_state.py index 9dcd8c3ec..3dfd60608 100644 --- a/intel_extension_for_pytorch/quantization/_quantization_state.py +++ b/intel_extension_for_pytorch/quantization/_quantization_state.py @@ -391,7 +391,27 @@ def op_convert_before_hook( act_key = str(self.idx) if act_key in self.idx_to_smooth_quant_scaling_factor: act_scaling_factors = self.idx_to_smooth_quant_scaling_factor[act_key] - if act_scaling_factors is not None: + # if users modifies qconf.json and cancals quantization of the linear, + # then any_arg_quant_or_dequant_needed[0] is False. Don't insert mul in this case. + if act_scaling_factors is not None and any_arg_quant_or_dequant_needed[0]: + w_key = str(self.idx) + "_0" + act_scaling_factors = ( + act_scaling_factors[w_key] + if len(act_scaling_factors) > 1 + else next(iter(act_scaling_factors.values())) + ) + # update arg_quant_infos + scale = ( + arg_quant_infos[0][0][w_key] + if len(arg_quant_infos[0][0]) > 1 + else next(iter(arg_quant_infos[0][0].values())) + ) + zp = ( + arg_quant_infos[0][1][w_key] + if len(arg_quant_infos[0][1]) > 1 + else next(iter(arg_quant_infos[0][1].values())) + ) + arg_quant_infos = [(scale, zp, arg_quant_infos[0][2])] args = list(args) new_act = torch.mul(args[0], act_scaling_factors) args[0] = new_act @@ -1050,19 +1070,29 @@ def _maybe_insert_input_observers(self, seen_q_op_info: SeenQOpInfo): str(seen_q_op_info.idx) + "_" + str(w_tensor_id) ] # Duplicate input: - # (1) In modules like MHA, multiple linear layers may share the same activation tensor - # In other words, multiple weight tensors share one activation tensor - # In this case, we regard these weights as a single big tensor (i.e., concat along OC axis). - # When calculating scaling factor, consider per-IC min/max of the big tensor - # So, these weights share the same per-IC observer + # (1) In some cases, multiple linear layers share the same activation (like QKV). + # - If qconfig specifies share_weight_observers=True (default), we regard these + # weights as a single big tensor (i.e., concat along OC axis) during + # calibration. So, these weights share the same per-IC observer. + # But weights are not actually concated for computation. + # - If qconfig specifies share_weight_observers=False, they use different observers. # (2) It is also possible that linear shares activation with some non-weighted op. # In that case, x_obs.weight_obs is not set. Also check it here. + w_id_str = str(seen_q_op_info.idx) + "_" + str(w_tensor_id) if not found_duplicate_input or x_obs.weight_obs is None: - x_obs.weight_obs = w_obs.ic_obs + x_obs.weight_obs = {w_id_str: w_obs.ic_obs} else: - # The input (activation) has been used by other linear ops - # Weight should share the same per-IC observer with that linear - w_obs.ic_obs = x_obs.weight_obs + # The input (activation) is shared by more than one linear layers + if getattr(qconfig, "share_weight_observers", True): + # Weights of these layers share the same per-IC observer + assert ( + isinstance(x_obs.weight_obs, dict) + and len(x_obs.weight_obs) == 1 + ) + w_obs.ic_obs = next(iter(x_obs.weight_obs.values())) + else: + # Weights of these layers use different observers + x_obs.weight_obs.update({w_id_str: w_obs.ic_obs}) # In all cases, weight observer holds a reference to activation's per-IC observer w_obs.act_obs = x_obs.ic_obs # For all linear ops, set smooth_quant_enabled to true diff --git a/intel_extension_for_pytorch/quantization/_quantize_utils.py b/intel_extension_for_pytorch/quantization/_quantize_utils.py index 71a2938f7..00e5a8f36 100644 --- a/intel_extension_for_pytorch/quantization/_quantize_utils.py +++ b/intel_extension_for_pytorch/quantization/_quantize_utils.py @@ -783,3 +783,241 @@ def unwrap_proxy(a): swap_child_modules(module) module.__class__ = QuantizationDispatchModule return module + + +def quantize_per_channel(t: torch.Tensor, is_int4, scales=None, zero_points=None): + r""" + Quantize a weight tensor of Linear modules per channel. + Assume the tensor shape is [output channel, input channel], + each output channel has its own quantization parameters. + + Args: + input: The tensor to be quantized + is_int4: int4 or int8 + + Returns: + A tuple of + - The quantized tensor + - Scales + - Zero points + """ + assert t.ndim == 2 + + def get_qparams(scales, zps): + if scales is not None and zps is not None: + return scales, zps + eps = torch.tensor([torch.finfo(torch.float32).eps]) + zeros = torch.zeros(t.shape[0], dtype=t.dtype, device=t.device) + mins = torch.minimum(t.min(dim=1)[0], zeros) + maxs = torch.maximum(t.max(dim=1)[0], zeros) + scales = (maxs - mins) / 15 if is_int4 else (maxs - mins) / 255 + scales = torch.max(scales, eps) + zps = -torch.round(mins / scales) + if not is_int4: + zps -= 128 + return scales, zps + + scales, zps = get_qparams(scales, zero_points) + qmin = 0 if is_int4 else -127 + qmax = 15 if is_int4 else 127 + qt = torch.clamp( + torch.round(t / scales.unsqueeze(1)) + zps.unsqueeze(1), min=qmin, max=qmax + ) + qt = qt.to(torch.uint8) if is_int4 else qt.to(torch.int8) + if is_int4: + if qt.size(-1) % 2: + qt = torch.nn.functional.pad(qt, (0, 1), value=0) + qt = qt[:, 1::2].bitwise_left_shift(4).bitwise_or_(qt[:, ::2]) + return qt.contiguous(), scales, zps + + +def dequantize_per_channel( + qt: torch.Tensor, + scales: torch.Tensor, + zps: torch.Tensor, + is_int4, + weight_shape=None, +): + r""" + Dequantize a weight tensor of Linear modules per channel. + Assume the tensor shape is [output channel, input channel], + each output channel has its own quantization parameters. + + Args: + qt: The tensor to be dequantized + scales: Scales for dequantization + zps: Zero points for dequantization + is_int4: int4 or int8 + weight_shape: True weight shape. INT4 tensor's input channel may + be padded to even, so we need this to return the correct weight. + + Returns: + The dequantized tensor + """ + assert qt.ndim == 2 + scales = scales.squeeze() + zps = zps.squeeze() + if is_int4: + t = torch.empty( + qt.shape[0], qt.shape[1] * 2, dtype=torch.uint8, device=qt.device + ) + t[:, ::2] = qt.bitwise_and(0xF) + t[:, 1::2] = qt.bitwise_right_shift(4) + t = (t.to(torch.float) - zps.unsqueeze(-1)) * scales.unsqueeze(-1) + if weight_shape is not None: + t = t[: weight_shape[0], : weight_shape[1]].contiguous() + return t + else: + return (qt.to(torch.float) - zps.unsqueeze(-1)) * scales.unsqueeze(-1) + + +def quantize_per_block( + input: torch.Tensor, is_int4, group_size, scales=None, zero_points=None +): + r""" + Quantize a weight tensor of Linear modules per block. + Assume the tensor shape is [output channel, input channel], + block shape is [1, group_size]. + + Args: + input: The tensor to be quantized + is_int4: int4 or int8 + group_size: Size of group along input channel + scales: Scales for quantization. If None, find by min/max. + zero_points: zero points for quantization. If None, find by min/max. + + Returns: + A tuple of + - The quantized tensor + - Scales in shape [N, #block_k] + - Zero points in shape [N, #block_k] + """ + assert ( + input.dim() == 2 + ), f"{__name__}: Expect input has 2 dimensions but got {input.dim()}" + assert group_size > 0, f"{__name__}: Expect group_size > 0 but got {group_size}" + N = input.size(0) + K = input.size(1) + k_rem = K % group_size + has_rem = k_rem != 0 + + def get_qparams(scales, zps): + if scales is not None and zps is not None: + return scales, zps + eps = torch.tensor([torch.finfo(torch.float32).eps]) + t_com = input[:, : K - k_rem].view(N, K // group_size, group_size) + mins = torch.minimum(t_com.min(dim=-1)[0], torch.tensor([0])) + maxs = torch.maximum(t_com.max(dim=-1)[0], torch.tensor([0])) + scales = (maxs - mins) / 15 if is_int4 else (maxs - mins) / 255 + scales = torch.max(scales, eps) + zps = -torch.round(mins / scales) + if k_rem != 0: + t_rem = input[:, K - k_rem :].view(N, 1, k_rem) + mins_rem = torch.minimum(t_rem.min(dim=-1)[0], torch.tensor([0])) + maxs_rem = torch.maximum(t_rem.max(dim=-1)[0], torch.tensor([0])) + scales_rem = ( + (maxs_rem - mins_rem) / 15 if is_int4 else (maxs_rem - mins_rem) / 255 + ) + zps_rem = -torch.round(mins_rem / scales_rem) + scales = torch.cat([scales, scales_rem], dim=-1) + zps = torch.cat([zps, zps_rem], dim=-1) + if not is_int4: + zps -= 128 + return scales, zps + + scales, zps = get_qparams(scales, zero_points) + qmin = 0 if is_int4 else -127 + qmax = 15 if is_int4 else 127 + Kc = (K + group_size - 1) // group_size + t_com = input[:, : K - k_rem].view(N, K // group_size, group_size) + scales_com = scales[:, : Kc - has_rem] + zps_com = zps[:, : Kc - has_rem] + qt = torch.clamp( + torch.round(t_com / scales_com.unsqueeze(-1)) + zps_com.unsqueeze(-1), + min=qmin, + max=qmax, + ) + qt = qt.view(N, K // group_size * group_size) + if k_rem != 0: + t_rem = input[:, K - k_rem :].view(N, 1, k_rem) + scales_rem = scales[:, Kc - has_rem :] + zps_rem = zps[:, Kc - has_rem :] + qt_rem = torch.clamp( + torch.round(t_rem / scales_rem.unsqueeze(-1)) + zps_rem.unsqueeze(-1), + min=qmin, + max=qmax, + ) + qt_rem = qt_rem.view(N, k_rem) + qt = torch.cat([qt, qt_rem], dim=1).contiguous() + qt = qt.to(torch.uint8) if is_int4 else qt.to(torch.int8) + qt = qt.view(N, K) + if is_int4: + if qt.size(-1) % 2: + qt = torch.nn.functional.pad(qt, (0, 1), value=0) + qt = qt[:, 1::2].bitwise_left_shift(4).bitwise_or_(qt[:, ::2]) + return qt.contiguous(), scales, zps + + +def dequantize_per_block( + qt: torch.Tensor, + scales: torch.Tensor, + zps: torch.Tensor, + is_int4, + group_size, + weight_shape=None, +): + r""" + Dequantize a weight tensor of Linear modules per block. + Assume the tensor shape is [output channel, input channel], + block shape is [1, group_size]. + + Args: + qt: The tensor to be dequantized + scales: Scales in shape [N, #block_k] + zps: Zero points in shape [N, #block_k] + is_int4: int4 or int8 + group_size: Size of group along input channel + block_oc: Block size of output channel, should be the same for weight packing + + Returns: + The dequantized tensor + """ + N = qt.size(0) + K = qt.size(1) * 2 if is_int4 else qt.size(1) + if scales.dim() > 2: + scales = scales.squeeze() + zps = zps.squeeze() + if is_int4: + t = torch.empty( + qt.shape[0], qt.shape[1] * 2, dtype=torch.uint8, device=qt.device + ) + t[:, ::2] = qt.bitwise_and(0xF) + t[:, 1::2] = qt.bitwise_right_shift(4) + qt = t + k_rem = K % group_size + has_rem = k_rem != 0 + Kc = (K + group_size - 1) // group_size + qt_com = qt[:, : K - k_rem].view(N, K // group_size, group_size) + scales_com = scales[:, : Kc - has_rem] + zps_com = zps[:, : Kc - has_rem] + t = ( + ((qt_com.to(torch.float) - zps_com.unsqueeze(-1)) * scales_com.unsqueeze(-1)) + .view(N, K - k_rem) + .contiguous() + ) + if k_rem: + qt_rem = qt[:, K - k_rem :].view(N, 1, k_rem) + scales_rem = scales[:, Kc - has_rem :] + zps_rem = zps[:, Kc - has_rem :] + t_rem = ( + ( + (qt_rem.to(torch.float) - zps_rem.unsqueeze(-1)) + * scales_rem.unsqueeze(-1) + ) + .view(N, k_rem) + .contiguous() + ) + t = torch.cat([t, t_rem], dim=1).contiguous() + if weight_shape is not None: + t = t[: weight_shape[0], : weight_shape[1]].contiguous() + return t diff --git a/intel_extension_for_pytorch/quantization/_smooth_quant.py b/intel_extension_for_pytorch/quantization/_smooth_quant.py index a56e5a4c2..ae16301e8 100644 --- a/intel_extension_for_pytorch/quantization/_smooth_quant.py +++ b/intel_extension_for_pytorch/quantization/_smooth_quant.py @@ -24,9 +24,6 @@ class SmoothQuantActivationObserver(UniformQuantizationObserverBase): just act as a normal observer """ - # As a 1d tensor, not diagonal - scaling_factors: torch.Tensor - def __init__( self, act_observer=None, @@ -63,14 +60,7 @@ def __init__( eps=eps, ) else: - assert isinstance(act_ic_observer, UniformQuantizationObserverBase), ( - f"act_ic_observer should be an instance of UniformQuantizationObserverBase " - f"or its subclass but got {type(act_ic_observer)}" - ) - assert hasattr( - act_ic_observer, "ch_axis" - ), "act_ic_observer should be a per-channel observer and observe input channel axis" - self.ic_obs = act_ic_observer + self.ic_obs = act_ic_observer() if act_observer is None: self.act_obs = HistogramObserver( dtype=dtype, @@ -82,8 +72,7 @@ def __init__( eps=eps, ) else: - assert isinstance(act_observer, UniformQuantizationObserverBase) - self.act_obs = act_observer + self.act_obs = act_observer() # if smooth_quant_enabled is false, this observer acts as # a normal per-tensor observer self.smooth_quant_enabled = smooth_quant_enabled @@ -92,10 +81,14 @@ def __init__( # They are for checks, like `_check_observer_has_run` self.min_val = self.act_obs.min_val self.max_val = self.act_obs.max_val + # Dict of tensors. Keys are weight IDs. Factors are 1d tensors, not diagonal + self.scaling_factors = {} def forward(self, x_orig): if not self.smooth_quant_enabled: return self.act_obs.forward(x_orig) + # Run act_obs to indicate the observer has run + self.act_obs.forward(x_orig) # Call per-channel observer on IC to find scaling factor return self.ic_obs.forward(x_orig) @@ -103,31 +96,37 @@ def forward(self, x_orig): def calculate_qparams(self): if not self.smooth_quant_enabled: return self.act_obs.calculate_qparams() - # Get weight per IC min/max from weight observer - wei_min_per_ic = self.weight_obs.min_val - wei_max_per_ic = self.weight_obs.max_val - act_min_per_ic = self.ic_obs.min_val - act_max_per_ic = self.ic_obs.max_val - x_abs_max_per_ic = ( - torch.max(torch.abs(act_min_per_ic), torch.abs(act_max_per_ic)) + 1e-6 - ) - w_abs_max_per_ic = ( - torch.max(torch.abs(wei_min_per_ic), torch.abs(wei_max_per_ic)) + 1e-6 - ) - # Note: activation's scaling factors are reciprocals of weight's - self.scaling_factors = torch.pow(w_abs_max_per_ic, 1 - self.alpha) / torch.pow( - x_abs_max_per_ic, self.alpha - ) - # Apply scaling factors to each IC's min/max - act_min_per_ic_new = act_min_per_ic * self.scaling_factors.reshape( - act_min_per_ic.shape - ) - act_max_per_ic_new = act_max_per_ic * self.scaling_factors.reshape( - act_max_per_ic.shape - ) - min_val_per_tensor = torch.min(act_min_per_ic_new) - max_val_per_tensor = torch.max(act_max_per_ic_new) - return self._calculate_qparams(min_val_per_tensor, max_val_per_tensor) + scales, zero_points = {}, {} + for k in self.weight_obs.keys(): + # Get weight per IC min/max from weight observer + wei_min_per_ic = self.weight_obs[k].min_val + wei_max_per_ic = self.weight_obs[k].max_val + act_min_per_ic = self.ic_obs.min_val + act_max_per_ic = self.ic_obs.max_val + x_abs_max_per_ic = ( + torch.max(torch.abs(act_min_per_ic), torch.abs(act_max_per_ic)) + 1e-6 + ) + w_abs_max_per_ic = ( + torch.max(torch.abs(wei_min_per_ic), torch.abs(wei_max_per_ic)) + 1e-6 + ) + # Note: activation's scaling factors are reciprocals of weight's + scaling_factor = torch.pow(w_abs_max_per_ic, 1 - self.alpha) / torch.pow( + x_abs_max_per_ic, self.alpha + ) + self.scaling_factors.update({k: scaling_factor}) + # Apply scaling factors to each IC's min/max + act_min_per_ic_new = act_min_per_ic * scaling_factor.reshape( + act_min_per_ic.shape + ) + act_max_per_ic_new = act_max_per_ic * scaling_factor.reshape( + act_max_per_ic.shape + ) + min_val_per_tensor = torch.min(act_min_per_ic_new) + max_val_per_tensor = torch.max(act_max_per_ic_new) + scale, zp = self._calculate_qparams(min_val_per_tensor, max_val_per_tensor) + scales.update({k: scale}) + zero_points.update({k: zp}) + return scales, zero_points def get_scaling_factors(self): if not self.smooth_quant_enabled: @@ -198,8 +197,7 @@ def __init__( eps=eps, ) else: - assert isinstance(wei_observer, UniformQuantizationObserverBase) - self.oc_obs = wei_observer + self.oc_obs = wei_observer() if wei_ic_observer is None: self.ic_obs = PerChannelMinMaxObserver( ch_axis=1, @@ -212,14 +210,7 @@ def __init__( eps=eps, ) else: - assert isinstance(wei_ic_observer, UniformQuantizationObserverBase), ( - f"wei_ic_observer should be an instance of UniformQuantizationObserverBase " - f"or its subclass but got {type(wei_ic_observer)}" - ) - assert hasattr( - wei_ic_observer, "ch_axis" - ), "wei_ic_observer should be a per-channel observer and observe input channel axis" - self.ic_obs = wei_ic_observer + self.ic_obs = wei_ic_observer() # if smooth_quant_enabled is false, this observer acts as # a normal observer self.smooth_quant_enabled = smooth_quant_enabled diff --git a/intel_extension_for_pytorch/quantization/_utils.py b/intel_extension_for_pytorch/quantization/_utils.py index 026e930c8..bac5424b5 100644 --- a/intel_extension_for_pytorch/quantization/_utils.py +++ b/intel_extension_for_pytorch/quantization/_utils.py @@ -14,6 +14,7 @@ from ._quantization_state_utils import QTensorInfo from ._smooth_quant import SmoothQuantActivationObserver, SmoothQuantWeightObserver +from ._qconfig import QConfigSmoothQuant from intel_extension_for_pytorch.nn.modules import MergedEmbeddingBagWithCat add_and_mul_ops = set( @@ -659,7 +660,7 @@ def _create_observer(setting): ] for key in smooth_quant_sub_obs_keys: if key in setting: - setting[key] = _create_observer(setting[key])() + setting[key] = _create_observer(setting[key]) return observer.with_args(**setting) else: raise NameError("torch.quantization.observer %s not found" % setting["name"]) @@ -692,12 +693,29 @@ def save_quant_state(quant_state_map, configure_file): cur_tensor_infos["inf_dtype"] = str(tensor_info.inf_dtype) cur_tensor_infos["force_dtype"] = str(force_dtype) if tensor_info.id in v.tensor_id_to_scale_zp: - cur_tensor_infos["scale"] = v.tensor_id_to_scale_zp[ - tensor_info.id - ][0].tolist() - cur_tensor_infos["zero_point"] = v.tensor_id_to_scale_zp[ - tensor_info.id - ][1].tolist() + if isinstance( + v.tensor_id_to_scale_zp[tensor_info.id][0], torch.Tensor + ): + cur_tensor_infos["scale"] = v.tensor_id_to_scale_zp[ + tensor_info.id + ][0].tolist() + cur_tensor_infos[ + "zero_point" + ] = v.tensor_id_to_scale_zp[tensor_info.id][1].tolist() + else: + scales_dict = v.tensor_id_to_scale_zp[tensor_info.id][0] + zp_dict = v.tensor_id_to_scale_zp[tensor_info.id][1] + assert isinstance(scales_dict, dict) and isinstance( + zp_dict, dict + ) + scales_to_save = {} + zp_to_save = {} + for key, val in scales_dict.items(): + scales_to_save.update({key: val.tolist()}) + for key, val in zp_dict.items(): + zp_to_save.update({key: val.tolist()}) + cur_tensor_infos["scale"] = scales_to_save + cur_tensor_infos["zero_point"] = zp_to_save if ( str(tensor_info.id) in v.tensor_id_to_smooth_quant_scaling_factor @@ -706,11 +724,18 @@ def save_quant_state(quant_state_map, configure_file): ] is not None ): + scaling_factor_dict = ( + v.tensor_id_to_smooth_quant_scaling_factor[ + str(tensor_info.id) + ] + ) + assert isinstance(scaling_factor_dict, dict) + scaling_factors_to_save = {} + for key, val in scaling_factor_dict.items(): + scaling_factors_to_save.update({key: val.tolist()}) cur_tensor_infos[ "smooth_quant_scaling_factor" - ] = v.tensor_id_to_smooth_quant_scaling_factor[ - str(tensor_info.id) - ].tolist() + ] = scaling_factors_to_save smooth_quant_enabled = True input_tensor_infos.append(cur_tensor_infos) info["input_tensor_infos"] = input_tensor_infos @@ -749,18 +774,49 @@ def save_quant_state(quant_state_map, configure_file): cur_tensor_infos["orig_dtype"] = str(tensor_info.orig_dtype) cur_tensor_infos["inf_dtype"] = str(tensor_info.inf_dtype) if tensor_info.id in v.tensor_id_to_scale_zp: - cur_tensor_infos["scale"] = v.tensor_id_to_scale_zp[ - tensor_info.id - ][0].tolist() - cur_tensor_infos["zero_point"] = v.tensor_id_to_scale_zp[ - tensor_info.id - ][1].tolist() - if tensor_info.id in v.tensor_id_to_smooth_quant_scaling_factor: + if isinstance( + v.tensor_id_to_scale_zp[tensor_info.id][0], torch.Tensor + ): + cur_tensor_infos["scale"] = v.tensor_id_to_scale_zp[ + tensor_info.id + ][0].tolist() + cur_tensor_infos[ + "zero_point" + ] = v.tensor_id_to_scale_zp[tensor_info.id][1].tolist() + else: + scales_dict = v.tensor_id_to_scale_zp[tensor_info.id][0] + zp_dict = v.tensor_id_to_scale_zp[tensor_info.id][1] + assert isinstance(scales_dict, dict) and isinstance( + zp_dict, dict + ) + scales_to_save = {} + zp_to_save = {} + for key, val in scales_dict.items(): + scales_to_save.update({key: val.tolist()}) + for key, val in zp_dict.items(): + zp_to_save.update({key: val.tolist()}) + cur_tensor_infos["scale"] = scales_to_save + cur_tensor_infos["zero_point"] = zp_to_save + if ( + str(tensor_info.id) + in v.tensor_id_to_smooth_quant_scaling_factor + ): + scaling_factors = ( + v.tensor_id_to_smooth_quant_scaling_factor[ + str(tensor_info.id) + ] + ) + scaling_factors_to_save = None + if scaling_factors is not None: + assert isinstance( + scaling_factors, dict + ), f"Expect scaling factors is a dict but found {type(scaling_factors)}" + scaling_factors_to_save = {} + for key, val in scaling_factors.items(): + scaling_factors_to_save.update({key: val.tolist()}) cur_tensor_infos[ "smooth_quant_scaling_factor" - ] = v.tensor_id_to_smooth_quant_scaling_factor[ - tensor_info.id - ].tolist() + ] = scaling_factors_to_save output_tensor_infos.append(cur_tensor_infos) info["output_tensor_infos"] = output_tensor_infos # qconfig @@ -779,6 +835,9 @@ def save_quant_state(quant_state_map, configure_file): info["activation_observer"][ "act_ic_observer" ] = _get_observer_setting(op_info.qconfig.activation().ic_obs) + info["share_weight_observers"] = getattr( + op_info.qconfig, "share_weight_observers", True + ) info["weight_observer"] = _get_observer_setting( op_info.qconfig.weight() ) @@ -876,16 +935,41 @@ def load_qconf_summary_to_model(model, qconf_summary): dtype_dict[tensor_info["force_dtype"]] ) if "scale" in tensor_info: - scale = torch.FloatTensor(tensor_info["scale"]) - zp = torch.LongTensor(tensor_info["zero_point"]) + if isinstance(tensor_info["scale"], list): + scale = torch.FloatTensor(tensor_info["scale"]) + zp = torch.LongTensor(tensor_info["zero_point"]) + else: + scale, zp = {}, {} + scale_to_load = tensor_info["scale"] + zp_to_load = tensor_info["zero_point"] + assert isinstance(scale_to_load, dict) and isinstance( + zp_to_load, dict + ), ( + "Expect scales and zero points to load are dicts but " + f"found types {type(scale_to_load)} and {type(zp_to_load)}" + ) + for key, val in scale_to_load.items(): + s = torch.FloatTensor(val) + scale.update({key: s}) + for key, val in zp_to_load.items(): + z = torch.LongTensor(val) + zp.update({key: z}) v.tensor_id_to_scale_zp[tensor_info["id"]] = (scale, zp) if "smooth_quant_scaling_factor" in tensor_info: - scaling_factor = torch.FloatTensor( - tensor_info["smooth_quant_scaling_factor"] + scaling_factors = {} + scaling_factors_to_load = tensor_info[ + "smooth_quant_scaling_factor" + ] + assert isinstance(scaling_factors_to_load, dict), ( + f"Expect scaling factors to load are a dict but found type " + f"{type(scaling_factors_to_load)}" ) + for key, val in scaling_factors_to_load.items(): + scaling_factor = torch.FloatTensor(val) + scaling_factors.update({key: scaling_factor}) v.tensor_id_to_smooth_quant_scaling_factor[ str(tensor_info["id"]) - ] = scaling_factor + ] = scaling_factors else: input_tensor_infos.append(None) input_force_dtype_infos.append(None) @@ -929,17 +1013,43 @@ def load_qconf_summary_to_model(model, qconf_summary): ) insert_fake_quant_after_outputs.append(False) if "scale" in tensor_info: - scale = torch.FloatTensor(tensor_info["scale"]) - zp = torch.LongTensor(tensor_info["zero_point"]) + if isinstance(tensor_info["scale"], list): + scale = torch.FloatTensor(tensor_info["scale"]) + zp = torch.LongTensor(tensor_info["zero_point"]) + else: + scale, zp = {}, {} + scale_to_load = tensor_info["scale"] + zp_to_load = tensor_info["zero_point"] + assert isinstance(scale_to_load, dict) and isinstance( + zp_to_load, dict + ), ( + "Expect scales and zero points to load are dicts but " + f"found types {type(scale_to_load)} and {type(zp_to_load)}" + ) + for key, val in scale_to_load.items(): + s = torch.FloatTensor(val) + scale.update({key: s}) + for key, val in zp_to_load.items(): + z = torch.LongTensor(val) + zp.update({key: z}) v.tensor_id_to_scale_zp[tensor_info["id"]] = (scale, zp) else: output_tensor_infos.append(None) activation_observer = q_op_info["activation_observer"] weight_observer = q_op_info["weight_observer"] - qconfig = QConfig( - activation=_create_observer(activation_observer), - weight=_create_observer(weight_observer), - ) + activation_obs = _create_observer(activation_observer) + if isinstance(activation_obs(), SmoothQuantActivationObserver): + share_weight_observers = q_op_info.get("share_weight_observers", True) + qconfig = QConfigSmoothQuant( + activation=activation_obs, + weight=_create_observer(weight_observer), + share_weight_observers=share_weight_observers, + ) + else: + qconfig = QConfig( + activation=activation_obs, + weight=_create_observer(weight_observer), + ) # overide the cur model's info v.idx_to_seen_q_op_infos[int(i)].input_tensor_infos = input_tensor_infos v.idx_to_seen_q_op_infos[ diff --git a/intel_extension_for_pytorch/transformers/models/cpu/fusions/linear_fusion.py b/intel_extension_for_pytorch/transformers/models/cpu/fusions/linear_fusion.py index 38094f095..1634482fe 100644 --- a/intel_extension_for_pytorch/transformers/models/cpu/fusions/linear_fusion.py +++ b/intel_extension_for_pytorch/transformers/models/cpu/fusions/linear_fusion.py @@ -6,6 +6,8 @@ from intel_extension_for_pytorch.nn.modules import IpexWoqLinear from intel_extension_for_pytorch.quantization import ( get_weight_only_quant_qconfig_mapping, + dequantize_per_channel, + dequantize_per_block, ) @@ -155,15 +157,6 @@ def forward(self, x): self.linear.bias if self.linear.bias is not None else x.new_empty(0), self.linear.out_features, ) - if ( - self.woq - and hasattr(self.linear, "_op_context") - and self.linear._op_context is not None - ): - return torch.ops.torch_ipex.woq_linear_gelu( - x, - self.linear._op_context.get_data_handle(), - ) else: # fallback path x = self.linear(x) @@ -193,6 +186,15 @@ def forward(self, x): self.linear.weight, self.linear.bias if self.linear.bias is not None else x.new_empty(0), ) + if ( + self.woq + and hasattr(self.linear, "_op_context") + and self.linear._op_context is not None + ): + return torch.ops.torch_ipex.woq_linear_gelu( + x, + self.linear._op_context.get_data_handle(), + ) else: # fallback path x = self.gelu(self.linear(x)) return x @@ -214,14 +216,7 @@ def __init__(self, module, tpp=False, woq=False): isinstance(linear, IpexWoqLinear) for linear in self.linear_list ): # Quantization is done before lowering to CPU. - # We assume weights are all in shape [N, K] and per-channel quantized, axis = 0. - # And it must be one of the two cases below. - # Case 1: - # - weight dtype = qint8, qscheme = torch.per_channel_affine, - # - scales dtype = float, zero points dtype = int - # Case 2: - # - weight dtype = quint4x2, qscheme = torch.per_channel_affine_float_qparams, - # - scales dtype = float, zero points dtype = float + # We assume weights are all in shape [N, K]. # We need to unpack weights then concat them weights_list = [] scales_list = [] @@ -229,8 +224,13 @@ def __init__(self, module, tpp=False, woq=False): bias_list = [] w_dtype = self.linear_list[0].dtype lowp_mode = self.linear_list[0]._lowp_mode + act_quant_mode = self.linear_list[0]._act_quant_mode + group_size = self.linear_list[0]._group_size qconfig_mapping = get_weight_only_quant_qconfig_mapping( - weight_dtype=w_dtype, lowp_mode=lowp_mode + weight_dtype=w_dtype, + lowp_mode=lowp_mode, + act_quant_mode=act_quant_mode, + group_size=group_size, ) qconfig = qconfig_mapping.global_qconfig for i in range(self.num_concat): @@ -244,32 +244,41 @@ def __init__(self, module, tpp=False, woq=False): weights_list = [] break qw = linear._op_context.to_public(linear._op_context.get_weight()) - if ( - qw.qscheme() - not in [ - torch.per_channel_affine, - torch.per_channel_affine_float_qparams, - ] - or qw.q_per_channel_axis() != 0 - ): - warnings.warn( - "Concat linear fusion for CPU WOQ failed " - "because quantization type of weight is not supported. " - "Falling back to separate linears." + scales = linear._op_context.get_scales() + zero_points = linear._op_context.get_zero_points() + is_int4 = w_dtype == torch.quint4x2 + weight_shape = linear._op_context.get_weight_shape() + if group_size > 0: + weights_list.append( + dequantize_per_block( + qw, scales, zero_points, is_int4, group_size, weight_shape + ) ) - weights_list = [] - break - s = qw.q_per_channel_scales().float() - z = qw.q_per_channel_zero_points().float() - weights_list.append(qw.dequantize().float()) - scales_list.append(s) - zeros_list.append(z) - bias_list.append(linear._op_context.get_bias()) + else: + weights_list.append( + dequantize_per_channel( + qw, scales, zero_points, is_int4, weight_shape + ) + ) + # OC of Weight may be padded to a multiple of block_n. So are scales and zero points. + bias = linear._op_context.get_bias() + assert scales.shape == zero_points.shape + assert bias is None or bias.shape[0] == scales.shape[0] + if weight_shape[0] < scales.shape[0]: + original_n = weight_shape[0] + scales_list.append(scales.narrow(0, 0, original_n).contiguous()) + zeros_list.append(zero_points.narrow(0, 0, original_n).contiguous()) + bias_list.append(bias.narrow(0, 0, original_n).contiguous()) + else: + assert weight_shape[0] == scales.shape[0] + scales_list.append(scales) + zeros_list.append(zero_points) + bias_list.append(bias) w_dtype = linear.dtype if weights_list: concat_weight = torch.concat(weights_list, 0) - concat_scales = torch.concat(scales_list, -1) - concat_zeros = torch.concat(zeros_list, -1) + concat_scales = torch.concat(scales_list, 0) + concat_zeros = torch.concat(zeros_list, 0) use_bias = all(bias_list) concat_bias = torch.concat(bias_list, 0) if use_bias else None mod = nn.Linear( @@ -281,11 +290,17 @@ def __init__(self, module, tpp=False, woq=False): mod._num_concats = len(weights_list) if w_dtype == torch.quint4x2: self.concat_linear = IpexWoqLinear.from_float_and_int4_weight( - mod, concat_weight, concat_scales, concat_zeros + mod, + concat_weight, + concat_scales, + concat_zeros, + group_size=group_size, ) else: # qint8 assert w_dtype == torch.qint8 - self.concat_linear = IpexWoqLinear.from_float(mod) + self.concat_linear = IpexWoqLinear.from_float( + mod, concat_scales, concat_zeros + ) else: for i in range(self.num_concat): attr_name = f"linear_{i}" diff --git a/intel_extension_for_pytorch/transformers/optimize.py b/intel_extension_for_pytorch/transformers/optimize.py index 956aadbd2..17650fc27 100644 --- a/intel_extension_for_pytorch/transformers/optimize.py +++ b/intel_extension_for_pytorch/transformers/optimize.py @@ -349,22 +349,25 @@ def ipex_quantization_flow( ): from intel_extension_for_pytorch.quantization import prepare, convert - if not _is_woq_qconfig(qconfig) and sample_inputs is None: + is_woq = _is_woq_qconfig(qconfig) + if not is_woq and sample_inputs is None: sample_inputs = get_dummy_input(_model) prepared_model = prepare( _model.eval(), qconfig, example_inputs=sample_inputs, inplace=True ) + if static_qconfig_file is not None: prepared_model.load_qconf_summary(qconf_summary=static_qconfig_file) print("ipex.optimize_transformers is doing the static quantization") else: print("ipex.optimize_transformers is doing the weight only quantization") + with torch.no_grad(), torch.cpu.amp.autocast( enabled=True if dtype is torch.bfloat16 else False ): convert_model = convert(prepared_model.eval(), inplace=True).eval() - if _is_woq_qconfig(qconfig) and dtype is torch.bfloat16: + if is_woq and dtype is torch.bfloat16: convert_model = convert_model.to(dtype) return convert_model diff --git a/intel_extension_for_pytorch/utils/weight_only_quantization.py b/intel_extension_for_pytorch/utils/weight_only_quantization.py index 138b6b4aa..818b165d9 100644 --- a/intel_extension_for_pytorch/utils/weight_only_quantization.py +++ b/intel_extension_for_pytorch/utils/weight_only_quantization.py @@ -52,7 +52,13 @@ def _get_linear_parameters(attr_name, state_dict, checkpoint_config): scales = state_dict.get(s_key, None) qzeros = state_dict.get(z_key, None) bias = state_dict.get(b_key, None) - return qweight, scales, qzeros, bias + group_size = -1 + if qweight is not None and scales is not None: + assert scales.dim() == 2, "Unexpected scales tensor dimension" + if scales.size(-1) != 1: + # qweight is compressed along the last dim int4 * 8 -> int32 + group_size = qweight.size(-1) * 8 // scales.size(-1) + return qweight, scales, qzeros, bias, group_size def _convert_woq_with_low_precision_checkpoint( @@ -108,13 +114,13 @@ def _convert_woq_with_low_precision_checkpoint( def _convert(mod, attr_name): if isinstance(mod, torch.nn.Linear): mod.qconfig = qconfig_mapping.global_qconfig - qweight, scales, qzeros, bias = _get_linear_parameters( + qweight, scales, qzeros, bias, group_size = _get_linear_parameters( attr_name, state_dict, checkpoint_config ) if any(i is None for i in [qweight, scales, qzeros]): return mod mod_new = IpexWoqLinear.from_float_and_int4_weight( - mod, qweight, scales, qzeros, bias + mod, qweight, scales, qzeros, bias, group_size=group_size ) return mod_new diff --git a/tests/cpu/test_deepspeed.py b/tests/cpu/test_deepspeed.py index b470bd11b..2526a3920 100644 --- a/tests/cpu/test_deepspeed.py +++ b/tests/cpu/test_deepspeed.py @@ -368,7 +368,6 @@ def test_llama_with_optimize_transformers(self): inplace=True, deployment_mode=True, ) - print(model) if not hasattr(model, "trace_graph"): AssertionError(False) _IPEXAttentionCPU = ( diff --git a/tests/cpu/test_ipex_optimize_transformers.py b/tests/cpu/test_ipex_optimize_transformers.py index d0c311862..5a3685678 100644 --- a/tests/cpu/test_ipex_optimize_transformers.py +++ b/tests/cpu/test_ipex_optimize_transformers.py @@ -7,6 +7,7 @@ import copy import re import tempfile +from intel_extension_for_pytorch.quantization import prepare, convert try: import transformers @@ -249,11 +250,15 @@ def test_model_replacement_codegen_torchcompile(self): self.model_replacement_check(m, True, torchcompile=True) def _model_replacement_check_woq(self, model): - qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping() + qconfig_mapping = ipex.quantization.get_weight_only_quant_qconfig_mapping() + orig_model = copy.deepcopy(model) + orig_woq_model = prepare(orig_model, qconfig_mapping, inplace=True) + orig_woq_model = convert(orig_woq_model, inplace=True) + model = ipex.optimize_transformers( model, dtype=torch.float, - quantization_config=qconfig, + quantization_config=qconfig_mapping, deployment_mode=True, inplace=True, ) @@ -294,7 +299,14 @@ def _model_replacement_check_woq(self, model): # Ensure model can run without errors with torch.no_grad(): example_inputs = _get_gptj_example_inputs() - model(*example_inputs) + y = model(*example_inputs) + y_ref = orig_woq_model( + input_ids=example_inputs[0], + attention_mask=example_inputs[1], + position_ids=example_inputs[2], + use_cache=True, + ) + self.assertEqual(y[0], y_ref[0], prec=1e-4) def test_weight_only_quant_flow_for_gptj(self): config = AutoConfig.from_pretrained( diff --git a/tests/cpu/test_quantization_default_recipe.py b/tests/cpu/test_quantization_default_recipe.py index ab4e22df9..77013ff69 100644 --- a/tests/cpu/test_quantization_default_recipe.py +++ b/tests/cpu/test_quantization_default_recipe.py @@ -10,12 +10,21 @@ QConfigMapping, ) import copy +import unittest +from common_utils import TestCase import intel_extension_for_pytorch as ipex from test_ao_jit_llga_utils import JitLlgaTestCase, LLGA_FUSION_GROUP from torch.testing._internal.common_utils import run_tests from torch.ao.nn.quantized.modules.utils import _quantize_weight -from intel_extension_for_pytorch.quantization import prepare, convert +from intel_extension_for_pytorch.quantization import ( + prepare, + convert, + dequantize_per_channel, + dequantize_per_block, + quantize_per_channel, + quantize_per_block, +) class TestDefaultRecipe(JitLlgaTestCase): @@ -306,9 +315,10 @@ def __init__(self): super().__init__() self.dense = nn.Linear(4, 4) self.relu = nn.ReLU() + self.dense2 = nn.Linear(4, 4) def forward(self, x): - return self.relu(self.dense(x)) + return self.dense2(self.relu(self.dense(x))) m = Mod().eval() x = torch.rand(1, 4) @@ -318,12 +328,13 @@ def forward(self, x): ) custom_config = { "alpha": 0.75, - "act_observer": torch.ao.quantization.MinMaxObserver(), - "act_ic_observer": per_channel_observer(ch_axis=-1), - "wei_observer": per_channel_observer( + "act_observer": torch.ao.quantization.MinMaxObserver, + "act_ic_observer": per_channel_observer.with_args(ch_axis=-1), + "wei_observer": per_channel_observer.with_args( dtype=torch.qint8, qscheme=torch.per_channel_symmetric ), - "wei_ic_observer": per_channel_observer(ch_axis=1), + "wei_ic_observer": per_channel_observer.with_args(ch_axis=1), + "share_weight_observers": False, } for use_custom_config in [False, True]: kwargs = custom_config if use_custom_config else {} @@ -345,6 +356,17 @@ def forward(self, x): ].weight_tensor_id_to_observer, } observer_info_dict = {} + observer_info_dict["share_weight_observers"] = ( + prepared_model._fqn_to_auto_quant_state_map[" "] + .idx_to_seen_q_op_infos[0] + .qconfig.share_weight_observers + ) + sub_observer_ids = { + "act_ic_obs": [], + "act_obs": [], + "wei_oc_obs": [], + "wei_ic_obs": [], + } for key, obs in observer_info.items(): observer_info_dict[key] = { "smooth_quant_enabled": obs.smooth_quant_enabled, @@ -352,6 +374,17 @@ def forward(self, x): "ic_obs": type(obs.ic_obs), "act_obs": type(obs.act_obs), } + if isinstance( + obs, + ipex.quantization._smooth_quant.SmoothQuantActivationObserver, + ): + sub_observer_ids["act_ic_obs"].append(id(obs.ic_obs)) + sub_observer_ids["act_obs"].append(id(obs.act_obs)) + else: + sub_observer_ids["wei_oc_obs"].append(id(obs.oc_obs)) + sub_observer_ids["wei_ic_obs"].append(id(obs.ic_obs)) + for _, id_list in sub_observer_ids.items(): + assert all([id_list[0] != id for id in id_list[1:]]) for data in calib_dataset: prepared_model(data) @@ -382,6 +415,17 @@ def forward(self, x): ].weight_tensor_id_to_observer, } observer_info_dict_2 = {} + observer_info_dict_2["share_weight_observers"] = ( + prepared_model_2._fqn_to_auto_quant_state_map[" "] + .idx_to_seen_q_op_infos[0] + .qconfig.share_weight_observers + ) + sub_observer_ids = { + "act_ic_obs": [], + "act_obs": [], + "wei_oc_obs": [], + "wei_ic_obs": [], + } for key, obs in observer_info_2.items(): observer_info_dict_2[key] = { "smooth_quant_enabled": obs.smooth_quant_enabled, @@ -389,6 +433,17 @@ def forward(self, x): "ic_obs": type(obs.ic_obs), "act_obs": type(obs.act_obs), } + if isinstance( + obs, + ipex.quantization._smooth_quant.SmoothQuantActivationObserver, + ): + sub_observer_ids["act_ic_obs"].append(id(obs.ic_obs)) + sub_observer_ids["act_obs"].append(id(obs.act_obs)) + else: + sub_observer_ids["wei_oc_obs"].append(id(obs.oc_obs)) + sub_observer_ids["wei_ic_obs"].append(id(obs.ic_obs)) + for _, id_list in sub_observer_ids.items(): + assert all([id_list[0] != id for id in id_list[1:]]) q_model_2 = ipex.quantization.convert(prepared_model_2) @@ -399,12 +454,109 @@ def forward(self, x): assert torch.allclose(out_ref, out_2) + # Scales and zero points should be updated after rerunning calibration + scale_zp_0 = prepared_model_2._fqn_to_auto_quant_state_map[ + " " + ].tensor_id_to_scale_zp + scale_zp_0 = copy.deepcopy(scale_zp_0) + for data in calib_dataset: + prepared_model_2(data + 1) + prepared_model_2.save_qconf_summary(qconf_summary=qconf_filename) + scale_zp_1 = prepared_model_2._fqn_to_auto_quant_state_map[ + " " + ].tensor_id_to_scale_zp + assert scale_zp_0 != scale_zp_1 + # Check observers if use_custom_config: assert ( observer_info_dict == observer_info_dict_2 ), "Error: SmoothQuant observer info lost after saving/loading qconf JSON" + def test_smooth_quant_cancel_by_qconf_summary(self): + class Mod(nn.Module): + def __init__(self): + super().__init__() + self.dense = nn.Linear(4, 4) + self.relu = nn.ReLU() + + def forward(self, x): + return self.relu(self.dense(x)) + + m = Mod().eval() + x = torch.rand(1, 4) + calib_dataset = [torch.rand(1, 4) for _ in range(5)] + qconfig_mapping = ipex.quantization.get_smooth_quant_qconfig_mapping() + prepared_model = ipex.quantization.prepare( + m, qconfig_mapping, example_inputs=x, inplace=False + ) + for data in calib_dataset: + prepared_model(data) + + with tempfile.NamedTemporaryFile() as fp: + qconf_filename = fp.name + prepared_model.save_qconf_summary(qconf_summary=qconf_filename) + import json + + with open(qconf_filename, "r") as qconf_file: + parsed = json.load(qconf_file) + parsed[" "]["q_op_infos"]["0"]["input_tensor_infos"][0][ + "force_dtype" + ] = "torch.float32" + + with open(qconf_filename, "w") as qconf_file: + json.dump(parsed, qconf_file, indent=4) + + prepared_model_2 = ipex.quantization.prepare( + m, qconfig_mapping, example_inputs=x, inplace=False + ) + prepared_model_2.load_qconf_summary(qconf_summary=qconf_filename) + converted_model = ipex.quantization.convert(prepared_model_2) + with torch.no_grad(): + jit_model = torch.jit.trace(converted_model, x) + jit_model = torch.jit.freeze(jit_model) + for _ in range(2): + jit_model(x) + graph = jit_model.graph_for(x) + for n in graph.nodes(): + assert n.kind() != "aten::mul" + + def test_smooth_quant_share_weight_observers(self): + class Mod(nn.Module): + def __init__(self): + super().__init__() + self.q_proj = nn.Linear(4, 4) + self.k_proj = nn.Linear(4, 4) + self.v_proj = nn.Linear(4, 4) + self.relu = nn.ReLU() + + def forward(self, x): + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + return self.relu(torch.concat([q, k, v], axis=1)) + + m = Mod().eval() + x = torch.rand(1, 4) + calib_dataset = [torch.rand(1, 4) for _ in range(5)] + for share_weight_observers in [True, False]: + qconfig_mapping = ipex.quantization.get_smooth_quant_qconfig_mapping( + share_weight_observers=share_weight_observers + ) + prepared_model = ipex.quantization.prepare( + m, qconfig_mapping, example_inputs=x, inplace=True + ) + for data in calib_dataset: + prepared_model(data) + q_model = ipex.quantization.convert(prepared_model) + with torch.no_grad(): + q_model = torch.jit.trace(q_model, x) + q_model = torch.jit.freeze(q_model) + graph = q_model.graph_for(x) + num_mul = [n.kind() for n in graph.nodes()].count("aten::mul") + assert num_mul == 1 if share_weight_observers else 3 + q_model(x) + def test_none_example_input_for_quantization(self): class M(nn.Module): def __init__(self): @@ -431,6 +583,8 @@ def forward(self, x): with self.assertRaises(AssertionError): prepared_model = ipex.quantization.prepare(m, qconfig_mapping) + +class WeightOnlyQuantizationTester(TestCase): def test_weight_only_quantization(self): class M(nn.Module): def __init__(self, input_channel, output_channel, has_bias): @@ -445,12 +599,11 @@ def test(feature, has_bias): m = model.eval() data = torch.rand(1, feature[0], feature[1]) weight = model.linear.weight - weight_observer = ( - ipex.quantization.get_weight_only_quant_qconfig_mapping().global_qconfig.weight() + is_int4 = False + weight_int8, w_scales, w_zero_points = quantize_per_channel(weight, is_int4) + weight_fp32 = dequantize_per_channel( + weight_int8, w_scales, w_zero_points.int(), is_int4, weight.shape ) - weight_observer(weight) - weight_int8 = _quantize_weight(weight, weight_observer) - weight_fp32 = weight_int8.dequantize() if has_bias: bias = model.linear.bias output1 = torch.matmul(data, weight_fp32.T) + bias @@ -510,46 +663,39 @@ def test(feature, has_bias, w_dtype): ) prepared_model = prepare(m, qconfig, example_inputs=data, inplace=False) + is_int4 = w_dtype == torch.quint4x2 with torch.no_grad(): weight = m.linear.weight - weight_observer = qconfig.global_qconfig.weight() - weight_observer(weight) - weight_int8 = _quantize_weight(weight, weight_observer) - weight_fp32 = weight_int8.dequantize() + weight_int8, w_scales, w_zero_points = quantize_per_channel( + weight, is_int4 + ) + weight_fp32 = dequantize_per_channel( + weight_int8, w_scales, w_zero_points.int(), is_int4, weight.shape + ) weight_bf16 = weight_fp32.bfloat16() weight_fp16 = weight_fp32.half() data_bf16 = data.bfloat16() data_fp16 = data_bf16.half() bias_fp32 = m.linear.bias - use_tpp = tpp_is_used(feature[2], feature[1]) - if use_tpp: - # if M >= 32, compute in bf16 - # if M < 32, compute in fp32 or fp16. Depends on fp16 support. - if feature[0] >= 32: - output1 = torch.matmul( - data_bf16.float(), weight_bf16.float().T - ).bfloat16() - if has_bias: - output1 = output1 + bias_fp32.bfloat16() - else: - output1_fp32 = torch.matmul( - data_bf16.float(), weight_bf16.float().T - ) - if has_bias: - output1_fp32 = output1_fp32 + bias_fp32 - output1_fp16 = torch.matmul( - data_fp16.float(), weight_fp16.float().T - ).half() - if has_bias: - output1_fp16 = output1_fp16 + bias_fp32.half() + # if M >= 32, compute in bf16 + # if M < 32, compute in fp32 or fp16. Depends on fp16 support. + if feature[0] >= 32: + output1 = torch.matmul( + data_bf16.float(), weight_bf16.float().T + ).bfloat16() + if has_bias: + output1 = output1 + bias_fp32.bfloat16() else: - if feature[0] <= 4: - output1 = torch.matmul(data_bf16.float(), weight_fp32.T) - else: - output1 = torch.matmul(data_bf16.float(), weight_bf16.float().T) + output1_fp32 = torch.matmul( + data_bf16.float(), weight_bf16.float().T + ) if has_bias: - output1 = output1 + bias_fp32 - output1 = output1.bfloat16() + output1_fp32 = output1_fp32 + bias_fp32 + output1_fp16 = torch.matmul( + data_fp16.float(), weight_fp16.float().T + ).half() + if has_bias: + output1_fp16 = output1_fp16 + bias_fp32.half() with torch.autocast( device_type="cpu", enabled=True, dtype=torch.bfloat16 ): @@ -563,7 +709,7 @@ def test(feature, has_bias, w_dtype): woq_model = torch.jit.freeze(woq_model) output2 = woq_model(data) output2 = output2.bfloat16() - if use_tpp and feature[0] < 32: + if feature[0] < 32: try: torch.testing.assert_close( output1_fp32.bfloat16(), output2, atol=0.01, rtol=0.1 @@ -577,9 +723,9 @@ def test(feature, has_bias, w_dtype): shape_list = [ [3, 31, 31], - # [4, 4096, 4096], # not supported by TPP yet (block_n = 16 issue) - [9, 4095, 4095], - [196, 4095, 16383], + [4, 64, 64], + [9, 128, 128], + [196, 63, 255], ] use_bias_list = [True, False] w_dtype_list = [torch.qint8, torch.quint4x2] @@ -686,8 +832,11 @@ def test(feature, has_bias): weight_dtype=torch.quint4x2 ).global_qconfig.weight() weight_observer(weight) - weight_int4 = _quantize_weight(weight, weight_observer) - weight_fp32 = weight_int4.dequantize() + is_int4 = True + weight_int4, w_scales, w_zero_points = quantize_per_channel(weight, is_int4) + weight_fp32 = dequantize_per_channel( + weight_int4, w_scales, w_zero_points, is_int4, weight.shape + ) if has_bias: bias = model.linear.bias output1 = torch.matmul(data, weight_fp32.T) + bias @@ -799,7 +948,7 @@ def forward(self, x, others): output1, output2.to(output1.dtype), atol=1.5e-2, rtol=1e-3 ) - def test_weight_only_quantization_lowp_compute(self): + def test_weight_only_quantization_lowp_mode_functionality(self): from intel_extension_for_pytorch.quantization import WoqLowpMode class M(nn.Module): @@ -812,7 +961,12 @@ def forward(self, x): data = torch.rand(4, 64) m = M() - for mode in [WoqLowpMode.FP16, WoqLowpMode.BF16, WoqLowpMode.INT8]: + for mode in [ + WoqLowpMode.NONE, + WoqLowpMode.FP16, + WoqLowpMode.BF16, + WoqLowpMode.INT8, + ]: kwargs = {"lowp_mode": mode} if mode == WoqLowpMode.INT8: kwargs["weight_dtype"] = torch.quint4x2 @@ -826,6 +980,52 @@ def forward(self, x): and woq_model.linear._lowp_mode == mode ), "Weight-only quantization: low precision gemm flag is not correctly set" + def test_weight_only_quantization_int8_lowp_mode_correctness(self): + from intel_extension_for_pytorch.quantization import WoqLowpMode + + class M(nn.Module): + def __init__(self): + super(M, self).__init__() + self.linear = torch.nn.Linear(64, 128) + + def forward(self, x): + return self.linear(x) + + # When lowp_mode=BF16, only case of batch size >= 32 uses BF16. + data = torch.rand(32, 64) + m = M() + + lowp_mode_list = [WoqLowpMode.NONE, WoqLowpMode.FP16, WoqLowpMode.BF16] + act_dtype_list = [torch.bfloat16, torch.half] + compute_dtype_list = [None, torch.half, torch.bfloat16] + cases = itertools.product(lowp_mode_list, act_dtype_list) + # lowp_mode does not affect weight observer for int8 + qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping() + weight = copy.deepcopy(m.linear.weight) + weight_observer = qconfig.global_qconfig.weight() + weight_observer(weight) + weight_int8 = _quantize_weight(weight, weight_observer) + weight_fp32 = weight_int8.dequantize() + bias_fp32 = copy.deepcopy(m.linear.bias) + for lowp_mode, act_dtype in cases: + if lowp_mode == WoqLowpMode.NONE: + compute_dtype_list[0] = act_dtype + compute_dtype = compute_dtype_list[int(lowp_mode)] + qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping( + lowp_mode=lowp_mode, + weight_dtype=torch.qint8, + ) + prepared_model = prepare(m, qconfig, example_inputs=data, inplace=False) + with torch.no_grad(): + woq_model = convert(prepared_model) + y = woq_model(data.to(act_dtype)) + weight_for_compute = weight_fp32.to(compute_dtype).float() + act_for_compute = data.to(act_dtype).to(compute_dtype).float() + bias_for_compute = bias_fp32.to(compute_dtype).float() + y_ref = act_for_compute @ weight_for_compute.T + bias_for_compute + y_ref = y_ref.to(act_dtype) + torch.testing.assert_close(y, y_ref, atol=0.005, rtol=0.01) + def test_weight_only_quantization_num_concats(self): class Mod(nn.Module): def __init__(self): @@ -870,6 +1070,250 @@ def forward(self, x): output2 = qm2(data) torch.testing.assert_close(output1, output2, atol=1e-2, rtol=1e-4) + def _fakequant_by_group(self, t, quant_a_mode, groupsize): + assert quant_a_mode >= 0 and quant_a_mode <= 3 + if quant_a_mode == 0: + obs = torch.ao.quantization.MinMaxObserver(torch.quint8) + obs(t) + scale, zero_point = obs.calculate_qparams() + return ( + torch.quantize_per_tensor( + t.to(torch.float), scale, zero_point, torch.quint8 + ) + .dequantize() + .to(t.dtype) + ) + orig_shape = t.shape + if t.shape[-1] % groupsize: + pad_len = t.shape[-1] // groupsize * groupsize + groupsize - t.shape[-1] + t = torch.nn.functional.pad(t, (0, pad_len), value=0) + grouped = t.view(-1, t.shape[-1] // groupsize, groupsize) + if quant_a_mode == 1: + grouped_min = grouped.min(dim=-1)[0].min(dim=0)[0] + grouped_max = grouped.max(dim=-1)[0].max(dim=0)[0] + elif quant_a_mode == 2: + grouped_min = grouped.min(dim=-1)[0].min(dim=1)[0] + grouped_max = grouped.max(dim=-1)[0].max(dim=1)[0] + else: + grouped_min = grouped.min(dim=-1)[0] + grouped_max = grouped.max(dim=-1)[0] + zeros = torch.zeros_like(grouped_min) + min = torch.minimum(grouped_min, zeros) + max = torch.maximum(grouped_max, zeros) + eps = torch.tensor([torch.finfo(torch.float32).eps]) + scales = (max - min) / 255 + scales = torch.max(scales, eps) + zps = -torch.round(min / scales) + if quant_a_mode == 1: + qt = torch.clamp( + torch.round(grouped / scales.unsqueeze(1)) + zps.unsqueeze(1), + min=0, + max=255, + ) + out = ( + ((qt - zps.unsqueeze(1)) * scales.unsqueeze(1)) + .to(t.dtype) + .view(t.shape) + ) + if orig_shape != out.shape: + out = out[: orig_shape[0], : orig_shape[1]].contiguous() + return out + elif quant_a_mode == 2: + qt = torch.clamp( + torch.round(grouped / scales.unsqueeze(1).unsqueeze(2)) + + zps.unsqueeze(1).unsqueeze(2), + min=0, + max=255, + ) + out = ( + ( + (qt - zps.unsqueeze(1).unsqueeze(2)) + * scales.unsqueeze(1).unsqueeze(2) + ) + .to(t.dtype) + .view(t.shape) + ) + if orig_shape != out.shape: + out = out[: orig_shape[0], : orig_shape[1]].contiguous() + return out + else: + qt = torch.clamp( + torch.round(grouped / scales.unsqueeze(-1)) + zps.unsqueeze(-1), + min=0, + max=255, + ) + out = ( + ((qt - zps.unsqueeze(-1)) * scales.unsqueeze(-1)) + .to(t.dtype) + .view(t.shape) + ) + if orig_shape != out.shape: + out = out[: orig_shape[0], : orig_shape[1]].contiguous() + return out + + def test_weight_only_quantization_act_quant_mode(self): + M, N, K = 4, 64, 128 + groupsize = 64 + + class Mod(nn.Module): + def __init__(self, has_bias): + super(Mod, self).__init__() + self.linear = torch.nn.Linear(K, N, has_bias) + + def forward(self, x): + return self.linear(x) + + def test(has_bias, act_quant_mode): + dtype = torch.bfloat16 + model = Mod(has_bias) + m = model.eval() + m2 = copy.deepcopy(m) + data = torch.rand(M, K) * 0.5 + qconfig_mapping = ipex.quantization.get_weight_only_quant_qconfig_mapping( + weight_dtype=torch.quint4x2, + lowp_mode=ipex.quantization.WoqLowpMode.INT8, + act_quant_mode=act_quant_mode, + ) + fake_quant_x = self._fakequant_by_group(data, act_quant_mode, groupsize) + prepared_model = prepare(m2, qconfig_mapping, inplace=True) + with torch.no_grad(), torch.autocast( + device_type="cpu", enabled=True, dtype=dtype + ): + woq_model = convert(prepared_model) + # Behavior of WOQ Linear to simulate: + # Quantize weight to int4 by float qparams at quantization time + # Quantize activation to int8 at runtime + # Convert weight and its zero points to INT8 for computation + qw = woq_model.linear._op_context.to_public( + woq_model.linear._op_context.get_weight() + ) + w_scales = woq_model.linear._op_context.get_scales() + w_zero_points = woq_model.linear._op_context.get_zero_points() + w = copy.deepcopy(m.linear.weight.data) + is_int4 = True + qw, _, _ = quantize_per_channel(w, is_int4, w_scales, w_zero_points) + fake_quant_w = dequantize_per_channel( + qw, w_scales, w_zero_points.int(), is_int4, w.shape + ) + m.linear.weight.data = fake_quant_w + y_ref = m(fake_quant_x).to(dtype) + y = woq_model(data) + try: + torch.testing.assert_close(y, y_ref, atol=1e-2 * 5, rtol=1e-1 * 2) + except Exception: + # The fallback kernel does not support act quant mode + # It computes in fp32 by dequantizing weight. + fake_quant_w = qw.dequantize() + y_ref = data @ fake_quant_w.T + (m.linear.bias if has_bias else 0) + y_ref = y_ref.to(dtype) + torch.testing.assert_close(y, y_ref, atol=1e-2, rtol=1e-1) + + has_bias_list = [False, True] + quant_mode_list = [0, 1, 2, 3] + cases = itertools.product(has_bias_list, quant_mode_list) + for has_bias, quant_mode in cases: + test(has_bias, quant_mode) + + def test_weight_only_quantization_group_size(self): + # M, N, K = 4, 64, 128 + + class Mod(nn.Module): + def __init__(self, ic, oc, has_bias): + super(Mod, self).__init__() + self.linear = torch.nn.Linear(ic, oc, has_bias) + + def forward(self, x): + return self.linear(x) + + def test(shape, has_bias, act_quant_mode, group_size): + M, N, K = shape + dtype = torch.bfloat16 + model = Mod(K, N, has_bias) + m = model.eval() + m2 = copy.deepcopy(m) + data = torch.rand(M, K) * 0.5 + if group_size == -1 and act_quant_mode != 0: + # these cases are covered by another test case for act_quant_mode + return + qconfig_mapping = ipex.quantization.get_weight_only_quant_qconfig_mapping( + weight_dtype=torch.quint4x2, + lowp_mode=ipex.quantization.WoqLowpMode.INT8, + act_quant_mode=act_quant_mode, + group_size=group_size, + ) + fake_quant_x = self._fakequant_by_group(data, act_quant_mode, group_size) + prepared_model = prepare(m2, qconfig_mapping, inplace=True) + with torch.no_grad(), torch.autocast( + device_type="cpu", enabled=True, dtype=dtype + ): + woq_model = convert(prepared_model) + # Behavior of WOQ Linear to simulate: + # Quantize weight to int4 by float qparams at quantization time + # Quantize activation to int8 at runtime + # Convert weight and its zero points to INT8 for computation + w = copy.deepcopy(m.linear.weight.data) + is_int4 = True + if group_size == -1: + qw, w_scales, w_zero_points = quantize_per_channel( + w, is_int4, None, None + ) + fake_quant_w = dequantize_per_channel( + qw, w_scales, w_zero_points.int(), is_int4, w.shape + ) + else: + qw, w_scales, w_zero_points = quantize_per_block( + w, is_int4, group_size, None, None + ) + fake_quant_w = dequantize_per_block( + qw, + w_scales, + w_zero_points, + is_int4, + group_size, + weight_shape=w.shape, + ) + m.linear.weight.data = fake_quant_w + y_ref = m(fake_quant_x).to(dtype) + y = woq_model(data) + try: + torch.testing.assert_close(y, y_ref, atol=1e-2 * 5, rtol=1e-1 * 2) + except Exception: + # The fallback kernel does not support act quant mode + # It computes in fp32 by dequantizing weight. + # fake_quant_w = qw.dequantize() + y_ref = data @ fake_quant_w.T + (m.linear.bias if has_bias else 0) + y_ref = y_ref.to(dtype) + torch.testing.assert_close(y, y_ref, atol=1e-2, rtol=1e-1) + + MNK_list = [(4, 64, 128), (4, 32, 127), (9, 31, 256)] + has_bias_list = [False, True] + quant_mode_list = [0, 1, 2, 3] + group_size_list = [-1, 32, 64, 128] + cases = itertools.product( + MNK_list, has_bias_list, quant_mode_list, group_size_list + ) + for shape, has_bias, act_quant_mode, group_size in cases: + test(shape, has_bias, act_quant_mode, group_size) + + +class QuantizedOpsTester(TestCase): + def test_matmul_i8i8i32(self): + x = torch.randn(4, 8) + w = torch.randn(4, 8) + x_min, x_max = x.aminmax() + x_scale = torch.max(x_max, x_min.neg()) / 127 + qx = torch.round(x / x_scale).to(torch.int8) + w_min, w_max = w.aminmax(dim=1) + w_scale = torch.max(w_max, w_min.neg()) / 127 + qw = torch.round(w / w_scale.unsqueeze(-1)).to(torch.int8) + for use_bf16 in [False, True]: + dtype = torch.bfloat16 if use_bf16 else torch.float32 + with torch.cpu.amp.autocast(enabled=use_bf16, dtype=dtype): + qy = torch.ops.torch_ipex.matmul_i8i8i32(qx, qw) + qy_ref = torch.nn.functional.linear(qx.to(dtype), qw.to(dtype)) + self.assertEqual(qy.to(dtype), qy_ref) + if __name__ == "__main__": + test = unittest.main() run_tests()