Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

[LLM Runtime] refactor itrex backend based on the latest Jblas #769

Merged
merged 94 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
f7d03d8
qbits woq adaptation based on the latest jblas done.
zhewang1-intc Nov 22, 2023
5238aa4
fix python side op names
zhewang1-intc Nov 24, 2023
b4669ea
add jblas normal gemm in qbits
zhewang1-intc Nov 24, 2023
24fd3c3
better log & rm .clang-format
zhewang1-intc Nov 27, 2023
303a978
fix matmul bf16 ut
zhewang1-intc Nov 27, 2023
d854de3
update jblas
luoyu-intel Nov 27, 2023
27030ca
update jblas
luoyu-intel Nov 27, 2023
3415b38
add jblas new files
luoyu-intel Nov 27, 2023
c45e5d3
pass quant compile
luoyu-intel Nov 27, 2023
44c3a73
bypass mha, pass compilation.
luoyu-intel Nov 27, 2023
9f97da4
update jblas
luoyu-intel Nov 27, 2023
ba01508
update jblas
luoyu-intel Nov 28, 2023
049962b
update all integer weight to WeightNInteger
luoyu-intel Nov 28, 2023
8402853
fix quant size bug
luoyu-intel Nov 28, 2023
b02977e
update jblas
luoyu-intel Nov 28, 2023
fe23c12
update api
luoyu-intel Nov 28, 2023
2eca6ad
disable print flag
luoyu-intel Nov 28, 2023
2161622
add qkv
luoyu-intel Nov 28, 2023
5694f98
update jblas: merge GemmBaseRun and GemmKBlockRun
luoyu-intel Nov 29, 2023
e84efcc
change to 64-byte alignment
luoyu-intel Nov 29, 2023
90765ef
update jblas
luoyu-intel Nov 29, 2023
712d71a
change the object and tensor padding sizes
luoyu-intel Nov 29, 2023
2af7974
update jblas
luoyu-intel Nov 30, 2023
1e1a6b9
silu ffn support
zhewang1-intc Nov 30, 2023
e11e66f
init mha_dense
DDEle Nov 30, 2023
ffd29f1
fix int8 bug
luoyu-intel Nov 30, 2023
25a5050
fix mha
DDEle Dec 1, 2023
bb0148d
avoid set threads with CpuBase
DDEle Dec 1, 2023
240e883
update int kblock launcher
luoyu-intel Dec 4, 2023
f2cc214
fix fp16-bf16 mha
DDEle Dec 4, 2023
77f583e
fix int8 mha
DDEle Dec 4, 2023
575457f
mha clean code
DDEle Dec 4, 2023
5768742
update jblas
luoyu-intel Dec 4, 2023
1094666
modify kernel dispatch
luoyu-intel Dec 4, 2023
1512edc
mha uint64_t _core_id
DDEle Dec 4, 2023
754130c
dispatch amx_int8 and avx512_vnni
luoyu-intel Dec 4, 2023
9fa8e3d
fix scheduler
luoyu-intel Dec 4, 2023
f026e97
add ffn support for dispatch
luoyu-intel Dec 4, 2023
817ec91
update ip and ip_add
luoyu-intel Dec 4, 2023
d9677f4
update qbits
zhewang1-intc Dec 5, 2023
e11cefc
update jblas
luoyu-intel Dec 5, 2023
098e648
use 48x16 for bf16
luoyu-intel Dec 5, 2023
7d6a92a
fix api for MHA
luoyu-intel Dec 5, 2023
b379baa
speed bf16 next token
luoyu-intel Dec 5, 2023
8340c3a
update main pybind
zhenwei-intel Dec 5, 2023
8f7d50e
add ffn support
luoyu-intel Dec 5, 2023
62b4a46
update ip
luoyu-intel Dec 5, 2023
e6b5d96
update jblas
luoyu-intel Dec 5, 2023
c7eecb5
update ffn with epilogue template
zhewang1-intc Dec 5, 2023
3c530dd
add perchannel conversion
luoyu-intel Dec 5, 2023
1198053
update jblas
luoyu-intel Dec 6, 2023
e009cab
fall through asym to other compute types
luoyu-intel Dec 6, 2023
51eee67
support nf4 and fp4 weight
luoyu-intel Dec 6, 2023
2c4dce8
prevent fp4 and nf4 from comp=int8
luoyu-intel Dec 6, 2023
023e5a6
fix code
luoyu-intel Dec 6, 2023
6fb644f
add pack quantized weight function
luoyu-intel Dec 6, 2023
1a52022
ffn support multi-epilogue
yuchengliu1 Dec 6, 2023
5a45ce6
rebase main
luoyu-intel Dec 7, 2023
ceeae9a
add graph/core into cpplint check
airMeng Dec 7, 2023
aa8e997
fixed the alignment
luoyu-intel Dec 7, 2023
4aed4a4
clean common codes
luoyu-intel Dec 7, 2023
312a29c
add copyright
luoyu-intel Dec 7, 2023
961f3b4
fix mha ut & cpplint
DDEle Dec 7, 2023
2e17241
fix int type error
luoyu-intel Dec 7, 2023
47891f7
add constexpr
luoyu-intel Dec 7, 2023
890eaee
nolint false positive iwyu
DDEle Dec 7, 2023
281e0bd
fix type error of F4
luoyu-intel Dec 7, 2023
55f92a2
update jblas
luoyu-intel Dec 7, 2023
4f8d5f0
fix bug
luoyu-intel Dec 7, 2023
ab4894c
add blksize check before dispatch
luoyu-intel Dec 7, 2023
1dcea99
update jblas
luoyu-intel Dec 8, 2023
05e0633
update copyright
luoyu-intel Dec 8, 2023
f548e30
update jblas
luoyu-intel Dec 8, 2023
17e3a7e
mha ut improvement
DDEle Dec 8, 2023
02cd2b8
disable Wno-narrowing in qbits and fix all warning
zhewang1-intc Dec 8, 2023
45f2a01
add qbits cpplint check
zhewang1-intc Dec 8, 2023
6bb9101
update jblas and support e5m2/e4m3 wei + e8m0 scale woq feature in qbits
zhewang1-intc Dec 11, 2023
8054d89
update jblas
luoyu-intel Dec 11, 2023
4ea5cfb
update pybind format
zhenwei-intel Dec 12, 2023
65f84c7
fix clang-format.
zhewang1-intc Dec 12, 2023
42de32e
fix some cpplints
zhewang1-intc Dec 12, 2023
51aadd3
add fp8 weight and fp8 scale
luoyu-intel Dec 12, 2023
5d6959f
fix cpplint in conv and add some nolint
zhewang1-intc Dec 12, 2023
1b92bbc
more cpplints in conv.cpp
zhewang1-intc Dec 12, 2023
176b9b3
fp8 weight only valid for fp8 scale
luoyu-intel Dec 12, 2023
b59eb81
bug fix
luoyu-intel Dec 12, 2023
9108cfb
add gptq shuffle support
luoyu-intel Dec 12, 2023
bdef9df
fix f8 quant mantissa bits
zhewang1-intc Dec 12, 2023
9af4c5a
fix ffn template bug
luoyu-intel Dec 12, 2023
1520710
fix ip_add fusion bug
luoyu-intel Dec 12, 2023
1570c3b
change p=32 to p=16
luoyu-intel Dec 12, 2023
97877fc
fix shuffle buf
luoyu-intel Dec 12, 2023
192f979
determine whether symmetric from input zero_points tensor
airMeng Dec 12, 2023
7263631
fix ebits overflow when using uint8 to store
zhewang1-intc Dec 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/workflows/script/formatScan/cpplint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@ log_path=${log_dir}/cpplint.log
cpplint --extensions cpp,hpp --filter=-build/include_subdir,-build/header_guard --recursive --quiet --linelength=120 ${REPO_DIR}/intel_extension_for_transformers/llm/runtime/deprecated/compile 2>&1 | tee ${log_path}
cpplint --extensions cpp,hpp --filter=-build/include_subdir,-build/header_guard --recursive --quiet --linelength=120 ${REPO_DIR}/intel_extension_for_transformers/llm/runtime/deprecated/executor 2>&1 | tee -a ${log_path}
cpplint --extensions cpp,hpp --filter=-build/include_subdir,-build/header_guard --recursive --quiet --linelength=120 ${REPO_DIR}/intel_extension_for_transformers/llm/runtime/deprecated/test 2>&1 | tee -a ${log_path}
cpplint --extensions cpp,hpp --filter=-build/include_subdir,-build/header_guard --recursive --quiet --linelength=120 ${REPO_DIR}/intel_extension_for_transformers/llm/runtime/graph/application 2>&1 | tee -a ${log_path}
cpplint --extensions cpp,hpp --filter=-build/include_subdir,-build/header_guard --recursive --quiet --linelength=120 ${REPO_DIR}/intel_extension_for_transformers/llm/runtime/graph 2>&1 | tee -a ${log_path}
cpplint --extensions cpp,hpp --filter=-build/include_subdir,-build/header_guard --recursive --quiet --linelength=120 ${REPO_DIR}/intel_extension_for_transformers/library/kernels 2>&1 | tee -a ${log_path}
cpplint --extensions cpp,hpp --filter=-build/include_subdir,-build/header_guard --recursive --quiet --linelength=120 ${REPO_DIR}/intel_extension_for_transformers/llm/runtime/graph/models 2>&1 | tee -a ${log_path}
cpplint --extensions cpp,hpp --filter=-build/include_subdir,-build/header_guard --recursive --quiet --linelength=120 ${REPO_DIR}/intel_extension_for_transformers/llm/runtime/graph/vectors 2>&1 | tee -a ${log_path}
cpplint --extensions cpp,hpp --filter=-build/include_subdir,-build/header_guard --recursive --quiet --linelength=120 ${REPO_DIR}/intel_extension_for_transformers/operator/csrc 2>&1 | tee -a ${log_path}
if [[ ! -f ${log_path} ]] || [[ $(grep -c "Total errors found:" ${log_path}) != 0 ]]; then
exit 1
fi
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Language: Cpp
airMeng marked this conversation as resolved.
Show resolved Hide resolved
BasedOnStyle: Google
DerivePointerAlignment: false
ColumnLimit: 120
SpaceBeforeParens: ControlStatements
SpaceBeforeRangeBasedForLoopColon: true
SortIncludes: false
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#include <cstddef>
#include <type_traits>

#include "xbyak/xbyak.h"
#include "xbyak/xbyak_util.h"

Expand Down Expand Up @@ -50,6 +49,21 @@ class JitBase : protected Xbyak::CodeGenerator {
#endif
}

void padto_le(const Xbyak::Reg64& _src, int padding) {
// _src=_src/padding*padding
if (padding == 1) {
return;
}
for (int i = 1; i < 16; i++) {
if ((1 << i) == padding) {
shr(_src, i);
shl(_src, i);
return;
}
}
assert(0);
}

void generate_Nbitsmask(const Xbyak::Opmask& _msk, const Xbyak::Reg64& _pos, const Xbyak::Address& _total,
const Xbyak::Reg64& _tmp, const Xbyak::Reg64& _tmp1, int N) {
inLocalLabel();
Expand All @@ -59,9 +73,9 @@ class JitBase : protected Xbyak::CodeGenerator {
jb(".maskflag");
cmp(_tmp, 0);
jl(".zeroflag");
uint64_t allmask = ((uint64_t)1 << N) - 1;
uint64_t allmask = (static_cast<uint64_t>(1) << N) - 1;
if (N == 64) {
allmask = (uint64_t)-1;
allmask = static_cast<uint64_t>(-1);
}
mov(_tmp, allmask);
kmovq(_msk, _tmp);
Expand All @@ -87,13 +101,16 @@ class JitBase : protected Xbyak::CodeGenerator {
class JitAvx : protected JitBase {
protected:
static int constexpr VBits = 256;
static int constexpr VecBytes = VBits / 8;
static int constexpr RegCount = 16;
typedef Xbyak::Ymm vreg_t;
};

class JitAvx2 : protected JitAvx {
protected:
static int constexpr VBits = 256;
typedef Xbyak::Ymm vreg_t;
void vxor(const vreg_t& x1, const vreg_t& x2, const Xbyak::Operand& op) { vpxor(x1, x2, op); }

void loadbf16_f32(const Xbyak::Ymm& dst, const Xbyak::Address& addr) {
vpmovzxwd(dst, addr);
Expand All @@ -104,8 +121,12 @@ class JitAvx2 : protected JitAvx {
class JitAvx512f : protected JitAvx2 {
protected:
static int constexpr VBits = 512;
static int constexpr VecBytes = VBits / 8;
static int constexpr RegCount = 32;
typedef Xbyak::Zmm vreg_t;

void vxor(const vreg_t& x1, const vreg_t& x2, const Xbyak::Operand& op) { vpxorq(x1, x2, op); }

void interleave_2rows_4regs(Xbyak::Zmm* src_2regs, Xbyak::Zmm* tmp_2reg) {
vpunpcklwd(tmp_2reg[0], src_2regs[0], src_2regs[1]);
vpunpckhwd(tmp_2reg[1], src_2regs[0], src_2regs[1]);
Expand Down Expand Up @@ -192,18 +213,20 @@ class JitAvx512f : protected JitAvx2 {
}
};

class JitAvx512_bf16 : protected JitAvx512f {};

class JitAvx512_fp16 : protected JitAvx512f {};

class JitAvx512vnni : protected JitAvx512f {
protected:
void vpdpbusds_evex(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) {
void vpdpbusds_(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) {
vpdpbusds(x1, x2, op, Xbyak::EvexEncoding);
}
};

class JitAvxvnni : protected JitAvx2 {
protected:
void vpdpbusds_vex(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) {
void vpdpbusds_(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) {
vpdpbusds(x1, x2, op, Xbyak::VexEncoding);
}
};
Expand All @@ -216,6 +239,15 @@ class JitAmxtile : protected JitAvx512f {
uint16_t colb[16];
uint8_t rows[16];
};
static int constexpr TileCount = 8;

typedef long long (*configure_t)(void*);

static void generate_config(Xbyak::CodeGenerator* g) {
Xbyak::util::StackFrame st(g, 1, 0, 0);
auto& parambase = st.p[0];
g->ldtilecfg(g->ptr[parambase]);
}

static void configure_tiles(tileconfig_t& tc, int TILE_M, int TILE_N, int TILE_K, int elesize, int ANum, int BNum,
int CNum) {
Expand All @@ -224,19 +256,19 @@ class JitAmxtile : protected JitAvx512f {
// Configure C tiles
int t = 0;
for (; t < CNum; ++t) {
tc.rows[t] = uint8_t(TILE_M);
tc.colb[t] = uint16_t(TILE_N * 4);
tc.rows[t] = static_cast<uint8_t>(TILE_M);
tc.colb[t] = static_cast<uint16_t>(TILE_N * 4);
}
// Configure A tiles
for (; t < CNum + ANum; ++t) {
tc.rows[t] = uint8_t(TILE_M);
tc.colb[t] = uint16_t(TILE_K * elesize);
tc.rows[t] = static_cast<uint8_t>(TILE_M);
tc.colb[t] = static_cast<uint16_t>(TILE_K * elesize);
}
// Configure B tile. B effectively has 64 rows and 16 columns.
int kpack = 4 / elesize;
for (; t < CNum + ANum + BNum; ++t) {
tc.rows[t] = uint8_t(TILE_K / kpack);
tc.colb[t] = uint16_t(TILE_N * 4);
tc.rows[t] = static_cast<uint8_t>(TILE_K / kpack);
tc.colb[t] = static_cast<uint16_t>(TILE_N * 4);
}
}
};
Expand Down
115 changes: 69 additions & 46 deletions intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,59 +15,82 @@
#include <stdint.h>
enum JBLAS_CODE {
JblasSuccess = 0,
JblasInvalidParam = -1,
JblasInvalidISA = -2,
JblasRuntimeError = -3,
JblasNotSupport = -4,
JblasInvalidParam = 1,
JblasInvalidISA = 2,
JblasRuntimeError = 4,
JblasNotSupport = 8,
};
enum JBLAS_ISA {
JblasNoSIMD = 10,
JblasAVX = 11,
JblasAVX2 = 12,
JblasAVX_VNNI = 13,
JblasAVX512F = 14,
JblasAVX512_VNNI = 15,
JblasAMX_BF16 = 16,
JblasAMX_INT8 = 17,
JblasAVX512_FP16 = 18,
enum JBLAS_ISA : uint8_t {
JblasNoSIMD = 0,
JblasAVX,
JblasAVX2,
JblasAVX_VNNI,
JblasAVX512F,
JblasAVX512_VNNI,
JblasAMX_BF16,
JblasAMX_INT8,
JblasAVX512_FP16,
JblasAVX512_BF16,
};
enum JBLAS_DTYPE {
JblasF64 = 59,
JblasF32 = 60,
JblasBF16 = 61,
JblasS8 = 63,
JblasU8 = 64,
JblasF32F8 = 65,
};
enum JBLAS_FP8_ENCODING {
JblasFp8_e4m3 = 80,
JblasFp8_e5m2 = 81,
JblasFp8_e3m4 = 82,
enum class JBLAS_DTYPE : uint32_t {
EleBitsMask = 0xff,
EleBitsShift = 0,
EleBitsUndef = 0,
EleBits4 = 4,
EleBits8 = 8,
EleBits16 = 16,
EleBits32 = 32,
EleBits64 = 64,
TypeMask = 0xff00,
TypeShift = 8,
TypeFloat = 0 << TypeShift,
TypeInt = 1 << TypeShift,
SubTypeMask = 0xff0000,
SubTypeShift = 16,
SubType0 = 0 << SubTypeShift,
SubType1 = 1 << SubTypeShift,
SubType2 = 2 << SubTypeShift,
SubType3 = 3 << SubTypeShift,
F64 = EleBits64 | TypeFloat,
F32 = EleBits32 | TypeFloat,
F16 = EleBits16 | TypeFloat,
BF16 = EleBits16 | TypeFloat | SubType1,
F8_E4M3 = EleBits8 | TypeFloat,
F8_E5M2 = EleBits8 | TypeFloat | SubType1,
F8_E3M4 = EleBits8 | TypeFloat | SubType2,
F8_E8M0 = EleBits8 | TypeFloat | SubType3,
S8 = EleBits8 | TypeInt,
U8 = EleBits8 | TypeInt | SubType1,
S4_CLIP = EleBits4 | TypeInt,
S4_FULLRANGE = EleBits4 | TypeInt | SubType1,
F4_E2M1 = EleBits4 | TypeFloat,
F4_BNB = EleBits4 | TypeFloat | SubType1,
F4_NF4 = EleBits4 | TypeFloat | SubType2,
S32 = EleBits32 | TypeInt,
U32 = EleBits32 | TypeInt | SubType1,
};

enum JBLAS_LAYOUT { JblasRowMajor = 101, JblasColMajor = 102 };
enum JBLAS_TRANSPOSE {
JblasNoTrans = 111,
JblasTrans = 112,
JblasConjTrans = 113,
};
enum JBLAS_ELTWISEOP {
GELU,
SWISH,
TANH,
EXP,
LOW_PRECISION_EXP,
RELU,
LINEAR,
};
enum JBLAS_F4_TYPE {
F4_UNDEF,
FP4_BNB,
FP4_E2M1,
NF4,
};
enum JBLAS_SIGN_INT_TYPE {
S8,
S4_CLIP,
S4_FULLRANGE,
S4_UNDEF,
enum JBLAS_ELTWISEOP { GELU, SWISH, TANH, EXP, LOW_PRECISION_EXP, RELU, LINEAR };

enum class JBLAS_PROLOGUEB_IDS : uint32_t {
Undef = (uint32_t)-1,
Begin = 0,
NormalBegin = Begin,
WeightPack = NormalBegin,
NormalEnd,
KBlockBegin = NormalEnd,
WeightKBlockNInteger = KBlockBegin,
WeightKBlockNFloat,
WeightKBlockS8,
WeightKBlockS4,
WeightKBlockF4,
WeightKBlockF8,
KBlockEnd,
End,
};
Loading
Loading