From 18d36ef39cd956a0069743b8e8f833ee316bf007 Mon Sep 17 00:00:00 2001 From: "Wang, Zhe" Date: Sun, 4 Feb 2024 11:38:12 +0800 Subject: [PATCH] Qbits woq ref impl for debug (#1248) --- .../dispatcher/include/bestla_packq_impl.hpp | 32 ++++++ .../include/bestla_weightonly_dispatcher.hpp | 1 - .../csrc/dispatcher/src/bestla_packq_impl.cpp | 97 ++++++++++++++++- .../src/bestla_weightonly_dispatcher.cpp | 1 - .../llm/operator/csrc/qbits.cpp | 12 ++- .../llm/operator/csrc/qbits_ut/test_packq.py | 29 +++++ .../llm/quantization/autograd/functions.py | 102 +++++++++++++----- 7 files changed, 242 insertions(+), 32 deletions(-) create mode 100644 intel_extension_for_transformers/llm/operator/csrc/dispatcher/include/bestla_packq_impl.hpp diff --git a/intel_extension_for_transformers/llm/operator/csrc/dispatcher/include/bestla_packq_impl.hpp b/intel_extension_for_transformers/llm/operator/csrc/dispatcher/include/bestla_packq_impl.hpp new file mode 100644 index 00000000000..984236964d1 --- /dev/null +++ b/intel_extension_for_transformers/llm/operator/csrc/dispatcher/include/bestla_packq_impl.hpp @@ -0,0 +1,32 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "bestla_weightonly_dispatcher.hpp" +namespace woq { + +enum PACKW_ACQUIRE_TYPE { + SIZE = 0, + BLOCKSIZE, + K, + N, + ACT_SHUFFLE, + G_IDX, + WEI_TYPE, + CMPT_TYPE, + SCALE_TYPE, +}; + +void bestla_packq(woq_packq_param* p, woq_packq_ctx* ctx); +torch::Tensor get_packw_info(torch::Tensor& packw, PACKW_ACQUIRE_TYPE ACQ_T); +} // namespace woq diff --git a/intel_extension_for_transformers/llm/operator/csrc/dispatcher/include/bestla_weightonly_dispatcher.hpp b/intel_extension_for_transformers/llm/operator/csrc/dispatcher/include/bestla_weightonly_dispatcher.hpp index b8c7c36ea32..545188252b2 100644 --- a/intel_extension_for_transformers/llm/operator/csrc/dispatcher/include/bestla_weightonly_dispatcher.hpp +++ b/intel_extension_for_transformers/llm/operator/csrc/dispatcher/include/bestla_weightonly_dispatcher.hpp @@ -69,6 +69,5 @@ static std::map scale2bestladt_map{ {"fp32", BTLA_DTYPE::F32}, {"bf16", BTLA_DTYPE::BF16}, {"fp8_e8m0", BTLA_DTYPE::F8_E8M0}}; void dispatch_woq_task(woq_config_param* p, woq_runtime_ctx* ctx, WOQ_TASK task); -void bestla_packq(woq_packq_param* p, woq_packq_ctx* ctx); void set_woq_workspace(torch::Tensor* workspace); } // namespace woq diff --git a/intel_extension_for_transformers/llm/operator/csrc/dispatcher/src/bestla_packq_impl.cpp b/intel_extension_for_transformers/llm/operator/csrc/dispatcher/src/bestla_packq_impl.cpp index c0bb021bc25..ae7a466a842 100644 --- a/intel_extension_for_transformers/llm/operator/csrc/dispatcher/src/bestla_packq_impl.cpp +++ b/intel_extension_for_transformers/llm/operator/csrc/dispatcher/src/bestla_packq_impl.cpp @@ -1,5 +1,5 @@ #include "bestla/bestla_prologue_b.h" -#include "../include/bestla_weightonly_dispatcher.hpp" +#include "../include/bestla_packq_impl.hpp" namespace woq { template @@ -17,6 +17,101 @@ void execute_qpack(woq_packq_param* p, woq_packq_ctx* ctx) { p->asym ? ctx->zp->data_ptr() : nullptr, &qpackw, &dispatcher_utils::DefaultThreading); } +std::string get_dtype_str(BTLA_DTYPE dtype) { + switch (dtype) { + case BTLA_DTYPE::F32: + return "fp32"; + case BTLA_DTYPE::BF16: + return "bf16"; + case BTLA_DTYPE::S4_CLIP: + return "int4_clip"; + case BTLA_DTYPE::S4_FULLRANGE: + return "int4_fullrange"; + case BTLA_DTYPE::F4_NF4: + return "nf4"; + case BTLA_DTYPE::F4_E2M1: + return "fp4_e2m1"; + case BTLA_DTYPE::F4_BNB: + return "fp4_e2m1_bnb"; + case BTLA_DTYPE::S8: + return "int8"; + case BTLA_DTYPE::F8_E5M2: + return "fp8_e5m2"; + case BTLA_DTYPE::F8_E4M3: + return "fp8_e4m3"; + case BTLA_DTYPE::F8_E8M0: + return "fp8_e8m0"; + default: + TORCH_CHECK(false, "QBits: unrecognized data type.") + break; + } +} + +std::string get_cmpt_str(bestla::gemm::CompType cmpt) { + using bestla::gemm::CompType; + switch (cmpt) { + case CompType::COMP_INT8_US_INT32: + case CompType::COMP_INT8_US_FP32: + return "int8"; + case CompType::COMP_FP32: + return "fp32"; + case CompType::COMP_BF16_FP32: + return "bf16"; + default: + TORCH_CHECK(false, "QBits: unrecognized compute type."); + break; + } +} + +std::vector get_ascii_vec(std::string str) { + std::vector ret; + for (char c : str) ret.push_back(static_cast(c)); + return ret; +} + +torch::Tensor get_packw_info(torch::Tensor& packw, PACKW_ACQUIRE_TYPE ACQ_T) { + torch::Tensor output; + auto packw_ptr = dynamic_cast( + bestla::storage::gemm::PackedWeightParser::deserialBuffer(packw.data_ptr())); + output = torch::empty(1, torch::kInt64); + switch (ACQ_T) { + case SIZE: + return output.index_put_({0}, static_cast(packw_ptr->mSize)); + case BLOCKSIZE: + return output.index_put_({0}, static_cast(packw_ptr->mBlockSize)); + case K: + return output.index_put_({0}, static_cast(packw_ptr->mK)); + case N: + return output.index_put_({0}, static_cast(packw_ptr->mN)); + case ACT_SHUFFLE: + return output.index_put_({0}, static_cast(packw_ptr->ShfIndice() != nullptr ? 1 : 0)); + case G_IDX: { + auto tensor_size = packw_ptr->mShuffleIndices.size(); + TORCH_CHECK(packw_ptr->ShfIndice() != nullptr, "QBits: not pack g_idx tensor."); + output = torch::empty(tensor_size, torch::kInt32); + memcpy(output.data_ptr(), packw_ptr->ShfIndice(), tensor_size * sizeof(int)); + } break; + case WEI_TYPE: + case SCALE_TYPE: { + BTLA_DTYPE acquire_dt = ACQ_T == WEI_TYPE ? packw_ptr->mDType : packw_ptr->SDtype(); + auto ascii_vec = get_ascii_vec(get_dtype_str(acquire_dt)); + output = torch::empty(ascii_vec.size(), torch::kInt32); + memcpy(output.data_ptr(), ascii_vec.data(), ascii_vec.size() * sizeof(int)); + } break; + case CMPT_TYPE: { + auto CType = bestla::gemm::CoreAttr::get_mask_val(packw_ptr->mCoreId, bestla::gemm::CoreAttr::COMP_MASK, + bestla::gemm::CoreAttr::COMP_SHIFT); + auto ascii_vec = get_ascii_vec(get_cmpt_str(static_cast(CType))); + output = torch::empty(ascii_vec.size(), torch::kInt32); + memcpy(output.data_ptr(), ascii_vec.data(), ascii_vec.size() * sizeof(int)); + } break; + default: + TORCH_CHECK(false, "QBits: unsupported acquire_type"); + break; + } + return output; +} + void bestla_packq(woq_packq_param* p, woq_packq_ctx* ctx) { TORCH_CHECK(p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int4_fullrange", "Qbits: only support Integer WOQ in PACKQ"); diff --git a/intel_extension_for_transformers/llm/operator/csrc/dispatcher/src/bestla_weightonly_dispatcher.cpp b/intel_extension_for_transformers/llm/operator/csrc/dispatcher/src/bestla_weightonly_dispatcher.cpp index 8629e5b255b..74bc5f63ea6 100644 --- a/intel_extension_for_transformers/llm/operator/csrc/dispatcher/src/bestla_weightonly_dispatcher.cpp +++ b/intel_extension_for_transformers/llm/operator/csrc/dispatcher/src/bestla_weightonly_dispatcher.cpp @@ -49,7 +49,6 @@ void woq_dequantize(woq_config_param* p, woq_runtime_ctx* ctx) { using PrologueB = typename Launcher::PrologueB; using WType = typename Launcher::PrologueB::StorageWeight; static PrologueB kernel; - // TODO(zhe): using unified StorageWeightKBlockNInteger after sync with neural-speed(with NFloat ProB feature). if (ctx->transpose) { kernel.unpackTransposeWeight(ctx->deseries_wei->mN, ctx->deseries_wei->mK, dynamic_cast(ctx->deseries_wei), diff --git a/intel_extension_for_transformers/llm/operator/csrc/qbits.cpp b/intel_extension_for_transformers/llm/operator/csrc/qbits.cpp index d47dbfa6dac..c8637cad78b 100755 --- a/intel_extension_for_transformers/llm/operator/csrc/qbits.cpp +++ b/intel_extension_for_transformers/llm/operator/csrc/qbits.cpp @@ -14,6 +14,7 @@ #include "dispatcher/include/dispatcher_utils.hpp" #include "dispatcher/include/bestla_gemm_dispatcher.hpp" #include "dispatcher/include/bestla_weightonly_dispatcher.hpp" +#include "dispatcher/include/bestla_packq_impl.hpp" #include "include/dropout.hpp" #include #include @@ -45,8 +46,8 @@ static void inline init_woq_config_param(woq::woq_config_param* p, woq::woq_runt case woq::WOQ_QUANTIZE: case woq::WOQ_DEQUANTIZE: p->src_dt = dispatcher_utils::QBITS_FP32; - p->dst_dt = dispatcher_utils::QBITS_FP32; // bestla doesn't care about dst_dt in quantize/dequant task,so set fp32 - // as default. + p->dst_dt = dispatcher_utils::QBITS_FP32; // bestla doesn't care about dst_dt in quantize/dequant task,so set + // fp32 as default. break; case woq::WOQ_LINEAR: p->src_dt = get_qbits_dt(ctx->activation); @@ -122,7 +123,7 @@ static void set_woq_workspace(const torch::Tensor& workspace) { } static void bestlaop_gemm(const torch::Tensor& matA, const torch::Tensor& matB, const torch::Tensor& matC, - bool matB_trans) { + bool matB_trans) { TORCH_CHECK(matA.dim() == 2 && matB.dim() == 2 && matC.dim() == 2, "Qbits: only support 2-dim input-tensor in bestla gemm op."); bestla_gemm::bestla_gemm_runtime_ctx ctx; @@ -138,6 +139,10 @@ static void bestlaop_gemm(const torch::Tensor& matA, const torch::Tensor& matB, return bestla_gemm::dispatch_bestla_gemm(&ctx); } +static torch::Tensor acquire_woq_packw_info(torch::Tensor& packw, int64_t acquire_type) { + return woq::get_packw_info(packw, static_cast(acquire_type)); +} + static torch::Tensor qbits_dropout_fwd(torch::Tensor& output, double p) { return dropout_fwd(output, p); } static void qbits_dropout_bwd(torch::Tensor& grad, torch::Tensor& scale) { dropout_bwd(grad, scale); } @@ -149,6 +154,7 @@ TORCH_LIBRARY(bestlaop, m) { m.def("woq_packq", &woq_packq); m.def("set_woq_workspace", &set_woq_workspace); m.def("matmul", &bestlaop_gemm); + m.def("acquire_woq_packw_info", &acquire_woq_packw_info); } TORCH_LIBRARY(qbits_customop, m) { diff --git a/intel_extension_for_transformers/llm/operator/csrc/qbits_ut/test_packq.py b/intel_extension_for_transformers/llm/operator/csrc/qbits_ut/test_packq.py index df9d4ea1477..99462e68849 100644 --- a/intel_extension_for_transformers/llm/operator/csrc/qbits_ut/test_packq.py +++ b/intel_extension_for_transformers/llm/operator/csrc/qbits_ut/test_packq.py @@ -16,6 +16,7 @@ # limitations under the License. from ut_utils import * +from enum import Enum def convert_idx(g_idx, k, blocksize): @@ -27,6 +28,18 @@ def convert_idx(g_idx, k, blocksize): return ret_idx +class acquire_type(Enum): + SIZE = 0 + BLOCKSIZE = 1 + K = 2 + N = 3 + ACT_SHUFFLE = 4 + G_IDX = 5 + WEI_TYPE = 6 + CMPT_TYPE = 7 + SCALE_TYPE = 8 + + @pytest.mark.parametrize("m", [256]) @pytest.mark.parametrize("n", [1024]) @pytest.mark.parametrize("k", [512]) @@ -66,3 +79,19 @@ def test(m, k, n, weight_type, scale_type, compute_type, asym, blocksize, dump_t assert (abs(ref_dst - tar_dst).max() < 8) else: assert (abs(ref_dst - tar_dst).max() < 10) + packw_size = torch.ops.bestlaop.acquire_woq_packw_info( + packw, acquire_type.SIZE.value)[0].item() + if packw_size != packw.size()[0]: + assert (0) + packw_wei_type = torch.ops.bestlaop.acquire_woq_packw_info( + packw, acquire_type.WEI_TYPE.value) + packw_wei_type_str = ''.join(chr(ascii_code) + for ascii_code in packw_wei_type.tolist()) + if packw_wei_type_str != weight_type: + assert (0) + enable_act_shuffle = torch.ops.bestlaop.acquire_woq_packw_info( + packw, acquire_type.ACT_SHUFFLE.value)[0] != 0 + assert (enable_act_shuffle) + acquire_g_idx = packw_wei_type = torch.ops.bestlaop.acquire_woq_packw_info( + packw, acquire_type.G_IDX.value) + assert (abs(acquire_g_idx-cvt_idx).max() == 0) diff --git a/intel_extension_for_transformers/llm/quantization/autograd/functions.py b/intel_extension_for_transformers/llm/quantization/autograd/functions.py index 3b5d0506a35..9a8f2481e14 100644 --- a/intel_extension_for_transformers/llm/quantization/autograd/functions.py +++ b/intel_extension_for_transformers/llm/quantization/autograd/functions.py @@ -16,11 +16,49 @@ # limitations under the License. +import os import operator import torch from functools import reduce from torch import Tensor from typing import Tuple, Optional, List +from enum import Enum + + +class qbits_acquire_type(Enum): + SIZE = 0 + BLOCKSIZE = 1 + K = 2 + N = 3 + ACT_SHUFFLE = 4 + G_IDX = 5 + WEI_TYPE = 6 + CMPT_TYPE = 7 + SCALE_TYPE = 8 + + +def qbits_woq_linear_ref_impl(activation, packw, bias, compute_type, weight_type, scale_type): + assert (activation.is_contiguous()) + assert (packw.is_contiguous()) + activation = activation.to(torch.float32) + n = torch.ops.bestlaop.acquire_woq_packw_info( + packw, qbits_acquire_type.N.value)[0].item() + k = activation.shape[1] + revert_wei = torch.empty(k, n, dtype=torch.float) + torch.ops.bestlaop.woq_dequantize( + packw, revert_wei, False, compute_type, weight_type, scale_type) + enable_act_shuffle = torch.ops.bestlaop.acquire_woq_packw_info( + packw, qbits_acquire_type.ACT_SHUFFLE.value)[0] != 0 + if enable_act_shuffle: + g_idx = torch.ops.bestlaop.acquire_woq_packw_info( + packw, qbits_acquire_type.G_IDX.value) + activation = torch.index_select(activation, 1, g_idx) + out = torch.matmul(activation, revert_wei) + if bias is not None: + assert (bias.is_contiguous()) + assert (bias.dtype == torch.float32) + out += bias + return out def prod(iterable): @@ -64,18 +102,23 @@ def forward( # 2. Matmul # output = torch.nn.functional.linear(A, B_dequant, bias) - torch.ops.bestlaop.woq_linear( - A, - B.data, - bias, - out, - out.shape[-1], - bias is not None, - compute_dtype, - weight_dtype, - scale_dtype, - False, - ) + qbits_debug_flag = os.getenv('QBITS_DEBUG', 'NULL') + if qbits_debug_flag == 'NULL': + torch.ops.bestlaop.woq_linear( + A, + B.data, + bias, + out, + out.shape[-1], + bias is not None, + compute_dtype, + weight_dtype, + scale_dtype, + False, + ) + else: + out = qbits_woq_linear_ref_impl( + A, B.data, bias, compute_dtype, weight_dtype, scale_dtype) output = out # 3. Save state @@ -101,7 +144,8 @@ def forward( @staticmethod def backward(ctx, grad_output): if ctx.is_empty: - bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) + bias_grad = None if ctx.bias is None else torch.zeros_like( + ctx.bias) return ( torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), @@ -114,7 +158,8 @@ def backward(ctx, grad_output): A, B = ctx.tensors grad_A, grad_B, grad_bias = None, None, None - B_dequant = torch.zeros(grad_output.shape[-1], A.shape[-1], dtype=torch.float) + B_dequant = torch.zeros( + grad_output.shape[-1], A.shape[-1], dtype=torch.float) torch.ops.bestlaop.woq_dequantize( B, B_dequant, True, ctx.compute_dtype, ctx.weight_dtype, ctx.scale_dtype @@ -149,17 +194,22 @@ def matmul_kbit( A, B, out, bias, compute_dtype, weight_dtype, scale_dtype ) else: - torch.ops.bestlaop.woq_linear( - A, - B.data, - bias, - out, - out.shape[-1], - bias is not None, - compute_dtype, - weight_dtype, - scale_dtype, - False, - ) + qbits_debug_flag = os.getenv('QBITS_DEBUG', 'NULL') + if qbits_debug_flag == 'NULL': + torch.ops.bestlaop.woq_linear( + A, + B.data, + bias, + out, + out.shape[-1], + bias is not None, + compute_dtype, + weight_dtype, + scale_dtype, + False, + ) + else: + out = qbits_woq_linear_ref_impl( + A, B.data, bias, compute_dtype, weight_dtype, scale_dtype) return out