Skip to content

Commit

Permalink
fix int8 bf16s
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Sep 26, 2024
1 parent d59410f commit 07f6755
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 41 deletions.
11 changes: 10 additions & 1 deletion src/layer/arm/gemm_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Gemm_arm::Gemm_arm()
#endif // __ARM_NEON

#if NCNN_BF16
// support_bf16_storage = true;
support_bf16_storage = true;
#endif

nT = 0;
Expand Down Expand Up @@ -6037,6 +6037,15 @@ int Gemm_arm::create_pipeline_int8(const Option& opt)
}
#endif

#if __ARM_NEON
if (constant_broadcast_type_C == 3 && opt.use_packing_layout && CT_data.h % 4 == 0)
{
Mat C2;
ncnn::convert_packing(CT_data, C2, 4, opt);
CT_data = C2;
}
#endif

if (opt.lightmode)
C_data.release();
}
Expand Down
127 changes: 87 additions & 40 deletions src/layer/arm/gemm_int8_bf16s.h
Original file line number Diff line number Diff line change
Expand Up @@ -4914,11 +4914,11 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat&
}
if (broadcast_type_C == 3)
{
float32x4_t _c2;
float32x4_t _c3;
uint16x8_t _c01;
uint16x8_t _c23;
if (c_elempack == 1)
{
uint16x8_t _c01 = uint16x8_t();
_c01 = uint16x8_t();
_c01 = vsetq_lane_u16(pC[0], _c01, 0);
_c01 = vsetq_lane_u16(pC[c_hstep], _c01, 1);
_c01 = vsetq_lane_u16(pC[c_hstep * 2], _c01, 2);
Expand All @@ -4927,7 +4927,7 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat&
_c01 = vsetq_lane_u16(pC[c_hstep * 5], _c01, 5);
_c01 = vsetq_lane_u16(pC[c_hstep * 6], _c01, 6);
_c01 = vsetq_lane_u16(pC[c_hstep * 7], _c01, 7);
uint16x8_t _c23 = uint16x8_t();
_c23 = uint16x8_t();
_c23 = vsetq_lane_u16(pC[1], _c23, 0);
_c23 = vsetq_lane_u16(pC[c_hstep + 1], _c23, 1);
_c23 = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c23, 2);
Expand All @@ -4936,22 +4936,18 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat&
_c23 = vsetq_lane_u16(pC[c_hstep * 5 + 1], _c23, 5);
_c23 = vsetq_lane_u16(pC[c_hstep * 6 + 1], _c23, 6);
_c23 = vsetq_lane_u16(pC[c_hstep * 7 + 1], _c23, 7);
_c0 = bfloat2float(vget_low_u16(_c01));
_c1 = bfloat2float(vget_high_u16(_c01));
_c2 = bfloat2float(vget_low_u16(_c23));
_c3 = bfloat2float(vget_high_u16(_c23));
pC += 2;
}
else // if (c_elempack == 4)
{
uint16x8_t _c01 = vld1q_u16(pC);
uint16x8_t _c23 = vld1q_u16(pC + c_hstep * 4);
_c0 = bfloat2float(vget_low_u16(_c01));
_c1 = bfloat2float(vget_high_u16(_c01));
_c2 = bfloat2float(vget_low_u16(_c23));
_c3 = bfloat2float(vget_high_u16(_c23));
_c01 = vld1q_u16(pC);
_c23 = vld1q_u16(pC + c_hstep * 4);
pC += 8;
}
_c0 = bfloat2float(vget_low_u16(_c01));
_c1 = bfloat2float(vget_high_u16(_c01));
float32x4_t _c2 = bfloat2float(vget_low_u16(_c23));
float32x4_t _c3 = bfloat2float(vget_high_u16(_c23));
if (beta == 1.f)
{
_f0 = vaddq_f32(_f0, _c0);
Expand Down Expand Up @@ -8733,22 +8729,45 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma
float32x4_t _cd = bfloat2float(vget_high_u16(_ccd));
float32x4_t _ce = bfloat2float(vget_low_u16(_cef));
float32x4_t _cf = bfloat2float(vget_high_u16(_cef));
_f0 = vaddq_f32(_f0, _c0);
_f1 = vaddq_f32(_f1, _c2);
_f2 = vaddq_f32(_f2, _c4);
_f3 = vaddq_f32(_f3, _c6);
_f4 = vaddq_f32(_f4, _c8);
_f5 = vaddq_f32(_f5, _ca);
_f6 = vaddq_f32(_f6, _cc);
_f7 = vaddq_f32(_f7, _ce);
_f8 = vaddq_f32(_f8, _c1);
_f9 = vaddq_f32(_f9, _c3);
_fa = vaddq_f32(_fa, _c5);
_fb = vaddq_f32(_fb, _c7);
_fc = vaddq_f32(_fc, _c9);
_fd = vaddq_f32(_fd, _cb);
_fe = vaddq_f32(_fe, _cd);
_ff = vaddq_f32(_ff, _cf);
if (beta == 1.f)
{
_f0 = vaddq_f32(_f0, _c0);
_f1 = vaddq_f32(_f1, _c2);
_f2 = vaddq_f32(_f2, _c4);
_f3 = vaddq_f32(_f3, _c6);
_f4 = vaddq_f32(_f4, _c8);
_f5 = vaddq_f32(_f5, _ca);
_f6 = vaddq_f32(_f6, _cc);
_f7 = vaddq_f32(_f7, _ce);
_f8 = vaddq_f32(_f8, _c1);
_f9 = vaddq_f32(_f9, _c3);
_fa = vaddq_f32(_fa, _c5);
_fb = vaddq_f32(_fb, _c7);
_fc = vaddq_f32(_fc, _c9);
_fd = vaddq_f32(_fd, _cb);
_fe = vaddq_f32(_fe, _cd);
_ff = vaddq_f32(_ff, _cf);
}
else
{
float32x4_t _beta = vdupq_n_f32(beta);
_f0 = vmlaq_f32(_f0, _c0, _beta);
_f1 = vmlaq_f32(_f1, _c2, _beta);
_f2 = vmlaq_f32(_f2, _c4, _beta);
_f3 = vmlaq_f32(_f3, _c6, _beta);
_f4 = vmlaq_f32(_f4, _c8, _beta);
_f5 = vmlaq_f32(_f5, _ca, _beta);
_f6 = vmlaq_f32(_f6, _cc, _beta);
_f7 = vmlaq_f32(_f7, _ce, _beta);
_f8 = vmlaq_f32(_f8, _c1, _beta);
_f9 = vmlaq_f32(_f9, _c3, _beta);
_fa = vmlaq_f32(_fa, _c5, _beta);
_fb = vmlaq_f32(_fb, _c7, _beta);
_fc = vmlaq_f32(_fc, _c9, _beta);
_fd = vmlaq_f32(_fd, _cb, _beta);
_fe = vmlaq_f32(_fe, _cd, _beta);
_ff = vmlaq_f32(_ff, _cf, _beta);
}
pC += 8;
}
else // if (c_elempack == 4)
Expand Down Expand Up @@ -9017,14 +9036,44 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma
uint16x4_t _cc6 = vld1_u16(pC + c_hstep * 6);
uint16x4_t _cc7 = vld1_u16(pC + c_hstep * 7);
transpose4x8_u16(_cc0, _cc1, _cc2, _cc3, _cc4, _cc5, _cc6, _cc7);
_f0 = vaddq_f32(_f0, bfloat2float(_cc0));
_f1 = vaddq_f32(_f1, bfloat2float(_cc2));
_f2 = vaddq_f32(_f2, bfloat2float(_cc4));
_f3 = vaddq_f32(_f3, bfloat2float(_cc6));
_f4 = vaddq_f32(_f4, bfloat2float(_cc1));
_f5 = vaddq_f32(_f5, bfloat2float(_cc3));
_f6 = vaddq_f32(_f6, bfloat2float(_cc5));
_f7 = vaddq_f32(_f7, bfloat2float(_cc7));
_c0 = bfloat2float(_cc0);
_c1 = bfloat2float(_cc2);
float32x4_t _c2 = bfloat2float(_cc4);
float32x4_t _c3 = bfloat2float(_cc6);
if (beta == 1.f)
{
_f0 = vaddq_f32(_f0, _c0);
_f1 = vaddq_f32(_f1, _c1);
_f2 = vaddq_f32(_f2, _c2);
_f3 = vaddq_f32(_f3, _c3);
}
else
{
float32x4_t _beta = vdupq_n_f32(beta);
_f0 = vmlaq_f32(_f0, _c0, _beta);
_f1 = vmlaq_f32(_f1, _c1, _beta);
_f2 = vmlaq_f32(_f2, _c2, _beta);
_f3 = vmlaq_f32(_f3, _c3, _beta);
}
_c0 = bfloat2float(_cc1);
_c1 = bfloat2float(_cc3);
_c2 = bfloat2float(_cc5);
_c3 = bfloat2float(_cc7);
if (beta == 1.f)
{
_f4 = vaddq_f32(_f4, _c0);
_f5 = vaddq_f32(_f5, _c1);
_f6 = vaddq_f32(_f6, _c2);
_f7 = vaddq_f32(_f7, _c3);
}
else
{
float32x4_t _beta = vdupq_n_f32(beta);
_f4 = vmlaq_f32(_f4, _c0, _beta);
_f5 = vmlaq_f32(_f5, _c1, _beta);
_f6 = vmlaq_f32(_f6, _c2, _beta);
_f7 = vmlaq_f32(_f7, _c3, _beta);
}
pC += 4;
}
else // if (c_elempack == 4)
Expand Down Expand Up @@ -11108,8 +11157,6 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma
uint16x8_t _c = vld1q_u16(pC);
_c0 = bfloat2float(vget_low_u16(_c));
float32x4_t _c1 = bfloat2float(vget_high_u16(_c));
_f0 = vaddq_f32(_f0, _c0);
_f1 = vaddq_f32(_f1, _c1);
if (beta == 1.f)
{
_f0 = vaddq_f32(_f0, _c0);
Expand Down

0 comments on commit 07f6755

Please sign in to comment.