Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Qbits woq ref impl for debug #1248

Merged
merged 5 commits into from
Feb 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (c) 2023 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "bestla_weightonly_dispatcher.hpp"
namespace woq {

enum PACKW_ACQUIRE_TYPE {
SIZE = 0,
BLOCKSIZE,
K,
N,
ACT_SHUFFLE,
G_IDX,
WEI_TYPE,
CMPT_TYPE,
SCALE_TYPE,
};

void bestla_packq(woq_packq_param* p, woq_packq_ctx* ctx);
torch::Tensor get_packw_info(torch::Tensor& packw, PACKW_ACQUIRE_TYPE ACQ_T);
} // namespace woq
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,5 @@ static std::map<std::string, BTLA_DTYPE> scale2bestladt_map{
{"fp32", BTLA_DTYPE::F32}, {"bf16", BTLA_DTYPE::BF16}, {"fp8_e8m0", BTLA_DTYPE::F8_E8M0}};

void dispatch_woq_task(woq_config_param* p, woq_runtime_ctx* ctx, WOQ_TASK task);
void bestla_packq(woq_packq_param* p, woq_packq_ctx* ctx);
void set_woq_workspace(torch::Tensor* workspace);
} // namespace woq
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "bestla/bestla_prologue_b.h"
#include "../include/bestla_weightonly_dispatcher.hpp"
#include "../include/bestla_packq_impl.hpp"

namespace woq {
template <class GemmCore, BTLA_ISA ISA>
Expand All @@ -17,6 +17,101 @@ void execute_qpack(woq_packq_param* p, woq_packq_ctx* ctx) {
p->asym ? ctx->zp->data_ptr<int8_t>() : nullptr, &qpackw, &dispatcher_utils::DefaultThreading);
}

std::string get_dtype_str(BTLA_DTYPE dtype) {
switch (dtype) {
case BTLA_DTYPE::F32:
return "fp32";
case BTLA_DTYPE::BF16:
return "bf16";
case BTLA_DTYPE::S4_CLIP:
return "int4_clip";
case BTLA_DTYPE::S4_FULLRANGE:
return "int4_fullrange";
case BTLA_DTYPE::F4_NF4:
return "nf4";
case BTLA_DTYPE::F4_E2M1:
return "fp4_e2m1";
case BTLA_DTYPE::F4_BNB:
return "fp4_e2m1_bnb";
case BTLA_DTYPE::S8:
return "int8";
case BTLA_DTYPE::F8_E5M2:
return "fp8_e5m2";
case BTLA_DTYPE::F8_E4M3:
return "fp8_e4m3";
case BTLA_DTYPE::F8_E8M0:
return "fp8_e8m0";
default:
TORCH_CHECK(false, "QBits: unrecognized data type.")
break;
}
}

std::string get_cmpt_str(bestla::gemm::CompType cmpt) {
using bestla::gemm::CompType;
switch (cmpt) {
case CompType::COMP_INT8_US_INT32:
case CompType::COMP_INT8_US_FP32:
return "int8";
case CompType::COMP_FP32:
return "fp32";
case CompType::COMP_BF16_FP32:
return "bf16";
default:
TORCH_CHECK(false, "QBits: unrecognized compute type.");
break;
}
}

std::vector<int> get_ascii_vec(std::string str) {
std::vector<int32_t> ret;
for (char c : str) ret.push_back(static_cast<int32_t>(c));
return ret;
}

torch::Tensor get_packw_info(torch::Tensor& packw, PACKW_ACQUIRE_TYPE ACQ_T) {
torch::Tensor output;
auto packw_ptr = dynamic_cast<bestla::storage::gemm::StorageWeightKBlockNInteger*>(
bestla::storage::gemm::PackedWeightParser::deserialBuffer(packw.data_ptr()));
output = torch::empty(1, torch::kInt64);
switch (ACQ_T) {
case SIZE:
return output.index_put_({0}, static_cast<int64_t>(packw_ptr->mSize));
case BLOCKSIZE:
return output.index_put_({0}, static_cast<int64_t>(packw_ptr->mBlockSize));
case K:
return output.index_put_({0}, static_cast<int64_t>(packw_ptr->mK));
case N:
return output.index_put_({0}, static_cast<int64_t>(packw_ptr->mN));
case ACT_SHUFFLE:
return output.index_put_({0}, static_cast<int64_t>(packw_ptr->ShfIndice() != nullptr ? 1 : 0));
case G_IDX: {
auto tensor_size = packw_ptr->mShuffleIndices.size<int>();
TORCH_CHECK(packw_ptr->ShfIndice() != nullptr, "QBits: not pack g_idx tensor.");
output = torch::empty(tensor_size, torch::kInt32);
memcpy(output.data_ptr(), packw_ptr->ShfIndice(), tensor_size * sizeof(int));
} break;
case WEI_TYPE:
case SCALE_TYPE: {
BTLA_DTYPE acquire_dt = ACQ_T == WEI_TYPE ? packw_ptr->mDType : packw_ptr->SDtype();
auto ascii_vec = get_ascii_vec(get_dtype_str(acquire_dt));
output = torch::empty(ascii_vec.size(), torch::kInt32);
memcpy(output.data_ptr(), ascii_vec.data(), ascii_vec.size() * sizeof(int));
} break;
case CMPT_TYPE: {
auto CType = bestla::gemm::CoreAttr::get_mask_val(packw_ptr->mCoreId, bestla::gemm::CoreAttr::COMP_MASK,
bestla::gemm::CoreAttr::COMP_SHIFT);
auto ascii_vec = get_ascii_vec(get_cmpt_str(static_cast<bestla::gemm::CompType>(CType)));
output = torch::empty(ascii_vec.size(), torch::kInt32);
memcpy(output.data_ptr(), ascii_vec.data(), ascii_vec.size() * sizeof(int));
} break;
default:
TORCH_CHECK(false, "QBits: unsupported acquire_type");
break;
}
return output;
}

void bestla_packq(woq_packq_param* p, woq_packq_ctx* ctx) {
TORCH_CHECK(p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int4_fullrange",
"Qbits: only support Integer WOQ in PACKQ");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ void woq_dequantize(woq_config_param* p, woq_runtime_ctx* ctx) {
using PrologueB = typename Launcher::PrologueB;
using WType = typename Launcher::PrologueB::StorageWeight;
static PrologueB kernel;
// TODO(zhe): using unified StorageWeightKBlockNInteger after sync with neural-speed(with NFloat ProB feature).
if (ctx->transpose) {
kernel.unpackTransposeWeight(ctx->deseries_wei->mN, ctx->deseries_wei->mK,
dynamic_cast<bestla::storage::gemm::StorageWeightKBlockNInteger*>(ctx->deseries_wei),
Expand Down
12 changes: 9 additions & 3 deletions intel_extension_for_transformers/llm/operator/csrc/qbits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "dispatcher/include/dispatcher_utils.hpp"
#include "dispatcher/include/bestla_gemm_dispatcher.hpp"
#include "dispatcher/include/bestla_weightonly_dispatcher.hpp"
#include "dispatcher/include/bestla_packq_impl.hpp"
#include "include/dropout.hpp"
#include <ATen/core/TensorBody.h>
#include <c10/core/ScalarType.h>
Expand Down Expand Up @@ -45,8 +46,8 @@ static void inline init_woq_config_param(woq::woq_config_param* p, woq::woq_runt
case woq::WOQ_QUANTIZE:
case woq::WOQ_DEQUANTIZE:
p->src_dt = dispatcher_utils::QBITS_FP32;
p->dst_dt = dispatcher_utils::QBITS_FP32; // bestla doesn't care about dst_dt in quantize/dequant task,so set fp32
// as default.
p->dst_dt = dispatcher_utils::QBITS_FP32; // bestla doesn't care about dst_dt in quantize/dequant task,so set
// fp32 as default.
break;
case woq::WOQ_LINEAR:
p->src_dt = get_qbits_dt(ctx->activation);
Expand Down Expand Up @@ -122,7 +123,7 @@ static void set_woq_workspace(const torch::Tensor& workspace) {
}

static void bestlaop_gemm(const torch::Tensor& matA, const torch::Tensor& matB, const torch::Tensor& matC,
bool matB_trans) {
bool matB_trans) {
TORCH_CHECK(matA.dim() == 2 && matB.dim() == 2 && matC.dim() == 2,
"Qbits: only support 2-dim input-tensor in bestla gemm op.");
bestla_gemm::bestla_gemm_runtime_ctx ctx;
Expand All @@ -138,6 +139,10 @@ static void bestlaop_gemm(const torch::Tensor& matA, const torch::Tensor& matB,
return bestla_gemm::dispatch_bestla_gemm(&ctx);
}

static torch::Tensor acquire_woq_packw_info(torch::Tensor& packw, int64_t acquire_type) {
return woq::get_packw_info(packw, static_cast<woq::PACKW_ACQUIRE_TYPE>(acquire_type));
}

static torch::Tensor qbits_dropout_fwd(torch::Tensor& output, double p) { return dropout_fwd(output, p); }

static void qbits_dropout_bwd(torch::Tensor& grad, torch::Tensor& scale) { dropout_bwd(grad, scale); }
Expand All @@ -149,6 +154,7 @@ TORCH_LIBRARY(bestlaop, m) {
m.def("woq_packq", &woq_packq);
m.def("set_woq_workspace", &set_woq_workspace);
m.def("matmul", &bestlaop_gemm);
m.def("acquire_woq_packw_info", &acquire_woq_packw_info);
}

TORCH_LIBRARY(qbits_customop, m) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# limitations under the License.

from ut_utils import *
from enum import Enum


def convert_idx(g_idx, k, blocksize):
Expand All @@ -27,6 +28,18 @@ def convert_idx(g_idx, k, blocksize):
return ret_idx


class acquire_type(Enum):
SIZE = 0
BLOCKSIZE = 1
K = 2
N = 3
ACT_SHUFFLE = 4
G_IDX = 5
WEI_TYPE = 6
CMPT_TYPE = 7
SCALE_TYPE = 8


@pytest.mark.parametrize("m", [256])
@pytest.mark.parametrize("n", [1024])
@pytest.mark.parametrize("k", [512])
Expand Down Expand Up @@ -66,3 +79,19 @@ def test(m, k, n, weight_type, scale_type, compute_type, asym, blocksize, dump_t
assert (abs(ref_dst - tar_dst).max() < 8)
else:
assert (abs(ref_dst - tar_dst).max() < 10)
packw_size = torch.ops.bestlaop.acquire_woq_packw_info(
packw, acquire_type.SIZE.value)[0].item()
if packw_size != packw.size()[0]:
assert (0)
packw_wei_type = torch.ops.bestlaop.acquire_woq_packw_info(
packw, acquire_type.WEI_TYPE.value)
packw_wei_type_str = ''.join(chr(ascii_code)
for ascii_code in packw_wei_type.tolist())
if packw_wei_type_str != weight_type:
assert (0)
enable_act_shuffle = torch.ops.bestlaop.acquire_woq_packw_info(
packw, acquire_type.ACT_SHUFFLE.value)[0] != 0
assert (enable_act_shuffle)
acquire_g_idx = packw_wei_type = torch.ops.bestlaop.acquire_woq_packw_info(
packw, acquire_type.G_IDX.value)
assert (abs(acquire_g_idx-cvt_idx).max() == 0)
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,49 @@
# limitations under the License.


import os
import operator
import torch
from functools import reduce
from torch import Tensor
from typing import Tuple, Optional, List
from enum import Enum


class qbits_acquire_type(Enum):
SIZE = 0
BLOCKSIZE = 1
K = 2
N = 3
ACT_SHUFFLE = 4
G_IDX = 5
WEI_TYPE = 6
CMPT_TYPE = 7
SCALE_TYPE = 8


def qbits_woq_linear_ref_impl(activation, packw, bias, compute_type, weight_type, scale_type):
assert (activation.is_contiguous())
assert (packw.is_contiguous())
activation = activation.to(torch.float32)
n = torch.ops.bestlaop.acquire_woq_packw_info(
packw, qbits_acquire_type.N.value)[0].item()
k = activation.shape[1]
revert_wei = torch.empty(k, n, dtype=torch.float)
torch.ops.bestlaop.woq_dequantize(
packw, revert_wei, False, compute_type, weight_type, scale_type)
enable_act_shuffle = torch.ops.bestlaop.acquire_woq_packw_info(
packw, qbits_acquire_type.ACT_SHUFFLE.value)[0] != 0
if enable_act_shuffle:
g_idx = torch.ops.bestlaop.acquire_woq_packw_info(
packw, qbits_acquire_type.G_IDX.value)
activation = torch.index_select(activation, 1, g_idx)
out = torch.matmul(activation, revert_wei)
if bias is not None:
assert (bias.is_contiguous())
assert (bias.dtype == torch.float32)
out += bias
return out


def prod(iterable):
Expand Down Expand Up @@ -64,18 +102,23 @@ def forward(

# 2. Matmul
# output = torch.nn.functional.linear(A, B_dequant, bias)
torch.ops.bestlaop.woq_linear(
A,
B.data,
bias,
out,
out.shape[-1],
bias is not None,
compute_dtype,
weight_dtype,
scale_dtype,
False,
)
qbits_debug_flag = os.getenv('QBITS_DEBUG', 'NULL')
if qbits_debug_flag == 'NULL':
torch.ops.bestlaop.woq_linear(
A,
B.data,
bias,
out,
out.shape[-1],
bias is not None,
compute_dtype,
weight_dtype,
scale_dtype,
False,
)
else:
out = qbits_woq_linear_ref_impl(
A, B.data, bias, compute_dtype, weight_dtype, scale_dtype)
output = out

# 3. Save state
Expand All @@ -101,7 +144,8 @@ def forward(
@staticmethod
def backward(ctx, grad_output):
if ctx.is_empty:
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
bias_grad = None if ctx.bias is None else torch.zeros_like(
ctx.bias)
return (
torch.zeros_like(ctx.A),
torch.zeros_like(ctx.B),
Expand All @@ -114,7 +158,8 @@ def backward(ctx, grad_output):
A, B = ctx.tensors
grad_A, grad_B, grad_bias = None, None, None

B_dequant = torch.zeros(grad_output.shape[-1], A.shape[-1], dtype=torch.float)
B_dequant = torch.zeros(
grad_output.shape[-1], A.shape[-1], dtype=torch.float)

torch.ops.bestlaop.woq_dequantize(
B, B_dequant, True, ctx.compute_dtype, ctx.weight_dtype, ctx.scale_dtype
Expand Down Expand Up @@ -149,17 +194,22 @@ def matmul_kbit(
A, B, out, bias, compute_dtype, weight_dtype, scale_dtype
)
else:
torch.ops.bestlaop.woq_linear(
A,
B.data,
bias,
out,
out.shape[-1],
bias is not None,
compute_dtype,
weight_dtype,
scale_dtype,
False,
)
qbits_debug_flag = os.getenv('QBITS_DEBUG', 'NULL')
if qbits_debug_flag == 'NULL':
torch.ops.bestlaop.woq_linear(
A,
B.data,
bias,
out,
out.shape[-1],
bias is not None,
compute_dtype,
weight_dtype,
scale_dtype,
False,
)
else:
out = qbits_woq_linear_ref_impl(
A, B.data, bias, compute_dtype, weight_dtype, scale_dtype)

return out
Loading