From 4b1b2b355e41545cd0e477fe1bdebe2638022e90 Mon Sep 17 00:00:00 2001 From: nihuini Date: Wed, 25 Sep 2024 17:55:50 +0800 Subject: [PATCH] stash --- src/layer/arm/gemm_arm.cpp | 16 +- src/layer/arm/gemm_arm_asimddp.cpp | 8 +- src/layer/arm/gemm_int8.h | 600 ++++---- src/layer/arm/gemm_int8_bf16s.h | 2056 +++++++++++++++++++++------- 4 files changed, 1790 insertions(+), 890 deletions(-) diff --git a/src/layer/arm/gemm_arm.cpp b/src/layer/arm/gemm_arm.cpp index 62efb5f6cac..522ffd7d704 100644 --- a/src/layer/arm/gemm_arm.cpp +++ b/src/layer/arm/gemm_arm.cpp @@ -5530,14 +5530,14 @@ static int gemm_arm_int8(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob if (output_transpose) { if (top_blob.elembits() == 16) - transpose_unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales); + transpose_unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); else transpose_unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); } else { if (top_blob.elembits() == 16) - unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales); + unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); else unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); } @@ -5728,14 +5728,14 @@ static int gemm_AT_arm_int8(const Mat& AT, const Mat& A_int8_scales, const Mat& if (output_transpose) { if (top_blob.elembits() == 16) - transpose_unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales); + transpose_unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); else transpose_unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); } else { if (top_blob.elembits() == 16) - unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales); + unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); else unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); } @@ -5849,14 +5849,14 @@ static int gemm_BT_arm_int8(const Mat& A, const Mat& BT, float B_int8_scale, con if (output_transpose) { if (top_blob.elembits() == 16) - transpose_unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales); + transpose_unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); else transpose_unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); } else { if (top_blob.elembits() == 16) - unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales); + unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); else unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); } @@ -5923,14 +5923,14 @@ static int gemm_AT_BT_arm_int8(const Mat& AT, const Mat& A_int8_scales, const Ma if (output_transpose) { if (top_blob.elembits() == 16) - transpose_unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales); + transpose_unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); else transpose_unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); } else { if (top_blob.elembits() == 16) - unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales); + unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); else unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); } diff --git a/src/layer/arm/gemm_arm_asimddp.cpp b/src/layer/arm/gemm_arm_asimddp.cpp index 821e3d2812e..518fbd46c76 100644 --- a/src/layer/arm/gemm_arm_asimddp.cpp +++ b/src/layer/arm/gemm_arm_asimddp.cpp @@ -100,14 +100,14 @@ void transpose_pack_B_tile_bf16_to_int8_asimddp(const Mat& B, Mat& BT, int j, in transpose_pack_B_tile_bf16_to_int8(B, BT, j, max_jj, k, max_kk, scale); } -void unpack_output_tile_int32_to_bf16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales) +void unpack_output_tile_int32_to_bf16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta) { - unpack_output_tile_int32_to_bf16(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales); + unpack_output_tile_int32_to_bf16(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta); } -void transpose_unpack_output_tile_int32_to_bf16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales) +void transpose_unpack_output_tile_int32_to_bf16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta) { - transpose_unpack_output_tile_int32_to_bf16(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales); + transpose_unpack_output_tile_int32_to_bf16(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta); } #endif // NCNN_BF16 diff --git a/src/layer/arm/gemm_int8.h b/src/layer/arm/gemm_int8.h index 42779200c34..10045d2f722 100644 --- a/src/layer/arm/gemm_int8.h +++ b/src/layer/arm/gemm_int8.h @@ -6007,7 +6007,6 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& _cc0 = vmulq_f32(_cc0, _beta); _cc1 = vmulq_f32(_cc1, _beta); } -#if __aarch64__ _c0 = vdupq_laneq_f32(_cc0, 0); _c1 = vdupq_laneq_f32(_cc0, 1); float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); @@ -6016,16 +6015,6 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); -#else - _c0 = vdupq_lane_f32(vget_low_f32(_cc0), 0); - _c1 = vdupq_lane_f32(vget_low_f32(_cc0), 1); - float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_cc0), 0); - float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_cc0), 1); - float32x4_t _c4 = vdupq_lane_f32(vget_low_f32(_cc1), 0); - float32x4_t _c5 = vdupq_lane_f32(vget_low_f32(_cc1), 1); - float32x4_t _c6 = vdupq_lane_f32(vget_high_f32(_cc1), 0); - float32x4_t _c7 = vdupq_lane_f32(vget_high_f32(_cc1), 1); -#endif _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); @@ -6194,23 +6183,13 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& { float32x4_t _c2; float32x4_t _c3; - float32x4_t _c4; - float32x4_t _c5; - float32x4_t _c6; - float32x4_t _c7; if (c_elempack == 1) { _c0 = vld1q_f32(pC); _c1 = vld1q_f32(pC + c_hstep); _c2 = vld1q_f32(pC + c_hstep * 2); _c3 = vld1q_f32(pC + c_hstep * 3); - _c4 = vld1q_f32(pC + c_hstep * 4); - _c5 = vld1q_f32(pC + c_hstep * 5); - _c6 = vld1q_f32(pC + c_hstep * 6); - _c7 = vld1q_f32(pC + c_hstep * 7); transpose4x4_ps(_c0, _c1, _c2, _c3); - transpose4x4_ps(_c4, _c5, _c6, _c7); - pC += 4; } else // if (c_elempack == 4) { @@ -6218,11 +6197,6 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& _c1 = vld1q_f32(pC + 4); _c2 = vld1q_f32(pC + 8); _c3 = vld1q_f32(pC + 12); - _c4 = vld1q_f32(pC + c_hstep * 4); - _c5 = vld1q_f32(pC + c_hstep * 4 + 4); - _c6 = vld1q_f32(pC + c_hstep * 4 + 8); - _c7 = vld1q_f32(pC + c_hstep * 4 + 12); - pC += 16; } if (beta == 1.f) { @@ -6230,10 +6204,6 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); } else { @@ -6242,19 +6212,44 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& _f1 = vmlaq_f32(_f1, _c1, _beta); _f2 = vmlaq_f32(_f2, _c2, _beta); _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); + } + if (c_elempack == 1) + { + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 5); + _c2 = vld1q_f32(pC + c_hstep * 6); + _c3 = vld1q_f32(pC + c_hstep * 7); + transpose4x4_ps(_c0, _c1, _c2, _c3); + pC += 4; + } + else // if (c_elempack == 4) + { + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 4 + 4); + _c2 = vld1q_f32(pC + c_hstep * 4 + 8); + _c3 = vld1q_f32(pC + c_hstep * 4 + 12); + pC += 16; + } + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); } } if (broadcast_type_C == 4) { float32x4_t _c = vld1q_f32(pC); - if (beta != 1.f) - { - _c = vmulq_n_f32(_c, beta); - } + _c = vmulq_n_f32(_c, beta); #if __aarch64__ _c0 = vdupq_laneq_f32(_c, 0); _c1 = vdupq_laneq_f32(_c, 1); @@ -6416,10 +6411,7 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& if (broadcast_type_C == 4) { float32x2_t _c = vld1_f32(pC); - if (beta != 1.f) - { - _c = vmul_n_f32(_c, beta); - } + _c = vmul_n_f32(_c, beta); _c0 = vdupq_lane_f32(_c, 0); _c1 = vdupq_lane_f32(_c, 1); _f0 = vaddq_f32(_f0, _c0); @@ -6481,7 +6473,7 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& _c1 = vsetq_lane_f32(pC[c_hstep * 7], _c1, 3); pC += 1; } - if (c_elempack == 4) + else // if (c_elempack == 4) { _c0 = vld1q_f32(pC); _c1 = vld1q_f32(pC + c_hstep * 4); @@ -6727,7 +6719,6 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 1 || broadcast_type_C == 2) { -#if __aarch64__ float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); @@ -6736,16 +6727,6 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& float32x4_t _cc5 = vdupq_laneq_f32(_c1, 1); float32x4_t _cc6 = vdupq_laneq_f32(_c1, 2); float32x4_t _cc7 = vdupq_laneq_f32(_c1, 3); -#else - float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); - float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); - float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); - float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); - float32x4_t _cc4 = vdupq_lane_f32(vget_low_f32(_c1), 0); - float32x4_t _cc5 = vdupq_lane_f32(vget_low_f32(_c1), 1); - float32x4_t _cc6 = vdupq_lane_f32(vget_high_f32(_c1), 0); - float32x4_t _cc7 = vdupq_lane_f32(vget_high_f32(_c1), 1); -#endif _f0 = vaddq_f32(_f0, _cc0); _f1 = vaddq_f32(_f1, _cc0); _f2 = vaddq_f32(_f2, _cc1); @@ -7087,28 +7068,31 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); - float32x4_t _cc4 = vdupq_laneq_f32(_c1, 0); - float32x4_t _cc5 = vdupq_laneq_f32(_c1, 1); - float32x4_t _cc6 = vdupq_laneq_f32(_c1, 2); - float32x4_t _cc7 = vdupq_laneq_f32(_c1, 3); #else float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); - float32x4_t _cc4 = vdupq_lane_f32(vget_low_f32(_c1), 0); - float32x4_t _cc5 = vdupq_lane_f32(vget_low_f32(_c1), 1); - float32x4_t _cc6 = vdupq_lane_f32(vget_high_f32(_c1), 0); - float32x4_t _cc7 = vdupq_lane_f32(vget_high_f32(_c1), 1); #endif _f0 = vaddq_f32(_f0, _cc0); _f1 = vaddq_f32(_f1, _cc1); _f2 = vaddq_f32(_f2, _cc2); _f3 = vaddq_f32(_f3, _cc3); - _f4 = vaddq_f32(_f4, _cc4); - _f5 = vaddq_f32(_f5, _cc5); - _f6 = vaddq_f32(_f6, _cc6); - _f7 = vaddq_f32(_f7, _cc7); +#if __aarch64__ + _cc0 = vdupq_laneq_f32(_c1, 0); + _cc1 = vdupq_laneq_f32(_c1, 1); + _cc2 = vdupq_laneq_f32(_c1, 2); + _cc3 = vdupq_laneq_f32(_c1, 3); +#else + _cc0 = vdupq_lane_f32(vget_low_f32(_c1), 0); + _cc1 = vdupq_lane_f32(vget_low_f32(_c1), 1); + _cc2 = vdupq_lane_f32(vget_high_f32(_c1), 0); + _cc3 = vdupq_lane_f32(vget_high_f32(_c1), 1); +#endif + _f4 = vaddq_f32(_f4, _cc0); + _f5 = vaddq_f32(_f5, _cc1); + _f6 = vaddq_f32(_f6, _cc2); + _f7 = vaddq_f32(_f7, _cc3); } if (broadcast_type_C == 3) { @@ -7118,20 +7102,12 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& _c1 = vld1q_f32(pC + c_hstep); float32x4_t _c2 = vld1q_f32(pC + c_hstep * 2); float32x4_t _c3 = vld1q_f32(pC + c_hstep * 3); - float32x4_t _c4 = vld1q_f32(pC + c_hstep * 4); - float32x4_t _c5 = vld1q_f32(pC + c_hstep * 5); - float32x4_t _c6 = vld1q_f32(pC + c_hstep * 6); - float32x4_t _c7 = vld1q_f32(pC + c_hstep * 7); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); } else { @@ -7140,27 +7116,37 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& _f1 = vmlaq_f32(_f1, _c1, _beta); _f2 = vmlaq_f32(_f2, _c2, _beta); _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 5); + _c2 = vld1q_f32(pC + c_hstep * 6); + _c3 = vld1q_f32(pC + c_hstep * 7); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); } pC += 4; } else // if (c_elempack == 4) { float32x4x4_t _cc0 = vld4q_f32(pC); - float32x4x4_t _cc1 = vld4q_f32(pC + c_hstep * 4); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _cc0.val[0]); _f1 = vaddq_f32(_f1, _cc0.val[1]); _f2 = vaddq_f32(_f2, _cc0.val[2]); _f3 = vaddq_f32(_f3, _cc0.val[3]); - _f4 = vaddq_f32(_f4, _cc1.val[0]); - _f5 = vaddq_f32(_f5, _cc1.val[1]); - _f6 = vaddq_f32(_f6, _cc1.val[2]); - _f7 = vaddq_f32(_f7, _cc1.val[3]); } else { @@ -7169,10 +7155,22 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& _f1 = vmlaq_f32(_f1, _cc0.val[1], _beta); _f2 = vmlaq_f32(_f2, _cc0.val[2], _beta); _f3 = vmlaq_f32(_f3, _cc0.val[3], _beta); - _f4 = vmlaq_f32(_f4, _cc1.val[0], _beta); - _f5 = vmlaq_f32(_f5, _cc1.val[1], _beta); - _f6 = vmlaq_f32(_f6, _cc1.val[2], _beta); - _f7 = vmlaq_f32(_f7, _cc1.val[3], _beta); + } + _cc0 = vld4q_f32(pC + c_hstep * 4); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _cc0.val[0]); + _f5 = vaddq_f32(_f5, _cc0.val[1]); + _f6 = vaddq_f32(_f6, _cc0.val[2]); + _f7 = vaddq_f32(_f7, _cc0.val[3]); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _cc0.val[0], _beta); + _f5 = vmlaq_f32(_f5, _cc0.val[1], _beta); + _f6 = vmlaq_f32(_f6, _cc0.val[2], _beta); + _f7 = vmlaq_f32(_f7, _cc0.val[3], _beta); } pC += 16; } @@ -7303,12 +7301,12 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& float32x2_t _cc1 = vld1_f32(pC + c_hstep); float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); + _c0 = vcombine_f32(_cc0, _cc1); + _c1 = vcombine_f32(_cc2, _cc3); float32x2_t _cc4 = vld1_f32(pC + c_hstep * 4); float32x2_t _cc5 = vld1_f32(pC + c_hstep * 5); float32x2_t _cc6 = vld1_f32(pC + c_hstep * 6); float32x2_t _cc7 = vld1_f32(pC + c_hstep * 7); - _c0 = vcombine_f32(_cc0, _cc1); - _c1 = vcombine_f32(_cc2, _cc3); float32x4_t _c2 = vcombine_f32(_cc4, _cc5); float32x4_t _c3 = vcombine_f32(_cc6, _cc7); if (beta == 1.f) @@ -7673,7 +7671,6 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& _cc0 = vmulq_f32(_cc0, _beta); _cc1 = vmulq_f32(_cc1, _beta); } -#if __aarch64__ _c0 = vdupq_laneq_f32(_cc0, 0); float32x4_t _c1 = vdupq_laneq_f32(_cc0, 1); float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); @@ -7682,16 +7679,6 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); -#else - _c0 = vdupq_lane_f32(vget_low_f32(_cc0), 0); - float32x4_t _c1 = vdupq_lane_f32(vget_low_f32(_cc0), 1); - float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_cc0), 0); - float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_cc0), 1); - float32x4_t _c4 = vdupq_lane_f32(vget_low_f32(_cc1), 0); - float32x4_t _c5 = vdupq_lane_f32(vget_low_f32(_cc1), 1); - float32x4_t _c6 = vdupq_lane_f32(vget_high_f32(_cc1), 0); - float32x4_t _c7 = vdupq_lane_f32(vget_high_f32(_cc1), 1); -#endif _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); @@ -7833,10 +7820,7 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& if (broadcast_type_C == 4) { float32x4_t _c = vld1q_f32(pC); - if (beta != 1.f) - { - _c = vmulq_n_f32(_c, beta); - } + _c = vmulq_n_f32(_c, beta); #if __aarch64__ _c0 = vdupq_laneq_f32(_c, 0); float32x4_t _c1 = vdupq_laneq_f32(_c, 1); @@ -7951,10 +7935,7 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& if (broadcast_type_C == 4) { float32x2_t _c = vld1_f32(pC); - if (beta != 1.f) - { - _c = vmul_n_f32(_c, beta); - } + _c = vmul_n_f32(_c, beta); _c0 = vdupq_lane_f32(_c, 0); float32x4_t _c1 = vdupq_lane_f32(_c, 1); _f0 = vaddq_f32(_f0, _c0); @@ -8000,19 +7981,12 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); pC += 1; } - if (c_elempack == 4) + else // if (c_elempack == 4) { _c0 = vld1q_f32(pC); pC += 4; } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - } - else - { - _f0 = vmlaq_n_f32(_f0, _c0, beta); - } + _f0 = vmlaq_n_f32(_f0, _c0, beta); } if (broadcast_type_C == 4) { @@ -8022,10 +7996,7 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& } } - if (alpha != 1.f) - { - _f0 = vmulq_n_f32(_f0, alpha); - } + _f0 = vmulq_n_f32(_f0, alpha); vst1q_f32(p0, _f0); @@ -8150,17 +8121,10 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 1 || broadcast_type_C == 2) { -#if __aarch64__ float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); -#else - float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); - float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); - float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); - float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); -#endif _f0 = vaddq_f32(_f0, _cc0); _f1 = vaddq_f32(_f1, _cc0); _f2 = vaddq_f32(_f2, _cc1); @@ -8528,7 +8492,7 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& } pC += 2; } - if (c_elempack == 4) + else // if (c_elempack == 4) { _c0 = vld1q_f32(pC); float32x4_t _c1 = vld1q_f32(pC + 4); @@ -8597,19 +8561,12 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); pC += 1; } - if (c_elempack == 4) + else // if (c_elempack == 4) { _c0 = vld1q_f32(pC); pC += 4; } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - } - else - { - _f0 = vmlaq_n_f32(_f0, _c0, beta); - } + _f0 = vmlaq_n_f32(_f0, _c0, beta); } if (broadcast_type_C == 4) { @@ -8619,10 +8576,7 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& } } - if (alpha != 1.f) - { - _f0 = vmulq_n_f32(_f0, alpha); - } + _f0 = vmulq_n_f32(_f0, alpha); p0[0] = vgetq_lane_f32(_f0, 0); p0[out_hstep] = vgetq_lane_f32(_f0, 1); @@ -8936,39 +8890,20 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& if (broadcast_type_C == 3) { // c_elempack == 1 - if (beta == 1.f) - { - f0 += pC[0]; - f1 += pC[c_hstep]; - } - else - { - f0 += pC[0] * beta; - f1 += pC[c_hstep] * beta; - } + f0 += pC[0] * beta; + f1 += pC[c_hstep] * beta; pC += 1; } if (broadcast_type_C == 4) { - if (beta == 1.f) - { - f0 += pC[0]; - f1 += pC[0]; - } - else - { - f0 += pC[0] * beta; - f1 += pC[0] * beta; - } + f0 += pC[0] * beta; + f1 += pC[0] * beta; pC += 1; } } - if (alpha != 1.f) - { - f0 *= alpha; - f1 *= alpha; - } + f0 *= alpha; + f1 *= alpha; p0[0] = f0; p0[out_hstep] = f1; @@ -9150,22 +9085,12 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& { // out_elempack == 1 _c0 = vld1q_f32(pC); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - } - else - { - _f0 = vmlaq_n_f32(_f0, _c0, beta); - } + _f0 = vmlaq_n_f32(_f0, _c0, beta); pC += 4; } } - if (alpha != 1.f) - { - _f0 = vmulq_n_f32(_f0, alpha); - } + _f0 = vmulq_n_f32(_f0, alpha); vst1q_f32(p0, _f0); @@ -9186,22 +9111,12 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& { // out_elempack == 1 float32x2_t _c = vld1_f32(pC); - if (beta == 1.f) - { - _f0 = vadd_f32(_f0, _c); - } - else - { - _f0 = vmla_n_f32(_f0, _c, beta); - } + _f0 = vmla_n_f32(_f0, _c, beta); pC += 2; } } - if (alpha != 1.f) - { - _f0 = vmul_n_f32(_f0, alpha); - } + _f0 = vmul_n_f32(_f0, alpha); vst1_f32(p0, _f0); @@ -9597,8 +9512,6 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma { float32x4x4_t _cc0 = vld4q_f32(pC); float32x4x4_t _cc1 = vld4q_f32(pC + 16); - float32x4x4_t _cc2 = vld4q_f32(pC + c_hstep * 4); - float32x4x4_t _cc3 = vld4q_f32(pC + c_hstep * 4 + 16); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _cc0.val[0]); @@ -9609,14 +9522,6 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _f5 = vaddq_f32(_f5, _cc1.val[2]); _f6 = vaddq_f32(_f6, _cc0.val[3]); _f7 = vaddq_f32(_f7, _cc1.val[3]); - _f8 = vaddq_f32(_f8, _cc2.val[0]); - _f9 = vaddq_f32(_f9, _cc3.val[0]); - _fa = vaddq_f32(_fa, _cc2.val[1]); - _fb = vaddq_f32(_fb, _cc3.val[1]); - _fc = vaddq_f32(_fc, _cc2.val[2]); - _fd = vaddq_f32(_fd, _cc3.val[2]); - _fe = vaddq_f32(_fe, _cc2.val[3]); - _ff = vaddq_f32(_ff, _cc3.val[3]); } else { @@ -9629,14 +9534,31 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _f5 = vmlaq_f32(_f5, _cc1.val[2], _beta); _f6 = vmlaq_f32(_f6, _cc0.val[3], _beta); _f7 = vmlaq_f32(_f7, _cc1.val[3], _beta); - _f8 = vmlaq_f32(_f8, _cc2.val[0], _beta); - _f9 = vmlaq_f32(_f9, _cc3.val[0], _beta); - _fa = vmlaq_f32(_fa, _cc2.val[1], _beta); - _fb = vmlaq_f32(_fb, _cc3.val[1], _beta); - _fc = vmlaq_f32(_fc, _cc2.val[2], _beta); - _fd = vmlaq_f32(_fd, _cc3.val[2], _beta); - _fe = vmlaq_f32(_fe, _cc2.val[3], _beta); - _ff = vmlaq_f32(_ff, _cc3.val[3], _beta); + } + _cc0 = vld4q_f32(pC + c_hstep * 4); + _cc1 = vld4q_f32(pC + c_hstep * 4 + 16); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _cc0.val[0]); + _f9 = vaddq_f32(_f9, _cc1.val[0]); + _fa = vaddq_f32(_fa, _cc0.val[1]); + _fb = vaddq_f32(_fb, _cc1.val[1]); + _fc = vaddq_f32(_fc, _cc0.val[2]); + _fd = vaddq_f32(_fd, _cc1.val[2]); + _fe = vaddq_f32(_fe, _cc0.val[3]); + _ff = vaddq_f32(_ff, _cc1.val[3]); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _cc0.val[0], _beta); + _f9 = vmlaq_f32(_f9, _cc1.val[0], _beta); + _fa = vmlaq_f32(_fa, _cc0.val[1], _beta); + _fb = vmlaq_f32(_fb, _cc1.val[1], _beta); + _fc = vmlaq_f32(_fc, _cc0.val[2], _beta); + _fd = vmlaq_f32(_fd, _cc1.val[2], _beta); + _fe = vmlaq_f32(_fe, _cc0.val[3], _beta); + _ff = vmlaq_f32(_ff, _cc1.val[3], _beta); } pC += 32; } @@ -9842,28 +9764,31 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); - float32x4_t _cc4 = vdupq_laneq_f32(_c1, 0); - float32x4_t _cc5 = vdupq_laneq_f32(_c1, 1); - float32x4_t _cc6 = vdupq_laneq_f32(_c1, 2); - float32x4_t _cc7 = vdupq_laneq_f32(_c1, 3); #else float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); - float32x4_t _cc4 = vdupq_lane_f32(vget_low_f32(_c1), 0); - float32x4_t _cc5 = vdupq_lane_f32(vget_low_f32(_c1), 1); - float32x4_t _cc6 = vdupq_lane_f32(vget_high_f32(_c1), 0); - float32x4_t _cc7 = vdupq_lane_f32(vget_high_f32(_c1), 1); #endif _f0 = vaddq_f32(_f0, _cc0); _f1 = vaddq_f32(_f1, _cc1); _f2 = vaddq_f32(_f2, _cc2); _f3 = vaddq_f32(_f3, _cc3); - _f4 = vaddq_f32(_f4, _cc4); - _f5 = vaddq_f32(_f5, _cc5); - _f6 = vaddq_f32(_f6, _cc6); - _f7 = vaddq_f32(_f7, _cc7); +#if __aarch64__ + _cc0 = vdupq_laneq_f32(_c1, 0); + _cc1 = vdupq_laneq_f32(_c1, 1); + _cc2 = vdupq_laneq_f32(_c1, 2); + _cc3 = vdupq_laneq_f32(_c1, 3); +#else + _cc0 = vdupq_lane_f32(vget_low_f32(_c1), 0); + _cc1 = vdupq_lane_f32(vget_low_f32(_c1), 1); + _cc2 = vdupq_lane_f32(vget_high_f32(_c1), 0); + _cc3 = vdupq_lane_f32(vget_high_f32(_c1), 1); +#endif + _f4 = vaddq_f32(_f4, _cc0); + _f5 = vaddq_f32(_f5, _cc1); + _f6 = vaddq_f32(_f6, _cc2); + _f7 = vaddq_f32(_f7, _cc3); } if (broadcast_type_C == 3) { @@ -9873,20 +9798,12 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _c1 = vld1q_f32(pC + c_hstep); float32x4_t _c2 = vld1q_f32(pC + c_hstep * 2); float32x4_t _c3 = vld1q_f32(pC + c_hstep * 3); - float32x4_t _c4 = vld1q_f32(pC + c_hstep * 4); - float32x4_t _c5 = vld1q_f32(pC + c_hstep * 5); - float32x4_t _c6 = vld1q_f32(pC + c_hstep * 6); - float32x4_t _c7 = vld1q_f32(pC + c_hstep * 7); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); } else { @@ -9895,27 +9812,37 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _f1 = vmlaq_f32(_f1, _c1, _beta); _f2 = vmlaq_f32(_f2, _c2, _beta); _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 5); + _c2 = vld1q_f32(pC + c_hstep * 6); + _c3 = vld1q_f32(pC + c_hstep * 7); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); } pC += 4; } else // if (c_elempack == 4) { float32x4x4_t _cc0 = vld4q_f32(pC); - float32x4x4_t _cc1 = vld4q_f32(pC + c_hstep * 4); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _cc0.val[0]); _f1 = vaddq_f32(_f1, _cc0.val[1]); _f2 = vaddq_f32(_f2, _cc0.val[2]); _f3 = vaddq_f32(_f3, _cc0.val[3]); - _f4 = vaddq_f32(_f4, _cc1.val[0]); - _f5 = vaddq_f32(_f5, _cc1.val[1]); - _f6 = vaddq_f32(_f6, _cc1.val[2]); - _f7 = vaddq_f32(_f7, _cc1.val[3]); } else { @@ -9924,10 +9851,22 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _f1 = vmlaq_f32(_f1, _cc0.val[1], _beta); _f2 = vmlaq_f32(_f2, _cc0.val[2], _beta); _f3 = vmlaq_f32(_f3, _cc0.val[3], _beta); - _f4 = vmlaq_f32(_f4, _cc1.val[0], _beta); - _f5 = vmlaq_f32(_f5, _cc1.val[1], _beta); - _f6 = vmlaq_f32(_f6, _cc1.val[2], _beta); - _f7 = vmlaq_f32(_f7, _cc1.val[3], _beta); + } + _cc0 = vld4q_f32(pC + c_hstep * 4); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _cc0.val[0]); + _f5 = vaddq_f32(_f5, _cc0.val[1]); + _f6 = vaddq_f32(_f6, _cc0.val[2]); + _f7 = vaddq_f32(_f7, _cc0.val[3]); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _cc0.val[0], _beta); + _f5 = vmlaq_f32(_f5, _cc0.val[1], _beta); + _f6 = vmlaq_f32(_f6, _cc0.val[2], _beta); + _f7 = vmlaq_f32(_f7, _cc0.val[3], _beta); } pC += 16; } @@ -9935,10 +9874,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma if (broadcast_type_C == 4) { _c0 = vld1q_f32(pC); - if (beta != 1.f) - { - _c0 = vmulq_n_f32(_c0, beta); - } + _c0 = vmulq_n_f32(_c0, beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c0); _f2 = vaddq_f32(_f2, _c0); @@ -10201,6 +10137,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma { if (c_elempack == 1) { + // TODO decompose 8x8 to 8x4 and 8x4 _c0 = vld1q_f32(pC); _c1 = vld1q_f32(pC + 4); float32x4_t _c2 = vld1q_f32(pC + c_hstep); @@ -10336,7 +10273,6 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _cc0 = vmulq_f32(_cc0, _beta); _cc1 = vmulq_f32(_cc1, _beta); } -#if __aarch64__ _c0 = vdupq_laneq_f32(_cc0, 0); _c1 = vdupq_laneq_f32(_cc0, 1); float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); @@ -10345,16 +10281,6 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); -#else - _c0 = vdupq_lane_f32(vget_low_f32(_cc0), 0); - _c1 = vdupq_lane_f32(vget_low_f32(_cc0), 1); - float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_cc0), 0); - float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_cc0), 1); - float32x4_t _c4 = vdupq_lane_f32(vget_low_f32(_cc1), 0); - float32x4_t _c5 = vdupq_lane_f32(vget_low_f32(_cc1), 1); - float32x4_t _c6 = vdupq_lane_f32(vget_high_f32(_cc1), 0); - float32x4_t _c7 = vdupq_lane_f32(vget_high_f32(_cc1), 1); -#endif _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); @@ -10525,6 +10451,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma { if (c_elempack == 1) { + // TODO decompose 4x8 to 4x4 and 4x4 _c0 = vld1q_f32(pC); _c1 = vld1q_f32(pC + c_hstep); float32x4_t _c2 = vld1q_f32(pC + c_hstep * 2); @@ -10565,20 +10492,12 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _c1 = vld1q_f32(pC + 4); float32x4_t _c2 = vld1q_f32(pC + 8); float32x4_t _c3 = vld1q_f32(pC + 12); - float32x4_t _c4 = vld1q_f32(pC + c_hstep * 4); - float32x4_t _c5 = vld1q_f32(pC + c_hstep * 4 + 4); - float32x4_t _c6 = vld1q_f32(pC + c_hstep * 4 + 8); - float32x4_t _c7 = vld1q_f32(pC + c_hstep * 4 + 12); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); } else { @@ -10587,10 +10506,25 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _f1 = vmlaq_f32(_f1, _c1, _beta); _f2 = vmlaq_f32(_f2, _c2, _beta); _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 4 + 4); + _c2 = vld1q_f32(pC + c_hstep * 4 + 8); + _c3 = vld1q_f32(pC + c_hstep * 4 + 12); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); } pC += 16; } @@ -10598,10 +10532,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma if (broadcast_type_C == 4) { float32x4_t _cc = vld1q_f32(pC); - if (beta != 1.f) - { - _cc = vmulq_n_f32(_cc, beta); - } + _cc = vmulq_n_f32(_cc, beta); #if __aarch64__ _c0 = vdupq_laneq_f32(_cc, 0); _c1 = vdupq_laneq_f32(_cc, 1); @@ -10715,7 +10646,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma if (c_elempack == 1) { float32x2_t _cc0 = vld1_f32(pC); - float32x2_t _cc1 = vld1_f32(pC + c_hstep * 1); + float32x2_t _cc1 = vld1_f32(pC + c_hstep); float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); float32x2_t _cc4 = vld1_f32(pC + c_hstep * 4); @@ -10772,10 +10703,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma if (broadcast_type_C == 4) { float32x2_t _cc = vld1_f32(pC); - if (beta != 1.f) - { - _cc = vmul_n_f32(_cc, beta); - } + _cc = vmul_n_f32(_cc, beta); _c0 = vdupq_lane_f32(_cc, 0); _c1 = vdupq_lane_f32(_cc, 1); _f0 = vaddq_f32(_f0, _c0); @@ -11291,10 +11219,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma if (broadcast_type_C == 4) { _c0 = vld1q_f32(pC); - if (beta != 1.f) - { - _c0 = vmulq_n_f32(_c0, beta); - } + _c0 = vmulq_n_f32(_c0, beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c0); _f2 = vaddq_f32(_f2, _c0); @@ -11489,7 +11414,12 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma { float32x4_t _cc0 = vld1q_f32(pC); float32x4_t _cc1 = vld1q_f32(pC + 4); -#if __aarch64__ + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); + } _c0 = vdupq_laneq_f32(_cc0, 0); float32x4_t _c1 = vdupq_laneq_f32(_cc0, 1); float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); @@ -11498,16 +11428,6 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); -#else - _c0 = vdupq_lane_f32(vget_low_f32(_cc0), 0); - float32x4_t _c1 = vdupq_lane_f32(vget_low_f32(_cc0), 1); - float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_cc0), 0); - float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_cc0), 1); - float32x4_t _c4 = vdupq_lane_f32(vget_low_f32(_cc1), 0); - float32x4_t _c5 = vdupq_lane_f32(vget_low_f32(_cc1), 1); - float32x4_t _c6 = vdupq_lane_f32(vget_high_f32(_cc1), 0); - float32x4_t _c7 = vdupq_lane_f32(vget_high_f32(_cc1), 1); -#endif _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); @@ -11649,10 +11569,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma if (broadcast_type_C == 4) { float32x4_t _cc = vld1q_f32(pC); - if (beta != 1.f) - { - _cc = vmulq_n_f32(_cc, beta); - } + _cc = vmulq_n_f32(_cc, beta); #if __aarch64__ _c0 = vdupq_laneq_f32(_cc, 0); float32x4_t _c1 = vdupq_laneq_f32(_cc, 1); @@ -11766,8 +11683,10 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(pC[0]); - float32x4_t _c1 = vdupq_n_f32(pC[1]); + float32x2_t _c = vld1_f32(pC); + _c = vmul_n_f32(_c, beta); + _c0 = vdupq_lane_f32(_c, 0); + float32x4_t _c1 = vdupq_lane_f32(_c, 1); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); pC += 2; @@ -11816,14 +11735,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _c0 = vld1q_f32(pC); pC += 4; } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - } - else - { - _f0 = vmlaq_n_f32(_f0, _c0, beta); - } + _f0 = vmlaq_n_f32(_f0, _c0, beta); } if (broadcast_type_C == 4) { @@ -11833,10 +11745,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma } } - if (alpha != 1.f) - { - _f0 = vmulq_n_f32(_f0, alpha); - } + _f0 = vmulq_n_f32(_f0, alpha); vst1q_f32(p0, _f0); pp += 4; @@ -11954,8 +11863,9 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _c1 = vld1q_f32(pC + 4); if (beta != 1.f) { - _c0 = vmulq_n_f32(_c0, beta); - _c1 = vmulq_n_f32(_c1, beta); + float32x4_t _beta = vdupq_n_f32(beta); + _c0 = vmulq_f32(_c0, _beta); + _c1 = vmulq_f32(_c1, _beta); } _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); @@ -12027,10 +11937,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma if (broadcast_type_C == 4) { _c0 = vld1q_f32(pC); - if (beta != 1.f) - { - _c0 = vmulq_n_f32(_c0, beta); - } + _c0 = vmulq_n_f32(_c0, beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c0); pC += 4; @@ -12128,8 +12035,9 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _c1 = vld1q_f32(pC + 4); if (beta != 1.f) { - _c0 = vmulq_n_f32(_c0, beta); - _c1 = vmulq_n_f32(_c1, beta); + float32x4_t _beta = vdupq_n_f32(beta); + _c0 = vmulq_f32(_c0, _beta); + _c1 = vmulq_f32(_c1, _beta); } float32x4x2_t _cc0 = vzipq_f32(_c0, _c0); float32x4x2_t _cc1 = vzipq_f32(_c1, _c1); @@ -12213,10 +12121,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma if (broadcast_type_C == 4) { _c0 = vld1q_f32(pC); - if (beta != 1.f) - { - _c0 = vmulq_n_f32(_c0, beta); - } + _c0 = vmulq_n_f32(_c0, beta); float32x4x2_t _cc = vzipq_f32(_c0, _c0); _f0 = vaddq_f32(_f0, _cc.val[0]); _f1 = vaddq_f32(_f1, _cc.val[1]); @@ -12333,11 +12238,8 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma } } - if (alpha != 1.f) - { - f0 *= alpha; - f1 *= alpha; - } + f0 *= alpha; + f1 *= alpha; p0[0] = f0; p0[1] = f1; @@ -12516,22 +12418,12 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma { // c_elempack == 1 _c0 = vld1q_f32(pC); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - } - else - { - _f0 = vmlaq_n_f32(_f0, _c0, beta); - } + _f0 = vmlaq_n_f32(_f0, _c0, beta); pC += 4; } } - if (alpha != 1.f) - { - _f0 = vmulq_n_f32(_f0, alpha); - } + _f0 = vmulq_n_f32(_f0, alpha); vst1q_f32(p0, _f0); pp += 4; @@ -12687,22 +12579,12 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma { // out_elempack == 1 _c0 = vld1q_f32(pC); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - } - else - { - _f0 = vmlaq_n_f32(_f0, _c0, beta); - } + _f0 = vmlaq_n_f32(_f0, _c0, beta); pC += 4; } } - if (alpha != 1.f) - { - _f0 = vmulq_n_f32(_f0, alpha); - } + _f0 = vmulq_n_f32(_f0, alpha); p0[0] = vgetq_lane_f32(_f0, 0); p0[out_hstep] = vgetq_lane_f32(_f0, 1); @@ -12726,22 +12608,12 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma { // c_elempack == 1 float32x2_t _c = vld1_f32(pC); - if (beta == 1.f) - { - _f0 = vadd_f32(_f0, _c); - } - else - { - _f0 = vmla_n_f32(_f0, _c, beta); - } + _f0 = vmla_n_f32(_f0, _c, beta); pC += 2; } } - if (alpha != 1.f) - { - _f0 = vmul_n_f32(_f0, alpha); - } + _f0 = vmul_n_f32(_f0, alpha); p0[0] = vget_lane_f32(_f0, 0); p0[out_hstep] = vget_lane_f32(_f0, 1); diff --git a/src/layer/arm/gemm_int8_bf16s.h b/src/layer/arm/gemm_int8_bf16s.h index 9cecdc89298..e1a6c0c1499 100644 --- a/src/layer/arm/gemm_int8_bf16s.h +++ b/src/layer/arm/gemm_int8_bf16s.h @@ -24,8 +24,8 @@ void pack_A_tile_bf16_to_int8_asimddp(const Mat& A, Mat& AT, int i, int max_ii, void transpose_pack_A_tile_bf16_to_int8_asimddp(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); void pack_B_tile_bf16_to_int8_asimddp(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); void transpose_pack_B_tile_bf16_to_int8_asimddp(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); -void unpack_output_tile_int32_to_bf16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales); -void transpose_unpack_output_tile_int32_to_bf16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales); +void unpack_output_tile_int32_to_bf16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta); +void transpose_unpack_output_tile_int32_to_bf16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta); #endif static void compute_A_tile_bf16_int8_scales(const Mat& A, Mat& scales, float B_scale, Mat& out_descales, int i, int max_ii) @@ -4097,12 +4097,12 @@ static void transpose_pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int } } -static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales) +static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta) { #if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 if (ncnn::cpu_support_arm_asimddp()) { - unpack_output_tile_int32_to_bf16_asimddp(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales); + unpack_output_tile_int32_to_bf16_asimddp(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta); return; } #endif @@ -4381,12 +4381,7 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& uint16x8_t _c23 = vld1q_u16(pC + c_hstep); uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); - uint16x8_t _c89 = vld1q_u16(pC + c_hstep * 4); - uint16x8_t _cab = vld1q_u16(pC + c_hstep * 5); - uint16x8_t _ccd = vld1q_u16(pC + c_hstep * 6); - uint16x8_t _cef = vld1q_u16(pC + c_hstep * 7); transpose8x4_u16(_c01, _c23, _c45, _c67); - transpose8x4_u16(_c89, _cab, _ccd, _cef); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); @@ -4395,42 +4390,73 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); - float32x4_t _c8 = bfloat2float(vget_low_u16(_c89)); - float32x4_t _c9 = bfloat2float(vget_high_u16(_c89)); - float32x4_t _ca = bfloat2float(vget_low_u16(_cab)); - float32x4_t _cb = bfloat2float(vget_high_u16(_cab)); - float32x4_t _cc = bfloat2float(vget_low_u16(_ccd)); - float32x4_t _cd = bfloat2float(vget_high_u16(_ccd)); - float32x4_t _ce = bfloat2float(vget_low_u16(_cef)); - float32x4_t _cf = bfloat2float(vget_high_u16(_cef)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - _f8 = vaddq_f32(_f8, _c8); - _f9 = vaddq_f32(_f9, _c9); - _fa = vaddq_f32(_fa, _ca); - _fb = vaddq_f32(_fb, _cb); - _fc = vaddq_f32(_fc, _cc); - _fd = vaddq_f32(_fd, _cd); - _fe = vaddq_f32(_fe, _ce); - _ff = vaddq_f32(_ff, _cf); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 5); + _c45 = vld1q_u16(pC + c_hstep * 6); + _c67 = vld1q_u16(pC + c_hstep * 7); + transpose8x4_u16(_c01, _c23, _c45, _c67); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + _c4 = bfloat2float(vget_low_u16(_c45)); + _c5 = bfloat2float(vget_high_u16(_c45)); + _c6 = bfloat2float(vget_low_u16(_c67)); + _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); + } pC += 8; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); uint16x8_t _c23 = vld1q_u16(pC + 8); uint16x8_t _c45 = vld1q_u16(pC + 16); uint16x8_t _c67 = vld1q_u16(pC + 24); - uint16x8_t _c89 = vld1q_u16(pC + c_hstep * 4); - uint16x8_t _cab = vld1q_u16(pC + c_hstep * 4 + 8); - uint16x8_t _ccd = vld1q_u16(pC + c_hstep * 4 + 16); - uint16x8_t _cef = vld1q_u16(pC + c_hstep * 4 + 24); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); @@ -4439,43 +4465,86 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); - float32x4_t _c8 = bfloat2float(vget_low_u16(_c89)); - float32x4_t _c9 = bfloat2float(vget_high_u16(_c89)); - float32x4_t _ca = bfloat2float(vget_low_u16(_cab)); - float32x4_t _cb = bfloat2float(vget_high_u16(_cab)); - float32x4_t _cc = bfloat2float(vget_low_u16(_ccd)); - float32x4_t _cd = bfloat2float(vget_high_u16(_ccd)); - float32x4_t _ce = bfloat2float(vget_low_u16(_cef)); - float32x4_t _cf = bfloat2float(vget_high_u16(_cef)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - _f8 = vaddq_f32(_f8, _c8); - _f9 = vaddq_f32(_f9, _c9); - _fa = vaddq_f32(_fa, _ca); - _fb = vaddq_f32(_fb, _cb); - _fc = vaddq_f32(_fc, _cc); - _fd = vaddq_f32(_fd, _cd); - _fe = vaddq_f32(_fe, _ce); - _ff = vaddq_f32(_ff, _cf); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 4 + 8); + _c45 = vld1q_u16(pC + c_hstep * 4 + 16); + _c67 = vld1q_u16(pC + c_hstep * 4 + 24); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + _c4 = bfloat2float(vget_low_u16(_c45)); + _c5 = bfloat2float(vget_high_u16(_c45)); + _c6 = bfloat2float(vget_low_u16(_c67)); + _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); + } pC += 32; } } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); - _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); - float32x4_t _c2 = vdupq_n_f32(bfloat16_to_float32(pC[2])); - float32x4_t _c3 = vdupq_n_f32(bfloat16_to_float32(pC[3])); - float32x4_t _c4 = vdupq_n_f32(bfloat16_to_float32(pC[4])); - float32x4_t _c5 = vdupq_n_f32(bfloat16_to_float32(pC[5])); - float32x4_t _c6 = vdupq_n_f32(bfloat16_to_float32(pC[6])); - float32x4_t _c7 = vdupq_n_f32(bfloat16_to_float32(pC[7])); + uint16x8_t _cc = vld1q_u16(pC); + float32x4_t _cc0 = bfloat2float(vget_low_u16(_cc)); + float32x4_t _cc1 = bfloat2float(vget_high_u16(_cc)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); + } + _c0 = vdupq_laneq_f32(_cc0, 0); + _c1 = vdupq_laneq_f32(_cc0, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); + float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); + float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); + float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); + float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); @@ -4496,6 +4565,27 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + _f8 = vmulq_f32(_f8, _alpha); + _f9 = vmulq_f32(_f9, _alpha); + _fa = vmulq_f32(_fa, _alpha); + _fb = vmulq_f32(_fb, _alpha); + _fc = vmulq_f32(_fc, _alpha); + _fd = vmulq_f32(_fd, _alpha); + _fe = vmulq_f32(_fe, _alpha); + _ff = vmulq_f32(_ff, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); vst1q_u16(p0 + 16, vcombine_u16(float2bfloat(_f4), float2bfloat(_f5))); @@ -4619,53 +4709,114 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& uint16x4_t _cc1 = vld1_u16(pC + c_hstep); uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); - uint16x4_t _cc4 = vld1_u16(pC + c_hstep * 4); - uint16x4_t _cc5 = vld1_u16(pC + c_hstep * 5); - uint16x4_t _cc6 = vld1_u16(pC + c_hstep * 6); - uint16x4_t _cc7 = vld1_u16(pC + c_hstep * 7); transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); - transpose4x4_u16(_cc4, _cc5, _cc6, _cc7); - _f0 = vaddq_f32(_f0, bfloat2float(_cc0)); - _f1 = vaddq_f32(_f1, bfloat2float(_cc1)); - _f2 = vaddq_f32(_f2, bfloat2float(_cc2)); - _f3 = vaddq_f32(_f3, bfloat2float(_cc3)); - _f4 = vaddq_f32(_f4, bfloat2float(_cc4)); - _f5 = vaddq_f32(_f5, bfloat2float(_cc5)); - _f6 = vaddq_f32(_f6, bfloat2float(_cc6)); - _f7 = vaddq_f32(_f7, bfloat2float(_cc7)); + _c0 = bfloat2float(_cc0); + _c1 = bfloat2float(_cc1); + float32x4_t _c2 = bfloat2float(_cc2); + float32x4_t _c3 = bfloat2float(_cc3); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + _cc0 = vld1_u16(pC + c_hstep * 4); + _cc1 = vld1_u16(pC + c_hstep * 5); + _cc2 = vld1_u16(pC + c_hstep * 6); + _cc3 = vld1_u16(pC + c_hstep * 7); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _c0 = bfloat2float(_cc0); + _c1 = bfloat2float(_cc1); + _c2 = bfloat2float(_cc2); + _c3 = bfloat2float(_cc3); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); + } pC += 4; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); uint16x8_t _c23 = vld1q_u16(pC + 8); - uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 4); - uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 4 + 8); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); - float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); - float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); - float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 4 + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); + } pC += 16; } } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); - _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); - float32x4_t _c2 = vdupq_n_f32(bfloat16_to_float32(pC[2])); - float32x4_t _c3 = vdupq_n_f32(bfloat16_to_float32(pC[3])); + float32x4_t _c = bfloat2float(vld1_u16(pC)); + _c = vmulq_n_f32(_c, beta); +#if __aarch64__ + _c0 = vdupq_laneq_f32(_c, 0); + _c1 = vdupq_laneq_f32(_c, 1); + float32x4_t _c2 = vdupq_laneq_f32(_c, 2); + float32x4_t _c3 = vdupq_laneq_f32(_c, 3); +#else + _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); + _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); + float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); + float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); +#endif _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); @@ -4678,6 +4829,19 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); vst1q_u16(p0 + out_hstep * 4, vcombine_u16(float2bfloat(_f4), float2bfloat(_f5))); @@ -4748,6 +4912,8 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 3) { + float32x4_t _c2; + float32x4_t _c3; if (c_elempack == 1) { uint16x8_t _c01 = uint16x8_t(); @@ -4770,33 +4936,40 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _c23 = vsetq_lane_u16(pC[c_hstep * 7 + 1], _c23, 7); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); pC += 2; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); uint16x8_t _c23 = vld1q_u16(pC + c_hstep * 4); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + pC += 8; + } + if (beta == 1.f) + { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); _f3 = vaddq_f32(_f3, _c3); - pC += 8; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); - _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); + _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1]) * beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c0); @@ -4805,6 +4978,15 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); vst1q_u16(p0 + out_hstep * 4, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); @@ -4846,28 +5028,42 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _c01 = vsetq_lane_u16(pC[c_hstep * 7], _c01, 7); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); pC += 1; } - if (c_elempack == 4) + else // if (c_elempack == 4) { _c0 = bfloat2float(vld1_u16(pC)); _c1 = bfloat2float(vld1_u16(pC + c_hstep * 4)); + pC += 4; + } + if (beta == 1.f) + { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); - pC += 4; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c0); pC += 1; } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + vst1_u16(p0, float2bfloat(_f0)); vst1_u16(p0 + out_hstep * 4, float2bfloat(_f1)); @@ -5080,7 +5276,6 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 1 || broadcast_type_C == 2) { -#if __aarch64__ float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); @@ -5089,16 +5284,6 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& float32x4_t _cc5 = vdupq_laneq_f32(_c1, 1); float32x4_t _cc6 = vdupq_laneq_f32(_c1, 2); float32x4_t _cc7 = vdupq_laneq_f32(_c1, 3); -#else - float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); - float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); - float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); - float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); - float32x4_t _cc4 = vdupq_lane_f32(vget_low_f32(_c1), 0); - float32x4_t _cc5 = vdupq_lane_f32(vget_low_f32(_c1), 1); - float32x4_t _cc6 = vdupq_lane_f32(vget_high_f32(_c1), 0); - float32x4_t _cc7 = vdupq_lane_f32(vget_high_f32(_c1), 1); -#endif _f0 = vaddq_f32(_f0, _cc0); _f1 = vaddq_f32(_f1, _cc0); _f2 = vaddq_f32(_f2, _cc1); @@ -5124,10 +5309,6 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& uint16x8_t _c23 = vld1q_u16(pC + c_hstep); uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); - uint16x8_t _c89 = vld1q_u16(pC + c_hstep * 4); - uint16x8_t _cab = vld1q_u16(pC + c_hstep * 5); - uint16x8_t _ccd = vld1q_u16(pC + c_hstep * 6); - uint16x8_t _cef = vld1q_u16(pC + c_hstep * 7); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); @@ -5136,36 +5317,69 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); - float32x4_t _c8 = bfloat2float(vget_low_u16(_c89)); - float32x4_t _c9 = bfloat2float(vget_high_u16(_c89)); - float32x4_t _ca = bfloat2float(vget_low_u16(_cab)); - float32x4_t _cb = bfloat2float(vget_high_u16(_cab)); - float32x4_t _cc = bfloat2float(vget_low_u16(_ccd)); - float32x4_t _cd = bfloat2float(vget_high_u16(_ccd)); - float32x4_t _ce = bfloat2float(vget_low_u16(_cef)); - float32x4_t _cf = bfloat2float(vget_high_u16(_cef)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - _f8 = vaddq_f32(_f8, _c8); - _f9 = vaddq_f32(_f9, _c9); - _fa = vaddq_f32(_fa, _ca); - _fb = vaddq_f32(_fb, _cb); - _fc = vaddq_f32(_fc, _cc); - _fd = vaddq_f32(_fd, _cd); - _fe = vaddq_f32(_fe, _ce); - _ff = vaddq_f32(_ff, _cf); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 5); + _c45 = vld1q_u16(pC + c_hstep * 6); + _c67 = vld1q_u16(pC + c_hstep * 7); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + _c4 = bfloat2float(vget_low_u16(_c45)); + _c5 = bfloat2float(vget_high_u16(_c45)); + _c6 = bfloat2float(vget_low_u16(_c67)); + _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); + } pC += 8; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x8x4_t _cc0 = vld4q_u16(pC); - uint16x8x4_t _cc1 = vld4q_u16(pC + c_hstep * 4); _f0 = vaddq_f32(_f0, bfloat2float(vget_low_u16(_cc0.val[0]))); _f1 = vaddq_f32(_f1, bfloat2float(vget_high_u16(_cc0.val[0]))); _f2 = vaddq_f32(_f2, bfloat2float(vget_low_u16(_cc0.val[1]))); @@ -5174,14 +5388,15 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _f5 = vaddq_f32(_f5, bfloat2float(vget_high_u16(_cc0.val[2]))); _f6 = vaddq_f32(_f6, bfloat2float(vget_low_u16(_cc0.val[3]))); _f7 = vaddq_f32(_f7, bfloat2float(vget_high_u16(_cc0.val[3]))); - _f8 = vaddq_f32(_f8, bfloat2float(vget_low_u16(_cc1.val[0]))); - _f9 = vaddq_f32(_f9, bfloat2float(vget_high_u16(_cc1.val[0]))); - _fa = vaddq_f32(_fa, bfloat2float(vget_low_u16(_cc1.val[1]))); - _fb = vaddq_f32(_fb, bfloat2float(vget_high_u16(_cc1.val[1]))); - _fc = vaddq_f32(_fc, bfloat2float(vget_low_u16(_cc1.val[2]))); - _fd = vaddq_f32(_fd, bfloat2float(vget_high_u16(_cc1.val[2]))); - _fe = vaddq_f32(_fe, bfloat2float(vget_low_u16(_cc1.val[3]))); - _ff = vaddq_f32(_ff, bfloat2float(vget_high_u16(_cc1.val[3]))); + _cc0 = vld4q_u16(pC + c_hstep * 4); + _f8 = vaddq_f32(_f8, bfloat2float(vget_low_u16(_cc0.val[0]))); + _f9 = vaddq_f32(_f9, bfloat2float(vget_high_u16(_cc0.val[0]))); + _fa = vaddq_f32(_fa, bfloat2float(vget_low_u16(_cc0.val[1]))); + _fb = vaddq_f32(_fb, bfloat2float(vget_high_u16(_cc0.val[1]))); + _fc = vaddq_f32(_fc, bfloat2float(vget_low_u16(_cc0.val[2]))); + _fd = vaddq_f32(_fd, bfloat2float(vget_high_u16(_cc0.val[2]))); + _fe = vaddq_f32(_fe, bfloat2float(vget_low_u16(_cc0.val[3]))); + _ff = vaddq_f32(_ff, bfloat2float(vget_high_u16(_cc0.val[3]))); pC += 32; } } @@ -5190,6 +5405,12 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& uint16x8_t _c = vld1q_u16(pC); _c0 = bfloat2float(vget_low_u16(_c)); _c1 = bfloat2float(vget_high_u16(_c)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _c0 = vmulq_f32(_c0, _beta); + _c1 = vmulq_f32(_c1, _beta); + } _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c0); @@ -5210,6 +5431,27 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + _f8 = vmulq_f32(_f8, _alpha); + _f9 = vmulq_f32(_f9, _alpha); + _fa = vmulq_f32(_fa, _alpha); + _fb = vmulq_f32(_fb, _alpha); + _fc = vmulq_f32(_fc, _alpha); + _fd = vmulq_f32(_fd, _alpha); + _fe = vmulq_f32(_fe, _alpha); + _ff = vmulq_f32(_ff, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); vst1q_u16(p0 + out_hstep, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); vst1q_u16(p0 + out_hstep * 2, vcombine_u16(float2bfloat(_f4), float2bfloat(_f5))); @@ -5383,38 +5625,61 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _c1 = bfloat2float(vld1_u16(pC + c_hstep * 1)); float32x4_t _c2 = bfloat2float(vld1_u16(pC + c_hstep * 2)); float32x4_t _c3 = bfloat2float(vld1_u16(pC + c_hstep * 3)); - float32x4_t _c4 = bfloat2float(vld1_u16(pC + c_hstep * 4)); - float32x4_t _c5 = bfloat2float(vld1_u16(pC + c_hstep * 5)); - float32x4_t _c6 = bfloat2float(vld1_u16(pC + c_hstep * 6)); - float32x4_t _c7 = bfloat2float(vld1_u16(pC + c_hstep * 7)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + _c0 = bfloat2float(vld1_u16(pC + c_hstep * 4)); + _c1 = bfloat2float(vld1_u16(pC + c_hstep * 5)); + _c2 = bfloat2float(vld1_u16(pC + c_hstep * 6)); + _c3 = bfloat2float(vld1_u16(pC + c_hstep * 7)); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); + } pC += 4; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x4x4_t _cc0 = vld4_u16(pC); - uint16x4x4_t _cc1 = vld4_u16(pC + c_hstep * 4); _f0 = vaddq_f32(_f0, bfloat2float(_cc0.val[0])); _f1 = vaddq_f32(_f1, bfloat2float(_cc0.val[1])); _f2 = vaddq_f32(_f2, bfloat2float(_cc0.val[2])); _f3 = vaddq_f32(_f3, bfloat2float(_cc0.val[3])); - _f4 = vaddq_f32(_f4, bfloat2float(_cc1.val[0])); - _f5 = vaddq_f32(_f5, bfloat2float(_cc1.val[1])); - _f6 = vaddq_f32(_f6, bfloat2float(_cc1.val[2])); - _f7 = vaddq_f32(_f7, bfloat2float(_cc1.val[3])); + _cc0 = vld4_u16(pC + c_hstep * 4); + _f4 = vaddq_f32(_f4, bfloat2float(_cc0.val[0])); + _f5 = vaddq_f32(_f5, bfloat2float(_cc0.val[1])); + _f6 = vaddq_f32(_f6, bfloat2float(_cc0.val[2])); + _f7 = vaddq_f32(_f7, bfloat2float(_cc0.val[3])); pC += 16; } } if (broadcast_type_C == 4) { _c0 = bfloat2float(vld1_u16(pC)); + _c0 = vmulq_n_f32(_c0, beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c0); _f2 = vaddq_f32(_f2, _c0); @@ -5427,6 +5692,19 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + vst1_u16(p0, float2bfloat(_f0)); vst1_u16(p0 + out_hstep, float2bfloat(_f1)); vst1_u16(p0 + out_hstep * 2, float2bfloat(_f2)); @@ -5518,6 +5796,8 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 3) { + float32x4_t _c2; + float32x4_t _c3; if (c_elempack == 1) { uint16x8_t _cc0 = uint16x8_t(); @@ -5538,13 +5818,13 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _cc1 = vsetq_lane_u16(pC[c_hstep * 6 + 1], _cc1, 5); _cc1 = vsetq_lane_u16(pC[c_hstep * 7], _cc1, 6); _cc1 = vsetq_lane_u16(pC[c_hstep * 7 + 1], _cc1, 7); - _f0 = vaddq_f32(_f0, bfloat2float(vget_low_u16(_cc0))); - _f1 = vaddq_f32(_f1, bfloat2float(vget_high_u16(_cc0))); - _f2 = vaddq_f32(_f2, bfloat2float(vget_low_u16(_cc1))); - _f3 = vaddq_f32(_f3, bfloat2float(vget_high_u16(_cc1))); + _c0 = bfloat2float(vget_low_u16(_cc0)); + _c1 = bfloat2float(vget_high_u16(_cc0)); + _c2 = bfloat2float(vget_low_u16(_cc1)); + _c3 = bfloat2float(vget_high_u16(_cc1)); pC += 2; } - if (c_elempack == 4) + else // if (c_elempack == 4) { // TODO optimize uint16x8_t _cc0 = vld1q_u16(pC); @@ -5555,12 +5835,27 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& float32x4_t _c3 = bfloat2float(vget_high_u16(_cc1)); float32x4x2_t _c01 = vzipq_f32(_c0, _c1); float32x4x2_t _c23 = vzipq_f32(_c2, _c3); - _f0 = vaddq_f32(_f0, vcombine_f32(vget_low_f32(_c01.val[0]), vget_low_f32(_c01.val[1]))); - _f1 = vaddq_f32(_f1, vcombine_f32(vget_high_f32(_c01.val[0]), vget_high_f32(_c01.val[1]))); - _f2 = vaddq_f32(_f2, vcombine_f32(vget_low_f32(_c23.val[0]), vget_low_f32(_c23.val[1]))); - _f3 = vaddq_f32(_f3, vcombine_f32(vget_high_f32(_c23.val[0]), vget_high_f32(_c23.val[1]))); + _c0 = vcombine_f32(vget_low_f32(_c01.val[0]), vget_low_f32(_c01.val[1])); + _c1 = vcombine_f32(vget_high_f32(_c01.val[0]), vget_high_f32(_c01.val[1])); + _c2 = vcombine_f32(vget_low_f32(_c23.val[0]), vget_low_f32(_c23.val[1])); + _c3 = vcombine_f32(vget_high_f32(_c23.val[0]), vget_high_f32(_c23.val[1])); pC += 8; } + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } } if (broadcast_type_C == 4) { @@ -5570,6 +5865,7 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _c = vset_lane_u16(pC[0], _c, 2); _c = vset_lane_u16(pC[1], _c, 3); _c0 = bfloat2float(_c); + _c0 = vmulq_n_f32(_c0, beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c0); _f2 = vaddq_f32(_f2, _c0); @@ -5578,6 +5874,15 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + uint16x4_t _fb0 = float2bfloat(_f0); uint16x4_t _fb1 = float2bfloat(_f1); uint16x4_t _fb2 = float2bfloat(_f2); @@ -5636,28 +5941,44 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _c = vsetq_lane_u16(pC[c_hstep * 5], _c, 5); _c = vsetq_lane_u16(pC[c_hstep * 6], _c, 6); _c = vsetq_lane_u16(pC[c_hstep * 7], _c, 7); - _f0 = vaddq_f32(_f0, bfloat2float(vget_low_u16(_c))); - _f1 = vaddq_f32(_f1, bfloat2float(vget_high_u16(_c))); + _c0 = bfloat2float(vget_low_u16(_c)); + _c1 = bfloat2float(vget_high_u16(_c)); pC += 1; } - if (c_elempack == 4) + else // if (c_elempack == 4) { _c0 = bfloat2float(vld1_u16(pC)); _c1 = bfloat2float(vld1_u16(pC + c_hstep * 4)); + pC += 4; + } + if (beta == 1.f) + { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); - pC += 4; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c0); pC += 1; } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + uint16x4_t _fb0 = float2bfloat(_f0); uint16x4_t _fb1 = float2bfloat(_f1); @@ -5811,45 +6132,37 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 3) { + uint16x8_t _c01; + uint16x8_t _c23; + uint16x8_t _c45; + uint16x8_t _c67; if (c_elempack == 1) { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + c_hstep); - uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); - uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); + _c01 = vld1q_u16(pC); + _c23 = vld1q_u16(pC + c_hstep); + _c45 = vld1q_u16(pC + c_hstep * 2); + _c67 = vld1q_u16(pC + c_hstep * 3); transpose8x4_u16(_c01, _c23, _c45, _c67); - _c0 = bfloat2float(vget_low_u16(_c01)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); - float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); - float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); - float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); pC += 8; } - if (c_elempack == 4) + else // if (c_elempack == 4) + { + _c01 = vld1q_u16(pC); + _c23 = vld1q_u16(pC + 8); + _c45 = vld1q_u16(pC + 16); + _c67 = vld1q_u16(pC + 24); + pC += 32; + } + _c0 = bfloat2float(vget_low_u16(_c01)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); + float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); + float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); + float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + 8); - uint16x8_t _c45 = vld1q_u16(pC + 16); - uint16x8_t _c67 = vld1q_u16(pC + 24); - _c0 = bfloat2float(vget_low_u16(_c01)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); - float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); - float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); - float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); @@ -5858,19 +6171,39 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _f5 = vaddq_f32(_f5, _c5); _f6 = vaddq_f32(_f6, _c6); _f7 = vaddq_f32(_f7, _c7); - pC += 32; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); } } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); - float32x4_t _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); - float32x4_t _c2 = vdupq_n_f32(bfloat16_to_float32(pC[2])); - float32x4_t _c3 = vdupq_n_f32(bfloat16_to_float32(pC[3])); - float32x4_t _c4 = vdupq_n_f32(bfloat16_to_float32(pC[4])); - float32x4_t _c5 = vdupq_n_f32(bfloat16_to_float32(pC[5])); - float32x4_t _c6 = vdupq_n_f32(bfloat16_to_float32(pC[6])); - float32x4_t _c7 = vdupq_n_f32(bfloat16_to_float32(pC[7])); + uint16x8_t _c = vld1q_u16(pC); + float32x4_t _cc0 = bfloat2float(vget_low_u16(_c)); + float32x4_t _cc1 = bfloat2float(vget_high_u16(_c)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); + } + _c0 = vdupq_laneq_f32(_cc0, 0); + float32x4_t _c1 = vdupq_laneq_f32(_cc0, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); + float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); + float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); + float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); + float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); @@ -5883,6 +6216,19 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); vst1q_u16(p0 + 16, vcombine_u16(float2bfloat(_f4), float2bfloat(_f5))); @@ -5956,6 +6302,9 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 3) { + float32x4_t _c1; + float32x4_t _c2; + float32x4_t _c3; if (c_elempack == 1) { uint16x4_t _cc0 = vld1_u16(pC); @@ -5963,33 +6312,53 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); - _f0 = vaddq_f32(_f0, bfloat2float(_cc0)); - _f1 = vaddq_f32(_f1, bfloat2float(_cc1)); - _f2 = vaddq_f32(_f2, bfloat2float(_cc2)); - _f3 = vaddq_f32(_f3, bfloat2float(_cc3)); + _c0 = bfloat2float(_cc0); + _c1 = bfloat2float(_cc1); + _c2 = bfloat2float(_cc2); + _c3 = bfloat2float(_cc3); pC += 4; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); uint16x8_t _c23 = vld1q_u16(pC + 8); _c0 = bfloat2float(vget_low_u16(_c01)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + pC += 16; + } + if (beta == 1.f) + { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); _f3 = vaddq_f32(_f3, _c3); - pC += 16; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); - float32x4_t _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); - float32x4_t _c2 = vdupq_n_f32(bfloat16_to_float32(pC[2])); - float32x4_t _c3 = vdupq_n_f32(bfloat16_to_float32(pC[3])); + float32x4_t _c = bfloat2float(vld1_u16(pC)); + _c = vmulq_n_f32(_c, beta); +#if __aarch64__ + _c0 = vdupq_laneq_f32(_c, 0); + float32x4_t _c1 = vdupq_laneq_f32(_c, 1); + float32x4_t _c2 = vdupq_laneq_f32(_c, 2); + float32x4_t _c3 = vdupq_laneq_f32(_c, 3); +#else + _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); + float32x4_t _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); + float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); + float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); +#endif _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); @@ -5998,6 +6367,15 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); @@ -6047,9 +6425,10 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 3) { + uint16x8_t _c; if (c_elempack == 1) { - uint16x8_t _c = uint16x8_t(); + _c = uint16x8_t(); _c = vsetq_lane_u16(pC[0], _c, 0); _c = vsetq_lane_u16(pC[c_hstep], _c, 1); _c = vsetq_lane_u16(pC[c_hstep * 2], _c, 2); @@ -6058,32 +6437,44 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _c = vsetq_lane_u16(pC[c_hstep + 1], _c, 5); _c = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c, 6); _c = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c, 7); - _c0 = bfloat2float(vget_low_u16(_c)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); pC += 2; } - if (c_elempack == 4) + else // if (c_elempack == 4) + { + _c = vld1q_u16(pC); + pC += 8; + } + _c0 = bfloat2float(vget_low_u16(_c)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); + if (beta == 1.f) { - uint16x8_t _c01 = vld1q_u16(pC); - _c0 = bfloat2float(vget_low_u16(_c01)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); - pC += 8; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); - float32x4_t _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); + float32x4_t _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1]) * beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); pC += 2; } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); pp += 8; @@ -6105,31 +6496,34 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 3) { + uint16x4_t _c; if (c_elempack == 1) { - uint16x4_t _c = uint16x4_t(); + _c = uint16x4_t(); _c = vset_lane_u16(pC[0], _c, 0); _c = vset_lane_u16(pC[c_hstep], _c, 1); _c = vset_lane_u16(pC[c_hstep * 2], _c, 2); _c = vset_lane_u16(pC[c_hstep * 3], _c, 3); - _f0 = vaddq_f32(_f0, bfloat2float(_c)); pC += 1; } - if (c_elempack == 4) + else // if (c_elempack == 4) { - _c0 = bfloat2float(vld1_u16(pC)); - _f0 = vaddq_f32(_f0, _c0); + _c = vld1_u16(pC); pC += 4; } + _c0 = bfloat2float(_c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); _f0 = vaddq_f32(_f0, _c0); pC += 1; } } + _f0 = vmulq_n_f32(_f0, alpha); + vst1_u16(p0, float2bfloat(_f0)); pp += 4; @@ -6253,17 +6647,10 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 1 || broadcast_type_C == 2) { -#if __aarch64__ float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); -#else - float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); - float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); - float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); - float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); -#endif _f0 = vaddq_f32(_f0, _cc0); _f1 = vaddq_f32(_f1, _cc0); _f2 = vaddq_f32(_f2, _cc1); @@ -6275,6 +6662,13 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 3) { + float32x4_t _c1; + float32x4_t _c2; + float32x4_t _c3; + float32x4_t _c4; + float32x4_t _c5; + float32x4_t _c6; + float32x4_t _c7; if (c_elempack == 1) { uint16x8_t _c01 = vld1q_u16(pC); @@ -6282,13 +6676,30 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); _c0 = bfloat2float(vget_low_u16(_c01)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); - float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); - float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); - float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + _c4 = bfloat2float(vget_low_u16(_c45)); + _c5 = bfloat2float(vget_high_u16(_c45)); + _c6 = bfloat2float(vget_low_u16(_c67)); + _c7 = bfloat2float(vget_high_u16(_c67)); + pC += 8; + } + else // if (c_elempack == 4) + { + uint16x8x4_t _cc = vld4q_u16(pC); + _c0 = bfloat2float(vget_low_u16(_cc.val[0])); + _c1 = bfloat2float(vget_high_u16(_cc.val[0])); + _c2 = bfloat2float(vget_low_u16(_cc.val[1])); + _c3 = bfloat2float(vget_high_u16(_cc.val[1])); + _c4 = bfloat2float(vget_low_u16(_cc.val[2])); + _c5 = bfloat2float(vget_high_u16(_cc.val[2])); + _c6 = bfloat2float(vget_low_u16(_cc.val[3])); + _c7 = bfloat2float(vget_high_u16(_cc.val[3])); + pC += 32; + } + if (beta == 1.f) + { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); @@ -6297,20 +6708,18 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _f5 = vaddq_f32(_f5, _c5); _f6 = vaddq_f32(_f6, _c6); _f7 = vaddq_f32(_f7, _c7); - pC += 8; } - if (c_elempack == 4) + else { - uint16x8x4_t _cc = vld4q_u16(pC); - _f0 = vaddq_f32(_f0, bfloat2float(vget_low_u16(_cc.val[0]))); - _f1 = vaddq_f32(_f1, bfloat2float(vget_high_u16(_cc.val[0]))); - _f2 = vaddq_f32(_f2, bfloat2float(vget_low_u16(_cc.val[1]))); - _f3 = vaddq_f32(_f3, bfloat2float(vget_high_u16(_cc.val[1]))); - _f4 = vaddq_f32(_f4, bfloat2float(vget_low_u16(_cc.val[2]))); - _f5 = vaddq_f32(_f5, bfloat2float(vget_high_u16(_cc.val[2]))); - _f6 = vaddq_f32(_f6, bfloat2float(vget_low_u16(_cc.val[3]))); - _f7 = vaddq_f32(_f7, bfloat2float(vget_high_u16(_cc.val[3]))); - pC += 32; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); } } if (broadcast_type_C == 4) @@ -6318,6 +6727,12 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& uint16x8_t _c = vld1q_u16(pC); _c0 = bfloat2float(vget_low_u16(_c)); float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _c0 = vmulq_f32(_c0, _beta); + _c1 = vmulq_f32(_c1, _beta); + } _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c0); @@ -6330,6 +6745,19 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); vst1q_u16(p0 + out_hstep, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); vst1q_u16(p0 + out_hstep * 2, vcombine_u16(float2bfloat(_f4), float2bfloat(_f5))); @@ -6433,31 +6861,46 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 3) { + float32x4_t _c1; + float32x4_t _c2; + float32x4_t _c3; if (c_elempack == 1) { _c0 = bfloat2float(vld1_u16(pC)); - float32x4_t _c1 = bfloat2float(vld1_u16(pC + c_hstep * 1)); - float32x4_t _c2 = bfloat2float(vld1_u16(pC + c_hstep * 2)); - float32x4_t _c3 = bfloat2float(vld1_u16(pC + c_hstep * 3)); + _c1 = bfloat2float(vld1_u16(pC + c_hstep)); + _c2 = bfloat2float(vld1_u16(pC + c_hstep * 2)); + _c3 = bfloat2float(vld1_u16(pC + c_hstep * 3)); + pC += 4; + } + else // if (c_elempack == 4) + { + uint16x4x4_t _c = vld4_u16(pC); + _c0 = bfloat2float(_c.val[0]); + _c1 = bfloat2float(_c.val[1]); + _c2 = bfloat2float(_c.val[2]); + _c3 = bfloat2float(_c.val[3]); + pC += 16; + } + if (beta == 1.f) + { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); _f3 = vaddq_f32(_f3, _c3); - pC += 4; } - if (c_elempack == 4) + else { - uint16x4x4_t _c = vld4_u16(pC); - _f0 = vaddq_f32(_f0, bfloat2float(_c.val[0])); - _f1 = vaddq_f32(_f1, bfloat2float(_c.val[1])); - _f2 = vaddq_f32(_f2, bfloat2float(_c.val[2])); - _f3 = vaddq_f32(_f3, bfloat2float(_c.val[3])); - pC += 16; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } } if (broadcast_type_C == 4) { _c0 = bfloat2float(vld1_u16(pC)); + _c0 = vmulq_n_f32(_c0, beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c0); _f2 = vaddq_f32(_f2, _c0); @@ -6466,6 +6909,15 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + vst1_u16(p0, float2bfloat(_f0)); vst1_u16(p0 + out_hstep, float2bfloat(_f1)); vst1_u16(p0 + out_hstep * 2, float2bfloat(_f2)); @@ -6529,6 +6981,7 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 3) { + float32x4_t _c1; if (c_elempack == 1) { uint16x8_t _c = uint16x8_t(); @@ -6540,20 +6993,29 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _c = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c, 5); _c = vsetq_lane_u16(pC[c_hstep * 3], _c, 6); _c = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c, 7); - _f0 = vaddq_f32(_f0, bfloat2float(vget_low_u16(_c))); - _f1 = vaddq_f32(_f1, bfloat2float(vget_high_u16(_c))); + _c0 = bfloat2float(vget_low_u16(_c)); + _c1 = bfloat2float(vget_high_u16(_c)); pC += 2; } - if (c_elempack == 4) + else // if (c_elempack == 4) { - uint16x8_t _cc = vld1q_u16(pC); - _c0 = bfloat2float(vget_low_u16(_cc)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_cc)); - float32x4x2_t _c01 = vzipq_f32(_c0, _c1); - _f0 = vaddq_f32(_f0, _c01.val[0]); - _f1 = vaddq_f32(_f1, _c01.val[1]); + uint16x8_t _c = vld1q_u16(pC); + uint16x4x2_t _c01 = vzip_u16(vget_low_u16(_c), vget_high_u16(_c)); + _c0 = bfloat2float(_c01.val[0]); + _c1 = bfloat2float(_c01.val[1]); pC += 8; } + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } } if (broadcast_type_C == 4) { @@ -6563,12 +7025,20 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _c = vset_lane_u16(pC[0], _c, 2); _c = vset_lane_u16(pC[1], _c, 3); _c0 = bfloat2float(_c); + _c0 = vmulq_n_f32(_c0, beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c0); pC += 2; } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + uint16x4_t _fb0 = float2bfloat(_f0); uint16x4_t _fb1 = float2bfloat(_f1); @@ -6581,11 +7051,6 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& p0[out_hstep * 3] = vget_lane_u16(_fb1, 2); p0[out_hstep * 3 + 1] = vget_lane_u16(_fb1, 3); - // vst1_f32(p0, vget_low_f32(_f0)); - // vst1_f32(p1, vget_high_f32(_f0)); - // vst1_f32(p2, vget_low_f32(_f1)); - // vst1_f32(p3, vget_high_f32(_f1)); - pp += 8; p0 += 2; } @@ -6605,31 +7070,34 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 3) { + uint16x4_t _c; if (c_elempack == 1) { - uint16x4_t _c = uint16x4_t(); + _c = uint16x4_t(); _c = vset_lane_u16(pC[0], _c, 0); _c = vset_lane_u16(pC[c_hstep], _c, 1); _c = vset_lane_u16(pC[c_hstep * 2], _c, 2); _c = vset_lane_u16(pC[c_hstep * 3], _c, 3); - _f0 = vaddq_f32(_f0, bfloat2float(_c)); pC += 1; } - if (c_elempack == 4) + else // if (c_elempack == 4) { - _c0 = bfloat2float(vld1_u16(pC)); - _f0 = vaddq_f32(_f0, _c0); + _c = vld1_u16(pC); pC += 4; } + _c0 = bfloat2float(_c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); _f0 = vaddq_f32(_f0, _c0); pC += 1; } } + _f0 = vmulq_n_f32(_f0, alpha); + uint16x4_t _fb0 = float2bfloat(_f0); p0[0] = vget_lane_u16(_fb0, 0); @@ -6732,10 +7200,21 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } pC += 8; } if (broadcast_type_C == 4) @@ -6743,6 +7222,12 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& uint16x8_t _c = vld1q_u16(pC); _c0 = bfloat2float(vget_low_u16(_c)); float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _c0 = vmulq_f32(_c0, _beta); + _c1 = vmulq_f32(_c1, _beta); + } _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c0); @@ -6751,6 +7236,15 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); vst1q_u16(p0 + out_hstep, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); @@ -6783,19 +7277,36 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& // c_elempack == 1 _c0 = bfloat2float(vld1_u16(pC)); float32x4_t _c1 = bfloat2float(vld1_u16(pC + c_hstep)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } pC += 4; } if (broadcast_type_C == 4) { _c0 = bfloat2float(vld1_u16(pC)); + _c0 = vmulq_n_f32(_c0, beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c0); pC += 4; } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + vst1_u16(p0, float2bfloat(_f0)); vst1_u16(p0 + out_hstep, float2bfloat(_f1)); @@ -6829,22 +7340,30 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& if (broadcast_type_C == 3) { // c_elempack == 1 - f00 += bfloat16_to_float32(pC[0]); - f01 += bfloat16_to_float32(pC[1]); - f10 += bfloat16_to_float32(pC[c_hstep]); - f11 += bfloat16_to_float32(pC[c_hstep + 1]); + f00 += bfloat16_to_float32(pC[0]) * beta; + f01 += bfloat16_to_float32(pC[1]) * beta; + f10 += bfloat16_to_float32(pC[c_hstep]) * beta; + f11 += bfloat16_to_float32(pC[c_hstep + 1]) * beta; pC += 2; } if (broadcast_type_C == 4) { - f00 += bfloat16_to_float32(pC[0]); - f01 += bfloat16_to_float32(pC[1]); - f10 += bfloat16_to_float32(pC[0]); - f11 += bfloat16_to_float32(pC[1]); + f00 += bfloat16_to_float32(pC[0]) * beta; + f01 += bfloat16_to_float32(pC[1]) * beta; + f10 += bfloat16_to_float32(pC[0]) * beta; + f11 += bfloat16_to_float32(pC[1]) * beta; pC += 2; } } + if (alpha != 1.f) + { + f00 *= alpha; + f01 *= alpha; + f10 *= alpha; + f11 *= alpha; + } + p0[0] = float32_to_bfloat16(f00); p0[1] = float32_to_bfloat16(f01); p0[out_hstep] = float32_to_bfloat16(f10); @@ -6874,18 +7393,24 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& if (broadcast_type_C == 3) { // c_elempack == 1 - f0 += bfloat16_to_float32(pC[0]); - f1 += bfloat16_to_float32(pC[c_hstep]); + f0 += bfloat16_to_float32(pC[0]) * beta; + f1 += bfloat16_to_float32(pC[c_hstep]) * beta; pC += 1; } if (broadcast_type_C == 4) { - f0 += bfloat16_to_float32(pC[0]); - f1 += bfloat16_to_float32(pC[0]); + f0 += bfloat16_to_float32(pC[0]) * beta; + f1 += bfloat16_to_float32(pC[0]) * beta; pC += 1; } } + if (alpha != 1.f) + { + f0 *= alpha; + f1 *= alpha; + } + p0[0] = float32_to_bfloat16(f0); p0[out_hstep] = float32_to_bfloat16(f1); @@ -6970,14 +7495,34 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } pC += 16; } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); @@ -7005,12 +7550,28 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& uint16x8_t _c01 = vld1q_u16(pC); _c0 = bfloat2float(vget_low_u16(_c01)); float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } pC += 8; } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); pp += 8; @@ -7030,11 +7591,13 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& { // c_elempack == 1 _c0 = bfloat2float(vld1_u16(pC)); - _f0 = vaddq_f32(_f0, _c0); + _f0 = vmlaq_n_f32(_f0, _c0, beta); pC += 4; } } + _f0 = vmulq_n_f32(_f0, alpha); + vst1_u16(p0, float2bfloat(_f0)); pp += 4; @@ -7056,11 +7619,13 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& float32x2_t _cc = float32x2_t(); _cc = vset_lane_f32(bfloat16_to_float32(pC[0]), _cc, 0); _cc = vset_lane_f32(bfloat16_to_float32(pC[1]), _cc, 1); - _f0 = vadd_f32(_f0, _cc); + _f0 = vmla_n_f32(_f0, _cc, beta); pC += 2; } } + _f0 = vmul_n_f32(_f0, alpha); + p0[0] = float32_to_bfloat16(vget_lane_f32(_f0, 0)); p0[1] = float32_to_bfloat16(vget_lane_f32(_f0, 1)); @@ -7081,11 +7646,13 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& if (broadcast_type_C == 3 || broadcast_type_C == 4) { // c_elempack == 1 - f0 += bfloat16_to_float32(pC[0]); + f0 += bfloat16_to_float32(pC[0]) * beta; pC += 1; } } + f0 *= alpha; + p0[0] = float32_to_bfloat16(f0); pp += 1; @@ -7095,12 +7662,12 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } } -static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales) +static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta) { #if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 if (ncnn::cpu_support_arm_asimddp()) { - transpose_unpack_output_tile_int32_to_bf16_asimddp(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales); + transpose_unpack_output_tile_int32_to_bf16_asimddp(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta); return; } #endif @@ -7393,52 +7960,128 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma float32x4_t _c5 = bfloat2float(vld1_u16(pC + c_hstep * 2 + 4)); float32x4_t _c6 = bfloat2float(vld1_u16(pC + c_hstep * 3)); float32x4_t _c7 = bfloat2float(vld1_u16(pC + c_hstep * 3 + 4)); - float32x4_t _c8 = bfloat2float(vld1_u16(pC + c_hstep * 4)); - float32x4_t _c9 = bfloat2float(vld1_u16(pC + c_hstep * 4 + 4)); - float32x4_t _ca = bfloat2float(vld1_u16(pC + c_hstep * 5)); - float32x4_t _cb = bfloat2float(vld1_u16(pC + c_hstep * 5 + 4)); - float32x4_t _cc = bfloat2float(vld1_u16(pC + c_hstep * 6)); - float32x4_t _cd = bfloat2float(vld1_u16(pC + c_hstep * 6 + 4)); - float32x4_t _ce = bfloat2float(vld1_u16(pC + c_hstep * 7)); - float32x4_t _cf = bfloat2float(vld1_u16(pC + c_hstep * 7 + 4)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - _f8 = vaddq_f32(_f8, _c8); - _f9 = vaddq_f32(_f9, _c9); - _fa = vaddq_f32(_fa, _ca); - _fb = vaddq_f32(_fb, _cb); - _fc = vaddq_f32(_fc, _cc); - _fd = vaddq_f32(_fd, _cd); - _fe = vaddq_f32(_fe, _ce); - _ff = vaddq_f32(_ff, _cf); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c0 = bfloat2float(vld1_u16(pC + c_hstep * 4)); + _c1 = bfloat2float(vld1_u16(pC + c_hstep * 4 + 4)); + _c2 = bfloat2float(vld1_u16(pC + c_hstep * 5)); + _c3 = bfloat2float(vld1_u16(pC + c_hstep * 5 + 4)); + _c4 = bfloat2float(vld1_u16(pC + c_hstep * 6)); + _c5 = bfloat2float(vld1_u16(pC + c_hstep * 6 + 4)); + _c6 = bfloat2float(vld1_u16(pC + c_hstep * 7)); + _c7 = bfloat2float(vld1_u16(pC + c_hstep * 7 + 4)); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); + } pC += 8; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x8x4_t _cc0 = vld4q_u16(pC); - uint16x8x4_t _cc1 = vld4q_u16(pC + c_hstep * 4); - _f0 = vaddq_f32(_f0, bfloat2float(vget_low_u16(_cc0.val[0]))); - _f1 = vaddq_f32(_f1, bfloat2float(vget_high_u16(_cc0.val[0]))); - _f2 = vaddq_f32(_f2, bfloat2float(vget_low_u16(_cc0.val[1]))); - _f3 = vaddq_f32(_f3, bfloat2float(vget_high_u16(_cc0.val[1]))); - _f4 = vaddq_f32(_f4, bfloat2float(vget_low_u16(_cc0.val[2]))); - _f5 = vaddq_f32(_f5, bfloat2float(vget_high_u16(_cc0.val[2]))); - _f6 = vaddq_f32(_f6, bfloat2float(vget_low_u16(_cc0.val[3]))); - _f7 = vaddq_f32(_f7, bfloat2float(vget_high_u16(_cc0.val[3]))); - _f8 = vaddq_f32(_f8, bfloat2float(vget_low_u16(_cc1.val[0]))); - _f9 = vaddq_f32(_f9, bfloat2float(vget_high_u16(_cc1.val[0]))); - _fa = vaddq_f32(_fa, bfloat2float(vget_low_u16(_cc1.val[1]))); - _fb = vaddq_f32(_fb, bfloat2float(vget_high_u16(_cc1.val[1]))); - _fc = vaddq_f32(_fc, bfloat2float(vget_low_u16(_cc1.val[2]))); - _fd = vaddq_f32(_fd, bfloat2float(vget_high_u16(_cc1.val[2]))); - _fe = vaddq_f32(_fe, bfloat2float(vget_low_u16(_cc1.val[3]))); - _ff = vaddq_f32(_ff, bfloat2float(vget_high_u16(_cc1.val[3]))); + _c0 = bfloat2float(vget_low_u16(_cc0.val[0])); + _c1 = bfloat2float(vget_high_u16(_cc0.val[0])); + float32x4_t _c2 = bfloat2float(vget_low_u16(_cc0.val[1])); + float32x4_t _c3 = bfloat2float(vget_high_u16(_cc0.val[1])); + float32x4_t _c4 = bfloat2float(vget_low_u16(_cc0.val[2])); + float32x4_t _c5 = bfloat2float(vget_high_u16(_cc0.val[2])); + float32x4_t _c6 = bfloat2float(vget_low_u16(_cc0.val[3])); + float32x4_t _c7 = bfloat2float(vget_high_u16(_cc0.val[3])); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _cc0 = vld4q_u16(pC + c_hstep * 4); + _c0 = bfloat2float(vget_low_u16(_cc0.val[0])); + _c1 = bfloat2float(vget_high_u16(_cc0.val[0])); + _c2 = bfloat2float(vget_low_u16(_cc0.val[1])); + _c3 = bfloat2float(vget_high_u16(_cc0.val[1])); + _c4 = bfloat2float(vget_low_u16(_cc0.val[2])); + _c5 = bfloat2float(vget_high_u16(_cc0.val[2])); + _c6 = bfloat2float(vget_low_u16(_cc0.val[3])); + _c7 = bfloat2float(vget_high_u16(_cc0.val[3])); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); + } pC += 32; } } @@ -7447,6 +8090,12 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma uint16x8_t _c = vld1q_u16(pC); _c0 = bfloat2float(vget_low_u16(_c)); _c1 = bfloat2float(vget_high_u16(_c)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _c0 = vmulq_f32(_c0, _beta); + _c1 = vmulq_f32(_c1, _beta); + } _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c0); @@ -7467,6 +8116,27 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + _f8 = vmulq_f32(_f8, _alpha); + _f9 = vmulq_f32(_f9, _alpha); + _fa = vmulq_f32(_fa, _alpha); + _fb = vmulq_f32(_fb, _alpha); + _fc = vmulq_f32(_fc, _alpha); + _fd = vmulq_f32(_fd, _alpha); + _fe = vmulq_f32(_fe, _alpha); + _ff = vmulq_f32(_ff, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f2))); vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f4), float2bfloat(_f6))); vst1q_u16(p0 + 16, vcombine_u16(float2bfloat(_f8), float2bfloat(_fa))); @@ -7609,28 +8279,31 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); - float32x4_t _cc4 = vdupq_laneq_f32(_c1, 0); - float32x4_t _cc5 = vdupq_laneq_f32(_c1, 1); - float32x4_t _cc6 = vdupq_laneq_f32(_c1, 2); - float32x4_t _cc7 = vdupq_laneq_f32(_c1, 3); #else float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); - float32x4_t _cc4 = vdupq_lane_f32(vget_low_f32(_c1), 0); - float32x4_t _cc5 = vdupq_lane_f32(vget_low_f32(_c1), 1); - float32x4_t _cc6 = vdupq_lane_f32(vget_high_f32(_c1), 0); - float32x4_t _cc7 = vdupq_lane_f32(vget_high_f32(_c1), 1); #endif _f0 = vaddq_f32(_f0, _cc0); _f1 = vaddq_f32(_f1, _cc1); _f2 = vaddq_f32(_f2, _cc2); _f3 = vaddq_f32(_f3, _cc3); - _f4 = vaddq_f32(_f4, _cc4); - _f5 = vaddq_f32(_f5, _cc5); - _f6 = vaddq_f32(_f6, _cc6); - _f7 = vaddq_f32(_f7, _cc7); +#if __aarch64__ + _cc0 = vdupq_laneq_f32(_c1, 0); + _cc1 = vdupq_laneq_f32(_c1, 1); + _cc2 = vdupq_laneq_f32(_c1, 2); + _cc3 = vdupq_laneq_f32(_c1, 3); +#else + _cc0 = vdupq_lane_f32(vget_low_f32(_c1), 0); + _cc1 = vdupq_lane_f32(vget_low_f32(_c1), 1); + _cc2 = vdupq_lane_f32(vget_high_f32(_c1), 0); + _cc3 = vdupq_lane_f32(vget_high_f32(_c1), 1); +#endif + _f4 = vaddq_f32(_f4, _cc0); + _f5 = vaddq_f32(_f5, _cc1); + _f6 = vaddq_f32(_f6, _cc2); + _f7 = vaddq_f32(_f7, _cc3); } if (broadcast_type_C == 3) { @@ -7640,38 +8313,91 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _c1 = bfloat2float(vld1_u16(pC + c_hstep)); float32x4_t _c2 = bfloat2float(vld1_u16(pC + c_hstep * 2)); float32x4_t _c3 = bfloat2float(vld1_u16(pC + c_hstep * 3)); - float32x4_t _c4 = bfloat2float(vld1_u16(pC + c_hstep * 4)); - float32x4_t _c5 = bfloat2float(vld1_u16(pC + c_hstep * 5)); - float32x4_t _c6 = bfloat2float(vld1_u16(pC + c_hstep * 6)); - float32x4_t _c7 = bfloat2float(vld1_u16(pC + c_hstep * 7)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + _c0 = bfloat2float(vld1_u16(pC + c_hstep * 4)); + _c1 = bfloat2float(vld1_u16(pC + c_hstep * 5)); + _c2 = bfloat2float(vld1_u16(pC + c_hstep * 6)); + _c3 = bfloat2float(vld1_u16(pC + c_hstep * 7)); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); + } pC += 4; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x4x4_t _cc0 = vld4_u16(pC); - uint16x4x4_t _cc1 = vld4_u16(pC + c_hstep * 4); - _f0 = vaddq_f32(_f0, bfloat2float(_cc0.val[0])); - _f1 = vaddq_f32(_f1, bfloat2float(_cc0.val[1])); - _f2 = vaddq_f32(_f2, bfloat2float(_cc0.val[2])); - _f3 = vaddq_f32(_f3, bfloat2float(_cc0.val[3])); - _f4 = vaddq_f32(_f4, bfloat2float(_cc1.val[0])); - _f5 = vaddq_f32(_f5, bfloat2float(_cc1.val[1])); - _f6 = vaddq_f32(_f6, bfloat2float(_cc1.val[2])); - _f7 = vaddq_f32(_f7, bfloat2float(_cc1.val[3])); + _c0 = bfloat2float(_cc0.val[0]); + _c1 = bfloat2float(_cc0.val[1]); + float32x4_t _c2 = bfloat2float(_cc0.val[2]); + float32x4_t _c3 = bfloat2float(_cc0.val[3]); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + _cc0 = vld4_u16(pC + c_hstep * 4); + _c0 = bfloat2float(_cc0.val[0]); + _c1 = bfloat2float(_cc0.val[1]); + _c2 = bfloat2float(_cc0.val[2]); + _c3 = bfloat2float(_cc0.val[3]); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); + } pC += 16; } } if (broadcast_type_C == 4) { _c0 = bfloat2float(vld1_u16(pC)); + _c0 = vmulq_n_f32(_c0, beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c0); _f2 = vaddq_f32(_f2, _c0); @@ -7684,6 +8410,19 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); vst1q_u16(p0 + 16, vcombine_u16(float2bfloat(_f4), float2bfloat(_f5))); @@ -7917,6 +8656,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma { if (c_elempack == 1) { + // TODO decompose 8x8 to 8x4 and 8x4 uint16x8_t _c01 = vld1q_u16(pC); uint16x8_t _c23 = vld1q_u16(pC + c_hstep); uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); @@ -7960,16 +8700,12 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _ff = vaddq_f32(_ff, _cf); pC += 8; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); uint16x8_t _c23 = vld1q_u16(pC + 8); uint16x8_t _c45 = vld1q_u16(pC + 16); uint16x8_t _c67 = vld1q_u16(pC + 24); - uint16x8_t _c89 = vld1q_u16(pC + c_hstep * 4); - uint16x8_t _cab = vld1q_u16(pC + c_hstep * 4 + 8); - uint16x8_t _ccd = vld1q_u16(pC + c_hstep * 4 + 16); - uint16x8_t _cef = vld1q_u16(pC + c_hstep * 4 + 24); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); @@ -7978,43 +8714,86 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); - float32x4_t _c8 = bfloat2float(vget_low_u16(_c89)); - float32x4_t _c9 = bfloat2float(vget_high_u16(_c89)); - float32x4_t _ca = bfloat2float(vget_low_u16(_cab)); - float32x4_t _cb = bfloat2float(vget_high_u16(_cab)); - float32x4_t _cc = bfloat2float(vget_low_u16(_ccd)); - float32x4_t _cd = bfloat2float(vget_high_u16(_ccd)); - float32x4_t _ce = bfloat2float(vget_low_u16(_cef)); - float32x4_t _cf = bfloat2float(vget_high_u16(_cef)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - _f8 = vaddq_f32(_f8, _c8); - _f9 = vaddq_f32(_f9, _c9); - _fa = vaddq_f32(_fa, _ca); - _fb = vaddq_f32(_fb, _cb); - _fc = vaddq_f32(_fc, _cc); - _fd = vaddq_f32(_fd, _cd); - _fe = vaddq_f32(_fe, _ce); - _ff = vaddq_f32(_ff, _cf); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 4 + 8); + _c45 = vld1q_u16(pC + c_hstep * 4 + 16); + _c67 = vld1q_u16(pC + c_hstep * 4 + 24); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + _c4 = bfloat2float(vget_low_u16(_c45)); + _c5 = bfloat2float(vget_high_u16(_c45)); + _c6 = bfloat2float(vget_low_u16(_c67)); + _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); + } pC += 32; } } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); - _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); - float32x4_t _c2 = vdupq_n_f32(bfloat16_to_float32(pC[2])); - float32x4_t _c3 = vdupq_n_f32(bfloat16_to_float32(pC[3])); - float32x4_t _c4 = vdupq_n_f32(bfloat16_to_float32(pC[4])); - float32x4_t _c5 = vdupq_n_f32(bfloat16_to_float32(pC[5])); - float32x4_t _c6 = vdupq_n_f32(bfloat16_to_float32(pC[6])); - float32x4_t _c7 = vdupq_n_f32(bfloat16_to_float32(pC[7])); + uint16x8_t _c = vld1q_u16(pC); + float32x4_t _cc0 = bfloat2float(vget_low_u16(_c)); + float32x4_t _cc1 = bfloat2float(vget_high_u16(_c)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); + } + _c0 = vdupq_laneq_f32(_cc0, 0); + _c1 = vdupq_laneq_f32(_cc0, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); + float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); + float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); + float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); + float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); @@ -8035,6 +8814,27 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + _f8 = vmulq_f32(_f8, _alpha); + _f9 = vmulq_f32(_f9, _alpha); + _fa = vmulq_f32(_fa, _alpha); + _fb = vmulq_f32(_fb, _alpha); + _fc = vmulq_f32(_fc, _alpha); + _fd = vmulq_f32(_fd, _alpha); + _fe = vmulq_f32(_fe, _alpha); + _ff = vmulq_f32(_ff, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f8))); vst1q_u16(p0 + out_hstep, vcombine_u16(float2bfloat(_f1), float2bfloat(_f9))); vst1q_u16(p0 + out_hstep * 2, vcombine_u16(float2bfloat(_f2), float2bfloat(_fa))); @@ -8156,6 +8956,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma { if (c_elempack == 1) { + // TODO decompose 4x8 to 4x4 and 4x4 uint16x4_t _cc0 = vld1_u16(pC); uint16x4_t _cc1 = vld1_u16(pC + c_hstep); uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); @@ -8175,37 +8976,68 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _f7 = vaddq_f32(_f7, bfloat2float(_cc7)); pC += 4; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); uint16x8_t _c23 = vld1q_u16(pC + 8); - uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 4); - uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 4 + 8); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); - float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); - float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); - float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 4 + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); + } pC += 16; } } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); - _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); - float32x4_t _c2 = vdupq_n_f32(bfloat16_to_float32(pC[2])); - float32x4_t _c3 = vdupq_n_f32(bfloat16_to_float32(pC[3])); + float32x4_t _c = bfloat2float(vld1_u16(pC)); + _c = vmulq_n_f32(_c, beta); +#if __aarch64__ + _c0 = vdupq_laneq_f32(_c, 0); + _c1 = vdupq_laneq_f32(_c, 1); + float32x4_t _c2 = vdupq_laneq_f32(_c, 2); + float32x4_t _c3 = vdupq_laneq_f32(_c, 3); +#else + _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); + _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); + float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); + float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); +#endif _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); @@ -8218,6 +9050,19 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f4))); vst1q_u16(p0 + out_hstep, vcombine_u16(float2bfloat(_f1), float2bfloat(_f5))); vst1q_u16(p0 + out_hstep * 2, vcombine_u16(float2bfloat(_f2), float2bfloat(_f6))); @@ -8288,6 +9133,8 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } if (broadcast_type_C == 3) { + float32x4_t _c2; + float32x4_t _c3; if (c_elempack == 1) { uint16x8_t _c01 = uint16x8_t(); @@ -8312,33 +9159,40 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c2 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); + _c2 = bfloat2float(vget_high_u16(_c01)); + _c3 = bfloat2float(vget_high_u16(_c23)); pC += 2; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); uint16x8_t _c23 = vld1q_u16(pC + c_hstep * 4); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + pC += 8; + } + if (beta == 1.f) + { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); _f3 = vaddq_f32(_f3, _c3); - pC += 8; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); - _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); + _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1]) * beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c0); @@ -8347,6 +9201,15 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f2))); vst1q_u16(p0 + out_hstep, vcombine_u16(float2bfloat(_f1), float2bfloat(_f3))); @@ -8385,28 +9248,42 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _c01 = vsetq_lane_u16(pC[c_hstep * 7], _c01, 7); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); pC += 1; } - if (c_elempack == 4) + else // if (c_elempack == 4) { _c0 = bfloat2float(vld1_u16(pC)); _c1 = bfloat2float(vld1_u16(pC + c_hstep * 4)); + pC += 4; + } + if (beta == 1.f) + { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); - pC += 4; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c0); pC += 1; } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); pp += 8; p0 += out_hstep; @@ -8598,7 +9475,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _f7 = vaddq_f32(_f7, _c7); pC += 8; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x8x4_t _c = vld4q_u16(pC); _f0 = vaddq_f32(_f0, bfloat2float(vget_low_u16(_c.val[0]))); @@ -8629,6 +9506,19 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f2))); vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f4), float2bfloat(_f6))); vst1q_u16(p0 + out_hstep * 4, vcombine_u16(float2bfloat(_f1), float2bfloat(_f3))); @@ -8744,7 +9634,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _f3 = vaddq_f32(_f3, _c3); pC += 4; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x4x4_t _c = vld4_u16(pC); _f0 = vaddq_f32(_f0, bfloat2float(_c.val[0])); @@ -8765,6 +9655,15 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); @@ -8905,7 +9804,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _f7 = vaddq_f32(_f7, _c7); pC += 8; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); uint16x8_t _c23 = vld1q_u16(pC + 8); @@ -8952,6 +9851,19 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + vst1_u16(p0, float2bfloat(_f0)); vst1_u16(p0 + out_hstep, float2bfloat(_f1)); vst1_u16(p0 + out_hstep * 2, float2bfloat(_f2)); @@ -9042,7 +9954,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _f3 = vaddq_f32(_f3, bfloat2float(_cc3)); pC += 4; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); uint16x8_t _c23 = vld1q_u16(pC + 8); @@ -9071,6 +9983,15 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + vst1_u16(p0, float2bfloat(_f0)); vst1_u16(p0 + out_hstep, float2bfloat(_f1)); vst1_u16(p0 + out_hstep * 2, float2bfloat(_f2)); @@ -9139,7 +10060,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _f1 = vaddq_f32(_f1, _c1); pC += 2; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); _c0 = bfloat2float(vget_low_u16(_c01)); @@ -9159,6 +10080,13 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + vst1_u16(p0, float2bfloat(_f0)); vst1_u16(p0 + out_hstep, float2bfloat(_f1)); @@ -9191,7 +10119,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _f0 = vaddq_f32(_f0, bfloat2float(_c)); pC += 1; } - if (c_elempack == 4) + else // if (c_elempack == 4) { _c0 = bfloat2float(vld1_u16(pC)); _f0 = vaddq_f32(_f0, _c0); @@ -9206,6 +10134,11 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + _f0 = vmulq_n_f32(_f0, alpha); + } + vst1_u16(p0, float2bfloat(_f0)); pp += 4; p0 += out_hstep; @@ -9320,6 +10253,15 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f2))); vst1q_u16(p0 + out_hstep * 4, vcombine_u16(float2bfloat(_f1), float2bfloat(_f3))); @@ -9368,6 +10310,13 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + vst1_u16(p0, float2bfloat(_f0)); vst1_u16(p0 + 4, float2bfloat(_f1)); @@ -9446,6 +10395,15 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + uint16x4_t _bf0 = float2bfloat(_f0); uint16x4_t _bf1 = float2bfloat(_f1); uint16x4_t _bf2 = float2bfloat(_f2); @@ -9520,6 +10478,13 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + uint16x4_t _bf0 = float2bfloat(_f0); uint16x4_t _bf1 = float2bfloat(_f1); @@ -9582,6 +10547,14 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + f00 *= alpha; + f01 *= alpha; + f10 *= alpha; + f11 *= alpha; + } + p0[0] = float32_to_bfloat16(f00); p0[1] = float32_to_bfloat16(f01); p0[out_hstep] = float32_to_bfloat16(f10); @@ -9624,6 +10597,12 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + f0 *= alpha; + f1 *= alpha; + } + p0[0] = float32_to_bfloat16(f0); p0[1] = float32_to_bfloat16(f1); pp += 2; @@ -9714,6 +10693,15 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + vst1_u16(p0, float2bfloat(_f0)); vst1_u16(p0 + out_hstep * 4, float2bfloat(_f1)); vst1_u16(p0 + out_hstep * 8, float2bfloat(_f2)); @@ -9748,6 +10736,13 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + vst1_u16(p0, float2bfloat(_f0)); vst1_u16(p0 + out_hstep * 4, float2bfloat(_f1)); pp += 8; @@ -9771,6 +10766,11 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + _f0 = vmulq_n_f32(_f0, alpha); + } + vst1_u16(p0, float2bfloat(_f0)); pp += 4; p0 += out_hstep * 4; @@ -9819,6 +10819,15 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + uint16x4_t _bf0 = float2bfloat(_f0); uint16x4_t _bf1 = float2bfloat(_f1); uint16x4_t _bf2 = float2bfloat(_f2); @@ -9871,6 +10880,13 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + uint16x4_t _bf0 = float2bfloat(_f0); uint16x4_t _bf1 = float2bfloat(_f1); @@ -9905,6 +10921,11 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + _f0 = vmulq_n_f32(_f0, alpha); + } + uint16x4_t _bf0 = float2bfloat(_f0); p0[0] = vget_lane_u16(_bf0, 0); @@ -9936,6 +10957,11 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + _f0 = vmul_n_f32(_f0, alpha); + } + p0[0] = float32_to_bfloat16(vget_lane_f32(_f0, 0)); p0[out_hstep] = float32_to_bfloat16(vget_lane_f32(_f0, 1)); @@ -9961,6 +10987,8 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + f0 *= alpha; + p0[0] = float32_to_bfloat16(f0); pp += 1;