diff --git a/src/layer/arm/gemm_arm.cpp b/src/layer/arm/gemm_arm.cpp index 522ffd7d704..04f775f4479 100644 --- a/src/layer/arm/gemm_arm.cpp +++ b/src/layer/arm/gemm_arm.cpp @@ -47,7 +47,7 @@ Gemm_arm::Gemm_arm() #endif // __ARM_NEON #if NCNN_BF16 - // support_bf16_storage = true; + support_bf16_storage = true; #endif nT = 0; @@ -6037,6 +6037,15 @@ int Gemm_arm::create_pipeline_int8(const Option& opt) } #endif +#if __ARM_NEON + if (constant_broadcast_type_C == 3 && opt.use_packing_layout && CT_data.h % 4 == 0) + { + Mat C2; + ncnn::convert_packing(CT_data, C2, 4, opt); + CT_data = C2; + } +#endif + if (opt.lightmode) C_data.release(); } diff --git a/src/layer/arm/gemm_int8_bf16s.h b/src/layer/arm/gemm_int8_bf16s.h index b3d15675986..786b3310368 100644 --- a/src/layer/arm/gemm_int8_bf16s.h +++ b/src/layer/arm/gemm_int8_bf16s.h @@ -4914,11 +4914,11 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 3) { - float32x4_t _c2; - float32x4_t _c3; + uint16x8_t _c01; + uint16x8_t _c23; if (c_elempack == 1) { - uint16x8_t _c01 = uint16x8_t(); + _c01 = uint16x8_t(); _c01 = vsetq_lane_u16(pC[0], _c01, 0); _c01 = vsetq_lane_u16(pC[c_hstep], _c01, 1); _c01 = vsetq_lane_u16(pC[c_hstep * 2], _c01, 2); @@ -4927,7 +4927,7 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _c01 = vsetq_lane_u16(pC[c_hstep * 5], _c01, 5); _c01 = vsetq_lane_u16(pC[c_hstep * 6], _c01, 6); _c01 = vsetq_lane_u16(pC[c_hstep * 7], _c01, 7); - uint16x8_t _c23 = uint16x8_t(); + _c23 = uint16x8_t(); _c23 = vsetq_lane_u16(pC[1], _c23, 0); _c23 = vsetq_lane_u16(pC[c_hstep + 1], _c23, 1); _c23 = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c23, 2); @@ -4936,22 +4936,18 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _c23 = vsetq_lane_u16(pC[c_hstep * 5 + 1], _c23, 5); _c23 = vsetq_lane_u16(pC[c_hstep * 6 + 1], _c23, 6); _c23 = vsetq_lane_u16(pC[c_hstep * 7 + 1], _c23, 7); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - _c2 = bfloat2float(vget_low_u16(_c23)); - _c3 = bfloat2float(vget_high_u16(_c23)); pC += 2; } else // if (c_elempack == 4) { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + c_hstep * 4); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - _c2 = bfloat2float(vget_low_u16(_c23)); - _c3 = bfloat2float(vget_high_u16(_c23)); + _c01 = vld1q_u16(pC); + _c23 = vld1q_u16(pC + c_hstep * 4); pC += 8; } + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); @@ -8733,22 +8729,45 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma float32x4_t _cd = bfloat2float(vget_high_u16(_ccd)); float32x4_t _ce = bfloat2float(vget_low_u16(_cef)); float32x4_t _cf = bfloat2float(vget_high_u16(_cef)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c2); - _f2 = vaddq_f32(_f2, _c4); - _f3 = vaddq_f32(_f3, _c6); - _f4 = vaddq_f32(_f4, _c8); - _f5 = vaddq_f32(_f5, _ca); - _f6 = vaddq_f32(_f6, _cc); - _f7 = vaddq_f32(_f7, _ce); - _f8 = vaddq_f32(_f8, _c1); - _f9 = vaddq_f32(_f9, _c3); - _fa = vaddq_f32(_fa, _c5); - _fb = vaddq_f32(_fb, _c7); - _fc = vaddq_f32(_fc, _c9); - _fd = vaddq_f32(_fd, _cb); - _fe = vaddq_f32(_fe, _cd); - _ff = vaddq_f32(_ff, _cf); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c2); + _f2 = vaddq_f32(_f2, _c4); + _f3 = vaddq_f32(_f3, _c6); + _f4 = vaddq_f32(_f4, _c8); + _f5 = vaddq_f32(_f5, _ca); + _f6 = vaddq_f32(_f6, _cc); + _f7 = vaddq_f32(_f7, _ce); + _f8 = vaddq_f32(_f8, _c1); + _f9 = vaddq_f32(_f9, _c3); + _fa = vaddq_f32(_fa, _c5); + _fb = vaddq_f32(_fb, _c7); + _fc = vaddq_f32(_fc, _c9); + _fd = vaddq_f32(_fd, _cb); + _fe = vaddq_f32(_fe, _cd); + _ff = vaddq_f32(_ff, _cf); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c2, _beta); + _f2 = vmlaq_f32(_f2, _c4, _beta); + _f3 = vmlaq_f32(_f3, _c6, _beta); + _f4 = vmlaq_f32(_f4, _c8, _beta); + _f5 = vmlaq_f32(_f5, _ca, _beta); + _f6 = vmlaq_f32(_f6, _cc, _beta); + _f7 = vmlaq_f32(_f7, _ce, _beta); + _f8 = vmlaq_f32(_f8, _c1, _beta); + _f9 = vmlaq_f32(_f9, _c3, _beta); + _fa = vmlaq_f32(_fa, _c5, _beta); + _fb = vmlaq_f32(_fb, _c7, _beta); + _fc = vmlaq_f32(_fc, _c9, _beta); + _fd = vmlaq_f32(_fd, _cb, _beta); + _fe = vmlaq_f32(_fe, _cd, _beta); + _ff = vmlaq_f32(_ff, _cf, _beta); + } pC += 8; } else // if (c_elempack == 4) @@ -9017,14 +9036,44 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma uint16x4_t _cc6 = vld1_u16(pC + c_hstep * 6); uint16x4_t _cc7 = vld1_u16(pC + c_hstep * 7); transpose4x8_u16(_cc0, _cc1, _cc2, _cc3, _cc4, _cc5, _cc6, _cc7); - _f0 = vaddq_f32(_f0, bfloat2float(_cc0)); - _f1 = vaddq_f32(_f1, bfloat2float(_cc2)); - _f2 = vaddq_f32(_f2, bfloat2float(_cc4)); - _f3 = vaddq_f32(_f3, bfloat2float(_cc6)); - _f4 = vaddq_f32(_f4, bfloat2float(_cc1)); - _f5 = vaddq_f32(_f5, bfloat2float(_cc3)); - _f6 = vaddq_f32(_f6, bfloat2float(_cc5)); - _f7 = vaddq_f32(_f7, bfloat2float(_cc7)); + _c0 = bfloat2float(_cc0); + _c1 = bfloat2float(_cc2); + float32x4_t _c2 = bfloat2float(_cc4); + float32x4_t _c3 = bfloat2float(_cc6); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + _c0 = bfloat2float(_cc1); + _c1 = bfloat2float(_cc3); + _c2 = bfloat2float(_cc5); + _c3 = bfloat2float(_cc7); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); + } pC += 4; } else // if (c_elempack == 4) @@ -11108,8 +11157,6 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma uint16x8_t _c = vld1q_u16(pC); _c0 = bfloat2float(vget_low_u16(_c)); float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0);