Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
Q4 perchannel (#271)
Browse files Browse the repository at this point in the history
* add s4 perchannel quant and inner product code.
  • Loading branch information
luoyu-intel authored Sep 8, 2023
1 parent 51a1b88 commit 4e164a8
Show file tree
Hide file tree
Showing 10 changed files with 602 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,12 @@ struct int4x2 : bit4x2 {
int4x2(int8_t v) : bit4x2(v) {}
int4x2() : bit4x2() {}
static int8_t convert(int8_t src) {
int16_t dst = src;
dst += 7;
dst >>= 4;
return dst > 7 ? 7 : dst;
int32_t dst = src;
dst = dst >= 0 ? dst + 8 : dst - 8;
dst = dst / 16;
dst = dst > 7 ? 7 : dst;
dst = dst < -8 ? -8 : dst;
return dst;
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ enum class WeightCompType : int {
WeightFp4E2M1ScaleFp32,
WeightNf4ScaleFp32,
WeightS8ScaleFp32PerChannelN,
WeightS4ClipScaleFp32PerChannelN,
End,
};

Expand Down Expand Up @@ -661,14 +662,15 @@ class WeightS8ScaleFp32PerChannelN : public WeightS8ScaleFp32<_GemmCore_T, ISA_T
if (zero_points != nullptr) {
std::memcpy(stor->mZPtr, zero_points, N * sizeof(zero_points[0]));
}
reduceWeight(N, K, B, ldb, scales, zero_points, stor->mRPtr);
WeightS8ScaleFp32<_GemmCore_T, ISA_T>::reorderWeight(N, K, B, ldb, stor->mWPtr);
utils::avector<float> deq(K * N);
WeightS8ScaleFp32<_GemmCore_T, ISA_T>::unpackWeight(N, K, stor, deq.data(), N);
reduceWeight(N, K, deq.data(), ldb, stor->mRPtr);
}
}

protected:
void reduceWeight(const int N, const int K, const int8_t* B, const int ldb, const float* scales,
const int8_t* zero_points, float* rptr) {
void reduceWeight(const int N, const int K, const float* B, const int ldb, float* rptr) {
utils::parallel::Parallel2DRowMajor _para;
utils::CpuBase cb;
_para.update(K, N, K, 16, cb.mNumThreads);
Expand All @@ -684,10 +686,9 @@ class WeightS8ScaleFp32PerChannelN : public WeightS8ScaleFp32<_GemmCore_T, ISA_T
int colremain = utils::remainsize(colidx, N, colsize);
const auto src = B + rowidx * ldb + colidx;
const auto dst = rptr + colidx;
using RowReduceSum = kernel::wrapper::QuantS8RowReduceSum<float>;
using RowReduceSum = kernel::wrapper::RowReduceSum<float>;
auto ret = RowReduceSum::template forward<ISA_T>( //
src, ldb, scales + colidx, zero_points != nullptr ? zero_points + colidx : nullptr, rowremain, colremain,
dst);
src, ldb, rowremain, colremain, dst);
assert(ret == JblasSuccess);
}
}
Expand Down Expand Up @@ -910,6 +911,166 @@ class WeightS4ScaleFp32 : public WeightS8ScaleFp32<_GemmCore_T, ISA_T> {
}
};

class StorageWeightS4ScaleFp32PerChannelN : public StorageWeightS4ScaleFp32, public StorageWeightReduce<float> {
public:
StorageWeightS4ScaleFp32PerChannelN(jblas::gemm::GemmCoreType _type, JBLAS_SIGN_INT_TYPE _s4_type = S4_UNDEF)
: StorageWeightS4ScaleFp32(_type) {
switch (_s4_type) {
case S4_CLIP:
mType = static_cast<int>(WeightCompType::WeightS4ClipScaleFp32PerChannelN);
break;
case S4_FULLRANGE:
default:
assert(false);
break;
}
}

void resize(int NPad, int KPad, int K, bool IsSym = true) {
PackedWeightKBlock::resize(NPad, KPad, K); // kblock==K
StorageWeight4Bit::resize(NPad, KPad);
StorageSimpleCorrection<float, int8_t>::resize(NPad, 1, IsSym);
StorageWeightReduce<float>::resize(NPad, 1);
}

protected:
virtual size_t getDataSerializedSize() override {
size_t totalsize = StorageWeight4Bit::myDataSerializedSize() +
StorageSimpleCorrection<float, int8_t>::myDataSerializedSize() +
StorageWeightReduce<float>::myDataSerializedSize();
return totalsize;
}
virtual void serializeDataToBuffer(void* buf) override {
auto wptr = reinterpret_cast<int8_t*>(buf);
StorageWeight4Bit::mySerializeDataToBuffer(wptr);
StorageSimpleCorrection<float, int8_t>::mySerializeDataToBuffer(wptr);
StorageWeightReduce<float>::mySerializeDataToBuffer(wptr);
}
virtual void deserializeDataBuffer(void* buf, int memalloc) override {
auto rptr = reinterpret_cast<int8_t*>(buf);
StorageWeight4Bit::myDeserializeDataBuffer(rptr, memalloc);
StorageSimpleCorrection<float, int8_t>::myDeserializeDataBuffer(rptr, memalloc);
StorageWeightReduce<float>::myDeserializeDataBuffer(rptr, memalloc);
}
};

template <class _GemmCore_T, JBLAS_ISA ISA_T, JBLAS_SIGN_INT_TYPE S4_T>
class WeightS4ScaleFp32PerChannelN : public WeightS8ScaleFp32PerChannelN<_GemmCore_T, ISA_T> {
public:
using Parent = WeightS8ScaleFp32PerChannelN<_GemmCore_T, ISA_T>;
using Param = typename Parent::Param;
using StorageWeight = StorageWeightS4ScaleFp32PerChannelN;
PackedWeight* createStorage(const int N, const int K, bool is_sym = true) override {
int KPad = utils::padto(K, _GemmCore_T::KTILE);
int NPad = utils::padto(N, _GemmCore_T::NTILE);
auto ptr = new StorageWeight(_GemmCore_T::TYPE, S4_T);
ptr->resize(NPad, KPad, K, is_sym);
return ptr;
}

virtual void quantRowBlock(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, int ld_dst,
float* scales, int8_t* zero_points, int blocksize) {
kernel::wrapper::QuantizeSignIntRowBlock::forward<ISA_T, S4_T>(srcptr, dstptr, row, col, ld_src, ld_dst, scales,
zero_points, blocksize);
}

virtual void packQWeight(const int N, const int K, const int8_t* B, const int ldb, const float* scales,
const int8_t* zero_points, PackedWeight* ptr) override {
auto stor = dynamic_cast<StorageWeight*>(ptr);
if (stor) {
std::memcpy(stor->mSPtr, scales, N * sizeof(scales[0]));
if (zero_points != nullptr) {
std::memcpy(stor->mZPtr, zero_points, N * sizeof(zero_points[0]));
}
utils::avector<int8_t> reorded(stor->mKPad * stor->mNPad);
WeightS8ScaleFp32<_GemmCore_T, ISA_T>::reorderWeight(N, K, B, ldb, reorded.data());
compressWeight(stor->mNPad, stor->mKPad, reorded.data(), stor->mNPad, stor->mWPtr);
utils::avector<float> deq(K * N);
WeightS8ScaleFp32<_GemmCore_T, ISA_T>::unpackWeight(N, K, stor, deq.data(), N);
Parent::reduceWeight(N, K, deq.data(), ldb, stor->mRPtr);
}
}

void compressWeight(const int N, const int K, const int8_t* B, const int ldb, utils::bit4x2* dstptr) {
utils::parallel::Parallel2DRowMajor _para;
utils::CpuBase cb;
_para.update(K, N, _GemmCore_T::KTILE, _GemmCore_T::NTILE, cb.mNumThreads);
omp_set_num_threads(cb.mNumThreads);
#pragma omp parallel
{
int tidx = omp_get_thread_num();
int colidx, rowidx, rowsize, colsize;
_para.getIndex(tidx, &rowidx, &colidx, &rowsize, &colsize);
if (rowsize > 0 && colsize > 0) {
int rowremain = utils::remainsize(rowidx, K,
rowsize); // rowremain: src valid size. rowsize: padded size
int colremain = utils::remainsize(colidx, N, colsize);
auto ret = doCompress(B + rowidx * ldb + colidx, dstptr + rowidx * ldb / 2 + colidx / 2, rowremain, colremain,
ldb, ldb);
assert(ret == JblasSuccess);
}
}
}

virtual inline JBLAS_CODE getWeight(int8_t** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset,
const Param& _param) override {
auto wptr = dynamic_cast<const StorageWeight*>(_param.packedW);
if (wptr) {
auto NPad = wptr->mNPad;
auto KPad = wptr->mKPad;
auto bptr = wptr->mWPtr + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2;
for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) {
kernel::wrapper::DecompressKBlockS4S8::forward<ISA_T, S4_T>(
(utils::int4x2*)(bptr + i * KPad / 2), *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW,
_GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW,
_GemmCore_T::NTILE * _GemmCore_T::PACK_ROW);
}
*dststep = k_size;
return JblasSuccess;
}
return JblasInvalidParam;
}

virtual inline JBLAS_CODE getWeight(float** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset,
const Param& _param) override {
auto wptr = dynamic_cast<const StorageWeight*>(_param.packedW);
// TODO unpack vnni format to fp32
if (wptr) {
auto NPad = wptr->mNPad;
auto KPad = wptr->mKPad;
auto bptr = wptr->mWPtr + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2;
for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) {
if constexpr (_GemmCore_T::PACK_ROW == 1) {
kernel::wrapper::DecompressPerNS4FP<float>::forward<ISA_T, float, S4_T>(
(utils::int4x2*)(bptr + i * KPad / 2), *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW,
_GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW,
_GemmCore_T::NTILE * _GemmCore_T::PACK_ROW, wptr->mSPtr + n_offset + i,
wptr->mZPtr != nullptr ? wptr->mZPtr + n_offset + i : nullptr, k_offset, wptr->mBlockSize, NPad);
} else {
kernel::wrapper::DecompressPerNS4FPPackRow<float>::forward<ISA_T, float, S4_T>(
(utils::int4x2*)(bptr + i * KPad / 2), *dstptr + i * k_size, k_size, _GemmCore_T::NTILE,
_GemmCore_T::NTILE, _GemmCore_T::NTILE, wptr->mSPtr + n_offset + i,
wptr->mZPtr != nullptr ? wptr->mZPtr + n_offset + i : nullptr, k_offset, wptr->mBlockSize, NPad,
_GemmCore_T::PACK_ROW);
}
}
*dststep = k_size;
return JblasSuccess;
}
return JblasInvalidParam;
}

protected:
virtual JBLAS_CODE doCompress(const int8_t* srcptr, void* dstptr, int row, int col, int ld_src, int ld_dst) {
return kernel::wrapper::CompressS8S4<_GemmCore_T::NTILE>::template forward<ISA_T>(
srcptr, reinterpret_cast<utils::int4x2*>(dstptr), row, col, ld_src,
ld_dst); // ld_dst here not stride
}
};

template <class _GemmCore_T, JBLAS_ISA ISA_T>
using WeightS4ClipScaleFp32PerN = WeightS4ScaleFp32PerChannelN<_GemmCore_T, ISA_T, S4_CLIP>;

class StorageWeightS4ScaleBf16 : public prologue::weight_comp::PackedWeightKBlock,
public StorageWeight4Bit,
public StorageSimpleCorrection<utils::bf16, int8_t> {
Expand Down Expand Up @@ -1233,6 +1394,11 @@ class PackedWeightParser {
ptr->deserializeBuffer(rptr, memalloc);
return ptr;
}
case WeightCompType::WeightS4ClipScaleFp32PerChannelN: {
auto ptr = new StorageWeightS4ScaleFp32PerChannelN(jblas::gemm::GemmCoreType::Undef);
ptr->deserializeBuffer(rptr, memalloc);
return ptr;
}
default:
return nullptr;
}
Expand Down
115 changes: 104 additions & 11 deletions intel_extension_for_transformers/llm/library/jblas/jblas/kernel_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ static inline JBLAS_CODE padding_trans_interleave(const T_SRC* src, T_DST* dst,

template <typename SRC_DT, typename DST_DT>
static inline JBLAS_CODE dt_cvt_2D_write_back(const void* raw_srcptr, void* raw_dstptr, int row, int col, int srcstride,
int dststride, bool zeropadding) {
int dststride, bool zeropadding) {
for (int i = 0; i < row; i++) {
int j = 0;
for (; j < col; j++) {
Expand Down Expand Up @@ -276,8 +276,8 @@ inline JBLAS_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, _DST_T* dstptr,
scale1 = sptr[s1_idx];
}
if (zero_points != nullptr) {
dst0 = (float(get_s8<S4_T>(tmp.x)) - float((zero_points + kpos * NPad)[i * ld_dst + j + 0])) * scale0;
dst1 = (float(get_s8<S4_T>(tmp.y)) - float((zero_points + kpos * NPad)[i * ld_dst + j + 1])) * scale1;
dst0 = (float(get_s8<S4_T>(tmp.x)) - float((zero_points + kpos * NPad)[j + 0])) * scale0;
dst1 = (float(get_s8<S4_T>(tmp.y)) - float((zero_points + kpos * NPad)[j + 1])) * scale1;
} else {
dst0 = float(get_s8<S4_T>(tmp.x)) * scale0;
dst1 = float(get_s8<S4_T>(tmp.y)) * scale1;
Expand All @@ -297,6 +297,50 @@ inline JBLAS_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, _DST_T* dstptr,
return JblasSuccess;
}

template <JBLAS_SIGN_INT_TYPE S4_T, typename _DST_T, typename _S_T>
inline JBLAS_CODE decompress_pern_s4_fp(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst,
_S_T* scales, int8_t* zero_points, int k_offset, int kblock, int NPad) {
auto sptr = scales;
for (int i = 0; i < row; i++) {
for (int j = 0; j < col; j += 2) {
auto tmp = srcptr[i * ld_src / 2 + j / 2];
float scale0, scale1, dst0, dst1;
int s0_idx, s1_idx;
if constexpr (std::is_same<_DST_T, utils::bf16>::value) {
s0_idx = j / 2;
s1_idx = j / 2;
} else {
s0_idx = j;
s1_idx = j + 1;
}
if constexpr (std::is_same<_S_T, utils::bf16>::value) {
scale0 = sptr[s0_idx].tofloat();
scale1 = sptr[s1_idx].tofloat();
} else {
scale0 = sptr[s0_idx];
scale1 = sptr[s1_idx];
}
if (zero_points != nullptr) {
dst0 = (float(get_s8<S4_T>(tmp.x)) - float((zero_points)[j + 0])) * scale0;
dst1 = (float(get_s8<S4_T>(tmp.y)) - float((zero_points)[j + 1])) * scale1;
} else {
dst0 = float(get_s8<S4_T>(tmp.x)) * scale0;
dst1 = float(get_s8<S4_T>(tmp.y)) * scale1;
}
if constexpr (std::is_same<_DST_T, utils::bf16>::value) {
utils::bf16 bf16_ret0, bf16_ret1;
bf16_ret0.fromfloat(dst0);
bf16_ret1.fromfloat(dst1);
dstptr[i * ld_dst + j + 0] = bf16_ret0;
dstptr[i * ld_dst + j + 1] = bf16_ret1;
} else {
dstptr[i * ld_dst + j + 0] = dst0;
dstptr[i * ld_dst + j + 1] = dst1;
}
}
}
return JblasSuccess;
}
template <JBLAS_SIGN_INT_TYPE S4_T>
inline JBLAS_CODE decompress_kblock_s4_fp_packrow(utils::int4x2* srcptr, float* dstptr, int row, int col, int ld_src,
int ld_dst, float* scales, int8_t* zero_points, int k_offset,
Expand All @@ -312,9 +356,9 @@ inline JBLAS_CODE decompress_kblock_s4_fp_packrow(utils::int4x2* srcptr, float*
dstptr[i * ld_dst + j + 1] = float(get_s8<S4_T>(tmp.y)) * sptr[j + 1];
} else {
dstptr[i * ld_dst + j + 0] =
(float(get_s8<S4_T>(tmp.x)) - float((zero_points + kpos * NPad)[i * ld_dst + j + 0])) * sptr[j + 0];
(float(get_s8<S4_T>(tmp.x)) - float((zero_points + kpos * NPad)[j + 0])) * sptr[j + 0];
dstptr[i * ld_dst + j + 1] =
(float(get_s8<S4_T>(tmp.y)) - float((zero_points + kpos * NPad)[i * ld_dst + j + 0])) * sptr[j + 1];
(float(get_s8<S4_T>(tmp.y)) - float((zero_points + kpos * NPad)[j + 1])) * sptr[j + 1];
}
}
} else {
Expand All @@ -324,13 +368,51 @@ inline JBLAS_CODE decompress_kblock_s4_fp_packrow(utils::int4x2* srcptr, float*
auto sptr = scales + kpos * NPad + j;
auto tmp = srcptr[(i * ld_src + j * packrow + k) / 2];
if (zero_points == nullptr) {
dstptr[i * ld_dst + j + 0] = float(get_s8<S4_T>(tmp.x)) * sptr[0];
dstptr[i * ld_dst + j + 1] = float(get_s8<S4_T>(tmp.y)) * sptr[1];
dstptr[i * ld_dst + j * packrow + k + 0] = float(get_s8<S4_T>(tmp.x)) * sptr[0];
dstptr[i * ld_dst + j * packrow + k + 1] = float(get_s8<S4_T>(tmp.y)) * sptr[0];
} else {
dstptr[i * ld_dst + j * packrow + k + 0] =
(float(get_s8<S4_T>(tmp.x)) - float((zero_points + kpos * NPad)[j + 0])) * sptr[0];
dstptr[i * ld_dst + j * packrow + k + 1] =
(float(get_s8<S4_T>(tmp.y)) - float((zero_points + kpos * NPad)[j + 0])) * sptr[0];
}
}
}
}
}
return JblasSuccess;
}

template <JBLAS_SIGN_INT_TYPE S4_T>
inline JBLAS_CODE decompress_pern_s4_fp_packrow(utils::int4x2* srcptr, float* dstptr, int row, int col, int ld_src,
int ld_dst, float* scales, int8_t* zero_points, int k_offset,
int kblock, int NPad, int packrow) {
for (int i = 0; i < row; i += packrow) {
if (packrow == 1) {
auto sptr = scales;
for (int j = 0; j < col; j += 2) {
auto tmp = srcptr[i * ld_src / 2 + j / 2];
if (zero_points == nullptr) {
dstptr[i * ld_dst + j + 0] = float(get_s8<S4_T>(tmp.x)) * sptr[j + 0];
dstptr[i * ld_dst + j + 1] = float(get_s8<S4_T>(tmp.y)) * sptr[j + 1];
} else {
dstptr[i * ld_dst + j + 0] = (float(get_s8<S4_T>(tmp.x)) - float((zero_points)[j + 0])) * sptr[j + 0];
dstptr[i * ld_dst + j + 1] = (float(get_s8<S4_T>(tmp.y)) - float((zero_points)[j + 1])) * sptr[j + 1];
}
}
} else {
for (int j = 0; j < col; j++) {
auto sptr = scales + j;
for (int k = 0; k < packrow; k += 2) {
auto tmp = srcptr[(i * ld_src + j * packrow + k) / 2];
if (zero_points == nullptr) {
dstptr[i * ld_dst + j * packrow + k + 0] = float(get_s8<S4_T>(tmp.x)) * sptr[0];
dstptr[i * ld_dst + j * packrow + k + 1] = float(get_s8<S4_T>(tmp.y)) * sptr[0];
} else {
dstptr[i * ld_dst + j + 0] =
(float(get_s8<S4_T>(tmp.x)) - float((zero_points + kpos * NPad)[i * ld_dst + j + 0])) * sptr[0];
dstptr[i * ld_dst + j + 1] =
(float(get_s8<S4_T>(tmp.y)) - float((zero_points + kpos * NPad)[i * ld_dst + j + 0])) * sptr[1];
dstptr[i * ld_dst + j * packrow + k + 0] =
(float(get_s8<S4_T>(tmp.x)) - float((zero_points)[j + 0])) * sptr[0];
dstptr[i * ld_dst + j * packrow + k + 1] =
(float(get_s8<S4_T>(tmp.y)) - float((zero_points)[j + 0])) * sptr[0];
}
}
}
Expand Down Expand Up @@ -996,6 +1078,17 @@ static inline JBLAS_CODE quant_s8_row_reduce_sum(const int8_t* srcptr, int ldsrc
return JblasSuccess;
}

template <typename _RT>
static inline JBLAS_CODE row_reduce_sum(const _RT* srcptr, int ldsrc, int row, int col, _RT* reduce) {
std::memset(reduce, 0, sizeof(reduce[0]) * col);
for (int i = 0; i < row; i++) {
for (int j = 0; j < col; j++) {
reduce[j] += srcptr[i * ldsrc + j];
}
}
return JblasSuccess;
}

static inline JBLAS_CODE remove_zeropoint_bias(float* accptr, int ldacc, int row, int col, uint8_t* zps, float* scales,
int lds, const float* reduce) {
for (int i = 0; i < row; i++) {
Expand Down
Loading

0 comments on commit 4e164a8

Please sign in to comment.