diff --git a/src/layer/arm/gemm_int8.h b/src/layer/arm/gemm_int8.h index 09c94d5226f..020df8b9c84 100644 --- a/src/layer/arm/gemm_int8.h +++ b/src/layer/arm/gemm_int8.h @@ -5675,423 +5675,388 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& } } - if (out_elempack == 4) - { - int jj = 0; + int jj = 0; #if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); - int32x4_t _sum8 = vld1q_s32(pp + 32); - int32x4_t _sum9 = vld1q_s32(pp + 36); - int32x4_t _suma = vld1q_s32(pp + 40); - int32x4_t _sumb = vld1q_s32(pp + 44); - int32x4_t _sumc = vld1q_s32(pp + 48); - int32x4_t _sumd = vld1q_s32(pp + 52); - int32x4_t _sume = vld1q_s32(pp + 56); - int32x4_t _sumf = vld1q_s32(pp + 60); - -#if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - // e4 f4 g4 h4 - // e5 f5 g5 h5 - // e6 f6 g6 h6 - // e7 f7 g7 h7 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - // e4 f4 g4 h4 - // e5 f5 g5 h5 - // e6 f6 g6 h6 - // e7 f7 g7 h7 - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); - float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale0); - float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale0); - float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_suma), _descale0); - float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sumb), _descale0); - float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); - float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); - float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); - float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); - float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); - float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); - float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); - float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + int32x4_t _sum8 = vld1q_s32(pp + 32); + int32x4_t _sum9 = vld1q_s32(pp + 36); + int32x4_t _suma = vld1q_s32(pp + 40); + int32x4_t _sumb = vld1q_s32(pp + 44); + int32x4_t _sumc = vld1q_s32(pp + 48); + int32x4_t _sumd = vld1q_s32(pp + 52); + int32x4_t _sume = vld1q_s32(pp + 56); + int32x4_t _sumf = vld1q_s32(pp + 60); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 #else - // from - // a0 b1 c2 d3 - // e4 f5 g6 h7 - // e0 f1 g2 h3 - // a4 b5 c6 d7 - // c0 d1 a2 b3 - // g4 h5 e6 f7 - // g0 h1 e2 f3 - // c4 d5 a6 b7 - // a3 b2 c1 d0 - // e7 f6 g5 h4 - // e3 f2 g1 h0 - // a7 b6 c5 d4 - // c3 d2 a1 b0 - // g7 h6 e5 f4 - // g3 h2 e1 f0 - // c7 d6 a5 b4 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - // e4 f4 g4 h4 - // e5 f5 g5 h5 - // e6 f6 g6 h6 - // e7 f7 g7 h7 - { - _sum8 = vrev64q_s32(_sum8); - _sum9 = vrev64q_s32(_sum9); - _suma = vrev64q_s32(_suma); - _sumb = vrev64q_s32(_sumb); - _sumc = vrev64q_s32(_sumc); - _sumd = vrev64q_s32(_sumd); - _sume = vrev64q_s32(_sume); - _sumf = vrev64q_s32(_sumf); - _sum8 = vextq_s32(_sum8, _sum8, 2); - _sum9 = vextq_s32(_sum9, _sum9, 2); - _suma = vextq_s32(_suma, _suma, 2); - _sumb = vextq_s32(_sumb, _sumb, 2); - _sumc = vextq_s32(_sumc, _sumc, 2); - _sumd = vextq_s32(_sumd, _sumd, 2); - _sume = vextq_s32(_sume, _sume, 2); - _sumf = vextq_s32(_sumf, _sumf, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); - int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); - int32x4x2_t _t2 = vzipq_s32(_sum2, _sume); - int32x4x2_t _t3 = vzipq_s32(_sum6, _suma); - int32x4x2_t _t4 = vzipq_s32(_sum3, _sumf); - int32x4x2_t _t5 = vzipq_s32(_sum7, _sumb); - int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); - int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum4 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); - _sum5 = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); - _sum6 = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); - _sum8 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum9 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _suma = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sumb = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sumc = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); - _sumd = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); - _sume = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); - _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - _sum5 = vrev64q_s32(_sum5); - _sum7 = vrev64q_s32(_sum7); - _sum9 = vrev64q_s32(_sum9); - _sumb = vrev64q_s32(_sumb); - _sumd = vrev64q_s32(_sumd); - _sumf = vrev64q_s32(_sumf); - } - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); - float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale0); - float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale0); - float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale0); - float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale0); - float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale1); - float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale1); - float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_suma), _descale1); - float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sumb), _descale1); - float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); - float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); - float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); - float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); + // from + // a0 b1 c2 d3 + // e4 f5 g6 h7 + // e0 f1 g2 h3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // g4 h5 e6 f7 + // g0 h1 e2 f3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // e7 f6 g5 h4 + // e3 f2 g1 h0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // g7 h6 e5 f4 + // g3 h2 e1 f0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 + { + _sum8 = vrev64q_s32(_sum8); + _sum9 = vrev64q_s32(_sum9); + _suma = vrev64q_s32(_suma); + _sumb = vrev64q_s32(_sumb); + _sumc = vrev64q_s32(_sumc); + _sumd = vrev64q_s32(_sumd); + _sume = vrev64q_s32(_sume); + _sumf = vrev64q_s32(_sumf); + _sum8 = vextq_s32(_sum8, _sum8, 2); + _sum9 = vextq_s32(_sum9, _sum9, 2); + _suma = vextq_s32(_suma, _suma, 2); + _sumb = vextq_s32(_sumb, _sumb, 2); + _sumc = vextq_s32(_sumc, _sumc, 2); + _sumd = vextq_s32(_sumd, _sumd, 2); + _sume = vextq_s32(_sume, _sume, 2); + _sumf = vextq_s32(_sumf, _sumf, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); + int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); + int32x4x2_t _t2 = vzipq_s32(_sum2, _sume); + int32x4x2_t _t3 = vzipq_s32(_sum6, _suma); + int32x4x2_t _t4 = vzipq_s32(_sum3, _sumf); + int32x4x2_t _t5 = vzipq_s32(_sum7, _sumb); + int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); + int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); + _sum9 = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); + _suma = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); + _sumb = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); + _sumc = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); + _sumd = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); + _sume = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); + _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + _sum9 = vrev64q_s32(_sum9); + _sumb = vrev64q_s32(_sumb); + _sumd = vrev64q_s32(_sumd); + _sumf = vrev64q_s32(_sumf); + } #endif - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c0); - _fa = vaddq_f32(_fa, _c0); - _fb = vaddq_f32(_fb, _c0); - _fc = vaddq_f32(_fc, _c0); - _fd = vaddq_f32(_fd, _c0); - _fe = vaddq_f32(_fe, _c0); - _ff = vaddq_f32(_ff, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - _f8 = vaddq_f32(_f8, _c1); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c1); - _fb = vaddq_f32(_fb, _c1); - _fc = vaddq_f32(_fc, _c1); - _fd = vaddq_f32(_fd, _c1); - _fe = vaddq_f32(_fe, _c1); - _ff = vaddq_f32(_ff, _c1); - } - if (broadcast_type_C == 3) + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale0); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale0); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_suma), _descale0); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sumb), _descale0); + float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); + float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); + float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); + float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c0); + _fa = vaddq_f32(_fa, _c0); + _fb = vaddq_f32(_fb, _c0); + _fc = vaddq_f32(_fc, _c0); + _fd = vaddq_f32(_fd, _c0); + _fe = vaddq_f32(_fe, _c0); + _ff = vaddq_f32(_ff, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c1); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c1); + _fb = vaddq_f32(_fb, _c1); + _fc = vaddq_f32(_fc, _c1); + _fd = vaddq_f32(_fd, _c1); + _fe = vaddq_f32(_fe, _c1); + _ff = vaddq_f32(_ff, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 4) { - if (c_elempack == 1) + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + 4 * 2); + float32x4_t _c3 = vld1q_f32(pC + 4 * 3); + float32x4_t _c4 = vld1q_f32(pC + 4 * 4); + float32x4_t _c5 = vld1q_f32(pC + 4 * 5); + float32x4_t _c6 = vld1q_f32(pC + 4 * 6); + float32x4_t _c7 = vld1q_f32(pC + 4 * 7); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 4 + 4); + _c2 = vld1q_f32(pC + c_hstep * 4 + 4 * 2); + _c3 = vld1q_f32(pC + c_hstep * 4 + 4 * 3); + _c4 = vld1q_f32(pC + c_hstep * 4 + 4 * 4); + _c5 = vld1q_f32(pC + c_hstep * 4 + 4 * 5); + _c6 = vld1q_f32(pC + c_hstep * 4 + 4 * 6); + _c7 = vld1q_f32(pC + c_hstep * 4 + 4 * 7); + if (beta == 1.f) { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + c_hstep); - float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); - float32x4_t _c4 = vld1q_f32(pC + c_hstep * 2); - float32x4_t _c5 = vld1q_f32(pC + c_hstep * 2 + 4); - float32x4_t _c6 = vld1q_f32(pC + c_hstep * 3); - float32x4_t _c7 = vld1q_f32(pC + c_hstep * 3 + 4); - transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } - _c0 = vld1q_f32(pC + c_hstep * 4); - _c1 = vld1q_f32(pC + c_hstep * 4 + 4); - _c2 = vld1q_f32(pC + c_hstep * 5); - _c3 = vld1q_f32(pC + c_hstep * 5 + 4); - _c4 = vld1q_f32(pC + c_hstep * 6); - _c5 = vld1q_f32(pC + c_hstep * 6 + 4); - _c6 = vld1q_f32(pC + c_hstep * 7); - _c7 = vld1q_f32(pC + c_hstep * 7 + 4); - transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); - if (beta == 1.f) - { - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c2); - _fb = vaddq_f32(_fb, _c3); - _fc = vaddq_f32(_fc, _c4); - _fd = vaddq_f32(_fd, _c5); - _fe = vaddq_f32(_fe, _c6); - _ff = vaddq_f32(_ff, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f8 = vmlaq_f32(_f8, _c0, _beta); - _f9 = vmlaq_f32(_f9, _c1, _beta); - _fa = vmlaq_f32(_fa, _c2, _beta); - _fb = vmlaq_f32(_fb, _c3, _beta); - _fc = vmlaq_f32(_fc, _c4, _beta); - _fd = vmlaq_f32(_fd, _c5, _beta); - _fe = vmlaq_f32(_fe, _c6, _beta); - _ff = vmlaq_f32(_ff, _c7, _beta); - } - pC += 8; + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); } - else // if (c_elempack == 4) + else { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + 4 * 2); - float32x4_t _c3 = vld1q_f32(pC + 4 * 3); - float32x4_t _c4 = vld1q_f32(pC + 4 * 4); - float32x4_t _c5 = vld1q_f32(pC + 4 * 5); - float32x4_t _c6 = vld1q_f32(pC + 4 * 6); - float32x4_t _c7 = vld1q_f32(pC + 4 * 7); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } - _c0 = vld1q_f32(pC + c_hstep * 4); - _c1 = vld1q_f32(pC + c_hstep * 4 + 4); - _c2 = vld1q_f32(pC + c_hstep * 4 + 4 * 2); - _c3 = vld1q_f32(pC + c_hstep * 4 + 4 * 3); - _c4 = vld1q_f32(pC + c_hstep * 4 + 4 * 4); - _c5 = vld1q_f32(pC + c_hstep * 4 + 4 * 5); - _c6 = vld1q_f32(pC + c_hstep * 4 + 4 * 6); - _c7 = vld1q_f32(pC + c_hstep * 4 + 4 * 7); - if (beta == 1.f) - { - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c2); - _fb = vaddq_f32(_fb, _c3); - _fc = vaddq_f32(_fc, _c4); - _fd = vaddq_f32(_fd, _c5); - _fe = vaddq_f32(_fe, _c6); - _ff = vaddq_f32(_ff, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f8 = vmlaq_f32(_f8, _c0, _beta); - _f9 = vmlaq_f32(_f9, _c1, _beta); - _fa = vmlaq_f32(_fa, _c2, _beta); - _fb = vmlaq_f32(_fb, _c3, _beta); - _fc = vmlaq_f32(_fc, _c4, _beta); - _fd = vmlaq_f32(_fd, _c5, _beta); - _fe = vmlaq_f32(_fe, _c6, _beta); - _ff = vmlaq_f32(_ff, _c7, _beta); - } - pC += 32; + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); } + pC += 32; } - if (broadcast_type_C == 4) + if (c_elempack == 1) { - float32x4_t _cc0 = vld1q_f32(pC); - float32x4_t _cc1 = vld1q_f32(pC + 4); - if (beta != 1.f) + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + c_hstep); + float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); + float32x4_t _c4 = vld1q_f32(pC + c_hstep * 2); + float32x4_t _c5 = vld1q_f32(pC + c_hstep * 2 + 4); + float32x4_t _c6 = vld1q_f32(pC + c_hstep * 3); + float32x4_t _c7 = vld1q_f32(pC + c_hstep * 3 + 4); + transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 4 + 4); + _c2 = vld1q_f32(pC + c_hstep * 5); + _c3 = vld1q_f32(pC + c_hstep * 5 + 4); + _c4 = vld1q_f32(pC + c_hstep * 6); + _c5 = vld1q_f32(pC + c_hstep * 6 + 4); + _c6 = vld1q_f32(pC + c_hstep * 7); + _c7 = vld1q_f32(pC + c_hstep * 7 + 4); + transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else { float32x4_t _beta = vdupq_n_f32(beta); - _cc0 = vmulq_f32(_cc0, _beta); - _cc1 = vmulq_f32(_cc1, _beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); } - _c0 = vdupq_laneq_f32(_cc0, 0); - _c1 = vdupq_laneq_f32(_cc0, 1); - float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); - float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); - float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); - float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); - float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); - float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c2); - _fb = vaddq_f32(_fb, _c3); - _fc = vaddq_f32(_fc, _c4); - _fd = vaddq_f32(_fd, _c5); - _fe = vaddq_f32(_fe, _c6); - _ff = vaddq_f32(_ff, _c7); pC += 8; } } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); - _f8 = vmulq_f32(_f8, _alpha); - _f9 = vmulq_f32(_f9, _alpha); - _fa = vmulq_f32(_fa, _alpha); - _fb = vmulq_f32(_fb, _alpha); - _fc = vmulq_f32(_fc, _alpha); - _fd = vmulq_f32(_fd, _alpha); - _fe = vmulq_f32(_fe, _alpha); - _ff = vmulq_f32(_ff, _alpha); - } - + if (broadcast_type_C == 4) + { + float32x4_t _cc0 = vld1q_f32(pC); + float32x4_t _cc1 = vld1q_f32(pC + 4); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); + } + _c0 = vdupq_laneq_f32(_cc0, 0); + _c1 = vdupq_laneq_f32(_cc0, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); + float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); + float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); + float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); + float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + _f8 = vmulq_f32(_f8, _alpha); + _f9 = vmulq_f32(_f9, _alpha); + _fa = vmulq_f32(_fa, _alpha); + _fb = vmulq_f32(_fb, _alpha); + _fc = vmulq_f32(_fc, _alpha); + _fd = vmulq_f32(_fd, _alpha); + _fe = vmulq_f32(_fe, _alpha); + _ff = vmulq_f32(_ff, _alpha); + } + + if (out_elempack == 4) + { vst1q_f32(p0, _f0); vst1q_f32(p0 + 4, _f1); vst1q_f32(p0 + 8, _f2); @@ -6108,220 +6073,247 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& vst1q_f32(p0 + out_hstep * 4 + 20, _fd); vst1q_f32(p0 + out_hstep * 4 + 24, _fe); vst1q_f32(p0 + out_hstep * 4 + 28, _ff); - - pp += 64; p0 += 32; } -#endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) + if (out_elempack == 1) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); + transpose4x4_ps(_f0, _f1, _f2, _f3); + transpose4x4_ps(_f4, _f5, _f6, _f7); + transpose4x4_ps(_f8, _f9, _fa, _fb); + transpose4x4_ps(_fc, _fd, _fe, _ff); + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f4); + vst1q_f32(p0 + out_hstep, _f1); + vst1q_f32(p0 + out_hstep + 4, _f5); + vst1q_f32(p0 + out_hstep * 2, _f2); + vst1q_f32(p0 + out_hstep * 2 + 4, _f6); + vst1q_f32(p0 + out_hstep * 3, _f3); + vst1q_f32(p0 + out_hstep * 3 + 4, _f7); + vst1q_f32(p0 + out_hstep * 4, _f8); + vst1q_f32(p0 + out_hstep * 4 + 4, _fc); + vst1q_f32(p0 + out_hstep * 5, _f9); + vst1q_f32(p0 + out_hstep * 5 + 4, _fd); + vst1q_f32(p0 + out_hstep * 6, _fa); + vst1q_f32(p0 + out_hstep * 6 + 4, _fe); + vst1q_f32(p0 + out_hstep * 7, _fb); + vst1q_f32(p0 + out_hstep * 7 + 4, _ff); + p0 += 8; + } -#if __ARM_FEATURE_DOTPROD - // from/to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 + pp += 64; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 #else - // from - // a0 b1 c2 d3 - // e0 f1 g2 h3 - // c0 d1 a2 b3 - // g0 h1 e2 f3 - // a3 b2 c1 d0 - // e3 f2 g1 h0 - // c3 d2 a1 b0 - // g3 h2 e1 f0 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - { - _sum4 = vrev64q_s32(_sum4); - _sum5 = vrev64q_s32(_sum5); - _sum6 = vrev64q_s32(_sum6); - _sum7 = vrev64q_s32(_sum7); - _sum4 = vextq_s32(_sum4, _sum4, 2); - _sum5 = vextq_s32(_sum5, _sum5, 2); - _sum6 = vextq_s32(_sum6, _sum6, 2); - _sum7 = vextq_s32(_sum7, _sum7, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); - int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); - int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - _sum5 = vrev64q_s32(_sum5); - _sum7 = vrev64q_s32(_sum7); - } -#endif // __ARM_FEATURE_DOTPROD - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); - float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); - float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); - float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); - float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + // from + // a0 b1 c2 d3 + // e0 f1 g2 h3 + // c0 d1 a2 b3 + // g0 h1 e2 f3 + // a3 b2 c1 d0 + // e3 f2 g1 h0 + // c3 d2 a1 b0 + // g3 h2 e1 f0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c1); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c1); + _f7 = vaddq_f32(_f7, _c1); + } + if (broadcast_type_C == 3) + { + float32x4_t _c2; + float32x4_t _c3; + if (c_elempack == 4) { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c1); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c1); - _f7 = vaddq_f32(_f7, _c1); + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + _c2 = vld1q_f32(pC + 8); + _c3 = vld1q_f32(pC + 12); } - if (broadcast_type_C == 3) + if (c_elempack == 1) { - float32x4_t _c2; - float32x4_t _c3; - if (c_elempack == 1) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + c_hstep); - _c2 = vld1q_f32(pC + c_hstep * 2); - _c3 = vld1q_f32(pC + c_hstep * 3); - transpose4x4_ps(_c0, _c1, _c2, _c3); - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - _c2 = vld1q_f32(pC + 8); - _c3 = vld1q_f32(pC + 12); - } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - if (c_elempack == 1) - { - _c0 = vld1q_f32(pC + c_hstep * 4); - _c1 = vld1q_f32(pC + c_hstep * 5); - _c2 = vld1q_f32(pC + c_hstep * 6); - _c3 = vld1q_f32(pC + c_hstep * 7); - transpose4x4_ps(_c0, _c1, _c2, _c3); - pC += 4; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC + c_hstep * 4); - _c1 = vld1q_f32(pC + c_hstep * 4 + 4); - _c2 = vld1q_f32(pC + c_hstep * 4 + 8); - _c3 = vld1q_f32(pC + c_hstep * 4 + 12); - pC += 16; - } - if (beta == 1.f) - { - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c2); - _f7 = vaddq_f32(_f7, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f4 = vmlaq_f32(_f4, _c0, _beta); - _f5 = vmlaq_f32(_f5, _c1, _beta); - _f6 = vmlaq_f32(_f6, _c2, _beta); - _f7 = vmlaq_f32(_f7, _c3, _beta); - } + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep); + _c2 = vld1q_f32(pC + c_hstep * 2); + _c3 = vld1q_f32(pC + c_hstep * 3); + transpose4x4_ps(_c0, _c1, _c2, _c3); } - if (broadcast_type_C == 4) + if (beta == 1.f) { - float32x4_t _c = vld1q_f32(pC); - _c = vmulq_n_f32(_c, beta); -#if __aarch64__ - _c0 = vdupq_laneq_f32(_c, 0); - _c1 = vdupq_laneq_f32(_c, 1); - float32x4_t _c2 = vdupq_laneq_f32(_c, 2); - float32x4_t _c3 = vdupq_laneq_f32(_c, 3); -#else - _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); - _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); - float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); - float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); -#endif _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 4 + 4); + _c2 = vld1q_f32(pC + c_hstep * 4 + 8); + _c3 = vld1q_f32(pC + c_hstep * 4 + 12); + pC += 16; + } + if (c_elempack == 1) + { + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 5); + _c2 = vld1q_f32(pC + c_hstep * 6); + _c3 = vld1q_f32(pC + c_hstep * 7); + transpose4x4_ps(_c0, _c1, _c2, _c3); + pC += 4; + } + if (beta == 1.f) + { _f4 = vaddq_f32(_f4, _c0); _f5 = vaddq_f32(_f5, _c1); _f6 = vaddq_f32(_f6, _c2); _f7 = vaddq_f32(_f7, _c3); - pC += 4; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); } } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); + float32x4_t _c = vld1q_f32(pC); + _c = vmulq_n_f32(_c, beta); +#if __aarch64__ + _c0 = vdupq_laneq_f32(_c, 0); + _c1 = vdupq_laneq_f32(_c, 1); + float32x4_t _c2 = vdupq_laneq_f32(_c, 2); + float32x4_t _c3 = vdupq_laneq_f32(_c, 3); +#else + _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); + _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); + float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); + float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); +#endif + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + pC += 4; } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + if (out_elempack == 4) + { vst1q_f32(p0, _f0); vst1q_f32(p0 + 4, _f1); vst1q_f32(p0 + 8, _f2); @@ -6330,1166 +6322,259 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& vst1q_f32(p0 + out_hstep * 4 + 4, _f5); vst1q_f32(p0 + out_hstep * 4 + 8, _f6); vst1q_f32(p0 + out_hstep * 4 + 12, _f7); - - pp += 32; p0 += 16; } - for (; jj + 1 < max_jj; jj += 2) + if (out_elempack == 1) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); + transpose4x4_ps(_f0, _f1, _f2, _f3); + transpose4x4_ps(_f4, _f5, _f6, _f7); + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep, _f1); + vst1q_f32(p0 + out_hstep * 2, _f2); + vst1q_f32(p0 + out_hstep * 3, _f3); + vst1q_f32(p0 + out_hstep * 4, _f4); + vst1q_f32(p0 + out_hstep * 5, _f5); + vst1q_f32(p0 + out_hstep * 6, _f6); + vst1q_f32(p0 + out_hstep * 7, _f7); + p0 += 4; + } + + pp += 32; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); #if __ARM_FEATURE_DOTPROD - // from/to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // e0 f0 g0 h0 - // e1 f1 g1 h1 + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 #else - // from - // a0 b1 c0 d1 - // e0 f1 g0 h1 - // a1 b0 c1 d0 - // e1 f0 g1 h0 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - { - _sum2 = vrev64q_s32(_sum2); - _sum3 = vrev64q_s32(_sum3); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum2); - int32x4x2_t _t1 = vzipq_s32(_sum1, _sum3); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[0]), vget_low_s32(_t1.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[0]), vget_high_s32(_t1.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - } + // from + // a0 b1 c0 d1 + // e0 f1 g0 h1 + // a1 b0 c1 d0 + // e1 f0 g1 h0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + { + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum2); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[0]), vget_low_s32(_t1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[0]), vget_high_s32(_t1.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } #endif // __ARM_FEATURE_DOTPROD - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale1); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale1); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale1); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale1); - if (pC) + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); + } + if (broadcast_type_C == 3) + { + float32x4_t _c2; + float32x4_t _c3; + if (c_elempack == 4) { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c1); - _f3 = vaddq_f32(_f3, _c1); + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + _c2 = vld1q_f32(pC + c_hstep * 4); + _c3 = vld1q_f32(pC + c_hstep * 4 + 4); + pC += 8; } - if (broadcast_type_C == 3) + if (c_elempack == 1) { - float32x4_t _c2; - float32x4_t _c3; - if (c_elempack == 1) - { - float32x2_t _cc0 = vld1_f32(pC); - float32x2_t _cc1 = vld1_f32(pC + c_hstep); - float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); - float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); - float32x4_t _c01 = vcombine_f32(_cc0, _cc1); - float32x4_t _c23 = vcombine_f32(_cc2, _cc3); - float32x4x2_t _ccc0 = vuzpq_f32(_c01, _c23); - _c0 = _ccc0.val[0]; - _c1 = _ccc0.val[1]; - float32x2_t _cc4 = vld1_f32(pC + c_hstep * 4); - float32x2_t _cc5 = vld1_f32(pC + c_hstep * 5); - float32x2_t _cc6 = vld1_f32(pC + c_hstep * 6); - float32x2_t _cc7 = vld1_f32(pC + c_hstep * 7); - float32x4_t _c45 = vcombine_f32(_cc4, _cc5); - float32x4_t _c67 = vcombine_f32(_cc6, _cc7); - float32x4x2_t _ccc1 = vuzpq_f32(_c45, _c67); - _c2 = _ccc1.val[0]; - _c3 = _ccc1.val[1]; - pC += 2; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - _c2 = vld1q_f32(pC + c_hstep * 4); - _c3 = vld1q_f32(pC + c_hstep * 4 + 4); - pC += 8; - } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } + float32x2_t _cc0 = vld1_f32(pC); + float32x2_t _cc1 = vld1_f32(pC + c_hstep); + float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); + float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); + float32x4_t _c01 = vcombine_f32(_cc0, _cc1); + float32x4_t _c23 = vcombine_f32(_cc2, _cc3); + float32x4x2_t _ccc0 = vuzpq_f32(_c01, _c23); + _c0 = _ccc0.val[0]; + _c1 = _ccc0.val[1]; + float32x2_t _cc4 = vld1_f32(pC + c_hstep * 4); + float32x2_t _cc5 = vld1_f32(pC + c_hstep * 5); + float32x2_t _cc6 = vld1_f32(pC + c_hstep * 6); + float32x2_t _cc7 = vld1_f32(pC + c_hstep * 7); + float32x4_t _c45 = vcombine_f32(_cc4, _cc5); + float32x4_t _c67 = vcombine_f32(_cc6, _cc7); + float32x4x2_t _ccc1 = vuzpq_f32(_c45, _c67); + _c2 = _ccc1.val[0]; + _c3 = _ccc1.val[1]; + pC += 2; } - if (broadcast_type_C == 4) + if (beta == 1.f) { - float32x2_t _c = vld1_f32(pC); - _c = vmul_n_f32(_c, beta); - _c0 = vdupq_lane_f32(_c, 0); - _c1 = vdupq_lane_f32(_c, 1); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c1); - pC += 2; + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); + float32x2_t _c = vld1_f32(pC); + _c = vmul_n_f32(_c, beta); + _c0 = vdupq_lane_f32(_c, 0); + _c1 = vdupq_lane_f32(_c, 1); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 2; } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + if (out_elempack == 4) + { vst1q_f32(p0, _f0); vst1q_f32(p0 + 4, _f1); vst1q_f32(p0 + out_hstep * 4, _f2); vst1q_f32(p0 + out_hstep * 4 + 4, _f3); - - pp += 16; p0 += 8; } - for (; jj < max_jj; jj++) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale1); - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - if (broadcast_type_C == 3) - { - if (c_elempack == 1) - { - _c0 = vsetq_lane_f32(pC[0], _c0, 0); - _c0 = vsetq_lane_f32(pC[c_hstep], _c0, 1); - _c0 = vsetq_lane_f32(pC[c_hstep * 2], _c0, 2); - _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); - _c1 = vsetq_lane_f32(pC[c_hstep * 4], _c1, 0); - _c1 = vsetq_lane_f32(pC[c_hstep * 5], _c1, 1); - _c1 = vsetq_lane_f32(pC[c_hstep * 6], _c1, 2); - _c1 = vsetq_lane_f32(pC[c_hstep * 7], _c1, 3); - pC += 1; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + c_hstep * 4); - pC += 4; - } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } - } - if (broadcast_type_C == 4) - { - _c0 = vdupq_n_f32(pC[0] * beta); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - pC += 1; - } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - } - - vst1q_f32(p0, _f0); - vst1q_f32(p0 + out_hstep * 4, _f1); - - pp += 8; - p0 += 4; + if (out_elempack == 1) + { + float32x4x2_t _f01 = vzipq_f32(_f0, _f1); + float32x4x2_t _f23 = vzipq_f32(_f2, _f3); + vst1_f32(p0, vget_low_f32(_f01.val[0])); + vst1_f32(p0 + out_hstep, vget_high_f32(_f01.val[0])); + vst1_f32(p0 + out_hstep * 2, vget_low_f32(_f01.val[1])); + vst1_f32(p0 + out_hstep * 3, vget_high_f32(_f01.val[1])); + vst1_f32(p0 + out_hstep * 4, vget_low_f32(_f23.val[0])); + vst1_f32(p0 + out_hstep * 5, vget_high_f32(_f23.val[0])); + vst1_f32(p0 + out_hstep * 6, vget_low_f32(_f23.val[1])); + vst1_f32(p0 + out_hstep * 7, vget_high_f32(_f23.val[1])); + p0 += 2; } + + pp += 16; } - if (out_elempack == 1) + for (; jj < max_jj; jj++) { - int jj = 0; -#if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); - int32x4_t _sum8 = vld1q_s32(pp + 32); - int32x4_t _sum9 = vld1q_s32(pp + 36); - int32x4_t _suma = vld1q_s32(pp + 40); - int32x4_t _sumb = vld1q_s32(pp + 44); - int32x4_t _sumc = vld1q_s32(pp + 48); - int32x4_t _sumd = vld1q_s32(pp + 52); - int32x4_t _sume = vld1q_s32(pp + 56); - int32x4_t _sumf = vld1q_s32(pp + 60); - -#if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - // e4 f4 g4 h4 - // e5 f5 g5 h5 - // e6 f6 g6 h6 - // e7 f7 g7 h7 - - // to - // a0 a1 a2 a3 - // a4 a5 a6 a7 - // b0 b1 b2 b3 - // b4 b5 b6 b7 - // c0 c1 c2 c3 - // c4 c5 c6 c7 - // d0 d1 d2 d3 - // d4 d5 d6 d7 - // e0 e1 e2 e3 - // e4 e5 e6 e7 - // f0 f1 f2 f3 - // f4 f5 f6 f7 - // g0 g1 g2 g3 - // g4 g5 g6 g7 - // h0 h1 h2 h3 - // h4 h5 h6 h7 - { - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum3); - int32x4x2_t _t2 = vzipq_s32(_sum8, _sum9); - int32x4x2_t _t3 = vzipq_s32(_suma, _sumb); - int32x4x2_t _t4 = vzipq_s32(_sum4, _sum5); - int32x4x2_t _t5 = vzipq_s32(_sum6, _sum7); - int32x4x2_t _t6 = vzipq_s32(_sumc, _sumd); - int32x4x2_t _t7 = vzipq_s32(_sume, _sumf); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum4 = vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t1.val[1])); - _sum5 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); - _sum6 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); - _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); - _sum9 = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); - _suma = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); - _sumb = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); - _sumc = vcombine_s32(vget_low_s32(_t4.val[1]), vget_low_s32(_t5.val[1])); - _sumd = vcombine_s32(vget_low_s32(_t6.val[1]), vget_low_s32(_t7.val[1])); - _sume = vcombine_s32(vget_high_s32(_t4.val[1]), vget_high_s32(_t5.val[1])); - _sumf = vcombine_s32(vget_high_s32(_t6.val[1]), vget_high_s32(_t7.val[1])); - } -#else - // from - // a0 b1 c2 d3 - // e4 f5 g6 h7 - // e0 f1 g2 h3 - // a4 b5 c6 d7 - // c0 d1 a2 b3 - // g4 h5 e6 f7 - // g0 h1 e2 f3 - // c4 d5 a6 b7 - // a3 b2 c1 d0 - // e7 f6 g5 h4 - // e3 f2 g1 h0 - // a7 b6 c5 d4 - // c3 d2 a1 b0 - // g7 h6 e5 f4 - // g3 h2 e1 f0 - // c7 d6 a5 b4 - - // to - // a0 a1 a2 a3 - // a4 a5 a6 a7 - // b0 b1 b2 b3 - // b4 b5 b6 b7 - // c0 c1 c2 c3 - // c4 c5 c6 c7 - // d0 d1 d2 d3 - // d4 d5 d6 d7 - // e0 e1 e2 e3 - // e4 e5 e6 e7 - // f0 f1 f2 f3 - // f4 f5 f6 f7 - // g0 g1 g2 g3 - // g4 g5 g6 g7 - // h0 h1 h2 h3 - // h4 h5 h6 h7 - { - _sum4 = vextq_s32(_sum4, _sum4, 2); - _sum5 = vextq_s32(_sum5, _sum5, 2); - _sum6 = vextq_s32(_sum6, _sum6, 2); - _sum7 = vextq_s32(_sum7, _sum7, 2); - _sumc = vextq_s32(_sumc, _sumc, 2); - _sumd = vextq_s32(_sumd, _sumd, 2); - _sume = vextq_s32(_sume, _sume, 2); - _sumf = vextq_s32(_sumf, _sumf, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); - int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); - int32x4x2_t _t2 = vzipq_s32(_sum3, _sumf); - int32x4x2_t _t3 = vzipq_s32(_sum7, _sumb); - int32x4x2_t _t4 = vzipq_s32(_sum2, _sume); - int32x4x2_t _t5 = vzipq_s32(_sum6, _suma); - int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); - int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum4 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum5 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sum6 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); - _sum9 = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); - _suma = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); - _sumb = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); - _sumc = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); - _sumd = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); - _sume = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); - _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); - _sum2 = vrev64q_s32(_sum2); - _sum3 = vrev64q_s32(_sum3); - _sum6 = vrev64q_s32(_sum6); - _sum7 = vrev64q_s32(_sum7); - _suma = vrev64q_s32(_suma); - _sumb = vrev64q_s32(_sumb); - _sume = vrev64q_s32(_sume); - _sumf = vrev64q_s32(_sumf); - } -#endif // __ARM_FEATURE_DOTPROD - - float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale0, 0); - float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale0, 0); - float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale0, 1); - float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale0, 1); - float32x4_t _f4 = vmulq_laneq_f32(vcvtq_f32_s32(_sum4), _descale0, 2); - float32x4_t _f5 = vmulq_laneq_f32(vcvtq_f32_s32(_sum5), _descale0, 2); - float32x4_t _f6 = vmulq_laneq_f32(vcvtq_f32_s32(_sum6), _descale0, 3); - float32x4_t _f7 = vmulq_laneq_f32(vcvtq_f32_s32(_sum7), _descale0, 3); - float32x4_t _f8 = vmulq_laneq_f32(vcvtq_f32_s32(_sum8), _descale1, 0); - float32x4_t _f9 = vmulq_laneq_f32(vcvtq_f32_s32(_sum9), _descale1, 0); - float32x4_t _fa = vmulq_laneq_f32(vcvtq_f32_s32(_suma), _descale1, 1); - float32x4_t _fb = vmulq_laneq_f32(vcvtq_f32_s32(_sumb), _descale1, 1); - float32x4_t _fc = vmulq_laneq_f32(vcvtq_f32_s32(_sumc), _descale1, 2); - float32x4_t _fd = vmulq_laneq_f32(vcvtq_f32_s32(_sumd), _descale1, 2); - float32x4_t _fe = vmulq_laneq_f32(vcvtq_f32_s32(_sume), _descale1, 3); - float32x4_t _ff = vmulq_laneq_f32(vcvtq_f32_s32(_sumf), _descale1, 3); - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c0); - _fa = vaddq_f32(_fa, _c0); - _fb = vaddq_f32(_fb, _c0); - _fc = vaddq_f32(_fc, _c0); - _fd = vaddq_f32(_fd, _c0); - _fe = vaddq_f32(_fe, _c0); - _ff = vaddq_f32(_ff, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); - float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); - float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); - float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); - float32x4_t _cc4 = vdupq_laneq_f32(_c1, 0); - float32x4_t _cc5 = vdupq_laneq_f32(_c1, 1); - float32x4_t _cc6 = vdupq_laneq_f32(_c1, 2); - float32x4_t _cc7 = vdupq_laneq_f32(_c1, 3); - _f0 = vaddq_f32(_f0, _cc0); - _f1 = vaddq_f32(_f1, _cc0); - _f2 = vaddq_f32(_f2, _cc1); - _f3 = vaddq_f32(_f3, _cc1); - _f4 = vaddq_f32(_f4, _cc2); - _f5 = vaddq_f32(_f5, _cc2); - _f6 = vaddq_f32(_f6, _cc3); - _f7 = vaddq_f32(_f7, _cc3); - _f8 = vaddq_f32(_f8, _cc4); - _f9 = vaddq_f32(_f9, _cc4); - _fa = vaddq_f32(_fa, _cc5); - _fb = vaddq_f32(_fb, _cc5); - _fc = vaddq_f32(_fc, _cc6); - _fd = vaddq_f32(_fd, _cc6); - _fe = vaddq_f32(_fe, _cc7); - _ff = vaddq_f32(_ff, _cc7); - } - if (broadcast_type_C == 3) - { - if (c_elempack == 1) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + c_hstep); - float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); - float32x4_t _c4 = vld1q_f32(pC + c_hstep * 2); - float32x4_t _c5 = vld1q_f32(pC + c_hstep * 2 + 4); - float32x4_t _c6 = vld1q_f32(pC + c_hstep * 3); - float32x4_t _c7 = vld1q_f32(pC + c_hstep * 3 + 4); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } - _c0 = vld1q_f32(pC + c_hstep * 4); - _c1 = vld1q_f32(pC + c_hstep * 4 + 4); - _c2 = vld1q_f32(pC + c_hstep * 5); - _c3 = vld1q_f32(pC + c_hstep * 5 + 4); - _c4 = vld1q_f32(pC + c_hstep * 6); - _c5 = vld1q_f32(pC + c_hstep * 6 + 4); - _c6 = vld1q_f32(pC + c_hstep * 7); - _c7 = vld1q_f32(pC + c_hstep * 7 + 4); - if (beta == 1.f) - { - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c2); - _fb = vaddq_f32(_fb, _c3); - _fc = vaddq_f32(_fc, _c4); - _fd = vaddq_f32(_fd, _c5); - _fe = vaddq_f32(_fe, _c6); - _ff = vaddq_f32(_ff, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f8 = vmlaq_f32(_f8, _c0, _beta); - _f9 = vmlaq_f32(_f9, _c1, _beta); - _fa = vmlaq_f32(_fa, _c2, _beta); - _fb = vmlaq_f32(_fb, _c3, _beta); - _fc = vmlaq_f32(_fc, _c4, _beta); - _fd = vmlaq_f32(_fd, _c5, _beta); - _fe = vmlaq_f32(_fe, _c6, _beta); - _ff = vmlaq_f32(_ff, _c7, _beta); - } - pC += 8; - } - else // if (c_elempack == 4) - { - float32x4x4_t _cc0 = vld4q_f32(pC); - float32x4x4_t _cc1 = vld4q_f32(pC + 16); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _cc0.val[0]); - _f1 = vaddq_f32(_f1, _cc1.val[0]); - _f2 = vaddq_f32(_f2, _cc0.val[1]); - _f3 = vaddq_f32(_f3, _cc1.val[1]); - _f4 = vaddq_f32(_f4, _cc0.val[2]); - _f5 = vaddq_f32(_f5, _cc1.val[2]); - _f6 = vaddq_f32(_f6, _cc0.val[3]); - _f7 = vaddq_f32(_f7, _cc1.val[3]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _cc0.val[0], _beta); - _f1 = vmlaq_f32(_f1, _cc1.val[0], _beta); - _f2 = vmlaq_f32(_f2, _cc0.val[1], _beta); - _f3 = vmlaq_f32(_f3, _cc1.val[1], _beta); - _f4 = vmlaq_f32(_f4, _cc0.val[2], _beta); - _f5 = vmlaq_f32(_f5, _cc1.val[2], _beta); - _f6 = vmlaq_f32(_f6, _cc0.val[3], _beta); - _f7 = vmlaq_f32(_f7, _cc1.val[3], _beta); - } - _cc0 = vld4q_f32(pC + c_hstep * 4); - _cc1 = vld4q_f32(pC + c_hstep * 4 + 16); - if (beta == 1.f) - { - _f8 = vaddq_f32(_f8, _cc0.val[0]); - _f9 = vaddq_f32(_f9, _cc1.val[0]); - _fa = vaddq_f32(_fa, _cc0.val[1]); - _fb = vaddq_f32(_fb, _cc1.val[1]); - _fc = vaddq_f32(_fc, _cc0.val[2]); - _fd = vaddq_f32(_fd, _cc1.val[2]); - _fe = vaddq_f32(_fe, _cc0.val[3]); - _ff = vaddq_f32(_ff, _cc1.val[3]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f8 = vmlaq_f32(_f8, _cc0.val[0], _beta); - _f9 = vmlaq_f32(_f9, _cc1.val[0], _beta); - _fa = vmlaq_f32(_fa, _cc0.val[1], _beta); - _fb = vmlaq_f32(_fb, _cc1.val[1], _beta); - _fc = vmlaq_f32(_fc, _cc0.val[2], _beta); - _fd = vmlaq_f32(_fd, _cc1.val[2], _beta); - _fe = vmlaq_f32(_fe, _cc0.val[3], _beta); - _ff = vmlaq_f32(_ff, _cc1.val[3], _beta); - } - pC += 32; - } - } - if (broadcast_type_C == 4) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - if (beta != 1.f) - { - float32x4_t _beta = vdupq_n_f32(beta); - _c0 = vmulq_f32(_c0, _beta); - _c1 = vmulq_f32(_c1, _beta); - } - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c1); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c1); - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c0); - _fb = vaddq_f32(_fb, _c1); - _fc = vaddq_f32(_fc, _c0); - _fd = vaddq_f32(_fd, _c1); - _fe = vaddq_f32(_fe, _c0); - _ff = vaddq_f32(_ff, _c1); - pC += 8; - } - } + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); - _f8 = vmulq_f32(_f8, _alpha); - _f9 = vmulq_f32(_f9, _alpha); - _fa = vmulq_f32(_fa, _alpha); - _fb = vmulq_f32(_fb, _alpha); - _fc = vmulq_f32(_fc, _alpha); - _fd = vmulq_f32(_fd, _alpha); - _fe = vmulq_f32(_fe, _alpha); - _ff = vmulq_f32(_ff, _alpha); - } - - vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f1); - vst1q_f32(p0 + out_hstep, _f2); - vst1q_f32(p0 + out_hstep + 4, _f3); - vst1q_f32(p0 + out_hstep * 2, _f4); - vst1q_f32(p0 + out_hstep * 2 + 4, _f5); - vst1q_f32(p0 + out_hstep * 3, _f6); - vst1q_f32(p0 + out_hstep * 3 + 4, _f7); - vst1q_f32(p0 + out_hstep * 4, _f8); - vst1q_f32(p0 + out_hstep * 4 + 4, _f9); - vst1q_f32(p0 + out_hstep * 5, _fa); - vst1q_f32(p0 + out_hstep * 5 + 4, _fb); - vst1q_f32(p0 + out_hstep * 6, _fc); - vst1q_f32(p0 + out_hstep * 6 + 4, _fd); - vst1q_f32(p0 + out_hstep * 7, _fe); - vst1q_f32(p0 + out_hstep * 7 + 4, _ff); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale1); - pp += 64; - p0 += 8; - } -#endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) + if (pC) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); - -#if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - - // to - // a0 a1 a2 a3 - // b0 b1 b2 b3 - // c0 c1 c2 c3 - // d0 d1 d2 d3 - // e0 e1 e2 e3 - // f0 f1 f2 f3 - // g0 g1 g2 g3 - // h0 h1 h2 h3 + if (broadcast_type_C == 0) { - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum3); - int32x4x2_t _t2 = vzipq_s32(_sum4, _sum5); - int32x4x2_t _t3 = vzipq_s32(_sum6, _sum7); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t1.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); - _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum6 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); } -#else - // from - // a0 b1 c2 d3 - // e0 f1 g2 h3 - // c0 d1 a2 b3 - // g0 h1 e2 f3 - // a3 b2 c1 d0 - // e3 f2 g1 h0 - // c3 d2 a1 b0 - // g3 h2 e1 f0 - - // to - // a0 a1 a2 a3 - // b0 b1 b2 b3 - // c0 c1 c2 c3 - // d0 d1 d2 d3 - // e0 e1 e2 e3 - // f0 f1 f2 f3 - // g0 g1 g2 g3 - // h0 h1 h2 h3 + if (broadcast_type_C == 1 || broadcast_type_C == 2) { - _sum2 = vextq_s32(_sum2, _sum2, 2); - _sum3 = vextq_s32(_sum3, _sum3, 2); - _sum6 = vextq_s32(_sum6, _sum6, 2); - _sum7 = vextq_s32(_sum7, _sum7, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); - int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); - int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - _sum5 = vrev64q_s32(_sum5); - _sum7 = vrev64q_s32(_sum7); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); } -#endif // __ARM_FEATURE_DOTPROD - -#if __aarch64__ - float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale0, 0); - float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale0, 1); - float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale0, 2); - float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale0, 3); - float32x4_t _f4 = vmulq_laneq_f32(vcvtq_f32_s32(_sum4), _descale1, 0); - float32x4_t _f5 = vmulq_laneq_f32(vcvtq_f32_s32(_sum5), _descale1, 1); - float32x4_t _f6 = vmulq_laneq_f32(vcvtq_f32_s32(_sum6), _descale1, 2); - float32x4_t _f7 = vmulq_laneq_f32(vcvtq_f32_s32(_sum7), _descale1, 3); -#else - float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), vget_low_f32(_descale0), 0); - float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), vget_low_f32(_descale0), 1); - float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), vget_high_f32(_descale0), 0); - float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), vget_high_f32(_descale0), 1); - float32x4_t _f4 = vmulq_lane_f32(vcvtq_f32_s32(_sum4), vget_low_f32(_descale1), 0); - float32x4_t _f5 = vmulq_lane_f32(vcvtq_f32_s32(_sum5), vget_low_f32(_descale1), 1); - float32x4_t _f6 = vmulq_lane_f32(vcvtq_f32_s32(_sum6), vget_high_f32(_descale1), 0); - float32x4_t _f7 = vmulq_lane_f32(vcvtq_f32_s32(_sum7), vget_high_f32(_descale1), 1); -#endif - - if (pC) + if (broadcast_type_C == 3) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { -#if __aarch64__ - float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); - float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); - float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); - float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); -#else - float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); - float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); - float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); - float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); -#endif - _f0 = vaddq_f32(_f0, _cc0); - _f1 = vaddq_f32(_f1, _cc1); - _f2 = vaddq_f32(_f2, _cc2); - _f3 = vaddq_f32(_f3, _cc3); -#if __aarch64__ - _cc0 = vdupq_laneq_f32(_c1, 0); - _cc1 = vdupq_laneq_f32(_c1, 1); - _cc2 = vdupq_laneq_f32(_c1, 2); - _cc3 = vdupq_laneq_f32(_c1, 3); -#else - _cc0 = vdupq_lane_f32(vget_low_f32(_c1), 0); - _cc1 = vdupq_lane_f32(vget_low_f32(_c1), 1); - _cc2 = vdupq_lane_f32(vget_high_f32(_c1), 0); - _cc3 = vdupq_lane_f32(vget_high_f32(_c1), 1); -#endif - _f4 = vaddq_f32(_f4, _cc0); - _f5 = vaddq_f32(_f5, _cc1); - _f6 = vaddq_f32(_f6, _cc2); - _f7 = vaddq_f32(_f7, _cc3); - } - if (broadcast_type_C == 3) - { - if (c_elempack == 1) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + c_hstep); - float32x4_t _c2 = vld1q_f32(pC + c_hstep * 2); - float32x4_t _c3 = vld1q_f32(pC + c_hstep * 3); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - _c0 = vld1q_f32(pC + c_hstep * 4); - _c1 = vld1q_f32(pC + c_hstep * 5); - _c2 = vld1q_f32(pC + c_hstep * 6); - _c3 = vld1q_f32(pC + c_hstep * 7); - if (beta == 1.f) - { - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c2); - _f7 = vaddq_f32(_f7, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f4 = vmlaq_f32(_f4, _c0, _beta); - _f5 = vmlaq_f32(_f5, _c1, _beta); - _f6 = vmlaq_f32(_f6, _c2, _beta); - _f7 = vmlaq_f32(_f7, _c3, _beta); - } - pC += 4; - } - else // if (c_elempack == 4) - { - float32x4x4_t _cc0 = vld4q_f32(pC); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _cc0.val[0]); - _f1 = vaddq_f32(_f1, _cc0.val[1]); - _f2 = vaddq_f32(_f2, _cc0.val[2]); - _f3 = vaddq_f32(_f3, _cc0.val[3]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _cc0.val[0], _beta); - _f1 = vmlaq_f32(_f1, _cc0.val[1], _beta); - _f2 = vmlaq_f32(_f2, _cc0.val[2], _beta); - _f3 = vmlaq_f32(_f3, _cc0.val[3], _beta); - } - _cc0 = vld4q_f32(pC + c_hstep * 4); - if (beta == 1.f) - { - _f4 = vaddq_f32(_f4, _cc0.val[0]); - _f5 = vaddq_f32(_f5, _cc0.val[1]); - _f6 = vaddq_f32(_f6, _cc0.val[2]); - _f7 = vaddq_f32(_f7, _cc0.val[3]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f4 = vmlaq_f32(_f4, _cc0.val[0], _beta); - _f5 = vmlaq_f32(_f5, _cc0.val[1], _beta); - _f6 = vmlaq_f32(_f6, _cc0.val[2], _beta); - _f7 = vmlaq_f32(_f7, _cc0.val[3], _beta); - } - pC += 16; - } - } - if (broadcast_type_C == 4) + if (c_elempack == 4) { _c0 = vld1q_f32(pC); - _c0 = vmulq_n_f32(_c0, beta); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); + _c1 = vld1q_f32(pC + c_hstep * 4); pC += 4; } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); - } - - vst1q_f32(p0, _f0); - vst1q_f32(p0 + out_hstep, _f1); - vst1q_f32(p0 + out_hstep * 2, _f2); - vst1q_f32(p0 + out_hstep * 3, _f3); - vst1q_f32(p0 + out_hstep * 4, _f4); - vst1q_f32(p0 + out_hstep * 5, _f5); - vst1q_f32(p0 + out_hstep * 6, _f6); - vst1q_f32(p0 + out_hstep * 7, _f7); - - pp += 32; - p0 += 4; - } - for (; jj + 1 < max_jj; jj += 2) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - -#if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - - // to - // a0 a1 b0 b1 - // c0 c1 d0 d1 - // e0 e1 f0 f1 - // g0 g1 h0 h1 - { - int32x4x2_t _sum02 = vzipq_s32(_sum0, _sum1); - int32x4x2_t _sum13 = vzipq_s32(_sum2, _sum3); - _sum0 = _sum02.val[0]; - _sum1 = _sum02.val[1]; - _sum2 = _sum13.val[0]; - _sum3 = _sum13.val[1]; - } -#else - // from - // a0 b1 c0 d1 - // e0 f1 g0 h1 - // a1 b0 c1 d0 - // e1 f0 g1 h0 - - // to - // a0 a1 b0 b1 - // c0 c1 d0 d1 - // e0 e1 f0 f1 - // g0 g1 h0 h1 - { - int32x4x2_t _t0 = vuzpq_s32(_sum0, _sum1); - int32x4x2_t _t1 = vuzpq_s32(_sum2, _sum3); - int32x4x2_t _t2 = vzipq_s32(_t0.val[0], _t1.val[0]); - int32x4x2_t _t3 = vzipq_s32(_t1.val[1], _t0.val[1]); - _sum0 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); - } -#endif // __ARM_FEATURE_DOTPROD - - float32x4x2_t _descale01 = vzipq_f32(_descale0, _descale0); - float32x4x2_t _descale23 = vzipq_f32(_descale1, _descale1); - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale01.val[0]); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale01.val[1]); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale23.val[0]); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale23.val[1]); - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + if (c_elempack == 1) { - float32x4x2_t _cc0 = vzipq_f32(_c0, _c0); - float32x4x2_t _cc1 = vzipq_f32(_c1, _c1); - _f0 = vaddq_f32(_f0, _cc0.val[0]); - _f1 = vaddq_f32(_f1, _cc0.val[1]); - _f2 = vaddq_f32(_f2, _cc1.val[0]); - _f3 = vaddq_f32(_f3, _cc1.val[1]); + _c0 = vsetq_lane_f32(pC[0], _c0, 0); + _c0 = vsetq_lane_f32(pC[c_hstep], _c0, 1); + _c0 = vsetq_lane_f32(pC[c_hstep * 2], _c0, 2); + _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); + _c1 = vsetq_lane_f32(pC[c_hstep * 4], _c1, 0); + _c1 = vsetq_lane_f32(pC[c_hstep * 5], _c1, 1); + _c1 = vsetq_lane_f32(pC[c_hstep * 6], _c1, 2); + _c1 = vsetq_lane_f32(pC[c_hstep * 7], _c1, 3); + pC += 1; } - if (broadcast_type_C == 3) + if (beta == 1.f) { - if (c_elempack == 1) - { - float32x2_t _cc0 = vld1_f32(pC); - float32x2_t _cc1 = vld1_f32(pC + c_hstep); - float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); - float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); - _c0 = vcombine_f32(_cc0, _cc1); - _c1 = vcombine_f32(_cc2, _cc3); - float32x2_t _cc4 = vld1_f32(pC + c_hstep * 4); - float32x2_t _cc5 = vld1_f32(pC + c_hstep * 5); - float32x2_t _cc6 = vld1_f32(pC + c_hstep * 6); - float32x2_t _cc7 = vld1_f32(pC + c_hstep * 7); - float32x4_t _c2 = vcombine_f32(_cc4, _cc5); - float32x4_t _c3 = vcombine_f32(_cc6, _cc7); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - pC += 2; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + c_hstep * 4); - float32x4_t _c3 = vld1q_f32(pC + c_hstep * 4 + 4); - float32x4x2_t _c01 = vzipq_f32(_c0, _c1); - float32x4x2_t _c23 = vzipq_f32(_c2, _c3); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c01.val[0]); - _f1 = vaddq_f32(_f1, _c01.val[1]); - _f2 = vaddq_f32(_f2, _c23.val[0]); - _f3 = vaddq_f32(_f3, _c23.val[1]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c01.val[0], _beta); - _f1 = vmlaq_f32(_f1, _c01.val[1], _beta); - _f2 = vmlaq_f32(_f2, _c23.val[0], _beta); - _f3 = vmlaq_f32(_f3, _c23.val[1], _beta); - } - pC += 8; - } + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); } - if (broadcast_type_C == 4) + else { - float32x2_t _cc0 = vld1_f32(pC); - _cc0 = vmul_n_f32(_cc0, beta); - _c0 = vcombine_f32(_cc0, _cc0); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - pC += 2; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); + _c0 = vdupq_n_f32(pC[0] * beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 1; } + } - vst1_f32(p0, vget_low_f32(_f0)); - vst1_f32(p0 + out_hstep, vget_high_f32(_f0)); - vst1_f32(p0 + out_hstep * 2, vget_low_f32(_f1)); - vst1_f32(p0 + out_hstep * 3, vget_high_f32(_f1)); - vst1_f32(p0 + out_hstep * 4, vget_low_f32(_f2)); - vst1_f32(p0 + out_hstep * 5, vget_high_f32(_f2)); - vst1_f32(p0 + out_hstep * 6, vget_low_f32(_f3)); - vst1_f32(p0 + out_hstep * 7, vget_high_f32(_f3)); + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } - pp += 16; - p0 += 2; + if (out_elempack == 4) + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep * 4, _f1); + p0 += 4; } - for (; jj < max_jj; jj++) + if (out_elempack == 1) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale1); - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - if (broadcast_type_C == 3) - { - if (c_elempack == 1) - { - _c0 = vsetq_lane_f32(pC[0], _c0, 0); - _c0 = vsetq_lane_f32(pC[c_hstep], _c0, 1); - _c0 = vsetq_lane_f32(pC[c_hstep * 2], _c0, 2); - _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); - _c1 = vsetq_lane_f32(pC[c_hstep * 4], _c1, 0); - _c1 = vsetq_lane_f32(pC[c_hstep * 5], _c1, 1); - _c1 = vsetq_lane_f32(pC[c_hstep * 6], _c1, 2); - _c1 = vsetq_lane_f32(pC[c_hstep * 7], _c1, 3); - pC += 1; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + c_hstep * 4); - pC += 4; - } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } - } - if (broadcast_type_C == 4) - { - _c0 = vdupq_n_f32(pC[0] * beta); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - pC += 1; - } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - } - p0[0] = vgetq_lane_f32(_f0, 0); p0[out_hstep] = vgetq_lane_f32(_f0, 1); p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2); @@ -7498,10 +6583,10 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& p0[out_hstep * 5] = vgetq_lane_f32(_f1, 1); p0[out_hstep * 6] = vgetq_lane_f32(_f1, 2); p0[out_hstep * 7] = vgetq_lane_f32(_f1, 3); - - pp += 8; p0++; } + + pp += 8; } } for (; ii + 3 < max_ii; ii += 4) @@ -7533,213 +6618,213 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& } } - if (out_elempack == 4) - { - int jj = 0; + int jj = 0; #if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); - -#if __ARM_FEATURE_DOTPROD - // from/to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 #else - // from - // a0 b1 c2 d3 - // a4 b5 c6 d7 - // c0 d1 a2 b3 - // c4 d5 a6 b7 - // a3 b2 c1 d0 - // a7 b6 c5 d4 - // c3 d2 a1 b0 - // c7 d6 a5 b4 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - { - _sum4 = vrev64q_s32(_sum4); - _sum5 = vrev64q_s32(_sum5); - _sum6 = vrev64q_s32(_sum6); - _sum7 = vrev64q_s32(_sum7); - _sum4 = vextq_s32(_sum4, _sum4, 2); - _sum5 = vextq_s32(_sum5, _sum5, 2); - _sum6 = vextq_s32(_sum6, _sum6, 2); - _sum7 = vextq_s32(_sum7, _sum7, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); - int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); - int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - _sum5 = vrev64q_s32(_sum5); - _sum7 = vrev64q_s32(_sum7); - } -#endif // __ARM_FEATURE_DOTPROD - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); - float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale); - float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale); - float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale); - float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale); - - if (pC) - { - if (broadcast_type_C == 0) + // from + // a0 b1 c2 d3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 3) + { + float32x4_t _c1; + float32x4_t _c2; + float32x4_t _c3; + float32x4_t _c4; + float32x4_t _c5; + float32x4_t _c6; + float32x4_t _c7; + if (c_elempack == 4) { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + _c2 = vld1q_f32(pC + 8); + _c3 = vld1q_f32(pC + 12); + _c4 = vld1q_f32(pC + 16); + _c5 = vld1q_f32(pC + 20); + _c6 = vld1q_f32(pC + 24); + _c7 = vld1q_f32(pC + 28); + pC += 32; } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + if (c_elempack == 1) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + _c2 = vld1q_f32(pC + c_hstep); + _c3 = vld1q_f32(pC + c_hstep + 4); + _c4 = vld1q_f32(pC + c_hstep * 2); + _c5 = vld1q_f32(pC + c_hstep * 2 + 4); + _c6 = vld1q_f32(pC + c_hstep * 3); + _c7 = vld1q_f32(pC + c_hstep * 3 + 4); + transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); + pC += 8; + } + if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); } - if (broadcast_type_C == 3) + else { - float32x4_t _c1; - float32x4_t _c2; - float32x4_t _c3; - float32x4_t _c4; - float32x4_t _c5; - float32x4_t _c6; - float32x4_t _c7; - if (c_elempack == 1) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - _c2 = vld1q_f32(pC + c_hstep); - _c3 = vld1q_f32(pC + c_hstep + 4); - _c4 = vld1q_f32(pC + c_hstep * 2); - _c5 = vld1q_f32(pC + c_hstep * 2 + 4); - _c6 = vld1q_f32(pC + c_hstep * 3); - _c7 = vld1q_f32(pC + c_hstep * 3 + 4); - transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); - pC += 8; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - _c2 = vld1q_f32(pC + 8); - _c3 = vld1q_f32(pC + 12); - _c4 = vld1q_f32(pC + 16); - _c5 = vld1q_f32(pC + 20); - _c6 = vld1q_f32(pC + 24); - _c7 = vld1q_f32(pC + 28); - pC += 32; - } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); } - if (broadcast_type_C == 4) + } + if (broadcast_type_C == 4) + { + float32x4_t _cc0 = vld1q_f32(pC); + float32x4_t _cc1 = vld1q_f32(pC + 4); + if (beta != 1.f) { - float32x4_t _cc0 = vld1q_f32(pC); - float32x4_t _cc1 = vld1q_f32(pC + 4); - if (beta != 1.f) - { - float32x4_t _beta = vdupq_n_f32(beta); - _cc0 = vmulq_f32(_cc0, _beta); - _cc1 = vmulq_f32(_cc1, _beta); - } - _c0 = vdupq_laneq_f32(_cc0, 0); - float32x4_t _c1 = vdupq_laneq_f32(_cc0, 1); - float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); - float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); - float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); - float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); - float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); - float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - pC += 8; + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); } + _c0 = vdupq_laneq_f32(_cc0, 0); + float32x4_t _c1 = vdupq_laneq_f32(_cc0, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); + float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); + float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); + float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); + float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 8; } + } - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); - } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + if (out_elempack == 4) + { vst1q_f32(p0, _f0); vst1q_f32(p0 + 4, _f1); vst1q_f32(p0 + 8, _f2); @@ -7748,899 +6833,648 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& vst1q_f32(p0 + 20, _f5); vst1q_f32(p0 + 24, _f6); vst1q_f32(p0 + 28, _f7); - - pp += 32; p0 += 32; } -#endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) + if (out_elempack == 1) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); + transpose4x4_ps(_f0, _f1, _f2, _f3); + transpose4x4_ps(_f4, _f5, _f6, _f7); + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f4); + vst1q_f32(p0 + out_hstep, _f1); + vst1q_f32(p0 + out_hstep + 4, _f5); + vst1q_f32(p0 + out_hstep * 2, _f2); + vst1q_f32(p0 + out_hstep * 2 + 4, _f6); + vst1q_f32(p0 + out_hstep * 3, _f3); + vst1q_f32(p0 + out_hstep * 3 + 4, _f7); + p0 += 8; + } + + pp += 32; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); #if __ARM_FEATURE_DOTPROD - // from/to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 #else - // from - // a0 b1 c2 d3 - // c0 d1 a2 b3 - // a3 b2 c1 d0 - // c3 d2 a1 b0 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - { - _sum2 = vrev64q_s32(_sum2); - _sum3 = vrev64q_s32(_sum3); - _sum2 = vextq_s32(_sum2, _sum2, 2); - _sum3 = vextq_s32(_sum3, _sum3, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); - int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - } + // from + // a0 b1 c2 d3 + // c0 d1 a2 b3 + // a3 b2 c1 d0 + // c3 d2 a1 b0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + { + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + _sum2 = vextq_s32(_sum2, _sum2, 2); + _sum3 = vextq_s32(_sum3, _sum3, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } #endif // __ARM_FEATURE_DOTPROD - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); - if (pC) + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 3) + { + float32x4_t _c1; + float32x4_t _c2; + float32x4_t _c3; + if (c_elempack == 4) { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + _c2 = vld1q_f32(pC + 8); + _c3 = vld1q_f32(pC + 12); + pC += 16; } - if (broadcast_type_C == 3) + if (c_elempack == 1) { - float32x4_t _c1; - float32x4_t _c2; - float32x4_t _c3; - if (c_elempack == 1) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + c_hstep * 1); - _c2 = vld1q_f32(pC + c_hstep * 2); - _c3 = vld1q_f32(pC + c_hstep * 3); - transpose4x4_ps(_c0, _c1, _c2, _c3); - pC += 4; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - _c2 = vld1q_f32(pC + 8); - _c3 = vld1q_f32(pC + 12); - pC += 16; - } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep * 1); + _c2 = vld1q_f32(pC + c_hstep * 2); + _c3 = vld1q_f32(pC + c_hstep * 3); + transpose4x4_ps(_c0, _c1, _c2, _c3); + pC += 4; } - if (broadcast_type_C == 4) + if (beta == 1.f) { - float32x4_t _c = vld1q_f32(pC); - _c = vmulq_n_f32(_c, beta); -#if __aarch64__ - _c0 = vdupq_laneq_f32(_c, 0); - float32x4_t _c1 = vdupq_laneq_f32(_c, 1); - float32x4_t _c2 = vdupq_laneq_f32(_c, 2); - float32x4_t _c3 = vdupq_laneq_f32(_c, 3); -#else - _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); - float32x4_t _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); - float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); - float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); -#endif _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); _f3 = vaddq_f32(_f3, _c3); - pC += 4; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); + float32x4_t _c = vld1q_f32(pC); + _c = vmulq_n_f32(_c, beta); +#if __aarch64__ + _c0 = vdupq_laneq_f32(_c, 0); + float32x4_t _c1 = vdupq_laneq_f32(_c, 1); + float32x4_t _c2 = vdupq_laneq_f32(_c, 2); + float32x4_t _c3 = vdupq_laneq_f32(_c, 3); +#else + _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); + float32x4_t _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); + float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); + float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); +#endif + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 4; } + } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + + if (out_elempack == 4) + { vst1q_f32(p0, _f0); vst1q_f32(p0 + 4, _f1); vst1q_f32(p0 + 8, _f2); vst1q_f32(p0 + 12, _f3); - - pp += 16; p0 += 16; } - for (; jj + 1 < max_jj; jj += 2) + if (out_elempack == 1) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); + transpose4x4_ps(_f0, _f1, _f2, _f3); + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep, _f1); + vst1q_f32(p0 + out_hstep * 2, _f2); + vst1q_f32(p0 + out_hstep * 3, _f3); + p0 += 4; + } + + pp += 16; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); #if __ARM_FEATURE_DOTPROD - // from/to - // a0 b0 c0 d0 - // a1 b1 c1 d1 + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 #else - // from - // a0 b1 c0 d1 - // a1 b0 c1 d0 + // from + // a0 b1 c0 d1 + // a1 b0 c1 d0 - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - { - _sum1 = vrev64q_s32(_sum1); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); - _sum1 = vrev64q_s32(_sum1); - } + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + { + _sum1 = vrev64q_s32(_sum1); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + } #endif // __ARM_FEATURE_DOTPROD - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - if (pC) + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3) + { + float32x4_t _c1; + if (c_elempack == 4) { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + pC += 8; } - if (broadcast_type_C == 3) + if (c_elempack == 1) { - float32x4_t _c1; - if (c_elempack == 1) - { - float32x2_t _cc0 = vld1_f32(pC); - float32x2_t _cc1 = vld1_f32(pC + c_hstep); - float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); - float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); - float32x4_t _c01 = vcombine_f32(_cc0, _cc1); - float32x4_t _c23 = vcombine_f32(_cc2, _cc3); - float32x4x2_t _cc = vuzpq_f32(_c01, _c23); - _c0 = _cc.val[0]; - _c1 = _cc.val[1]; - pC += 2; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - pC += 8; - } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } + float32x2_t _cc0 = vld1_f32(pC); + float32x2_t _cc1 = vld1_f32(pC + c_hstep); + float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); + float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); + float32x4_t _c01 = vcombine_f32(_cc0, _cc1); + float32x4_t _c23 = vcombine_f32(_cc2, _cc3); + float32x4x2_t _cc = vuzpq_f32(_c01, _c23); + _c0 = _cc.val[0]; + _c1 = _cc.val[1]; + pC += 2; } - if (broadcast_type_C == 4) + if (beta == 1.f) { - float32x2_t _c = vld1_f32(pC); - _c = vmul_n_f32(_c, beta); - _c0 = vdupq_lane_f32(_c, 0); - float32x4_t _c1 = vdupq_lane_f32(_c, 1); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); - pC += 2; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); + float32x2_t _c = vld1_f32(pC); + _c = vmul_n_f32(_c, beta); + _c0 = vdupq_lane_f32(_c, 0); + float32x4_t _c1 = vdupq_lane_f32(_c, 1); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 2; } + } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + + if (out_elempack == 4) + { vst1q_f32(p0, _f0); vst1q_f32(p0 + 4, _f1); - - pp += 8; p0 += 8; } - for (; jj < max_jj; jj++) + if (out_elempack == 1) { - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - } - if (broadcast_type_C == 3) - { - if (c_elempack == 1) - { - _c0 = vsetq_lane_f32(pC[0], _c0, 0); - _c0 = vsetq_lane_f32(pC[c_hstep], _c0, 1); - _c0 = vsetq_lane_f32(pC[c_hstep * 2], _c0, 2); - _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); - pC += 1; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - pC += 4; - } - _f0 = vmlaq_n_f32(_f0, _c0, beta); - } - if (broadcast_type_C == 4) - { - _c0 = vdupq_n_f32(pC[0] * beta); - _f0 = vaddq_f32(_f0, _c0); - pC += 1; - } - } - - _f0 = vmulq_n_f32(_f0, alpha); - - vst1q_f32(p0, _f0); - - pp += 4; - p0 += 4; + float32x4x2_t _f01 = vzipq_f32(_f0, _f1); + vst1_f32(p0, vget_low_f32(_f01.val[0])); + vst1_f32(p0 + out_hstep, vget_high_f32(_f01.val[0])); + vst1_f32(p0 + out_hstep * 2, vget_low_f32(_f01.val[1])); + vst1_f32(p0 + out_hstep * 3, vget_high_f32(_f01.val[1])); + p0 += 2; } + + pp += 8; } - if (out_elempack == 1) + for (; jj < max_jj; jj++) { - int jj = 0; -#if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); -#if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - - // to - // a0 a1 a2 a3 - // a4 a5 a6 a7 - // b0 b1 b2 b3 - // b4 b5 b6 b7 - // c0 c1 c2 c3 - // c4 c5 c6 c7 - // d0 d1 d2 d3 - // d4 d5 d6 d7 + if (pC) + { + if (broadcast_type_C == 0) { - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum3); - int32x4x2_t _t2 = vzipq_s32(_sum4, _sum5); - int32x4x2_t _t3 = vzipq_s32(_sum6, _sum7); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum4 = vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t1.val[1])); - _sum5 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); - _sum6 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); + _f0 = vaddq_f32(_f0, _c0); } -#else - // from - // a0 b1 c2 d3 - // a4 b5 c6 d7 - // c0 d1 a2 b3 - // c4 d5 a6 b7 - // a3 b2 c1 d0 - // a7 b6 c5 d4 - // c3 d2 a1 b0 - // c7 d6 a5 b4 - - // to - // a0 a1 a2 a3 - // a4 a5 a6 a7 - // b0 b1 b2 b3 - // b4 b5 b6 b7 - // c0 c1 c2 c3 - // c4 c5 c6 c7 - // d0 d1 d2 d3 - // d4 d5 d6 d7 + if (broadcast_type_C == 1 || broadcast_type_C == 2) { - _sum2 = vextq_s32(_sum2, _sum2, 2); - _sum3 = vextq_s32(_sum3, _sum3, 2); - _sum6 = vextq_s32(_sum6, _sum6, 2); - _sum7 = vextq_s32(_sum7, _sum7, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); - int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); - int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum4 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum5 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sum6 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sum2 = vrev64q_s32(_sum2); - _sum3 = vrev64q_s32(_sum3); - _sum6 = vrev64q_s32(_sum6); - _sum7 = vrev64q_s32(_sum7); + _f0 = vaddq_f32(_f0, _c0); } -#endif // __ARM_FEATURE_DOTPROD - - float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale, 0); - float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale, 0); - float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale, 1); - float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale, 1); - float32x4_t _f4 = vmulq_laneq_f32(vcvtq_f32_s32(_sum4), _descale, 2); - float32x4_t _f5 = vmulq_laneq_f32(vcvtq_f32_s32(_sum5), _descale, 2); - float32x4_t _f6 = vmulq_laneq_f32(vcvtq_f32_s32(_sum6), _descale, 3); - float32x4_t _f7 = vmulq_laneq_f32(vcvtq_f32_s32(_sum7), _descale, 3); - - if (pC) + if (broadcast_type_C == 3) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); - float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); - float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); - float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); - _f0 = vaddq_f32(_f0, _cc0); - _f1 = vaddq_f32(_f1, _cc0); - _f2 = vaddq_f32(_f2, _cc1); - _f3 = vaddq_f32(_f3, _cc1); - _f4 = vaddq_f32(_f4, _cc2); - _f5 = vaddq_f32(_f5, _cc2); - _f6 = vaddq_f32(_f6, _cc3); - _f7 = vaddq_f32(_f7, _cc3); - } - if (broadcast_type_C == 3) + if (c_elempack == 4) { - if (c_elempack == 1) - { - _c0 = vld1q_f32(pC); - float32x4_t _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + c_hstep); - float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); - float32x4_t _c4 = vld1q_f32(pC + c_hstep * 2); - float32x4_t _c5 = vld1q_f32(pC + c_hstep * 2 + 4); - float32x4_t _c6 = vld1q_f32(pC + c_hstep * 3); - float32x4_t _c7 = vld1q_f32(pC + c_hstep * 3 + 4); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } - pC += 8; - } - else // if (c_elempack == 4) - { - float32x4x4_t _cc0 = vld4q_f32(pC); - float32x4x4_t _cc1 = vld4q_f32(pC + 16); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _cc0.val[0]); - _f1 = vaddq_f32(_f1, _cc1.val[0]); - _f2 = vaddq_f32(_f2, _cc0.val[1]); - _f3 = vaddq_f32(_f3, _cc1.val[1]); - _f4 = vaddq_f32(_f4, _cc0.val[2]); - _f5 = vaddq_f32(_f5, _cc1.val[2]); - _f6 = vaddq_f32(_f6, _cc0.val[3]); - _f7 = vaddq_f32(_f7, _cc1.val[3]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _cc0.val[0], _beta); - _f1 = vmlaq_f32(_f1, _cc1.val[0], _beta); - _f2 = vmlaq_f32(_f2, _cc0.val[1], _beta); - _f3 = vmlaq_f32(_f3, _cc1.val[1], _beta); - _f4 = vmlaq_f32(_f4, _cc0.val[2], _beta); - _f5 = vmlaq_f32(_f5, _cc1.val[2], _beta); - _f6 = vmlaq_f32(_f6, _cc0.val[3], _beta); - _f7 = vmlaq_f32(_f7, _cc1.val[3], _beta); - } - pC += 32; - } + _c0 = vld1q_f32(pC); + pC += 4; } - if (broadcast_type_C == 4) + if (c_elempack == 1) { - _c0 = vld1q_f32(pC); - float32x4_t _c1 = vld1q_f32(pC + 4); - if (beta != 1.f) - { - float32x4_t _beta = vdupq_n_f32(beta); - _c0 = vmulq_f32(_c0, _beta); - _c1 = vmulq_f32(_c1, _beta); - } - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c1); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c1); - pC += 8; + _c0 = vsetq_lane_f32(pC[0], _c0, 0); + _c0 = vsetq_lane_f32(pC[c_hstep], _c0, 1); + _c0 = vsetq_lane_f32(pC[c_hstep * 2], _c0, 2); + _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); + pC += 1; } + _f0 = vmlaq_n_f32(_f0, _c0, beta); } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); + _c0 = vdupq_n_f32(pC[0] * beta); + _f0 = vaddq_f32(_f0, _c0); + pC += 1; } + } - vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f1); - vst1q_f32(p0 + out_hstep, _f2); - vst1q_f32(p0 + out_hstep + 4, _f3); - vst1q_f32(p0 + out_hstep * 2, _f4); - vst1q_f32(p0 + out_hstep * 2 + 4, _f5); - vst1q_f32(p0 + out_hstep * 3, _f6); - vst1q_f32(p0 + out_hstep * 3 + 4, _f7); + _f0 = vmulq_n_f32(_f0, alpha); - pp += 32; - p0 += 8; + if (out_elempack == 4) + { + vst1q_f32(p0, _f0); + p0 += 4; } -#endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) + if (out_elempack == 1) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); + p0[0] = vgetq_lane_f32(_f0, 0); + p0[out_hstep] = vgetq_lane_f32(_f0, 1); + p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2); + p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3); + p0++; + } -#if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - - // to - // a0 a1 a2 a3 - // b0 b1 b2 b3 - // c0 c1 c2 c3 - // d0 d1 d2 d3 - { - int32x4x2_t _r01 = vzipq_s32(_sum0, _sum1); - int32x4x2_t _r23 = vzipq_s32(_sum2, _sum3); - _sum0 = vcombine_s32(vget_low_s32(_r01.val[0]), vget_low_s32(_r23.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_r01.val[0]), vget_high_s32(_r23.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_r01.val[1]), vget_low_s32(_r23.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_r01.val[1]), vget_high_s32(_r23.val[1])); - } -#else - // from - // a0 b1 c2 d3 - // c0 d1 a2 b3 - // a3 b2 c1 d0 - // c3 d2 a1 b0 - - // to - // a0 a1 a2 a3 - // b0 b1 b2 b3 - // c0 c1 c2 c3 - // d0 d1 d2 d3 - { - _sum1 = vextq_s32(_sum1, _sum1, 2); - _sum3 = vextq_s32(_sum3, _sum3, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); - int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - } -#endif // __ARM_FEATURE_DOTPROD + pp += 4; + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + // out_elempack == 1 + float* p0 = (float*)top_blob + (i + ii) * out_hstep + j; -#if __aarch64__ - float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale, 0); - float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale, 1); - float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale, 2); - float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale, 3); -#else - float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), vget_low_f32(_descale), 0); - float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), vget_low_f32(_descale), 1); - float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), vget_high_f32(_descale), 0); - float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), vget_high_f32(_descale), 1); + const float descale0 = descales[ii]; + const float descale1 = descales[ii + 1]; +#if __ARM_NEON + float32x2_t _descale = vld1_f32((const float*)descales + ii); +#endif + + float c0; + float c1; +#if __ARM_NEON + float32x4_t _c0; + float32x4_t _c1; +#endif + if (pC) + { + if (broadcast_type_C == 0) + { + c0 = pC[0] * beta; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)C + i + ii; + c0 = pC[0] * beta; + c1 = pC[1] * beta; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); + _c1 = vdupq_n_f32(c1); #endif + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + pC = (const float*)C + (i + ii) * c_hstep + j; + } + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); - if (pC) + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale, 0); + float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), _descale, 1); + float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), _descale, 1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) { - if (broadcast_type_C == 0) + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + c_hstep); + float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); + if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { -#if __aarch64__ - float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); - float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); - float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); - float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); -#else - float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); - float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); - float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); - float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); -#endif - _f0 = vaddq_f32(_f0, _cc0); - _f1 = vaddq_f32(_f1, _cc1); - _f2 = vaddq_f32(_f2, _cc2); - _f3 = vaddq_f32(_f3, _cc3); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); } - if (broadcast_type_C == 3) + else { - if (c_elempack == 1) - { - _c0 = vld1q_f32(pC); - float32x4_t _c1 = vld1q_f32(pC + c_hstep * 1); - float32x4_t _c2 = vld1q_f32(pC + c_hstep * 2); - float32x4_t _c3 = vld1q_f32(pC + c_hstep * 3); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - pC += 4; - } - else // if (c_elempack == 4) - { - float32x4x4_t _c = vld4q_f32(pC); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c.val[0]); - _f1 = vaddq_f32(_f1, _c.val[1]); - _f2 = vaddq_f32(_f2, _c.val[2]); - _f3 = vaddq_f32(_f3, _c.val[3]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c.val[0], _beta); - _f1 = vmlaq_f32(_f1, _c.val[1], _beta); - _f2 = vmlaq_f32(_f2, _c.val[2], _beta); - _f3 = vmlaq_f32(_f3, _c.val[3], _beta); - } - pC += 16; - } + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } - if (broadcast_type_C == 4) + pC += 8; + } + if (broadcast_type_C == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + if (beta != 1.f) { - _c0 = vld1q_f32(pC); - _c0 = vmulq_n_f32(_c0, beta); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - pC += 4; + float32x4_t _beta = vdupq_n_f32(beta); + _c0 = vmulq_f32(_c0, _beta); + _c1 = vmulq_f32(_c1, _beta); } + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 8; } + } - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } - vst1q_f32(p0, _f0); - vst1q_f32(p0 + out_hstep, _f1); - vst1q_f32(p0 + out_hstep * 2, _f2); - vst1q_f32(p0 + out_hstep * 3, _f3); + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + vst1q_f32(p0 + out_hstep, _f2); + vst1q_f32(p0 + out_hstep + 4, _f3); - pp += 16; - p0 += 4; - } - for (; jj + 1 < max_jj; jj += 2) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); + pp += 16; + p0 += 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); -#if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale, 1); - // to - // a0 a1 b0 b1 - // c0 c1 d0 d1 + if (pC) + { + if (broadcast_type_C == 0) { - int32x4x2_t _sum01 = vzipq_s32(_sum0, _sum1); - _sum0 = _sum01.val[0]; - _sum1 = _sum01.val[1]; + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); } -#else - // from - // a0 b1 c0 d1 - // a1 b0 c1 d0 - - // to - // a0 a1 b0 b1 - // c0 c1 d0 d1 + if (broadcast_type_C == 1 || broadcast_type_C == 2) { - int32x4_t _t0 = vuzpq_s32(_sum0, _sum1).val[0]; - int32x4_t _t1 = vuzpq_s32(_sum1, _sum0).val[1]; - int32x4x2_t _t3 = vuzpq_s32(_t0, _t1); - _sum0 = _t3.val[0]; - _sum1 = _t3.val[1]; + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); } -#endif // __ARM_FEATURE_DOTPROD - - float32x4x2_t _descale01 = vzipq_f32(_descale, _descale); - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale01.val[0]); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale01.val[1]); - - if (pC) + if (broadcast_type_C == 3) { - if (broadcast_type_C == 0) + // c_elempack == 1 + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep); + if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - float32x4x2_t _cc0 = vzipq_f32(_c0, _c0); - _f0 = vaddq_f32(_f0, _cc0.val[0]); - _f1 = vaddq_f32(_f1, _cc0.val[1]); - } - if (broadcast_type_C == 3) - { - if (c_elempack == 1) - { - float32x2_t _cc0 = vld1_f32(pC); - float32x2_t _cc1 = vld1_f32(pC + c_hstep); - float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); - float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); - _c0 = vcombine_f32(_cc0, _cc1); - float32x4_t _c1 = vcombine_f32(_cc2, _cc3); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } - pC += 2; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - float32x4_t _c1 = vld1q_f32(pC + 4); - float32x4x2_t _c01 = vzipq_f32(_c0, _c1); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c01.val[0]); - _f1 = vaddq_f32(_f1, _c01.val[1]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c01.val[0], _beta); - _f1 = vmlaq_f32(_f1, _c01.val[1], _beta); - } - pC += 8; - } + _f1 = vaddq_f32(_f1, _c1); } - if (broadcast_type_C == 4) + else { - float32x2_t _cc0 = vld1_f32(pC); - _cc0 = vmul_n_f32(_cc0, beta); - _c0 = vcombine_f32(_cc0, _cc0); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - pC += 2; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } + pC += 4; } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); + _c0 = vld1q_f32(pC); + _c0 = vmulq_n_f32(_c0, beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 4; } - - vst1_f32(p0, vget_low_f32(_f0)); - vst1_f32(p0 + out_hstep, vget_high_f32(_f0)); - vst1_f32(p0 + out_hstep * 2, vget_low_f32(_f1)); - vst1_f32(p0 + out_hstep * 3, vget_high_f32(_f1)); - - pp += 8; - p0 += 2; } - for (; jj < max_jj; jj++) + + if (alpha != 1.f) { - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep, _f1); + + pp += 8; + p0 += 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); - if (pC) + float32x2x2_t _descale01 = vzip_f32(_descale, _descale); + float32x4_t _descale0011 = vcombine_f32(_descale01.val[0], _descale01.val[1]); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0011); + + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - } - if (broadcast_type_C == 3) - { - if (c_elempack == 1) - { - _c0 = vsetq_lane_f32(pC[0], _c0, 0); - _c0 = vsetq_lane_f32(pC[c_hstep], _c0, 1); - _c0 = vsetq_lane_f32(pC[c_hstep * 2], _c0, 2); - _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); - pC += 1; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - pC += 4; - } - _f0 = vmlaq_n_f32(_f0, _c0, beta); - } - if (broadcast_type_C == 4) - { - _c0 = vdupq_n_f32(pC[0] * beta); - _f0 = vaddq_f32(_f0, _c0); - pC += 1; - } + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + float32x4_t _c0011 = vcombine_f32(vget_low_f32(_c0), vget_high_f32(_c1)); + _f0 = vaddq_f32(_f0, _c0011); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0 = vcombine_f32(vld1_f32(pC), vld1_f32(pC + c_hstep)); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 2; + } + if (broadcast_type_C == 4) + { + float32x2_t _c = vld1_f32(pC); + _c0 = vcombine_f32(_c, _c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 2; } + } - _f0 = vmulq_n_f32(_f0, alpha); + _f0 = vmulq_n_f32(_f0, alpha); - p0[0] = vgetq_lane_f32(_f0, 0); - p0[out_hstep] = vgetq_lane_f32(_f0, 1); - p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2); - p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3); + vst1_f32(p0, vget_low_f32(_f0)); + vst1_f32(p0 + out_hstep, vget_high_f32(_f0)); - pp += 4; - p0++; + pp += 4; + p0 += 2; + } +#endif // __ARM_NEON + for (; jj < max_jj; jj++) + { + float f0 = pp[0] * descale0; + float f1 = pp[1] * descale1; + + if (pC) + { + if (broadcast_type_C == 0) + { + f0 += c0; + f1 += c0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + f0 += c0; + f1 += c1; + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + f0 += pC[0] * beta; + f1 += pC[c_hstep] * beta; + pC += 1; + } + if (broadcast_type_C == 4) + { + f0 += pC[0] * beta; + f1 += pC[0] * beta; + pC += 1; + } } + + f0 *= alpha; + f1 *= alpha; + + p0[0] = f0; + p0[out_hstep] = f1; + + pp += 2; + p0++; } } -#endif // __ARM_NEON - for (; ii + 1 < max_ii; ii += 2) + for (; ii < max_ii; ii += 1) { // out_elempack == 1 float* p0 = (float*)top_blob + (i + ii) * out_hstep + j; - const float descale0 = descales[ii]; - const float descale1 = descales[ii + 1]; + const float descale = descales[ii]; #if __ARM_NEON - float32x2_t _descale = vld1_f32((const float*)descales + ii); + float32x4_t _descale = vdupq_n_f32(descale); #endif float c0; - float c1; #if __ARM_NEON float32x4_t _c0; - float32x4_t _c1; #endif if (pC) { @@ -8655,10 +7489,8 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& { pC = (const float*)C + i + ii; c0 = pC[0] * beta; - c1 = pC[1] * beta; #if __ARM_NEON _c0 = vdupq_n_f32(c0); - _c1 = vdupq_n_f32(c1); #endif } if (broadcast_type_C == 3) @@ -8672,484 +7504,197 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& } } - // if (out_elempack == 1) - { - int jj = 0; + int jj = 0; #if __ARM_NEON -#if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); + for (; jj + 15 < max_jj; jj += 16) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); - float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale, 0); - float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale, 0); - float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), _descale, 1); - float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), _descale, 1); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); - if (pC) + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // out_elempack == 1 + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + 8); + float32x4_t _c3 = vld1q_f32(pC + 12); + if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c1); - _f3 = vaddq_f32(_f3, _c1); - } - if (broadcast_type_C == 3) - { - // c_elempack == 1 - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + c_hstep); - float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - pC += 8; + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); } - if (broadcast_type_C == 4) + else { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - if (beta != 1.f) - { - float32x4_t _beta = vdupq_n_f32(beta); - _c0 = vmulq_f32(_c0, _beta); - _c1 = vmulq_f32(_c1, _beta); - } - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c1); - pC += 8; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } + pC += 16; } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - } - - vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f1); - vst1q_f32(p0 + out_hstep, _f2); - vst1q_f32(p0 + out_hstep + 4, _f3); - - pp += 16; - p0 += 8; } -#endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) + + if (alpha != 1.f) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } - float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale, 0); - float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale, 1); + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + vst1q_f32(p0 + 8, _f2); + vst1q_f32(p0 + 12, _f3); - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - if (broadcast_type_C == 3) - { - // c_elempack == 1 - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + c_hstep); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } - pC += 4; - } - if (broadcast_type_C == 4) - { - _c0 = vld1q_f32(pC); - _c0 = vmulq_n_f32(_c0, beta); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - pC += 4; - } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - } + pp += 16; + p0 += 16; + } + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); - vst1q_f32(p0, _f0); - vst1q_f32(p0 + out_hstep, _f1); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - pp += 8; - p0 += 4; - } - for (; jj + 1 < max_jj; jj += 2) + if (pC) { - int32x4_t _sum0 = vld1q_s32(pp); - - float32x2x2_t _descale01 = vzip_f32(_descale, _descale); - float32x4_t _descale0011 = vcombine_f32(_descale01.val[0], _descale01.val[1]); - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0011); - - if (pC) + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - float32x4_t _c0011 = vcombine_f32(vget_low_f32(_c0), vget_high_f32(_c1)); - _f0 = vaddq_f32(_f0, _c0011); - } - if (broadcast_type_C == 3) - { - // c_elempack == 1 - _c0 = vcombine_f32(vld1_f32(pC), vld1_f32(pC + c_hstep)); - _f0 = vmlaq_n_f32(_f0, _c0, beta); - pC += 2; - } - if (broadcast_type_C == 4) - { - float32x2_t _c = vld1_f32(pC); - _c0 = vcombine_f32(_c, _c); - _f0 = vmlaq_n_f32(_f0, _c0, beta); - pC += 2; - } + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); } - - _f0 = vmulq_n_f32(_f0, alpha); - - vst1_f32(p0, vget_low_f32(_f0)); - vst1_f32(p0 + out_hstep, vget_high_f32(_f0)); - - pp += 4; - p0 += 2; - } -#endif // __ARM_NEON - for (; jj < max_jj; jj++) - { - float f0 = pp[0] * descale0; - float f1 = pp[1] * descale1; - - if (pC) + if (broadcast_type_C == 3 || broadcast_type_C == 4) { - if (broadcast_type_C == 0) - { - f0 += c0; - f1 += c0; - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + // out_elempack == 1 + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + 4); + if (beta == 1.f) { - f0 += c0; - f1 += c1; - } - if (broadcast_type_C == 3) - { - // c_elempack == 1 - f0 += pC[0] * beta; - f1 += pC[c_hstep] * beta; - pC += 1; + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); } - if (broadcast_type_C == 4) + else { - f0 += pC[0] * beta; - f1 += pC[0] * beta; - pC += 1; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } + pC += 8; } - - f0 *= alpha; - f1 *= alpha; - - p0[0] = f0; - p0[out_hstep] = f1; - - pp += 2; - p0++; } - } - } - for (; ii < max_ii; ii += 1) - { - // out_elempack == 1 - float* p0 = (float*)top_blob + (i + ii) * out_hstep + j; - - const float descale = descales[ii]; -#if __ARM_NEON - float32x4_t _descale = vdupq_n_f32(descale); -#endif - float c0; -#if __ARM_NEON - float32x4_t _c0; -#endif - if (pC) - { - if (broadcast_type_C == 0) - { - c0 = pC[0] * beta; -#if __ARM_NEON - _c0 = vdupq_n_f32(c0); -#endif - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - pC = (const float*)C + i + ii; - c0 = pC[0] * beta; -#if __ARM_NEON - _c0 = vdupq_n_f32(c0); -#endif - } - if (broadcast_type_C == 3) - { - // c_elempack == 1 - pC = (const float*)C + (i + ii) * c_hstep + j; - } - if (broadcast_type_C == 4) + if (alpha != 1.f) { - pC = (const float*)C + j; + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); } - } - // if (out_elempack == 1) - { - int jj = 0; -#if __ARM_NEON - for (; jj + 15 < max_jj; jj += 16) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + pp += 8; + p0 += 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); - if (pC) + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // out_elempack == 1 - _c0 = vld1q_f32(pC); - float32x4_t _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + 8); - float32x4_t _c3 = vld1q_f32(pC + 12); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - pC += 16; - } + _f0 = vaddq_f32(_f0, _c0); } - - if (alpha != 1.f) + if (broadcast_type_C == 3 || broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); + // out_elempack == 1 + _c0 = vld1q_f32(pC); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 4; } + } - vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f1); - vst1q_f32(p0 + 8, _f2); - vst1q_f32(p0 + 12, _f3); + _f0 = vmulq_n_f32(_f0, alpha); - pp += 16; - p0 += 16; - } - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); + vst1q_f32(p0, _f0); - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + pp += 4; + p0 += 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + float32x2_t _f0 = vmul_f32(vcvt_f32_s32(vld1_s32(pp)), vget_low_f32(_descale)); - if (pC) + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // out_elempack == 1 - _c0 = vld1q_f32(pC); - float32x4_t _c1 = vld1q_f32(pC + 4); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } - pC += 8; - } + _f0 = vadd_f32(_f0, vget_low_f32(_c0)); } - - if (alpha != 1.f) + if (broadcast_type_C == 3 || broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); + // out_elempack == 1 + float32x2_t _c = vld1_f32(pC); + _f0 = vmla_n_f32(_f0, _c, beta); + pC += 2; } - - vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f1); - - pp += 8; - p0 += 8; } - for (; jj + 3 < max_jj; jj += 4) - { - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); - if (pC) - { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // out_elempack == 1 - _c0 = vld1q_f32(pC); - _f0 = vmlaq_n_f32(_f0, _c0, beta); - pC += 4; - } - } + _f0 = vmul_n_f32(_f0, alpha); - _f0 = vmulq_n_f32(_f0, alpha); + vst1_f32(p0, _f0); - vst1q_f32(p0, _f0); + pp += 2; + p0 += 2; + } +#endif // __ARM_NEON + for (; jj < max_jj; jj++) + { + float f0 = pp[0] * descale; - pp += 4; - p0 += 4; - } - for (; jj + 1 < max_jj; jj += 2) + if (pC) { - float32x2_t _f0 = vmul_f32(vcvt_f32_s32(vld1_s32(pp)), vget_low_f32(_descale)); - - if (pC) + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vadd_f32(_f0, vget_low_f32(_c0)); - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // out_elempack == 1 - float32x2_t _c = vld1_f32(pC); - _f0 = vmla_n_f32(_f0, _c, beta); - pC += 2; - } + f0 += c0; } - - _f0 = vmul_n_f32(_f0, alpha); - - vst1_f32(p0, _f0); - - pp += 2; - p0 += 2; - } -#endif // __ARM_NEON - for (; jj < max_jj; jj++) - { - float f0 = pp[0] * descale; - - if (pC) + if (broadcast_type_C == 3 || broadcast_type_C == 4) { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - f0 += c0; - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // out_elempack == 1 - f0 += pC[0] * beta; - pC += 1; - } + // out_elempack == 1 + f0 += pC[0] * beta; + pC += 1; } + } - f0 *= alpha; + f0 *= alpha; - p0[0] = f0; + p0[0] = f0; - pp += 1; - p0++; - } + pp += 1; + p0++; } } } @@ -9210,1126 +7755,415 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma } } - if (out_elempack == 4) - { - int jj = 0; + int jj = 0; #if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); - int32x4_t _sum8 = vld1q_s32(pp + 32); - int32x4_t _sum9 = vld1q_s32(pp + 36); - int32x4_t _suma = vld1q_s32(pp + 40); - int32x4_t _sumb = vld1q_s32(pp + 44); - int32x4_t _sumc = vld1q_s32(pp + 48); - int32x4_t _sumd = vld1q_s32(pp + 52); - int32x4_t _sume = vld1q_s32(pp + 56); - int32x4_t _sumf = vld1q_s32(pp + 60); - -#if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - // e4 f4 g4 h4 - // e5 f5 g5 h5 - // e6 f6 g6 h6 - // e7 f7 g7 h7 - - // to - // a0 a1 a2 a3 - // a4 a5 a6 a7 - // b0 b1 b2 b3 - // b4 b5 b6 b7 - // c0 c1 c2 c3 - // c4 c5 c6 c7 - // d0 d1 d2 d3 - // d4 d5 d6 d7 - // e0 e1 e2 e3 - // e4 e5 e6 e7 - // f0 f1 f2 f3 - // f4 f5 f6 f7 - // g0 g1 g2 g3 - // g4 g5 g6 g7 - // h0 h1 h2 h3 - // h4 h5 h6 h7 - { - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum3); - int32x4x2_t _t2 = vzipq_s32(_sum8, _sum9); - int32x4x2_t _t3 = vzipq_s32(_suma, _sumb); - int32x4x2_t _t4 = vzipq_s32(_sum4, _sum5); - int32x4x2_t _t5 = vzipq_s32(_sum6, _sum7); - int32x4x2_t _t6 = vzipq_s32(_sumc, _sumd); - int32x4x2_t _t7 = vzipq_s32(_sume, _sumf); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum4 = vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t1.val[1])); - _sum5 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); - _sum6 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); - _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); - _sum9 = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); - _suma = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); - _sumb = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); - _sumc = vcombine_s32(vget_low_s32(_t4.val[1]), vget_low_s32(_t5.val[1])); - _sumd = vcombine_s32(vget_low_s32(_t6.val[1]), vget_low_s32(_t7.val[1])); - _sume = vcombine_s32(vget_high_s32(_t4.val[1]), vget_high_s32(_t5.val[1])); - _sumf = vcombine_s32(vget_high_s32(_t6.val[1]), vget_high_s32(_t7.val[1])); - } -#else // __ARM_FEATURE_DOTPROD - - // from - // a0 b1 c2 d3 - // e4 f5 g6 h7 - // e0 f1 g2 h3 - // a4 b5 c6 d7 - // c0 d1 a2 b3 - // g4 h5 e6 f7 - // g0 h1 e2 f3 - // c4 d5 a6 b7 - // a3 b2 c1 d0 - // e7 f6 g5 h4 - // e3 f2 g1 h0 - // a7 b6 c5 d4 - // c3 d2 a1 b0 - // g7 h6 e5 f4 - // g3 h2 e1 f0 - // c7 d6 a5 b4 - - // to - // a0 a1 a2 a3 - // a4 a5 a6 a7 - // b0 b1 b2 b3 - // b4 b5 b6 b7 - // c0 c1 c2 c3 - // c4 c5 c6 c7 - // d0 d1 d2 d3 - // d4 d5 d6 d7 - // e0 e1 e2 e3 - // e4 e5 e6 e7 - // f0 f1 f2 f3 - // f4 f5 f6 f7 - // g0 g1 g2 g3 - // g4 g5 g6 g7 - // h0 h1 h2 h3 - // h4 h5 h6 h7 - { - _sum4 = vextq_s32(_sum4, _sum4, 2); - _sum5 = vextq_s32(_sum5, _sum5, 2); - _sum6 = vextq_s32(_sum6, _sum6, 2); - _sum7 = vextq_s32(_sum7, _sum7, 2); - _sumc = vextq_s32(_sumc, _sumc, 2); - _sumd = vextq_s32(_sumd, _sumd, 2); - _sume = vextq_s32(_sume, _sume, 2); - _sumf = vextq_s32(_sumf, _sumf, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); - int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); - int32x4x2_t _t2 = vzipq_s32(_sum3, _sumf); - int32x4x2_t _t3 = vzipq_s32(_sum7, _sumb); - int32x4x2_t _t4 = vzipq_s32(_sum2, _sume); - int32x4x2_t _t5 = vzipq_s32(_sum6, _suma); - int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); - int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum4 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum5 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sum6 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); - _sum9 = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); - _suma = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); - _sumb = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); - _sumc = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); - _sumd = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); - _sume = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); - _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); - _sum2 = vrev64q_s32(_sum2); - _sum3 = vrev64q_s32(_sum3); - _sum6 = vrev64q_s32(_sum6); - _sum7 = vrev64q_s32(_sum7); - _suma = vrev64q_s32(_suma); - _sumb = vrev64q_s32(_sumb); - _sume = vrev64q_s32(_sume); - _sumf = vrev64q_s32(_sumf); - } -#endif // __ARM_FEATURE_DOTPROD - - float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale0, 0); - float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale0, 0); - float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale0, 1); - float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale0, 1); - float32x4_t _f4 = vmulq_laneq_f32(vcvtq_f32_s32(_sum4), _descale0, 2); - float32x4_t _f5 = vmulq_laneq_f32(vcvtq_f32_s32(_sum5), _descale0, 2); - float32x4_t _f6 = vmulq_laneq_f32(vcvtq_f32_s32(_sum6), _descale0, 3); - float32x4_t _f7 = vmulq_laneq_f32(vcvtq_f32_s32(_sum7), _descale0, 3); - float32x4_t _f8 = vmulq_laneq_f32(vcvtq_f32_s32(_sum8), _descale1, 0); - float32x4_t _f9 = vmulq_laneq_f32(vcvtq_f32_s32(_sum9), _descale1, 0); - float32x4_t _fa = vmulq_laneq_f32(vcvtq_f32_s32(_suma), _descale1, 1); - float32x4_t _fb = vmulq_laneq_f32(vcvtq_f32_s32(_sumb), _descale1, 1); - float32x4_t _fc = vmulq_laneq_f32(vcvtq_f32_s32(_sumc), _descale1, 2); - float32x4_t _fd = vmulq_laneq_f32(vcvtq_f32_s32(_sumd), _descale1, 2); - float32x4_t _fe = vmulq_laneq_f32(vcvtq_f32_s32(_sume), _descale1, 3); - float32x4_t _ff = vmulq_laneq_f32(vcvtq_f32_s32(_sumf), _descale1, 3); - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c0); - _fa = vaddq_f32(_fa, _c0); - _fb = vaddq_f32(_fb, _c0); - _fc = vaddq_f32(_fc, _c0); - _fd = vaddq_f32(_fd, _c0); - _fe = vaddq_f32(_fe, _c0); - _ff = vaddq_f32(_ff, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); - float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); - float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); - float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); - _f0 = vaddq_f32(_f0, _cc0); - _f1 = vaddq_f32(_f1, _cc0); - _f2 = vaddq_f32(_f2, _cc1); - _f3 = vaddq_f32(_f3, _cc1); - _f4 = vaddq_f32(_f4, _cc2); - _f5 = vaddq_f32(_f5, _cc2); - _f6 = vaddq_f32(_f6, _cc3); - _f7 = vaddq_f32(_f7, _cc3); - _cc0 = vdupq_laneq_f32(_c1, 0); - _cc1 = vdupq_laneq_f32(_c1, 1); - _cc2 = vdupq_laneq_f32(_c1, 2); - _cc3 = vdupq_laneq_f32(_c1, 3); - _f8 = vaddq_f32(_f8, _cc0); - _f9 = vaddq_f32(_f9, _cc0); - _fa = vaddq_f32(_fa, _cc1); - _fb = vaddq_f32(_fb, _cc1); - _fc = vaddq_f32(_fc, _cc2); - _fd = vaddq_f32(_fd, _cc2); - _fe = vaddq_f32(_fe, _cc3); - _ff = vaddq_f32(_ff, _cc3); - } - if (broadcast_type_C == 3) - { - if (c_elempack == 1) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + c_hstep); - float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); - float32x4_t _c4 = vld1q_f32(pC + c_hstep * 2); - float32x4_t _c5 = vld1q_f32(pC + c_hstep * 2 + 4); - float32x4_t _c6 = vld1q_f32(pC + c_hstep * 3); - float32x4_t _c7 = vld1q_f32(pC + c_hstep * 3 + 4); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } - _c0 = vld1q_f32(pC + c_hstep * 4); - _c1 = vld1q_f32(pC + c_hstep * 4 + 4); - _c2 = vld1q_f32(pC + c_hstep * 5); - _c3 = vld1q_f32(pC + c_hstep * 5 + 4); - _c4 = vld1q_f32(pC + c_hstep * 6); - _c5 = vld1q_f32(pC + c_hstep * 6 + 4); - _c6 = vld1q_f32(pC + c_hstep * 7); - _c7 = vld1q_f32(pC + c_hstep * 7 + 4); - if (beta == 1.f) - { - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c2); - _fb = vaddq_f32(_fb, _c3); - _fc = vaddq_f32(_fc, _c4); - _fd = vaddq_f32(_fd, _c5); - _fe = vaddq_f32(_fe, _c6); - _ff = vaddq_f32(_ff, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f8 = vmlaq_f32(_f8, _c0, _beta); - _f9 = vmlaq_f32(_f9, _c1, _beta); - _fa = vmlaq_f32(_fa, _c2, _beta); - _fb = vmlaq_f32(_fb, _c3, _beta); - _fc = vmlaq_f32(_fc, _c4, _beta); - _fd = vmlaq_f32(_fd, _c5, _beta); - _fe = vmlaq_f32(_fe, _c6, _beta); - _ff = vmlaq_f32(_ff, _c7, _beta); - } - pC += 8; - } - else // if (c_elempack == 4) - { - float32x4x4_t _cc0 = vld4q_f32(pC); - float32x4x4_t _cc1 = vld4q_f32(pC + 16); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _cc0.val[0]); - _f1 = vaddq_f32(_f1, _cc1.val[0]); - _f2 = vaddq_f32(_f2, _cc0.val[1]); - _f3 = vaddq_f32(_f3, _cc1.val[1]); - _f4 = vaddq_f32(_f4, _cc0.val[2]); - _f5 = vaddq_f32(_f5, _cc1.val[2]); - _f6 = vaddq_f32(_f6, _cc0.val[3]); - _f7 = vaddq_f32(_f7, _cc1.val[3]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _cc0.val[0], _beta); - _f1 = vmlaq_f32(_f1, _cc1.val[0], _beta); - _f2 = vmlaq_f32(_f2, _cc0.val[1], _beta); - _f3 = vmlaq_f32(_f3, _cc1.val[1], _beta); - _f4 = vmlaq_f32(_f4, _cc0.val[2], _beta); - _f5 = vmlaq_f32(_f5, _cc1.val[2], _beta); - _f6 = vmlaq_f32(_f6, _cc0.val[3], _beta); - _f7 = vmlaq_f32(_f7, _cc1.val[3], _beta); - } - _cc0 = vld4q_f32(pC + c_hstep * 4); - _cc1 = vld4q_f32(pC + c_hstep * 4 + 16); - if (beta == 1.f) - { - _f8 = vaddq_f32(_f8, _cc0.val[0]); - _f9 = vaddq_f32(_f9, _cc1.val[0]); - _fa = vaddq_f32(_fa, _cc0.val[1]); - _fb = vaddq_f32(_fb, _cc1.val[1]); - _fc = vaddq_f32(_fc, _cc0.val[2]); - _fd = vaddq_f32(_fd, _cc1.val[2]); - _fe = vaddq_f32(_fe, _cc0.val[3]); - _ff = vaddq_f32(_ff, _cc1.val[3]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f8 = vmlaq_f32(_f8, _cc0.val[0], _beta); - _f9 = vmlaq_f32(_f9, _cc1.val[0], _beta); - _fa = vmlaq_f32(_fa, _cc0.val[1], _beta); - _fb = vmlaq_f32(_fb, _cc1.val[1], _beta); - _fc = vmlaq_f32(_fc, _cc0.val[2], _beta); - _fd = vmlaq_f32(_fd, _cc1.val[2], _beta); - _fe = vmlaq_f32(_fe, _cc0.val[3], _beta); - _ff = vmlaq_f32(_ff, _cc1.val[3], _beta); - } - pC += 32; - } - } - if (broadcast_type_C == 4) + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + int32x4_t _sum8 = vld1q_s32(pp + 32); + int32x4_t _sum9 = vld1q_s32(pp + 36); + int32x4_t _suma = vld1q_s32(pp + 40); + int32x4_t _sumb = vld1q_s32(pp + 44); + int32x4_t _sumc = vld1q_s32(pp + 48); + int32x4_t _sumd = vld1q_s32(pp + 52); + int32x4_t _sume = vld1q_s32(pp + 56); + int32x4_t _sumf = vld1q_s32(pp + 60); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 +#else + // from + // a0 b1 c2 d3 + // e4 f5 g6 h7 + // e0 f1 g2 h3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // g4 h5 e6 f7 + // g0 h1 e2 f3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // e7 f6 g5 h4 + // e3 f2 g1 h0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // g7 h6 e5 f4 + // g3 h2 e1 f0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 + { + _sum8 = vrev64q_s32(_sum8); + _sum9 = vrev64q_s32(_sum9); + _suma = vrev64q_s32(_suma); + _sumb = vrev64q_s32(_sumb); + _sumc = vrev64q_s32(_sumc); + _sumd = vrev64q_s32(_sumd); + _sume = vrev64q_s32(_sume); + _sumf = vrev64q_s32(_sumf); + _sum8 = vextq_s32(_sum8, _sum8, 2); + _sum9 = vextq_s32(_sum9, _sum9, 2); + _suma = vextq_s32(_suma, _suma, 2); + _sumb = vextq_s32(_sumb, _sumb, 2); + _sumc = vextq_s32(_sumc, _sumc, 2); + _sumd = vextq_s32(_sumd, _sumd, 2); + _sume = vextq_s32(_sume, _sume, 2); + _sumf = vextq_s32(_sumf, _sumf, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); + int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); + int32x4x2_t _t2 = vzipq_s32(_sum2, _sume); + int32x4x2_t _t3 = vzipq_s32(_sum6, _suma); + int32x4x2_t _t4 = vzipq_s32(_sum3, _sumf); + int32x4x2_t _t5 = vzipq_s32(_sum7, _sumb); + int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); + int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); + _sum9 = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); + _suma = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); + _sumb = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); + _sumc = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); + _sumd = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); + _sume = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); + _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + _sum9 = vrev64q_s32(_sum9); + _sumb = vrev64q_s32(_sumb); + _sumd = vrev64q_s32(_sumd); + _sumf = vrev64q_s32(_sumf); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale0); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale0); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_suma), _descale0); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sumb), _descale0); + float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); + float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); + float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); + float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c0); + _fa = vaddq_f32(_fa, _c0); + _fb = vaddq_f32(_fb, _c0); + _fc = vaddq_f32(_fc, _c0); + _fd = vaddq_f32(_fd, _c0); + _fe = vaddq_f32(_fe, _c0); + _ff = vaddq_f32(_ff, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c1); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c1); + _fb = vaddq_f32(_fb, _c1); + _fc = vaddq_f32(_fc, _c1); + _fd = vaddq_f32(_fd, _c1); + _fe = vaddq_f32(_fe, _c1); + _ff = vaddq_f32(_ff, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 4) { _c0 = vld1q_f32(pC); _c1 = vld1q_f32(pC + 4); - if (beta != 1.f) + float32x4_t _c2 = vld1q_f32(pC + 8); + float32x4_t _c3 = vld1q_f32(pC + 12); + float32x4_t _c4 = vld1q_f32(pC + 16); + float32x4_t _c5 = vld1q_f32(pC + 20); + float32x4_t _c6 = vld1q_f32(pC + 24); + float32x4_t _c7 = vld1q_f32(pC + 28); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else { float32x4_t _beta = vdupq_n_f32(beta); - _c0 = vmulq_f32(_c0, _beta); - _c1 = vmulq_f32(_c1, _beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); } - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c1); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c1); - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c0); - _fb = vaddq_f32(_fb, _c1); - _fc = vaddq_f32(_fc, _c0); - _fd = vaddq_f32(_fd, _c1); - _fe = vaddq_f32(_fe, _c0); - _ff = vaddq_f32(_ff, _c1); - pC += 8; - } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); - _f8 = vmulq_f32(_f8, _alpha); - _f9 = vmulq_f32(_f9, _alpha); - _fa = vmulq_f32(_fa, _alpha); - _fb = vmulq_f32(_fb, _alpha); - _fc = vmulq_f32(_fc, _alpha); - _fd = vmulq_f32(_fd, _alpha); - _fe = vmulq_f32(_fe, _alpha); - _ff = vmulq_f32(_ff, _alpha); - } - - vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f2); - vst1q_f32(p0 + 8, _f4); - vst1q_f32(p0 + 12, _f6); - vst1q_f32(p0 + 16, _f8); - vst1q_f32(p0 + 20, _fa); - vst1q_f32(p0 + 24, _fc); - vst1q_f32(p0 + 28, _fe); - vst1q_f32(p0 + out_hstep * 4, _f1); - vst1q_f32(p0 + out_hstep * 4 + 4, _f3); - vst1q_f32(p0 + out_hstep * 4 + 8, _f5); - vst1q_f32(p0 + out_hstep * 4 + 12, _f7); - vst1q_f32(p0 + out_hstep * 4 + 16, _f9); - vst1q_f32(p0 + out_hstep * 4 + 20, _fb); - vst1q_f32(p0 + out_hstep * 4 + 24, _fd); - vst1q_f32(p0 + out_hstep * 4 + 28, _ff); - pp += 64; - p0 += out_hstep * 8; - } -#endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); - -#if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - - // to - // a0 a1 a2 a3 - // b0 b1 b2 b3 - // c0 c1 c2 c3 - // d0 d1 d2 d3 - // e0 e1 e2 e3 - // f0 f1 f2 f3 - // g0 g1 g2 g3 - // h0 h1 h2 h3 - { - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum3); - int32x4x2_t _t2 = vzipq_s32(_sum4, _sum5); - int32x4x2_t _t3 = vzipq_s32(_sum6, _sum7); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t1.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); - _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum6 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); - } -#else // __ARM_FEATURE_DOTPROD - - // from - // a0 b1 c2 d3 - // e0 f1 g2 h3 - // c0 d1 a2 b3 - // g0 h1 e2 f3 - // a3 b2 c1 d0 - // e3 f2 g1 h0 - // c3 d2 a1 b0 - // g3 h2 e1 f0 - - // to - // a0 a1 a2 a3 - // b0 b1 b2 b3 - // c0 c1 c2 c3 - // d0 d1 d2 d3 - // e0 e1 e2 e3 - // f0 f1 f2 f3 - // g0 g1 g2 g3 - // h0 h1 h2 h3 - { - _sum2 = vextq_s32(_sum2, _sum2, 2); - _sum3 = vextq_s32(_sum3, _sum3, 2); - _sum6 = vextq_s32(_sum6, _sum6, 2); - _sum7 = vextq_s32(_sum7, _sum7, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); - int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); - int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - _sum5 = vrev64q_s32(_sum5); - _sum7 = vrev64q_s32(_sum7); - } -#endif // __ARM_FEATURE_DOTPROD - -#if __aarch64__ - float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale0, 0); - float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale0, 1); - float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale0, 2); - float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale0, 3); - float32x4_t _f4 = vmulq_laneq_f32(vcvtq_f32_s32(_sum4), _descale1, 0); - float32x4_t _f5 = vmulq_laneq_f32(vcvtq_f32_s32(_sum5), _descale1, 1); - float32x4_t _f6 = vmulq_laneq_f32(vcvtq_f32_s32(_sum6), _descale1, 2); - float32x4_t _f7 = vmulq_laneq_f32(vcvtq_f32_s32(_sum7), _descale1, 3); -#else - float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), vget_low_f32(_descale0), 0); - float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), vget_low_f32(_descale0), 1); - float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), vget_high_f32(_descale0), 0); - float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), vget_high_f32(_descale0), 1); - float32x4_t _f4 = vmulq_lane_f32(vcvtq_f32_s32(_sum4), vget_low_f32(_descale1), 0); - float32x4_t _f5 = vmulq_lane_f32(vcvtq_f32_s32(_sum5), vget_low_f32(_descale1), 1); - float32x4_t _f6 = vmulq_lane_f32(vcvtq_f32_s32(_sum6), vget_high_f32(_descale1), 0); - float32x4_t _f7 = vmulq_lane_f32(vcvtq_f32_s32(_sum7), vget_high_f32(_descale1), 1); -#endif - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { -#if __aarch64__ - float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); - float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); - float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); - float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); -#else - float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); - float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); - float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); - float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); -#endif - _f0 = vaddq_f32(_f0, _cc0); - _f1 = vaddq_f32(_f1, _cc1); - _f2 = vaddq_f32(_f2, _cc2); - _f3 = vaddq_f32(_f3, _cc3); -#if __aarch64__ - _cc0 = vdupq_laneq_f32(_c1, 0); - _cc1 = vdupq_laneq_f32(_c1, 1); - _cc2 = vdupq_laneq_f32(_c1, 2); - _cc3 = vdupq_laneq_f32(_c1, 3); -#else - _cc0 = vdupq_lane_f32(vget_low_f32(_c1), 0); - _cc1 = vdupq_lane_f32(vget_low_f32(_c1), 1); - _cc2 = vdupq_lane_f32(vget_high_f32(_c1), 0); - _cc3 = vdupq_lane_f32(vget_high_f32(_c1), 1); -#endif - _f4 = vaddq_f32(_f4, _cc0); - _f5 = vaddq_f32(_f5, _cc1); - _f6 = vaddq_f32(_f6, _cc2); - _f7 = vaddq_f32(_f7, _cc3); - } - if (broadcast_type_C == 3) - { - if (c_elempack == 1) + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 4 + 4); + _c2 = vld1q_f32(pC + c_hstep * 4 + 8); + _c3 = vld1q_f32(pC + c_hstep * 4 + 12); + _c4 = vld1q_f32(pC + c_hstep * 4 + 16); + _c5 = vld1q_f32(pC + c_hstep * 4 + 20); + _c6 = vld1q_f32(pC + c_hstep * 4 + 24); + _c7 = vld1q_f32(pC + c_hstep * 4 + 28); + if (beta == 1.f) { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + c_hstep); - float32x4_t _c2 = vld1q_f32(pC + c_hstep * 2); - float32x4_t _c3 = vld1q_f32(pC + c_hstep * 3); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - _c0 = vld1q_f32(pC + c_hstep * 4); - _c1 = vld1q_f32(pC + c_hstep * 5); - _c2 = vld1q_f32(pC + c_hstep * 6); - _c3 = vld1q_f32(pC + c_hstep * 7); - if (beta == 1.f) - { - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c2); - _f7 = vaddq_f32(_f7, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f4 = vmlaq_f32(_f4, _c0, _beta); - _f5 = vmlaq_f32(_f5, _c1, _beta); - _f6 = vmlaq_f32(_f6, _c2, _beta); - _f7 = vmlaq_f32(_f7, _c3, _beta); - } - pC += 4; + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); } - else // if (c_elempack == 4) + else { - float32x4x4_t _cc0 = vld4q_f32(pC); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _cc0.val[0]); - _f1 = vaddq_f32(_f1, _cc0.val[1]); - _f2 = vaddq_f32(_f2, _cc0.val[2]); - _f3 = vaddq_f32(_f3, _cc0.val[3]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _cc0.val[0], _beta); - _f1 = vmlaq_f32(_f1, _cc0.val[1], _beta); - _f2 = vmlaq_f32(_f2, _cc0.val[2], _beta); - _f3 = vmlaq_f32(_f3, _cc0.val[3], _beta); - } - _cc0 = vld4q_f32(pC + c_hstep * 4); - if (beta == 1.f) - { - _f4 = vaddq_f32(_f4, _cc0.val[0]); - _f5 = vaddq_f32(_f5, _cc0.val[1]); - _f6 = vaddq_f32(_f6, _cc0.val[2]); - _f7 = vaddq_f32(_f7, _cc0.val[3]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f4 = vmlaq_f32(_f4, _cc0.val[0], _beta); - _f5 = vmlaq_f32(_f5, _cc0.val[1], _beta); - _f6 = vmlaq_f32(_f6, _cc0.val[2], _beta); - _f7 = vmlaq_f32(_f7, _cc0.val[3], _beta); - } - pC += 16; + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); } + pC += 32; } - if (broadcast_type_C == 4) + if (c_elempack == 1) { _c0 = vld1q_f32(pC); - _c0 = vmulq_n_f32(_c0, beta); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - pC += 4; - } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); - } - - vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f1); - vst1q_f32(p0 + 8, _f2); - vst1q_f32(p0 + 12, _f3); - vst1q_f32(p0 + 16, _f4); - vst1q_f32(p0 + 20, _f5); - vst1q_f32(p0 + 24, _f6); - vst1q_f32(p0 + 28, _f7); - pp += 32; - p0 += out_hstep * 4; - } - } - if (out_elempack == 1) - { - int jj = 0; -#if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); - int32x4_t _sum8 = vld1q_s32(pp + 32); - int32x4_t _sum9 = vld1q_s32(pp + 36); - int32x4_t _suma = vld1q_s32(pp + 40); - int32x4_t _sumb = vld1q_s32(pp + 44); - int32x4_t _sumc = vld1q_s32(pp + 48); - int32x4_t _sumd = vld1q_s32(pp + 52); - int32x4_t _sume = vld1q_s32(pp + 56); - int32x4_t _sumf = vld1q_s32(pp + 60); - -#if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - // e4 f4 g4 h4 - // e5 f5 g5 h5 - // e6 f6 g6 h6 - // e7 f7 g7 h7 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - // e4 f4 g4 h4 - // e5 f5 g5 h5 - // e6 f6 g6 h6 - // e7 f7 g7 h7 - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); - float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale0); - float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale0); - float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_suma), _descale0); - float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sumb), _descale0); - float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); - float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); - float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); - float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); - float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); - float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); - float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); - float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); -#else - // from - // a0 b1 c2 d3 - // e4 f5 g6 h7 - // e0 f1 g2 h3 - // a4 b5 c6 d7 - // c0 d1 a2 b3 - // g4 h5 e6 f7 - // g0 h1 e2 f3 - // c4 d5 a6 b7 - // a3 b2 c1 d0 - // e7 f6 g5 h4 - // e3 f2 g1 h0 - // a7 b6 c5 d4 - // c3 d2 a1 b0 - // g7 h6 e5 f4 - // g3 h2 e1 f0 - // c7 d6 a5 b4 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - // e4 f4 g4 h4 - // e5 f5 g5 h5 - // e6 f6 g6 h6 - // e7 f7 g7 h7 - { - _sum8 = vrev64q_s32(_sum8); - _sum9 = vrev64q_s32(_sum9); - _suma = vrev64q_s32(_suma); - _sumb = vrev64q_s32(_sumb); - _sumc = vrev64q_s32(_sumc); - _sumd = vrev64q_s32(_sumd); - _sume = vrev64q_s32(_sume); - _sumf = vrev64q_s32(_sumf); - _sum8 = vextq_s32(_sum8, _sum8, 2); - _sum9 = vextq_s32(_sum9, _sum9, 2); - _suma = vextq_s32(_suma, _suma, 2); - _sumb = vextq_s32(_sumb, _sumb, 2); - _sumc = vextq_s32(_sumc, _sumc, 2); - _sumd = vextq_s32(_sumd, _sumd, 2); - _sume = vextq_s32(_sume, _sume, 2); - _sumf = vextq_s32(_sumf, _sumf, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); - int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); - int32x4x2_t _t2 = vzipq_s32(_sum2, _sume); - int32x4x2_t _t3 = vzipq_s32(_sum6, _suma); - int32x4x2_t _t4 = vzipq_s32(_sum3, _sumf); - int32x4x2_t _t5 = vzipq_s32(_sum7, _sumb); - int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); - int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum4 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); - _sum5 = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); - _sum6 = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); - _sum8 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum9 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _suma = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sumb = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sumc = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); - _sumd = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); - _sume = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); - _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - _sum5 = vrev64q_s32(_sum5); - _sum7 = vrev64q_s32(_sum7); - _sum9 = vrev64q_s32(_sum9); - _sumb = vrev64q_s32(_sumb); - _sumd = vrev64q_s32(_sumd); - _sumf = vrev64q_s32(_sumf); - } - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); - float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale0); - float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale0); - float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale0); - float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale0); - float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale1); - float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale1); - float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_suma), _descale1); - float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sumb), _descale1); - float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); - float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); - float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); - float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); -#endif // __ARM_FEATURE_DOTPROD - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c0); - _fa = vaddq_f32(_fa, _c0); - _fb = vaddq_f32(_fb, _c0); - _fc = vaddq_f32(_fc, _c0); - _fd = vaddq_f32(_fd, _c0); - _fe = vaddq_f32(_fe, _c0); - _ff = vaddq_f32(_ff, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - _f8 = vaddq_f32(_f8, _c1); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c1); - _fb = vaddq_f32(_fb, _c1); - _fc = vaddq_f32(_fc, _c1); - _fd = vaddq_f32(_fd, _c1); - _fe = vaddq_f32(_fe, _c1); - _ff = vaddq_f32(_ff, _c1); - } - if (broadcast_type_C == 3) - { - if (c_elempack == 1) + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + c_hstep); + float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); + float32x4_t _c4 = vld1q_f32(pC + c_hstep * 2); + float32x4_t _c5 = vld1q_f32(pC + c_hstep * 2 + 4); + float32x4_t _c6 = vld1q_f32(pC + c_hstep * 3); + float32x4_t _c7 = vld1q_f32(pC + c_hstep * 3 + 4); + transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); + if (beta == 1.f) { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + c_hstep); - float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); - float32x4_t _c4 = vld1q_f32(pC + c_hstep * 2); - float32x4_t _c5 = vld1q_f32(pC + c_hstep * 2 + 4); - float32x4_t _c6 = vld1q_f32(pC + c_hstep * 3); - float32x4_t _c7 = vld1q_f32(pC + c_hstep * 3 + 4); - transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } - _c0 = vld1q_f32(pC + c_hstep * 4); - _c1 = vld1q_f32(pC + c_hstep * 4 + 4); - _c2 = vld1q_f32(pC + c_hstep * 5); - _c3 = vld1q_f32(pC + c_hstep * 5 + 4); - _c4 = vld1q_f32(pC + c_hstep * 6); - _c5 = vld1q_f32(pC + c_hstep * 6 + 4); - _c6 = vld1q_f32(pC + c_hstep * 7); - _c7 = vld1q_f32(pC + c_hstep * 7 + 4); - transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); - if (beta == 1.f) - { - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c2); - _fb = vaddq_f32(_fb, _c3); - _fc = vaddq_f32(_fc, _c4); - _fd = vaddq_f32(_fd, _c5); - _fe = vaddq_f32(_fe, _c6); - _ff = vaddq_f32(_ff, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f8 = vmlaq_f32(_f8, _c0, _beta); - _f9 = vmlaq_f32(_f9, _c1, _beta); - _fa = vmlaq_f32(_fa, _c2, _beta); - _fb = vmlaq_f32(_fb, _c3, _beta); - _fc = vmlaq_f32(_fc, _c4, _beta); - _fd = vmlaq_f32(_fd, _c5, _beta); - _fe = vmlaq_f32(_fe, _c6, _beta); - _ff = vmlaq_f32(_ff, _c7, _beta); - } - pC += 8; + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); } - else // if (c_elempack == 4) + else { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + 8); - float32x4_t _c3 = vld1q_f32(pC + 12); - float32x4_t _c4 = vld1q_f32(pC + 16); - float32x4_t _c5 = vld1q_f32(pC + 20); - float32x4_t _c6 = vld1q_f32(pC + 24); - float32x4_t _c7 = vld1q_f32(pC + 28); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } - _c0 = vld1q_f32(pC + c_hstep * 4); - _c1 = vld1q_f32(pC + c_hstep * 4 + 4); - _c2 = vld1q_f32(pC + c_hstep * 4 + 8); - _c3 = vld1q_f32(pC + c_hstep * 4 + 12); - _c4 = vld1q_f32(pC + c_hstep * 4 + 16); - _c5 = vld1q_f32(pC + c_hstep * 4 + 20); - _c6 = vld1q_f32(pC + c_hstep * 4 + 24); - _c7 = vld1q_f32(pC + c_hstep * 4 + 28); - if (beta == 1.f) - { - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c2); - _fb = vaddq_f32(_fb, _c3); - _fc = vaddq_f32(_fc, _c4); - _fd = vaddq_f32(_fd, _c5); - _fe = vaddq_f32(_fe, _c6); - _ff = vaddq_f32(_ff, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f8 = vmlaq_f32(_f8, _c0, _beta); - _f9 = vmlaq_f32(_f9, _c1, _beta); - _fa = vmlaq_f32(_fa, _c2, _beta); - _fb = vmlaq_f32(_fb, _c3, _beta); - _fc = vmlaq_f32(_fc, _c4, _beta); - _fd = vmlaq_f32(_fd, _c5, _beta); - _fe = vmlaq_f32(_fe, _c6, _beta); - _ff = vmlaq_f32(_ff, _c7, _beta); - } - pC += 32; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); } - } - if (broadcast_type_C == 4) - { - float32x4_t _cc0 = vld1q_f32(pC); - float32x4_t _cc1 = vld1q_f32(pC + 4); - if (beta != 1.f) + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 4 + 4); + _c2 = vld1q_f32(pC + c_hstep * 5); + _c3 = vld1q_f32(pC + c_hstep * 5 + 4); + _c4 = vld1q_f32(pC + c_hstep * 6); + _c5 = vld1q_f32(pC + c_hstep * 6 + 4); + _c6 = vld1q_f32(pC + c_hstep * 7); + _c7 = vld1q_f32(pC + c_hstep * 7 + 4); + transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else { float32x4_t _beta = vdupq_n_f32(beta); - _cc0 = vmulq_f32(_cc0, _beta); - _cc1 = vmulq_f32(_cc1, _beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); } - _c0 = vdupq_laneq_f32(_cc0, 0); - _c1 = vdupq_laneq_f32(_cc0, 1); - float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); - float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); - float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); - float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); - float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); - float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c2); - _fb = vaddq_f32(_fb, _c3); - _fc = vaddq_f32(_fc, _c4); - _fd = vaddq_f32(_fd, _c5); - _fe = vaddq_f32(_fe, _c6); - _ff = vaddq_f32(_ff, _c7); pC += 8; } } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); - _f8 = vmulq_f32(_f8, _alpha); - _f9 = vmulq_f32(_f9, _alpha); - _fa = vmulq_f32(_fa, _alpha); - _fb = vmulq_f32(_fb, _alpha); - _fc = vmulq_f32(_fc, _alpha); - _fd = vmulq_f32(_fd, _alpha); - _fe = vmulq_f32(_fe, _alpha); - _ff = vmulq_f32(_ff, _alpha); - } - + if (broadcast_type_C == 4) + { + float32x4_t _cc0 = vld1q_f32(pC); + float32x4_t _cc1 = vld1q_f32(pC + 4); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); + } + _c0 = vdupq_laneq_f32(_cc0, 0); + _c1 = vdupq_laneq_f32(_cc0, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); + float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); + float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); + float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); + float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + _f8 = vmulq_f32(_f8, _alpha); + _f9 = vmulq_f32(_f9, _alpha); + _fa = vmulq_f32(_fa, _alpha); + _fb = vmulq_f32(_fb, _alpha); + _fc = vmulq_f32(_fc, _alpha); + _fd = vmulq_f32(_fd, _alpha); + _fe = vmulq_f32(_fe, _alpha); + _ff = vmulq_f32(_ff, _alpha); + } + + if (out_elempack == 4) + { + float32x4x4_t _ffa; + float32x4x4_t _ffb; + float32x4x4_t _ffc; + float32x4x4_t _ffd; + _ffa.val[0] = _f0; + _ffa.val[1] = _f1; + _ffa.val[2] = _f2; + _ffa.val[3] = _f3; + _ffb.val[0] = _f4; + _ffb.val[1] = _f5; + _ffb.val[2] = _f6; + _ffb.val[3] = _f7; + _ffc.val[0] = _f8; + _ffc.val[1] = _f9; + _ffc.val[2] = _fa; + _ffc.val[3] = _fb; + _ffd.val[0] = _fc; + _ffd.val[1] = _fd; + _ffd.val[2] = _fe; + _ffd.val[3] = _ff; + vst4q_f32(p0, _ffa); + vst4q_f32(p0 + 16, _ffc); + vst4q_f32(p0 + out_hstep * 4, _ffb); + vst4q_f32(p0 + out_hstep * 4 + 16, _ffd); + } + if (out_elempack == 1) + { vst1q_f32(p0, _f0); vst1q_f32(p0 + 4, _f8); vst1q_f32(p0 + out_hstep, _f1); @@ -10346,1103 +8180,664 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma vst1q_f32(p0 + out_hstep * 6 + 4, _fe); vst1q_f32(p0 + out_hstep * 7, _f7); vst1q_f32(p0 + out_hstep * 7 + 4, _ff); - - pp += 64; - p0 += out_hstep * 8; } -#endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); -#if __ARM_FEATURE_DOTPROD - // from/to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 + pp += 64; + p0 += out_hstep * 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 #else - // from - // a0 b1 c2 d3 - // e0 f1 g2 h3 - // c0 d1 a2 b3 - // g0 h1 e2 f3 - // a3 b2 c1 d0 - // e3 f2 g1 h0 - // c3 d2 a1 b0 - // g3 h2 e1 f0 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - - { - _sum4 = vrev64q_s32(_sum4); - _sum5 = vrev64q_s32(_sum5); - _sum6 = vrev64q_s32(_sum6); - _sum7 = vrev64q_s32(_sum7); - _sum4 = vextq_s32(_sum4, _sum4, 2); - _sum5 = vextq_s32(_sum5, _sum5, 2); - _sum6 = vextq_s32(_sum6, _sum6, 2); - _sum7 = vextq_s32(_sum7, _sum7, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); - int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); - int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - _sum5 = vrev64q_s32(_sum5); - _sum7 = vrev64q_s32(_sum7); - } -#endif // __ARM_FEATURE_DOTPROD - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); - float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); - float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); - float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); - float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c1); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c1); - _f7 = vaddq_f32(_f7, _c1); - } - if (broadcast_type_C == 3) + // from + // a0 b1 c2 d3 + // e0 f1 g2 h3 + // c0 d1 a2 b3 + // g0 h1 e2 f3 + // a3 b2 c1 d0 + // e3 f2 g1 h0 + // c3 d2 a1 b0 + // g3 h2 e1 f0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c1); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c1); + _f7 = vaddq_f32(_f7, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 4) { - if (c_elempack == 1) + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + 8); + float32x4_t _c3 = vld1q_f32(pC + 12); + if (beta == 1.f) { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + c_hstep); - float32x4_t _c2 = vld1q_f32(pC + c_hstep * 2); - float32x4_t _c3 = vld1q_f32(pC + c_hstep * 3); - transpose4x4_ps(_c0, _c1, _c2, _c3); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - _c0 = vld1q_f32(pC + c_hstep * 4); - _c1 = vld1q_f32(pC + c_hstep * 5); - _c2 = vld1q_f32(pC + c_hstep * 6); - _c3 = vld1q_f32(pC + c_hstep * 7); - transpose4x4_ps(_c0, _c1, _c2, _c3); - if (beta == 1.f) - { - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c2); - _f7 = vaddq_f32(_f7, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f4 = vmlaq_f32(_f4, _c0, _beta); - _f5 = vmlaq_f32(_f5, _c1, _beta); - _f6 = vmlaq_f32(_f6, _c2, _beta); - _f7 = vmlaq_f32(_f7, _c3, _beta); - } - pC += 4; + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); } - else // if (c_elempack == 4) + else { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + 8); - float32x4_t _c3 = vld1q_f32(pC + 12); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - _c0 = vld1q_f32(pC + c_hstep * 4); - _c1 = vld1q_f32(pC + c_hstep * 4 + 4); - _c2 = vld1q_f32(pC + c_hstep * 4 + 8); - _c3 = vld1q_f32(pC + c_hstep * 4 + 12); - if (beta == 1.f) - { - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c2); - _f7 = vaddq_f32(_f7, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f4 = vmlaq_f32(_f4, _c0, _beta); - _f5 = vmlaq_f32(_f5, _c1, _beta); - _f6 = vmlaq_f32(_f6, _c2, _beta); - _f7 = vmlaq_f32(_f7, _c3, _beta); - } - pC += 16; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } - } - if (broadcast_type_C == 4) - { - float32x4_t _cc = vld1q_f32(pC); - _cc = vmulq_n_f32(_cc, beta); -#if __aarch64__ - _c0 = vdupq_laneq_f32(_cc, 0); - _c1 = vdupq_laneq_f32(_cc, 1); - float32x4_t _c2 = vdupq_laneq_f32(_cc, 2); - float32x4_t _c3 = vdupq_laneq_f32(_cc, 3); -#else - _c0 = vdupq_lane_f32(vget_low_f32(_cc), 0); - _c1 = vdupq_lane_f32(vget_low_f32(_cc), 1); - float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_cc), 0); - float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_cc), 1); -#endif - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c2); - _f7 = vaddq_f32(_f7, _c3); - pC += 4; - } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); - } - - vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f4); - vst1q_f32(p0 + out_hstep, _f1); - vst1q_f32(p0 + out_hstep + 4, _f5); - vst1q_f32(p0 + out_hstep * 2, _f2); - vst1q_f32(p0 + out_hstep * 2 + 4, _f6); - vst1q_f32(p0 + out_hstep * 3, _f3); - vst1q_f32(p0 + out_hstep * 3 + 4, _f7); - - pp += 32; - p0 += out_hstep * 4; - } - for (; jj + 1 < max_jj; jj += 2) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - -#if __ARM_FEATURE_DOTPROD - // from/to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // e0 f0 g0 h0 - // e1 f1 g1 h1 -#else - // from - // a0 b1 c0 d1 - // e0 f1 g0 h1 - // a1 b0 c1 d0 - // e1 f0 g1 h0 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - { - _sum2 = vrev64q_s32(_sum2); - _sum3 = vrev64q_s32(_sum3); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum2); - int32x4x2_t _t1 = vzipq_s32(_sum1, _sum3); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[0]), vget_low_s32(_t1.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[0]), vget_high_s32(_t1.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - } -#endif // __ARM_FEATURE_DOTPROD - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale1); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale1); - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c1); - _f3 = vaddq_f32(_f3, _c1); - } - if (broadcast_type_C == 3) - { - if (c_elempack == 1) + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 4 + 4); + _c2 = vld1q_f32(pC + c_hstep * 4 + 8); + _c3 = vld1q_f32(pC + c_hstep * 4 + 12); + if (beta == 1.f) { - float32x2_t _cc0 = vld1_f32(pC); - float32x2_t _cc1 = vld1_f32(pC + c_hstep); - float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); - float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); - float32x2_t _cc4 = vld1_f32(pC + c_hstep * 4); - float32x2_t _cc5 = vld1_f32(pC + c_hstep * 5); - float32x2_t _cc6 = vld1_f32(pC + c_hstep * 6); - float32x2_t _cc7 = vld1_f32(pC + c_hstep * 7); - float32x4_t _cc01 = vcombine_f32(_cc0, _cc1); - float32x4_t _cc23 = vcombine_f32(_cc2, _cc3); - float32x4_t _cc45 = vcombine_f32(_cc4, _cc5); - float32x4_t _cc67 = vcombine_f32(_cc6, _cc7); - float32x4x2_t _ccc0 = vuzpq_f32(_cc01, _cc23); - float32x4x2_t _ccc1 = vuzpq_f32(_cc45, _cc67); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _ccc0.val[0]); - _f1 = vaddq_f32(_f1, _ccc0.val[1]); - _f2 = vaddq_f32(_f2, _ccc1.val[0]); - _f3 = vaddq_f32(_f3, _ccc1.val[1]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _ccc0.val[0], _beta); - _f1 = vmlaq_f32(_f1, _ccc0.val[1], _beta); - _f2 = vmlaq_f32(_f2, _ccc1.val[0], _beta); - _f3 = vmlaq_f32(_f3, _ccc1.val[1], _beta); - } - pC += 2; + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); } - else // if (c_elempack == 4) + else { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + c_hstep * 4); - float32x4_t _c3 = vld1q_f32(pC + c_hstep * 4 + 4); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - pC += 8; + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); } + pC += 16; } - if (broadcast_type_C == 4) - { - float32x2_t _cc = vld1_f32(pC); - _cc = vmul_n_f32(_cc, beta); - _c0 = vdupq_lane_f32(_cc, 0); - _c1 = vdupq_lane_f32(_cc, 1); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c1); - pC += 2; - } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - } - - vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f2); - vst1q_f32(p0 + out_hstep, _f1); - vst1q_f32(p0 + out_hstep + 4, _f3); - - pp += 16; - p0 += out_hstep * 2; - } - for (; jj < max_jj; jj += 1) - { - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp + 4)), _descale1); - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - if (broadcast_type_C == 3) + if (c_elempack == 1) { - if (c_elempack == 1) - { - _c0 = vsetq_lane_f32(pC[0], _c0, 0); - _c0 = vsetq_lane_f32(pC[c_hstep], _c0, 1); - _c0 = vsetq_lane_f32(pC[c_hstep * 2], _c0, 2); - _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); - _c1 = vsetq_lane_f32(pC[c_hstep * 4], _c1, 0); - _c1 = vsetq_lane_f32(pC[c_hstep * 5], _c1, 1); - _c1 = vsetq_lane_f32(pC[c_hstep * 6], _c1, 2); - _c1 = vsetq_lane_f32(pC[c_hstep * 7], _c1, 3); - pC += 1; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + c_hstep * 4); - pC += 4; - } + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep); + float32x4_t _c2 = vld1q_f32(pC + c_hstep * 2); + float32x4_t _c3 = vld1q_f32(pC + c_hstep * 3); + transpose4x4_ps(_c0, _c1, _c2, _c3); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); } else { float32x4_t _beta = vdupq_n_f32(beta); _f0 = vmlaq_f32(_f0, _c0, _beta); _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } - } - if (broadcast_type_C == 4) - { - _c0 = vdupq_n_f32(pC[0] * beta); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - pC += 1; + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 5); + _c2 = vld1q_f32(pC + c_hstep * 6); + _c3 = vld1q_f32(pC + c_hstep * 7); + transpose4x4_ps(_c0, _c1, _c2, _c3); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); + } + pC += 4; } } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - } - + float32x4_t _cc = vld1q_f32(pC); + _cc = vmulq_n_f32(_cc, beta); +#if __aarch64__ + _c0 = vdupq_laneq_f32(_cc, 0); + _c1 = vdupq_laneq_f32(_cc, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc, 3); +#else + _c0 = vdupq_lane_f32(vget_low_f32(_cc), 0); + _c1 = vdupq_lane_f32(vget_low_f32(_cc), 1); + float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_cc), 0); + float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_cc), 1); +#endif + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + pC += 4; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + + if (out_elempack == 4) + { + float32x4x4_t _fa; + float32x4x4_t _fb; + _fa.val[0] = _f0; + _fa.val[1] = _f1; + _fa.val[2] = _f2; + _fa.val[3] = _f3; + _fb.val[0] = _f4; + _fb.val[1] = _f5; + _fb.val[2] = _f6; + _fb.val[3] = _f7; + vst4q_f32(p0, _fa); + vst4q_f32(p0 + 16, _fb); + } + if (out_elempack == 1) + { vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f1); - pp += 8; - p0 += out_hstep; + vst1q_f32(p0 + 4, _f4); + vst1q_f32(p0 + out_hstep, _f1); + vst1q_f32(p0 + out_hstep + 4, _f5); + vst1q_f32(p0 + out_hstep * 2, _f2); + vst1q_f32(p0 + out_hstep * 2 + 4, _f6); + vst1q_f32(p0 + out_hstep * 3, _f3); + vst1q_f32(p0 + out_hstep * 3 + 4, _f7); } + + pp += 32; + p0 += out_hstep * 4; } - } - for (; ii + 3 < max_ii; ii += 4) - { - float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * out_elempack; + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); - float32x4_t _descale = vld1q_f32((const float*)descales + ii); +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 +#else + // from + // a0 b1 c0 d1 + // e0 f1 g0 h1 + // a1 b0 c1 d0 + // e1 f0 g1 h0 - float32x4_t _c0; - if (pC) - { - if (broadcast_type_C == 0) + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 { - _c0 = vdupq_n_f32(pC[0] * beta); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - pC = (const float*)C + i + ii; - _c0 = vld1q_f32(pC); - _c0 = vmulq_n_f32(_c0, beta); + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum2); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[0]), vget_low_s32(_t1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[0]), vget_high_s32(_t1.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); } - if (broadcast_type_C == 3) - { - pC = (const float*)C + (i + ii) * c_hstep + j * c_elempack; - } - if (broadcast_type_C == 4) - { - pC = (const float*)C + j; - } - } +#endif // __ARM_FEATURE_DOTPROD - if (out_elempack == 4) - { - int jj = 0; -#if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale1); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale1); -#if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - - // to - // a0 a1 a2 a3 - // a4 a5 a6 a7 - // b0 b1 b2 b3 - // b4 b5 b6 b7 - // c0 c1 c2 c3 - // c4 c5 c6 c7 - // d0 d1 d2 d3 - // d4 d5 d6 d7 + if (pC) + { + if (broadcast_type_C == 0) { - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum3); - int32x4x2_t _t2 = vzipq_s32(_sum4, _sum5); - int32x4x2_t _t3 = vzipq_s32(_sum6, _sum7); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum4 = vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t1.val[1])); - _sum5 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); - _sum6 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); } -#else // __ARM_FEATURE_DOTPROD - - // from - // a0 b1 c2 d3 - // a4 b5 c6 d7 - // c0 d1 a2 b3 - // c4 d5 a6 b7 - // a3 b2 c1 d0 - // a7 b6 c5 d4 - // c3 d2 a1 b0 - // c7 d6 a5 b4 - - // to - // a0 a1 a2 a3 - // a4 a5 a6 a7 - // b0 b1 b2 b3 - // b4 b5 b6 b7 - // c0 c1 c2 c3 - // c4 c5 c6 c7 - // d0 d1 d2 d3 - // d4 d5 d6 d7 + if (broadcast_type_C == 1 || broadcast_type_C == 2) { - _sum2 = vextq_s32(_sum2, _sum2, 2); - _sum3 = vextq_s32(_sum3, _sum3, 2); - _sum6 = vextq_s32(_sum6, _sum6, 2); - _sum7 = vextq_s32(_sum7, _sum7, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); - int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); - int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum4 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum5 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sum6 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sum2 = vrev64q_s32(_sum2); - _sum3 = vrev64q_s32(_sum3); - _sum6 = vrev64q_s32(_sum6); - _sum7 = vrev64q_s32(_sum7); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); } -#endif // __ARM_FEATURE_DOTPROD - - float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale, 0); - float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale, 0); - float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale, 1); - float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale, 1); - float32x4_t _f4 = vmulq_laneq_f32(vcvtq_f32_s32(_sum4), _descale, 2); - float32x4_t _f5 = vmulq_laneq_f32(vcvtq_f32_s32(_sum5), _descale, 2); - float32x4_t _f6 = vmulq_laneq_f32(vcvtq_f32_s32(_sum6), _descale, 3); - float32x4_t _f7 = vmulq_laneq_f32(vcvtq_f32_s32(_sum7), _descale, 3); - - if (pC) + if (broadcast_type_C == 3) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + if (c_elempack == 1) { - float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); - float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); - float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); - float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); - _f0 = vaddq_f32(_f0, _cc0); - _f1 = vaddq_f32(_f1, _cc0); - _f2 = vaddq_f32(_f2, _cc1); - _f3 = vaddq_f32(_f3, _cc1); - _f4 = vaddq_f32(_f4, _cc2); - _f5 = vaddq_f32(_f5, _cc2); - _f6 = vaddq_f32(_f6, _cc3); - _f7 = vaddq_f32(_f7, _cc3); - } - if (broadcast_type_C == 3) - { - if (c_elempack == 1) + float32x2_t _cc0 = vld1_f32(pC); + float32x2_t _cc1 = vld1_f32(pC + c_hstep); + float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); + float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); + float32x2_t _cc4 = vld1_f32(pC + c_hstep * 4); + float32x2_t _cc5 = vld1_f32(pC + c_hstep * 5); + float32x2_t _cc6 = vld1_f32(pC + c_hstep * 6); + float32x2_t _cc7 = vld1_f32(pC + c_hstep * 7); + float32x4_t _cc01 = vcombine_f32(_cc0, _cc1); + float32x4_t _cc23 = vcombine_f32(_cc2, _cc3); + float32x4_t _cc45 = vcombine_f32(_cc4, _cc5); + float32x4_t _cc67 = vcombine_f32(_cc6, _cc7); + float32x4x2_t _ccc0 = vuzpq_f32(_cc01, _cc23); + float32x4x2_t _ccc1 = vuzpq_f32(_cc45, _cc67); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _ccc0.val[0]); + _f1 = vaddq_f32(_f1, _ccc0.val[1]); + _f2 = vaddq_f32(_f2, _ccc1.val[0]); + _f3 = vaddq_f32(_f3, _ccc1.val[1]); + } + else { - _c0 = vld1q_f32(pC); - float32x4_t _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + c_hstep); - float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); - float32x4_t _c4 = vld1q_f32(pC + c_hstep * 2); - float32x4_t _c5 = vld1q_f32(pC + c_hstep * 2 + 4); - float32x4_t _c6 = vld1q_f32(pC + c_hstep * 3); - float32x4_t _c7 = vld1q_f32(pC + c_hstep * 3 + 4); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } - pC += 8; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _ccc0.val[0], _beta); + _f1 = vmlaq_f32(_f1, _ccc0.val[1], _beta); + _f2 = vmlaq_f32(_f2, _ccc1.val[0], _beta); + _f3 = vmlaq_f32(_f3, _ccc1.val[1], _beta); } - else // if (c_elempack == 4) + pC += 2; + } + else // if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + c_hstep * 4); + float32x4_t _c3 = vld1q_f32(pC + c_hstep * 4 + 4); + if (beta == 1.f) { - float32x4x4_t _cc0 = vld4q_f32(pC); - float32x4x4_t _cc1 = vld4q_f32(pC + 16); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _cc0.val[0]); - _f1 = vaddq_f32(_f1, _cc1.val[0]); - _f2 = vaddq_f32(_f2, _cc0.val[1]); - _f3 = vaddq_f32(_f3, _cc1.val[1]); - _f4 = vaddq_f32(_f4, _cc0.val[2]); - _f5 = vaddq_f32(_f5, _cc1.val[2]); - _f6 = vaddq_f32(_f6, _cc0.val[3]); - _f7 = vaddq_f32(_f7, _cc1.val[3]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _cc0.val[0], _beta); - _f1 = vmlaq_f32(_f1, _cc1.val[0], _beta); - _f2 = vmlaq_f32(_f2, _cc0.val[1], _beta); - _f3 = vmlaq_f32(_f3, _cc1.val[1], _beta); - _f4 = vmlaq_f32(_f4, _cc0.val[2], _beta); - _f5 = vmlaq_f32(_f5, _cc1.val[2], _beta); - _f6 = vmlaq_f32(_f6, _cc0.val[3], _beta); - _f7 = vmlaq_f32(_f7, _cc1.val[3], _beta); - } - pC += 32; + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); } - } - if (broadcast_type_C == 4) - { - _c0 = vld1q_f32(pC); - float32x4_t _c1 = vld1q_f32(pC + 4); - if (beta != 1.f) + else { float32x4_t _beta = vdupq_n_f32(beta); - _c0 = vmulq_f32(_c0, _beta); - _c1 = vmulq_f32(_c1, _beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c1); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c1); pC += 8; } } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); + float32x2_t _cc = vld1_f32(pC); + _cc = vmul_n_f32(_cc, beta); + _c0 = vdupq_lane_f32(_cc, 0); + _c1 = vdupq_lane_f32(_cc, 1); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 2; } - - vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f2); - vst1q_f32(p0 + 8, _f4); - vst1q_f32(p0 + 12, _f6); - vst1q_f32(p0 + out_hstep * 4, _f1); - vst1q_f32(p0 + out_hstep * 4 + 4, _f3); - vst1q_f32(p0 + out_hstep * 4 + 8, _f5); - vst1q_f32(p0 + out_hstep * 4 + 12, _f7); - - pp += 32; - p0 += out_hstep * 8; } -#endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) + + if (alpha != 1.f) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } -#if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - - // to - // a0 a1 a2 a3 - // b0 b1 b2 b3 - // c0 c1 c2 c3 - // d0 d1 d2 d3 + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f2); + vst1q_f32(p0 + out_hstep, _f1); + vst1q_f32(p0 + out_hstep + 4, _f3); + + pp += 16; + p0 += out_hstep * 2; + } + for (; jj < max_jj; jj += 1) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp + 4)), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) { - int32x4x2_t _r01 = vzipq_s32(_sum0, _sum1); - int32x4x2_t _r23 = vzipq_s32(_sum2, _sum3); - _sum0 = vcombine_s32(vget_low_s32(_r01.val[0]), vget_low_s32(_r23.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_r01.val[0]), vget_high_s32(_r23.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_r01.val[1]), vget_low_s32(_r23.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_r01.val[1]), vget_high_s32(_r23.val[1])); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); } -#else - // from - // a0 b1 c2 d3 - // c0 d1 a2 b3 - // a3 b2 c1 d0 - // c3 d2 a1 b0 - - // to - // a0 a1 a2 a3 - // b0 b1 b2 b3 - // c0 c1 c2 c3 - // d0 d1 d2 d3 + if (broadcast_type_C == 1 || broadcast_type_C == 2) { - _sum1 = vextq_s32(_sum1, _sum1, 2); - _sum3 = vextq_s32(_sum3, _sum3, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); - int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); } -#endif // __ARM_FEATURE_DOTPROD - -#if __aarch64__ - float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale, 0); - float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale, 1); - float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale, 2); - float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale, 3); -#else - float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), vget_low_f32(_descale), 0); - float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), vget_low_f32(_descale), 1); - float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), vget_high_f32(_descale), 0); - float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), vget_high_f32(_descale), 1); -#endif - - if (pC) + if (broadcast_type_C == 3) { - if (broadcast_type_C == 0) + if (c_elempack == 1) { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); + _c0 = vsetq_lane_f32(pC[0], _c0, 0); + _c0 = vsetq_lane_f32(pC[c_hstep], _c0, 1); + _c0 = vsetq_lane_f32(pC[c_hstep * 2], _c0, 2); + _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); + _c1 = vsetq_lane_f32(pC[c_hstep * 4], _c1, 0); + _c1 = vsetq_lane_f32(pC[c_hstep * 5], _c1, 1); + _c1 = vsetq_lane_f32(pC[c_hstep * 6], _c1, 2); + _c1 = vsetq_lane_f32(pC[c_hstep * 7], _c1, 3); + pC += 1; } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + else // if (c_elempack == 4) { -#if __aarch64__ - float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); - float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); - float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); - float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); -#else - float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); - float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); - float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); - float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); -#endif - _f0 = vaddq_f32(_f0, _cc0); - _f1 = vaddq_f32(_f1, _cc1); - _f2 = vaddq_f32(_f2, _cc2); - _f3 = vaddq_f32(_f3, _cc3); + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep * 4); + pC += 4; } - if (broadcast_type_C == 3) + if (beta == 1.f) { - if (c_elempack == 1) - { - _c0 = vld1q_f32(pC); - float32x4_t _c1 = vld1q_f32(pC + c_hstep); - float32x4_t _c2 = vld1q_f32(pC + c_hstep * 2); - float32x4_t _c3 = vld1q_f32(pC + c_hstep * 3); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - pC += 4; - } - else // if (c_elempack == 4) - { - float32x4x4_t _c = vld4q_f32(pC); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c.val[0]); - _f1 = vaddq_f32(_f1, _c.val[1]); - _f2 = vaddq_f32(_f2, _c.val[2]); - _f3 = vaddq_f32(_f3, _c.val[3]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c.val[0], _beta); - _f1 = vmlaq_f32(_f1, _c.val[1], _beta); - _f2 = vmlaq_f32(_f2, _c.val[2], _beta); - _f3 = vmlaq_f32(_f3, _c.val[3], _beta); - } - pC += 16; - } + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); } - if (broadcast_type_C == 4) + else { - _c0 = vld1q_f32(pC); - _c0 = vmulq_n_f32(_c0, beta); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - pC += 4; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); + _c0 = vdupq_n_f32(pC[0] * beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 1; } + } - vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f1); - vst1q_f32(p0 + 8, _f2); - vst1q_f32(p0 + 12, _f3); - - pp += 16; - p0 += out_hstep * 4; + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + pp += 8; + p0 += out_hstep; } - if (out_elempack == 1) + } + for (; ii + 3 < max_ii; ii += 4) + { + float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * out_elempack; + + float32x4_t _descale = vld1q_f32((const float*)descales + ii); + + float32x4_t _c0; + if (pC) { - int jj = 0; -#if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) + if (broadcast_type_C == 0) + { + _c0 = vdupq_n_f32(pC[0] * beta); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)C + i + ii; + _c0 = vld1q_f32(pC); + _c0 = vmulq_n_f32(_c0, beta); + } + if (broadcast_type_C == 3) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); + pC = (const float*)C + (i + ii) * c_hstep + j * c_elempack; + } + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } -#if __ARM_FEATURE_DOTPROD - // from/to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 #else - // from - // a0 b1 c2 d3 - // a4 b5 c6 d7 - // c0 d1 a2 b3 - // c4 d5 a6 b7 - // a3 b2 c1 d0 - // a7 b6 c5 d4 - // c3 d2 a1 b0 - // c7 d6 a5 b4 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - { - _sum4 = vrev64q_s32(_sum4); - _sum5 = vrev64q_s32(_sum5); - _sum6 = vrev64q_s32(_sum6); - _sum7 = vrev64q_s32(_sum7); - _sum4 = vextq_s32(_sum4, _sum4, 2); - _sum5 = vextq_s32(_sum5, _sum5, 2); - _sum6 = vextq_s32(_sum6, _sum6, 2); - _sum7 = vextq_s32(_sum7, _sum7, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); - int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); - int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - _sum5 = vrev64q_s32(_sum5); - _sum7 = vrev64q_s32(_sum7); - } -#endif // __ARM_FEATURE_DOTPROD - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); - float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale); - float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale); - float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale); - float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale); - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + // from + // a0 b1 c2 d3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 3) + { + float32x4_t _c1; + float32x4_t _c2; + float32x4_t _c3; + float32x4_t _c4; + float32x4_t _c5; + float32x4_t _c6; + float32x4_t _c7; + if (c_elempack == 4) { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + _c2 = vld1q_f32(pC + 8); + _c3 = vld1q_f32(pC + 12); + _c4 = vld1q_f32(pC + 16); + _c5 = vld1q_f32(pC + 20); + _c6 = vld1q_f32(pC + 24); + _c7 = vld1q_f32(pC + 28); + pC += 32; } - if (broadcast_type_C == 3) + if (c_elempack == 1) { - float32x4_t _c1; - float32x4_t _c2; - float32x4_t _c3; - float32x4_t _c4; - float32x4_t _c5; - float32x4_t _c6; - float32x4_t _c7; - if (c_elempack == 1) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - _c2 = vld1q_f32(pC + c_hstep); - _c3 = vld1q_f32(pC + c_hstep + 4); - _c4 = vld1q_f32(pC + c_hstep * 2); - _c5 = vld1q_f32(pC + c_hstep * 2 + 4); - _c6 = vld1q_f32(pC + c_hstep * 3); - _c7 = vld1q_f32(pC + c_hstep * 3 + 4); - transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); - pC += 8; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - _c2 = vld1q_f32(pC + 8); - _c3 = vld1q_f32(pC + 12); - _c4 = vld1q_f32(pC + 16); - _c5 = vld1q_f32(pC + 20); - _c6 = vld1q_f32(pC + 24); - _c7 = vld1q_f32(pC + 28); - pC += 32; - } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + _c2 = vld1q_f32(pC + c_hstep); + _c3 = vld1q_f32(pC + c_hstep + 4); + _c4 = vld1q_f32(pC + c_hstep * 2); + _c5 = vld1q_f32(pC + c_hstep * 2 + 4); + _c6 = vld1q_f32(pC + c_hstep * 3); + _c7 = vld1q_f32(pC + c_hstep * 3 + 4); + transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); + pC += 8; } - if (broadcast_type_C == 4) + if (beta == 1.f) { - float32x4_t _cc0 = vld1q_f32(pC); - float32x4_t _cc1 = vld1q_f32(pC + 4); - if (beta != 1.f) - { - float32x4_t _beta = vdupq_n_f32(beta); - _cc0 = vmulq_f32(_cc0, _beta); - _cc1 = vmulq_f32(_cc1, _beta); - } - _c0 = vdupq_laneq_f32(_cc0, 0); - float32x4_t _c1 = vdupq_laneq_f32(_cc0, 1); - float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); - float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); - float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); - float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); - float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); - float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); @@ -11451,23 +8846,80 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _f5 = vaddq_f32(_f5, _c5); _f6 = vaddq_f32(_f6, _c6); _f7 = vaddq_f32(_f7, _c7); - pC += 8; } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); - } - + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + } + if (broadcast_type_C == 4) + { + float32x4_t _cc0 = vld1q_f32(pC); + float32x4_t _cc1 = vld1q_f32(pC + 4); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); + } + _c0 = vdupq_laneq_f32(_cc0, 0); + float32x4_t _c1 = vdupq_laneq_f32(_cc0, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); + float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); + float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); + float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); + float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + + if (out_elempack == 4) + { + float32x4x4_t _fa; + float32x4x4_t _fb; + _fa.val[0] = _f0; + _fa.val[1] = _f1; + _fa.val[2] = _f2; + _fa.val[3] = _f3; + _fb.val[0] = _f4; + _fb.val[1] = _f5; + _fb.val[2] = _f6; + _fb.val[3] = _f7; + vst4q_f32(p0, _fa); + vst4q_f32(p0 + out_hstep * 4, _fb); + } + if (out_elempack == 1) + { vst1q_f32(p0, _f0); vst1q_f32(p0 + out_hstep, _f1); vst1q_f32(p0 + out_hstep * 2, _f2); @@ -11476,296 +8928,308 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma vst1q_f32(p0 + out_hstep * 5, _f5); vst1q_f32(p0 + out_hstep * 6, _f6); vst1q_f32(p0 + out_hstep * 7, _f7); - - pp += 32; - p0 += out_hstep * 8; } + + pp += 32; + p0 += out_hstep * 8; + } #endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); #if __ARM_FEATURE_DOTPROD - // from/to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 #else - // from - // a0 b1 c2 d3 - // c0 d1 a2 b3 - // a3 b2 c1 d0 - // c3 d2 a1 b0 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - { - _sum2 = vrev64q_s32(_sum2); - _sum3 = vrev64q_s32(_sum3); - _sum2 = vextq_s32(_sum2, _sum2, 2); - _sum3 = vextq_s32(_sum3, _sum3, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); - int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - } + // from + // a0 b1 c2 d3 + // c0 d1 a2 b3 + // a3 b2 c1 d0 + // c3 d2 a1 b0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + { + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + _sum2 = vextq_s32(_sum2, _sum2, 2); + _sum3 = vextq_s32(_sum3, _sum3, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } #endif // __ARM_FEATURE_DOTPROD - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); - if (pC) + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 3) + { + float32x4_t _c1; + float32x4_t _c2; + float32x4_t _c3; + if (c_elempack == 4) { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + _c2 = vld1q_f32(pC + 8); + _c3 = vld1q_f32(pC + 12); + pC += 16; } - if (broadcast_type_C == 3) + if (c_elempack == 1) { - float32x4_t _c1; - float32x4_t _c2; - float32x4_t _c3; - if (c_elempack == 1) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + c_hstep); - _c2 = vld1q_f32(pC + c_hstep * 2); - _c3 = vld1q_f32(pC + c_hstep * 3); - transpose4x4_ps(_c0, _c1, _c2, _c3); - pC += 4; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - _c2 = vld1q_f32(pC + 8); - _c3 = vld1q_f32(pC + 12); - pC += 16; - } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep); + _c2 = vld1q_f32(pC + c_hstep * 2); + _c3 = vld1q_f32(pC + c_hstep * 3); + transpose4x4_ps(_c0, _c1, _c2, _c3); + pC += 4; } - if (broadcast_type_C == 4) + if (beta == 1.f) { - float32x4_t _cc = vld1q_f32(pC); - _cc = vmulq_n_f32(_cc, beta); -#if __aarch64__ - _c0 = vdupq_laneq_f32(_cc, 0); - float32x4_t _c1 = vdupq_laneq_f32(_cc, 1); - float32x4_t _c2 = vdupq_laneq_f32(_cc, 2); - float32x4_t _c3 = vdupq_laneq_f32(_cc, 3); -#else - _c0 = vdupq_lane_f32(vget_low_f32(_cc), 0); - float32x4_t _c1 = vdupq_lane_f32(vget_low_f32(_cc), 1); - float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_cc), 0); - float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_cc), 1); -#endif _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); _f3 = vaddq_f32(_f3, _c3); - pC += 4; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); + float32x4_t _cc = vld1q_f32(pC); + _cc = vmulq_n_f32(_cc, beta); +#if __aarch64__ + _c0 = vdupq_laneq_f32(_cc, 0); + float32x4_t _c1 = vdupq_laneq_f32(_cc, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc, 3); +#else + _c0 = vdupq_lane_f32(vget_low_f32(_cc), 0); + float32x4_t _c1 = vdupq_lane_f32(vget_low_f32(_cc), 1); + float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_cc), 0); + float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_cc), 1); +#endif + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 4; } + } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + + if (out_elempack == 4) + { + float32x4x4_t _f; + _f.val[0] = _f0; + _f.val[1] = _f1; + _f.val[2] = _f2; + _f.val[3] = _f3; + vst4q_f32(p0, _f); + } + if (out_elempack == 1) + { vst1q_f32(p0, _f0); vst1q_f32(p0 + out_hstep, _f1); vst1q_f32(p0 + out_hstep * 2, _f2); vst1q_f32(p0 + out_hstep * 3, _f3); - - pp += 16; - p0 += out_hstep * 4; } - for (; jj + 1 < max_jj; jj += 2) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); + + pp += 16; + p0 += out_hstep * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); #if __ARM_FEATURE_DOTPROD - // from/to - // a0 b0 c0 d0 - // a1 b1 c1 d1 + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 #else - // from - // a0 b1 c0 d1 - // a1 b0 c1 d0 + // from + // a0 b1 c0 d1 + // a1 b0 c1 d0 - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - { - _sum1 = vrev64q_s32(_sum1); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); - _sum1 = vrev64q_s32(_sum1); - } + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + { + _sum1 = vrev64q_s32(_sum1); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + } #endif // __ARM_FEATURE_DOTPROD - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - if (pC) + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 3) + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3) + { + float32x4_t _c1; + if (c_elempack == 1) { - float32x4_t _c1; - if (c_elempack == 1) - { - float32x2_t _cc0 = vld1_f32(pC); - float32x2_t _cc1 = vld1_f32(pC + c_hstep); - float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); - float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); - float32x4_t _cc01 = vcombine_f32(_cc0, _cc1); - float32x4_t _cc23 = vcombine_f32(_cc2, _cc3); - float32x4x2_t _cc = vuzpq_f32(_cc01, _cc23); - _c0 = _cc.val[0]; - _c1 = _cc.val[1]; - pC += 2; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - pC += 8; - } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } + float32x2_t _cc0 = vld1_f32(pC); + float32x2_t _cc1 = vld1_f32(pC + c_hstep); + float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); + float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); + float32x4_t _cc01 = vcombine_f32(_cc0, _cc1); + float32x4_t _cc23 = vcombine_f32(_cc2, _cc3); + float32x4x2_t _cc = vuzpq_f32(_cc01, _cc23); + _c0 = _cc.val[0]; + _c1 = _cc.val[1]; + pC += 2; + } + else // if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + pC += 8; } - if (broadcast_type_C == 4) + if (beta == 1.f) { - float32x2_t _c = vld1_f32(pC); - _c = vmul_n_f32(_c, beta); - _c0 = vdupq_lane_f32(_c, 0); - float32x4_t _c1 = vdupq_lane_f32(_c, 1); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); - pC += 2; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); + float32x2_t _c = vld1_f32(pC); + _c = vmul_n_f32(_c, beta); + _c0 = vdupq_lane_f32(_c, 0); + float32x4_t _c1 = vdupq_lane_f32(_c, 1); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 2; } - - vst1q_f32(p0, _f0); - vst1q_f32(p0 + out_hstep, _f1); - - pp += 8; - p0 += out_hstep * 2; } - for (; jj < max_jj; jj += 1) + + if (alpha != 1.f) { - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep, _f1); + + pp += 8; + p0 += out_hstep * 2; + } + for (; jj < max_jj; jj += 1) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); - if (pC) + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - } - if (broadcast_type_C == 3) + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) { - if (c_elempack == 1) - { - _c0 = vsetq_lane_f32(pC[0], _c0, 0); - _c0 = vsetq_lane_f32(pC[c_hstep], _c0, 1); - _c0 = vsetq_lane_f32(pC[c_hstep * 2], _c0, 2); - _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); - pC += 1; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - pC += 4; - } - _f0 = vmlaq_n_f32(_f0, _c0, beta); + _c0 = vsetq_lane_f32(pC[0], _c0, 0); + _c0 = vsetq_lane_f32(pC[c_hstep], _c0, 1); + _c0 = vsetq_lane_f32(pC[c_hstep * 2], _c0, 2); + _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); + pC += 1; } - if (broadcast_type_C == 4) + else // if (c_elempack == 4) { - _c0 = vdupq_n_f32(pC[0] * beta); - _f0 = vaddq_f32(_f0, _c0); - pC += 1; + _c0 = vld1q_f32(pC); + pC += 4; } + _f0 = vmlaq_n_f32(_f0, _c0, beta); + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(pC[0] * beta); + _f0 = vaddq_f32(_f0, _c0); + pC += 1; } + } - _f0 = vmulq_n_f32(_f0, alpha); + _f0 = vmulq_n_f32(_f0, alpha); - vst1q_f32(p0, _f0); - pp += 4; - p0 += out_hstep; - } + vst1q_f32(p0, _f0); + pp += 4; + p0 += out_hstep; } } #endif // __ARM_NEON @@ -11815,440 +9279,277 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma } } + int jj = 0; #if __ARM_NEON - if (out_elempack == 4) - { - int jj = 0; #if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); - float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale01, 0); - float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale01, 0); - float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), _descale01, 1); - float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), _descale01, 1); + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale01, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale01, 0); + float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), _descale01, 1); + float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), _descale01, 1); - if (pC) + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + c_hstep); + float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); + if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c1); - _f3 = vaddq_f32(_f3, _c1); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); } - if (broadcast_type_C == 3) + else { - // c_elempack == 1 - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + c_hstep); - float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - pC += 8; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } - if (broadcast_type_C == 4) + pC += 8; + } + if (broadcast_type_C == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + if (beta != 1.f) { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - if (beta != 1.f) - { - float32x4_t _beta = vdupq_n_f32(beta); - _c0 = vmulq_f32(_c0, _beta); - _c1 = vmulq_f32(_c1, _beta); - } - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c1); - pC += 8; + float32x4_t _beta = vdupq_n_f32(beta); + _c0 = vmulq_f32(_c0, _beta); + _c1 = vmulq_f32(_c1, _beta); } + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 8; } + } - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + if (out_elempack == 4) + { vst1q_f32(p0, _f0); vst1q_f32(p0 + 4, _f2); vst1q_f32(p0 + out_hstep * 4, _f1); vst1q_f32(p0 + out_hstep * 4 + 4, _f3); - - pp += 16; - p0 += out_hstep * 8; } -#endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) + if (out_elempack == 1) { - // a0 a1 a2 a3 - // b0 b1 b2 b3 + float32x4x2_t _f02 = vzipq_f32(_f0, _f2); + float32x4x2_t _f13 = vzipq_f32(_f1, _f3); + vst1_f32(p0, vget_low_f32(_f02.val[0])); + vst1_f32(p0 + out_hstep, vget_high_f32(_f02.val[0])); + vst1_f32(p0 + out_hstep * 2, vget_low_f32(_f02.val[1])); + vst1_f32(p0 + out_hstep * 3, vget_high_f32(_f02.val[1])); + vst1_f32(p0 + out_hstep * 4, vget_low_f32(_f13.val[0])); + vst1_f32(p0 + out_hstep * 5, vget_high_f32(_f13.val[0])); + vst1_f32(p0 + out_hstep * 6, vget_low_f32(_f13.val[1])); + vst1_f32(p0 + out_hstep * 7, vget_high_f32(_f13.val[1])); + } + + pp += 16; + p0 += out_hstep * 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + // a0 a1 a2 a3 + // b0 b1 b2 b3 - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); - float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale01, 0); - float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale01, 1); + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale01, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale01, 1); - if (pC) + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep); + if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); } - if (broadcast_type_C == 3) - { - // c_elempack == 1 - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + c_hstep); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } - pC += 4; - } - if (broadcast_type_C == 4) + else { - _c0 = vld1q_f32(pC); - _c0 = vmulq_n_f32(_c0, beta); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - pC += 4; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } + pC += 4; } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); + _c0 = vld1q_f32(pC); + _c0 = vmulq_n_f32(_c0, beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 4; } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + if (out_elempack == 4) + { vst1q_f32(p0, _f0); vst1q_f32(p0 + 4, _f1); - - pp += 8; - p0 += out_hstep * 4; } - } -#endif // __ARM_NEON - if (out_elempack == 1) - { - int jj = 0; -#if __ARM_NEON -#if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) + if (out_elempack == 1) { - // a0 a1 a2 a3 - // a4 a5 a6 a7 - // b0 b1 b2 b3 - // b4 b5 b6 b7 - - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); + float32x4x2_t _f01 = vzipq_f32(_f0, _f1); + vst1_f32(p0, vget_low_f32(_f01.val[0])); + vst1_f32(p0 + out_hstep, vget_high_f32(_f01.val[0])); + vst1_f32(p0 + out_hstep * 2, vget_low_f32(_f01.val[1])); + vst1_f32(p0 + out_hstep * 3, vget_high_f32(_f01.val[1])); + } - int32x4x2_t _sum02 = vzipq_s32(_sum0, _sum2); - int32x4x2_t _sum13 = vzipq_s32(_sum1, _sum3); + pp += 8; + p0 += out_hstep * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + // a0 a1 b0 b1 + int32x2x2_t _sum0 = vld2_s32(pp); - float32x4_t _descale = vcombine_f32(_descale01, _descale01); + float32x4_t _descale = vcombine_f32(_descale01, _descale01); - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum02.val[0]), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum02.val[1]), _descale); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum13.val[0]), _descale); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum13.val[1]), _descale); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vcombine_s32(_sum0.val[0], _sum0.val[1])), _descale); - if (pC) + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - float32x4_t _cc = vzipq_f32(_c0, _c1).val[0]; - _f0 = vaddq_f32(_f0, _cc); - _f1 = vaddq_f32(_f1, _cc); - _f2 = vaddq_f32(_f2, _cc); - _f3 = vaddq_f32(_f3, _cc); - } - if (broadcast_type_C == 3) - { - // c_elempack == 1 - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + c_hstep); - float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); - float32x4x2_t _c02 = vzipq_f32(_c0, _c2); - float32x4x2_t _c13 = vzipq_f32(_c1, _c3); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c02.val[0]); - _f1 = vaddq_f32(_f1, _c02.val[1]); - _f2 = vaddq_f32(_f2, _c13.val[0]); - _f3 = vaddq_f32(_f3, _c13.val[1]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c02.val[0], _beta); - _f1 = vmlaq_f32(_f1, _c02.val[1], _beta); - _f2 = vmlaq_f32(_f2, _c13.val[0], _beta); - _f3 = vmlaq_f32(_f3, _c13.val[1], _beta); - } - pC += 8; - } - if (broadcast_type_C == 4) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - if (beta != 1.f) - { - float32x4_t _beta = vdupq_n_f32(beta); - _c0 = vmulq_f32(_c0, _beta); - _c1 = vmulq_f32(_c1, _beta); - } - float32x4x2_t _cc0 = vzipq_f32(_c0, _c0); - float32x4x2_t _cc1 = vzipq_f32(_c1, _c1); - _f0 = vaddq_f32(_f0, _cc0.val[0]); - _f1 = vaddq_f32(_f1, _cc0.val[1]); - _f2 = vaddq_f32(_f2, _cc1.val[0]); - _f3 = vaddq_f32(_f3, _cc1.val[1]); - pC += 8; - } + _f0 = vaddq_f32(_f0, _c0); } - - if (alpha != 1.f) + if (broadcast_type_C == 1 || broadcast_type_C == 2) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); + float32x4_t _cc = vzipq_f32(_c0, _c1).val[0]; + _f0 = vaddq_f32(_f0, _cc); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + float32x2_t _cc0 = vld1_f32(pC); + float32x2_t _cc1 = vld1_f32(pC + c_hstep); + float32x2x2_t _c01 = vzip_f32(_cc0, _cc1); + _c0 = vcombine_f32(_c01.val[0], _c01.val[1]); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 2; + } + if (broadcast_type_C == 4) + { + float32x2_t _cc = vld1_f32(pC); + float32x2x2_t _c01 = vzip_f32(_cc, _cc); + _c0 = vcombine_f32(_c01.val[0], _c01.val[1]); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 2; } - - vst1_f32(p0, vget_low_f32(_f0)); - vst1_f32(p0 + out_hstep, vget_high_f32(_f0)); - vst1_f32(p0 + out_hstep * 2, vget_low_f32(_f1)); - vst1_f32(p0 + out_hstep * 3, vget_high_f32(_f1)); - vst1_f32(p0 + out_hstep * 4, vget_low_f32(_f2)); - vst1_f32(p0 + out_hstep * 5, vget_high_f32(_f2)); - vst1_f32(p0 + out_hstep * 6, vget_low_f32(_f3)); - vst1_f32(p0 + out_hstep * 7, vget_high_f32(_f3)); - - pp += 16; - p0 += out_hstep * 8; } -#endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) - { - // a0 a1 a2 a3 - // b0 b1 b2 b3 - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); + _f0 = vmulq_n_f32(_f0, alpha); - int32x4x2_t _sum01 = vzipq_s32(_sum0, _sum1); + vst1_f32(p0, vget_low_f32(_f0)); + vst1_f32(p0 + out_hstep, vget_high_f32(_f0)); - float32x4_t _descale = vcombine_f32(_descale01, _descale01); - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum01.val[0]), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum01.val[1]), _descale); + pp += 4; + p0 += out_hstep * 2; + } +#endif // __ARM_NEON + for (; jj < max_jj; jj += 1) + { + float f0 = pp[0] * descale0; + float f1 = pp[1] * descale1; - if (pC) + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - float32x4_t _cc = vzipq_f32(_c0, _c1).val[0]; - _f0 = vaddq_f32(_f0, _cc); - _f1 = vaddq_f32(_f1, _cc); - } - if (broadcast_type_C == 3) - { - // c_elempack == 1 - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + c_hstep); - float32x4x2_t _c01 = vzipq_f32(_c0, _c1); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c01.val[0]); - _f1 = vaddq_f32(_f1, _c01.val[1]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c01.val[0], _beta); - _f1 = vmlaq_f32(_f1, _c01.val[1], _beta); - } - pC += 4; - } - if (broadcast_type_C == 4) - { - _c0 = vld1q_f32(pC); - _c0 = vmulq_n_f32(_c0, beta); - float32x4x2_t _cc = vzipq_f32(_c0, _c0); - _f0 = vaddq_f32(_f0, _cc.val[0]); - _f1 = vaddq_f32(_f1, _cc.val[1]); - pC += 4; - } + f0 += c0; + f1 += c0; } - - if (alpha != 1.f) + if (broadcast_type_C == 1 || broadcast_type_C == 2) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); + f0 += c0; + f1 += c1; } - - vst1_f32(p0, vget_low_f32(_f0)); - vst1_f32(p0 + out_hstep, vget_high_f32(_f0)); - vst1_f32(p0 + out_hstep * 2, vget_low_f32(_f1)); - vst1_f32(p0 + out_hstep * 3, vget_high_f32(_f1)); - - pp += 8; - p0 += out_hstep * 4; - } - for (; jj + 1 < max_jj; jj += 2) - { - // a0 a1 b0 b1 - int32x2x2_t _sum0 = vld2_s32(pp); - - float32x4_t _descale = vcombine_f32(_descale01, _descale01); - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vcombine_s32(_sum0.val[0], _sum0.val[1])), _descale); - - if (pC) + if (broadcast_type_C == 3) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - float32x4_t _cc = vzipq_f32(_c0, _c1).val[0]; - _f0 = vaddq_f32(_f0, _cc); - } - if (broadcast_type_C == 3) - { - // c_elempack == 1 - float32x2_t _cc0 = vld1_f32(pC); - float32x2_t _cc1 = vld1_f32(pC + c_hstep); - float32x2x2_t _c01 = vzip_f32(_cc0, _cc1); - _c0 = vcombine_f32(_c01.val[0], _c01.val[1]); - _f0 = vmlaq_n_f32(_f0, _c0, beta); - pC += 2; - } - if (broadcast_type_C == 4) - { - float32x2_t _cc = vld1_f32(pC); - float32x2x2_t _c01 = vzip_f32(_cc, _cc); - _c0 = vcombine_f32(_c01.val[0], _c01.val[1]); - _f0 = vmlaq_n_f32(_f0, _c0, beta); - pC += 2; - } + // c_elempack == 1 + f0 += pC[0] * beta; + f1 += pC[c_hstep] * beta; + pC += 1; } - - _f0 = vmulq_n_f32(_f0, alpha); - - vst1_f32(p0, vget_low_f32(_f0)); - vst1_f32(p0 + out_hstep, vget_high_f32(_f0)); - - pp += 4; - p0 += out_hstep * 2; - } -#endif // __ARM_NEON - for (; jj < max_jj; jj += 1) - { - float f0 = pp[0] * descale0; - float f1 = pp[1] * descale1; - - if (pC) + if (broadcast_type_C == 4) { - if (broadcast_type_C == 0) - { - f0 += c0; - f1 += c0; - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - f0 += c0; - f1 += c1; - } - if (broadcast_type_C == 3) - { - // c_elempack == 1 - f0 += pC[0] * beta; - f1 += pC[c_hstep] * beta; - pC += 1; - } - if (broadcast_type_C == 4) - { - f0 += pC[0] * beta; - f1 += pC[0] * beta; - pC += 1; - } + f0 += pC[0] * beta; + f1 += pC[0] * beta; + pC += 1; } + } - f0 *= alpha; - f1 *= alpha; + f0 *= alpha; + f1 *= alpha; - p0[0] = f0; - p0[1] = f1; + p0[0] = f0; + p0[1] = f1; - pp += 2; - p0 += out_hstep; - } + pp += 2; + p0 += out_hstep; } } for (; ii < max_ii; ii += 1) @@ -12292,235 +9593,81 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma } } + int jj = 0; #if __ARM_NEON - if (out_elempack == 4) + for (; jj + 15 < max_jj; jj += 16) { - int jj = 0; - for (; jj + 15 < max_jj; jj += 16) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); - - if (pC) - { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // c_elempack == 1 - _c0 = vld1q_f32(pC); - float32x4_t _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + 8); - float32x4_t _c3 = vld1q_f32(pC + 12); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - pC += 16; - } - } + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - } - - if (out_hstep == 1) - { - vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f1); - vst1q_f32(p0 + 8, _f2); - vst1q_f32(p0 + 12, _f3); - } - else - { - vst1q_f32(p0, _f0); - vst1q_f32(p0 + out_hstep * 4, _f1); - vst1q_f32(p0 + out_hstep * 8, _f2); - vst1q_f32(p0 + out_hstep * 12, _f3); - } + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); - pp += 16; - p0 += out_hstep * 16; - } - for (; jj + 7 < max_jj; jj += 8) + if (pC) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - - if (pC) - { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // c_elempack == 1 - _c0 = vld1q_f32(pC); - float32x4_t _c1 = vld1q_f32(pC + 4); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } - pC += 8; - } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - } - - if (out_hstep == 1) - { - vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f1); - } - else + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) { - vst1q_f32(p0, _f0); - vst1q_f32(p0 + out_hstep * 4, _f1); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); } - - pp += 8; - p0 += out_hstep * 8; - } - for (; jj + 3 < max_jj; jj += 4) - { - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); - - if (pC) + if (broadcast_type_C == 3 || broadcast_type_C == 4) { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + // c_elempack == 1 + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + 8); + float32x4_t _c3 = vld1q_f32(pC + 12); + if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); } - if (broadcast_type_C == 3 || broadcast_type_C == 4) + else { - // c_elempack == 1 - _c0 = vld1q_f32(pC); - _f0 = vmlaq_n_f32(_f0, _c0, beta); - pC += 4; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } + pC += 16; } + } - _f0 = vmulq_n_f32(_f0, alpha); + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + if (out_hstep == 1) + { vst1q_f32(p0, _f0); - pp += 4; - p0 += out_hstep * 4; + vst1q_f32(p0 + 4, _f1); + vst1q_f32(p0 + 8, _f2); + vst1q_f32(p0 + 12, _f3); } - } -#endif // __ARM_NEON - if (out_elempack == 1) - { - int jj = 0; -#if __ARM_NEON - for (; jj + 15 < max_jj; jj += 16) + else { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); - - if (pC) - { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // c_elempack == 1 - _c0 = vld1q_f32(pC); - float32x4_t _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + 8); - float32x4_t _c3 = vld1q_f32(pC + 12); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - pC += 16; - } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - } - - if (out_hstep == 1) + if (out_elempack == 4) { vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f1); - vst1q_f32(p0 + 8, _f2); - vst1q_f32(p0 + 12, _f3); + vst1q_f32(p0 + out_hstep * 4, _f1); + vst1q_f32(p0 + out_hstep * 8, _f2); + vst1q_f32(p0 + out_hstep * 12, _f3); } - else + if (out_elempack == 1) { p0[0] = vgetq_lane_f32(_f0, 0); p0[out_hstep] = vgetq_lane_f32(_f0, 1); @@ -12539,58 +9686,66 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma p0[out_hstep * 14] = vgetq_lane_f32(_f3, 2); p0[out_hstep * 15] = vgetq_lane_f32(_f3, 3); } - - pp += 16; - p0 += out_hstep * 16; } - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + pp += 16; + p0 += out_hstep * 16; + } + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - if (pC) + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + // c_elempack == 1 + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + 4); + if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); + _f1 = vaddq_f32(_f1, _c1); } - if (broadcast_type_C == 3 || broadcast_type_C == 4) + else { - // out_elempack == 1 - _c0 = vld1q_f32(pC); - float32x4_t _c1 = vld1q_f32(pC + 4); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } - pC += 8; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } + pC += 8; } + } - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } - if (out_hstep == 1) + if (out_hstep == 1) + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + } + else + { + if (out_elempack == 4) { vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f1); + vst1q_f32(p0 + out_hstep * 4, _f1); } - else + if (out_elempack == 1) { p0[0] = vgetq_lane_f32(_f0, 0); p0[out_hstep] = vgetq_lane_f32(_f0, 1); @@ -12601,106 +9756,113 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma p0[out_hstep * 6] = vgetq_lane_f32(_f1, 2); p0[out_hstep * 7] = vgetq_lane_f32(_f1, 3); } - - pp += 8; - p0 += out_hstep * 8; } - for (; jj + 3 < max_jj; jj += 4) - { - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); - if (pC) + pp += 8; + p0 += out_hstep * 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // out_elempack == 1 - _c0 = vld1q_f32(pC); - _f0 = vmlaq_n_f32(_f0, _c0, beta); - pC += 4; - } + _f0 = vaddq_f32(_f0, _c0); } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + _c0 = vld1q_f32(pC); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 4; + } + } - _f0 = vmulq_n_f32(_f0, alpha); + _f0 = vmulq_n_f32(_f0, alpha); - if (out_hstep == 1) + if (out_hstep == 1) + { + vst1q_f32(p0, _f0); + } + else + { + if (out_elempack == 4) { vst1q_f32(p0, _f0); } - else + if (out_elempack == 1) { p0[0] = vgetq_lane_f32(_f0, 0); p0[out_hstep] = vgetq_lane_f32(_f0, 1); p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2); p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3); } - - pp += 4; - p0 += out_hstep * 4; } - for (; jj + 1 < max_jj; jj += 2) - { - float32x2_t _f0 = vmul_f32(vcvt_f32_s32(vld1_s32(pp)), vget_low_f32(_descale)); - - if (pC) - { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vadd_f32(_f0, vget_low_f32(_c0)); - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // c_elempack == 1 - float32x2_t _c = vld1_f32(pC); - _f0 = vmla_n_f32(_f0, _c, beta); - pC += 2; - } - } - _f0 = vmul_n_f32(_f0, alpha); + pp += 4; + p0 += out_hstep * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + float32x2_t _f0 = vmul_f32(vcvt_f32_s32(vld1_s32(pp)), vget_low_f32(_descale)); - if (out_hstep == 1) + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) { - vst1_f32(p0, _f0); + _f0 = vadd_f32(_f0, vget_low_f32(_c0)); } - else + if (broadcast_type_C == 3 || broadcast_type_C == 4) { - p0[0] = vget_lane_f32(_f0, 0); - p0[out_hstep] = vget_lane_f32(_f0, 1); + // c_elempack == 1 + float32x2_t _c = vld1_f32(pC); + _f0 = vmla_n_f32(_f0, _c, beta); + pC += 2; } + } - pp += 2; - p0 += out_hstep * 2; + _f0 = vmul_n_f32(_f0, alpha); + + if (out_hstep == 1) + { + vst1_f32(p0, _f0); } -#endif // __ARM_NEON - for (; jj < max_jj; jj += 1) + else { - float f0 = pp[0] * descale; + p0[0] = vget_lane_f32(_f0, 0); + p0[out_hstep] = vget_lane_f32(_f0, 1); + } + + pp += 2; + p0 += out_hstep * 2; + } +#endif // __ARM_NEON + for (; jj < max_jj; jj += 1) + { + float f0 = pp[0] * descale; - if (pC) + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - f0 += c0; - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // c_elempack == 1 - f0 += pC[0] * beta; - pC += 1; - } + f0 += c0; + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + f0 += pC[0] * beta; + pC += 1; } + } - f0 *= alpha; + f0 *= alpha; - p0[0] = f0; + p0[0] = f0; - pp += 1; - p0 += out_hstep; - } + pp += 1; + p0 += out_hstep; } } }