From 502d58be1d52be1db5ca0168db39c2d0263fb9b1 Mon Sep 17 00:00:00 2001 From: nihui Date: Mon, 4 Nov 2024 16:44:17 +0800 Subject: [PATCH] w --- src/layer/gemm.cpp | 14 +- src/layer/x86/gemm_int8.h | 7137 +++++++++++++++++++++++++++++++++++-- 2 files changed, 6827 insertions(+), 324 deletions(-) diff --git a/src/layer/gemm.cpp b/src/layer/gemm.cpp index 0ebe5974d0b7..0b8c88bfa8b2 100644 --- a/src/layer/gemm.cpp +++ b/src/layer/gemm.cpp @@ -241,10 +241,18 @@ static void gemm_transB_int8(const Mat& A_int8, const Mat& BT_int8, const Mat& A int sum = 0; for (int k = 0; k < K; k++) { - // NCNN_LOGE("ptrA[%d] %d", k, ptrA[k]); + // if (M==4 && N==7) + // { + // NCNN_LOGE("ptrA[%d] %d %d", k, ptrA[k], ptrBT[k]); + // } sum += ptrA[k] * ptrBT[k]; } + // if (M==4 && N==7) + // { + // NCNN_LOGE("sum %d", sum); + // } + float sum_fp32 = sum * descale; if (ptrC) @@ -501,11 +509,11 @@ int Gemm::forward_int8(const std::vector& bottom_blobs, std::vector& t absmax = std::max(absmax, (float)fabs(ptr[k])); } - // NCNN_LOGE("A[%d] absmax %f", i, absmax); - float A_int8_scale = absmax == 0.f ? 1.f : 127.f / absmax; A_int8_scales[i] = A_int8_scale; + // NCNN_LOGE("A[%d] absmax %.9f %.9f", i, absmax, A_int8_scale); + signed char* ptrAi = A_int8.row(i); for (int k = 0; k < A_int8.w; k++) diff --git a/src/layer/x86/gemm_int8.h b/src/layer/x86/gemm_int8.h index fd3b1f715182..27f29dab681a 100644 --- a/src/layer/x86/gemm_int8.h +++ b/src/layer/x86/gemm_int8.h @@ -25,6 +25,19 @@ void gemm_transB_packed_tile_int8_avx2(const Mat& AT_tile, const Mat& BT_tile, M void gemm_transB_packed_tile_int8_xop(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk); #endif +#if __AVX512F__ +static void print(__m512 x) +{ + float a[16]; + _mm512_storeu_ps(a, x); + for (int i = 0; i < 16; i++) + { + fprintf(stderr, "%.0f ", a[i]); + } + fprintf(stderr, "\n"); +} +#endif + static void pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) { #if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ @@ -44,6 +57,119 @@ static void pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, in int ii = 0; #if __SSE2__ #if __AVX2__ +#if __AVX512F__ + for (; ii + 15 < max_ii; ii += 16) + { + const signed char* p0 = A.row(i + ii) + k; + const signed char* p1 = A.row(i + ii + 1) + k; + const signed char* p2 = A.row(i + ii + 2) + k; + const signed char* p3 = A.row(i + ii + 3) + k; + const signed char* p4 = A.row(i + ii + 4) + k; + const signed char* p5 = A.row(i + ii + 5) + k; + const signed char* p6 = A.row(i + ii + 6) + k; + const signed char* p7 = A.row(i + ii + 7) + k; + const signed char* p8 = A.row(i + ii + 8) + k; + const signed char* p9 = A.row(i + ii + 9) + k; + const signed char* pa = A.row(i + ii + 10) + k; + const signed char* pb = A.row(i + ii + 11) + k; + const signed char* pc = A.row(i + ii + 12) + k; + const signed char* pd = A.row(i + ii + 13) + k; + const signed char* pe = A.row(i + ii + 14) + k; + const signed char* pf = A.row(i + ii + 15) + k; + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp[4] = p2[0]; + pp[5] = p2[1]; + pp[6] = p3[0]; + pp[7] = p3[1]; + pp[8] = p4[0]; + pp[9] = p4[1]; + pp[10] = p5[0]; + pp[11] = p5[1]; + pp[12] = p6[0]; + pp[13] = p6[1]; + pp[14] = p7[0]; + pp[15] = p7[1]; + + pp[16 + 0] = p8[0]; + pp[16 + 1] = p8[1]; + pp[16 + 2] = p9[0]; + pp[16 + 3] = p9[1]; + pp[16 + 4] = pa[0]; + pp[16 + 5] = pa[1]; + pp[16 + 6] = pb[0]; + pp[16 + 7] = pb[1]; + pp[16 + 8] = pc[0]; + pp[16 + 9] = pc[1]; + pp[16 + 10] = pd[0]; + pp[16 + 11] = pd[1]; + pp[16 + 12] = pe[0]; + pp[16 + 13] = pe[1]; + pp[16 + 14] = pf[0]; + pp[16 + 15] = pf[1]; + + pp += 32; + p0 += 2; + p1 += 2; + p2 += 2; + p3 += 2; + p4 += 2; + p5 += 2; + p6 += 2; + p7 += 2; + p8 += 2; + p9 += 2; + pa += 2; + pb += 2; + pc += 2; + pd += 2; + pe += 2; + pf += 2; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp[2] = p2[0]; + pp[3] = p3[0]; + pp[4] = p4[0]; + pp[5] = p5[0]; + pp[6] = p6[0]; + pp[7] = p7[0]; + pp[8] = p8[0]; + pp[9] = p9[0]; + pp[10] = pa[0]; + pp[11] = pb[0]; + pp[12] = pc[0]; + pp[13] = pd[0]; + pp[14] = pe[0]; + pp[15] = pf[0]; + pp += 16; + p0++; + p1++; + p2++; + p3++; + p4++; + p5++; + p6++; + p7++; + p8++; + p9++; + pa++; + pb++; + pc++; + pd++; + pe++; + pf++; + } + } +#endif // __AVX512F__ for (; ii + 7 < max_ii; ii += 8) { const signed char* p0 = A.row(i + ii) + k; @@ -206,6 +332,73 @@ static void transpose_pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int ii = 0; #if __SSE2__ #if __AVX2__ +#if __AVX512F__ + for (; ii + 15 < max_ii; ii += 16) + { + const signed char* p0 = A.row(k) + (i + ii); + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[A_hstep]; + pp[2] = p0[1]; + pp[3] = p0[A_hstep + 1]; + pp[4] = p0[2]; + pp[5] = p0[A_hstep + 2]; + pp[6] = p0[3]; + pp[7] = p0[A_hstep + 3]; + pp[8] = p0[4]; + pp[9] = p0[A_hstep + 4]; + pp[10] = p0[5]; + pp[11] = p0[A_hstep + 5]; + pp[12] = p0[6]; + pp[13] = p0[A_hstep + 6]; + pp[14] = p0[7]; + pp[15] = p0[A_hstep + 7]; + + pp[16 + 0] = p0[8]; + pp[16 + 1] = p0[A_hstep + 8]; + pp[16 + 2] = p0[9]; + pp[16 + 3] = p0[A_hstep + 9]; + pp[16 + 4] = p0[10]; + pp[16 + 5] = p0[A_hstep + 10]; + pp[16 + 6] = p0[11]; + pp[16 + 7] = p0[A_hstep + 11]; + pp[16 + 8] = p0[12]; + pp[16 + 9] = p0[A_hstep + 12]; + pp[16 + 10] = p0[13]; + pp[16 + 11] = p0[A_hstep + 13]; + pp[16 + 12] = p0[14]; + pp[16 + 13] = p0[A_hstep + 14]; + pp[16 + 14] = p0[15]; + pp[16 + 15] = p0[A_hstep + 15]; + pp += 32; + p0 += A_hstep * 2; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + pp[4] = p0[4]; + pp[5] = p0[5]; + pp[6] = p0[6]; + pp[7] = p0[7]; + pp[8] = p0[8]; + pp[9] = p0[9]; + pp[10] = p0[10]; + pp[11] = p0[11]; + pp[12] = p0[12]; + pp[13] = p0[13]; + pp[14] = p0[14]; + pp[15] = p0[15]; + pp += 16; + p0 += A_hstep; + } + } +#endif // __AVX512F__ for (; ii + 7 < max_ii; ii += 8) { const signed char* p0 = A.row(k) + (i + ii); @@ -325,6 +518,119 @@ static void pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, in int jj = 0; #if __SSE2__ #if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + const signed char* p0 = B.row(j + jj) + k; + const signed char* p1 = B.row(j + jj + 1) + k; + const signed char* p2 = B.row(j + jj + 2) + k; + const signed char* p3 = B.row(j + jj + 3) + k; + const signed char* p4 = B.row(j + jj + 4) + k; + const signed char* p5 = B.row(j + jj + 5) + k; + const signed char* p6 = B.row(j + jj + 6) + k; + const signed char* p7 = B.row(j + jj + 7) + k; + const signed char* p8 = B.row(j + jj + 8) + k; + const signed char* p9 = B.row(j + jj + 9) + k; + const signed char* pa = B.row(j + jj + 10) + k; + const signed char* pb = B.row(j + jj + 11) + k; + const signed char* pc = B.row(j + jj + 12) + k; + const signed char* pd = B.row(j + jj + 13) + k; + const signed char* pe = B.row(j + jj + 14) + k; + const signed char* pf = B.row(j + jj + 15) + k; + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp[4] = p2[0]; + pp[5] = p2[1]; + pp[6] = p3[0]; + pp[7] = p3[1]; + pp[8] = p4[0]; + pp[9] = p4[1]; + pp[10] = p5[0]; + pp[11] = p5[1]; + pp[12] = p6[0]; + pp[13] = p6[1]; + pp[14] = p7[0]; + pp[15] = p7[1]; + + pp[16 + 0] = p8[0]; + pp[16 + 1] = p8[1]; + pp[16 + 2] = p9[0]; + pp[16 + 3] = p9[1]; + pp[16 + 4] = pa[0]; + pp[16 + 5] = pa[1]; + pp[16 + 6] = pb[0]; + pp[16 + 7] = pb[1]; + pp[16 + 8] = pc[0]; + pp[16 + 9] = pc[1]; + pp[16 + 10] = pd[0]; + pp[16 + 11] = pd[1]; + pp[16 + 12] = pe[0]; + pp[16 + 13] = pe[1]; + pp[16 + 14] = pf[0]; + pp[16 + 15] = pf[1]; + + pp += 32; + p0 += 2; + p1 += 2; + p2 += 2; + p3 += 2; + p4 += 2; + p5 += 2; + p6 += 2; + p7 += 2; + p8 += 2; + p9 += 2; + pa += 2; + pb += 2; + pc += 2; + pd += 2; + pe += 2; + pf += 2; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp[2] = p2[0]; + pp[3] = p3[0]; + pp[4] = p4[0]; + pp[5] = p5[0]; + pp[6] = p6[0]; + pp[7] = p7[0]; + pp[8] = p8[0]; + pp[9] = p9[0]; + pp[10] = pa[0]; + pp[11] = pb[0]; + pp[12] = pc[0]; + pp[13] = pd[0]; + pp[14] = pe[0]; + pp[15] = pf[0]; + pp += 16; + p0++; + p1++; + p2++; + p3++; + p4++; + p5++; + p6++; + p7++; + p8++; + p9++; + pa++; + pb++; + pc++; + pd++; + pe++; + pf++; + } + } +#endif // __AVX512F__ for (; jj + 7 < max_jj; jj += 8) { const signed char* p0 = B.row(j + jj) + k; @@ -479,6 +785,73 @@ static void transpose_pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int jj = 0; #if __SSE2__ #if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + const signed char* p0 = B.row(k) + (j + jj); + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[B_hstep]; + pp[2] = p0[1]; + pp[3] = p0[B_hstep + 1]; + pp[4] = p0[2]; + pp[5] = p0[B_hstep + 2]; + pp[6] = p0[3]; + pp[7] = p0[B_hstep + 3]; + pp[8] = p0[4]; + pp[9] = p0[B_hstep + 4]; + pp[10] = p0[5]; + pp[11] = p0[B_hstep + 5]; + pp[12] = p0[6]; + pp[13] = p0[B_hstep + 6]; + pp[14] = p0[7]; + pp[15] = p0[B_hstep + 7]; + + pp[16 + 0] = p0[8]; + pp[16 + 1] = p0[B_hstep + 8]; + pp[16 + 2] = p0[9]; + pp[16 + 3] = p0[B_hstep + 9]; + pp[16 + 4] = p0[10]; + pp[16 + 5] = p0[B_hstep + 10]; + pp[16 + 6] = p0[11]; + pp[16 + 7] = p0[B_hstep + 11]; + pp[16 + 8] = p0[12]; + pp[16 + 9] = p0[B_hstep + 12]; + pp[16 + 10] = p0[13]; + pp[16 + 11] = p0[B_hstep + 13]; + pp[16 + 12] = p0[14]; + pp[16 + 13] = p0[B_hstep + 14]; + pp[16 + 14] = p0[15]; + pp[16 + 15] = p0[B_hstep + 15]; + pp += 32; + p0 += B_hstep * 2; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + pp[4] = p0[4]; + pp[5] = p0[5]; + pp[6] = p0[6]; + pp[7] = p0[7]; + pp[8] = p0[8]; + pp[9] = p0[9]; + pp[10] = p0[10]; + pp[11] = p0[11]; + pp[12] = p0[12]; + pp[13] = p0[13]; + pp[14] = p0[14]; + pp[15] = p0[15]; + pp += 16; + p0 += B_hstep; + } + } +#endif // __AVX512F__ for (; jj + 7 < max_jj; jj += 8) { const signed char* p0 = B.row(k) + (j + jj); @@ -602,10 +975,75 @@ static void compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, float B_s #if __SSE2__ #if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + // __m512 _v127 = _mm512_set1_ps(127.f); + // __m512 _v127_B_scale = _mm512_set1_ps(v127_B_scale); + for (int ii = 0; ii + 15 < max_ii; ii += 16) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep; + + __m512 _absmax0 = _mm512_setzero_ps(); + int kk = 0; + for (; kk < K; kk++) + { + __m512 _p = _mm512_loadu_ps(p0); + _absmax0 = _mm512_max_ps(_absmax0, abs512_ps(_p)); + p0 += 16; + } + + // __m512 _scale = _mm512_div_ps(_v127, _absmax0); + // __m512 _out_descale = _mm512_div_ps(_absmax0, _v127_B_scale); + + // _mm512_store_ps(ps, _scale); + // _mm512_store_ps(pods, _out_descale); + + float absmax[16]; + _mm512_storeu_ps(absmax, _absmax0); + + ps[0] = 127.f / absmax[0]; + ps[1] = 127.f / absmax[1]; + ps[2] = 127.f / absmax[2]; + ps[3] = 127.f / absmax[3]; + ps[4] = 127.f / absmax[4]; + ps[5] = 127.f / absmax[5]; + ps[6] = 127.f / absmax[6]; + ps[7] = 127.f / absmax[7]; + ps[8] = 127.f / absmax[8]; + ps[9] = 127.f / absmax[9]; + ps[10] = 127.f / absmax[10]; + ps[11] = 127.f / absmax[11]; + ps[12] = 127.f / absmax[12]; + ps[13] = 127.f / absmax[13]; + ps[14] = 127.f / absmax[14]; + ps[15] = 127.f / absmax[15]; + pods[0] = absmax[0] / v127_B_scale; + pods[1] = absmax[1] / v127_B_scale; + pods[2] = absmax[2] / v127_B_scale; + pods[3] = absmax[3] / v127_B_scale; + pods[4] = absmax[4] / v127_B_scale; + pods[5] = absmax[5] / v127_B_scale; + pods[6] = absmax[6] / v127_B_scale; + pods[7] = absmax[7] / v127_B_scale; + pods[8] = absmax[8] / v127_B_scale; + pods[9] = absmax[9] / v127_B_scale; + pods[10] = absmax[10] / v127_B_scale; + pods[11] = absmax[11] / v127_B_scale; + pods[12] = absmax[12] / v127_B_scale; + pods[13] = absmax[13] / v127_B_scale; + pods[14] = absmax[14] / v127_B_scale; + pods[15] = absmax[15] / v127_B_scale; + + ps += 16; + pods += 16; + } + } +#endif // __AVX512F__ if (elempack == 8) { - __m256 _v127 = _mm256_set1_ps(127.f); - __m256 _v127_B_scale = _mm256_set1_ps(v127_B_scale); + // __m256 _v127 = _mm256_set1_ps(127.f); + // __m256 _v127_B_scale = _mm256_set1_ps(v127_B_scale); for (int ii = 0; ii + 7 < max_ii; ii += 8) { const float* p0 = (const float*)A + (i + ii) * A_hstep; @@ -619,11 +1057,31 @@ static void compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, float B_s p0 += 8; } - __m256 _scale = _mm256_div_ps(_v127, _absmax0); - __m256 _out_descale = _mm256_div_ps(_absmax0, _v127_B_scale); - - _mm256_store_ps(ps, _scale); - _mm256_store_ps(pods, _out_descale); + // __m256 _scale = _mm256_div_ps(_v127, _absmax0); + // __m256 _out_descale = _mm256_div_ps(_absmax0, _v127_B_scale); + + // _mm256_store_ps(ps, _scale); + // _mm256_store_ps(pods, _out_descale); + + float absmax[8]; + _mm256_storeu_ps(absmax, _absmax0); + + ps[0] = 127.f / absmax[0]; + ps[1] = 127.f / absmax[1]; + ps[2] = 127.f / absmax[2]; + ps[3] = 127.f / absmax[3]; + ps[4] = 127.f / absmax[4]; + ps[5] = 127.f / absmax[5]; + ps[6] = 127.f / absmax[6]; + ps[7] = 127.f / absmax[7]; + pods[0] = absmax[0] / v127_B_scale; + pods[1] = absmax[1] / v127_B_scale; + pods[2] = absmax[2] / v127_B_scale; + pods[3] = absmax[3] / v127_B_scale; + pods[4] = absmax[4] / v127_B_scale; + pods[5] = absmax[5] / v127_B_scale; + pods[6] = absmax[6] / v127_B_scale; + pods[7] = absmax[7] / v127_B_scale; ps += 8; pods += 8; @@ -632,8 +1090,8 @@ static void compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, float B_s #endif // __AVX__ if (elempack == 4) { - __m128 _v127 = _mm_set1_ps(127.f); - __m128 _v127_B_scale = _mm_set1_ps(v127_B_scale); + // __m128 _v127 = _mm_set1_ps(127.f); + // __m128 _v127_B_scale = _mm_set1_ps(v127_B_scale); for (int ii = 0; ii + 3 < max_ii; ii += 4) { const float* p0 = (const float*)A + (i + ii) * A_hstep; @@ -647,11 +1105,23 @@ static void compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, float B_s p0 += 4; } - __m128 _scale = _mm_div_ps(_v127, _absmax0); - __m128 _out_descale = _mm_div_ps(_absmax0, _v127_B_scale); + // __m128 _scale = _mm_div_ps(_v127, _absmax0); + // __m128 _out_descale = _mm_div_ps(_absmax0, _v127_B_scale); + + // _mm_store_ps(ps, _scale); + // _mm_store_ps(pods, _out_descale); - _mm_store_ps(ps, _scale); - _mm_store_ps(pods, _out_descale); + float absmax[4]; + _mm_storeu_ps(absmax, _absmax0); + + ps[0] = 127.f / absmax[0]; + ps[1] = 127.f / absmax[1]; + ps[2] = 127.f / absmax[2]; + ps[3] = 127.f / absmax[3]; + pods[0] = absmax[0] / v127_B_scale; + pods[1] = absmax[1] / v127_B_scale; + pods[2] = absmax[2] / v127_B_scale; + pods[3] = absmax[3] / v127_B_scale; ps += 4; pods += 4; @@ -698,16 +1168,10 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i int ii = 0; #if __SSE2__ #if __AVX__ - for (; ii + 7 < max_ii; ii += 8) +#if __AVX512F__ + for (; ii + 15 < max_ii; ii += 16) { -#if __AVX2__ - signed char* pp = (signed char*)AT + ii * max_kk; -#else signed char* pp = (signed char*)AT + ii * max_kk; - signed char* pp1 = (signed char*)AT + (ii + 4) * max_kk; - // NCNN_LOGE("pp0 %p", pp); - // NCNN_LOGE("pp1 %p", pp1); -#endif const float* p0 = (const float*)A + (i + ii) * A_hstep + k * elempack; @@ -719,33 +1183,318 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i const float scale5 = scales[i + ii + 5]; const float scale6 = scales[i + ii + 6]; const float scale7 = scales[i + ii + 7]; - - if (elempack == 8) + const float scale8 = scales[i + ii + 8]; + const float scale9 = scales[i + ii + 9]; + const float scalea = scales[i + ii + 10]; + const float scaleb = scales[i + ii + 11]; + const float scalec = scales[i + ii + 12]; + const float scaled = scales[i + ii + 13]; + const float scalee = scales[i + ii + 14]; + const float scalef = scales[i + ii + 15]; + + if (elempack == 16) { int kk = 0; for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[8] * scale0); + pp[1] = float2int8(p0[16] * scale0); pp[2] = float2int8(p0[1] * scale1); - pp[3] = float2int8(p0[9] * scale1); + pp[3] = float2int8(p0[17] * scale1); pp[4] = float2int8(p0[2] * scale2); - pp[5] = float2int8(p0[10] * scale2); + pp[5] = float2int8(p0[18] * scale2); pp[6] = float2int8(p0[3] * scale3); - pp[7] = float2int8(p0[11] * scale3); -#if __AVX2__ + pp[7] = float2int8(p0[19] * scale3); pp[8] = float2int8(p0[4] * scale4); - pp[9] = float2int8(p0[12] * scale4); + pp[9] = float2int8(p0[20] * scale4); pp[10] = float2int8(p0[5] * scale5); - pp[11] = float2int8(p0[13] * scale5); + pp[11] = float2int8(p0[21] * scale5); pp[12] = float2int8(p0[6] * scale6); - pp[13] = float2int8(p0[14] * scale6); + pp[13] = float2int8(p0[22] * scale6); pp[14] = float2int8(p0[7] * scale7); - pp[15] = float2int8(p0[15] * scale7); - pp += 16; -#else - pp1[0] = float2int8(p0[4] * scale4); - pp1[1] = float2int8(p0[12] * scale4); + pp[15] = float2int8(p0[23]* scale7); + pp[16 + 0] = float2int8(p0[8] * scale8); + pp[16 + 1] = float2int8(p0[24] * scale8); + pp[16 + 2] = float2int8(p0[9] * scale9); + pp[16 + 3] = float2int8(p0[25] * scale9); + pp[16 + 4] = float2int8(p0[10] * scalea); + pp[16 + 5] = float2int8(p0[26] * scalea); + pp[16 + 6] = float2int8(p0[11] * scaleb); + pp[16 + 7] = float2int8(p0[27] * scaleb); + pp[16 + 8] = float2int8(p0[12] * scalec); + pp[16 + 9] = float2int8(p0[28] * scalec); + pp[16 + 10] = float2int8(p0[13] * scaled); + pp[16 + 11] = float2int8(p0[29] * scaled); + pp[16 + 12] = float2int8(p0[14] * scalee); + pp[16 + 13] = float2int8(p0[30] * scalee); + pp[16 + 14] = float2int8(p0[15] * scalef); + pp[16 + 15] = float2int8(p0[31]* scalef); + pp += 32; + p0 += 32; + } + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale1); + pp[2] = float2int8(p0[2] * scale2); + pp[3] = float2int8(p0[3] * scale3); + pp[4] = float2int8(p0[4] * scale4); + pp[5] = float2int8(p0[5] * scale5); + pp[6] = float2int8(p0[6] * scale6); + pp[7] = float2int8(p0[7] * scale7); + pp[8] = float2int8(p0[8] * scale8); + pp[9] = float2int8(p0[9] * scale9); + pp[10] = float2int8(p0[10] * scalea); + pp[11] = float2int8(p0[11] * scaleb); + pp[12] = float2int8(p0[12] * scalec); + pp[13] = float2int8(p0[13] * scaled); + pp[14] = float2int8(p0[14] * scalee); + pp[15] = float2int8(p0[15] * scalef); + pp += 16; + p0 += 16; + } + } + if (elempack == 8) + { + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[8] * scale0); + pp[2] = float2int8(p0[1] * scale1); + pp[3] = float2int8(p0[9] * scale1); + pp[4] = float2int8(p0[2] * scale2); + pp[5] = float2int8(p0[10] * scale2); + pp[6] = float2int8(p0[3] * scale3); + pp[7] = float2int8(p0[11] * scale3); + pp[8] = float2int8(p0[4] * scale4); + pp[9] = float2int8(p0[12] * scale4); + pp[10] = float2int8(p0[5] * scale5); + pp[11] = float2int8(p0[13] * scale5); + pp[12] = float2int8(p0[6] * scale6); + pp[13] = float2int8(p0[14] * scale6); + pp[14] = float2int8(p0[7] * scale7); + pp[15] = float2int8(p0[15]* scale7); + + pp[16 + 0] = float2int8(p0[A_hstep * 8 + 0] * scale8); + pp[16 + 1] = float2int8(p0[A_hstep * 8 + 8] * scale8); + pp[16 + 2] = float2int8(p0[A_hstep * 8 + 1] * scale9); + pp[16 + 3] = float2int8(p0[A_hstep * 8 + 9] * scale9); + pp[16 + 4] = float2int8(p0[A_hstep * 8 + 2] * scalea); + pp[16 + 5] = float2int8(p0[A_hstep * 8 + 10] * scalea); + pp[16 + 6] = float2int8(p0[A_hstep * 8 + 3] * scaleb); + pp[16 + 7] = float2int8(p0[A_hstep * 8 + 11] * scaleb); + pp[16 + 8] = float2int8(p0[A_hstep * 8 + 4] * scalec); + pp[16 + 9] = float2int8(p0[A_hstep * 8 + 12] * scalec); + pp[16 + 10] = float2int8(p0[A_hstep * 8 + 5] * scaled); + pp[16 + 11] = float2int8(p0[A_hstep * 8 + 13] * scaled); + pp[16 + 12] = float2int8(p0[A_hstep * 8 + 6] * scalee); + pp[16 + 13] = float2int8(p0[A_hstep * 8 + 14] * scalee); + pp[16 + 14] = float2int8(p0[A_hstep * 8 + 7] * scalef); + pp[16 + 15] = float2int8(p0[A_hstep * 8 + 15]* scalef); + pp += 32; + p0 += 16; + } + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale1); + pp[2] = float2int8(p0[2] * scale2); + pp[3] = float2int8(p0[3] * scale3); + pp[4] = float2int8(p0[4] * scale4); + pp[5] = float2int8(p0[5] * scale5); + pp[6] = float2int8(p0[6] * scale6); + pp[7] = float2int8(p0[7] * scale7); + pp[8] = float2int8(p0[A_hstep * 8 + 0] * scale8); + pp[9] = float2int8(p0[A_hstep * 8 + 1] * scale9); + pp[10] = float2int8(p0[A_hstep * 8 + 2] * scalea); + pp[11] = float2int8(p0[A_hstep * 8 + 3] * scaleb); + pp[12] = float2int8(p0[A_hstep * 8 + 4] * scalec); + pp[13] = float2int8(p0[A_hstep * 8 + 5] * scaled); + pp[14] = float2int8(p0[A_hstep * 8 + 6] * scalee); + pp[15] = float2int8(p0[A_hstep * 8 + 7] * scalef); + pp += 16; + p0 += 8; + } + } + if (elempack == 4) + { + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[4] * scale0); + pp[2] = float2int8(p0[1] * scale1); + pp[3] = float2int8(p0[5] * scale1); + pp[4] = float2int8(p0[2] * scale2); + pp[5] = float2int8(p0[6] * scale2); + pp[6] = float2int8(p0[3] * scale3); + pp[7] = float2int8(p0[7] * scale3); + pp[8] = float2int8(p0[A_hstep * 4 + 0] * scale4); + pp[9] = float2int8(p0[A_hstep * 4 + 4] * scale4); + pp[10] = float2int8(p0[A_hstep * 4 + 1] * scale5); + pp[11] = float2int8(p0[A_hstep * 4 + 5] * scale5); + pp[12] = float2int8(p0[A_hstep * 4 + 2] * scale6); + pp[13] = float2int8(p0[A_hstep * 4 + 6] * scale6); + pp[14] = float2int8(p0[A_hstep * 4 + 3] * scale7); + pp[15] = float2int8(p0[A_hstep * 4 + 7] * scale7); + + pp[16 + 0] = float2int8(p0[A_hstep * 8 + 0] * scale8); + pp[16 + 1] = float2int8(p0[A_hstep * 8 + 4] * scale8); + pp[16 + 2] = float2int8(p0[A_hstep * 8 + 1] * scale9); + pp[16 + 3] = float2int8(p0[A_hstep * 8 + 5] * scale9); + pp[16 + 4] = float2int8(p0[A_hstep * 8 + 2] * scalea); + pp[16 + 5] = float2int8(p0[A_hstep * 8 + 6] * scalea); + pp[16 + 6] = float2int8(p0[A_hstep * 8 + 3] * scaleb); + pp[16 + 7] = float2int8(p0[A_hstep * 8 + 7] * scaleb); + + pp[16 + 8] = float2int8(p0[A_hstep * 12 + 0] * scalec); + pp[16 + 9] = float2int8(p0[A_hstep * 12 + 4] * scalec); + pp[16 + 10] = float2int8(p0[A_hstep * 12 + 1] * scaled); + pp[16 + 11] = float2int8(p0[A_hstep * 12 + 5] * scaled); + pp[16 + 12] = float2int8(p0[A_hstep * 12 + 2] * scalee); + pp[16 + 13] = float2int8(p0[A_hstep * 12 + 6] * scalee); + pp[16 + 14] = float2int8(p0[A_hstep * 12 + 3] * scalef); + pp[16 + 15] = float2int8(p0[A_hstep * 12 + 7] * scalef); + + pp += 32; + p0 += 8; + } + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale1); + pp[2] = float2int8(p0[2] * scale2); + pp[3] = float2int8(p0[3] * scale3); + pp[4] = float2int8(p0[A_hstep * 4] * scale4); + pp[5] = float2int8(p0[A_hstep * 4 + 1] * scale5); + pp[6] = float2int8(p0[A_hstep * 4 + 2] * scale6); + pp[7] = float2int8(p0[A_hstep * 4 + 3] * scale7); + pp[8] = float2int8(p0[A_hstep * 8] * scale8); + pp[9] = float2int8(p0[A_hstep * 8 + 1] * scale9); + pp[10] = float2int8(p0[A_hstep * 8 + 2] * scalea); + pp[11] = float2int8(p0[A_hstep * 8 + 3] * scaleb); + pp[12] = float2int8(p0[A_hstep * 12] * scalec); + pp[13] = float2int8(p0[A_hstep * 12 + 1] * scaled); + pp[14] = float2int8(p0[A_hstep * 12 + 2] * scalee); + pp[15] = float2int8(p0[A_hstep * 12 + 3] * scalef); + pp += 16; + p0 += 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale0); + pp[2] = float2int8(p0[A_hstep] * scale1); + pp[3] = float2int8(p0[A_hstep + 1] * scale1); + pp[4] = float2int8(p0[A_hstep * 2] * scale2); + pp[5] = float2int8(p0[A_hstep * 2 + 1] * scale2); + pp[6] = float2int8(p0[A_hstep * 3] * scale3); + pp[7] = float2int8(p0[A_hstep * 3 + 1] * scale3); + pp[8] = float2int8(p0[A_hstep * 4] * scale4); + pp[9] = float2int8(p0[A_hstep * 4 + 1] * scale4); + pp[10] = float2int8(p0[A_hstep * 5] * scale5); + pp[11] = float2int8(p0[A_hstep * 5 + 1] * scale5); + pp[12] = float2int8(p0[A_hstep * 6] * scale6); + pp[13] = float2int8(p0[A_hstep * 6 + 1] * scale6); + pp[14] = float2int8(p0[A_hstep * 7] * scale7); + pp[15] = float2int8(p0[A_hstep * 7 + 1] * scale7); + + pp[16 + 0] = float2int8(p0[A_hstep * 8] * scale8); + pp[16 + 1] = float2int8(p0[A_hstep * 8 + 1] * scale8); + pp[16 + 2] = float2int8(p0[A_hstep * 9] * scale9); + pp[16 + 3] = float2int8(p0[A_hstep * 9 + 1] * scale9); + pp[16 + 4] = float2int8(p0[A_hstep * 10] * scalea); + pp[16 + 5] = float2int8(p0[A_hstep * 10 + 1] * scalea); + pp[16 + 6] = float2int8(p0[A_hstep * 11] * scaleb); + pp[16 + 7] = float2int8(p0[A_hstep * 11 + 1] * scaleb); + pp[16 + 8] = float2int8(p0[A_hstep * 12] * scalec); + pp[16 + 9] = float2int8(p0[A_hstep * 12 + 1] * scalec); + pp[16 + 10] = float2int8(p0[A_hstep * 13] * scaled); + pp[16 + 11] = float2int8(p0[A_hstep * 13 + 1] * scaled); + pp[16 + 12] = float2int8(p0[A_hstep * 14] * scalee); + pp[16 + 13] = float2int8(p0[A_hstep * 14 + 1] * scalee); + pp[16 + 14] = float2int8(p0[A_hstep * 15] * scalef); + pp[16 + 15] = float2int8(p0[A_hstep * 15 + 1] * scalef); + pp += 32; + p0 += 2; + } + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[A_hstep] * scale1); + pp[2] = float2int8(p0[A_hstep * 2] * scale2); + pp[3] = float2int8(p0[A_hstep * 3] * scale3); + pp[4] = float2int8(p0[A_hstep * 4] * scale4); + pp[5] = float2int8(p0[A_hstep * 5] * scale5); + pp[6] = float2int8(p0[A_hstep * 6] * scale6); + pp[7] = float2int8(p0[A_hstep * 7] * scale7); + pp[8] = float2int8(p0[A_hstep * 8] * scale8); + pp[9] = float2int8(p0[A_hstep * 9] * scale9); + pp[10] = float2int8(p0[A_hstep * 10] * scalea); + pp[11] = float2int8(p0[A_hstep * 11] * scaleb); + pp[12] = float2int8(p0[A_hstep * 12] * scalec); + pp[13] = float2int8(p0[A_hstep * 13] * scaled); + pp[14] = float2int8(p0[A_hstep * 14] * scalee); + pp[15] = float2int8(p0[A_hstep * 15] * scalef); + pp += 16; + p0++; + } + } + } +#endif // __AVX512F__ + for (; ii + 7 < max_ii; ii += 8) + { +#if __AVX2__ + signed char* pp = (signed char*)AT + ii * max_kk; +#else + signed char* pp = (signed char*)AT + ii * max_kk; + signed char* pp1 = (signed char*)AT + (ii + 4) * max_kk; + // NCNN_LOGE("pp0 %p", pp); + // NCNN_LOGE("pp1 %p", pp1); +#endif + + const float* p0 = (const float*)A + (i + ii) * A_hstep + k * elempack; + + const float scale0 = scales[i + ii]; + const float scale1 = scales[i + ii + 1]; + const float scale2 = scales[i + ii + 2]; + const float scale3 = scales[i + ii + 3]; + const float scale4 = scales[i + ii + 4]; + const float scale5 = scales[i + ii + 5]; + const float scale6 = scales[i + ii + 6]; + const float scale7 = scales[i + ii + 7]; + + if (elempack == 8) + { + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[8] * scale0); + pp[2] = float2int8(p0[1] * scale1); + pp[3] = float2int8(p0[9] * scale1); + pp[4] = float2int8(p0[2] * scale2); + pp[5] = float2int8(p0[10] * scale2); + pp[6] = float2int8(p0[3] * scale3); + pp[7] = float2int8(p0[11] * scale3); +#if __AVX2__ + pp[8] = float2int8(p0[4] * scale4); + pp[9] = float2int8(p0[12] * scale4); + pp[10] = float2int8(p0[5] * scale5); + pp[11] = float2int8(p0[13] * scale5); + pp[12] = float2int8(p0[6] * scale6); + pp[13] = float2int8(p0[14] * scale6); + pp[14] = float2int8(p0[7] * scale7); + pp[15] = float2int8(p0[15] * scale7); + pp += 16; +#else + pp1[0] = float2int8(p0[4] * scale4); + pp1[1] = float2int8(p0[12] * scale4); pp1[2] = float2int8(p0[5] * scale5); pp1[3] = float2int8(p0[13] * scale5); pp1[4] = float2int8(p0[6] * scale6); @@ -916,6 +1665,8 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i const float scale2 = scales[i + ii + 2]; const float scale3 = scales[i + ii + 3]; + // NCNN_LOGE("scale %f %f %f %f", scale0, scale1, scale2, scale3); + if (elempack == 4) { int kk = 0; @@ -1040,6 +1791,31 @@ static void transpose_compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, #if __SSE2__ #if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + int ii = 0; + for (; ii < max_ii; ii++) + { + const float* p0 = (const float*)A + (i + ii) * 16; + + __m512 _absmax0 = _mm512_setzero_ps(); + int kk = 0; + for (; kk < K; kk++) + { + __m512 _p = _mm512_loadu_ps(p0); + _absmax0 = _mm512_max_ps(_absmax0, abs512_ps(_p)); + p0 += A_hstep * 8; + } + float absmax = _mm512_reduce_max_ps(_absmax0); + + ps[0] = 127.f / absmax; + pods[0] = absmax / v127_B_scale; + ps++; + pods++; + } + } +#endif // __AVX512F__ if (elempack == 8) { int ii = 0; @@ -1128,14 +1904,10 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int int ii = 0; #if __SSE2__ #if __AVX__ - for (; ii + 7 < max_ii; ii += 8) +#if __AVX512F__ + for (; ii + 15 < max_ii; ii += 16) { -#if __AVX2__ - signed char* pp = (signed char*)AT + ii * max_kk; -#else signed char* pp = (signed char*)AT + ii * max_kk; - signed char* pp1 = (signed char*)AT + (ii + 4) * max_kk; -#endif const float* p0 = (const float*)A + k * A_hstep + (i + ii) * elempack; @@ -1147,7 +1919,296 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int const float scale5 = scales[i + ii + 5]; const float scale6 = scales[i + ii + 6]; const float scale7 = scales[i + ii + 7]; - + const float scale8 = scales[i + ii + 8]; + const float scale9 = scales[i + ii + 9]; + const float scalea = scales[i + ii + 10]; + const float scaleb = scales[i + ii + 11]; + const float scalec = scales[i + ii + 12]; + const float scaled = scales[i + ii + 13]; + const float scalee = scales[i + ii + 14]; + const float scalef = scales[i + ii + 15]; + + if (elempack == 16) + { + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale0); + pp[2] = float2int8(p0[16] * scale1); + pp[3] = float2int8(p0[17] * scale1); + pp[4] = float2int8(p0[32] * scale2); + pp[5] = float2int8(p0[33] * scale2); + pp[6] = float2int8(p0[48] * scale3); + pp[7] = float2int8(p0[49] * scale3); + pp[8] = float2int8(p0[64] * scale4); + pp[9] = float2int8(p0[65] * scale4); + pp[10] = float2int8(p0[80] * scale5); + pp[11] = float2int8(p0[81] * scale5); + pp[12] = float2int8(p0[96] * scale6); + pp[13] = float2int8(p0[97] * scale6); + pp[14] = float2int8(p0[112] * scale7); + pp[15] = float2int8(p0[113] * scale7); + + pp[16 + 0] = float2int8(p0[128 + 0] * scale8); + pp[16 + 1] = float2int8(p0[128 + 1] * scale8); + pp[16 + 2] = float2int8(p0[128 + 16] * scale9); + pp[16 + 3] = float2int8(p0[128 + 17] * scale9); + pp[16 + 4] = float2int8(p0[128 + 32] * scalea); + pp[16 + 5] = float2int8(p0[128 + 33] * scalea); + pp[16 + 6] = float2int8(p0[128 + 48] * scaleb); + pp[16 + 7] = float2int8(p0[128 + 49] * scaleb); + pp[16 + 8] = float2int8(p0[128 + 64] * scalec); + pp[16 + 9] = float2int8(p0[128 + 65] * scalec); + pp[16 + 10] = float2int8(p0[128 + 80] * scaled); + pp[16 + 11] = float2int8(p0[128 + 81] * scaled); + pp[16 + 12] = float2int8(p0[128 + 96] * scalee); + pp[16 + 13] = float2int8(p0[128 + 97] * scalee); + pp[16 + 14] = float2int8(p0[128 + 112] * scalef); + pp[16 + 15] = float2int8(p0[128 + 113] * scalef); + + pp[32 + 0] = float2int8(p0[2 + 0] * scale0); + pp[32 + 1] = float2int8(p0[2 + 1] * scale0); + pp[32 + 2] = float2int8(p0[2 + 16] * scale1); + pp[32 + 3] = float2int8(p0[2 + 17] * scale1); + pp[32 + 4] = float2int8(p0[2 + 32] * scale2); + pp[32 + 5] = float2int8(p0[2 + 33] * scale2); + pp[32 + 6] = float2int8(p0[2 + 48] * scale3); + pp[32 + 7] = float2int8(p0[2 + 49] * scale3); + pp[32 + 8] = float2int8(p0[2 + 64] * scale4); + pp[32 + 9] = float2int8(p0[2 + 65] * scale4); + pp[32 + 10] = float2int8(p0[2 + 80] * scale5); + pp[32 + 11] = float2int8(p0[2 + 81] * scale5); + pp[32 + 12] = float2int8(p0[2 + 96] * scale6); + pp[32 + 13] = float2int8(p0[2 + 97] * scale6); + pp[32 + 14] = float2int8(p0[2 + 112] * scale7); + pp[32 + 15] = float2int8(p0[2 + 113] * scale7); + + pp[48 + 0] = float2int8(p0[2 + 128 + 0] * scale8); + pp[48 + 1] = float2int8(p0[2 + 128 + 1] * scale8); + pp[48 + 2] = float2int8(p0[2 + 128 + 16] * scale9); + pp[48 + 3] = float2int8(p0[2 + 128 + 17] * scale9); + pp[48 + 4] = float2int8(p0[2 + 128 + 32] * scalea); + pp[48 + 5] = float2int8(p0[2 + 128 + 33] * scalea); + pp[48 + 6] = float2int8(p0[2 + 128 + 48] * scaleb); + pp[48 + 7] = float2int8(p0[2 + 128 + 49] * scaleb); + pp[48 + 8] = float2int8(p0[2 + 128 + 64] * scalec); + pp[48 + 9] = float2int8(p0[2 + 128 + 65] * scalec); + pp[48 + 10] = float2int8(p0[2 + 128 + 80] * scaled); + pp[48 + 11] = float2int8(p0[2 + 128 + 81] * scaled); + pp[48 + 12] = float2int8(p0[2 + 128 + 96] * scalee); + pp[48 + 13] = float2int8(p0[2 + 128 + 97] * scalee); + pp[48 + 14] = float2int8(p0[2 + 128 + 112] * scalef); + pp[48 + 15] = float2int8(p0[2 + 128 + 113] * scalef); + + pp[64 + 0] = float2int8(p0[4 + 0] * scale0); + pp[64 + 1] = float2int8(p0[4 + 1] * scale0); + pp[64 + 2] = float2int8(p0[4 + 16] * scale1); + pp[64 + 3] = float2int8(p0[4 + 17] * scale1); + pp[64 + 4] = float2int8(p0[4 + 32] * scale2); + pp[64 + 5] = float2int8(p0[4 + 33] * scale2); + pp[64 + 6] = float2int8(p0[4 + 48] * scale3); + pp[64 + 7] = float2int8(p0[4 + 49] * scale3); + pp[64 + 8] = float2int8(p0[4 + 64] * scale4); + pp[64 + 9] = float2int8(p0[4 + 65] * scale4); + pp[64 + 10] = float2int8(p0[4 + 80] * scale5); + pp[64 + 11] = float2int8(p0[4 + 81] * scale5); + pp[64 + 12] = float2int8(p0[4 + 96] * scale6); + pp[64 + 13] = float2int8(p0[4 + 97] * scale6); + pp[64 + 14] = float2int8(p0[4 + 112] * scale7); + pp[64 + 15] = float2int8(p0[4 + 113] * scale7); + + pp[80 + 0] = float2int8(p0[4 + 128 + 0] * scale8); + pp[80 + 1] = float2int8(p0[4 + 128 + 1] * scale8); + pp[80 + 2] = float2int8(p0[4 + 128 + 16] * scale9); + pp[80 + 3] = float2int8(p0[4 + 128 + 17] * scale9); + pp[80 + 4] = float2int8(p0[4 + 128 + 32] * scalea); + pp[80 + 5] = float2int8(p0[4 + 128 + 33] * scalea); + pp[80 + 6] = float2int8(p0[4 + 128 + 48] * scaleb); + pp[80 + 7] = float2int8(p0[4 + 128 + 49] * scaleb); + pp[80 + 8] = float2int8(p0[4 + 128 + 64] * scalec); + pp[80 + 9] = float2int8(p0[4 + 128 + 65] * scalec); + pp[80 + 10] = float2int8(p0[4 + 128 + 80] * scaled); + pp[80 + 11] = float2int8(p0[4 + 128 + 81] * scaled); + pp[80 + 12] = float2int8(p0[4 + 128 + 96] * scalee); + pp[80 + 13] = float2int8(p0[4 + 128 + 97] * scalee); + pp[80 + 14] = float2int8(p0[4 + 128 + 112] * scalef); + pp[80 + 15] = float2int8(p0[4 + 128 + 113] * scalef); + + pp[96 + 0] = float2int8(p0[6 + 0] * scale0); + pp[96 + 1] = float2int8(p0[6 + 1] * scale0); + pp[96 + 2] = float2int8(p0[6 + 16] * scale1); + pp[96 + 3] = float2int8(p0[6 + 17] * scale1); + pp[96 + 4] = float2int8(p0[6 + 32] * scale2); + pp[96 + 5] = float2int8(p0[6 + 33] * scale2); + pp[96 + 6] = float2int8(p0[6 + 48] * scale3); + pp[96 + 7] = float2int8(p0[6 + 49] * scale3); + pp[96 + 8] = float2int8(p0[6 + 64] * scale4); + pp[96 + 9] = float2int8(p0[6 + 65] * scale4); + pp[96 + 10] = float2int8(p0[6 + 80] * scale5); + pp[96 + 11] = float2int8(p0[6 + 81] * scale5); + pp[96 + 12] = float2int8(p0[6 + 96] * scale6); + pp[96 + 13] = float2int8(p0[6 + 97] * scale6); + pp[96 + 14] = float2int8(p0[6 + 112] * scale7); + pp[96 + 15] = float2int8(p0[6 + 113] * scale7); + + pp[112 + 0] = float2int8(p0[6 + 128 + 0] * scale8); + pp[112 + 1] = float2int8(p0[6 + 128 + 1] * scale8); + pp[112 + 2] = float2int8(p0[6 + 128 + 16] * scale9); + pp[112 + 3] = float2int8(p0[6 + 128 + 17] * scale9); + pp[112 + 4] = float2int8(p0[6 + 128 + 32] * scalea); + pp[112 + 5] = float2int8(p0[6 + 128 + 33] * scalea); + pp[112 + 6] = float2int8(p0[6 + 128 + 48] * scaleb); + pp[112 + 7] = float2int8(p0[6 + 128 + 49] * scaleb); + pp[112 + 8] = float2int8(p0[6 + 128 + 64] * scalec); + pp[112 + 9] = float2int8(p0[6 + 128 + 65] * scalec); + pp[112 + 10] = float2int8(p0[6 + 128 + 80] * scaled); + pp[112 + 11] = float2int8(p0[6 + 128 + 81] * scaled); + pp[112 + 12] = float2int8(p0[6 + 128 + 96] * scalee); + pp[112 + 13] = float2int8(p0[6 + 128 + 97] * scalee); + pp[112 + 14] = float2int8(p0[6 + 128 + 112] * scalef); + pp[112 + 15] = float2int8(p0[6 + 128 + 113] * scalef); + + pp[128 + 0] = float2int8(p0[8 + 0] * scale0); + pp[128 + 1] = float2int8(p0[8 + 1] * scale0); + pp[128 + 2] = float2int8(p0[8 + 16] * scale1); + pp[128 + 3] = float2int8(p0[8 + 17] * scale1); + pp[128 + 4] = float2int8(p0[8 + 32] * scale2); + pp[128 + 5] = float2int8(p0[8 + 33] * scale2); + pp[128 + 6] = float2int8(p0[8 + 48] * scale3); + pp[128 + 7] = float2int8(p0[8 + 49] * scale3); + pp[128 + 8] = float2int8(p0[8 + 64] * scale4); + pp[128 + 9] = float2int8(p0[8 + 65] * scale4); + pp[128 + 10] = float2int8(p0[8 + 80] * scale5); + pp[128 + 11] = float2int8(p0[8 + 81] * scale5); + pp[128 + 12] = float2int8(p0[8 + 96] * scale6); + pp[128 + 13] = float2int8(p0[8 + 97] * scale6); + pp[128 + 14] = float2int8(p0[8 + 112] * scale7); + pp[128 + 15] = float2int8(p0[8 + 113] * scale7); + + pp[16 + 128 + 0] = float2int8(p0[8 + 128 + 0] * scale8); + pp[16 + 128 + 1] = float2int8(p0[8 + 128 + 1] * scale8); + pp[16 + 128 + 2] = float2int8(p0[8 + 128 + 16] * scale9); + pp[16 + 128 + 3] = float2int8(p0[8 + 128 + 17] * scale9); + pp[16 + 128 + 4] = float2int8(p0[8 + 128 + 32] * scalea); + pp[16 + 128 + 5] = float2int8(p0[8 + 128 + 33] * scalea); + pp[16 + 128 + 6] = float2int8(p0[8 + 128 + 48] * scaleb); + pp[16 + 128 + 7] = float2int8(p0[8 + 128 + 49] * scaleb); + pp[16 + 128 + 8] = float2int8(p0[8 + 128 + 64] * scalec); + pp[16 + 128 + 9] = float2int8(p0[8 + 128 + 65] * scalec); + pp[16 + 128 + 10] = float2int8(p0[8 + 128 + 80] * scaled); + pp[16 + 128 + 11] = float2int8(p0[8 + 128 + 81] * scaled); + pp[16 + 128 + 12] = float2int8(p0[8 + 128 + 96] * scalee); + pp[16 + 128 + 13] = float2int8(p0[8 + 128 + 97] * scalee); + pp[16 + 128 + 14] = float2int8(p0[8 + 128 + 112] * scalef); + pp[16 + 128 + 15] = float2int8(p0[8 + 128 + 113] * scalef); + + pp[32 + 128 + 0] = float2int8(p0[10 + 0] * scale0); + pp[32 + 128 + 1] = float2int8(p0[10 + 1] * scale0); + pp[32 + 128 + 2] = float2int8(p0[10 + 16] * scale1); + pp[32 + 128 + 3] = float2int8(p0[10 + 17] * scale1); + pp[32 + 128 + 4] = float2int8(p0[10 + 32] * scale2); + pp[32 + 128 + 5] = float2int8(p0[10 + 33] * scale2); + pp[32 + 128 + 6] = float2int8(p0[10 + 48] * scale3); + pp[32 + 128 + 7] = float2int8(p0[10 + 49] * scale3); + pp[32 + 128 + 8] = float2int8(p0[10 + 64] * scale4); + pp[32 + 128 + 9] = float2int8(p0[10 + 65] * scale4); + pp[32 + 128 + 10] = float2int8(p0[10 + 80] * scale5); + pp[32 + 128 + 11] = float2int8(p0[10 + 81] * scale5); + pp[32 + 128 + 12] = float2int8(p0[10 + 96] * scale6); + pp[32 + 128 + 13] = float2int8(p0[10 + 97] * scale6); + pp[32 + 128 + 14] = float2int8(p0[10 + 112] * scale7); + pp[32 + 128 + 15] = float2int8(p0[10 + 113] * scale7); + + pp[48 + 128 + 0] = float2int8(p0[10 + 128 + 0] * scale8); + pp[48 + 128 + 1] = float2int8(p0[10 + 128 + 1] * scale8); + pp[48 + 128 + 2] = float2int8(p0[10 + 128 + 16] * scale9); + pp[48 + 128 + 3] = float2int8(p0[10 + 128 + 17] * scale9); + pp[48 + 128 + 4] = float2int8(p0[10 + 128 + 32] * scalea); + pp[48 + 128 + 5] = float2int8(p0[10 + 128 + 33] * scalea); + pp[48 + 128 + 6] = float2int8(p0[10 + 128 + 48] * scaleb); + pp[48 + 128 + 7] = float2int8(p0[10 + 128 + 49] * scaleb); + pp[48 + 128 + 8] = float2int8(p0[10 + 128 + 64] * scalec); + pp[48 + 128 + 9] = float2int8(p0[10 + 128 + 65] * scalec); + pp[48 + 128 + 10] = float2int8(p0[10 + 128 + 80] * scaled); + pp[48 + 128 + 11] = float2int8(p0[10 + 128 + 81] * scaled); + pp[48 + 128 + 12] = float2int8(p0[10 + 128 + 96] * scalee); + pp[48 + 128 + 13] = float2int8(p0[10 + 128 + 97] * scalee); + pp[48 + 128 + 14] = float2int8(p0[10 + 128 + 112] * scalef); + pp[48 + 128 + 15] = float2int8(p0[10 + 128 + 113] * scalef); + + pp[64 + 128 + 0] = float2int8(p0[12 + 0] * scale0); + pp[64 + 128 + 1] = float2int8(p0[12 + 1] * scale0); + pp[64 + 128 + 2] = float2int8(p0[12 + 16] * scale1); + pp[64 + 128 + 3] = float2int8(p0[12 + 17] * scale1); + pp[64 + 128 + 4] = float2int8(p0[12 + 32] * scale2); + pp[64 + 128 + 5] = float2int8(p0[12 + 33] * scale2); + pp[64 + 128 + 6] = float2int8(p0[12 + 48] * scale3); + pp[64 + 128 + 7] = float2int8(p0[12 + 49] * scale3); + pp[64 + 128 + 8] = float2int8(p0[12 + 64] * scale4); + pp[64 + 128 + 9] = float2int8(p0[12 + 65] * scale4); + pp[64 + 128 + 10] = float2int8(p0[12 + 80] * scale5); + pp[64 + 128 + 11] = float2int8(p0[12 + 81] * scale5); + pp[64 + 128 + 12] = float2int8(p0[12 + 96] * scale6); + pp[64 + 128 + 13] = float2int8(p0[12 + 97] * scale6); + pp[64 + 128 + 14] = float2int8(p0[12 + 112] * scale7); + pp[64 + 128 + 15] = float2int8(p0[12 + 113] * scale7); + + pp[80 + 128 + 0] = float2int8(p0[12 + 128 + 0] * scale8); + pp[80 + 128 + 1] = float2int8(p0[12 + 128 + 1] * scale8); + pp[80 + 128 + 2] = float2int8(p0[12 + 128 + 16] * scale9); + pp[80 + 128 + 3] = float2int8(p0[12 + 128 + 17] * scale9); + pp[80 + 128 + 4] = float2int8(p0[12 + 128 + 32] * scalea); + pp[80 + 128 + 5] = float2int8(p0[12 + 128 + 33] * scalea); + pp[80 + 128 + 6] = float2int8(p0[12 + 128 + 48] * scaleb); + pp[80 + 128 + 7] = float2int8(p0[12 + 128 + 49] * scaleb); + pp[80 + 128 + 8] = float2int8(p0[12 + 128 + 64] * scalec); + pp[80 + 128 + 9] = float2int8(p0[12 + 128 + 65] * scalec); + pp[80 + 128 + 10] = float2int8(p0[12 + 128 + 80] * scaled); + pp[80 + 128 + 11] = float2int8(p0[12 + 128 + 81] * scaled); + pp[80 + 128 + 12] = float2int8(p0[12 + 128 + 96] * scalee); + pp[80 + 128 + 13] = float2int8(p0[12 + 128 + 97] * scalee); + pp[80 + 128 + 14] = float2int8(p0[12 + 128 + 112] * scalef); + pp[80 + 128 + 15] = float2int8(p0[12 + 128 + 113] * scalef); + + pp[96 + 128 + 0] = float2int8(p0[14 + 0] * scale0); + pp[96 + 128 + 1] = float2int8(p0[14 + 1] * scale0); + pp[96 + 128 + 2] = float2int8(p0[14 + 16] * scale1); + pp[96 + 128 + 3] = float2int8(p0[14 + 17] * scale1); + pp[96 + 128 + 4] = float2int8(p0[14 + 32] * scale2); + pp[96 + 128 + 5] = float2int8(p0[14 + 33] * scale2); + pp[96 + 128 + 6] = float2int8(p0[14 + 48] * scale3); + pp[96 + 128 + 7] = float2int8(p0[14 + 49] * scale3); + pp[96 + 128 + 8] = float2int8(p0[14 + 64] * scale4); + pp[96 + 128 + 9] = float2int8(p0[14 + 65] * scale4); + pp[96 + 128 + 10] = float2int8(p0[14 + 80] * scale5); + pp[96 + 128 + 11] = float2int8(p0[14 + 81] * scale5); + pp[96 + 128 + 12] = float2int8(p0[14 + 96] * scale6); + pp[96 + 128 + 13] = float2int8(p0[14 + 97] * scale6); + pp[96 + 128 + 14] = float2int8(p0[14 + 112] * scale7); + pp[96 + 128 + 15] = float2int8(p0[14 + 113] * scale7); + + pp[112 + 128 + 0] = float2int8(p0[14 + 128 + 0] * scale8); + pp[112 + 128 + 1] = float2int8(p0[14 + 128 + 1] * scale8); + pp[112 + 128 + 2] = float2int8(p0[14 + 128 + 16] * scale9); + pp[112 + 128 + 3] = float2int8(p0[14 + 128 + 17] * scale9); + pp[112 + 128 + 4] = float2int8(p0[14 + 128 + 32] * scalea); + pp[112 + 128 + 5] = float2int8(p0[14 + 128 + 33] * scalea); + pp[112 + 128 + 6] = float2int8(p0[14 + 128 + 48] * scaleb); + pp[112 + 128 + 7] = float2int8(p0[14 + 128 + 49] * scaleb); + pp[112 + 128 + 8] = float2int8(p0[14 + 128 + 64] * scalec); + pp[112 + 128 + 9] = float2int8(p0[14 + 128 + 65] * scalec); + pp[112 + 128 + 10] = float2int8(p0[14 + 128 + 80] * scaled); + pp[112 + 128 + 11] = float2int8(p0[14 + 128 + 81] * scaled); + pp[112 + 128 + 12] = float2int8(p0[14 + 128 + 96] * scalee); + pp[112 + 128 + 13] = float2int8(p0[14 + 128 + 97] * scalee); + pp[112 + 128 + 14] = float2int8(p0[14 + 128 + 112] * scalef); + pp[112 + 128 + 15] = float2int8(p0[14 + 128 + 113] * scalef); + + pp += 256; + p0 += A_hstep * 16; + } + } if (elempack == 8) { int kk = 0; @@ -1161,7 +2222,6 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int pp[5] = float2int8(p0[17] * scale2); pp[6] = float2int8(p0[24] * scale3); pp[7] = float2int8(p0[25] * scale3); -#if __AVX2__ pp[8] = float2int8(p0[32] * scale4); pp[9] = float2int8(p0[33] * scale4); pp[10] = float2int8(p0[40] * scale5); @@ -1170,47 +2230,502 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int pp[13] = float2int8(p0[49] * scale6); pp[14] = float2int8(p0[56] * scale7); pp[15] = float2int8(p0[57] * scale7); - pp += 16; -#else - pp1[0] = float2int8(p0[32] * scale4); - pp1[1] = float2int8(p0[33] * scale4); - pp1[2] = float2int8(p0[40] * scale5); - pp1[3] = float2int8(p0[41] * scale5); - pp1[4] = float2int8(p0[48] * scale6); - pp1[5] = float2int8(p0[49] * scale6); - pp1[6] = float2int8(p0[56] * scale7); - pp1[7] = float2int8(p0[57] * scale7); - pp += 8; - pp1 += 8; -#endif - pp[0] = float2int8(p0[2] * scale0); - pp[1] = float2int8(p0[3] * scale0); - pp[2] = float2int8(p0[10] * scale1); - pp[3] = float2int8(p0[11] * scale1); - pp[4] = float2int8(p0[18] * scale2); - pp[5] = float2int8(p0[19] * scale2); - pp[6] = float2int8(p0[26] * scale3); - pp[7] = float2int8(p0[27] * scale3); -#if __AVX2__ - pp[8] = float2int8(p0[34] * scale4); - pp[9] = float2int8(p0[35] * scale4); - pp[10] = float2int8(p0[42] * scale5); - pp[11] = float2int8(p0[43] * scale5); - pp[12] = float2int8(p0[50] * scale6); - pp[13] = float2int8(p0[51] * scale6); - pp[14] = float2int8(p0[58] * scale7); - pp[15] = float2int8(p0[59] * scale7); - pp += 16; -#else - pp1[0] = float2int8(p0[34] * scale4); - pp1[1] = float2int8(p0[35] * scale4); - pp1[2] = float2int8(p0[42] * scale5); - pp1[3] = float2int8(p0[43] * scale5); - pp1[4] = float2int8(p0[50] * scale6); - pp1[5] = float2int8(p0[51] * scale6); - pp1[6] = float2int8(p0[58] * scale7); - pp1[7] = float2int8(p0[59] * scale7); + pp[16 + 0] = float2int8(p0[64 + 0] * scale8); + pp[16 + 1] = float2int8(p0[64 + 1] * scale8); + pp[16 + 2] = float2int8(p0[64 + 8] * scale9); + pp[16 + 3] = float2int8(p0[64 + 9] * scale9); + pp[16 + 4] = float2int8(p0[64 + 16] * scalea); + pp[16 + 5] = float2int8(p0[64 + 17] * scalea); + pp[16 + 6] = float2int8(p0[64 + 24] * scaleb); + pp[16 + 7] = float2int8(p0[64 + 25] * scaleb); + pp[16 + 8] = float2int8(p0[64 + 32] * scalec); + pp[16 + 9] = float2int8(p0[64 + 33] * scalec); + pp[16 + 10] = float2int8(p0[64 + 40] * scaled); + pp[16 + 11] = float2int8(p0[64 + 41] * scaled); + pp[16 + 12] = float2int8(p0[64 + 48] * scalee); + pp[16 + 13] = float2int8(p0[64 + 49] * scalee); + pp[16 + 14] = float2int8(p0[64 + 56] * scalef); + pp[16 + 15] = float2int8(p0[64 + 57] * scalef); + + pp[32 + 0] = float2int8(p0[2] * scale0); + pp[32 + 1] = float2int8(p0[3] * scale0); + pp[32 + 2] = float2int8(p0[10] * scale1); + pp[32 + 3] = float2int8(p0[11] * scale1); + pp[32 + 4] = float2int8(p0[18] * scale2); + pp[32 + 5] = float2int8(p0[19] * scale2); + pp[32 + 6] = float2int8(p0[26] * scale3); + pp[32 + 7] = float2int8(p0[27] * scale3); + pp[32 + 8] = float2int8(p0[34] * scale4); + pp[32 + 9] = float2int8(p0[35] * scale4); + pp[32 + 10] = float2int8(p0[42] * scale5); + pp[32 + 11] = float2int8(p0[43] * scale5); + pp[32 + 12] = float2int8(p0[50] * scale6); + pp[32 + 13] = float2int8(p0[51] * scale6); + pp[32 + 14] = float2int8(p0[58] * scale7); + pp[32 + 15] = float2int8(p0[59] * scale7); + + pp[48 + 0] = float2int8(p0[64 + 2] * scale8); + pp[48 + 1] = float2int8(p0[64 + 3] * scale8); + pp[48 + 2] = float2int8(p0[64 + 10] * scale9); + pp[48 + 3] = float2int8(p0[64 + 11] * scale9); + pp[48 + 4] = float2int8(p0[64 + 18] * scalea); + pp[48 + 5] = float2int8(p0[64 + 19] * scalea); + pp[48 + 6] = float2int8(p0[64 + 26] * scaleb); + pp[48 + 7] = float2int8(p0[64 + 27] * scaleb); + pp[48 + 8] = float2int8(p0[64 + 34] * scalec); + pp[48 + 9] = float2int8(p0[64 + 35] * scalec); + pp[48 + 10] = float2int8(p0[64 + 42] * scaled); + pp[48 + 11] = float2int8(p0[64 + 43] * scaled); + pp[48 + 12] = float2int8(p0[64 + 50] * scalee); + pp[48 + 13] = float2int8(p0[64 + 51] * scalee); + pp[48 + 14] = float2int8(p0[64 + 58] * scalef); + pp[48 + 15] = float2int8(p0[64 + 59] * scalef); + + pp[64 + 0] = float2int8(p0[4] * scale0); + pp[64 + 1] = float2int8(p0[5] * scale0); + pp[64 + 2] = float2int8(p0[12] * scale1); + pp[64 + 3] = float2int8(p0[13] * scale1); + pp[64 + 4] = float2int8(p0[20] * scale2); + pp[64 + 5] = float2int8(p0[21] * scale2); + pp[64 + 6] = float2int8(p0[28] * scale3); + pp[64 + 7] = float2int8(p0[29] * scale3); + pp[64 + 8] = float2int8(p0[36] * scale4); + pp[64 + 9] = float2int8(p0[37] * scale4); + pp[64 + 10] = float2int8(p0[44] * scale5); + pp[64 + 11] = float2int8(p0[45] * scale5); + pp[64 + 12] = float2int8(p0[52] * scale6); + pp[64 + 13] = float2int8(p0[53] * scale6); + pp[64 + 14] = float2int8(p0[60] * scale7); + pp[64 + 15] = float2int8(p0[61] * scale7); + + pp[80 + 0] = float2int8(p0[64 + 4] * scale8); + pp[80 + 1] = float2int8(p0[64 + 5] * scale8); + pp[80 + 2] = float2int8(p0[64 + 12] * scale9); + pp[80 + 3] = float2int8(p0[64 + 13] * scale9); + pp[80 + 4] = float2int8(p0[64 + 20] * scalea); + pp[80 + 5] = float2int8(p0[64 + 21] * scalea); + pp[80 + 6] = float2int8(p0[64 + 28] * scaleb); + pp[80 + 7] = float2int8(p0[64 + 29] * scaleb); + pp[80 + 8] = float2int8(p0[64 + 36] * scalec); + pp[80 + 9] = float2int8(p0[64 + 37] * scalec); + pp[80 + 10] = float2int8(p0[64 + 44] * scaled); + pp[80 + 11] = float2int8(p0[64 + 45] * scaled); + pp[80 + 12] = float2int8(p0[64 + 52] * scalee); + pp[80 + 13] = float2int8(p0[64 + 53] * scalee); + pp[80 + 14] = float2int8(p0[64 + 60] * scalef); + pp[80 + 15] = float2int8(p0[64 + 61] * scalef); + + pp[96 + 0] = float2int8(p0[6] * scale0); + pp[96 + 1] = float2int8(p0[7] * scale0); + pp[96 + 2] = float2int8(p0[14] * scale1); + pp[96 + 3] = float2int8(p0[15] * scale1); + pp[96 + 4] = float2int8(p0[22] * scale2); + pp[96 + 5] = float2int8(p0[23] * scale2); + pp[96 + 6] = float2int8(p0[30] * scale3); + pp[96 + 7] = float2int8(p0[31] * scale3); + pp[96 + 8] = float2int8(p0[38] * scale4); + pp[96 + 9] = float2int8(p0[39] * scale4); + pp[96 + 10] = float2int8(p0[46] * scale5); + pp[96 + 11] = float2int8(p0[47] * scale5); + pp[96 + 12] = float2int8(p0[54] * scale6); + pp[96 + 13] = float2int8(p0[55] * scale6); + pp[96 + 14] = float2int8(p0[62] * scale7); + pp[96 + 15] = float2int8(p0[63] * scale7); + + pp[112 + 0] = float2int8(p0[64 + 6] * scale8); + pp[112 + 1] = float2int8(p0[64 + 7] * scale8); + pp[112 + 2] = float2int8(p0[64 + 14] * scale9); + pp[112 + 3] = float2int8(p0[64 + 15] * scale9); + pp[112 + 4] = float2int8(p0[64 + 22] * scalea); + pp[112 + 5] = float2int8(p0[64 + 23] * scalea); + pp[112 + 6] = float2int8(p0[64 + 30] * scaleb); + pp[112 + 7] = float2int8(p0[64 + 31] * scaleb); + pp[112 + 8] = float2int8(p0[64 + 38] * scalec); + pp[112 + 9] = float2int8(p0[64 + 39] * scalec); + pp[112 + 10] = float2int8(p0[64 + 46] * scaled); + pp[112 + 11] = float2int8(p0[64 + 47] * scaled); + pp[112 + 12] = float2int8(p0[64 + 54] * scalee); + pp[112 + 13] = float2int8(p0[64 + 55] * scalee); + pp[112 + 14] = float2int8(p0[64 + 62] * scalef); + pp[112 + 15] = float2int8(p0[64 + 63] * scalef); + + pp += 128; + p0 += A_hstep * 8; + } + } + if (elempack == 4) + { + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale0); + pp[2] = float2int8(p0[4] * scale1); + pp[3] = float2int8(p0[5] * scale1); + pp[4] = float2int8(p0[8] * scale2); + pp[5] = float2int8(p0[9] * scale2); + pp[6] = float2int8(p0[12] * scale3); + pp[7] = float2int8(p0[13] * scale3); + pp[8] = float2int8(p0[16] * scale4); + pp[9] = float2int8(p0[17] * scale4); + pp[10] = float2int8(p0[20] * scale5); + pp[11] = float2int8(p0[21] * scale5); + pp[12] = float2int8(p0[24] * scale6); + pp[13] = float2int8(p0[25] * scale6); + pp[14] = float2int8(p0[28] * scale7); + pp[15] = float2int8(p0[29] * scale7); + + pp[16 + 0] = float2int8(p0[32 + 0] * scale8); + pp[16 + 1] = float2int8(p0[32 + 1] * scale8); + pp[16 + 2] = float2int8(p0[32 + 4] * scale9); + pp[16 + 3] = float2int8(p0[32 + 5] * scale9); + pp[16 + 4] = float2int8(p0[32 + 8] * scalea); + pp[16 + 5] = float2int8(p0[32 + 9] * scalea); + pp[16 + 6] = float2int8(p0[32 + 12] * scaleb); + pp[16 + 7] = float2int8(p0[32 + 13] * scaleb); + pp[16 + 8] = float2int8(p0[32 + 16] * scalec); + pp[16 + 9] = float2int8(p0[32 + 17] * scalec); + pp[16 + 10] = float2int8(p0[32 + 20] * scaled); + pp[16 + 11] = float2int8(p0[32 + 21] * scaled); + pp[16 + 12] = float2int8(p0[32 + 24] * scalee); + pp[16 + 13] = float2int8(p0[32 + 25] * scalee); + pp[16 + 14] = float2int8(p0[32 + 28] * scalef); + pp[16 + 15] = float2int8(p0[32 + 29] * scalef); + + pp[32 + 0] = float2int8(p0[2] * scale0); + pp[32 + 1] = float2int8(p0[3] * scale0); + pp[32 + 2] = float2int8(p0[6] * scale1); + pp[32 + 3] = float2int8(p0[7] * scale1); + pp[32 + 4] = float2int8(p0[10] * scale2); + pp[32 + 5] = float2int8(p0[11] * scale2); + pp[32 + 6] = float2int8(p0[14] * scale3); + pp[32 + 7] = float2int8(p0[15] * scale3); + pp[32 + 8] = float2int8(p0[18] * scale4); + pp[32 + 9] = float2int8(p0[19] * scale4); + pp[32 + 10] = float2int8(p0[22] * scale5); + pp[32 + 11] = float2int8(p0[23] * scale5); + pp[32 + 12] = float2int8(p0[26] * scale6); + pp[32 + 13] = float2int8(p0[27] * scale6); + pp[32 + 14] = float2int8(p0[30] * scale7); + pp[32 + 15] = float2int8(p0[31] * scale7); + + pp[48 + 0] = float2int8(p0[32 + 2] * scale8); + pp[48 + 1] = float2int8(p0[32 + 3] * scale8); + pp[48 + 2] = float2int8(p0[32 + 6] * scale9); + pp[48 + 3] = float2int8(p0[32 + 7] * scale9); + pp[48 + 4] = float2int8(p0[32 + 10] * scalea); + pp[48 + 5] = float2int8(p0[32 + 11] * scalea); + pp[48 + 6] = float2int8(p0[32 + 14] * scaleb); + pp[48 + 7] = float2int8(p0[32 + 15] * scaleb); + pp[48 + 8] = float2int8(p0[32 + 18] * scalec); + pp[48 + 9] = float2int8(p0[32 + 19] * scalec); + pp[48 + 10] = float2int8(p0[32 + 22] * scaled); + pp[48 + 11] = float2int8(p0[32 + 23] * scaled); + pp[48 + 12] = float2int8(p0[32 + 26] * scalee); + pp[48 + 13] = float2int8(p0[32 + 27] * scalee); + pp[48 + 14] = float2int8(p0[32 + 30] * scalef); + pp[48 + 15] = float2int8(p0[32 + 31] * scalef); + + pp += 64; + p0 += A_hstep * 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[A_hstep] * scale0); + pp[2] = float2int8(p0[1] * scale1); + pp[3] = float2int8(p0[A_hstep + 1] * scale1); + pp[4] = float2int8(p0[2] * scale2); + pp[5] = float2int8(p0[A_hstep + 2] * scale2); + pp[6] = float2int8(p0[3] * scale3); + pp[7] = float2int8(p0[A_hstep + 3] * scale3); + pp[8] = float2int8(p0[4] * scale4); + pp[9] = float2int8(p0[A_hstep + 4] * scale4); + pp[10] = float2int8(p0[5] * scale5); + pp[11] = float2int8(p0[A_hstep + 5] * scale5); + pp[12] = float2int8(p0[6] * scale6); + pp[13] = float2int8(p0[A_hstep + 6] * scale6); + pp[14] = float2int8(p0[7] * scale7); + pp[15] = float2int8(p0[A_hstep + 7] * scale7); + + pp[16 + 0] = float2int8(p0[8] * scale8); + pp[16 + 1] = float2int8(p0[A_hstep + 8] * scale8); + pp[16 + 2] = float2int8(p0[9] * scale9); + pp[16 + 3] = float2int8(p0[A_hstep + 9] * scale9); + pp[16 + 4] = float2int8(p0[10] * scalea); + pp[16 + 5] = float2int8(p0[A_hstep + 10] * scalea); + pp[16 + 6] = float2int8(p0[11] * scaleb); + pp[16 + 7] = float2int8(p0[A_hstep + 11] * scaleb); + pp[16 + 8] = float2int8(p0[12] * scalec); + pp[16 + 9] = float2int8(p0[A_hstep + 12] * scalec); + pp[16 + 10] = float2int8(p0[13] * scaled); + pp[16 + 11] = float2int8(p0[A_hstep + 13] * scaled); + pp[16 + 12] = float2int8(p0[14] * scalee); + pp[16 + 13] = float2int8(p0[A_hstep + 14] * scalee); + pp[16 + 14] = float2int8(p0[15] * scalef); + pp[16 + 15] = float2int8(p0[A_hstep + 15] * scalef); + pp += 32; + p0 += A_hstep * 2; + } + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale1); + pp[2] = float2int8(p0[2] * scale2); + pp[3] = float2int8(p0[3] * scale3); + pp[4] = float2int8(p0[4] * scale4); + pp[5] = float2int8(p0[5] * scale5); + pp[6] = float2int8(p0[6] * scale6); + pp[7] = float2int8(p0[7] * scale7); + pp[8] = float2int8(p0[8] * scale8); + pp[9] = float2int8(p0[9] * scale9); + pp[10] = float2int8(p0[10] * scalea); + pp[11] = float2int8(p0[11] * scaleb); + pp[12] = float2int8(p0[12] * scalec); + pp[13] = float2int8(p0[13] * scaled); + pp[14] = float2int8(p0[14] * scalee); + pp[15] = float2int8(p0[15] * scalef); + pp += 16; + p0 += A_hstep; + } + } + } +#endif // __AVX512F__ + for (; ii + 7 < max_ii; ii += 8) + { +#if __AVX2__ + signed char* pp = (signed char*)AT + ii * max_kk; +#else + signed char* pp = (signed char*)AT + ii * max_kk; + signed char* pp1 = (signed char*)AT + (ii + 4) * max_kk; +#endif + + const float* p0 = (const float*)A + k * A_hstep + (i + ii) * elempack; + + const float scale0 = scales[i + ii]; + const float scale1 = scales[i + ii + 1]; + const float scale2 = scales[i + ii + 2]; + const float scale3 = scales[i + ii + 3]; + const float scale4 = scales[i + ii + 4]; + const float scale5 = scales[i + ii + 5]; + const float scale6 = scales[i + ii + 6]; + const float scale7 = scales[i + ii + 7]; + +#if __AVX512F__ + if (elempack == 16) + { + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale0); + pp[2] = float2int8(p0[16] * scale1); + pp[3] = float2int8(p0[17] * scale1); + pp[4] = float2int8(p0[32] * scale2); + pp[5] = float2int8(p0[33] * scale2); + pp[6] = float2int8(p0[48] * scale3); + pp[7] = float2int8(p0[49] * scale3); + pp[8] = float2int8(p0[64] * scale4); + pp[9] = float2int8(p0[65] * scale4); + pp[10] = float2int8(p0[80] * scale5); + pp[11] = float2int8(p0[81] * scale5); + pp[12] = float2int8(p0[96] * scale6); + pp[13] = float2int8(p0[97] * scale6); + pp[14] = float2int8(p0[112] * scale7); + pp[15] = float2int8(p0[113] * scale7); + + pp[16 + 0] = float2int8(p0[2 + 0] * scale0); + pp[16 + 1] = float2int8(p0[2 + 1] * scale0); + pp[16 + 2] = float2int8(p0[2 + 16] * scale1); + pp[16 + 3] = float2int8(p0[2 + 17] * scale1); + pp[16 + 4] = float2int8(p0[2 + 32] * scale2); + pp[16 + 5] = float2int8(p0[2 + 33] * scale2); + pp[16 + 6] = float2int8(p0[2 + 48] * scale3); + pp[16 + 7] = float2int8(p0[2 + 49] * scale3); + pp[16 + 8] = float2int8(p0[2 + 64] * scale4); + pp[16 + 9] = float2int8(p0[2 + 65] * scale4); + pp[16 + 10] = float2int8(p0[2 + 80] * scale5); + pp[16 + 11] = float2int8(p0[2 + 81] * scale5); + pp[16 + 12] = float2int8(p0[2 + 96] * scale6); + pp[16 + 13] = float2int8(p0[2 + 97] * scale6); + pp[16 + 14] = float2int8(p0[2 + 112] * scale7); + pp[16 + 15] = float2int8(p0[2 + 113] * scale7); + + pp[32 + 0] = float2int8(p0[4 + 0] * scale0); + pp[32 + 1] = float2int8(p0[4 + 1] * scale0); + pp[32 + 2] = float2int8(p0[4 + 16] * scale1); + pp[32 + 3] = float2int8(p0[4 + 17] * scale1); + pp[32 + 4] = float2int8(p0[4 + 32] * scale2); + pp[32 + 5] = float2int8(p0[4 + 33] * scale2); + pp[32 + 6] = float2int8(p0[4 + 48] * scale3); + pp[32 + 7] = float2int8(p0[4 + 49] * scale3); + pp[32 + 8] = float2int8(p0[4 + 64] * scale4); + pp[32 + 9] = float2int8(p0[4 + 65] * scale4); + pp[32 + 10] = float2int8(p0[4 + 80] * scale5); + pp[32 + 11] = float2int8(p0[4 + 81] * scale5); + pp[32 + 12] = float2int8(p0[4 + 96] * scale6); + pp[32 + 13] = float2int8(p0[4 + 97] * scale6); + pp[32 + 14] = float2int8(p0[4 + 112] * scale7); + pp[32 + 15] = float2int8(p0[4 + 113] * scale7); + + pp[48 + 0] = float2int8(p0[6 + 0] * scale0); + pp[48 + 1] = float2int8(p0[6 + 1] * scale0); + pp[48 + 2] = float2int8(p0[6 + 16] * scale1); + pp[48 + 3] = float2int8(p0[6 + 17] * scale1); + pp[48 + 4] = float2int8(p0[6 + 32] * scale2); + pp[48 + 5] = float2int8(p0[6 + 33] * scale2); + pp[48 + 6] = float2int8(p0[6 + 48] * scale3); + pp[48 + 7] = float2int8(p0[6 + 49] * scale3); + pp[48 + 8] = float2int8(p0[6 + 64] * scale4); + pp[48 + 9] = float2int8(p0[6 + 65] * scale4); + pp[48 + 10] = float2int8(p0[6 + 80] * scale5); + pp[48 + 11] = float2int8(p0[6 + 81] * scale5); + pp[48 + 12] = float2int8(p0[6 + 96] * scale6); + pp[48 + 13] = float2int8(p0[6 + 97] * scale6); + pp[48 + 14] = float2int8(p0[6 + 112] * scale7); + pp[48 + 15] = float2int8(p0[6 + 113] * scale7); + + pp[64 + 0] = float2int8(p0[8 + 0] * scale0); + pp[64 + 1] = float2int8(p0[8 + 1] * scale0); + pp[64 + 2] = float2int8(p0[8 + 16] * scale1); + pp[64 + 3] = float2int8(p0[8 + 17] * scale1); + pp[64 + 4] = float2int8(p0[8 + 32] * scale2); + pp[64 + 5] = float2int8(p0[8 + 33] * scale2); + pp[64 + 6] = float2int8(p0[8 + 48] * scale3); + pp[64 + 7] = float2int8(p0[8 + 49] * scale3); + pp[64 + 8] = float2int8(p0[8 + 64] * scale4); + pp[64 + 9] = float2int8(p0[8 + 65] * scale4); + pp[64 + 10] = float2int8(p0[8 + 80] * scale5); + pp[64 + 11] = float2int8(p0[8 + 81] * scale5); + pp[64 + 12] = float2int8(p0[8 + 96] * scale6); + pp[64 + 13] = float2int8(p0[8 + 97] * scale6); + pp[64 + 14] = float2int8(p0[8 + 112] * scale7); + pp[64 + 15] = float2int8(p0[8 + 113] * scale7); + + pp[80 + 0] = float2int8(p0[10 + 0] * scale0); + pp[80 + 1] = float2int8(p0[10 + 1] * scale0); + pp[80 + 2] = float2int8(p0[10 + 16] * scale1); + pp[80 + 3] = float2int8(p0[10 + 17] * scale1); + pp[80 + 4] = float2int8(p0[10 + 32] * scale2); + pp[80 + 5] = float2int8(p0[10 + 33] * scale2); + pp[80 + 6] = float2int8(p0[10 + 48] * scale3); + pp[80 + 7] = float2int8(p0[10 + 49] * scale3); + pp[80 + 8] = float2int8(p0[10 + 64] * scale4); + pp[80 + 9] = float2int8(p0[10 + 65] * scale4); + pp[80 + 10] = float2int8(p0[10 + 80] * scale5); + pp[80 + 11] = float2int8(p0[10 + 81] * scale5); + pp[80 + 12] = float2int8(p0[10 + 96] * scale6); + pp[80 + 13] = float2int8(p0[10 + 97] * scale6); + pp[80 + 14] = float2int8(p0[10 + 112] * scale7); + pp[80 + 15] = float2int8(p0[10 + 113] * scale7); + + pp[96 + 0] = float2int8(p0[12 + 0] * scale0); + pp[96 + 1] = float2int8(p0[12 + 1] * scale0); + pp[96 + 2] = float2int8(p0[12 + 16] * scale1); + pp[96 + 3] = float2int8(p0[12 + 17] * scale1); + pp[96 + 4] = float2int8(p0[12 + 32] * scale2); + pp[96 + 5] = float2int8(p0[12 + 33] * scale2); + pp[96 + 6] = float2int8(p0[12 + 48] * scale3); + pp[96 + 7] = float2int8(p0[12 + 49] * scale3); + pp[96 + 8] = float2int8(p0[12 + 64] * scale4); + pp[96 + 9] = float2int8(p0[12 + 65] * scale4); + pp[96 + 10] = float2int8(p0[12 + 80] * scale5); + pp[96 + 11] = float2int8(p0[12 + 81] * scale5); + pp[96 + 12] = float2int8(p0[12 + 96] * scale6); + pp[96 + 13] = float2int8(p0[12 + 97] * scale6); + pp[96 + 14] = float2int8(p0[12 + 112] * scale7); + pp[96 + 15] = float2int8(p0[12 + 113] * scale7); + + pp[112 + 0] = float2int8(p0[14 + 0] * scale0); + pp[112 + 1] = float2int8(p0[14 + 1] * scale0); + pp[112 + 2] = float2int8(p0[14 + 16] * scale1); + pp[112 + 3] = float2int8(p0[14 + 17] * scale1); + pp[112 + 4] = float2int8(p0[14 + 32] * scale2); + pp[112 + 5] = float2int8(p0[14 + 33] * scale2); + pp[112 + 6] = float2int8(p0[14 + 48] * scale3); + pp[112 + 7] = float2int8(p0[14 + 49] * scale3); + pp[112 + 8] = float2int8(p0[14 + 64] * scale4); + pp[112 + 9] = float2int8(p0[14 + 65] * scale4); + pp[112 + 10] = float2int8(p0[14 + 80] * scale5); + pp[112 + 11] = float2int8(p0[14 + 81] * scale5); + pp[112 + 12] = float2int8(p0[14 + 96] * scale6); + pp[112 + 13] = float2int8(p0[14 + 97] * scale6); + pp[112 + 14] = float2int8(p0[14 + 112] * scale7); + pp[112 + 15] = float2int8(p0[14 + 113] * scale7); + + pp += 128; + p0 += A_hstep * 16; + } + } +#endif // __AVX512F__ + if (elempack == 8) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale0); + pp[2] = float2int8(p0[8] * scale1); + pp[3] = float2int8(p0[9] * scale1); + pp[4] = float2int8(p0[16] * scale2); + pp[5] = float2int8(p0[17] * scale2); + pp[6] = float2int8(p0[24] * scale3); + pp[7] = float2int8(p0[25] * scale3); +#if __AVX2__ + pp[8] = float2int8(p0[32] * scale4); + pp[9] = float2int8(p0[33] * scale4); + pp[10] = float2int8(p0[40] * scale5); + pp[11] = float2int8(p0[41] * scale5); + pp[12] = float2int8(p0[48] * scale6); + pp[13] = float2int8(p0[49] * scale6); + pp[14] = float2int8(p0[56] * scale7); + pp[15] = float2int8(p0[57] * scale7); + pp += 16; +#else + pp1[0] = float2int8(p0[32] * scale4); + pp1[1] = float2int8(p0[33] * scale4); + pp1[2] = float2int8(p0[40] * scale5); + pp1[3] = float2int8(p0[41] * scale5); + pp1[4] = float2int8(p0[48] * scale6); + pp1[5] = float2int8(p0[49] * scale6); + pp1[6] = float2int8(p0[56] * scale7); + pp1[7] = float2int8(p0[57] * scale7); + pp += 8; + pp1 += 8; +#endif + + pp[0] = float2int8(p0[2] * scale0); + pp[1] = float2int8(p0[3] * scale0); + pp[2] = float2int8(p0[10] * scale1); + pp[3] = float2int8(p0[11] * scale1); + pp[4] = float2int8(p0[18] * scale2); + pp[5] = float2int8(p0[19] * scale2); + pp[6] = float2int8(p0[26] * scale3); + pp[7] = float2int8(p0[27] * scale3); +#if __AVX2__ + pp[8] = float2int8(p0[34] * scale4); + pp[9] = float2int8(p0[35] * scale4); + pp[10] = float2int8(p0[42] * scale5); + pp[11] = float2int8(p0[43] * scale5); + pp[12] = float2int8(p0[50] * scale6); + pp[13] = float2int8(p0[51] * scale6); + pp[14] = float2int8(p0[58] * scale7); + pp[15] = float2int8(p0[59] * scale7); + pp += 16; +#else + pp1[0] = float2int8(p0[34] * scale4); + pp1[1] = float2int8(p0[35] * scale4); + pp1[2] = float2int8(p0[42] * scale5); + pp1[3] = float2int8(p0[43] * scale5); + pp1[4] = float2int8(p0[50] * scale6); + pp1[5] = float2int8(p0[51] * scale6); + pp1[6] = float2int8(p0[58] * scale7); + pp1[7] = float2int8(p0[59] * scale7); pp += 8; pp1 += 8; #endif @@ -1424,6 +2939,89 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int const float scale3 = scales[i + ii + 3]; #if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale0); + pp[2] = float2int8(p0[16] * scale1); + pp[3] = float2int8(p0[17] * scale1); + pp[4] = float2int8(p0[32] * scale2); + pp[5] = float2int8(p0[33] * scale2); + pp[6] = float2int8(p0[48] * scale3); + pp[7] = float2int8(p0[49] * scale3); + + pp[8] = float2int8(p0[2 + 0] * scale0); + pp[9] = float2int8(p0[2 + 1] * scale0); + pp[10] = float2int8(p0[2 + 16] * scale1); + pp[11] = float2int8(p0[2 + 17] * scale1); + pp[12] = float2int8(p0[2 + 32] * scale2); + pp[13] = float2int8(p0[2 + 33] * scale2); + pp[14] = float2int8(p0[2 + 48] * scale3); + pp[15] = float2int8(p0[2 + 49] * scale3); + + pp[16 + 0] = float2int8(p0[4 + 0] * scale0); + pp[16 + 1] = float2int8(p0[4 + 1] * scale0); + pp[16 + 2] = float2int8(p0[4 + 16] * scale1); + pp[16 + 3] = float2int8(p0[4 + 17] * scale1); + pp[16 + 4] = float2int8(p0[4 + 32] * scale2); + pp[16 + 5] = float2int8(p0[4 + 33] * scale2); + pp[16 + 6] = float2int8(p0[4 + 48] * scale3); + pp[16 + 7] = float2int8(p0[4 + 49] * scale3); + + pp[16 + 8] = float2int8(p0[6 + 0] * scale0); + pp[16 + 9] = float2int8(p0[6 + 1] * scale0); + pp[16 + 10] = float2int8(p0[6 + 16] * scale1); + pp[16 + 11] = float2int8(p0[6 + 17] * scale1); + pp[16 + 12] = float2int8(p0[6 + 32] * scale2); + pp[16 + 13] = float2int8(p0[6 + 33] * scale2); + pp[16 + 14] = float2int8(p0[6 + 48] * scale3); + pp[16 + 15] = float2int8(p0[6 + 49] * scale3); + + pp[32 + 0] = float2int8(p0[8 + 0] * scale0); + pp[32 + 1] = float2int8(p0[8 + 1] * scale0); + pp[32 + 2] = float2int8(p0[8 + 16] * scale1); + pp[32 + 3] = float2int8(p0[8 + 17] * scale1); + pp[32 + 4] = float2int8(p0[8 + 32] * scale2); + pp[32 + 5] = float2int8(p0[8 + 33] * scale2); + pp[32 + 6] = float2int8(p0[8 + 48] * scale3); + pp[32 + 7] = float2int8(p0[8 + 49] * scale3); + + pp[32 + 8] = float2int8(p0[10 + 0] * scale0); + pp[32 + 9] = float2int8(p0[10 + 1] * scale0); + pp[32 + 10] = float2int8(p0[10 + 16] * scale1); + pp[32 + 11] = float2int8(p0[10 + 17] * scale1); + pp[32 + 12] = float2int8(p0[10 + 32] * scale2); + pp[32 + 13] = float2int8(p0[10 + 33] * scale2); + pp[32 + 14] = float2int8(p0[10 + 48] * scale3); + pp[32 + 15] = float2int8(p0[10 + 49] * scale3); + + pp[48 + 0] = float2int8(p0[12 + 0] * scale0); + pp[48 + 1] = float2int8(p0[12 + 1] * scale0); + pp[48 + 2] = float2int8(p0[12 + 16] * scale1); + pp[48 + 3] = float2int8(p0[12 + 17] * scale1); + pp[48 + 4] = float2int8(p0[12 + 32] * scale2); + pp[48 + 5] = float2int8(p0[12 + 33] * scale2); + pp[48 + 6] = float2int8(p0[12 + 48] * scale3); + pp[48 + 7] = float2int8(p0[12 + 49] * scale3); + + pp[48 + 8] = float2int8(p0[14 + 0] * scale0); + pp[48 + 9] = float2int8(p0[14 + 1] * scale0); + pp[48 + 10] = float2int8(p0[14 + 16] * scale1); + pp[48 + 11] = float2int8(p0[14 + 17] * scale1); + pp[48 + 12] = float2int8(p0[14 + 32] * scale2); + pp[48 + 13] = float2int8(p0[14 + 33] * scale2); + pp[48 + 14] = float2int8(p0[14 + 48] * scale3); + pp[48 + 15] = float2int8(p0[14 + 49] * scale3); + + pp += 64; + p0 += A_hstep * 16; + } + } +#endif // __AVX512F__ if (elempack == 8) { int kk = 0; @@ -1537,6 +3135,57 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int #if __SSE2__ #if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale0); + pp[2] = float2int8(p0[16] * scale1); + pp[3] = float2int8(p0[17] * scale1); + + pp[4] = float2int8(p0[2] * scale0); + pp[5] = float2int8(p0[3] * scale0); + pp[6] = float2int8(p0[18] * scale1); + pp[7] = float2int8(p0[19] * scale1); + + pp[8] = float2int8(p0[4] * scale0); + pp[9] = float2int8(p0[5] * scale0); + pp[10] = float2int8(p0[20] * scale1); + pp[11] = float2int8(p0[21] * scale1); + + pp[12] = float2int8(p0[6] * scale0); + pp[13] = float2int8(p0[7] * scale0); + pp[14] = float2int8(p0[22] * scale1); + pp[15] = float2int8(p0[23] * scale1); + + pp[16 + 0] = float2int8(p0[8] * scale0); + pp[16 + 1] = float2int8(p0[9] * scale0); + pp[16 + 2] = float2int8(p0[24] * scale1); + pp[16 + 3] = float2int8(p0[25] * scale1); + + pp[16 + 4] = float2int8(p0[10] * scale0); + pp[16 + 5] = float2int8(p0[11] * scale0); + pp[16 + 6] = float2int8(p0[26] * scale1); + pp[16 + 7] = float2int8(p0[27] * scale1); + + pp[16 + 8] = float2int8(p0[12] * scale0); + pp[16 + 9] = float2int8(p0[13] * scale0); + pp[16 + 10] = float2int8(p0[28] * scale1); + pp[16 + 11] = float2int8(p0[29] * scale1); + + pp[16 + 12] = float2int8(p0[14] * scale0); + pp[16 + 13] = float2int8(p0[15] * scale0); + pp[16 + 14] = float2int8(p0[30] * scale1); + pp[16 + 15] = float2int8(p0[31] * scale1); + + pp += 32; + p0 += A_hstep * 16; + } + } +#endif // __AVX512F__ if (elempack == 8) { int kk = 0; @@ -1616,6 +3265,33 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int #if __SSE2__ #if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[1] * scale); + pp[2] = float2int8(p0[2] * scale); + pp[3] = float2int8(p0[3] * scale); + pp[4] = float2int8(p0[4] * scale); + pp[5] = float2int8(p0[5] * scale); + pp[6] = float2int8(p0[6] * scale); + pp[7] = float2int8(p0[7] * scale); + pp[8] = float2int8(p0[8] * scale); + pp[9] = float2int8(p0[9] * scale); + pp[10] = float2int8(p0[10] * scale); + pp[11] = float2int8(p0[11] * scale); + pp[12] = float2int8(p0[12] * scale); + pp[13] = float2int8(p0[13] * scale); + pp[14] = float2int8(p0[14] * scale); + pp[15] = float2int8(p0[15] * scale); + pp += 16; + p0 += A_hstep * 16; + } + } +#endif // __AVX512F__ if (elempack == 8) { int kk = 0; @@ -1668,6 +3344,9 @@ static void compute_B_fp32_int8_scale(const Mat& B, float& scale) float absmax = 0.f; #if __SSE2__ #if __AVX__ +#if __AVX512F__ + __m512 _absmax_avx512 = _mm512_setzero_ps(); +#endif // __AVX512F__ __m256 _absmax_avx = _mm256_setzero_ps(); #endif // __AVX__ __m128 _absmax = _mm_setzero_ps(); @@ -1682,6 +3361,14 @@ static void compute_B_fp32_int8_scale(const Mat& B, float& scale) int j = 0; #if __SSE2__ #if __AVX__ +#if __AVX512F__ + for (; j + 15 < size; j += 16) + { + __m512 _p = _mm512_loadu_ps(ptr); + _absmax_avx512 = _mm512_max_ps(_absmax_avx512, abs512_ps(_p)); + ptr += 16; + } +#endif // __AVX512F__ for (; j + 7 < size; j += 8) { __m256 _p = _mm256_loadu_ps(ptr); @@ -1704,6 +3391,9 @@ static void compute_B_fp32_int8_scale(const Mat& B, float& scale) } #if __SSE2__ #if __AVX__ +#if __AVX512F__ + absmax = std::max(absmax, _mm512_comp_reduce_max_ps(_absmax_avx512)); +#endif // __AVX512F__ absmax = std::max(absmax, _mm256_reduce_max_ps(_absmax_avx)); #endif // __AVX__ absmax = std::max(absmax, _mm_reduce_max_ps(_absmax)); @@ -1724,35 +3414,50 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i int jj = 0; #if __SSE2__ #if defined(__x86_64__) || defined(_M_X64) - for (; jj + 7 < max_jj; jj += 8) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) { const float* p0 = (const float*)B + (j + jj) * B_hstep + k * elempack; -#if __AVX__ - if (elempack == 8) + if (elempack == 16) { int kk = 0; for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[8] * scale); + pp[1] = float2int8(p0[16] * scale); pp[2] = float2int8(p0[1] * scale); - pp[3] = float2int8(p0[9] * scale); + pp[3] = float2int8(p0[17] * scale); pp[4] = float2int8(p0[2] * scale); - pp[5] = float2int8(p0[10] * scale); + pp[5] = float2int8(p0[18] * scale); pp[6] = float2int8(p0[3] * scale); - pp[7] = float2int8(p0[11] * scale); + pp[7] = float2int8(p0[19] * scale); pp[8] = float2int8(p0[4] * scale); - pp[9] = float2int8(p0[12] * scale); + pp[9] = float2int8(p0[20] * scale); pp[10] = float2int8(p0[5] * scale); - pp[11] = float2int8(p0[13] * scale); + pp[11] = float2int8(p0[21] * scale); pp[12] = float2int8(p0[6] * scale); - pp[13] = float2int8(p0[14] * scale); + pp[13] = float2int8(p0[22] * scale); pp[14] = float2int8(p0[7] * scale); - pp[15] = float2int8(p0[15] * scale); - - pp += 16; - p0 += 16; + pp[15] = float2int8(p0[23]* scale); + pp[16 + 0] = float2int8(p0[8] * scale); + pp[16 + 1] = float2int8(p0[24] * scale); + pp[16 + 2] = float2int8(p0[9] * scale); + pp[16 + 3] = float2int8(p0[25] * scale); + pp[16 + 4] = float2int8(p0[10] * scale); + pp[16 + 5] = float2int8(p0[26] * scale); + pp[16 + 6] = float2int8(p0[11] * scale); + pp[16 + 7] = float2int8(p0[27] * scale); + pp[16 + 8] = float2int8(p0[12] * scale); + pp[16 + 9] = float2int8(p0[28] * scale); + pp[16 + 10] = float2int8(p0[13] * scale); + pp[16 + 11] = float2int8(p0[29] * scale); + pp[16 + 12] = float2int8(p0[14] * scale); + pp[16 + 13] = float2int8(p0[30] * scale); + pp[16 + 14] = float2int8(p0[15] * scale); + pp[16 + 15] = float2int8(p0[31]* scale); + pp += 32; + p0 += 32; } for (; kk < max_kk; kk++) { @@ -1764,30 +3469,275 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i pp[5] = float2int8(p0[5] * scale); pp[6] = float2int8(p0[6] * scale); pp[7] = float2int8(p0[7] * scale); - - pp += 8; - p0 += 8; + pp[8] = float2int8(p0[8] * scale); + pp[9] = float2int8(p0[9] * scale); + pp[10] = float2int8(p0[10] * scale); + pp[11] = float2int8(p0[11] * scale); + pp[12] = float2int8(p0[12] * scale); + pp[13] = float2int8(p0[13] * scale); + pp[14] = float2int8(p0[14] * scale); + pp[15] = float2int8(p0[15] * scale); + pp += 16; + p0 += 16; } } -#endif // __AVX__ - if (elempack == 4) + if (elempack == 8) { int kk = 0; for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[4] * scale); + pp[1] = float2int8(p0[8] * scale); pp[2] = float2int8(p0[1] * scale); - pp[3] = float2int8(p0[5] * scale); + pp[3] = float2int8(p0[9] * scale); pp[4] = float2int8(p0[2] * scale); - pp[5] = float2int8(p0[6] * scale); + pp[5] = float2int8(p0[10] * scale); pp[6] = float2int8(p0[3] * scale); - pp[7] = float2int8(p0[7] * scale); - pp[8] = float2int8(p0[B_hstep * 4] * scale); - pp[9] = float2int8(p0[B_hstep * 4 + 4] * scale); - pp[10] = float2int8(p0[B_hstep * 4 + 1] * scale); - pp[11] = float2int8(p0[B_hstep * 4 + 5] * scale); - pp[12] = float2int8(p0[B_hstep * 4 + 2] * scale); + pp[7] = float2int8(p0[11] * scale); + pp[8] = float2int8(p0[4] * scale); + pp[9] = float2int8(p0[12] * scale); + pp[10] = float2int8(p0[5] * scale); + pp[11] = float2int8(p0[13] * scale); + pp[12] = float2int8(p0[6] * scale); + pp[13] = float2int8(p0[14] * scale); + pp[14] = float2int8(p0[7] * scale); + pp[15] = float2int8(p0[15]* scale); + + pp[16 + 0] = float2int8(p0[B_hstep * 8 + 0] * scale); + pp[16 + 1] = float2int8(p0[B_hstep * 8 + 8] * scale); + pp[16 + 2] = float2int8(p0[B_hstep * 8 + 1] * scale); + pp[16 + 3] = float2int8(p0[B_hstep * 8 + 9] * scale); + pp[16 + 4] = float2int8(p0[B_hstep * 8 + 2] * scale); + pp[16 + 5] = float2int8(p0[B_hstep * 8 + 10] * scale); + pp[16 + 6] = float2int8(p0[B_hstep * 8 + 3] * scale); + pp[16 + 7] = float2int8(p0[B_hstep * 8 + 11] * scale); + pp[16 + 8] = float2int8(p0[B_hstep * 8 + 4] * scale); + pp[16 + 9] = float2int8(p0[B_hstep * 8 + 12] * scale); + pp[16 + 10] = float2int8(p0[B_hstep * 8 + 5] * scale); + pp[16 + 11] = float2int8(p0[B_hstep * 8 + 13] * scale); + pp[16 + 12] = float2int8(p0[B_hstep * 8 + 6] * scale); + pp[16 + 13] = float2int8(p0[B_hstep * 8 + 14] * scale); + pp[16 + 14] = float2int8(p0[B_hstep * 8 + 7] * scale); + pp[16 + 15] = float2int8(p0[B_hstep * 8 + 15]* scale); + pp += 32; + p0 += 16; + } + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[1] * scale); + pp[2] = float2int8(p0[2] * scale); + pp[3] = float2int8(p0[3] * scale); + pp[4] = float2int8(p0[4] * scale); + pp[5] = float2int8(p0[5] * scale); + pp[6] = float2int8(p0[6] * scale); + pp[7] = float2int8(p0[7] * scale); + pp[8] = float2int8(p0[B_hstep * 8 + 0] * scale); + pp[9] = float2int8(p0[B_hstep * 8 + 1] * scale); + pp[10] = float2int8(p0[B_hstep * 8 + 2] * scale); + pp[11] = float2int8(p0[B_hstep * 8 + 3] * scale); + pp[12] = float2int8(p0[B_hstep * 8 + 4] * scale); + pp[13] = float2int8(p0[B_hstep * 8 + 5] * scale); + pp[14] = float2int8(p0[B_hstep * 8 + 6] * scale); + pp[15] = float2int8(p0[B_hstep * 8 + 7] * scale); + pp += 16; + p0 += 8; + } + } + if (elempack == 4) + { + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[4] * scale); + pp[2] = float2int8(p0[1] * scale); + pp[3] = float2int8(p0[5] * scale); + pp[4] = float2int8(p0[2] * scale); + pp[5] = float2int8(p0[6] * scale); + pp[6] = float2int8(p0[3] * scale); + pp[7] = float2int8(p0[7] * scale); + pp[8] = float2int8(p0[B_hstep * 4 + 0] * scale); + pp[9] = float2int8(p0[B_hstep * 4 + 4] * scale); + pp[10] = float2int8(p0[B_hstep * 4 + 1] * scale); + pp[11] = float2int8(p0[B_hstep * 4 + 5] * scale); + pp[12] = float2int8(p0[B_hstep * 4 + 2] * scale); + pp[13] = float2int8(p0[B_hstep * 4 + 6] * scale); + pp[14] = float2int8(p0[B_hstep * 4 + 3] * scale); + pp[15] = float2int8(p0[B_hstep * 4 + 7] * scale); + + pp[16 + 0] = float2int8(p0[B_hstep * 8 + 0] * scale); + pp[16 + 1] = float2int8(p0[B_hstep * 8 + 4] * scale); + pp[16 + 2] = float2int8(p0[B_hstep * 8 + 1] * scale); + pp[16 + 3] = float2int8(p0[B_hstep * 8 + 5] * scale); + pp[16 + 4] = float2int8(p0[B_hstep * 8 + 2] * scale); + pp[16 + 5] = float2int8(p0[B_hstep * 8 + 6] * scale); + pp[16 + 6] = float2int8(p0[B_hstep * 8 + 3] * scale); + pp[16 + 7] = float2int8(p0[B_hstep * 8 + 7] * scale); + + pp[16 + 8] = float2int8(p0[B_hstep * 12 + 0] * scale); + pp[16 + 9] = float2int8(p0[B_hstep * 12 + 4] * scale); + pp[16 + 10] = float2int8(p0[B_hstep * 12 + 1] * scale); + pp[16 + 11] = float2int8(p0[B_hstep * 12 + 5] * scale); + pp[16 + 12] = float2int8(p0[B_hstep * 12 + 2] * scale); + pp[16 + 13] = float2int8(p0[B_hstep * 12 + 6] * scale); + pp[16 + 14] = float2int8(p0[B_hstep * 12 + 3] * scale); + pp[16 + 15] = float2int8(p0[B_hstep * 12 + 7] * scale); + + pp += 32; + p0 += 8; + } + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[1] * scale); + pp[2] = float2int8(p0[2] * scale); + pp[3] = float2int8(p0[3] * scale); + pp[4] = float2int8(p0[B_hstep * 4] * scale); + pp[5] = float2int8(p0[B_hstep * 4 + 1] * scale); + pp[6] = float2int8(p0[B_hstep * 4 + 2] * scale); + pp[7] = float2int8(p0[B_hstep * 4 + 3] * scale); + pp[8] = float2int8(p0[B_hstep * 8] * scale); + pp[9] = float2int8(p0[B_hstep * 8 + 1] * scale); + pp[10] = float2int8(p0[B_hstep * 8 + 2] * scale); + pp[11] = float2int8(p0[B_hstep * 8 + 3] * scale); + pp[12] = float2int8(p0[B_hstep * 12] * scale); + pp[13] = float2int8(p0[B_hstep * 12 + 1] * scale); + pp[14] = float2int8(p0[B_hstep * 12 + 2] * scale); + pp[15] = float2int8(p0[B_hstep * 12 + 3] * scale); + pp += 16; + p0 += 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[1] * scale); + pp[2] = float2int8(p0[B_hstep] * scale); + pp[3] = float2int8(p0[B_hstep + 1] * scale); + pp[4] = float2int8(p0[B_hstep * 2] * scale); + pp[5] = float2int8(p0[B_hstep * 2 + 1] * scale); + pp[6] = float2int8(p0[B_hstep * 3] * scale); + pp[7] = float2int8(p0[B_hstep * 3 + 1] * scale); + pp[8] = float2int8(p0[B_hstep * 4] * scale); + pp[9] = float2int8(p0[B_hstep * 4 + 1] * scale); + pp[10] = float2int8(p0[B_hstep * 5] * scale); + pp[11] = float2int8(p0[B_hstep * 5 + 1] * scale); + pp[12] = float2int8(p0[B_hstep * 6] * scale); + pp[13] = float2int8(p0[B_hstep * 6 + 1] * scale); + pp[14] = float2int8(p0[B_hstep * 7] * scale); + pp[15] = float2int8(p0[B_hstep * 7 + 1] * scale); + + pp[16 + 0] = float2int8(p0[B_hstep * 8] * scale); + pp[16 + 1] = float2int8(p0[B_hstep * 8 + 1] * scale); + pp[16 + 2] = float2int8(p0[B_hstep * 9] * scale); + pp[16 + 3] = float2int8(p0[B_hstep * 9 + 1] * scale); + pp[16 + 4] = float2int8(p0[B_hstep * 10] * scale); + pp[16 + 5] = float2int8(p0[B_hstep * 10 + 1] * scale); + pp[16 + 6] = float2int8(p0[B_hstep * 11] * scale); + pp[16 + 7] = float2int8(p0[B_hstep * 11 + 1] * scale); + pp[16 + 8] = float2int8(p0[B_hstep * 12] * scale); + pp[16 + 9] = float2int8(p0[B_hstep * 12 + 1] * scale); + pp[16 + 10] = float2int8(p0[B_hstep * 13] * scale); + pp[16 + 11] = float2int8(p0[B_hstep * 13 + 1] * scale); + pp[16 + 12] = float2int8(p0[B_hstep * 14] * scale); + pp[16 + 13] = float2int8(p0[B_hstep * 14 + 1] * scale); + pp[16 + 14] = float2int8(p0[B_hstep * 15] * scale); + pp[16 + 15] = float2int8(p0[B_hstep * 15 + 1] * scale); + pp += 32; + p0 += 2; + } + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[B_hstep] * scale); + pp[2] = float2int8(p0[B_hstep * 2] * scale); + pp[3] = float2int8(p0[B_hstep * 3] * scale); + pp[4] = float2int8(p0[B_hstep * 4] * scale); + pp[5] = float2int8(p0[B_hstep * 5] * scale); + pp[6] = float2int8(p0[B_hstep * 6] * scale); + pp[7] = float2int8(p0[B_hstep * 7] * scale); + pp[8] = float2int8(p0[B_hstep * 8] * scale); + pp[9] = float2int8(p0[B_hstep * 9] * scale); + pp[10] = float2int8(p0[B_hstep * 10] * scale); + pp[11] = float2int8(p0[B_hstep * 11] * scale); + pp[12] = float2int8(p0[B_hstep * 12] * scale); + pp[13] = float2int8(p0[B_hstep * 13] * scale); + pp[14] = float2int8(p0[B_hstep * 14] * scale); + pp[15] = float2int8(p0[B_hstep * 15] * scale); + pp += 16; + p0++; + } + } + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k * elempack; + +#if __AVX__ + if (elempack == 8) + { + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[8] * scale); + pp[2] = float2int8(p0[1] * scale); + pp[3] = float2int8(p0[9] * scale); + pp[4] = float2int8(p0[2] * scale); + pp[5] = float2int8(p0[10] * scale); + pp[6] = float2int8(p0[3] * scale); + pp[7] = float2int8(p0[11] * scale); + pp[8] = float2int8(p0[4] * scale); + pp[9] = float2int8(p0[12] * scale); + pp[10] = float2int8(p0[5] * scale); + pp[11] = float2int8(p0[13] * scale); + pp[12] = float2int8(p0[6] * scale); + pp[13] = float2int8(p0[14] * scale); + pp[14] = float2int8(p0[7] * scale); + pp[15] = float2int8(p0[15] * scale); + + pp += 16; + p0 += 16; + } + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[1] * scale); + pp[2] = float2int8(p0[2] * scale); + pp[3] = float2int8(p0[3] * scale); + pp[4] = float2int8(p0[4] * scale); + pp[5] = float2int8(p0[5] * scale); + pp[6] = float2int8(p0[6] * scale); + pp[7] = float2int8(p0[7] * scale); + + pp += 8; + p0 += 8; + } + } +#endif // __AVX__ + if (elempack == 4) + { + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[4] * scale); + pp[2] = float2int8(p0[1] * scale); + pp[3] = float2int8(p0[5] * scale); + pp[4] = float2int8(p0[2] * scale); + pp[5] = float2int8(p0[6] * scale); + pp[6] = float2int8(p0[3] * scale); + pp[7] = float2int8(p0[7] * scale); + pp[8] = float2int8(p0[B_hstep * 4] * scale); + pp[9] = float2int8(p0[B_hstep * 4 + 4] * scale); + pp[10] = float2int8(p0[B_hstep * 4 + 1] * scale); + pp[11] = float2int8(p0[B_hstep * 4 + 5] * scale); + pp[12] = float2int8(p0[B_hstep * 4 + 2] * scale); pp[13] = float2int8(p0[B_hstep * 4 + 6] * scale); pp[14] = float2int8(p0[B_hstep * 4 + 3] * scale); pp[15] = float2int8(p0[B_hstep * 4 + 7] * scale); @@ -1970,11 +3920,292 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int int jj = 0; #if __SSE2__ #if defined(__x86_64__) || defined(_M_X64) - for (; jj + 7 < max_jj; jj += 8) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) { const float* p0 = (const float*)B + k * B_hstep + (j + jj) * elempack; -#if __AVX__ + if (elempack == 16) + { + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[1] * scale); + pp[2] = float2int8(p0[16] * scale); + pp[3] = float2int8(p0[17] * scale); + pp[4] = float2int8(p0[32] * scale); + pp[5] = float2int8(p0[33] * scale); + pp[6] = float2int8(p0[48] * scale); + pp[7] = float2int8(p0[49] * scale); + pp[8] = float2int8(p0[64] * scale); + pp[9] = float2int8(p0[65] * scale); + pp[10] = float2int8(p0[80] * scale); + pp[11] = float2int8(p0[81] * scale); + pp[12] = float2int8(p0[96] * scale); + pp[13] = float2int8(p0[97] * scale); + pp[14] = float2int8(p0[112] * scale); + pp[15] = float2int8(p0[113] * scale); + + pp[16 + 0] = float2int8(p0[128 + 0] * scale); + pp[16 + 1] = float2int8(p0[128 + 1] * scale); + pp[16 + 2] = float2int8(p0[128 + 16] * scale); + pp[16 + 3] = float2int8(p0[128 + 17] * scale); + pp[16 + 4] = float2int8(p0[128 + 32] * scale); + pp[16 + 5] = float2int8(p0[128 + 33] * scale); + pp[16 + 6] = float2int8(p0[128 + 48] * scale); + pp[16 + 7] = float2int8(p0[128 + 49] * scale); + pp[16 + 8] = float2int8(p0[128 + 64] * scale); + pp[16 + 9] = float2int8(p0[128 + 65] * scale); + pp[16 + 10] = float2int8(p0[128 + 80] * scale); + pp[16 + 11] = float2int8(p0[128 + 81] * scale); + pp[16 + 12] = float2int8(p0[128 + 96] * scale); + pp[16 + 13] = float2int8(p0[128 + 97] * scale); + pp[16 + 14] = float2int8(p0[128 + 112] * scale); + pp[16 + 15] = float2int8(p0[128 + 113] * scale); + + pp[32 + 0] = float2int8(p0[2 + 0] * scale); + pp[32 + 1] = float2int8(p0[2 + 1] * scale); + pp[32 + 2] = float2int8(p0[2 + 16] * scale); + pp[32 + 3] = float2int8(p0[2 + 17] * scale); + pp[32 + 4] = float2int8(p0[2 + 32] * scale); + pp[32 + 5] = float2int8(p0[2 + 33] * scale); + pp[32 + 6] = float2int8(p0[2 + 48] * scale); + pp[32 + 7] = float2int8(p0[2 + 49] * scale); + pp[32 + 8] = float2int8(p0[2 + 64] * scale); + pp[32 + 9] = float2int8(p0[2 + 65] * scale); + pp[32 + 10] = float2int8(p0[2 + 80] * scale); + pp[32 + 11] = float2int8(p0[2 + 81] * scale); + pp[32 + 12] = float2int8(p0[2 + 96] * scale); + pp[32 + 13] = float2int8(p0[2 + 97] * scale); + pp[32 + 14] = float2int8(p0[2 + 112] * scale); + pp[32 + 15] = float2int8(p0[2 + 113] * scale); + + pp[48 + 0] = float2int8(p0[2 + 128 + 0] * scale); + pp[48 + 1] = float2int8(p0[2 + 128 + 1] * scale); + pp[48 + 2] = float2int8(p0[2 + 128 + 16] * scale); + pp[48 + 3] = float2int8(p0[2 + 128 + 17] * scale); + pp[48 + 4] = float2int8(p0[2 + 128 + 32] * scale); + pp[48 + 5] = float2int8(p0[2 + 128 + 33] * scale); + pp[48 + 6] = float2int8(p0[2 + 128 + 48] * scale); + pp[48 + 7] = float2int8(p0[2 + 128 + 49] * scale); + pp[48 + 8] = float2int8(p0[2 + 128 + 64] * scale); + pp[48 + 9] = float2int8(p0[2 + 128 + 65] * scale); + pp[48 + 10] = float2int8(p0[2 + 128 + 80] * scale); + pp[48 + 11] = float2int8(p0[2 + 128 + 81] * scale); + pp[48 + 12] = float2int8(p0[2 + 128 + 96] * scale); + pp[48 + 13] = float2int8(p0[2 + 128 + 97] * scale); + pp[48 + 14] = float2int8(p0[2 + 128 + 112] * scale); + pp[48 + 15] = float2int8(p0[2 + 128 + 113] * scale); + + pp[64 + 0] = float2int8(p0[4 + 0] * scale); + pp[64 + 1] = float2int8(p0[4 + 1] * scale); + pp[64 + 2] = float2int8(p0[4 + 16] * scale); + pp[64 + 3] = float2int8(p0[4 + 17] * scale); + pp[64 + 4] = float2int8(p0[4 + 32] * scale); + pp[64 + 5] = float2int8(p0[4 + 33] * scale); + pp[64 + 6] = float2int8(p0[4 + 48] * scale); + pp[64 + 7] = float2int8(p0[4 + 49] * scale); + pp[64 + 8] = float2int8(p0[4 + 64] * scale); + pp[64 + 9] = float2int8(p0[4 + 65] * scale); + pp[64 + 10] = float2int8(p0[4 + 80] * scale); + pp[64 + 11] = float2int8(p0[4 + 81] * scale); + pp[64 + 12] = float2int8(p0[4 + 96] * scale); + pp[64 + 13] = float2int8(p0[4 + 97] * scale); + pp[64 + 14] = float2int8(p0[4 + 112] * scale); + pp[64 + 15] = float2int8(p0[4 + 113] * scale); + + pp[80 + 0] = float2int8(p0[4 + 128 + 0] * scale); + pp[80 + 1] = float2int8(p0[4 + 128 + 1] * scale); + pp[80 + 2] = float2int8(p0[4 + 128 + 16] * scale); + pp[80 + 3] = float2int8(p0[4 + 128 + 17] * scale); + pp[80 + 4] = float2int8(p0[4 + 128 + 32] * scale); + pp[80 + 5] = float2int8(p0[4 + 128 + 33] * scale); + pp[80 + 6] = float2int8(p0[4 + 128 + 48] * scale); + pp[80 + 7] = float2int8(p0[4 + 128 + 49] * scale); + pp[80 + 8] = float2int8(p0[4 + 128 + 64] * scale); + pp[80 + 9] = float2int8(p0[4 + 128 + 65] * scale); + pp[80 + 10] = float2int8(p0[4 + 128 + 80] * scale); + pp[80 + 11] = float2int8(p0[4 + 128 + 81] * scale); + pp[80 + 12] = float2int8(p0[4 + 128 + 96] * scale); + pp[80 + 13] = float2int8(p0[4 + 128 + 97] * scale); + pp[80 + 14] = float2int8(p0[4 + 128 + 112] * scale); + pp[80 + 15] = float2int8(p0[4 + 128 + 113] * scale); + + pp[96 + 0] = float2int8(p0[6 + 0] * scale); + pp[96 + 1] = float2int8(p0[6 + 1] * scale); + pp[96 + 2] = float2int8(p0[6 + 16] * scale); + pp[96 + 3] = float2int8(p0[6 + 17] * scale); + pp[96 + 4] = float2int8(p0[6 + 32] * scale); + pp[96 + 5] = float2int8(p0[6 + 33] * scale); + pp[96 + 6] = float2int8(p0[6 + 48] * scale); + pp[96 + 7] = float2int8(p0[6 + 49] * scale); + pp[96 + 8] = float2int8(p0[6 + 64] * scale); + pp[96 + 9] = float2int8(p0[6 + 65] * scale); + pp[96 + 10] = float2int8(p0[6 + 80] * scale); + pp[96 + 11] = float2int8(p0[6 + 81] * scale); + pp[96 + 12] = float2int8(p0[6 + 96] * scale); + pp[96 + 13] = float2int8(p0[6 + 97] * scale); + pp[96 + 14] = float2int8(p0[6 + 112] * scale); + pp[96 + 15] = float2int8(p0[6 + 113] * scale); + + pp[112 + 0] = float2int8(p0[6 + 128 + 0] * scale); + pp[112 + 1] = float2int8(p0[6 + 128 + 1] * scale); + pp[112 + 2] = float2int8(p0[6 + 128 + 16] * scale); + pp[112 + 3] = float2int8(p0[6 + 128 + 17] * scale); + pp[112 + 4] = float2int8(p0[6 + 128 + 32] * scale); + pp[112 + 5] = float2int8(p0[6 + 128 + 33] * scale); + pp[112 + 6] = float2int8(p0[6 + 128 + 48] * scale); + pp[112 + 7] = float2int8(p0[6 + 128 + 49] * scale); + pp[112 + 8] = float2int8(p0[6 + 128 + 64] * scale); + pp[112 + 9] = float2int8(p0[6 + 128 + 65] * scale); + pp[112 + 10] = float2int8(p0[6 + 128 + 80] * scale); + pp[112 + 11] = float2int8(p0[6 + 128 + 81] * scale); + pp[112 + 12] = float2int8(p0[6 + 128 + 96] * scale); + pp[112 + 13] = float2int8(p0[6 + 128 + 97] * scale); + pp[112 + 14] = float2int8(p0[6 + 128 + 112] * scale); + pp[112 + 15] = float2int8(p0[6 + 128 + 113] * scale); + + pp[128 + 0] = float2int8(p0[8 + 0] * scale); + pp[128 + 1] = float2int8(p0[8 + 1] * scale); + pp[128 + 2] = float2int8(p0[8 + 16] * scale); + pp[128 + 3] = float2int8(p0[8 + 17] * scale); + pp[128 + 4] = float2int8(p0[8 + 32] * scale); + pp[128 + 5] = float2int8(p0[8 + 33] * scale); + pp[128 + 6] = float2int8(p0[8 + 48] * scale); + pp[128 + 7] = float2int8(p0[8 + 49] * scale); + pp[128 + 8] = float2int8(p0[8 + 64] * scale); + pp[128 + 9] = float2int8(p0[8 + 65] * scale); + pp[128 + 10] = float2int8(p0[8 + 80] * scale); + pp[128 + 11] = float2int8(p0[8 + 81] * scale); + pp[128 + 12] = float2int8(p0[8 + 96] * scale); + pp[128 + 13] = float2int8(p0[8 + 97] * scale); + pp[128 + 14] = float2int8(p0[8 + 112] * scale); + pp[128 + 15] = float2int8(p0[8 + 113] * scale); + + pp[16 + 128 + 0] = float2int8(p0[8 + 128 + 0] * scale); + pp[16 + 128 + 1] = float2int8(p0[8 + 128 + 1] * scale); + pp[16 + 128 + 2] = float2int8(p0[8 + 128 + 16] * scale); + pp[16 + 128 + 3] = float2int8(p0[8 + 128 + 17] * scale); + pp[16 + 128 + 4] = float2int8(p0[8 + 128 + 32] * scale); + pp[16 + 128 + 5] = float2int8(p0[8 + 128 + 33] * scale); + pp[16 + 128 + 6] = float2int8(p0[8 + 128 + 48] * scale); + pp[16 + 128 + 7] = float2int8(p0[8 + 128 + 49] * scale); + pp[16 + 128 + 8] = float2int8(p0[8 + 128 + 64] * scale); + pp[16 + 128 + 9] = float2int8(p0[8 + 128 + 65] * scale); + pp[16 + 128 + 10] = float2int8(p0[8 + 128 + 80] * scale); + pp[16 + 128 + 11] = float2int8(p0[8 + 128 + 81] * scale); + pp[16 + 128 + 12] = float2int8(p0[8 + 128 + 96] * scale); + pp[16 + 128 + 13] = float2int8(p0[8 + 128 + 97] * scale); + pp[16 + 128 + 14] = float2int8(p0[8 + 128 + 112] * scale); + pp[16 + 128 + 15] = float2int8(p0[8 + 128 + 113] * scale); + + pp[32 + 128 + 0] = float2int8(p0[10 + 0] * scale); + pp[32 + 128 + 1] = float2int8(p0[10 + 1] * scale); + pp[32 + 128 + 2] = float2int8(p0[10 + 16] * scale); + pp[32 + 128 + 3] = float2int8(p0[10 + 17] * scale); + pp[32 + 128 + 4] = float2int8(p0[10 + 32] * scale); + pp[32 + 128 + 5] = float2int8(p0[10 + 33] * scale); + pp[32 + 128 + 6] = float2int8(p0[10 + 48] * scale); + pp[32 + 128 + 7] = float2int8(p0[10 + 49] * scale); + pp[32 + 128 + 8] = float2int8(p0[10 + 64] * scale); + pp[32 + 128 + 9] = float2int8(p0[10 + 65] * scale); + pp[32 + 128 + 10] = float2int8(p0[10 + 80] * scale); + pp[32 + 128 + 11] = float2int8(p0[10 + 81] * scale); + pp[32 + 128 + 12] = float2int8(p0[10 + 96] * scale); + pp[32 + 128 + 13] = float2int8(p0[10 + 97] * scale); + pp[32 + 128 + 14] = float2int8(p0[10 + 112] * scale); + pp[32 + 128 + 15] = float2int8(p0[10 + 113] * scale); + + pp[48 + 128 + 0] = float2int8(p0[10 + 128 + 0] * scale); + pp[48 + 128 + 1] = float2int8(p0[10 + 128 + 1] * scale); + pp[48 + 128 + 2] = float2int8(p0[10 + 128 + 16] * scale); + pp[48 + 128 + 3] = float2int8(p0[10 + 128 + 17] * scale); + pp[48 + 128 + 4] = float2int8(p0[10 + 128 + 32] * scale); + pp[48 + 128 + 5] = float2int8(p0[10 + 128 + 33] * scale); + pp[48 + 128 + 6] = float2int8(p0[10 + 128 + 48] * scale); + pp[48 + 128 + 7] = float2int8(p0[10 + 128 + 49] * scale); + pp[48 + 128 + 8] = float2int8(p0[10 + 128 + 64] * scale); + pp[48 + 128 + 9] = float2int8(p0[10 + 128 + 65] * scale); + pp[48 + 128 + 10] = float2int8(p0[10 + 128 + 80] * scale); + pp[48 + 128 + 11] = float2int8(p0[10 + 128 + 81] * scale); + pp[48 + 128 + 12] = float2int8(p0[10 + 128 + 96] * scale); + pp[48 + 128 + 13] = float2int8(p0[10 + 128 + 97] * scale); + pp[48 + 128 + 14] = float2int8(p0[10 + 128 + 112] * scale); + pp[48 + 128 + 15] = float2int8(p0[10 + 128 + 113] * scale); + + pp[64 + 128 + 0] = float2int8(p0[12 + 0] * scale); + pp[64 + 128 + 1] = float2int8(p0[12 + 1] * scale); + pp[64 + 128 + 2] = float2int8(p0[12 + 16] * scale); + pp[64 + 128 + 3] = float2int8(p0[12 + 17] * scale); + pp[64 + 128 + 4] = float2int8(p0[12 + 32] * scale); + pp[64 + 128 + 5] = float2int8(p0[12 + 33] * scale); + pp[64 + 128 + 6] = float2int8(p0[12 + 48] * scale); + pp[64 + 128 + 7] = float2int8(p0[12 + 49] * scale); + pp[64 + 128 + 8] = float2int8(p0[12 + 64] * scale); + pp[64 + 128 + 9] = float2int8(p0[12 + 65] * scale); + pp[64 + 128 + 10] = float2int8(p0[12 + 80] * scale); + pp[64 + 128 + 11] = float2int8(p0[12 + 81] * scale); + pp[64 + 128 + 12] = float2int8(p0[12 + 96] * scale); + pp[64 + 128 + 13] = float2int8(p0[12 + 97] * scale); + pp[64 + 128 + 14] = float2int8(p0[12 + 112] * scale); + pp[64 + 128 + 15] = float2int8(p0[12 + 113] * scale); + + pp[80 + 128 + 0] = float2int8(p0[12 + 128 + 0] * scale); + pp[80 + 128 + 1] = float2int8(p0[12 + 128 + 1] * scale); + pp[80 + 128 + 2] = float2int8(p0[12 + 128 + 16] * scale); + pp[80 + 128 + 3] = float2int8(p0[12 + 128 + 17] * scale); + pp[80 + 128 + 4] = float2int8(p0[12 + 128 + 32] * scale); + pp[80 + 128 + 5] = float2int8(p0[12 + 128 + 33] * scale); + pp[80 + 128 + 6] = float2int8(p0[12 + 128 + 48] * scale); + pp[80 + 128 + 7] = float2int8(p0[12 + 128 + 49] * scale); + pp[80 + 128 + 8] = float2int8(p0[12 + 128 + 64] * scale); + pp[80 + 128 + 9] = float2int8(p0[12 + 128 + 65] * scale); + pp[80 + 128 + 10] = float2int8(p0[12 + 128 + 80] * scale); + pp[80 + 128 + 11] = float2int8(p0[12 + 128 + 81] * scale); + pp[80 + 128 + 12] = float2int8(p0[12 + 128 + 96] * scale); + pp[80 + 128 + 13] = float2int8(p0[12 + 128 + 97] * scale); + pp[80 + 128 + 14] = float2int8(p0[12 + 128 + 112] * scale); + pp[80 + 128 + 15] = float2int8(p0[12 + 128 + 113] * scale); + + pp[96 + 128 + 0] = float2int8(p0[14 + 0] * scale); + pp[96 + 128 + 1] = float2int8(p0[14 + 1] * scale); + pp[96 + 128 + 2] = float2int8(p0[14 + 16] * scale); + pp[96 + 128 + 3] = float2int8(p0[14 + 17] * scale); + pp[96 + 128 + 4] = float2int8(p0[14 + 32] * scale); + pp[96 + 128 + 5] = float2int8(p0[14 + 33] * scale); + pp[96 + 128 + 6] = float2int8(p0[14 + 48] * scale); + pp[96 + 128 + 7] = float2int8(p0[14 + 49] * scale); + pp[96 + 128 + 8] = float2int8(p0[14 + 64] * scale); + pp[96 + 128 + 9] = float2int8(p0[14 + 65] * scale); + pp[96 + 128 + 10] = float2int8(p0[14 + 80] * scale); + pp[96 + 128 + 11] = float2int8(p0[14 + 81] * scale); + pp[96 + 128 + 12] = float2int8(p0[14 + 96] * scale); + pp[96 + 128 + 13] = float2int8(p0[14 + 97] * scale); + pp[96 + 128 + 14] = float2int8(p0[14 + 112] * scale); + pp[96 + 128 + 15] = float2int8(p0[14 + 113] * scale); + + pp[112 + 128 + 0] = float2int8(p0[14 + 128 + 0] * scale); + pp[112 + 128 + 1] = float2int8(p0[14 + 128 + 1] * scale); + pp[112 + 128 + 2] = float2int8(p0[14 + 128 + 16] * scale); + pp[112 + 128 + 3] = float2int8(p0[14 + 128 + 17] * scale); + pp[112 + 128 + 4] = float2int8(p0[14 + 128 + 32] * scale); + pp[112 + 128 + 5] = float2int8(p0[14 + 128 + 33] * scale); + pp[112 + 128 + 6] = float2int8(p0[14 + 128 + 48] * scale); + pp[112 + 128 + 7] = float2int8(p0[14 + 128 + 49] * scale); + pp[112 + 128 + 8] = float2int8(p0[14 + 128 + 64] * scale); + pp[112 + 128 + 9] = float2int8(p0[14 + 128 + 65] * scale); + pp[112 + 128 + 10] = float2int8(p0[14 + 128 + 80] * scale); + pp[112 + 128 + 11] = float2int8(p0[14 + 128 + 81] * scale); + pp[112 + 128 + 12] = float2int8(p0[14 + 128 + 96] * scale); + pp[112 + 128 + 13] = float2int8(p0[14 + 128 + 97] * scale); + pp[112 + 128 + 14] = float2int8(p0[14 + 128 + 112] * scale); + pp[112 + 128 + 15] = float2int8(p0[14 + 128 + 113] * scale); + + pp += 256; + p0 += B_hstep * 16; + } + } if (elempack == 8) { int kk = 0; @@ -1996,66 +4227,130 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int pp[13] = float2int8(p0[49] * scale); pp[14] = float2int8(p0[56] * scale); pp[15] = float2int8(p0[57] * scale); - pp += 16; - - pp[0] = float2int8(p0[2] * scale); - pp[1] = float2int8(p0[3] * scale); - pp[2] = float2int8(p0[10] * scale); - pp[3] = float2int8(p0[11] * scale); - pp[4] = float2int8(p0[18] * scale); - pp[5] = float2int8(p0[19] * scale); - pp[6] = float2int8(p0[26] * scale); - pp[7] = float2int8(p0[27] * scale); - pp[8] = float2int8(p0[34] * scale); - pp[9] = float2int8(p0[35] * scale); - pp[10] = float2int8(p0[42] * scale); - pp[11] = float2int8(p0[43] * scale); - pp[12] = float2int8(p0[50] * scale); - pp[13] = float2int8(p0[51] * scale); - pp[14] = float2int8(p0[58] * scale); - pp[15] = float2int8(p0[59] * scale); - pp += 16; - - pp[0] = float2int8(p0[4] * scale); - pp[1] = float2int8(p0[5] * scale); - pp[2] = float2int8(p0[12] * scale); - pp[3] = float2int8(p0[13] * scale); - pp[4] = float2int8(p0[20] * scale); - pp[5] = float2int8(p0[21] * scale); - pp[6] = float2int8(p0[28] * scale); - pp[7] = float2int8(p0[29] * scale); - pp[8] = float2int8(p0[36] * scale); - pp[9] = float2int8(p0[37] * scale); - pp[10] = float2int8(p0[44] * scale); - pp[11] = float2int8(p0[45] * scale); - pp[12] = float2int8(p0[52] * scale); - pp[13] = float2int8(p0[53] * scale); - pp[14] = float2int8(p0[60] * scale); - pp[15] = float2int8(p0[61] * scale); - pp += 16; - - pp[0] = float2int8(p0[6] * scale); - pp[1] = float2int8(p0[7] * scale); - pp[2] = float2int8(p0[14] * scale); - pp[3] = float2int8(p0[15] * scale); - pp[4] = float2int8(p0[22] * scale); - pp[5] = float2int8(p0[23] * scale); - pp[6] = float2int8(p0[30] * scale); - pp[7] = float2int8(p0[31] * scale); - pp[8] = float2int8(p0[38] * scale); - pp[9] = float2int8(p0[39] * scale); - pp[10] = float2int8(p0[46] * scale); - pp[11] = float2int8(p0[47] * scale); - pp[12] = float2int8(p0[54] * scale); - pp[13] = float2int8(p0[55] * scale); - pp[14] = float2int8(p0[62] * scale); - pp[15] = float2int8(p0[63] * scale); - pp += 16; + pp[16 + 0] = float2int8(p0[64 + 0] * scale); + pp[16 + 1] = float2int8(p0[64 + 1] * scale); + pp[16 + 2] = float2int8(p0[64 + 8] * scale); + pp[16 + 3] = float2int8(p0[64 + 9] * scale); + pp[16 + 4] = float2int8(p0[64 + 16] * scale); + pp[16 + 5] = float2int8(p0[64 + 17] * scale); + pp[16 + 6] = float2int8(p0[64 + 24] * scale); + pp[16 + 7] = float2int8(p0[64 + 25] * scale); + pp[16 + 8] = float2int8(p0[64 + 32] * scale); + pp[16 + 9] = float2int8(p0[64 + 33] * scale); + pp[16 + 10] = float2int8(p0[64 + 40] * scale); + pp[16 + 11] = float2int8(p0[64 + 41] * scale); + pp[16 + 12] = float2int8(p0[64 + 48] * scale); + pp[16 + 13] = float2int8(p0[64 + 49] * scale); + pp[16 + 14] = float2int8(p0[64 + 56] * scale); + pp[16 + 15] = float2int8(p0[64 + 57] * scale); + + pp[32 + 0] = float2int8(p0[2] * scale); + pp[32 + 1] = float2int8(p0[3] * scale); + pp[32 + 2] = float2int8(p0[10] * scale); + pp[32 + 3] = float2int8(p0[11] * scale); + pp[32 + 4] = float2int8(p0[18] * scale); + pp[32 + 5] = float2int8(p0[19] * scale); + pp[32 + 6] = float2int8(p0[26] * scale); + pp[32 + 7] = float2int8(p0[27] * scale); + pp[32 + 8] = float2int8(p0[34] * scale); + pp[32 + 9] = float2int8(p0[35] * scale); + pp[32 + 10] = float2int8(p0[42] * scale); + pp[32 + 11] = float2int8(p0[43] * scale); + pp[32 + 12] = float2int8(p0[50] * scale); + pp[32 + 13] = float2int8(p0[51] * scale); + pp[32 + 14] = float2int8(p0[58] * scale); + pp[32 + 15] = float2int8(p0[59] * scale); + + pp[48 + 0] = float2int8(p0[64 + 2] * scale); + pp[48 + 1] = float2int8(p0[64 + 3] * scale); + pp[48 + 2] = float2int8(p0[64 + 10] * scale); + pp[48 + 3] = float2int8(p0[64 + 11] * scale); + pp[48 + 4] = float2int8(p0[64 + 18] * scale); + pp[48 + 5] = float2int8(p0[64 + 19] * scale); + pp[48 + 6] = float2int8(p0[64 + 26] * scale); + pp[48 + 7] = float2int8(p0[64 + 27] * scale); + pp[48 + 8] = float2int8(p0[64 + 34] * scale); + pp[48 + 9] = float2int8(p0[64 + 35] * scale); + pp[48 + 10] = float2int8(p0[64 + 42] * scale); + pp[48 + 11] = float2int8(p0[64 + 43] * scale); + pp[48 + 12] = float2int8(p0[64 + 50] * scale); + pp[48 + 13] = float2int8(p0[64 + 51] * scale); + pp[48 + 14] = float2int8(p0[64 + 58] * scale); + pp[48 + 15] = float2int8(p0[64 + 59] * scale); + + pp[64 + 0] = float2int8(p0[4] * scale); + pp[64 + 1] = float2int8(p0[5] * scale); + pp[64 + 2] = float2int8(p0[12] * scale); + pp[64 + 3] = float2int8(p0[13] * scale); + pp[64 + 4] = float2int8(p0[20] * scale); + pp[64 + 5] = float2int8(p0[21] * scale); + pp[64 + 6] = float2int8(p0[28] * scale); + pp[64 + 7] = float2int8(p0[29] * scale); + pp[64 + 8] = float2int8(p0[36] * scale); + pp[64 + 9] = float2int8(p0[37] * scale); + pp[64 + 10] = float2int8(p0[44] * scale); + pp[64 + 11] = float2int8(p0[45] * scale); + pp[64 + 12] = float2int8(p0[52] * scale); + pp[64 + 13] = float2int8(p0[53] * scale); + pp[64 + 14] = float2int8(p0[60] * scale); + pp[64 + 15] = float2int8(p0[61] * scale); + + pp[80 + 0] = float2int8(p0[64 + 4] * scale); + pp[80 + 1] = float2int8(p0[64 + 5] * scale); + pp[80 + 2] = float2int8(p0[64 + 12] * scale); + pp[80 + 3] = float2int8(p0[64 + 13] * scale); + pp[80 + 4] = float2int8(p0[64 + 20] * scale); + pp[80 + 5] = float2int8(p0[64 + 21] * scale); + pp[80 + 6] = float2int8(p0[64 + 28] * scale); + pp[80 + 7] = float2int8(p0[64 + 29] * scale); + pp[80 + 8] = float2int8(p0[64 + 36] * scale); + pp[80 + 9] = float2int8(p0[64 + 37] * scale); + pp[80 + 10] = float2int8(p0[64 + 44] * scale); + pp[80 + 11] = float2int8(p0[64 + 45] * scale); + pp[80 + 12] = float2int8(p0[64 + 52] * scale); + pp[80 + 13] = float2int8(p0[64 + 53] * scale); + pp[80 + 14] = float2int8(p0[64 + 60] * scale); + pp[80 + 15] = float2int8(p0[64 + 61] * scale); + + pp[96 + 0] = float2int8(p0[6] * scale); + pp[96 + 1] = float2int8(p0[7] * scale); + pp[96 + 2] = float2int8(p0[14] * scale); + pp[96 + 3] = float2int8(p0[15] * scale); + pp[96 + 4] = float2int8(p0[22] * scale); + pp[96 + 5] = float2int8(p0[23] * scale); + pp[96 + 6] = float2int8(p0[30] * scale); + pp[96 + 7] = float2int8(p0[31] * scale); + pp[96 + 8] = float2int8(p0[38] * scale); + pp[96 + 9] = float2int8(p0[39] * scale); + pp[96 + 10] = float2int8(p0[46] * scale); + pp[96 + 11] = float2int8(p0[47] * scale); + pp[96 + 12] = float2int8(p0[54] * scale); + pp[96 + 13] = float2int8(p0[55] * scale); + pp[96 + 14] = float2int8(p0[62] * scale); + pp[96 + 15] = float2int8(p0[63] * scale); + + pp[112 + 0] = float2int8(p0[64 + 6] * scale); + pp[112 + 1] = float2int8(p0[64 + 7] * scale); + pp[112 + 2] = float2int8(p0[64 + 14] * scale); + pp[112 + 3] = float2int8(p0[64 + 15] * scale); + pp[112 + 4] = float2int8(p0[64 + 22] * scale); + pp[112 + 5] = float2int8(p0[64 + 23] * scale); + pp[112 + 6] = float2int8(p0[64 + 30] * scale); + pp[112 + 7] = float2int8(p0[64 + 31] * scale); + pp[112 + 8] = float2int8(p0[64 + 38] * scale); + pp[112 + 9] = float2int8(p0[64 + 39] * scale); + pp[112 + 10] = float2int8(p0[64 + 46] * scale); + pp[112 + 11] = float2int8(p0[64 + 47] * scale); + pp[112 + 12] = float2int8(p0[64 + 54] * scale); + pp[112 + 13] = float2int8(p0[64 + 55] * scale); + pp[112 + 14] = float2int8(p0[64 + 62] * scale); + pp[112 + 15] = float2int8(p0[64 + 63] * scale); + + pp += 128; p0 += B_hstep * 8; } } -#endif // __AVX__ if (elempack == 4) { int kk = 0; @@ -2078,24 +4373,58 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int pp[14] = float2int8(p0[28] * scale); pp[15] = float2int8(p0[29] * scale); - pp[16 + 0] = float2int8(p0[2] * scale); - pp[16 + 1] = float2int8(p0[3] * scale); - pp[16 + 2] = float2int8(p0[6] * scale); - pp[16 + 3] = float2int8(p0[7] * scale); - pp[16 + 4] = float2int8(p0[10] * scale); - pp[16 + 5] = float2int8(p0[11] * scale); - pp[16 + 6] = float2int8(p0[14] * scale); - pp[16 + 7] = float2int8(p0[15] * scale); - pp[16 + 8] = float2int8(p0[18] * scale); - pp[16 + 9] = float2int8(p0[19] * scale); - pp[16 + 10] = float2int8(p0[22] * scale); - pp[16 + 11] = float2int8(p0[23] * scale); - pp[16 + 12] = float2int8(p0[26] * scale); - pp[16 + 13] = float2int8(p0[27] * scale); - pp[16 + 14] = float2int8(p0[30] * scale); - pp[16 + 15] = float2int8(p0[31] * scale); - - pp += 32; + pp[16 + 0] = float2int8(p0[32 + 0] * scale); + pp[16 + 1] = float2int8(p0[32 + 1] * scale); + pp[16 + 2] = float2int8(p0[32 + 4] * scale); + pp[16 + 3] = float2int8(p0[32 + 5] * scale); + pp[16 + 4] = float2int8(p0[32 + 8] * scale); + pp[16 + 5] = float2int8(p0[32 + 9] * scale); + pp[16 + 6] = float2int8(p0[32 + 12] * scale); + pp[16 + 7] = float2int8(p0[32 + 13] * scale); + pp[16 + 8] = float2int8(p0[32 + 16] * scale); + pp[16 + 9] = float2int8(p0[32 + 17] * scale); + pp[16 + 10] = float2int8(p0[32 + 20] * scale); + pp[16 + 11] = float2int8(p0[32 + 21] * scale); + pp[16 + 12] = float2int8(p0[32 + 24] * scale); + pp[16 + 13] = float2int8(p0[32 + 25] * scale); + pp[16 + 14] = float2int8(p0[32 + 28] * scale); + pp[16 + 15] = float2int8(p0[32 + 29] * scale); + + pp[32 + 0] = float2int8(p0[2] * scale); + pp[32 + 1] = float2int8(p0[3] * scale); + pp[32 + 2] = float2int8(p0[6] * scale); + pp[32 + 3] = float2int8(p0[7] * scale); + pp[32 + 4] = float2int8(p0[10] * scale); + pp[32 + 5] = float2int8(p0[11] * scale); + pp[32 + 6] = float2int8(p0[14] * scale); + pp[32 + 7] = float2int8(p0[15] * scale); + pp[32 + 8] = float2int8(p0[18] * scale); + pp[32 + 9] = float2int8(p0[19] * scale); + pp[32 + 10] = float2int8(p0[22] * scale); + pp[32 + 11] = float2int8(p0[23] * scale); + pp[32 + 12] = float2int8(p0[26] * scale); + pp[32 + 13] = float2int8(p0[27] * scale); + pp[32 + 14] = float2int8(p0[30] * scale); + pp[32 + 15] = float2int8(p0[31] * scale); + + pp[48 + 0] = float2int8(p0[32 + 2] * scale); + pp[48 + 1] = float2int8(p0[32 + 3] * scale); + pp[48 + 2] = float2int8(p0[32 + 6] * scale); + pp[48 + 3] = float2int8(p0[32 + 7] * scale); + pp[48 + 4] = float2int8(p0[32 + 10] * scale); + pp[48 + 5] = float2int8(p0[32 + 11] * scale); + pp[48 + 6] = float2int8(p0[32 + 14] * scale); + pp[48 + 7] = float2int8(p0[32 + 15] * scale); + pp[48 + 8] = float2int8(p0[32 + 18] * scale); + pp[48 + 9] = float2int8(p0[32 + 19] * scale); + pp[48 + 10] = float2int8(p0[32 + 22] * scale); + pp[48 + 11] = float2int8(p0[32 + 23] * scale); + pp[48 + 12] = float2int8(p0[32 + 26] * scale); + pp[48 + 13] = float2int8(p0[32 + 27] * scale); + pp[48 + 14] = float2int8(p0[32 + 30] * scale); + pp[48 + 15] = float2int8(p0[32 + 31] * scale); + + pp += 64; p0 += B_hstep * 4; } } @@ -2121,7 +4450,23 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int pp[14] = float2int8(p0[7] * scale); pp[15] = float2int8(p0[B_hstep + 7] * scale); - pp += 16; + pp[16 + 0] = float2int8(p0[8] * scale); + pp[16 + 1] = float2int8(p0[B_hstep + 8] * scale); + pp[16 + 2] = float2int8(p0[9] * scale); + pp[16 + 3] = float2int8(p0[B_hstep + 9] * scale); + pp[16 + 4] = float2int8(p0[10] * scale); + pp[16 + 5] = float2int8(p0[B_hstep + 10] * scale); + pp[16 + 6] = float2int8(p0[11] * scale); + pp[16 + 7] = float2int8(p0[B_hstep + 11] * scale); + pp[16 + 8] = float2int8(p0[12] * scale); + pp[16 + 9] = float2int8(p0[B_hstep + 12] * scale); + pp[16 + 10] = float2int8(p0[13] * scale); + pp[16 + 11] = float2int8(p0[B_hstep + 13] * scale); + pp[16 + 12] = float2int8(p0[14] * scale); + pp[16 + 13] = float2int8(p0[B_hstep + 14] * scale); + pp[16 + 14] = float2int8(p0[15] * scale); + pp[16 + 15] = float2int8(p0[B_hstep + 15] * scale); + pp += 32; p0 += B_hstep * 2; } for (; kk < max_kk; kk++) @@ -2134,7 +4479,332 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int pp[5] = float2int8(p0[5] * scale); pp[6] = float2int8(p0[6] * scale); pp[7] = float2int8(p0[7] * scale); - pp += 8; + pp[8] = float2int8(p0[8] * scale); + pp[9] = float2int8(p0[9] * scale); + pp[10] = float2int8(p0[10] * scale); + pp[11] = float2int8(p0[11] * scale); + pp[12] = float2int8(p0[12] * scale); + pp[13] = float2int8(p0[13] * scale); + pp[14] = float2int8(p0[14] * scale); + pp[15] = float2int8(p0[15] * scale); + pp += 16; + p0 += B_hstep; + } + } + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * elempack; + +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[1] * scale); + pp[2] = float2int8(p0[16] * scale); + pp[3] = float2int8(p0[17] * scale); + pp[4] = float2int8(p0[32] * scale); + pp[5] = float2int8(p0[33] * scale); + pp[6] = float2int8(p0[48] * scale); + pp[7] = float2int8(p0[49] * scale); + pp[8] = float2int8(p0[64] * scale); + pp[9] = float2int8(p0[65] * scale); + pp[10] = float2int8(p0[80] * scale); + pp[11] = float2int8(p0[81] * scale); + pp[12] = float2int8(p0[96] * scale); + pp[13] = float2int8(p0[97] * scale); + pp[14] = float2int8(p0[112] * scale); + pp[15] = float2int8(p0[113] * scale); + + pp[16 + 0] = float2int8(p0[2 + 0] * scale); + pp[16 + 1] = float2int8(p0[2 + 1] * scale); + pp[16 + 2] = float2int8(p0[2 + 16] * scale); + pp[16 + 3] = float2int8(p0[2 + 17] * scale); + pp[16 + 4] = float2int8(p0[2 + 32] * scale); + pp[16 + 5] = float2int8(p0[2 + 33] * scale); + pp[16 + 6] = float2int8(p0[2 + 48] * scale); + pp[16 + 7] = float2int8(p0[2 + 49] * scale); + pp[16 + 8] = float2int8(p0[2 + 64] * scale); + pp[16 + 9] = float2int8(p0[2 + 65] * scale); + pp[16 + 10] = float2int8(p0[2 + 80] * scale); + pp[16 + 11] = float2int8(p0[2 + 81] * scale); + pp[16 + 12] = float2int8(p0[2 + 96] * scale); + pp[16 + 13] = float2int8(p0[2 + 97] * scale); + pp[16 + 14] = float2int8(p0[2 + 112] * scale); + pp[16 + 15] = float2int8(p0[2 + 113] * scale); + + pp[32 + 0] = float2int8(p0[4 + 0] * scale); + pp[32 + 1] = float2int8(p0[4 + 1] * scale); + pp[32 + 2] = float2int8(p0[4 + 16] * scale); + pp[32 + 3] = float2int8(p0[4 + 17] * scale); + pp[32 + 4] = float2int8(p0[4 + 32] * scale); + pp[32 + 5] = float2int8(p0[4 + 33] * scale); + pp[32 + 6] = float2int8(p0[4 + 48] * scale); + pp[32 + 7] = float2int8(p0[4 + 49] * scale); + pp[32 + 8] = float2int8(p0[4 + 64] * scale); + pp[32 + 9] = float2int8(p0[4 + 65] * scale); + pp[32 + 10] = float2int8(p0[4 + 80] * scale); + pp[32 + 11] = float2int8(p0[4 + 81] * scale); + pp[32 + 12] = float2int8(p0[4 + 96] * scale); + pp[32 + 13] = float2int8(p0[4 + 97] * scale); + pp[32 + 14] = float2int8(p0[4 + 112] * scale); + pp[32 + 15] = float2int8(p0[4 + 113] * scale); + + pp[48 + 0] = float2int8(p0[6 + 0] * scale); + pp[48 + 1] = float2int8(p0[6 + 1] * scale); + pp[48 + 2] = float2int8(p0[6 + 16] * scale); + pp[48 + 3] = float2int8(p0[6 + 17] * scale); + pp[48 + 4] = float2int8(p0[6 + 32] * scale); + pp[48 + 5] = float2int8(p0[6 + 33] * scale); + pp[48 + 6] = float2int8(p0[6 + 48] * scale); + pp[48 + 7] = float2int8(p0[6 + 49] * scale); + pp[48 + 8] = float2int8(p0[6 + 64] * scale); + pp[48 + 9] = float2int8(p0[6 + 65] * scale); + pp[48 + 10] = float2int8(p0[6 + 80] * scale); + pp[48 + 11] = float2int8(p0[6 + 81] * scale); + pp[48 + 12] = float2int8(p0[6 + 96] * scale); + pp[48 + 13] = float2int8(p0[6 + 97] * scale); + pp[48 + 14] = float2int8(p0[6 + 112] * scale); + pp[48 + 15] = float2int8(p0[6 + 113] * scale); + + pp[64 + 0] = float2int8(p0[8 + 0] * scale); + pp[64 + 1] = float2int8(p0[8 + 1] * scale); + pp[64 + 2] = float2int8(p0[8 + 16] * scale); + pp[64 + 3] = float2int8(p0[8 + 17] * scale); + pp[64 + 4] = float2int8(p0[8 + 32] * scale); + pp[64 + 5] = float2int8(p0[8 + 33] * scale); + pp[64 + 6] = float2int8(p0[8 + 48] * scale); + pp[64 + 7] = float2int8(p0[8 + 49] * scale); + pp[64 + 8] = float2int8(p0[8 + 64] * scale); + pp[64 + 9] = float2int8(p0[8 + 65] * scale); + pp[64 + 10] = float2int8(p0[8 + 80] * scale); + pp[64 + 11] = float2int8(p0[8 + 81] * scale); + pp[64 + 12] = float2int8(p0[8 + 96] * scale); + pp[64 + 13] = float2int8(p0[8 + 97] * scale); + pp[64 + 14] = float2int8(p0[8 + 112] * scale); + pp[64 + 15] = float2int8(p0[8 + 113] * scale); + + pp[80 + 0] = float2int8(p0[10 + 0] * scale); + pp[80 + 1] = float2int8(p0[10 + 1] * scale); + pp[80 + 2] = float2int8(p0[10 + 16] * scale); + pp[80 + 3] = float2int8(p0[10 + 17] * scale); + pp[80 + 4] = float2int8(p0[10 + 32] * scale); + pp[80 + 5] = float2int8(p0[10 + 33] * scale); + pp[80 + 6] = float2int8(p0[10 + 48] * scale); + pp[80 + 7] = float2int8(p0[10 + 49] * scale); + pp[80 + 8] = float2int8(p0[10 + 64] * scale); + pp[80 + 9] = float2int8(p0[10 + 65] * scale); + pp[80 + 10] = float2int8(p0[10 + 80] * scale); + pp[80 + 11] = float2int8(p0[10 + 81] * scale); + pp[80 + 12] = float2int8(p0[10 + 96] * scale); + pp[80 + 13] = float2int8(p0[10 + 97] * scale); + pp[80 + 14] = float2int8(p0[10 + 112] * scale); + pp[80 + 15] = float2int8(p0[10 + 113] * scale); + + pp[96 + 0] = float2int8(p0[12 + 0] * scale); + pp[96 + 1] = float2int8(p0[12 + 1] * scale); + pp[96 + 2] = float2int8(p0[12 + 16] * scale); + pp[96 + 3] = float2int8(p0[12 + 17] * scale); + pp[96 + 4] = float2int8(p0[12 + 32] * scale); + pp[96 + 5] = float2int8(p0[12 + 33] * scale); + pp[96 + 6] = float2int8(p0[12 + 48] * scale); + pp[96 + 7] = float2int8(p0[12 + 49] * scale); + pp[96 + 8] = float2int8(p0[12 + 64] * scale); + pp[96 + 9] = float2int8(p0[12 + 65] * scale); + pp[96 + 10] = float2int8(p0[12 + 80] * scale); + pp[96 + 11] = float2int8(p0[12 + 81] * scale); + pp[96 + 12] = float2int8(p0[12 + 96] * scale); + pp[96 + 13] = float2int8(p0[12 + 97] * scale); + pp[96 + 14] = float2int8(p0[12 + 112] * scale); + pp[96 + 15] = float2int8(p0[12 + 113] * scale); + + pp[112 + 0] = float2int8(p0[14 + 0] * scale); + pp[112 + 1] = float2int8(p0[14 + 1] * scale); + pp[112 + 2] = float2int8(p0[14 + 16] * scale); + pp[112 + 3] = float2int8(p0[14 + 17] * scale); + pp[112 + 4] = float2int8(p0[14 + 32] * scale); + pp[112 + 5] = float2int8(p0[14 + 33] * scale); + pp[112 + 6] = float2int8(p0[14 + 48] * scale); + pp[112 + 7] = float2int8(p0[14 + 49] * scale); + pp[112 + 8] = float2int8(p0[14 + 64] * scale); + pp[112 + 9] = float2int8(p0[14 + 65] * scale); + pp[112 + 10] = float2int8(p0[14 + 80] * scale); + pp[112 + 11] = float2int8(p0[14 + 81] * scale); + pp[112 + 12] = float2int8(p0[14 + 96] * scale); + pp[112 + 13] = float2int8(p0[14 + 97] * scale); + pp[112 + 14] = float2int8(p0[14 + 112] * scale); + pp[112 + 15] = float2int8(p0[14 + 113] * scale); + + pp += 128; + p0 += B_hstep * 16; + } + } +#endif // __AVX512F__ + if (elempack == 8) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[1] * scale); + pp[2] = float2int8(p0[8] * scale); + pp[3] = float2int8(p0[9] * scale); + pp[4] = float2int8(p0[16] * scale); + pp[5] = float2int8(p0[17] * scale); + pp[6] = float2int8(p0[24] * scale); + pp[7] = float2int8(p0[25] * scale); + pp[8] = float2int8(p0[32] * scale); + pp[9] = float2int8(p0[33] * scale); + pp[10] = float2int8(p0[40] * scale); + pp[11] = float2int8(p0[41] * scale); + pp[12] = float2int8(p0[48] * scale); + pp[13] = float2int8(p0[49] * scale); + pp[14] = float2int8(p0[56] * scale); + pp[15] = float2int8(p0[57] * scale); + pp += 16; + + pp[0] = float2int8(p0[2] * scale); + pp[1] = float2int8(p0[3] * scale); + pp[2] = float2int8(p0[10] * scale); + pp[3] = float2int8(p0[11] * scale); + pp[4] = float2int8(p0[18] * scale); + pp[5] = float2int8(p0[19] * scale); + pp[6] = float2int8(p0[26] * scale); + pp[7] = float2int8(p0[27] * scale); + pp[8] = float2int8(p0[34] * scale); + pp[9] = float2int8(p0[35] * scale); + pp[10] = float2int8(p0[42] * scale); + pp[11] = float2int8(p0[43] * scale); + pp[12] = float2int8(p0[50] * scale); + pp[13] = float2int8(p0[51] * scale); + pp[14] = float2int8(p0[58] * scale); + pp[15] = float2int8(p0[59] * scale); + pp += 16; + + pp[0] = float2int8(p0[4] * scale); + pp[1] = float2int8(p0[5] * scale); + pp[2] = float2int8(p0[12] * scale); + pp[3] = float2int8(p0[13] * scale); + pp[4] = float2int8(p0[20] * scale); + pp[5] = float2int8(p0[21] * scale); + pp[6] = float2int8(p0[28] * scale); + pp[7] = float2int8(p0[29] * scale); + pp[8] = float2int8(p0[36] * scale); + pp[9] = float2int8(p0[37] * scale); + pp[10] = float2int8(p0[44] * scale); + pp[11] = float2int8(p0[45] * scale); + pp[12] = float2int8(p0[52] * scale); + pp[13] = float2int8(p0[53] * scale); + pp[14] = float2int8(p0[60] * scale); + pp[15] = float2int8(p0[61] * scale); + pp += 16; + + pp[0] = float2int8(p0[6] * scale); + pp[1] = float2int8(p0[7] * scale); + pp[2] = float2int8(p0[14] * scale); + pp[3] = float2int8(p0[15] * scale); + pp[4] = float2int8(p0[22] * scale); + pp[5] = float2int8(p0[23] * scale); + pp[6] = float2int8(p0[30] * scale); + pp[7] = float2int8(p0[31] * scale); + pp[8] = float2int8(p0[38] * scale); + pp[9] = float2int8(p0[39] * scale); + pp[10] = float2int8(p0[46] * scale); + pp[11] = float2int8(p0[47] * scale); + pp[12] = float2int8(p0[54] * scale); + pp[13] = float2int8(p0[55] * scale); + pp[14] = float2int8(p0[62] * scale); + pp[15] = float2int8(p0[63] * scale); + pp += 16; + + p0 += B_hstep * 8; + } + } +#endif // __AVX__ + if (elempack == 4) + { + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[1] * scale); + pp[2] = float2int8(p0[4] * scale); + pp[3] = float2int8(p0[5] * scale); + pp[4] = float2int8(p0[8] * scale); + pp[5] = float2int8(p0[9] * scale); + pp[6] = float2int8(p0[12] * scale); + pp[7] = float2int8(p0[13] * scale); + pp[8] = float2int8(p0[16] * scale); + pp[9] = float2int8(p0[17] * scale); + pp[10] = float2int8(p0[20] * scale); + pp[11] = float2int8(p0[21] * scale); + pp[12] = float2int8(p0[24] * scale); + pp[13] = float2int8(p0[25] * scale); + pp[14] = float2int8(p0[28] * scale); + pp[15] = float2int8(p0[29] * scale); + + pp[16 + 0] = float2int8(p0[2] * scale); + pp[16 + 1] = float2int8(p0[3] * scale); + pp[16 + 2] = float2int8(p0[6] * scale); + pp[16 + 3] = float2int8(p0[7] * scale); + pp[16 + 4] = float2int8(p0[10] * scale); + pp[16 + 5] = float2int8(p0[11] * scale); + pp[16 + 6] = float2int8(p0[14] * scale); + pp[16 + 7] = float2int8(p0[15] * scale); + pp[16 + 8] = float2int8(p0[18] * scale); + pp[16 + 9] = float2int8(p0[19] * scale); + pp[16 + 10] = float2int8(p0[22] * scale); + pp[16 + 11] = float2int8(p0[23] * scale); + pp[16 + 12] = float2int8(p0[26] * scale); + pp[16 + 13] = float2int8(p0[27] * scale); + pp[16 + 14] = float2int8(p0[30] * scale); + pp[16 + 15] = float2int8(p0[31] * scale); + + pp += 32; + p0 += B_hstep * 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[B_hstep] * scale); + pp[2] = float2int8(p0[1] * scale); + pp[3] = float2int8(p0[B_hstep + 1] * scale); + pp[4] = float2int8(p0[2] * scale); + pp[5] = float2int8(p0[B_hstep + 2] * scale); + pp[6] = float2int8(p0[3] * scale); + pp[7] = float2int8(p0[B_hstep + 3] * scale); + pp[8] = float2int8(p0[4] * scale); + pp[9] = float2int8(p0[B_hstep + 4] * scale); + pp[10] = float2int8(p0[5] * scale); + pp[11] = float2int8(p0[B_hstep + 5] * scale); + pp[12] = float2int8(p0[6] * scale); + pp[13] = float2int8(p0[B_hstep + 6] * scale); + pp[14] = float2int8(p0[7] * scale); + pp[15] = float2int8(p0[B_hstep + 7] * scale); + + pp += 16; + p0 += B_hstep * 2; + } + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[1] * scale); + pp[2] = float2int8(p0[2] * scale); + pp[3] = float2int8(p0[3] * scale); + pp[4] = float2int8(p0[4] * scale); + pp[5] = float2int8(p0[5] * scale); + pp[6] = float2int8(p0[6] * scale); + pp[7] = float2int8(p0[7] * scale); + pp += 8; p0 += B_hstep; } } @@ -2145,6 +4815,89 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int const float* p0 = (const float*)B + k * B_hstep + (j + jj) * elempack; #if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[1] * scale); + pp[2] = float2int8(p0[16] * scale); + pp[3] = float2int8(p0[17] * scale); + pp[4] = float2int8(p0[32] * scale); + pp[5] = float2int8(p0[33] * scale); + pp[6] = float2int8(p0[48] * scale); + pp[7] = float2int8(p0[49] * scale); + + pp[8] = float2int8(p0[2 + 0] * scale); + pp[9] = float2int8(p0[2 + 1] * scale); + pp[10] = float2int8(p0[2 + 16] * scale); + pp[11] = float2int8(p0[2 + 17] * scale); + pp[12] = float2int8(p0[2 + 32] * scale); + pp[13] = float2int8(p0[2 + 33] * scale); + pp[14] = float2int8(p0[2 + 48] * scale); + pp[15] = float2int8(p0[2 + 49] * scale); + + pp[16 + 0] = float2int8(p0[4 + 0] * scale); + pp[16 + 1] = float2int8(p0[4 + 1] * scale); + pp[16 + 2] = float2int8(p0[4 + 16] * scale); + pp[16 + 3] = float2int8(p0[4 + 17] * scale); + pp[16 + 4] = float2int8(p0[4 + 32] * scale); + pp[16 + 5] = float2int8(p0[4 + 33] * scale); + pp[16 + 6] = float2int8(p0[4 + 48] * scale); + pp[16 + 7] = float2int8(p0[4 + 49] * scale); + + pp[16 + 8] = float2int8(p0[6 + 0] * scale); + pp[16 + 9] = float2int8(p0[6 + 1] * scale); + pp[16 + 10] = float2int8(p0[6 + 16] * scale); + pp[16 + 11] = float2int8(p0[6 + 17] * scale); + pp[16 + 12] = float2int8(p0[6 + 32] * scale); + pp[16 + 13] = float2int8(p0[6 + 33] * scale); + pp[16 + 14] = float2int8(p0[6 + 48] * scale); + pp[16 + 15] = float2int8(p0[6 + 49] * scale); + + pp[32 + 0] = float2int8(p0[8 + 0] * scale); + pp[32 + 1] = float2int8(p0[8 + 1] * scale); + pp[32 + 2] = float2int8(p0[8 + 16] * scale); + pp[32 + 3] = float2int8(p0[8 + 17] * scale); + pp[32 + 4] = float2int8(p0[8 + 32] * scale); + pp[32 + 5] = float2int8(p0[8 + 33] * scale); + pp[32 + 6] = float2int8(p0[8 + 48] * scale); + pp[32 + 7] = float2int8(p0[8 + 49] * scale); + + pp[32 + 8] = float2int8(p0[10 + 0] * scale); + pp[32 + 9] = float2int8(p0[10 + 1] * scale); + pp[32 + 10] = float2int8(p0[10 + 16] * scale); + pp[32 + 11] = float2int8(p0[10 + 17] * scale); + pp[32 + 12] = float2int8(p0[10 + 32] * scale); + pp[32 + 13] = float2int8(p0[10 + 33] * scale); + pp[32 + 14] = float2int8(p0[10 + 48] * scale); + pp[32 + 15] = float2int8(p0[10 + 49] * scale); + + pp[48 + 0] = float2int8(p0[12 + 0] * scale); + pp[48 + 1] = float2int8(p0[12 + 1] * scale); + pp[48 + 2] = float2int8(p0[12 + 16] * scale); + pp[48 + 3] = float2int8(p0[12 + 17] * scale); + pp[48 + 4] = float2int8(p0[12 + 32] * scale); + pp[48 + 5] = float2int8(p0[12 + 33] * scale); + pp[48 + 6] = float2int8(p0[12 + 48] * scale); + pp[48 + 7] = float2int8(p0[12 + 49] * scale); + + pp[48 + 8] = float2int8(p0[14 + 0] * scale); + pp[48 + 9] = float2int8(p0[14 + 1] * scale); + pp[48 + 10] = float2int8(p0[14 + 16] * scale); + pp[48 + 11] = float2int8(p0[14 + 17] * scale); + pp[48 + 12] = float2int8(p0[14 + 32] * scale); + pp[48 + 13] = float2int8(p0[14 + 33] * scale); + pp[48 + 14] = float2int8(p0[14 + 48] * scale); + pp[48 + 15] = float2int8(p0[14 + 49] * scale); + + pp += 64; + p0 += B_hstep * 16; + } + } +#endif // __AVX512F__ if (elempack == 8) { int kk = 0; @@ -2252,6 +5005,57 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int #if __SSE2__ #if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[1] * scale); + pp[2] = float2int8(p0[16] * scale); + pp[3] = float2int8(p0[17] * scale); + + pp[4] = float2int8(p0[2] * scale); + pp[5] = float2int8(p0[3] * scale); + pp[6] = float2int8(p0[18] * scale); + pp[7] = float2int8(p0[19] * scale); + + pp[8] = float2int8(p0[4] * scale); + pp[9] = float2int8(p0[5] * scale); + pp[10] = float2int8(p0[20] * scale); + pp[11] = float2int8(p0[21] * scale); + + pp[12] = float2int8(p0[6] * scale); + pp[13] = float2int8(p0[7] * scale); + pp[14] = float2int8(p0[22] * scale); + pp[15] = float2int8(p0[23] * scale); + + pp[16 + 0] = float2int8(p0[8] * scale); + pp[16 + 1] = float2int8(p0[9] * scale); + pp[16 + 2] = float2int8(p0[24] * scale); + pp[16 + 3] = float2int8(p0[25] * scale); + + pp[16 + 4] = float2int8(p0[10] * scale); + pp[16 + 5] = float2int8(p0[11] * scale); + pp[16 + 6] = float2int8(p0[26] * scale); + pp[16 + 7] = float2int8(p0[27] * scale); + + pp[16 + 8] = float2int8(p0[12] * scale); + pp[16 + 9] = float2int8(p0[13] * scale); + pp[16 + 10] = float2int8(p0[28] * scale); + pp[16 + 11] = float2int8(p0[29] * scale); + + pp[16 + 12] = float2int8(p0[14] * scale); + pp[16 + 13] = float2int8(p0[15] * scale); + pp[16 + 14] = float2int8(p0[30] * scale); + pp[16 + 15] = float2int8(p0[31] * scale); + + pp += 32; + p0 += B_hstep * 16; + } + } +#endif // __AVX512F__ if (elempack == 8) { int kk = 0; @@ -2327,6 +5131,33 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int #if __SSE2__ #if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[1] * scale); + pp[2] = float2int8(p0[2] * scale); + pp[3] = float2int8(p0[3] * scale); + pp[4] = float2int8(p0[4] * scale); + pp[5] = float2int8(p0[5] * scale); + pp[6] = float2int8(p0[6] * scale); + pp[7] = float2int8(p0[7] * scale); + pp[8] = float2int8(p0[8] * scale); + pp[9] = float2int8(p0[9] * scale); + pp[10] = float2int8(p0[10] * scale); + pp[11] = float2int8(p0[11] * scale); + pp[12] = float2int8(p0[12] * scale); + pp[13] = float2int8(p0[13] * scale); + pp[14] = float2int8(p0[14] * scale); + pp[15] = float2int8(p0[15] * scale); + pp += 16; + p0 += B_hstep * 16; + } + } +#endif // __AVX512F__ if (elempack == 8) { int kk = 0; @@ -2343,59 +5174,1865 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int pp += 8; p0 += B_hstep * 8; } - } -#endif // __AVX__ - if (elempack == 4) - { - int kk = 0; - for (; kk + 3 < max_kk; kk += 4) + } +#endif // __AVX__ + if (elempack == 4) + { + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[1] * scale); + pp[2] = float2int8(p0[2] * scale); + pp[3] = float2int8(p0[3] * scale); + pp += 4; + p0 += B_hstep * 4; + } + } +#endif // __SSE2__ + if (elempack == 1) + { + int kk = 0; + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale); + pp += 1; + p0 += B_hstep; + } + } + } +} + +static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta, int output_transpose) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx2()) + { + unpack_output_tile_int32_to_fp32_avx2(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta, output_transpose); + return; + } +#endif + + const int out_elempack = top_blob.elempack; + const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; + + const int c_hstep = C.dims == 3 ? (int)C.cstep : C.w; + const int c_elempack = C.elempack; + const float* pC = C; + + NCNN_LOGE("unpack_output_tile_int32_to_fp32 %d %d %d %d %d %d %d %d", i, max_ii, j, max_jj, out_elempack, broadcast_type_C, c_elempack, output_transpose); + + // const int* pp = topT; + + int ii = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; ii + 15 < max_ii; ii += 16) + { + const int* pp = (const int*)topT + ii * max_jj; + + float* p0; + if (output_transpose) + { + p0 = (float*)top_blob + j * out_hstep + (i + ii) * out_elempack; + } + else + { + p0 = (float*)top_blob + (i + ii) * out_hstep + j * out_elempack; + } + + __m512 _descale = _mm512_loadu_ps((const float*)descales + i + ii); + + __m512 _c0; + if (pC) + { + if (broadcast_type_C == 0) + { + _c0 = _mm512_set1_ps(pC[0] * beta); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)C + i + ii; + _c0 = _mm512_loadu_ps(pC); + _c0 = _mm512_mul_ps(_c0, _mm512_set1_ps(beta)); + } + if (broadcast_type_C == 3) + { + pC = (const float*)C + (i + ii) * c_hstep + j * c_elempack; + } + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } + + int jj = 0; +#if defined(__x86_64__) || defined(_M_X64) + for (; jj + 15 < max_jj; jj += 16) + { + __m512 _f0 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)pp)); + __m512 _f1 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 16))); + __m512 _f2 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 32))); + __m512 _f3 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 48))); + __m512 _f4 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 64))); + __m512 _f5 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 80))); + __m512 _f6 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 96))); + __m512 _f7 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 112))); + __m512 _f8 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 128))); + __m512 _f9 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 128 + 16))); + __m512 _fa = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 128 + 32))); + __m512 _fb = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 128 + 48))); + __m512 _fc = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 128 + 64))); + __m512 _fd = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 128 + 80))); + __m512 _fe = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 128 + 96))); + __m512 _ff = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 128 + 112))); + pp += 256; + + // from + + // 00 11 22 33 44 55 66 77 88 99 aa bb cc dd ee ff + // 01 12 23 30 45 56 67 74 89 9a ab b8 cd de ef fc + // 20 31 02 13 64 75 46 57 a8 b9 8a 9b ec fd ce df + // 21 32 03 10 65 76 47 54 a9 ba 8b 98 ed fe cf dc + // 08 19 2a 3b 4c 5d 6e 7f 80 91 a2 b3 c4 d5 e6 f7 + // 09 1a 2b 38 4d 5e 6f 7c 81 92 a3 b0 c5 d6 e7 f4 + // 28 39 0a 1b 6c 7d 4e 5f a0 b1 82 93 e4 f5 c6 d7 + // 29 3a 0b 18 6d 7e 4f 5c a1 b2 83 90 e5 f6 c7 d4 + // 40 51 62 73 04 15 26 37 c8 d9 ea fb 8c 9d ae bf + // 41 52 63 70 05 16 27 34 c9 da eb f8 8d 9e af bc + // 60 71 42 53 24 35 06 17 e8 f9 ca db ac bd 8e 9f + // 61 72 43 50 25 36 07 14 e9 fa cb d8 ad be 8f 9c + // 48 59 6a 7b 0c 1d 2e 3f c0 d1 e2 f3 84 95 a6 b7 + // 49 5a 6b 78 0d 1e 2f 3c c1 d2 e3 f0 85 96 a7 b4 + // 68 79 4a 5b 2c 3d 0e 1f e0 f1 c2 d3 a4 b5 86 97 + // 69 7a 4b 58 2d 3e 0f 1c e1 f2 c3 d0 a5 b6 87 94 + + // _f0 = _mm512_setr_ps(0x00,0x11,0x22,0x33,0x44,0x55,0x66,0x77,0x88,0x99,0xaa,0xbb,0xcc,0xdd,0xee,0xff); + // _f1 = _mm512_setr_ps(0x01,0x12,0x23,0x30,0x45,0x56,0x67,0x74,0x89,0x9a,0xab,0xb8,0xcd,0xde,0xef,0xfc); + // _f2 = _mm512_setr_ps(0x20,0x31,0x02,0x13,0x64,0x75,0x46,0x57,0xa8,0xb9,0x8a,0x9b,0xec,0xfd,0xce,0xdf); + // _f3 = _mm512_setr_ps(0x21,0x32,0x03,0x10,0x65,0x76,0x47,0x54,0xa9,0xba,0x8b,0x98,0xed,0xfe,0xcf,0xdc); + // _f4 = _mm512_setr_ps(0x08,0x19,0x2a,0x3b,0x4c,0x5d,0x6e,0x7f,0x80,0x91,0xa2,0xb3,0xc4,0xd5,0xe6,0xf7); + // _f5 = _mm512_setr_ps(0x09,0x1a,0x2b,0x38,0x4d,0x5e,0x6f,0x7c,0x81,0x92,0xa3,0xb0,0xc5,0xd6,0xe7,0xf4); + // _f6 = _mm512_setr_ps(0x28,0x39,0x0a,0x1b,0x6c,0x7d,0x4e,0x5f,0xa0,0xb1,0x82,0x93,0xe4,0xf5,0xc6,0xd7); + // _f7 = _mm512_setr_ps(0x29,0x3a,0x0b,0x18,0x6d,0x7e,0x4f,0x5c,0xa1,0xb2,0x83,0x90,0xe5,0xf6,0xc7,0xd4); + // _f8 = _mm512_setr_ps(0x40,0x51,0x62,0x73,0x04,0x15,0x26,0x37,0xc8,0xd9,0xea,0xfb,0x8c,0x9d,0xae,0xbf); + // _f9 = _mm512_setr_ps(0x41,0x52,0x63,0x70,0x05,0x16,0x27,0x34,0xc9,0xda,0xeb,0xf8,0x8d,0x9e,0xaf,0xbc); + // _fa = _mm512_setr_ps(0x60,0x71,0x42,0x53,0x24,0x35,0x06,0x17,0xe8,0xf9,0xca,0xdb,0xac,0xbd,0x8e,0x9f); + // _fb = _mm512_setr_ps(0x61,0x72,0x43,0x50,0x25,0x36,0x07,0x14,0xe9,0xfa,0xcb,0xd8,0xad,0xbe,0x8f,0x9c); + // _fc = _mm512_setr_ps(0x48,0x59,0x6a,0x7b,0x0c,0x1d,0x2e,0x3f,0xc0,0xd1,0xe2,0xf3,0x84,0x95,0xa6,0xb7); + // _fd = _mm512_setr_ps(0x49,0x5a,0x6b,0x78,0x0d,0x1e,0x2f,0x3c,0xc1,0xd2,0xe3,0xf0,0x85,0x96,0xa7,0xb4); + // _fe = _mm512_setr_ps(0x68,0x79,0x4a,0x5b,0x2c,0x3d,0x0e,0x1f,0xe0,0xf1,0xc2,0xd3,0xa4,0xb5,0x86,0x97); + // _ff = _mm512_setr_ps(0x69,0x7a,0x4b,0x58,0x2d,0x3e,0x0f,0x1c,0xe1,0xf2,0xc3,0xd0,0xa5,0xb6,0x87,0x94); + + // print(_f0); + // print(_f1); + // print(_f2); + // print(_f3); + // print(_f4); + // print(_f5); + // print(_f6); + // print(_f7); + // print(_f8); + // print(_f9); + // print(_fa); + // print(_fb); + // print(_fc); + // print(_fd); + // print(_fe); + // print(_ff); + + // to + + // 00 10 20 30 40 50 60 70 80 90 a0 b0 c0 d0 e0 f0 + // 01 11 21 31 41 51 61 71 81 91 a1 b1 c1 d1 e1 f1 + // 02 12 22 32 42 52 62 72 82 92 a2 b2 c2 d2 e2 f2 + // 03 13 23 33 43 53 63 73 83 93 a3 b3 c3 d3 e3 f3 + // 04 14 24 34 44 54 64 74 84 94 a4 b4 c4 d4 e4 f4 + // 05 15 25 35 45 55 65 75 85 95 a5 b5 c5 d5 e5 f5 + // 06 16 26 36 46 56 66 76 86 96 a6 b6 c6 d6 e6 f6 + // 07 17 27 37 47 57 67 77 87 97 a7 b7 c7 d7 e7 f7 + // 08 18 28 38 48 58 68 78 88 98 a8 b8 c8 d8 e8 f8 + // 09 19 29 39 49 59 69 79 89 99 a9 b9 c9 d9 e9 f9 + // 0a 1a 2a 3a 4a 5a 6a 7a 8a 9a aa ba ca da ea fa + // 0b 1b 2b 3b 4b 5b 6b 7b 8b 9b ab bb cb db eb fb + // 0c 1c 2c 3c 4c 5c 6c 7c 8c 9c ac bc cc dc ec fc + // 0d 1d 2d 3d 4d 5d 6d 7d 8d 9d ad bd cd dd ed fd + // 0e 1e 2e 3e 4e 5e 6e 7e 8e 9e ae be ce de ee fe + // 0f 1f 2f 3f 4f 5f 6f 7f 8f 9f af bf cf df ef ff + + { + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm512_permute_ps(_f3, _MM_SHUFFLE(2, 1, 0, 3)); + _f5 = _mm512_permute_ps(_f5, _MM_SHUFFLE(2, 1, 0, 3)); + _f7 = _mm512_permute_ps(_f7, _MM_SHUFFLE(2, 1, 0, 3)); + _f9 = _mm512_permute_ps(_f9, _MM_SHUFFLE(2, 1, 0, 3)); + _fb = _mm512_permute_ps(_fb, _MM_SHUFFLE(2, 1, 0, 3)); + _fd = _mm512_permute_ps(_fd, _MM_SHUFFLE(2, 1, 0, 3)); + _ff = _mm512_permute_ps(_ff, _MM_SHUFFLE(2, 1, 0, 3)); + + // 00 11 22 33 44 55 66 77 88 99 aa bb cc dd ee ff + // 30 01 12 23 74 45 56 67 b8 89 9a ab fc cd de ef + // 20 31 02 13 64 75 46 57 a8 b9 8a 9b ec fd ce df + // 10 21 32 03 54 65 76 47 98 a9 ba 8b dc ed fe cf + + // 08 19 2a 3b 4c 5d 6e 7f 80 91 a2 b3 c4 d5 e6 f7 + // 38 09 1a 2b 7c 4d 5e 6f b0 81 92 a3 f4 c5 d6 e7 + // 28 39 0a 1b 6c 7d 4e 5f a0 b1 82 93 e4 f5 c6 d7 + // 18 29 3a 0b 5c 6d 7e 4f 90 a1 b2 83 d4 e5 f6 c7 + + // 40 51 62 73 04 15 26 37 c8 d9 ea fb 8c 9d ae bf + // 70 41 52 63 34 05 16 27 f8 c9 da eb bc 8d 9e af + // 60 71 42 53 24 35 06 17 e8 f9 ca db ac bd 8e 9f + // 50 61 72 43 14 25 36 07 d8 e9 fa cb 9c ad be 8f + + // 48 59 6a 7b 0c 1d 2e 3f c0 d1 e2 f3 84 95 a6 b7 + // 78 49 5a 6b 3c 0d 1e 2f f0 c1 d2 e3 b4 85 96 a7 + // 68 79 4a 5b 2c 3d 0e 1f e0 f1 c2 d3 a4 b5 86 97 + // 58 69 7a 4b 1c 2d 3e 0f d0 e1 f2 c3 94 a5 b6 87 + + __m512 _tmp0 = _mm512_unpacklo_ps(_f0, _f3); + __m512 _tmp1 = _mm512_unpackhi_ps(_f0, _f3); + __m512 _tmp2 = _mm512_unpacklo_ps(_f2, _f1); + __m512 _tmp3 = _mm512_unpackhi_ps(_f2, _f1); + __m512 _tmp4 = _mm512_unpacklo_ps(_f4, _f7); + __m512 _tmp5 = _mm512_unpackhi_ps(_f4, _f7); + __m512 _tmp6 = _mm512_unpacklo_ps(_f6, _f5); + __m512 _tmp7 = _mm512_unpackhi_ps(_f6, _f5); + __m512 _tmp8 = _mm512_unpacklo_ps(_f8, _fb); + __m512 _tmp9 = _mm512_unpackhi_ps(_f8, _fb); + __m512 _tmpa = _mm512_unpacklo_ps(_fa, _f9); + __m512 _tmpb = _mm512_unpackhi_ps(_fa, _f9); + __m512 _tmpc = _mm512_unpacklo_ps(_fc, _ff); + __m512 _tmpd = _mm512_unpackhi_ps(_fc, _ff); + __m512 _tmpe = _mm512_unpacklo_ps(_fe, _fd); + __m512 _tmpf = _mm512_unpackhi_ps(_fe, _fd); + + // 00 10 11 21 44 54 55 65 88 98 99 a9 cc dc dd ed + // 22 32 33 03 66 76 77 47 aa ba bb 8b ee fe ff cf + // 20 30 31 01 64 74 75 45 a8 b8 b9 89 ec fc fd cd + // 02 12 13 23 46 56 57 67 8a 9a 9b ab ce de df ef + + _f0 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp2))); + _f1 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp2))); + _f2 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp3), _mm512_castps_pd(_tmp1))); + _f3 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp3), _mm512_castps_pd(_tmp1))); + _f4 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp6))); + _f5 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp6))); + _f6 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp7), _mm512_castps_pd(_tmp5))); + _f7 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp7), _mm512_castps_pd(_tmp5))); + _f8 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp8), _mm512_castps_pd(_tmpa))); + _f9 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp8), _mm512_castps_pd(_tmpa))); + _fa = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmpb), _mm512_castps_pd(_tmp9))); + _fb = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmpb), _mm512_castps_pd(_tmp9))); + _fc = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmpc), _mm512_castps_pd(_tmpe))); + _fd = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmpc), _mm512_castps_pd(_tmpe))); + _fe = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmpf), _mm512_castps_pd(_tmpd))); + _ff = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmpf), _mm512_castps_pd(_tmpd))); + + // 00 10 20 30 44 54 64 74 88 98 a8 b8 cc dc ec fc + // 11 21 31 01 55 65 75 45 99 a9 b9 89 dd ed fd cd + // 02 12 22 32 46 56 66 76 8a 9a aa ba ce de ee fe + // 13 23 33 03 57 67 77 47 9b ab bb 8b df ef ff cf + + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm512_permute_ps(_f3, _MM_SHUFFLE(2, 1, 0, 3)); + _f5 = _mm512_permute_ps(_f5, _MM_SHUFFLE(2, 1, 0, 3)); + _f7 = _mm512_permute_ps(_f7, _MM_SHUFFLE(2, 1, 0, 3)); + _f9 = _mm512_permute_ps(_f9, _MM_SHUFFLE(2, 1, 0, 3)); + _fb = _mm512_permute_ps(_fb, _MM_SHUFFLE(2, 1, 0, 3)); + _fd = _mm512_permute_ps(_fd, _MM_SHUFFLE(2, 1, 0, 3)); + _ff = _mm512_permute_ps(_ff, _MM_SHUFFLE(2, 1, 0, 3)); + + // 00 10 20 30 44 54 64 74 88 98 a8 b8 cc dc ec fc + // 01 11 21 31 45 55 65 75 89 99 a9 b9 cd dd ed fd + // 02 12 22 32 46 56 66 76 8a 9a aa ba ce de ee fe + // 03 13 23 33 47 57 67 77 8b 9b ab bb cf df ef ff + + // 08 18 28 38 4c 5c 6c 7c 80 90 a0 b0 c4 d4 e4 f4 + // 09 19 29 39 4d 5d 6d 7d 81 91 a1 b1 c5 d5 e5 f5 + // 0a 1a 2a 3a 4e 5e 6e 7e 82 92 a2 b2 c6 d6 e6 f6 + // 0b 1b 2b 3b 4f 5f 6f 7f 83 93 a3 b3 c7 d7 e7 f7 + + // 40 50 60 70 04 14 24 34 c8 d8 e8 f8 8c 9c ac bc + // 41 51 61 71 05 15 25 35 c9 d9 e9 f9 8d 9d ad bd + // 42 52 62 72 06 16 26 36 ca da ea fa 8e 9e ae be + // 43 53 63 73 07 17 27 37 cb db eb fb 8f 9f af bf + + // 48 58 68 78 0c 1c 2c 3c c0 d0 e0 f0 84 94 a4 b4 + // 49 59 69 79 0d 1d 2d 3d c1 d1 e1 f1 85 95 a5 b5 + // 4a 5a 6a 7a 0e 1e 2e 3e c2 d2 e2 f2 86 96 a6 b6 + // 4b 5b 6b 7b 0f 1f 2f 3f c3 d3 e3 f3 87 97 a7 b7 + + // NCNN_LOGE("--------"); + // print(_f0); + // print(_f1); + // print(_f2); + // print(_f3); + // print(_f4); + // print(_f5); + // print(_f6); + // print(_f7); + // print(_f8); + // print(_f9); + // print(_fa); + // print(_fb); + // print(_fc); + // print(_fd); + // print(_fe); + // print(_ff); + + _tmp0 = _mm512_shuffle_f32x4(_f0, _f8, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp1 = _mm512_shuffle_f32x4(_f1, _f9, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp2 = _mm512_shuffle_f32x4(_f2, _fa, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp3 = _mm512_shuffle_f32x4(_f3, _fb, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp4 = _mm512_shuffle_f32x4(_f8, _f0, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp5 = _mm512_shuffle_f32x4(_f9, _f1, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp6 = _mm512_shuffle_f32x4(_fa, _f2, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp7 = _mm512_shuffle_f32x4(_fb, _f3, _MM_SHUFFLE(3, 1, 3, 1)); + _tmp8 = _mm512_shuffle_f32x4(_f4, _fc, _MM_SHUFFLE(2, 0, 2, 0)); + _tmp9 = _mm512_shuffle_f32x4(_f5, _fd, _MM_SHUFFLE(2, 0, 2, 0)); + _tmpa = _mm512_shuffle_f32x4(_f6, _fe, _MM_SHUFFLE(2, 0, 2, 0)); + _tmpb = _mm512_shuffle_f32x4(_f7, _ff, _MM_SHUFFLE(2, 0, 2, 0)); + _tmpc = _mm512_shuffle_f32x4(_fc, _f4, _MM_SHUFFLE(3, 1, 3, 1)); + _tmpd = _mm512_shuffle_f32x4(_fd, _f5, _MM_SHUFFLE(3, 1, 3, 1)); + _tmpe = _mm512_shuffle_f32x4(_fe, _f6, _MM_SHUFFLE(3, 1, 3, 1)); + _tmpf = _mm512_shuffle_f32x4(_ff, _f7, _MM_SHUFFLE(3, 1, 3, 1)); + + // 00 10 20 30 88 98 a8 b8 40 50 60 70 c8 d8 e8 f8 + // 01 11 21 31 89 99 a9 b9 41 51 61 71 c9 d9 e9 f9 + // 02 12 22 32 8a 9a aa ba 42 52 62 72 ca da ea fa + // 03 13 23 33 8b 9b ab bb 43 53 63 73 cb db eb fb + // 04 14 24 34 8c 9c ac bc 44 54 64 74 cc dc ec fc + // 05 15 25 35 8d 9d ad bd 45 55 65 75 cd dd ed fd + // 06 16 26 36 8e 9e ae be 46 56 66 76 ce de ee fe + // 07 17 27 37 8f 9f af bf 47 57 67 77 cf df ef ff + + // 08 18 28 38 80 90 a0 b0 48 58 68 78 c0 d0 e0 f0 + // 09 19 29 39 81 91 a1 b1 49 59 69 79 c1 d1 e1 f1 + // 0a 1a 2a 3a 82 92 a2 b2 4a 5a 6a 7a c2 d2 e2 f2 + // 0b 1b 2b 3b 83 93 a3 b3 4b 5b 6b 7b c3 d3 e3 f3 + // 0c 1c 2c 3c 84 94 a4 b4 4c 5c 6c 7c c4 d4 e4 f4 + // 0d 1d 2d 3d 85 95 a5 b5 4d 5d 6d 7d c5 d5 e5 f5 + // 0e 1e 2e 3e 86 96 a6 b6 4e 5e 6e 7e c6 d6 e6 f6 + // 0f 1f 2f 3f 87 97 a7 b7 4f 5f 6f 7f c7 d7 e7 f7 + + // NCNN_LOGE("--------"); + // print(_tmp0); + // print(_tmp1); + // print(_tmp2); + // print(_tmp3); + // print(_tmp4); + // print(_tmp5); + // print(_tmp6); + // print(_tmp7); + // print(_tmp8); + // print(_tmp9); + // print(_tmpa); + // print(_tmpb); + // print(_tmpc); + // print(_tmpd); + // print(_tmpe); + // print(_tmpf); + + _f0 = _mm512_shuffle_f32x4(_tmp0, _tmp8, _MM_SHUFFLE(3, 1, 2, 0)); + _f1 = _mm512_shuffle_f32x4(_tmp1, _tmp9, _MM_SHUFFLE(3, 1, 2, 0)); + _f2 = _mm512_shuffle_f32x4(_tmp2, _tmpa, _MM_SHUFFLE(3, 1, 2, 0)); + _f3 = _mm512_shuffle_f32x4(_tmp3, _tmpb, _MM_SHUFFLE(3, 1, 2, 0)); + _f4 = _mm512_shuffle_f32x4(_tmp4, _tmpc, _MM_SHUFFLE(3, 1, 2, 0)); + _f5 = _mm512_shuffle_f32x4(_tmp5, _tmpd, _MM_SHUFFLE(3, 1, 2, 0)); + _f6 = _mm512_shuffle_f32x4(_tmp6, _tmpe, _MM_SHUFFLE(3, 1, 2, 0)); + _f7 = _mm512_shuffle_f32x4(_tmp7, _tmpf, _MM_SHUFFLE(3, 1, 2, 0)); + _f8 = _mm512_shuffle_f32x4(_tmp8, _tmp0, _MM_SHUFFLE(3, 1, 2, 0)); + _f9 = _mm512_shuffle_f32x4(_tmp9, _tmp1, _MM_SHUFFLE(3, 1, 2, 0)); + _fa = _mm512_shuffle_f32x4(_tmpa, _tmp2, _MM_SHUFFLE(3, 1, 2, 0)); + _fb = _mm512_shuffle_f32x4(_tmpb, _tmp3, _MM_SHUFFLE(3, 1, 2, 0)); + _fc = _mm512_shuffle_f32x4(_tmpc, _tmp4, _MM_SHUFFLE(3, 1, 2, 0)); + _fd = _mm512_shuffle_f32x4(_tmpd, _tmp5, _MM_SHUFFLE(3, 1, 2, 0)); + _fe = _mm512_shuffle_f32x4(_tmpe, _tmp6, _MM_SHUFFLE(3, 1, 2, 0)); + _ff = _mm512_shuffle_f32x4(_tmpf, _tmp7, _MM_SHUFFLE(3, 1, 2, 0)); + } + + // NCNN_LOGE("--------"); + // print(_f0); + // print(_f1); + // print(_f2); + // print(_f3); + // print(_f4); + // print(_f5); + // print(_f6); + // print(_f7); + // print(_f8); + // print(_f9); + // print(_fa); + // print(_fb); + // print(_fc); + // print(_fd); + // print(_fe); + // print(_ff); + + + + _f0 = _mm512_mul_ps(_f0, _descale); + _f1 = _mm512_mul_ps(_f1, _descale); + _f2 = _mm512_mul_ps(_f2, _descale); + _f3 = _mm512_mul_ps(_f3, _descale); + _f4 = _mm512_mul_ps(_f4, _descale); + _f5 = _mm512_mul_ps(_f5, _descale); + _f6 = _mm512_mul_ps(_f6, _descale); + _f7 = _mm512_mul_ps(_f7, _descale); + _f8 = _mm512_mul_ps(_f8, _descale); + _f9 = _mm512_mul_ps(_f9, _descale); + _fa = _mm512_mul_ps(_fa, _descale); + _fb = _mm512_mul_ps(_fb, _descale); + _fc = _mm512_mul_ps(_fc, _descale); + _fd = _mm512_mul_ps(_fd, _descale); + _fe = _mm512_mul_ps(_fe, _descale); + _ff = _mm512_mul_ps(_ff, _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c0); + _f2 = _mm512_add_ps(_f2, _c0); + _f3 = _mm512_add_ps(_f3, _c0); + _f4 = _mm512_add_ps(_f4, _c0); + _f5 = _mm512_add_ps(_f5, _c0); + _f6 = _mm512_add_ps(_f6, _c0); + _f7 = _mm512_add_ps(_f7, _c0); + _f8 = _mm512_add_ps(_f8, _c0); + _f9 = _mm512_add_ps(_f9, _c0); + _fa = _mm512_add_ps(_fa, _c0); + _fb = _mm512_add_ps(_fb, _c0); + _fc = _mm512_add_ps(_fc, _c0); + _fd = _mm512_add_ps(_fd, _c0); + _fe = _mm512_add_ps(_fe, _c0); + _ff = _mm512_add_ps(_ff, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c0); + _f2 = _mm512_add_ps(_f2, _c0); + _f3 = _mm512_add_ps(_f3, _c0); + _f4 = _mm512_add_ps(_f4, _c0); + _f5 = _mm512_add_ps(_f5, _c0); + _f6 = _mm512_add_ps(_f6, _c0); + _f7 = _mm512_add_ps(_f7, _c0); + _f8 = _mm512_add_ps(_f8, _c0); + _f9 = _mm512_add_ps(_f9, _c0); + _fa = _mm512_add_ps(_fa, _c0); + _fb = _mm512_add_ps(_fb, _c0); + _fc = _mm512_add_ps(_fc, _c0); + _fd = _mm512_add_ps(_fd, _c0); + _fe = _mm512_add_ps(_fe, _c0); + _ff = _mm512_add_ps(_ff, _c0); + } + if (broadcast_type_C == 3) + { + // TODO + // __m512 _c1; + // __m512 _c2; + // __m512 _c3; + // __m512 _c4; + // __m512 _c5; + // __m512 _c6; + // __m512 _c7; + // if (c_elempack == 8) + // { + // _c0 = _mm256_loadu_ps(pC); + // _c1 = _mm256_loadu_ps(pC + 8); + // _c2 = _mm256_loadu_ps(pC + 16); + // _c3 = _mm256_loadu_ps(pC + 24); + // _c4 = _mm256_loadu_ps(pC + 32); + // _c5 = _mm256_loadu_ps(pC + 40); + // _c6 = _mm256_loadu_ps(pC + 48); + // _c7 = _mm256_loadu_ps(pC + 56); + // pC += 64; + // } + // if (c_elempack == 4) + // { + // __m256 _tmp0 = _mm256_loadu_ps(pC); + // __m256 _tmp1 = _mm256_loadu_ps(pC + 8); + // __m256 _tmp2 = _mm256_loadu_ps(pC + 16); + // __m256 _tmp3 = _mm256_loadu_ps(pC + 24); + // __m256 _tmp4 = _mm256_loadu_ps(pC + c_hstep * 4); + // __m256 _tmp5 = _mm256_loadu_ps(pC + c_hstep * 4 + 8); + // __m256 _tmp6 = _mm256_loadu_ps(pC + c_hstep * 4 + 16); + // __m256 _tmp7 = _mm256_loadu_ps(pC + c_hstep * 4 + 24); + // _c0 = _mm256_permute2f128_ps(_tmp0, _tmp4, _MM_SHUFFLE(0, 2, 0, 0)); + // _c1 = _mm256_permute2f128_ps(_tmp0, _tmp4, _MM_SHUFFLE(0, 3, 0, 1)); + // _c2 = _mm256_permute2f128_ps(_tmp1, _tmp5, _MM_SHUFFLE(0, 2, 0, 0)); + // _c3 = _mm256_permute2f128_ps(_tmp1, _tmp5, _MM_SHUFFLE(0, 3, 0, 1)); + // _c4 = _mm256_permute2f128_ps(_tmp2, _tmp6, _MM_SHUFFLE(0, 2, 0, 0)); + // _c5 = _mm256_permute2f128_ps(_tmp2, _tmp6, _MM_SHUFFLE(0, 3, 0, 1)); + // _c6 = _mm256_permute2f128_ps(_tmp3, _tmp7, _MM_SHUFFLE(0, 2, 0, 0)); + // _c7 = _mm256_permute2f128_ps(_tmp3, _tmp7, _MM_SHUFFLE(0, 3, 0, 1)); + // pC += 32; + // } + // if (c_elempack == 1) + // { + // _c0 = _mm256_loadu_ps(pC); + // _c1 = _mm256_loadu_ps(pC + c_hstep); + // _c2 = _mm256_loadu_ps(pC + c_hstep * 2); + // _c3 = _mm256_loadu_ps(pC + c_hstep * 3); + // _c4 = _mm256_loadu_ps(pC + c_hstep * 4); + // _c5 = _mm256_loadu_ps(pC + c_hstep * 5); + // _c6 = _mm256_loadu_ps(pC + c_hstep * 6); + // _c7 = _mm256_loadu_ps(pC + c_hstep * 7); + // transpose8x8_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); + // pC += 8; + // } + // if (beta == 1.f) + // { + // _f0 = _mm256_add_ps(_f0, _c0); + // _f1 = _mm256_add_ps(_f1, _c1); + // _f2 = _mm256_add_ps(_f2, _c2); + // _f3 = _mm256_add_ps(_f3, _c3); + // _f4 = _mm256_add_ps(_f4, _c4); + // _f5 = _mm256_add_ps(_f5, _c5); + // _f6 = _mm256_add_ps(_f6, _c6); + // _f7 = _mm256_add_ps(_f7, _c7); + // } + // else + // { + // __m256 _beta = _mm256_set1_ps(beta); + // _f0 = _mm256_comp_fmadd_ps(_c0, _beta, _f0); + // _f1 = _mm256_comp_fmadd_ps(_c1, _beta, _f1); + // _f2 = _mm256_comp_fmadd_ps(_c2, _beta, _f2); + // _f3 = _mm256_comp_fmadd_ps(_c3, _beta, _f3); + // _f4 = _mm256_comp_fmadd_ps(_c4, _beta, _f4); + // _f5 = _mm256_comp_fmadd_ps(_c5, _beta, _f5); + // _f6 = _mm256_comp_fmadd_ps(_c6, _beta, _f6); + // _f7 = _mm256_comp_fmadd_ps(_c7, _beta, _f7); + // } + } + if (broadcast_type_C == 4) + { + // TODO + // _c0 = _mm256_set1_ps(pC[0] * beta); + // __m256 _c1 = _mm256_set1_ps(pC[1] * beta); + // __m256 _c2 = _mm256_set1_ps(pC[2] * beta); + // __m256 _c3 = _mm256_set1_ps(pC[3] * beta); + // + // _f0 = _mm256_add_ps(_f0, _c0); + // _f1 = _mm256_add_ps(_f1, _c1); + // _f2 = _mm256_add_ps(_f2, _c2); + // _f3 = _mm256_add_ps(_f3, _c3); + // + // _c0 = _mm256_set1_ps(pC[4] * beta); + // _c1 = _mm256_set1_ps(pC[5] * beta); + // _c2 = _mm256_set1_ps(pC[6] * beta); + // _c3 = _mm256_set1_ps(pC[7] * beta); + // + // _f4 = _mm256_add_ps(_f4, _c0); + // _f5 = _mm256_add_ps(_f5, _c1); + // _f6 = _mm256_add_ps(_f6, _c2); + // _f7 = _mm256_add_ps(_f7, _c3); + // pC += 8; + } + } + + if (alpha != 1.f) + { + __m512 _alpha = _mm512_set1_ps(alpha); + _f0 = _mm512_mul_ps(_f0, _alpha); + _f1 = _mm512_mul_ps(_f1, _alpha); + _f2 = _mm512_mul_ps(_f2, _alpha); + _f3 = _mm512_mul_ps(_f3, _alpha); + _f4 = _mm512_mul_ps(_f4, _alpha); + _f5 = _mm512_mul_ps(_f5, _alpha); + _f6 = _mm512_mul_ps(_f6, _alpha); + _f7 = _mm512_mul_ps(_f7, _alpha); + _f8 = _mm512_mul_ps(_f8, _alpha); + _f9 = _mm512_mul_ps(_f9, _alpha); + _fa = _mm512_mul_ps(_fa, _alpha); + _fb = _mm512_mul_ps(_fb, _alpha); + _fc = _mm512_mul_ps(_fc, _alpha); + _fd = _mm512_mul_ps(_fd, _alpha); + _fe = _mm512_mul_ps(_fe, _alpha); + _ff = _mm512_mul_ps(_ff, _alpha); + } + + if (output_transpose) + { + // 00 10 20 30 40 50 60 70 80 90 a0 b0 c0 d0 e0 f0 + // 01 11 21 31 41 51 61 71 81 91 a1 b1 c1 d1 e1 f1 + // 02 12 22 32 42 52 62 72 82 92 a2 b2 c2 d2 e2 f2 + // 03 13 23 33 43 53 63 73 83 93 a3 b3 c3 d3 e3 f3 + // 04 14 24 34 44 54 64 74 84 94 a4 b4 c4 d4 e4 f4 + // 05 15 25 35 45 55 65 75 85 95 a5 b5 c5 d5 e5 f5 + // 06 16 26 36 46 56 66 76 86 96 a6 b6 c6 d6 e6 f6 + // 07 17 27 37 47 57 67 77 87 97 a7 b7 c7 d7 e7 f7 + // 08 18 28 38 48 58 68 78 88 98 a8 b8 c8 d8 e8 f8 + // 09 19 29 39 49 59 69 79 89 99 a9 b9 c9 d9 e9 f9 + // 0a 1a 2a 3a 4a 5a 6a 7a 8a 9a aa ba ca da ea fa + // 0b 1b 2b 3b 4b 5b 6b 7b 8b 9b ab bb cb db eb fb + // 0c 1c 2c 3c 4c 5c 6c 7c 8c 9c ac bc cc dc ec fc + // 0d 1d 2d 3d 4d 5d 6d 7d 8d 9d ad bd cd dd ed fd + // 0e 1e 2e 3e 4e 5e 6e 7e 8e 9e ae be ce de ee fe + // 0f 1f 2f 3f 4f 5f 6f 7f 8f 9f af bf cf df ef ff + + if (out_elempack == 16) + { + transpose16x16_ps(_f0, _f1, _f2, _f3, _f4, _f5, _f6, _f7, _f8, _f9, _fa, _fb, _fc, _fd, _fe, _ff); + + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + 16, _f1); + _mm512_storeu_ps(p0 + 16 * 2, _f2); + _mm512_storeu_ps(p0 + 16 * 3, _f3); + _mm512_storeu_ps(p0 + 16 * 4, _f4); + _mm512_storeu_ps(p0 + 16 * 5, _f5); + _mm512_storeu_ps(p0 + 16 * 6, _f6); + _mm512_storeu_ps(p0 + 16 * 7, _f7); + _mm512_storeu_ps(p0 + 16 * 8, _f8); + _mm512_storeu_ps(p0 + 16 * 9, _f9); + _mm512_storeu_ps(p0 + 16 * 10, _fa); + _mm512_storeu_ps(p0 + 16 * 11, _fb); + _mm512_storeu_ps(p0 + 16 * 12, _fc); + _mm512_storeu_ps(p0 + 16 * 13, _fd); + _mm512_storeu_ps(p0 + 16 * 14, _fe); + _mm512_storeu_ps(p0 + 16 * 15, _ff); + } + if (out_elempack == 8) + { + transpose16x8_ps(_f0, _f1, _f2, _f3, _f4, _f5, _f6, _f7); + transpose16x8_ps(_f8, _f9, _fa, _fb, _fc, _fd, _fe, _ff); + + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + 16, _f1); + _mm512_storeu_ps(p0 + 16 * 2, _f2); + _mm512_storeu_ps(p0 + 16 * 3, _f3); + _mm512_storeu_ps(p0 + 16 * 4, _f4); + _mm512_storeu_ps(p0 + 16 * 5, _f5); + _mm512_storeu_ps(p0 + 16 * 6, _f6); + _mm512_storeu_ps(p0 + 16 * 7, _f7); + _mm512_storeu_ps(p0 + out_hstep * 8, _f8); + _mm512_storeu_ps(p0 + out_hstep * 8 + 16, _f9); + _mm512_storeu_ps(p0 + out_hstep * 8 + 16 * 2, _fa); + _mm512_storeu_ps(p0 + out_hstep * 8 + 16 * 3, _fb); + _mm512_storeu_ps(p0 + out_hstep * 8 + 16 * 4, _fc); + _mm512_storeu_ps(p0 + out_hstep * 8 + 16 * 5, _fd); + _mm512_storeu_ps(p0 + out_hstep * 8 + 16 * 6, _fe); + _mm512_storeu_ps(p0 + out_hstep * 8 + 16 * 7, _ff); + } + if (out_elempack == 4) + { + transpose16x4_ps(_f0, _f1, _f2, _f3); + transpose16x4_ps(_f4, _f5, _f6, _f7); + transpose16x4_ps(_f8, _f9, _fa, _fb); + transpose16x4_ps(_fc, _fd, _fe, _ff); + + // 00 10 20 30 40 50 60 70 80 90 a0 b0 c0 d0 e0 f0 + // 01 11 21 31 41 51 61 71 81 91 a1 b1 c1 d1 e1 f1 + // 02 12 22 32 42 52 62 72 82 92 a2 b2 c2 d2 e2 f2 + // 03 13 23 33 43 53 63 73 83 93 a3 b3 c3 d3 e3 f3 + + // 04 14 24 34 44 54 64 74 84 94 a4 b4 c4 d4 e4 f4 + // 05 15 25 35 45 55 65 75 85 95 a5 b5 c5 d5 e5 f5 + // 06 16 26 36 46 56 66 76 86 96 a6 b6 c6 d6 e6 f6 + // 07 17 27 37 47 57 67 77 87 97 a7 b7 c7 d7 e7 f7 + + // 08 18 28 38 48 58 68 78 88 98 a8 b8 c8 d8 e8 f8 + // 09 19 29 39 49 59 69 79 89 99 a9 b9 c9 d9 e9 f9 + // 0a 1a 2a 3a 4a 5a 6a 7a 8a 9a aa ba ca da ea fa + // 0b 1b 2b 3b 4b 5b 6b 7b 8b 9b ab bb cb db eb fb + + // 0c 1c 2c 3c 4c 5c 6c 7c 8c 9c ac bc cc dc ec fc + // 0d 1d 2d 3d 4d 5d 6d 7d 8d 9d ad bd cd dd ed fd + // 0e 1e 2e 3e 4e 5e 6e 7e 8e 9e ae be ce de ee fe + // 0f 1f 2f 3f 4f 5f 6f 7f 8f 9f af bf cf df ef ff + + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + 16, _f1); + _mm512_storeu_ps(p0 + 32, _f2); + _mm512_storeu_ps(p0 + 48, _f3); + _mm512_storeu_ps(p0 + out_hstep * 4, _f4); + _mm512_storeu_ps(p0 + out_hstep * 4 + 16, _f5); + _mm512_storeu_ps(p0 + out_hstep * 4 + 32, _f6); + _mm512_storeu_ps(p0 + out_hstep * 4 + 48, _f7); + _mm512_storeu_ps(p0 + out_hstep * 8, _f8); + _mm512_storeu_ps(p0 + out_hstep * 8 + 16, _f9); + _mm512_storeu_ps(p0 + out_hstep * 8 + 32, _fa); + _mm512_storeu_ps(p0 + out_hstep * 8 + 48, _fb); + _mm512_storeu_ps(p0 + out_hstep * 12, _fc); + _mm512_storeu_ps(p0 + out_hstep * 12 + 16, _fd); + _mm512_storeu_ps(p0 + out_hstep * 12 + 32, _fe); + _mm512_storeu_ps(p0 + out_hstep * 12 + 48, _ff); + } + if (out_elempack == 1) + { + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + out_hstep, _f1); + _mm512_storeu_ps(p0 + out_hstep * 2, _f2); + _mm512_storeu_ps(p0 + out_hstep * 3, _f3); + _mm512_storeu_ps(p0 + out_hstep * 4, _f4); + _mm512_storeu_ps(p0 + out_hstep * 5, _f5); + _mm512_storeu_ps(p0 + out_hstep * 6, _f6); + _mm512_storeu_ps(p0 + out_hstep * 7, _f7); + _mm512_storeu_ps(p0 + out_hstep * 8, _f8); + _mm512_storeu_ps(p0 + out_hstep * 9, _f9); + _mm512_storeu_ps(p0 + out_hstep * 10, _fa); + _mm512_storeu_ps(p0 + out_hstep * 11, _fb); + _mm512_storeu_ps(p0 + out_hstep * 12, _fc); + _mm512_storeu_ps(p0 + out_hstep * 13, _fd); + _mm512_storeu_ps(p0 + out_hstep * 14, _fe); + _mm512_storeu_ps(p0 + out_hstep * 15, _ff); + } + p0 += out_hstep * 16; + } + else + { + if (out_elempack == 16) + { + _mm512_store_ps(p0, _f0); + _mm512_store_ps(p0 + 16, _f1); + _mm512_store_ps(p0 + 32, _f2); + _mm512_store_ps(p0 + 48, _f3); + _mm512_store_ps(p0 + 64, _f4); + _mm512_store_ps(p0 + 80, _f5); + _mm512_store_ps(p0 + 96, _f6); + _mm512_store_ps(p0 + 112, _f7); + _mm512_store_ps(p0 + 128, _f8); + _mm512_store_ps(p0 + 128 + 16, _f9); + _mm512_store_ps(p0 + 128 + 32, _fa); + _mm512_store_ps(p0 + 128 + 48, _fb); + _mm512_store_ps(p0 + 128 + 64, _fc); + _mm512_store_ps(p0 + 128 + 80, _fd); + _mm512_store_ps(p0 + 128 + 96, _fe); + _mm512_store_ps(p0 + 128 + 112, _ff); + p0 += 256; + } + if (out_elempack == 8) + { + // 00 10 20 30 40 50 60 70 80 90 a0 b0 c0 d0 e0 f0 + // 01 11 21 31 41 51 61 71 81 91 a1 b1 c1 d1 e1 f1 + // 02 12 22 32 42 52 62 72 82 92 a2 b2 c2 d2 e2 f2 + // 03 13 23 33 43 53 63 73 83 93 a3 b3 c3 d3 e3 f3 + // 04 14 24 34 44 54 64 74 84 94 a4 b4 c4 d4 e4 f4 + // 05 15 25 35 45 55 65 75 85 95 a5 b5 c5 d5 e5 f5 + // 06 16 26 36 46 56 66 76 86 96 a6 b6 c6 d6 e6 f6 + // 07 17 27 37 47 57 67 77 87 97 a7 b7 c7 d7 e7 f7 + // 08 18 28 38 48 58 68 78 88 98 a8 b8 c8 d8 e8 f8 + // 09 19 29 39 49 59 69 79 89 99 a9 b9 c9 d9 e9 f9 + // 0a 1a 2a 3a 4a 5a 6a 7a 8a 9a aa ba ca da ea fa + // 0b 1b 2b 3b 4b 5b 6b 7b 8b 9b ab bb cb db eb fb + // 0c 1c 2c 3c 4c 5c 6c 7c 8c 9c ac bc cc dc ec fc + // 0d 1d 2d 3d 4d 5d 6d 7d 8d 9d ad bd cd dd ed fd + // 0e 1e 2e 3e 4e 5e 6e 7e 8e 9e ae be ce de ee fe + // 0f 1f 2f 3f 4f 5f 6f 7f 8f 9f af bf cf df ef ff + + _mm256_storeu_ps(p0, _mm512_extractf32x8_ps(_f0, 0)); + _mm256_storeu_ps(p0 + 8, _mm512_extractf32x8_ps(_f1, 0)); + _mm256_storeu_ps(p0 + 16, _mm512_extractf32x8_ps(_f2, 0)); + _mm256_storeu_ps(p0 + 24, _mm512_extractf32x8_ps(_f3, 0)); + _mm256_storeu_ps(p0 + 32, _mm512_extractf32x8_ps(_f4, 0)); + _mm256_storeu_ps(p0 + 40, _mm512_extractf32x8_ps(_f5, 0)); + _mm256_storeu_ps(p0 + 48, _mm512_extractf32x8_ps(_f6, 0)); + _mm256_storeu_ps(p0 + 56, _mm512_extractf32x8_ps(_f7, 0)); + _mm256_storeu_ps(p0 + 64, _mm512_extractf32x8_ps(_f8, 0)); + _mm256_storeu_ps(p0 + 64 + 8, _mm512_extractf32x8_ps(_f9, 0)); + _mm256_storeu_ps(p0 + 64 + 16, _mm512_extractf32x8_ps(_fa, 0)); + _mm256_storeu_ps(p0 + 64 + 24, _mm512_extractf32x8_ps(_fb, 0)); + _mm256_storeu_ps(p0 + 64 + 32, _mm512_extractf32x8_ps(_fc, 0)); + _mm256_storeu_ps(p0 + 64 + 40, _mm512_extractf32x8_ps(_fd, 0)); + _mm256_storeu_ps(p0 + 64 + 48, _mm512_extractf32x8_ps(_fe, 0)); + _mm256_storeu_ps(p0 + 64 + 56, _mm512_extractf32x8_ps(_ff, 0)); + _mm256_storeu_ps(p0 + out_hstep * 8, _mm512_extractf32x8_ps(_f0, 1)); + _mm256_storeu_ps(p0 + out_hstep * 8 + 8, _mm512_extractf32x8_ps(_f1, 1)); + _mm256_storeu_ps(p0 + out_hstep * 8 + 16, _mm512_extractf32x8_ps(_f2, 1)); + _mm256_storeu_ps(p0 + out_hstep * 8 + 24, _mm512_extractf32x8_ps(_f3, 1)); + _mm256_storeu_ps(p0 + out_hstep * 8 + 32, _mm512_extractf32x8_ps(_f4, 1)); + _mm256_storeu_ps(p0 + out_hstep * 8 + 40, _mm512_extractf32x8_ps(_f5, 1)); + _mm256_storeu_ps(p0 + out_hstep * 8 + 48, _mm512_extractf32x8_ps(_f6, 1)); + _mm256_storeu_ps(p0 + out_hstep * 8 + 56, _mm512_extractf32x8_ps(_f7, 1)); + _mm256_storeu_ps(p0 + out_hstep * 8 + 64, _mm512_extractf32x8_ps(_f8, 1)); + _mm256_storeu_ps(p0 + out_hstep * 8 + 64 + 8, _mm512_extractf32x8_ps(_f9, 1)); + _mm256_storeu_ps(p0 + out_hstep * 8 + 64 + 16, _mm512_extractf32x8_ps(_fa, 1)); + _mm256_storeu_ps(p0 + out_hstep * 8 + 64 + 24, _mm512_extractf32x8_ps(_fb, 1)); + _mm256_storeu_ps(p0 + out_hstep * 8 + 64 + 32, _mm512_extractf32x8_ps(_fc, 1)); + _mm256_storeu_ps(p0 + out_hstep * 8 + 64 + 40, _mm512_extractf32x8_ps(_fd, 1)); + _mm256_storeu_ps(p0 + out_hstep * 8 + 64 + 48, _mm512_extractf32x8_ps(_fe, 1)); + _mm256_storeu_ps(p0 + out_hstep * 8 + 64 + 56, _mm512_extractf32x8_ps(_ff, 1)); + p0 += 128; + } + if (out_elempack == 4) + { + // 00 10 20 30 40 50 60 70 80 90 a0 b0 c0 d0 e0 f0 + // 01 11 21 31 41 51 61 71 81 91 a1 b1 c1 d1 e1 f1 + // 02 12 22 32 42 52 62 72 82 92 a2 b2 c2 d2 e2 f2 + // 03 13 23 33 43 53 63 73 83 93 a3 b3 c3 d3 e3 f3 + // 04 14 24 34 44 54 64 74 84 94 a4 b4 c4 d4 e4 f4 + // 05 15 25 35 45 55 65 75 85 95 a5 b5 c5 d5 e5 f5 + // 06 16 26 36 46 56 66 76 86 96 a6 b6 c6 d6 e6 f6 + // 07 17 27 37 47 57 67 77 87 97 a7 b7 c7 d7 e7 f7 + // 08 18 28 38 48 58 68 78 88 98 a8 b8 c8 d8 e8 f8 + // 09 19 29 39 49 59 69 79 89 99 a9 b9 c9 d9 e9 f9 + // 0a 1a 2a 3a 4a 5a 6a 7a 8a 9a aa ba ca da ea fa + // 0b 1b 2b 3b 4b 5b 6b 7b 8b 9b ab bb cb db eb fb + // 0c 1c 2c 3c 4c 5c 6c 7c 8c 9c ac bc cc dc ec fc + // 0d 1d 2d 3d 4d 5d 6d 7d 8d 9d ad bd cd dd ed fd + // 0e 1e 2e 3e 4e 5e 6e 7e 8e 9e ae be ce de ee fe + // 0f 1f 2f 3f 4f 5f 6f 7f 8f 9f af bf cf df ef ff + + __m512 _tmp0 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_f4, _f5, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_f6, _f7, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp4 = _mm512_shuffle_f32x4(_f8, _f9, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp5 = _mm512_shuffle_f32x4(_fa, _fb, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp6 = _mm512_shuffle_f32x4(_fc, _fd, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp7 = _mm512_shuffle_f32x4(_fe, _ff, _MM_SHUFFLE(2, 0, 2, 0)); + + __m512 _tmp8 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmp9 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmpa = _mm512_shuffle_f32x4(_f4, _f5, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmpb = _mm512_shuffle_f32x4(_f6, _f7, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmpc = _mm512_shuffle_f32x4(_f8, _f9, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmpd = _mm512_shuffle_f32x4(_fa, _fb, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmpe = _mm512_shuffle_f32x4(_fc, _fd, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmpf = _mm512_shuffle_f32x4(_fe, _ff, _MM_SHUFFLE(3, 1, 3, 1)); + + // 00 80 01 81 + // 02 82 03 83 + // 04 84 05 85 + // 06 86 07 87 + // 08 88 09 89 + // 0a 8a 0b 8b + // 0c 8c 0d 8d + // 0e 8e 0f 8f + + // 40 c0 41 c1 + // 42 c2 43 c3 + // 44 c4 45 c5 + // 46 c6 47 c7 + // 48 c8 49 c9 + // 4a ca 4b cb + // 4c cc 4d cd + // 4e ce 4f cf + + _f0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _f1 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _f2 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _f3 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _f4 = _mm512_shuffle_f32x4(_tmp8, _tmp9, _MM_SHUFFLE(2, 0, 2, 0)); + _f5 = _mm512_shuffle_f32x4(_tmpa, _tmpb, _MM_SHUFFLE(2, 0, 2, 0)); + _f6 = _mm512_shuffle_f32x4(_tmpc, _tmpd, _MM_SHUFFLE(2, 0, 2, 0)); + _f7 = _mm512_shuffle_f32x4(_tmpe, _tmpf, _MM_SHUFFLE(2, 0, 2, 0)); + + _f8 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _f9 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + _fa = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(3, 1, 3, 1)); + _fb = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(3, 1, 3, 1)); + _fc = _mm512_shuffle_f32x4(_tmp8, _tmp9, _MM_SHUFFLE(3, 1, 3, 1)); + _fd = _mm512_shuffle_f32x4(_tmpa, _tmpb, _MM_SHUFFLE(3, 1, 3, 1)); + _fe = _mm512_shuffle_f32x4(_tmpc, _tmpd, _MM_SHUFFLE(3, 1, 3, 1)); + _ff = _mm512_shuffle_f32x4(_tmpe, _tmpf, _MM_SHUFFLE(3, 1, 3, 1)); + + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + 16, _f1); + _mm512_storeu_ps(p0 + 32, _f2); + _mm512_storeu_ps(p0 + 48, _f3); + _mm512_storeu_ps(p0 + out_hstep * 4, _f4); + _mm512_storeu_ps(p0 + out_hstep * 4 + 16, _f5); + _mm512_storeu_ps(p0 + out_hstep * 4 + 32, _f6); + _mm512_storeu_ps(p0 + out_hstep * 4 + 48, _f7); + _mm512_storeu_ps(p0 + out_hstep * 8, _f8); + _mm512_storeu_ps(p0 + out_hstep * 8 + 16, _f9); + _mm512_storeu_ps(p0 + out_hstep * 8 + 32, _fa); + _mm512_storeu_ps(p0 + out_hstep * 8 + 48, _fb); + _mm512_storeu_ps(p0 + out_hstep * 12, _fc); + _mm512_storeu_ps(p0 + out_hstep * 12 + 16, _fd); + _mm512_storeu_ps(p0 + out_hstep * 12 + 32, _fe); + _mm512_storeu_ps(p0 + out_hstep * 12 + 48, _ff); + p0 += 64; + } + if (out_elempack == 1) + { + transpose16x16_ps(_f0, _f1, _f2, _f3, _f4, _f5, _f6, _f7, _f8, _f9, _fa, _fb, _fc, _fd, _fe, _ff); + + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + out_hstep, _f1); + _mm512_storeu_ps(p0 + out_hstep * 2, _f2); + _mm512_storeu_ps(p0 + out_hstep * 3, _f3); + _mm512_storeu_ps(p0 + out_hstep * 4, _f4); + _mm512_storeu_ps(p0 + out_hstep * 5, _f5); + _mm512_storeu_ps(p0 + out_hstep * 6, _f6); + _mm512_storeu_ps(p0 + out_hstep * 7, _f7); + _mm512_storeu_ps(p0 + out_hstep * 8, _f8); + _mm512_storeu_ps(p0 + out_hstep * 9, _f9); + _mm512_storeu_ps(p0 + out_hstep * 10, _fa); + _mm512_storeu_ps(p0 + out_hstep * 11, _fb); + _mm512_storeu_ps(p0 + out_hstep * 12, _fc); + _mm512_storeu_ps(p0 + out_hstep * 13, _fd); + _mm512_storeu_ps(p0 + out_hstep * 14, _fe); + _mm512_storeu_ps(p0 + out_hstep * 15, _ff); + p0 += 16; + } + } + } + for (; jj + 7 < max_jj; jj += 8) + { + __m512 _f0 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)pp)); + __m512 _f1 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 16))); + __m512 _f2 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 32))); + __m512 _f3 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 48))); + __m512 _f4 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 64))); + __m512 _f5 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 80))); + __m512 _f6 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 96))); + __m512 _f7 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 112))); + pp += 128; + + + // from + // 00 11 22 33 44 55 66 77 80 91 a2 b3 c4 d5 e6 f7 + // 01 12 23 30 45 56 67 74 81 92 a3 b0 c5 d6 e7 f4 + // 20 31 02 13 64 75 46 57 a0 b1 82 93 e4 f5 c6 d7 + // 21 32 03 10 65 76 47 54 a1 b2 83 90 e5 f6 c7 d4 + // 04 15 26 37 40 51 62 73 84 95 a6 b7 c0 d1 e2 f3 + // 05 16 27 34 41 52 63 70 85 96 a7 b4 c1 d2 e3 f0 + // 24 35 06 17 60 71 42 53 a4 b5 86 97 e0 f1 c2 d3 + // 25 36 07 14 61 72 43 50 a5 b6 87 94 e1 f2 c3 d0 + // + // to + // 00 10 20 30 40 50 60 70 80 90 a0 b0 c0 d0 e0 f0 + // 01 11 21 31 41 51 61 71 81 91 a1 b1 c1 d1 e1 f1 + // 02 12 22 32 42 52 62 72 82 92 a2 b2 c2 d2 e2 f2 + // 03 13 23 33 43 53 63 73 83 93 a3 b3 c3 d3 e3 f3 + // 04 14 24 34 44 54 64 74 84 94 a4 b4 c4 d4 e4 f4 + // 05 15 25 35 45 55 65 75 85 95 a5 b5 c5 d5 e5 f5 + // 06 16 26 36 46 56 66 76 86 96 a6 b6 c6 d6 e6 f6 + // 07 17 27 37 47 57 67 77 87 97 a7 b7 c7 d7 e7 f7 + + // _f0 = _mm512_setr_ps(00,11,22,33,44,55,66,77,0x80,0x91,0xa2,0xb3,0xc4,0xd5,0xe6,0xf7); + // _f1 = _mm512_setr_ps(01,12,23,30,45,56,67,74,0x81,0x92,0xa3,0xb0,0xc5,0xd6,0xe7,0xf4); + // _f2 = _mm512_setr_ps(20,31,02,13,64,75,46,57,0xa0,0xb1,0x82,0x93,0xe4,0xf5,0xc6,0xd7); + // _f3 = _mm512_setr_ps(21,32,03,10,65,76,47,54,0xa1,0xb2,0x83,0x90,0xe5,0xf6,0xc7,0xd4); + // _f4 = _mm512_setr_ps(04,15,26,37,40,51,62,73,0x84,0x95,0xa6,0xb7,0xc0,0xd1,0xe2,0xf3); + // _f5 = _mm512_setr_ps(05,16,27,34,41,52,63,70,0x85,0x96,0xa7,0xb4,0xc1,0xd2,0xe3,0xf0); + // _f6 = _mm512_setr_ps(24,35,06,17,60,71,42,53,0xa4,0xb5,0x86,0x97,0xe0,0xf1,0xc2,0xd3); + // _f7 = _mm512_setr_ps(25,36,07,14,61,72,43,50,0xa5,0xb6,0x87,0x94,0xe1,0xf2,0xc3,0xd0); + + // print(_f0); + // print(_f1); + // print(_f2); + // print(_f3); + // print(_f4); + // print(_f5); + // print(_f6); + // print(_f7); + + { + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm512_permute_ps(_f3, _MM_SHUFFLE(2, 1, 0, 3)); + _f5 = _mm512_permute_ps(_f5, _MM_SHUFFLE(2, 1, 0, 3)); + _f7 = _mm512_permute_ps(_f7, _MM_SHUFFLE(2, 1, 0, 3)); + + // 00 11 22 33 44 55 66 77 80 91 a2 b3 c4 d5 e6 f7 + // 30 01 12 23 74 45 56 67 b0 81 92 a3 f4 c5 d6 e7 + // 20 31 02 13 64 75 46 57 a0 b1 82 93 e4 f5 c6 d7 + // 10 21 32 03 54 65 76 47 90 a1 b2 83 d4 e5 f6 c7 + + // 04 15 26 37 40 51 62 73 84 95 a6 b7 c0 d1 e2 f3 + // 34 05 16 27 70 41 52 63 b4 85 96 a7 f0 c1 d2 e3 + // 24 35 06 17 60 71 42 53 a4 b5 86 97 e0 f1 c2 d3 + // 14 25 36 07 50 61 72 43 94 a5 b6 87 d0 e1 f2 c3 + + __m512 _tmp0 = _mm512_unpacklo_ps(_f0, _f3); + __m512 _tmp1 = _mm512_unpackhi_ps(_f0, _f3); + __m512 _tmp2 = _mm512_unpacklo_ps(_f2, _f1); + __m512 _tmp3 = _mm512_unpackhi_ps(_f2, _f1); + __m512 _tmp4 = _mm512_unpacklo_ps(_f4, _f7); + __m512 _tmp5 = _mm512_unpackhi_ps(_f4, _f7); + __m512 _tmp6 = _mm512_unpacklo_ps(_f6, _f5); + __m512 _tmp7 = _mm512_unpackhi_ps(_f6, _f5); + + // 00 10 11 21 44 54 55 65 80 90 91 a1 c4 d4 d5 e5 + // 22 32 33 03 66 76 77 47 a2 b2 b3 83 e6 f6 f7 c7 + // 20 30 31 01 64 74 75 45 a0 b0 b1 81 e4 f4 f5 c5 + // 02 12 13 23 46 56 57 67 82 92 93 a3 c6 d6 d7 e7 + + // 04 14 15 25 40 50 51 61 + // 26 36 37 07 62 72 73 43 + // 24 34 35 05 60 70 71 41 + // 06 16 17 27 42 52 53 63 + + + _f0 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp2))); + _f1 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp2))); + _f2 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp3), _mm512_castps_pd(_tmp1))); + _f3 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp3), _mm512_castps_pd(_tmp1))); + _f4 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp6))); + _f5 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp6))); + _f6 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp7), _mm512_castps_pd(_tmp5))); + _f7 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp7), _mm512_castps_pd(_tmp5))); + + // 00 10 20 30 44 54 64 74 80 90 a0 b0 c4 d4 e4 f4 + // 11 21 31 01 55 65 75 45 91 a1 b1 81 d5 e5 f5 c5 + // 02 12 22 32 46 56 66 76 82 92 a2 b2 c6 d6 e6 f6 + // 13 23 33 03 57 67 77 47 93 a3 b3 83 d7 e7 f7 c7 + + // 04 14 24 34 40 50 60 70 + // 15 25 35 05 51 61 71 41 + // 06 16 26 36 42 52 62 72 + // 17 27 37 07 53 63 73 43 + + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm512_permute_ps(_f3, _MM_SHUFFLE(2, 1, 0, 3)); + _f5 = _mm512_permute_ps(_f5, _MM_SHUFFLE(2, 1, 0, 3)); + _f7 = _mm512_permute_ps(_f7, _MM_SHUFFLE(2, 1, 0, 3)); + + // 00 10 20 30 44 54 64 74 80 90 a0 b0 c4 d4 e4 f4 + // 01 11 21 31 45 55 65 75 81 91 a1 b1 c5 d5 e5 f5 + // 02 12 22 32 46 56 66 76 82 92 a2 b2 c6 d6 e6 f6 + // 03 13 23 33 47 57 67 77 83 93 a3 b3 c7 d7 e7 f7 + + // 04 14 24 34 40 50 60 70 84 94 a4 b4 c0 d0 e0 f0 + // 05 15 25 35 41 51 61 71 85 95 a5 b5 c1 d1 e1 f1 + // 06 16 26 36 42 52 62 72 86 96 a6 b6 c2 d2 e2 f2 + // 07 17 27 37 43 53 63 73 87 97 a7 b7 c3 d3 e3 f3 + + _tmp0 = _mm512_shuffle_f32x4(_f0, _f4, _MM_SHUFFLE(0, 1, 1, 0)); + _tmp1 = _mm512_shuffle_f32x4(_f1, _f5, _MM_SHUFFLE(0, 1, 1, 0)); + _tmp2 = _mm512_shuffle_f32x4(_f2, _f6, _MM_SHUFFLE(0, 1, 1, 0)); + _tmp3 = _mm512_shuffle_f32x4(_f3, _f7, _MM_SHUFFLE(0, 1, 1, 0)); + _tmp4 = _mm512_shuffle_f32x4(_f0, _f4, _MM_SHUFFLE(2, 3, 3, 2)); + _tmp5 = _mm512_shuffle_f32x4(_f1, _f5, _MM_SHUFFLE(2, 3, 3, 2)); + _tmp6 = _mm512_shuffle_f32x4(_f2, _f6, _MM_SHUFFLE(2, 3, 3, 2)); + _tmp7 = _mm512_shuffle_f32x4(_f3, _f7, _MM_SHUFFLE(2, 3, 3, 2)); + + // 00 10 20 30 44 54 64 74 40 50 60 70 04 14 24 34 + // 01 11 21 31 45 55 65 75 41 51 61 71 05 15 25 35 + // 02 12 22 32 46 56 66 76 42 52 62 72 06 16 26 36 + // 03 13 23 33 47 57 67 77 43 53 63 73 07 17 27 37 + + // 80 90 a0 b0 c4 d4 e4 f4 c0 d0 e0 f0 84 94 a4 b4 + // 81 91 a1 b1 c5 d5 e5 f5 c1 d1 e1 f1 85 95 a5 b5 + // 82 92 a2 b2 c6 d6 e6 f6 c2 d2 e2 f2 86 96 a6 b6 + // 83 93 a3 b3 c7 d7 e7 f7 c3 d3 e3 f3 87 97 a7 b7 + + _f0 = _mm512_shuffle_f32x4(_tmp0, _tmp4, _MM_SHUFFLE(2, 0, 2, 0)); + _f1 = _mm512_shuffle_f32x4(_tmp1, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _f2 = _mm512_shuffle_f32x4(_tmp2, _tmp6, _MM_SHUFFLE(2, 0, 2, 0)); + _f3 = _mm512_shuffle_f32x4(_tmp3, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _f4 = _mm512_shuffle_f32x4(_tmp0, _tmp4, _MM_SHUFFLE(1, 3, 1, 3)); + _f5 = _mm512_shuffle_f32x4(_tmp1, _tmp5, _MM_SHUFFLE(1, 3, 1, 3)); + _f6 = _mm512_shuffle_f32x4(_tmp2, _tmp6, _MM_SHUFFLE(1, 3, 1, 3)); + _f7 = _mm512_shuffle_f32x4(_tmp3, _tmp7, _MM_SHUFFLE(1, 3, 1, 3)); + } + + // NCNN_LOGE("-------"); + // + // print(_f0); + // print(_f1); + // print(_f2); + // print(_f3); + // print(_f4); + // print(_f5); + // print(_f6); + // print(_f7); + + _f0 = _mm512_mul_ps(_f0, _descale); + _f1 = _mm512_mul_ps(_f1, _descale); + _f2 = _mm512_mul_ps(_f2, _descale); + _f3 = _mm512_mul_ps(_f3, _descale); + _f4 = _mm512_mul_ps(_f4, _descale); + _f5 = _mm512_mul_ps(_f5, _descale); + _f6 = _mm512_mul_ps(_f6, _descale); + _f7 = _mm512_mul_ps(_f7, _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c0); + _f2 = _mm512_add_ps(_f2, _c0); + _f3 = _mm512_add_ps(_f3, _c0); + _f4 = _mm512_add_ps(_f4, _c0); + _f5 = _mm512_add_ps(_f5, _c0); + _f6 = _mm512_add_ps(_f6, _c0); + _f7 = _mm512_add_ps(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c0); + _f2 = _mm512_add_ps(_f2, _c0); + _f3 = _mm512_add_ps(_f3, _c0); + _f4 = _mm512_add_ps(_f4, _c0); + _f5 = _mm512_add_ps(_f5, _c0); + _f6 = _mm512_add_ps(_f6, _c0); + _f7 = _mm512_add_ps(_f7, _c0); + } + if (broadcast_type_C == 3) + { + __m512 _c1; + __m512 _c2; + __m512 _c3; + __m512 _c4; + __m512 _c5; + __m512 _c6; + __m512 _c7; + if (c_elempack == 16) + { + _c0 = _mm512_loadu_ps(pC); + _c1 = _mm512_loadu_ps(pC + 16); + _c2 = _mm512_loadu_ps(pC + 32); + _c3 = _mm512_loadu_ps(pC + 48); + _c4 = _mm512_loadu_ps(pC + 64); + _c5 = _mm512_loadu_ps(pC + 80); + _c6 = _mm512_loadu_ps(pC + 96); + _c7 = _mm512_loadu_ps(pC + 112); + pC += 128; + } + // if (c_elempack == 8) + // { + // _c0 = _mm256_loadu_ps(pC); + // _c1 = _mm256_loadu_ps(pC + 8); + // _c2 = _mm256_loadu_ps(pC + 16); + // _c3 = _mm256_loadu_ps(pC + 24); + // _c4 = _mm256_loadu_ps(pC + 32); + // _c5 = _mm256_loadu_ps(pC + 40); + // _c6 = _mm256_loadu_ps(pC + 48); + // _c7 = _mm256_loadu_ps(pC + 56); + // pC += 64; + // } + // if (c_elempack == 4) + // { + // __m256 _tmp0 = _mm256_loadu_ps(pC); + // __m256 _tmp1 = _mm256_loadu_ps(pC + 8); + // __m256 _tmp2 = _mm256_loadu_ps(pC + 16); + // __m256 _tmp3 = _mm256_loadu_ps(pC + 24); + // __m256 _tmp4 = _mm256_loadu_ps(pC + c_hstep * 4); + // __m256 _tmp5 = _mm256_loadu_ps(pC + c_hstep * 4 + 8); + // __m256 _tmp6 = _mm256_loadu_ps(pC + c_hstep * 4 + 16); + // __m256 _tmp7 = _mm256_loadu_ps(pC + c_hstep * 4 + 24); + // _c0 = _mm256_permute2f128_ps(_tmp0, _tmp4, _MM_SHUFFLE(0, 2, 0, 0)); + // _c1 = _mm256_permute2f128_ps(_tmp0, _tmp4, _MM_SHUFFLE(0, 3, 0, 1)); + // _c2 = _mm256_permute2f128_ps(_tmp1, _tmp5, _MM_SHUFFLE(0, 2, 0, 0)); + // _c3 = _mm256_permute2f128_ps(_tmp1, _tmp5, _MM_SHUFFLE(0, 3, 0, 1)); + // _c4 = _mm256_permute2f128_ps(_tmp2, _tmp6, _MM_SHUFFLE(0, 2, 0, 0)); + // _c5 = _mm256_permute2f128_ps(_tmp2, _tmp6, _MM_SHUFFLE(0, 3, 0, 1)); + // _c6 = _mm256_permute2f128_ps(_tmp3, _tmp7, _MM_SHUFFLE(0, 2, 0, 0)); + // _c7 = _mm256_permute2f128_ps(_tmp3, _tmp7, _MM_SHUFFLE(0, 3, 0, 1)); + // pC += 32; + // } + // if (c_elempack == 1) + // { + // _c0 = _mm256_loadu_ps(pC); + // _c1 = _mm256_loadu_ps(pC + c_hstep); + // _c2 = _mm256_loadu_ps(pC + c_hstep * 2); + // _c3 = _mm256_loadu_ps(pC + c_hstep * 3); + // _c4 = _mm256_loadu_ps(pC + c_hstep * 4); + // _c5 = _mm256_loadu_ps(pC + c_hstep * 5); + // _c6 = _mm256_loadu_ps(pC + c_hstep * 6); + // _c7 = _mm256_loadu_ps(pC + c_hstep * 7); + // transpose8x8_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); + // pC += 8; + // } + if (beta == 1.f) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c1); + _f2 = _mm512_add_ps(_f2, _c2); + _f3 = _mm512_add_ps(_f3, _c3); + _f4 = _mm512_add_ps(_f4, _c4); + _f5 = _mm512_add_ps(_f5, _c5); + _f6 = _mm512_add_ps(_f6, _c6); + _f7 = _mm512_add_ps(_f7, _c7); + } + else + { + __m512 _beta = _mm512_set1_ps(beta); + _f0 = _mm512_fmadd_ps(_c0, _beta, _f0); + _f1 = _mm512_fmadd_ps(_c1, _beta, _f1); + _f2 = _mm512_fmadd_ps(_c2, _beta, _f2); + _f3 = _mm512_fmadd_ps(_c3, _beta, _f3); + _f4 = _mm512_fmadd_ps(_c4, _beta, _f4); + _f5 = _mm512_fmadd_ps(_c5, _beta, _f5); + _f6 = _mm512_fmadd_ps(_c6, _beta, _f6); + _f7 = _mm512_fmadd_ps(_c7, _beta, _f7); + } + } + if (broadcast_type_C == 4) + { + _c0 = _mm512_set1_ps(pC[0] * beta); + __m512 _c1 = _mm512_set1_ps(pC[1] * beta); + __m512 _c2 = _mm512_set1_ps(pC[2] * beta); + __m512 _c3 = _mm512_set1_ps(pC[3] * beta); + + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c1); + _f2 = _mm512_add_ps(_f2, _c2); + _f3 = _mm512_add_ps(_f3, _c3); + + _c0 = _mm512_set1_ps(pC[4] * beta); + _c1 = _mm512_set1_ps(pC[5] * beta); + _c2 = _mm512_set1_ps(pC[6] * beta); + _c3 = _mm512_set1_ps(pC[7] * beta); + + _f4 = _mm512_add_ps(_f4, _c0); + _f5 = _mm512_add_ps(_f5, _c1); + _f6 = _mm512_add_ps(_f6, _c2); + _f7 = _mm512_add_ps(_f7, _c3); + pC += 8; + } + } + + if (alpha != 1.f) + { + __m512 _alpha = _mm512_set1_ps(alpha); + _f0 = _mm512_mul_ps(_f0, _alpha); + _f1 = _mm512_mul_ps(_f1, _alpha); + _f2 = _mm512_mul_ps(_f2, _alpha); + _f3 = _mm512_mul_ps(_f3, _alpha); + _f4 = _mm512_mul_ps(_f4, _alpha); + _f5 = _mm512_mul_ps(_f5, _alpha); + _f6 = _mm512_mul_ps(_f6, _alpha); + _f7 = _mm512_mul_ps(_f7, _alpha); + } + + if (output_transpose) + { + // 00 10 20 30 40 50 60 70 80 90 a0 b0 c0 d0 e0 f0 + // 01 11 21 31 41 51 61 71 81 91 a1 b1 c1 d1 e1 f1 + // 02 12 22 32 42 52 62 72 82 92 a2 b2 c2 d2 e2 f2 + // 03 13 23 33 43 53 63 73 83 93 a3 b3 c3 d3 e3 f3 + // 04 14 24 34 44 54 64 74 84 94 a4 b4 c4 d4 e4 f4 + // 05 15 25 35 45 55 65 75 85 95 a5 b5 c5 d5 e5 f5 + // 06 16 26 36 46 56 66 76 86 96 a6 b6 c6 d6 e6 f6 + // 07 17 27 37 47 57 67 77 87 97 a7 b7 c7 d7 e7 f7 + + if (out_elempack == 8) + { + transpose16x8_ps(_f0, _f1, _f2, _f3, _f4, _f5, _f6, _f7); + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + 16, _f1); + _mm512_storeu_ps(p0 + 32, _f2); + _mm512_storeu_ps(p0 + 48, _f3); + _mm512_storeu_ps(p0 + 64, _f4); + _mm512_storeu_ps(p0 + 80, _f5); + _mm512_storeu_ps(p0 + 96, _f6); + _mm512_storeu_ps(p0 + 112, _f7); + } + if (out_elempack == 4) + { + transpose16x4_ps(_f0, _f1, _f2, _f3); + transpose16x4_ps(_f4, _f5, _f6, _f7); + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + 16, _f1); + _mm512_storeu_ps(p0 + 32, _f2); + _mm512_storeu_ps(p0 + 48, _f3); + _mm512_storeu_ps(p0 + out_hstep * 4, _f4); + _mm512_storeu_ps(p0 + out_hstep * 4 + 16, _f5); + _mm512_storeu_ps(p0 + out_hstep * 4 + 32, _f6); + _mm512_storeu_ps(p0 + out_hstep * 4 + 48, _f7); + } + if (out_elempack == 1) + { + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + out_hstep, _f1); + _mm512_storeu_ps(p0 + out_hstep * 2, _f2); + _mm512_storeu_ps(p0 + out_hstep * 3, _f3); + _mm512_storeu_ps(p0 + out_hstep * 4, _f4); + _mm512_storeu_ps(p0 + out_hstep * 5, _f5); + _mm512_storeu_ps(p0 + out_hstep * 6, _f6); + _mm512_storeu_ps(p0 + out_hstep * 7, _f7); + } + p0 += out_hstep * 8; + } + else + { + if (out_elempack == 16) + { + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + 16, _f1); + _mm512_storeu_ps(p0 + 32, _f2); + _mm512_storeu_ps(p0 + 48, _f3); + _mm512_storeu_ps(p0 + 64, _f4); + _mm512_storeu_ps(p0 + 80, _f5); + _mm512_storeu_ps(p0 + 96, _f6); + _mm512_storeu_ps(p0 + 112, _f7); + p0 += 128; + } + if (out_elempack == 8) + { + _mm256_storeu_ps(p0, _mm512_extractf32x8_ps(_f0, 0)); + _mm256_storeu_ps(p0 + 8, _mm512_extractf32x8_ps(_f1, 0)); + _mm256_storeu_ps(p0 + 16, _mm512_extractf32x8_ps(_f2, 0)); + _mm256_storeu_ps(p0 + 24, _mm512_extractf32x8_ps(_f3, 0)); + _mm256_storeu_ps(p0 + 32, _mm512_extractf32x8_ps(_f4, 0)); + _mm256_storeu_ps(p0 + 40, _mm512_extractf32x8_ps(_f5, 0)); + _mm256_storeu_ps(p0 + 48, _mm512_extractf32x8_ps(_f6, 0)); + _mm256_storeu_ps(p0 + 56, _mm512_extractf32x8_ps(_f7, 0)); + _mm256_storeu_ps(p0 + out_hstep * 8, _mm512_extractf32x8_ps(_f0, 1)); + _mm256_storeu_ps(p0 + out_hstep * 8 + 8, _mm512_extractf32x8_ps(_f1, 1)); + _mm256_storeu_ps(p0 + out_hstep * 8 + 16, _mm512_extractf32x8_ps(_f2, 1)); + _mm256_storeu_ps(p0 + out_hstep * 8 + 24, _mm512_extractf32x8_ps(_f3, 1)); + _mm256_storeu_ps(p0 + out_hstep * 8 + 32, _mm512_extractf32x8_ps(_f4, 1)); + _mm256_storeu_ps(p0 + out_hstep * 8 + 40, _mm512_extractf32x8_ps(_f5, 1)); + _mm256_storeu_ps(p0 + out_hstep * 8 + 48, _mm512_extractf32x8_ps(_f6, 1)); + _mm256_storeu_ps(p0 + out_hstep * 8 + 56, _mm512_extractf32x8_ps(_f7, 1)); + p0 += 64; + } + if (out_elempack == 4) + { + // 00 10 20 30 40 50 60 70 80 90 a0 b0 c0 d0 e0 f0 + // 01 11 21 31 41 51 61 71 81 91 a1 b1 c1 d1 e1 f1 + // 02 12 22 32 42 52 62 72 82 92 a2 b2 c2 d2 e2 f2 + // 03 13 23 33 43 53 63 73 83 93 a3 b3 c3 d3 e3 f3 + // 04 14 24 34 44 54 64 74 84 94 a4 b4 c4 d4 e4 f4 + // 05 15 25 35 45 55 65 75 85 95 a5 b5 c5 d5 e5 f5 + // 06 16 26 36 46 56 66 76 86 96 a6 b6 c6 d6 e6 f6 + // 07 17 27 37 47 57 67 77 87 97 a7 b7 c7 d7 e7 f7 + + // 00 40 80 c0 + // 01 41 81 c1 + // 02 42 82 c2 + // 03 43 83 c3 + // 04 44 84 c4 + // 05 45 85 c5 + // 06 46 86 c6 + // 07 47 87 c7 + __m512 _tmp0 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_f4, _f5, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_f6, _f7, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp4 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmp5 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmp6 = _mm512_shuffle_f32x4(_f4, _f5, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmp7 = _mm512_shuffle_f32x4(_f6, _f7, _MM_SHUFFLE(3, 1, 3, 1)); + + // 00 80 01 81 + // 02 82 03 83 + // 04 84 05 85 + // 06 86 07 87 + // 40 c0 41 c1 + // 42 c2 43 c3 + // 44 c4 45 c5 + // 46 c6 47 c7 + _f0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _f1 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _f2 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _f3 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _f4 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _f5 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + _f6 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(3, 1, 3, 1)); + _f7 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(3, 1, 3, 1)); + + // 00 01 02 03 + // 04 05 06 07 + // 40 41 42 43 + // 44 45 46 47 + // 80 81 82 83 + // 84 85 86 87 + // c0 c1 c2 c3 + // c4 c5 c6 c7 + + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + 16, _f1); + _mm512_storeu_ps(p0 + out_hstep * 4, _f2); + _mm512_storeu_ps(p0 + out_hstep * 4 + 16, _f3); + _mm512_storeu_ps(p0 + out_hstep * 8, _f4); + _mm512_storeu_ps(p0 + out_hstep * 8 + 16, _f5); + _mm512_storeu_ps(p0 + out_hstep * 12, _f6); + _mm512_storeu_ps(p0 + out_hstep * 12 + 16, _f7); + p0 += 32; + } + if (out_elempack == 1) + { + transpose16x8_ps(_f0, _f1, _f2, _f3, _f4, _f5, _f6, _f7); + _mm256_storeu_ps(p0, _mm512_extractf32x8_ps(_f0, 0)); + _mm256_storeu_ps(p0 + out_hstep, _mm512_extractf32x8_ps(_f0, 1)); + _mm256_storeu_ps(p0 + out_hstep * 2, _mm512_extractf32x8_ps(_f1, 0)); + _mm256_storeu_ps(p0 + out_hstep * 3, _mm512_extractf32x8_ps(_f1, 1)); + _mm256_storeu_ps(p0 + out_hstep * 4, _mm512_extractf32x8_ps(_f2, 0)); + _mm256_storeu_ps(p0 + out_hstep * 5, _mm512_extractf32x8_ps(_f2, 1)); + _mm256_storeu_ps(p0 + out_hstep * 6, _mm512_extractf32x8_ps(_f3, 0)); + _mm256_storeu_ps(p0 + out_hstep * 7, _mm512_extractf32x8_ps(_f3, 1)); + _mm256_storeu_ps(p0 + out_hstep * 8, _mm512_extractf32x8_ps(_f4, 0)); + _mm256_storeu_ps(p0 + out_hstep * 9, _mm512_extractf32x8_ps(_f4, 1)); + _mm256_storeu_ps(p0 + out_hstep * 10, _mm512_extractf32x8_ps(_f5, 0)); + _mm256_storeu_ps(p0 + out_hstep * 11, _mm512_extractf32x8_ps(_f5, 1)); + _mm256_storeu_ps(p0 + out_hstep * 12, _mm512_extractf32x8_ps(_f6, 0)); + _mm256_storeu_ps(p0 + out_hstep * 13, _mm512_extractf32x8_ps(_f6, 1)); + _mm256_storeu_ps(p0 + out_hstep * 14, _mm512_extractf32x8_ps(_f7, 0)); + _mm256_storeu_ps(p0 + out_hstep * 15, _mm512_extractf32x8_ps(_f7, 1)); + p0 += 8; + } + } + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + __m512 _f0 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)pp)); + __m512 _f1 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 16))); + __m512 _f2 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 32))); + __m512 _f3 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 48))); + pp += 64; + + // from + // 00 11 22 33 40 51 62 73 80 91 a2 b3 c0 d1 e2 f3 + // 01 12 23 30 41 52 63 70 81 92 a3 b0 c1 d2 e3 f0 + // 20 31 02 13 60 71 42 53 a0 b1 82 93 e0 f1 c2 d3 + // 21 32 03 10 61 72 43 50 a1 b2 83 90 e1 f2 c3 d0 + // to + // 00 10 20 30 40 50 60 70 80 90 a0 b0 c0 d0 e0 f0 + // 01 11 21 31 41 51 61 71 81 91 a1 b1 c1 d1 e1 f1 + // 02 12 22 32 42 52 62 72 82 92 a2 b2 c2 d2 e2 f2 + // 03 13 23 33 43 53 63 73 83 93 a3 b3 c3 d3 e3 f3 + { + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm512_permute_ps(_f3, _MM_SHUFFLE(2, 1, 0, 3)); + __m512 _tmp0 = _mm512_unpacklo_ps(_f0, _f3); + __m512 _tmp1 = _mm512_unpackhi_ps(_f0, _f3); + __m512 _tmp2 = _mm512_unpacklo_ps(_f2, _f1); + __m512 _tmp3 = _mm512_unpackhi_ps(_f2, _f1); + _f0 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp2))); + _f1 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp2))); + _f2 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp3), _mm512_castps_pd(_tmp1))); + _f3 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp3), _mm512_castps_pd(_tmp1))); + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm512_permute_ps(_f3, _MM_SHUFFLE(2, 1, 0, 3)); + } + + _f0 = _mm512_mul_ps(_f0, _descale); + _f1 = _mm512_mul_ps(_f1, _descale); + _f2 = _mm512_mul_ps(_f2, _descale); + _f3 = _mm512_mul_ps(_f3, _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c0); + _f2 = _mm512_add_ps(_f2, _c0); + _f3 = _mm512_add_ps(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c0); + _f2 = _mm512_add_ps(_f2, _c0); + _f3 = _mm512_add_ps(_f3, _c0); + } + if (broadcast_type_C == 3) + { + __m512 _c1; + __m512 _c2; + __m512 _c3; + if (c_elempack == 16) + { + _c0 = _mm512_loadu_ps(pC); + _c1 = _mm512_loadu_ps(pC + 16); + _c2 = _mm512_loadu_ps(pC + 32); + _c3 = _mm512_loadu_ps(pC + 48); + pC += 64; + } + // if (c_elempack == 8) + // { + // _c0 = _mm256_loadu_ps(pC); + // _c1 = _mm256_loadu_ps(pC + 8); + // _c2 = _mm256_loadu_ps(pC + 16); + // _c3 = _mm256_loadu_ps(pC + 24); + // pC += 32; + // } + // if (c_elempack == 4) + // { + // __m256 _cc0 = _mm256_loadu_ps(pC); + // __m256 _cc1 = _mm256_loadu_ps(pC + 8); + // __m256 _cc2 = _mm256_loadu_ps(pC + c_hstep * 4); + // __m256 _cc3 = _mm256_loadu_ps(pC + c_hstep * 4 + 8); + // _c0 = _mm256_permute2f128_ps(_cc0, _cc2, _MM_SHUFFLE(0, 2, 0, 0)); + // _c1 = _mm256_permute2f128_ps(_cc0, _cc2, _MM_SHUFFLE(0, 3, 0, 1)); + // _c2 = _mm256_permute2f128_ps(_cc1, _cc3, _MM_SHUFFLE(0, 2, 0, 0)); + // _c3 = _mm256_permute2f128_ps(_cc1, _cc3, _MM_SHUFFLE(0, 3, 0, 1)); + // pC += 16; + // } + // if (c_elempack == 1) + // { + // // __m256i _vindex = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + // // _c0 = _mm256_i32gather_ps(pC, _vindex, c_hstep * sizeof(float)); + // // _c1 = _mm256_i32gather_ps(pC + 1, _vindex, c_hstep * sizeof(float)); + // // _c2 = _mm256_i32gather_ps(pC + 2, _vindex, c_hstep * sizeof(float)); + // // _c3 = _mm256_i32gather_ps(pC + 3, _vindex, c_hstep * sizeof(float)); + // + // __m128 _cc0 = _mm_loadu_ps(pC); + // __m128 _cc1 = _mm_loadu_ps(pC + c_hstep); + // __m128 _cc2 = _mm_loadu_ps(pC + c_hstep * 2); + // __m128 _cc3 = _mm_loadu_ps(pC + c_hstep * 3); + // __m128 _cc4 = _mm_loadu_ps(pC + c_hstep * 4); + // __m128 _cc5 = _mm_loadu_ps(pC + c_hstep * 5); + // __m128 _cc6 = _mm_loadu_ps(pC + c_hstep * 6); + // __m128 _cc7 = _mm_loadu_ps(pC + c_hstep * 7); + // _MM_TRANSPOSE4_PS(_cc0, _cc1, _cc2, _cc3); + // _MM_TRANSPOSE4_PS(_cc4, _cc5, _cc6, _cc7); + // + // _c0 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc0), _cc4, 1); + // _c1 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc1), _cc5, 1); + // _c2 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc2), _cc6, 1); + // _c3 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc3), _cc7, 1); + // + // pC += 4; + // } + if (beta == 1.f) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c1); + _f2 = _mm512_add_ps(_f2, _c2); + _f3 = _mm512_add_ps(_f3, _c3); + } + else + { + __m512 _beta = _mm512_set1_ps(beta); + _f0 = _mm512_fmadd_ps(_c0, _beta, _f0); + _f1 = _mm512_fmadd_ps(_c1, _beta, _f1); + _f2 = _mm512_fmadd_ps(_c2, _beta, _f2); + _f3 = _mm512_fmadd_ps(_c3, _beta, _f3); + } + } + if (broadcast_type_C == 4) + { + _c0 = _mm512_set1_ps(pC[0] * beta); + __m512 _c1 = _mm512_set1_ps(pC[1] * beta); + __m512 _c2 = _mm512_set1_ps(pC[2] * beta); + __m512 _c3 = _mm512_set1_ps(pC[3] * beta); + + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c1); + _f2 = _mm512_add_ps(_f2, _c2); + _f3 = _mm512_add_ps(_f3, _c3); + pC += 4; + } + } + + if (alpha != 1.f) + { + __m512 _alpha = _mm512_set1_ps(alpha); + _f0 = _mm512_mul_ps(_f0, _alpha); + _f1 = _mm512_mul_ps(_f1, _alpha); + _f2 = _mm512_mul_ps(_f2, _alpha); + _f3 = _mm512_mul_ps(_f3, _alpha); + } + + if (output_transpose) + { + // 00 10 20 30 40 50 60 70 80 90 a0 b0 c0 d0 e0 f0 + // 01 11 21 31 41 51 61 71 81 91 a1 b1 c1 d1 e1 f1 + // 02 12 22 32 42 52 62 72 82 92 a2 b2 c2 d2 e2 f2 + // 03 13 23 33 43 53 63 73 83 93 a3 b3 c3 d3 e3 f3 + if (out_elempack == 4) + { + transpose16x4_ps(_f0, _f1, _f2, _f3); + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + 16, _f1); + _mm512_storeu_ps(p0 + 32, _f2); + _mm512_storeu_ps(p0 + 48, _f3); + } + if (out_elempack == 1) + { + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + out_hstep, _f1); + _mm512_storeu_ps(p0 + out_hstep * 2, _f2); + _mm512_storeu_ps(p0 + out_hstep * 3, _f3); + } + p0 += out_hstep * 4; + } + else + { + if (out_elempack == 16) + { + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + 16, _f1); + _mm512_storeu_ps(p0 + 32, _f2); + _mm512_storeu_ps(p0 + 48, _f3); + p0 += 64; + } + if (out_elempack == 8) + { + __m512 _tmp0 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(3, 2, 3, 2)); + + _mm512_storeu_ps(p0, _tmp0); + _mm512_storeu_ps(p0 + 16, _tmp1); + _mm512_storeu_ps(p0 + out_hstep * 8, _tmp2); + _mm512_storeu_ps(p0 + out_hstep * 8 + 16, _tmp3); + p0 += 32; + } + if (out_elempack == 4) + { + __m512 _tmp0 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(3, 2, 3, 2)); + _f0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _f1 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _f2 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _f3 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + out_hstep * 4, _f1); + _mm512_storeu_ps(p0 + out_hstep * 8, _f2); + _mm512_storeu_ps(p0 + out_hstep * 12, _f3); + p0 += 16; + } + if (out_elempack == 1) + { + transpose16x4_ps(_f0, _f1, _f2, _f3); + + _mm_storeu_ps(p0, _mm512_extractf32x4_ps(_f0, 0)); + _mm_storeu_ps(p0 + out_hstep, _mm512_extractf32x4_ps(_f0, 1)); + _mm_storeu_ps(p0 + out_hstep * 2, _mm512_extractf32x4_ps(_f0, 2)); + _mm_storeu_ps(p0 + out_hstep * 3, _mm512_extractf32x4_ps(_f0, 3)); + _mm_storeu_ps(p0 + out_hstep * 4, _mm512_extractf32x4_ps(_f1, 0)); + _mm_storeu_ps(p0 + out_hstep * 5, _mm512_extractf32x4_ps(_f1, 1)); + _mm_storeu_ps(p0 + out_hstep * 6, _mm512_extractf32x4_ps(_f1, 2)); + _mm_storeu_ps(p0 + out_hstep * 7, _mm512_extractf32x4_ps(_f1, 3)); + _mm_storeu_ps(p0 + out_hstep * 8, _mm512_extractf32x4_ps(_f2, 0)); + _mm_storeu_ps(p0 + out_hstep * 9, _mm512_extractf32x4_ps(_f2, 1)); + _mm_storeu_ps(p0 + out_hstep * 10, _mm512_extractf32x4_ps(_f2, 2)); + _mm_storeu_ps(p0 + out_hstep * 11, _mm512_extractf32x4_ps(_f2, 3)); + _mm_storeu_ps(p0 + out_hstep * 12, _mm512_extractf32x4_ps(_f3, 0)); + _mm_storeu_ps(p0 + out_hstep * 13, _mm512_extractf32x4_ps(_f3, 1)); + _mm_storeu_ps(p0 + out_hstep * 14, _mm512_extractf32x4_ps(_f3, 2)); + _mm_storeu_ps(p0 + out_hstep * 15, _mm512_extractf32x4_ps(_f3, 3)); + p0 += 4; + } + } + } + for (; jj + 1 < max_jj; jj += 2) + { + __m512 _f0 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)pp)); + __m512 _f1 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)(pp + 16))); + pp += 32; + + // from + // 00 11 20 31 40 51 60 71 80 91 a0 b1 c0 d1 e0 f1 + // 01 10 21 30 41 50 61 70 81 90 a1 b0 c1 d0 e1 f0 + // to + // 00 10 20 30 40 50 60 70 80 90 a0 b0 c0 d0 e0 f0 + // 01 11 21 31 41 51 61 71 81 91 a1 b1 c1 d1 e1 f1 + { + __m512 _tmp0 = _mm512_permute_ps(_f0, _MM_SHUFFLE(3, 1, 2, 0)); + __m512 _tmp1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(0, 2, 3, 1)); + _f0 = _mm512_unpacklo_ps(_tmp0, _tmp1); + _f1 = _mm512_unpackhi_ps(_tmp0, _tmp1); + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + } + + _f0 = _mm512_mul_ps(_f0, _descale); + _f1 = _mm512_mul_ps(_f1, _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c0); + } + if (broadcast_type_C == 3) + { + __m512 _c1; + if (c_elempack == 16) + { + _c0 = _mm512_loadu_ps(pC); + _c1 = _mm512_loadu_ps(pC + 16); + pC += 32; + } + // if (c_elempack == 8) + // { + // _c0 = _mm256_loadu_ps(pC); + // _c1 = _mm256_loadu_ps(pC + 8); + // pC += 16; + // } + // if (c_elempack == 4) + // { + // __m256 _cc0 = _mm256_loadu_ps(pC); + // __m256 _cc1 = _mm256_loadu_ps(pC + c_hstep * 4); + // _c0 = _mm256_permute2f128_ps(_cc0, _cc1, _MM_SHUFFLE(0, 2, 0, 0)); + // _c1 = _mm256_permute2f128_ps(_cc0, _cc1, _MM_SHUFFLE(0, 3, 0, 1)); + // pC += 8; + // } + // if (c_elempack == 1) + // { + // __m256i _vindex = _mm256_mullo_epi32(_mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7), _mm256_set1_epi32(c_hstep)); + // _c0 = _mm256_i32gather_ps(pC, _vindex, sizeof(float)); + // _c1 = _mm256_i32gather_ps(pC + 1, _vindex, sizeof(float)); + // pC += 2; + // } + if (beta == 1.f) + { + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c1); + } + else + { + __m512 _beta = _mm512_set1_ps(beta); + _f0 = _mm512_fmadd_ps(_c0, _beta, _f0); + _f1 = _mm512_fmadd_ps(_c1, _beta, _f1); + } + } + if (broadcast_type_C == 4) + { + _c0 = _mm512_set1_ps(pC[0] * beta); + __m512 _c1 = _mm512_set1_ps(pC[1] * beta); + _f0 = _mm512_add_ps(_f0, _c0); + _f1 = _mm512_add_ps(_f1, _c1); + pC += 2; + } + } + + if (alpha != 1.f) + { + __m512 _alpha = _mm512_set1_ps(alpha); + _f0 = _mm512_mul_ps(_f0, _alpha); + _f1 = _mm512_mul_ps(_f1, _alpha); + } + + if (output_transpose) + { + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + out_hstep, _f1); + p0 += out_hstep * 2; + } + else + { + if (out_elempack == 16) + { + _mm512_store_ps(p0, _f0); + _mm512_store_ps(p0 + 16, _f1); + p0 += 32; + } + if (out_elempack == 8) + { + _mm256_storeu_ps(p0, _mm512_extractf32x8_ps(_f0, 0)); + _mm256_storeu_ps(p0 + 8, _mm512_extractf32x8_ps(_f1, 0)); + _mm256_storeu_ps(p0 + out_hstep * 8, _mm512_extractf32x8_ps(_f0, 1)); + _mm256_storeu_ps(p0 + out_hstep * 8 + 8, _mm512_extractf32x8_ps(_f1, 1)); + p0 += 16; + } + if (out_elempack == 4) + { + _mm_storeu_ps(p0, _mm512_extractf32x4_ps(_f0, 0)); + _mm_storeu_ps(p0 + 4, _mm512_extractf32x4_ps(_f1, 0)); + _mm_storeu_ps(p0 + out_hstep * 4, _mm512_extractf32x4_ps(_f0, 1)); + _mm_storeu_ps(p0 + out_hstep * 4 + 4, _mm512_extractf32x4_ps(_f1, 1)); + _mm_storeu_ps(p0 + out_hstep * 8, _mm512_extractf32x4_ps(_f0, 2)); + _mm_storeu_ps(p0 + out_hstep * 8 + 4, _mm512_extractf32x4_ps(_f1, 2)); + _mm_storeu_ps(p0 + out_hstep * 12, _mm512_extractf32x4_ps(_f0, 3)); + _mm_storeu_ps(p0 + out_hstep * 12 + 4, _mm512_extractf32x4_ps(_f1, 3)); + p0 += 8; + } + if (out_elempack == 1) + { + __m512i _vindex = _mm512_mullo_epi32(_mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), _mm512_set1_epi32(out_hstep)); + _mm512_i32scatter_ps(p0, _vindex, _f0, sizeof(float)); + _mm512_i32scatter_ps(p0 + 1, _vindex, _f1, sizeof(float)); + p0 += 2; + } + } + } + for (; jj < max_jj; jj++) + { + __m512 _f0 = _mm512_cvtepi32_ps(_mm512_load_si512((const __m512i*)pp)); + pp += 16; + + _f0 = _mm512_mul_ps(_f0, _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm512_add_ps(_f0, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm512_add_ps(_f0, _c0); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 16) + { + _c0 = _mm512_loadu_ps(pC); + pC += 16; + } + // if (c_elempack == 8) + // { + // _c0 = _mm256_loadu_ps(pC); + // pC += 8; + // } + // if (c_elempack == 4) + // { + // __m128 _cc0 = _mm_loadu_ps(pC); + // __m128 _cc1 = _mm_loadu_ps(pC + c_hstep * 4); + // _c0 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc0), _cc1, 1); + // pC += 4; + // } + // if (c_elempack == 1) + // { + // __m256i _vindex = _mm256_mullo_epi32(_mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7), _mm256_set1_epi32(c_hstep)); + // _c0 = _mm256_i32gather_ps(pC, _vindex, sizeof(float)); + // pC += 1; + // } + _f0 = _mm512_fmadd_ps(_c0, _mm512_set1_ps(beta), _f0); + } + if (broadcast_type_C == 4) + { + _c0 = _mm512_set1_ps(pC[0] * beta); + _f0 = _mm512_add_ps(_f0, _c0); + pC += 1; + } + } + + _f0 = _mm512_mul_ps(_f0, _mm512_set1_ps(alpha)); + + if (output_transpose) { - pp[0] = float2int8(p0[0] * scale); - pp[1] = float2int8(p0[1] * scale); - pp[2] = float2int8(p0[2] * scale); - pp[3] = float2int8(p0[3] * scale); - pp += 4; - p0 += B_hstep * 4; + _mm512_storeu_ps(p0, _f0); + p0 += out_hstep; } - } -#endif // __SSE2__ - if (elempack == 1) - { - int kk = 0; - for (; kk < max_kk; kk++) + else { - pp[0] = float2int8(p0[0] * scale); - pp += 1; - p0 += B_hstep; + if (out_elempack == 16) + { + _mm512_storeu_ps(p0, _f0); + p0 += 16; + } + if (out_elempack == 8) + { + _mm256_storeu_ps(p0, _mm512_extractf32x8_ps(_f0, 0)); + _mm256_storeu_ps(p0 + out_hstep * 8, _mm512_extractf32x8_ps(_f0, 1)); + p0 += 8; + } + if (out_elempack == 4) + { + _mm_storeu_ps(p0, _mm512_extractf32x4_ps(_f0, 0)); + _mm_storeu_ps(p0 + out_hstep * 4, _mm512_extractf32x4_ps(_f0, 1)); + _mm_storeu_ps(p0 + out_hstep * 8, _mm512_extractf32x4_ps(_f0, 2)); + _mm_storeu_ps(p0 + out_hstep * 12, _mm512_extractf32x4_ps(_f0, 3)); + p0 += 4; + } + if (out_elempack == 1) + { + __m512i _vindex = _mm512_mullo_epi32(_mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), _mm512_set1_epi32(out_hstep)); + _mm512_i32scatter_ps(p0, _vindex, _f0, sizeof(float)); + p0++; + } } } } -} - -static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta, int output_transpose) -{ -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ - if (ncnn::cpu_support_x86_avx2()) - { - unpack_output_tile_int32_to_fp32_avx2(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta, output_transpose); - return; - } -#endif - - const int out_elempack = top_blob.elempack; - const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; - - const int c_hstep = C.dims == 3 ? (int)C.cstep : C.w; - const int c_elempack = C.elempack; - const float* pC = C; - - NCNN_LOGE("unpack_output_tile_int32_to_fp32 %d %d %d %d %d %d %d %d", i, max_ii, j, max_jj, out_elempack, broadcast_type_C, c_elempack, output_transpose); - - // const int* pp = topT; - - int ii = 0; -#if __SSE2__ -#if __AVX__ +#endif // __AVX512F__ for (; ii + 7 < max_ii; ii += 8) { #if __AVX2__ @@ -2416,19 +7053,31 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& } __m256 _descale = _mm256_loadu_ps((const float*)descales + i + ii); +#if __AVX512F__ + __m512 _descale_avx512 = _mm512_broadcast_f32x8(_descale); +#endif __m256 _c0; +#if __AVX512F__ + __m512 _c0_avx512; +#endif if (pC) { if (broadcast_type_C == 0) { _c0 = _mm256_set1_ps(pC[0] * beta); +#if __AVX512F__ + _c0_avx512 = _mm512_set1_ps(pC[0] * beta); +#endif } if (broadcast_type_C == 1 || broadcast_type_C == 2) { pC = (const float*)C + i + ii; _c0 = _mm256_loadu_ps(pC); _c0 = _mm256_mul_ps(_c0, _mm256_set1_ps(beta)); +#if __AVX512F__ + _c0_avx512 = _mm512_broadcast_f32x8(_c0); +#endif } if (broadcast_type_C == 3) { @@ -2442,6 +7091,515 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& int jj = 0; #if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + __m512 _f0 = _mm512_cvtepi32_ps(_mm512_loadu_si512((const __m512i*)pp)); + __m512 _f1 = _mm512_cvtepi32_ps(_mm512_loadu_si512((const __m512i*)(pp + 16))); + __m512 _f2 = _mm512_cvtepi32_ps(_mm512_loadu_si512((const __m512i*)(pp + 32))); + __m512 _f3 = _mm512_cvtepi32_ps(_mm512_loadu_si512((const __m512i*)(pp + 48))); + __m512 _f4 = _mm512_cvtepi32_ps(_mm512_loadu_si512((const __m512i*)(pp + 64))); + __m512 _f5 = _mm512_cvtepi32_ps(_mm512_loadu_si512((const __m512i*)(pp + 80))); + __m512 _f6 = _mm512_cvtepi32_ps(_mm512_loadu_si512((const __m512i*)(pp + 96))); + __m512 _f7 = _mm512_cvtepi32_ps(_mm512_loadu_si512((const __m512i*)(pp + 112))); + pp += 128; + + + // _f0 = _mm512_setr_ps(00,11,22,33,44,55,66,77,0x08,0x19,0x2a,0x3b,0x4c,0x5d,0x6e,0x7f); + // _f1 = _mm512_setr_ps(01,12,23,30,45,56,67,74,0x09,0x1a,0x2b,0x38,0x4d,0x5e,0x6f,0x7c); + // _f2 = _mm512_setr_ps(20,31,02,13,64,75,46,57,0x28,0x39,0x0a,0x1b,0x6c,0x7d,0x4e,0x5f); + // _f3 = _mm512_setr_ps(21,32,03,10,65,76,47,54,0x29,0x3a,0x0b,0x18,0x6d,0x7e,0x4f,0x5c); + // _f4 = _mm512_setr_ps(04,15,26,37,40,51,62,73,0x0c,0x1d,0x2e,0x3f,0x48,0x59,0x6a,0x7b); + // _f5 = _mm512_setr_ps(05,16,27,34,41,52,63,70,0x0d,0x1e,0x2f,0x3c,0x49,0x5a,0x6b,0x78); + // _f6 = _mm512_setr_ps(24,35,06,17,60,71,42,53,0x2c,0x3d,0x0e,0x1f,0x68,0x79,0x4a,0x5b); + // _f7 = _mm512_setr_ps(25,36,07,14,61,72,43,50,0x2d,0x3e,0x0f,0x1c,0x69,0x7a,0x4b,0x58); + // + + // print(_f0); + // print(_f1); + // print(_f2); + // print(_f3); + // print(_f4); + // print(_f5); + // print(_f6); + // print(_f7); + + // from + // 00 11 22 33 44 55 66 77 08 19 2a 3b 4c 5d 6e 7f + // 01 12 23 30 45 56 67 74 09 1a 2b 38 4d 5e 6f 7c + // 20 31 02 13 64 75 46 57 28 39 0a 1b 6c 7d 4e 5f + // 21 32 03 10 65 76 47 54 29 3a 0b 18 6d 7e 4f 5c + // 04 15 26 37 40 51 62 73 0c 1d 2e 3f 48 59 6a 7b + // 05 16 27 34 41 52 63 70 0d 1e 2f 3c 49 5a 6b 78 + // 24 35 06 17 60 71 42 53 2c 3d 0e 1f 68 79 4a 5b + // 25 36 07 14 61 72 43 50 2d 3e 0f 1c 69 7a 4b 58 + + + + // 00 10 20 30 40 50 60 70 08 18 28 38 48 58 68 78 + // 01 11 21 31 41 51 61 71 09 19 29 39 49 59 69 79 + // 02 12 22 32 42 52 62 72 0a 1a 2a 3a 4a 5a 6a 7a + // 03 13 23 33 43 53 63 73 0b 1b 2b 3b 4b 5b 6b 7b + // 04 14 24 34 44 54 64 74 0c 1c 2c 3c 4c 5c 6c 7c + // 05 15 25 35 45 55 65 75 0d 1d 2d 3d 4d 5d 6d 7d + // 06 16 26 36 46 56 66 76 0e 1e 2e 3e 4e 5e 6e 7e + // 07 17 27 37 47 57 67 77 0f 1f 2f 3f 4f 5f 6f 7f + + // to + // 00 10 20 30 44 54 64 74 08 18 28 38 4c 5c 6c 7c + // 01 11 21 31 45 55 65 75 09 19 29 39 4d 5d 6d 7d + // 02 12 22 32 46 56 66 76 0a 1a 2a 3a 4e 5e 6e 7e + // 03 13 23 33 47 57 67 77 0b 1b 2b 3b 4f 5f 6f 7f + // 04 14 24 34 40 50 60 70 0c 1c 2c 3c 48 58 68 78 + // 05 15 25 35 41 51 61 71 0d 1d 2d 3d 49 59 69 79 + // 06 16 26 36 42 52 62 72 0e 1e 2e 3e 4a 5a 6a 7a + // 07 17 27 37 43 53 63 73 0f 1f 2f 3f 4b 5b 6b 7b + { + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm512_permute_ps(_f3, _MM_SHUFFLE(2, 1, 0, 3)); + _f5 = _mm512_permute_ps(_f5, _MM_SHUFFLE(2, 1, 0, 3)); + _f7 = _mm512_permute_ps(_f7, _MM_SHUFFLE(2, 1, 0, 3)); + + // 00 11 22 33 44 55 66 77 08 19 2a 3b 4c 5d 6e 7f + // 30 01 12 23 74 45 56 67 38 09 1a 2b 7c 4d 5e 6f + // 20 31 02 13 64 75 46 57 28 39 0a 1b 6c 7d 4e 5f + // 10 21 32 03 54 65 76 47 18 29 3a 0b 5c 6d 7e 4f + + // 04 15 26 37 40 51 62 73 0c 1d 2e 3f 48 59 6a 7b + // 34 05 16 27 70 41 52 63 3c 0d 1e 2f 78 49 5a 6b + // 24 35 06 17 60 71 42 53 2c 3d 0e 1f 68 79 4a 5b + // 14 25 36 07 50 61 72 43 1c 2d 3e 0f 58 69 7a 4b + + __m512 _tmp0 = _mm512_unpacklo_ps(_f0, _f3); + __m512 _tmp1 = _mm512_unpackhi_ps(_f0, _f3); + __m512 _tmp2 = _mm512_unpacklo_ps(_f2, _f1); + __m512 _tmp3 = _mm512_unpackhi_ps(_f2, _f1); + __m512 _tmp4 = _mm512_unpacklo_ps(_f4, _f7); + __m512 _tmp5 = _mm512_unpackhi_ps(_f4, _f7); + __m512 _tmp6 = _mm512_unpacklo_ps(_f6, _f5); + __m512 _tmp7 = _mm512_unpackhi_ps(_f6, _f5); + + // 00 10 11 21 44 54 55 65 08 18 19 29 4c 5c 5d 6d + // 22 32 33 03 66 76 77 47 2a 3a 3b 0b 6e 7e 7f 4f + // 20 30 31 01 64 74 75 45 28 38 39 09 6c 7c 7d 4d + // 02 12 13 23 46 56 57 67 0a 1a 1b 2b 4e 5e 5f 6f + + // 04 14 15 25 40 50 51 61 0c 1c 1d 2d 48 58 59 69 + // 26 36 37 07 62 72 73 43 2e 3e 3f 0f 6a 7a 7b 4b + // 24 34 35 05 60 70 71 41 2c 3c 3d 0d 68 78 79 49 + // 06 16 17 27 42 52 53 63 0e 1e 1f 2f 4a 5a 5b 6b + + _f0 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp2))); + _f1 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp0), _mm512_castps_pd(_tmp2))); + _f2 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp3), _mm512_castps_pd(_tmp1))); + _f3 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp3), _mm512_castps_pd(_tmp1))); + _f4 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp6))); + _f5 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp4), _mm512_castps_pd(_tmp6))); + _f6 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_tmp7), _mm512_castps_pd(_tmp5))); + _f7 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_tmp7), _mm512_castps_pd(_tmp5))); + + // 00 10 20 30 44 54 64 74 08 18 28 38 4c 5c 6c 7c + // 11 21 31 01 55 65 75 45 19 29 39 09 5d 6d 7d 4d + // 02 12 22 32 46 56 66 76 0a 1a 2a 3a 4e 5e 6e 7e + // 13 23 33 03 57 67 77 47 1b 2b 3b 0b 5f 6f 7f 4f + + // 04 14 24 34 40 50 60 70 0c 1c 2c 3c 48 58 68 78 + // 15 25 35 05 51 61 71 41 1d 2d 3d 0d 59 69 79 49 + // 06 16 26 36 42 52 62 72 0e 1e 2e 3e 4a 5a 6a 7a + // 17 27 37 07 53 63 73 43 1f 2f 3f 0f 5b 6b 7b 4b + + _f1 = _mm512_permute_ps(_f1, _MM_SHUFFLE(2, 1, 0, 3)); + _f3 = _mm512_permute_ps(_f3, _MM_SHUFFLE(2, 1, 0, 3)); + _f5 = _mm512_permute_ps(_f5, _MM_SHUFFLE(2, 1, 0, 3)); + _f7 = _mm512_permute_ps(_f7, _MM_SHUFFLE(2, 1, 0, 3)); + + // 00 10 20 30 44 54 64 74 08 18 28 38 4c 5c 6c 7c + // 01 11 21 31 45 55 65 75 09 19 29 39 4d 5d 6d 7d + // 02 12 22 32 46 56 66 76 0a 1a 2a 3a 4e 5e 6e 7e + // 03 13 23 33 47 57 67 77 0b 1b 2b 3b 4f 5f 6f 7f + // 04 14 24 34 40 50 60 70 0c 1c 2c 3c 48 58 68 78 + // 05 15 25 35 41 51 61 71 0d 1d 2d 3d 49 59 69 79 + // 06 16 26 36 42 52 62 72 0e 1e 2e 3e 4a 5a 6a 7a + // 07 17 27 37 43 53 63 73 0f 1f 2f 3f 4b 5b 6b 7b + + + _tmp0 = _mm512_shuffle_f32x4(_f0, _f4, _MM_SHUFFLE(0, 1, 1, 0)); + _tmp1 = _mm512_shuffle_f32x4(_f0, _f4, _MM_SHUFFLE(2, 3, 3, 2)); + _tmp2 = _mm512_shuffle_f32x4(_f1, _f5, _MM_SHUFFLE(0, 1, 1, 0)); + _tmp3 = _mm512_shuffle_f32x4(_f1, _f5, _MM_SHUFFLE(2, 3, 3, 2)); + _tmp4 = _mm512_shuffle_f32x4(_f2, _f6, _MM_SHUFFLE(0, 1, 1, 0)); + _tmp5 = _mm512_shuffle_f32x4(_f2, _f6, _MM_SHUFFLE(2, 3, 3, 2)); + _tmp6 = _mm512_shuffle_f32x4(_f3, _f7, _MM_SHUFFLE(0, 1, 1, 0)); + _tmp7 = _mm512_shuffle_f32x4(_f3, _f7, _MM_SHUFFLE(2, 3, 3, 2)); + + // 00 10 20 30 44 54 64 74 40 50 60 70 04 14 24 34 + // 08 18 28 38 4c 5c 6c 7c 48 58 68 78 0c 1c 2c 3c + // 01 11 21 31 45 55 65 75 41 51 61 71 05 15 25 35 + + _f0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _f1 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _f2 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _f3 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _f4 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(1, 3, 1, 3)); + _f5 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(1, 3, 1, 3)); + _f6 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(1, 3, 1, 3)); + _f7 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(1, 3, 1, 3)); + + // 00 10 20 30 40 50 60 70 08 18 28 38 48 58 68 78 + // 01 11 21 31 41 51 61 71 09 19 29 39 49 59 69 79 + // 02 12 22 32 42 52 62 72 0a 1a 2a 3a 4a 5a 6a 7a + // 03 13 23 33 43 53 63 73 0b 1b 2b 3b 4b 5b 6b 7b + // 04 14 24 34 44 54 64 74 0c 1c 2c 3c 4c 5c 6c 7c + // 05 15 25 35 45 55 65 75 0d 1d 2d 3d 4d 5d 6d 7d + // 06 16 26 36 46 56 66 76 0e 1e 2e 3e 4e 5e 6e 7e + // 07 17 27 37 47 57 67 77 0f 1f 2f 3f 4f 5f 6f 7f + + } + + // NCNN_LOGE("-----"); + + // print(_f0); + // print(_f1); + // print(_f2); + // print(_f3); + // print(_f4); + // print(_f5); + // print(_f6); + // print(_f7); + + _f0 = _mm512_mul_ps(_f0, _descale_avx512); + _f1 = _mm512_mul_ps(_f1, _descale_avx512); + _f2 = _mm512_mul_ps(_f2, _descale_avx512); + _f3 = _mm512_mul_ps(_f3, _descale_avx512); + _f4 = _mm512_mul_ps(_f4, _descale_avx512); + _f5 = _mm512_mul_ps(_f5, _descale_avx512); + _f6 = _mm512_mul_ps(_f6, _descale_avx512); + _f7 = _mm512_mul_ps(_f7, _descale_avx512); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c0_avx512); + _f2 = _mm512_add_ps(_f2, _c0_avx512); + _f3 = _mm512_add_ps(_f3, _c0_avx512); + _f4 = _mm512_add_ps(_f4, _c0_avx512); + _f5 = _mm512_add_ps(_f5, _c0_avx512); + _f6 = _mm512_add_ps(_f6, _c0_avx512); + _f7 = _mm512_add_ps(_f7, _c0_avx512); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c0_avx512); + _f2 = _mm512_add_ps(_f2, _c0_avx512); + _f3 = _mm512_add_ps(_f3, _c0_avx512); + _f4 = _mm512_add_ps(_f4, _c0_avx512); + _f5 = _mm512_add_ps(_f5, _c0_avx512); + _f6 = _mm512_add_ps(_f6, _c0_avx512); + _f7 = _mm512_add_ps(_f7, _c0_avx512); + } + if (broadcast_type_C == 3) + { + // TODO + // __m256 _c1; + // __m256 _c2; + // __m256 _c3; + // __m256 _c4; + // __m256 _c5; + // __m256 _c6; + // __m256 _c7; + // if (c_elempack == 8) + // { + // _c0 = _mm256_loadu_ps(pC); + // _c1 = _mm256_loadu_ps(pC + 8); + // _c2 = _mm256_loadu_ps(pC + 16); + // _c3 = _mm256_loadu_ps(pC + 24); + // _c4 = _mm256_loadu_ps(pC + 32); + // _c5 = _mm256_loadu_ps(pC + 40); + // _c6 = _mm256_loadu_ps(pC + 48); + // _c7 = _mm256_loadu_ps(pC + 56); + // pC += 64; + // } + // if (c_elempack == 4) + // { + // __m256 _tmp0 = _mm256_loadu_ps(pC); + // __m256 _tmp1 = _mm256_loadu_ps(pC + 8); + // __m256 _tmp2 = _mm256_loadu_ps(pC + 16); + // __m256 _tmp3 = _mm256_loadu_ps(pC + 24); + // __m256 _tmp4 = _mm256_loadu_ps(pC + c_hstep * 4); + // __m256 _tmp5 = _mm256_loadu_ps(pC + c_hstep * 4 + 8); + // __m256 _tmp6 = _mm256_loadu_ps(pC + c_hstep * 4 + 16); + // __m256 _tmp7 = _mm256_loadu_ps(pC + c_hstep * 4 + 24); + // _c0 = _mm256_permute2f128_ps(_tmp0, _tmp4, _MM_SHUFFLE(0, 2, 0, 0)); + // _c1 = _mm256_permute2f128_ps(_tmp0, _tmp4, _MM_SHUFFLE(0, 3, 0, 1)); + // _c2 = _mm256_permute2f128_ps(_tmp1, _tmp5, _MM_SHUFFLE(0, 2, 0, 0)); + // _c3 = _mm256_permute2f128_ps(_tmp1, _tmp5, _MM_SHUFFLE(0, 3, 0, 1)); + // _c4 = _mm256_permute2f128_ps(_tmp2, _tmp6, _MM_SHUFFLE(0, 2, 0, 0)); + // _c5 = _mm256_permute2f128_ps(_tmp2, _tmp6, _MM_SHUFFLE(0, 3, 0, 1)); + // _c6 = _mm256_permute2f128_ps(_tmp3, _tmp7, _MM_SHUFFLE(0, 2, 0, 0)); + // _c7 = _mm256_permute2f128_ps(_tmp3, _tmp7, _MM_SHUFFLE(0, 3, 0, 1)); + // pC += 32; + // } + // if (c_elempack == 1) + // { + // _c0 = _mm256_loadu_ps(pC); + // _c1 = _mm256_loadu_ps(pC + c_hstep); + // _c2 = _mm256_loadu_ps(pC + c_hstep * 2); + // _c3 = _mm256_loadu_ps(pC + c_hstep * 3); + // _c4 = _mm256_loadu_ps(pC + c_hstep * 4); + // _c5 = _mm256_loadu_ps(pC + c_hstep * 5); + // _c6 = _mm256_loadu_ps(pC + c_hstep * 6); + // _c7 = _mm256_loadu_ps(pC + c_hstep * 7); + // transpose8x8_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); + // pC += 8; + // } + // if (beta == 1.f) + // { + // _f0 = _mm256_add_ps(_f0, _c0); + // _f1 = _mm256_add_ps(_f1, _c1); + // _f2 = _mm256_add_ps(_f2, _c2); + // _f3 = _mm256_add_ps(_f3, _c3); + // _f4 = _mm256_add_ps(_f4, _c4); + // _f5 = _mm256_add_ps(_f5, _c5); + // _f6 = _mm256_add_ps(_f6, _c6); + // _f7 = _mm256_add_ps(_f7, _c7); + // } + // else + // { + // __m256 _beta = _mm256_set1_ps(beta); + // _f0 = _mm256_comp_fmadd_ps(_c0, _beta, _f0); + // _f1 = _mm256_comp_fmadd_ps(_c1, _beta, _f1); + // _f2 = _mm256_comp_fmadd_ps(_c2, _beta, _f2); + // _f3 = _mm256_comp_fmadd_ps(_c3, _beta, _f3); + // _f4 = _mm256_comp_fmadd_ps(_c4, _beta, _f4); + // _f5 = _mm256_comp_fmadd_ps(_c5, _beta, _f5); + // _f6 = _mm256_comp_fmadd_ps(_c6, _beta, _f6); + // _f7 = _mm256_comp_fmadd_ps(_c7, _beta, _f7); + // } + } + if (broadcast_type_C == 4) + { + // TODO + // _c0 = _mm256_set1_ps(pC[0] * beta); + // __m256 _c1 = _mm256_set1_ps(pC[1] * beta); + // __m256 _c2 = _mm256_set1_ps(pC[2] * beta); + // __m256 _c3 = _mm256_set1_ps(pC[3] * beta); + // + // _f0 = _mm256_add_ps(_f0, _c0); + // _f1 = _mm256_add_ps(_f1, _c1); + // _f2 = _mm256_add_ps(_f2, _c2); + // _f3 = _mm256_add_ps(_f3, _c3); + // + // _c0 = _mm256_set1_ps(pC[4] * beta); + // _c1 = _mm256_set1_ps(pC[5] * beta); + // _c2 = _mm256_set1_ps(pC[6] * beta); + // _c3 = _mm256_set1_ps(pC[7] * beta); + // + // _f4 = _mm256_add_ps(_f4, _c0); + // _f5 = _mm256_add_ps(_f5, _c1); + // _f6 = _mm256_add_ps(_f6, _c2); + // _f7 = _mm256_add_ps(_f7, _c3); + // pC += 8; + } + } + + if (alpha != 1.f) + { + __m512 _alpha = _mm512_set1_ps(alpha); + _f0 = _mm512_mul_ps(_f0, _alpha); + _f1 = _mm512_mul_ps(_f1, _alpha); + _f2 = _mm512_mul_ps(_f2, _alpha); + _f3 = _mm512_mul_ps(_f3, _alpha); + _f4 = _mm512_mul_ps(_f4, _alpha); + _f5 = _mm512_mul_ps(_f5, _alpha); + _f6 = _mm512_mul_ps(_f6, _alpha); + _f7 = _mm512_mul_ps(_f7, _alpha); + } + + if (output_transpose) + { + // 00 10 20 30 40 50 60 70 08 18 28 38 48 58 68 78 + // 01 11 21 31 41 51 61 71 09 19 29 39 49 59 69 79 + // 02 12 22 32 42 52 62 72 0a 1a 2a 3a 4a 5a 6a 7a + // 03 13 23 33 43 53 63 73 0b 1b 2b 3b 4b 5b 6b 7b + // 04 14 24 34 44 54 64 74 0c 1c 2c 3c 4c 5c 6c 7c + // 05 15 25 35 45 55 65 75 0d 1d 2d 3d 4d 5d 6d 7d + // 06 16 26 36 46 56 66 76 0e 1e 2e 3e 4e 5e 6e 7e + // 07 17 27 37 47 57 67 77 0f 1f 2f 3f 4f 5f 6f 7f + + if (out_elempack == 16) + { + transpose16x8_ps(_f0, _f1, _f2, _f3, _f4, _f5, _f6, _f7); + + _mm256_storeu_ps(p0, _mm512_extractf32x8_ps(_f0, 0)); + _mm256_storeu_ps(p0 + 8, _mm512_extractf32x8_ps(_f4, 0)); + _mm256_storeu_ps(p0 + 16, _mm512_extractf32x8_ps(_f0, 1)); + _mm256_storeu_ps(p0 + 16 + 8, _mm512_extractf32x8_ps(_f4, 1)); + _mm256_storeu_ps(p0 + 16 * 2, _mm512_extractf32x8_ps(_f1, 0)); + _mm256_storeu_ps(p0 + 16 * 2 + 8, _mm512_extractf32x8_ps(_f5, 0)); + _mm256_storeu_ps(p0 + 16 * 3, _mm512_extractf32x8_ps(_f1, 1)); + _mm256_storeu_ps(p0 + 16 * 3 + 8, _mm512_extractf32x8_ps(_f5, 1)); + _mm256_storeu_ps(p0 + 16 * 4, _mm512_extractf32x8_ps(_f2, 0)); + _mm256_storeu_ps(p0 + 16 * 4 + 8, _mm512_extractf32x8_ps(_f6, 0)); + _mm256_storeu_ps(p0 + 16 * 5, _mm512_extractf32x8_ps(_f2, 1)); + _mm256_storeu_ps(p0 + 16 * 5 + 8, _mm512_extractf32x8_ps(_f6, 1)); + _mm256_storeu_ps(p0 + 16 * 6, _mm512_extractf32x8_ps(_f3, 0)); + _mm256_storeu_ps(p0 + 16 * 6 + 8, _mm512_extractf32x8_ps(_f7, 0)); + _mm256_storeu_ps(p0 + 16 * 7, _mm512_extractf32x8_ps(_f3, 1)); + _mm256_storeu_ps(p0 + 16 * 7 + 8, _mm512_extractf32x8_ps(_f7, 1)); + } + if (out_elempack == 8) + { + transpose16x8_ps(_f0, _f1, _f2, _f3, _f4, _f5, _f6, _f7); + + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + 16, _f1); + _mm512_storeu_ps(p0 + 16 * 2, _f2); + _mm512_storeu_ps(p0 + 16 * 3, _f3); + _mm512_storeu_ps(p0 + out_hstep * 8, _f4); + _mm512_storeu_ps(p0 + out_hstep * 8 + 16, _f5); + _mm512_storeu_ps(p0 + out_hstep * 8 + 16 * 2, _f6); + _mm512_storeu_ps(p0 + out_hstep * 8 + 16 * 3, _f7); + } + if (out_elempack == 4) + { + transpose16x4_ps(_f0, _f1, _f2, _f3); + transpose16x4_ps(_f4, _f5, _f6, _f7); + + // 00 10 20 30 40 50 60 70 08 18 28 38 48 58 68 78 + // 01 11 21 31 41 51 61 71 09 19 29 39 49 59 69 79 + // 02 12 22 32 42 52 62 72 0a 1a 2a 3a 4a 5a 6a 7a + // 03 13 23 33 43 53 63 73 0b 1b 2b 3b 4b 5b 6b 7b + + // 04 14 24 34 44 54 64 74 0c 1c 2c 3c 4c 5c 6c 7c + // 05 15 25 35 45 55 65 75 0d 1d 2d 3d 4d 5d 6d 7d + // 06 16 26 36 46 56 66 76 0e 1e 2e 3e 4e 5e 6e 7e + // 07 17 27 37 47 57 67 77 0f 1f 2f 3f 4f 5f 6f 7f + + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + 16, _f1); + _mm512_storeu_ps(p0 + out_hstep * 4, _f4); + _mm512_storeu_ps(p0 + out_hstep * 4 + 16, _f5); + _mm512_storeu_ps(p0 + out_hstep * 8, _f2); + _mm512_storeu_ps(p0 + out_hstep * 8 + 16, _f3); + _mm512_storeu_ps(p0 + out_hstep * 12, _f6); + _mm512_storeu_ps(p0 + out_hstep * 12 + 16, _f7); + } + if (out_elempack == 1) + { + _mm256_storeu_ps(p0, _mm512_extractf32x8_ps(_f0, 0)); + _mm256_storeu_ps(p0 + out_hstep, _mm512_extractf32x8_ps(_f1, 0)); + _mm256_storeu_ps(p0 + out_hstep * 2, _mm512_extractf32x8_ps(_f2, 0)); + _mm256_storeu_ps(p0 + out_hstep * 3, _mm512_extractf32x8_ps(_f3, 0)); + _mm256_storeu_ps(p0 + out_hstep * 4, _mm512_extractf32x8_ps(_f4, 0)); + _mm256_storeu_ps(p0 + out_hstep * 5, _mm512_extractf32x8_ps(_f5, 0)); + _mm256_storeu_ps(p0 + out_hstep * 6, _mm512_extractf32x8_ps(_f6, 0)); + _mm256_storeu_ps(p0 + out_hstep * 7, _mm512_extractf32x8_ps(_f7, 0)); + _mm256_storeu_ps(p0 + out_hstep * 8, _mm512_extractf32x8_ps(_f0, 1)); + _mm256_storeu_ps(p0 + out_hstep * 9, _mm512_extractf32x8_ps(_f1, 1)); + _mm256_storeu_ps(p0 + out_hstep * 10, _mm512_extractf32x8_ps(_f2, 1)); + _mm256_storeu_ps(p0 + out_hstep * 11, _mm512_extractf32x8_ps(_f3, 1)); + _mm256_storeu_ps(p0 + out_hstep * 12, _mm512_extractf32x8_ps(_f4, 1)); + _mm256_storeu_ps(p0 + out_hstep * 13, _mm512_extractf32x8_ps(_f5, 1)); + _mm256_storeu_ps(p0 + out_hstep * 14, _mm512_extractf32x8_ps(_f6, 1)); + _mm256_storeu_ps(p0 + out_hstep * 15, _mm512_extractf32x8_ps(_f7, 1)); + } + p0 += out_hstep * 16; + } + else + { + if (out_elempack == 8) + { + _mm256_storeu_ps(p0, _mm512_extractf32x8_ps(_f0, 0)); + _mm256_storeu_ps(p0 + 8, _mm512_extractf32x8_ps(_f1, 0)); + _mm256_storeu_ps(p0 + 16, _mm512_extractf32x8_ps(_f2, 0)); + _mm256_storeu_ps(p0 + 24, _mm512_extractf32x8_ps(_f3, 0)); + _mm256_storeu_ps(p0 + 32, _mm512_extractf32x8_ps(_f4, 0)); + _mm256_storeu_ps(p0 + 40, _mm512_extractf32x8_ps(_f5, 0)); + _mm256_storeu_ps(p0 + 48, _mm512_extractf32x8_ps(_f6, 0)); + _mm256_storeu_ps(p0 + 56, _mm512_extractf32x8_ps(_f7, 0)); + _mm256_storeu_ps(p0 + 64, _mm512_extractf32x8_ps(_f0, 1)); + _mm256_storeu_ps(p0 + 64 + 8, _mm512_extractf32x8_ps(_f1, 1)); + _mm256_storeu_ps(p0 + 64 + 16, _mm512_extractf32x8_ps(_f2, 1)); + _mm256_storeu_ps(p0 + 64 + 24, _mm512_extractf32x8_ps(_f3, 1)); + _mm256_storeu_ps(p0 + 64 + 32, _mm512_extractf32x8_ps(_f4, 1)); + _mm256_storeu_ps(p0 + 64 + 40, _mm512_extractf32x8_ps(_f5, 1)); + _mm256_storeu_ps(p0 + 64 + 48, _mm512_extractf32x8_ps(_f6, 1)); + _mm256_storeu_ps(p0 + 64 + 56, _mm512_extractf32x8_ps(_f7, 1)); + p0 += 128; + } + if (out_elempack == 4) + { + // 00 10 20 30 40 50 60 70 08 18 28 38 48 58 68 78 + // 01 11 21 31 41 51 61 71 09 19 29 39 49 59 69 79 + // 02 12 22 32 42 52 62 72 0a 1a 2a 3a 4a 5a 6a 7a + // 03 13 23 33 43 53 63 73 0b 1b 2b 3b 4b 5b 6b 7b + // 04 14 24 34 44 54 64 74 0c 1c 2c 3c 4c 5c 6c 7c + // 05 15 25 35 45 55 65 75 0d 1d 2d 3d 4d 5d 6d 7d + // 06 16 26 36 46 56 66 76 0e 1e 2e 3e 4e 5e 6e 7e + // 07 17 27 37 47 57 67 77 0f 1f 2f 3f 4f 5f 6f 7f + + __m512 _tmp0 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_f4, _f5, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_f6, _f7, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _tmp4 = _mm512_shuffle_f32x4(_f0, _f1, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmp5 = _mm512_shuffle_f32x4(_f2, _f3, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmp6 = _mm512_shuffle_f32x4(_f4, _f5, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _tmp7 = _mm512_shuffle_f32x4(_f6, _f7, _MM_SHUFFLE(3, 1, 3, 1)); + + // 00 08 01 09 + // 02 0a 03 0b + // 04 0c 05 0d + // 06 0e 06 0f + // 40 48 41 49 + + _f0 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _f1 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _f2 = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _f3 = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + _f4 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _f5 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _f6 = _mm512_shuffle_f32x4(_tmp4, _tmp5, _MM_SHUFFLE(3, 1, 3, 1)); + _f7 = _mm512_shuffle_f32x4(_tmp6, _tmp7, _MM_SHUFFLE(3, 1, 3, 1)); + + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + 16, _f1); + _mm512_storeu_ps(p0 + 32, _f2); + _mm512_storeu_ps(p0 + 48, _f3); + _mm512_storeu_ps(p0 + out_hstep * 4, _f4); + _mm512_storeu_ps(p0 + out_hstep * 4 + 16, _f5); + _mm512_storeu_ps(p0 + out_hstep * 4 + 32, _f6); + _mm512_storeu_ps(p0 + out_hstep * 4 + 48, _f7); + p0 += 64; + } + if (out_elempack == 1) + { + transpose16x8_ps(_f0, _f1, _f2, _f3, _f4, _f5, _f6, _f7); + + _mm256_storeu_ps(p0, _mm512_extractf32x8_ps(_f0, 0)); + _mm256_storeu_ps(p0 + 8, _mm512_extractf32x8_ps(_f4, 0)); + _mm256_storeu_ps(p0 + out_hstep, _mm512_extractf32x8_ps(_f0, 1)); + _mm256_storeu_ps(p0 + out_hstep + 8, _mm512_extractf32x8_ps(_f4, 1)); + _mm256_storeu_ps(p0 + out_hstep * 2, _mm512_extractf32x8_ps(_f1, 0)); + _mm256_storeu_ps(p0 + out_hstep * 2 + 8, _mm512_extractf32x8_ps(_f5, 0)); + _mm256_storeu_ps(p0 + out_hstep * 3, _mm512_extractf32x8_ps(_f1, 1)); + _mm256_storeu_ps(p0 + out_hstep * 3 + 8, _mm512_extractf32x8_ps(_f5, 1)); + _mm256_storeu_ps(p0 + out_hstep * 4, _mm512_extractf32x8_ps(_f2, 0)); + _mm256_storeu_ps(p0 + out_hstep * 4 + 8, _mm512_extractf32x8_ps(_f6, 0)); + _mm256_storeu_ps(p0 + out_hstep * 5, _mm512_extractf32x8_ps(_f2, 1)); + _mm256_storeu_ps(p0 + out_hstep * 5 + 8, _mm512_extractf32x8_ps(_f6, 1)); + _mm256_storeu_ps(p0 + out_hstep * 6, _mm512_extractf32x8_ps(_f3, 0)); + _mm256_storeu_ps(p0 + out_hstep * 6 + 8, _mm512_extractf32x8_ps(_f7, 0)); + _mm256_storeu_ps(p0 + out_hstep * 7, _mm512_extractf32x8_ps(_f3, 1)); + _mm256_storeu_ps(p0 + out_hstep * 7 + 8, _mm512_extractf32x8_ps(_f7, 1)); + p0 += 16; + } + } + } +#endif // __AVX512F__ for (; jj + 7 < max_jj; jj += 8) { #if __AVX2__ @@ -3389,19 +8547,31 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& } __m128 _descale = _mm_loadu_ps((const float*)descales + i + ii); +#if __AVX512F__ + __m512 _descale_avx512 = _mm512_broadcast_f32x4(_descale); +#endif __m128 _c0; +#if __AVX512F__ + __m512 _c0_avx512; +#endif if (pC) { if (broadcast_type_C == 0) { _c0 = _mm_set1_ps(pC[0] * beta); +#if __AVX512F__ + _c0_avx512 = _mm512_set1_ps(pC[0] * beta); +#endif } if (broadcast_type_C == 1 || broadcast_type_C == 2) { pC = (const float*)C + i + ii; _c0 = _mm_loadu_ps(pC); _c0 = _mm_mul_ps(_c0, _mm_set1_ps(beta)); +#if __AVX512F__ + _c0_avx512 = _mm512_broadcast_f32x4(_c0); +#endif } if (broadcast_type_C == 3) { @@ -3415,6 +8585,281 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& int jj = 0; #if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + __m512i _sum0 = _mm512_loadu_si512((const __m512i*)pp); + __m512i _sum1 = _mm512_loadu_si512((const __m512i*)(pp + 16)); + __m512i _sum2 = _mm512_loadu_si512((const __m512i*)(pp + 32)); + __m512i _sum3 = _mm512_loadu_si512((const __m512i*)(pp + 48)); + + // from + // 00 11 22 33 04 15 26 37 08 19 2a 3b 0c 1d 2e 3f + // 01 12 23 30 05 16 27 34 09 1a 2b 38 0d 1e 2f 3c + // 20 31 02 13 24 35 06 17 28 3a 0a 1b 2c 3d 0e 1f + // 21 32 03 10 25 36 07 14 29 3a 0b 18 2d 3e 0f 1c + // to + // 00 10 20 30 04 14 24 34 08 18 28 38 0c 1c 2c 3c + // 01 11 21 31 05 15 25 35 09 19 29 39 0d 1d 2d 3d + // 02 12 22 32 06 16 26 36 0a 1a 2a 3a 0e 1e 2e 3e + // 03 13 23 33 07 17 27 37 0b 1b 2b 3b 0f 1f 2f 3f + { + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_CBAD); + __m512i _tmp0 = _mm512_unpacklo_epi32(_sum0, _sum3); + __m512i _tmp1 = _mm512_unpackhi_epi32(_sum0, _sum3); + __m512i _tmp2 = _mm512_unpacklo_epi32(_sum2, _sum1); + __m512i _tmp3 = _mm512_unpackhi_epi32(_sum2, _sum1); + _sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm512_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm512_unpackhi_epi64(_tmp3, _tmp1); + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_CBAD); + } + + __m512 _f0 = _mm512_mul_ps(_mm512_cvtepi32_ps(_sum0), _descale_avx512); + __m512 _f1 = _mm512_mul_ps(_mm512_cvtepi32_ps(_sum1), _descale_avx512); + __m512 _f2 = _mm512_mul_ps(_mm512_cvtepi32_ps(_sum2), _descale_avx512); + __m512 _f3 = _mm512_mul_ps(_mm512_cvtepi32_ps(_sum3), _descale_avx512); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c0_avx512); + _f2 = _mm512_add_ps(_f2, _c0_avx512); + _f3 = _mm512_add_ps(_f3, _c0_avx512); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c0_avx512); + _f2 = _mm512_add_ps(_f2, _c0_avx512); + _f3 = _mm512_add_ps(_f3, _c0_avx512); + } + if (broadcast_type_C == 3) + { + // __m128 _c1; + // __m128 _c2; + // __m128 _c3; + // if (c_elempack == 4) + // { + // _c0 = _mm_loadu_ps(pC); + // _c1 = _mm_loadu_ps(pC + 4); + // _c2 = _mm_loadu_ps(pC + 8); + // _c3 = _mm_loadu_ps(pC + 12); + // } + // if (c_elempack == 1) + // { + // _c0 = _mm_loadu_ps(pC); + // _c1 = _mm_loadu_ps(pC + c_hstep); + // _c2 = _mm_loadu_ps(pC + c_hstep * 2); + // _c3 = _mm_loadu_ps(pC + c_hstep * 3); + // _MM_TRANSPOSE4_PS(_c0, _c1, _c2, _c3); + // } + // if (beta == 1.f) + // { + // _f0 = _mm_add_ps(_f0, _c0); + // _f1 = _mm_add_ps(_f1, _c1); + // _f2 = _mm_add_ps(_f2, _c2); + // _f3 = _mm_add_ps(_f3, _c3); + // } + // else + // { + // __m128 _beta = _mm_set1_ps(beta); + // _f0 = _mm_comp_fmadd_ps(_c0, _beta, _f0); + // _f1 = _mm_comp_fmadd_ps(_c1, _beta, _f1); + // _f2 = _mm_comp_fmadd_ps(_c2, _beta, _f2); + // _f3 = _mm_comp_fmadd_ps(_c3, _beta, _f3); + // } + // if (c_elempack == 4) + // { + // _c0 = _mm_loadu_ps(pC + 16); + // _c1 = _mm_loadu_ps(pC + 20); + // _c2 = _mm_loadu_ps(pC + 24); + // _c3 = _mm_loadu_ps(pC + 28); + // pC += 64; + // } + // if (c_elempack == 1) + // { + // _c0 = _mm_loadu_ps(pC + 4); + // _c1 = _mm_loadu_ps(pC + c_hstep + 4); + // _c2 = _mm_loadu_ps(pC + c_hstep * 2 + 4); + // _c3 = _mm_loadu_ps(pC + c_hstep * 3 + 4); + // _MM_TRANSPOSE4_PS(_c0, _c1, _c2, _c3); + // pC += 16; + // } + // if (beta == 1.f) + // { + // _f4 = _mm_add_ps(_f4, _c0); + // _f5 = _mm_add_ps(_f5, _c1); + // _f6 = _mm_add_ps(_f6, _c2); + // _f7 = _mm_add_ps(_f7, _c3); + // } + // else + // { + // __m128 _beta = _mm_set1_ps(beta); + // _f4 = _mm_comp_fmadd_ps(_c0, _beta, _f4); + // _f5 = _mm_comp_fmadd_ps(_c1, _beta, _f5); + // _f6 = _mm_comp_fmadd_ps(_c2, _beta, _f6); + // _f7 = _mm_comp_fmadd_ps(_c3, _beta, _f7); + // } + } + if (broadcast_type_C == 4) + { + // _c0 = _mm_set1_ps(pC[0] * beta); + // __m128 _c1 = _mm_set1_ps(pC[1] * beta); + // __m128 _c2 = _mm_set1_ps(pC[2] * beta); + // __m128 _c3 = _mm_set1_ps(pC[3] * beta); + // + // _f0 = _mm_add_ps(_f0, _c0); + // _f1 = _mm_add_ps(_f1, _c1); + // _f2 = _mm_add_ps(_f2, _c2); + // _f3 = _mm_add_ps(_f3, _c3); + // + // _c0 = _mm_set1_ps(pC[4] * beta); + // _c1 = _mm_set1_ps(pC[5] * beta); + // _c2 = _mm_set1_ps(pC[6] * beta); + // _c3 = _mm_set1_ps(pC[7] * beta); + // + // _f4 = _mm_add_ps(_f4, _c0); + // _f5 = _mm_add_ps(_f5, _c1); + // _f6 = _mm_add_ps(_f6, _c2); + // _f7 = _mm_add_ps(_f7, _c3); + // pC += 16; + } + } + + if (alpha != 1.f) + { + __m512 _alpha = _mm512_set1_ps(alpha); + _f0 = _mm512_mul_ps(_f0, _alpha); + _f1 = _mm512_mul_ps(_f1, _alpha); + _f2 = _mm512_mul_ps(_f2, _alpha); + _f3 = _mm512_mul_ps(_f3, _alpha); + } + + if (output_transpose) + { + if (out_elempack == 16) + { + transpose16x4_ps(_f0, _f1, _f2, _f3); + _mm_storeu_ps(p0, _mm512_extractf32x4_ps(_f0, 0)); + _mm_storeu_ps(p0 + 4, _mm512_extractf32x4_ps(_f1, 0)); + _mm_storeu_ps(p0 + 8, _mm512_extractf32x4_ps(_f2, 0)); + _mm_storeu_ps(p0 + 12, _mm512_extractf32x4_ps(_f3, 0)); + _mm_storeu_ps(p0 + 16, _mm512_extractf32x4_ps(_f0, 1)); + _mm_storeu_ps(p0 + 16 + 4, _mm512_extractf32x4_ps(_f1, 1)); + _mm_storeu_ps(p0 + 16 + 8, _mm512_extractf32x4_ps(_f2, 1)); + _mm_storeu_ps(p0 + 16 + 12, _mm512_extractf32x4_ps(_f3, 1)); + _mm_storeu_ps(p0 + 32, _mm512_extractf32x4_ps(_f0, 2)); + _mm_storeu_ps(p0 + 32 + 4, _mm512_extractf32x4_ps(_f1, 2)); + _mm_storeu_ps(p0 + 32 + 8, _mm512_extractf32x4_ps(_f2, 2)); + _mm_storeu_ps(p0 + 32 + 12, _mm512_extractf32x4_ps(_f3, 2)); + _mm_storeu_ps(p0 + 48, _mm512_extractf32x4_ps(_f0, 3)); + _mm_storeu_ps(p0 + 48 + 4, _mm512_extractf32x4_ps(_f1, 3)); + _mm_storeu_ps(p0 + 48 + 8, _mm512_extractf32x4_ps(_f2, 3)); + _mm_storeu_ps(p0 + 48 + 12, _mm512_extractf32x4_ps(_f3, 3)); + } + if (out_elempack == 8) + { + transpose16x4_ps(_f0, _f1, _f2, _f3); + _mm_storeu_ps(p0, _mm512_extractf32x4_ps(_f0, 0)); + _mm_storeu_ps(p0 + 4, _mm512_extractf32x4_ps(_f1, 0)); + _mm_storeu_ps(p0 + 8, _mm512_extractf32x4_ps(_f0, 1)); + _mm_storeu_ps(p0 + 12, _mm512_extractf32x4_ps(_f1, 1)); + _mm_storeu_ps(p0 + 16, _mm512_extractf32x4_ps(_f0, 2)); + _mm_storeu_ps(p0 + 16 + 4, _mm512_extractf32x4_ps(_f1, 2)); + _mm_storeu_ps(p0 + 16 + 8, _mm512_extractf32x4_ps(_f0, 3)); + _mm_storeu_ps(p0 + 16 + 12, _mm512_extractf32x4_ps(_f1, 3)); + _mm_storeu_ps(p0 + out_hstep * 8, _mm512_extractf32x4_ps(_f2, 0)); + _mm_storeu_ps(p0 + out_hstep * 8 + 4, _mm512_extractf32x4_ps(_f3, 0)); + _mm_storeu_ps(p0 + out_hstep * 8 + 8, _mm512_extractf32x4_ps(_f2, 1)); + _mm_storeu_ps(p0 + out_hstep * 8 + 12, _mm512_extractf32x4_ps(_f3, 1)); + _mm_storeu_ps(p0 + out_hstep * 8 + 16, _mm512_extractf32x4_ps(_f2, 2)); + _mm_storeu_ps(p0 + out_hstep * 8 + 16 + 4, _mm512_extractf32x4_ps(_f3, 2)); + _mm_storeu_ps(p0 + out_hstep * 8 + 16 + 8, _mm512_extractf32x4_ps(_f2, 3)); + _mm_storeu_ps(p0 + out_hstep * 8 + 16 + 12, _mm512_extractf32x4_ps(_f3, 3)); + } + if (out_elempack == 4) + { + transpose16x4_ps(_f0, _f1, _f2, _f3); + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + out_hstep * 4, _f1); + _mm512_storeu_ps(p0 + out_hstep * 8, _f2); + _mm512_storeu_ps(p0 + out_hstep * 12, _f3); + } + if (out_elempack == 1) + { + _mm_storeu_ps(p0, _mm512_extractf32x4_ps(_f0, 0)); + _mm_storeu_ps(p0 + out_hstep, _mm512_extractf32x4_ps(_f1, 0)); + _mm_storeu_ps(p0 + out_hstep * 2, _mm512_extractf32x4_ps(_f2, 0)); + _mm_storeu_ps(p0 + out_hstep * 3, _mm512_extractf32x4_ps(_f3, 0)); + _mm_storeu_ps(p0 + out_hstep * 4, _mm512_extractf32x4_ps(_f0, 1)); + _mm_storeu_ps(p0 + out_hstep * 5, _mm512_extractf32x4_ps(_f1, 1)); + _mm_storeu_ps(p0 + out_hstep * 6, _mm512_extractf32x4_ps(_f2, 1)); + _mm_storeu_ps(p0 + out_hstep * 7, _mm512_extractf32x4_ps(_f3, 1)); + _mm_storeu_ps(p0 + out_hstep * 8, _mm512_extractf32x4_ps(_f0, 2)); + _mm_storeu_ps(p0 + out_hstep * 9, _mm512_extractf32x4_ps(_f1, 2)); + _mm_storeu_ps(p0 + out_hstep * 10, _mm512_extractf32x4_ps(_f2, 2)); + _mm_storeu_ps(p0 + out_hstep * 11, _mm512_extractf32x4_ps(_f3, 2)); + _mm_storeu_ps(p0 + out_hstep * 12, _mm512_extractf32x4_ps(_f0, 3)); + _mm_storeu_ps(p0 + out_hstep * 13, _mm512_extractf32x4_ps(_f1, 3)); + _mm_storeu_ps(p0 + out_hstep * 14, _mm512_extractf32x4_ps(_f2, 3)); + _mm_storeu_ps(p0 + out_hstep * 15, _mm512_extractf32x4_ps(_f3, 3)); + } + p0 += out_hstep * 16; + } + else + { + if (out_elempack == 4) + { + _mm_storeu_ps(p0, _mm512_extractf32x4_ps(_f0, 0)); + _mm_storeu_ps(p0 + 4, _mm512_extractf32x4_ps(_f1, 0)); + _mm_storeu_ps(p0 + 8, _mm512_extractf32x4_ps(_f2, 0)); + _mm_storeu_ps(p0 + 12, _mm512_extractf32x4_ps(_f3, 0)); + _mm_storeu_ps(p0 + 16, _mm512_extractf32x4_ps(_f0, 1)); + _mm_storeu_ps(p0 + 20, _mm512_extractf32x4_ps(_f1, 1)); + _mm_storeu_ps(p0 + 24, _mm512_extractf32x4_ps(_f2, 1)); + _mm_storeu_ps(p0 + 28, _mm512_extractf32x4_ps(_f3, 1)); + _mm_storeu_ps(p0 + 32, _mm512_extractf32x4_ps(_f0, 2)); + _mm_storeu_ps(p0 + 36, _mm512_extractf32x4_ps(_f1, 2)); + _mm_storeu_ps(p0 + 40, _mm512_extractf32x4_ps(_f2, 2)); + _mm_storeu_ps(p0 + 44, _mm512_extractf32x4_ps(_f3, 2)); + _mm_storeu_ps(p0 + 48, _mm512_extractf32x4_ps(_f0, 3)); + _mm_storeu_ps(p0 + 52, _mm512_extractf32x4_ps(_f1, 3)); + _mm_storeu_ps(p0 + 56, _mm512_extractf32x4_ps(_f2, 3)); + _mm_storeu_ps(p0 + 60, _mm512_extractf32x4_ps(_f3, 3)); + p0 += 64; + } + if (out_elempack == 1) + { + transpose16x4_ps(_f0, _f1, _f2, _f3); + _mm_storeu_ps(p0, _mm512_extractf32x4_ps(_f0, 0)); + _mm_storeu_ps(p0 + 4, _mm512_extractf32x4_ps(_f1, 0)); + _mm_storeu_ps(p0 + 8, _mm512_extractf32x4_ps(_f2, 0)); + _mm_storeu_ps(p0 + 12, _mm512_extractf32x4_ps(_f3, 0)); + _mm_storeu_ps(p0 + out_hstep, _mm512_extractf32x4_ps(_f0, 1)); + _mm_storeu_ps(p0 + out_hstep + 4, _mm512_extractf32x4_ps(_f1, 1)); + _mm_storeu_ps(p0 + out_hstep + 8, _mm512_extractf32x4_ps(_f2, 1)); + _mm_storeu_ps(p0 + out_hstep + 12, _mm512_extractf32x4_ps(_f3, 1)); + _mm_storeu_ps(p0 + out_hstep * 2, _mm512_extractf32x4_ps(_f0, 2)); + _mm_storeu_ps(p0 + out_hstep * 2 + 4, _mm512_extractf32x4_ps(_f1, 2)); + _mm_storeu_ps(p0 + out_hstep * 2 + 8, _mm512_extractf32x4_ps(_f2, 2)); + _mm_storeu_ps(p0 + out_hstep * 2 + 12, _mm512_extractf32x4_ps(_f3, 2)); + _mm_storeu_ps(p0 + out_hstep * 3, _mm512_extractf32x4_ps(_f0, 3)); + _mm_storeu_ps(p0 + out_hstep * 3 + 4, _mm512_extractf32x4_ps(_f1, 3)); + _mm_storeu_ps(p0 + out_hstep * 3 + 8, _mm512_extractf32x4_ps(_f2, 3)); + _mm_storeu_ps(p0 + out_hstep * 3 + 12, _mm512_extractf32x4_ps(_f3, 3)); + p0 += 16; + } + } + + pp += 64; + } +#endif // __AVX512F__ for (; jj + 7 < max_jj; jj += 8) { __m128i _sum0 = _mm_loadu_si128((const __m128i*)pp); @@ -4043,6 +9488,10 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& #if __SSE2__ __m128 _descale0 = _mm_set1_ps(descale0); __m128 _descale1 = _mm_set1_ps(descale1); +#if __AVX512F__ + __m512 _descale0_avx512 = _mm512_set1_ps(descale0); + __m512 _descale1_avx512 = _mm512_set1_ps(descale1); +#endif // __AVX512F__ #endif float c0; @@ -4050,6 +9499,10 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& #if __SSE2__ __m128 _c0; __m128 _c1; +#if __AVX512F__ + __m512 _c0_avx512; + __m512 _c1_avx512; +#endif // __AVX512F__ #endif if (pC) { @@ -4058,6 +9511,9 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& c0 = pC[0] * beta; #if __SSE2__ _c0 = _mm_set1_ps(c0); +#if __AVX512F__ + _c0_avx512 = _mm512_set1_ps(c0); +#endif // __AVX512F__ #endif } if (broadcast_type_C == 1 || broadcast_type_C == 2) @@ -4068,6 +9524,10 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& #if __SSE2__ _c0 = _mm_set1_ps(c0); _c1 = _mm_set1_ps(c1); +#if __AVX512F__ + _c0_avx512 = _mm512_set1_ps(c0); + _c1_avx512 = _mm512_set1_ps(c1); +#endif // __AVX512F__ #endif } if (broadcast_type_C == 3) @@ -4075,15 +9535,125 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& // c_elempack == 1 pC = (const float*)C + (i + ii) * c_hstep + j; } - if (broadcast_type_C == 4) + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } + + int jj = 0; +#if __SSE2__ +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + __m512i _sum0 = _mm512_loadu_si512((const __m512i*)pp); + __m512i _sum1 = _mm512_loadu_si512((const __m512i*)(pp + 16)); + + // 00 11 02 13 04 15 06 17 08 19 0a 1b 0c 1d 0e 1f + // 01 12 03 10 05 16 07 14 09 1a 0b 18 0d 1e 0f 1c + + __m512i _tmp0 = _mm512_unpacklo_epi32(_sum0, _sum1); + __m512i _tmp1 = _mm512_unpackhi_epi32(_sum0, _sum1); + + _sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp1); + _sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp1); + + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + + __m512 _f0 = _mm512_mul_ps(_mm512_cvtepi32_ps(_sum0), _descale0_avx512); + __m512 _f1 = _mm512_mul_ps(_mm512_cvtepi32_ps(_sum1), _descale1_avx512); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c0_avx512); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c1_avx512); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0_avx512 = _mm512_loadu_ps(pC); + _c1_avx512 = _mm512_loadu_ps(pC + c_hstep); + if (beta == 1.f) + { + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c1_avx512); + } + else + { + __m512 _beta = _mm512_set1_ps(beta); + _f0 = _mm512_fmadd_ps(_c0_avx512, _beta, _f0); + _f1 = _mm512_fmadd_ps(_c1_avx512, _beta, _f1); + } + pC += 8; + } + if (broadcast_type_C == 4) + { + _c0_avx512 = _mm512_loadu_ps(pC); + _c0_avx512 = _mm512_mul_ps(_c0_avx512, _mm512_set1_ps(beta)); + _f0 = _mm512_add_ps(_f0, _c0_avx512); + _f1 = _mm512_add_ps(_f1, _c0_avx512); + pC += 16; + } + } + + if (alpha != 1.f) + { + __m512 _alpha = _mm512_set1_ps(alpha); + _f0 = _mm512_mul_ps(_f0, _alpha); + _f1 = _mm512_mul_ps(_f1, _alpha); + } + + if (output_transpose) + { + if (out_elempack == 16) + { + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + 16, _f1); + } + if (out_elempack == 8) + { + _mm256_storeu_ps(p0, _mm512_extractf32x8_ps(_f0, 0)); + _mm256_storeu_ps(p0 + 8, _mm512_extractf32x8_ps(_f1, 0)); + _mm256_storeu_ps(p0 + out_hstep * 8, _mm512_extractf32x8_ps(_f0, 1)); + _mm256_storeu_ps(p0 + out_hstep * 8 + 8, _mm512_extractf32x8_ps(_f1, 1)); + } + if (out_elempack == 4) + { + _mm_storeu_ps(p0, _mm512_extractf32x4_ps(_f0, 0)); + _mm_storeu_ps(p0 + 4, _mm512_extractf32x4_ps(_f1, 0)); + _mm_storeu_ps(p0 + out_hstep * 4, _mm512_extractf32x4_ps(_f0, 1)); + _mm_storeu_ps(p0 + out_hstep * 4 + 4, _mm512_extractf32x4_ps(_f1, 1)); + _mm_storeu_ps(p0 + out_hstep * 8, _mm512_extractf32x4_ps(_f0, 2)); + _mm_storeu_ps(p0 + out_hstep * 8 + 4, _mm512_extractf32x4_ps(_f1, 2)); + _mm_storeu_ps(p0 + out_hstep * 12, _mm512_extractf32x4_ps(_f0, 3)); + _mm_storeu_ps(p0 + out_hstep * 12 + 4, _mm512_extractf32x4_ps(_f1, 3)); + } + if (out_elempack == 1) + { + __m512i _vindex = _mm512_mullo_epi32(_mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), _mm512_set1_epi32(out_hstep)); + _mm512_i32scatter_ps(p0, _vindex, _f0, sizeof(float)); + _mm512_i32scatter_ps(p0 + 1, _vindex, _f1, sizeof(float)); + } + p0 += out_hstep * 16; + } + else { - pC = (const float*)C + j; + _mm512_storeu_ps(p0, _f0); + _mm512_storeu_ps(p0 + out_hstep, _f1); + p0 += 16; } - } - int jj = 0; -#if __SSE2__ -#if defined(__x86_64__) || defined(_M_X64) + pp += 32; + } +#endif // __AVX512F__ for (; jj + 7 < max_jj; jj += 8) { __m128i _sum0 = _mm_loadu_si128((const __m128i*)pp); @@ -4493,11 +10063,17 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& const float descale = descales[i + ii]; #if __SSE2__ __m128 _descale = _mm_set1_ps(descale); +#if __AVX512F__ + __m512 _descale_avx512 = _mm512_set1_ps(descale); +#endif // __AVX512F__ #endif float c0; #if __SSE2__ __m128 _c0; +#if __AVX512F__ + __m512 _c0_avx512; +#endif // __AVX512F__ #endif if (pC) { @@ -4506,6 +10082,9 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& c0 = pC[0] * beta; #if __SSE2__ _c0 = _mm_set1_ps(c0); +#if __AVX512F__ + _c0_avx512 = _mm512_set1_ps(c0); +#endif // __AVX512F__ #endif } if (broadcast_type_C == 1 || broadcast_type_C == 2) @@ -4514,6 +10093,9 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& c0 = pC[0] * beta; #if __SSE2__ _c0 = _mm_set1_ps(c0); +#if __AVX512F__ + _c0_avx512 = _mm512_set1_ps(c0); +#endif // __AVX512F__ #endif } if (broadcast_type_C == 3) @@ -4530,6 +10112,72 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& int jj = 0; #if __SSE2__ #if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + __m512 _f0 = _mm512_mul_ps(_mm512_cvtepi32_ps(_mm512_loadu_si512((const __m512i*)pp)), _descale_avx512); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = _mm512_add_ps(_f0, _c0_avx512); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // out_elempack == 1 + _c0_avx512 = _mm512_loadu_ps(pC); + _f0 = _mm512_fmadd_ps(_c0_avx512, _mm512_set1_ps(beta), _f0); + pC += 16; + } + } + + if (alpha != 1.f) + { + _f0 = _mm512_mul_ps(_f0, _mm512_set1_ps(alpha)); + } + + if (output_transpose) + { + if (out_hstep == 1) + { + _mm512_storeu_ps(p0, _f0); + } + else + { + if (out_elempack == 16) + { + _mm512_storeu_ps(p0, _f0); + } + if (out_elempack == 8) + { + _mm256_storeu_ps(p0, _mm512_extractf32x8_ps(_f0, 0)); + _mm256_storeu_ps(p0 + out_hstep * 8, _mm512_extractf32x8_ps(_f0, 1)); + } + if (out_elempack == 4) + { + _mm_storeu_ps(p0, _mm512_extractf32x4_ps(_f0, 0)); + _mm_storeu_ps(p0 + out_hstep * 4, _mm512_extractf32x4_ps(_f0, 1)); + _mm_storeu_ps(p0 + out_hstep * 8, _mm512_extractf32x4_ps(_f0, 2)); + _mm_storeu_ps(p0 + out_hstep * 12, _mm512_extractf32x4_ps(_f0, 3)); + } + if (out_elempack == 1) + { + __m512i _vindex = _mm512_mullo_epi32(_mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), _mm512_set1_epi32(out_hstep)); + _mm512_i32scatter_ps(p0, _vindex, _f0, sizeof(float)); + } + } + p0 += out_hstep * 16; + } + else + { + _mm512_storeu_ps(p0, _f0); + p0 += 16; + } + + pp += 16; + } +#endif // __AVX512F__ for (; jj + 7 < max_jj; jj += 8) { __m128 _f0 = _mm_mul_ps(_mm_cvtepi32_ps(_mm_loadu_si128((const __m128i*)pp)), _descale); @@ -4689,91 +10337,755 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& f0 *= alpha; f1 *= alpha; - if (output_transpose) + if (output_transpose) + { + p0[0] = f0; + p0[out_hstep] = f1; + p0 += out_hstep * 2; + } + else + { + p0[0] = f0; + p0[1] = f1; + p0 += 2; + } + + pp += 2; + } + for (; jj < max_jj; jj++) + { + float f0 = pp[0] * descale; + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + f0 += c0; + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // out_elempack == 1 + f0 += pC[0] * beta; + pC += 1; + } + } + + f0 *= alpha; + + p0[0] = f0; + + if (output_transpose) + { + p0 += out_hstep; + } + else + { + p0++; + } + + pp += 1; + } + } +} + +static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx2()) + { + gemm_transB_packed_tile_int8_avx2(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_xop()) + { + gemm_transB_packed_tile_int8_xop(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); + return; + } +#endif + + NCNN_LOGE("gemm_transB_packed_tile_int8 %d %d %d %d %d %d", i, max_ii, j, max_jj, k, max_kk); + + const signed char* pAT = AT_tile; + const signed char* pBT = BT_tile; + + int* outptr = topT_tile; + + int ii = 0; +#if __SSE2__ +#if __AVX2__ +#if __AVX512F__ + for (; ii + 15 < max_ii; ii += 16) + { + const signed char* pB = pBT; + + int jj = 0; +#if defined(__x86_64__) || defined(_M_X64) + for (; jj + 15 < max_jj; jj += 16) + { + const signed char* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + __m512i _sum2; + __m512i _sum3; + __m512i _sum4; + __m512i _sum5; + __m512i _sum6; + __m512i _sum7; + __m512i _sum8; + __m512i _sum9; + __m512i _suma; + __m512i _sumb; + __m512i _sumc; + __m512i _sumd; + __m512i _sume; + __m512i _sumf; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); + _sum4 = _mm512_setzero_si512(); + _sum5 = _mm512_setzero_si512(); + _sum6 = _mm512_setzero_si512(); + _sum7 = _mm512_setzero_si512(); + _sum8 = _mm512_setzero_si512(); + _sum9 = _mm512_setzero_si512(); + _suma = _mm512_setzero_si512(); + _sumb = _mm512_setzero_si512(); + _sumc = _mm512_setzero_si512(); + _sumd = _mm512_setzero_si512(); + _sume = _mm512_setzero_si512(); + _sumf = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_load_si512((const __m512i*)outptr); + _sum1 = _mm512_load_si512((const __m512i*)(outptr + 16)); + _sum2 = _mm512_load_si512((const __m512i*)(outptr + 32)); + _sum3 = _mm512_load_si512((const __m512i*)(outptr + 48)); + _sum4 = _mm512_load_si512((const __m512i*)(outptr + 64)); + _sum5 = _mm512_load_si512((const __m512i*)(outptr + 80)); + _sum6 = _mm512_load_si512((const __m512i*)(outptr + 96)); + _sum7 = _mm512_load_si512((const __m512i*)(outptr + 112)); + _sum8 = _mm512_load_si512((const __m512i*)(outptr + 128)); + _sum9 = _mm512_load_si512((const __m512i*)(outptr + 128 + 16)); + _suma = _mm512_load_si512((const __m512i*)(outptr + 128 + 32)); + _sumb = _mm512_load_si512((const __m512i*)(outptr + 128 + 48)); + _sumc = _mm512_load_si512((const __m512i*)(outptr + 128 + 64)); + _sumd = _mm512_load_si512((const __m512i*)(outptr + 128 + 80)); + _sume = _mm512_load_si512((const __m512i*)(outptr + 128 + 96)); + _sumf = _mm512_load_si512((const __m512i*)(outptr + 128 + 112)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA = _mm256_loadu_si256((const __m256i*)pA); + __m256i _pB = _mm256_loadu_si256((const __m256i*)pB); + + __m512i _pA0 = _mm512_cvtepi8_epi16(_pA); + __m512i _pB0 = _mm512_cvtepi8_epi16(_pB); + + // 0123 4567 89ab cdef + // 2301 6745 ab89 efcd + // 4567 0123 cdef 89ab + // 6745 2301 efcd ab89 + __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); + __m512i _pA2 = _mm512_shuffle_i32x4(_pA0, _pA0, _MM_SHUFFLE(2, 3, 0, 1)); + __m512i _pA3 = _mm512_shuffle_epi32(_pA2, _MM_PERM_BADC); + + // 0123 4567 89ab cdef + // 1230 5674 9ab8 defc + // 89ab cdef 0123 4567 + // 9ab8 defc 1230 5674 + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB2 = _mm512_shuffle_i32x4(_pB0, _pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512i _pB3 = _mm512_shuffle_epi32(_pB2, _MM_PERM_ADCB); + + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); + _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA1, _pB0)); + _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA1, _pB1)); + _sum4 = _mm512_add_epi32(_sum4, _mm512_madd_epi16(_pA0, _pB2)); + _sum5 = _mm512_add_epi32(_sum5, _mm512_madd_epi16(_pA0, _pB3)); + _sum6 = _mm512_add_epi32(_sum6, _mm512_madd_epi16(_pA1, _pB2)); + _sum7 = _mm512_add_epi32(_sum7, _mm512_madd_epi16(_pA1, _pB3)); + _sum8 = _mm512_add_epi32(_sum8, _mm512_madd_epi16(_pA2, _pB0)); + _sum9 = _mm512_add_epi32(_sum9, _mm512_madd_epi16(_pA2, _pB1)); + _suma = _mm512_add_epi32(_suma, _mm512_madd_epi16(_pA3, _pB0)); + _sumb = _mm512_add_epi32(_sumb, _mm512_madd_epi16(_pA3, _pB1)); + _sumc = _mm512_add_epi32(_sumc, _mm512_madd_epi16(_pA2, _pB2)); + _sumd = _mm512_add_epi32(_sumd, _mm512_madd_epi16(_pA2, _pB3)); + _sume = _mm512_add_epi32(_sume, _mm512_madd_epi16(_pA3, _pB2)); + _sumf = _mm512_add_epi32(_sumf, _mm512_madd_epi16(_pA3, _pB3)); + + pA += 32; + pB += 32; + } + for (; kk < max_kk; kk += 1) + { + __m128i _pA = _mm_load_si128((const __m128i*)pA); + __m128i _pB = _mm_load_si128((const __m128i*)pB); + + __m256i _pA0 = _mm256_cvtepi8_epi16(_pA); + __m256i _pB0 = _mm256_cvtepi8_epi16(_pB); + + // 01234567 89abcdef + // 23016745 ab89efcd + // 45670123 cdef89ab + // 67452301 efcdab89 + __m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(2, 3, 0, 1)); + __m256i _pA2 = _mm256_permute4x64_epi64(_pA0, _MM_SHUFFLE(2, 3, 0, 1)); + __m256i _pA3 = _mm256_shuffle_epi32(_pA2, _MM_SHUFFLE(2, 3, 0, 1)); + + // 01234567 89abcdef + // 12305674 9ab8defc + // 89abcdef 01234567 + // 9ab8defc 12305674 + __m256i _pB1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB0, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); + __m256i _pB2 = _mm256_permute4x64_epi64(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB3 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB2, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); + + __m512i _s0 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB0)); + __m512i _s1 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB1)); + __m512i _s2 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB0)); + __m512i _s3 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB1)); + __m512i _s4 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB2)); + __m512i _s5 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB3)); + __m512i _s6 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB2)); + __m512i _s7 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB3)); + __m512i _s8 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA2, _pB0)); + __m512i _s9 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA2, _pB1)); + __m512i _sa = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA3, _pB0)); + __m512i _sb = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA3, _pB1)); + __m512i _sc = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA2, _pB2)); + __m512i _sd = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA2, _pB3)); + __m512i _se = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA3, _pB2)); + __m512i _sf = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA3, _pB3)); + + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + _sum2 = _mm512_add_epi32(_sum2, _s2); + _sum3 = _mm512_add_epi32(_sum3, _s3); + _sum4 = _mm512_add_epi32(_sum4, _s4); + _sum5 = _mm512_add_epi32(_sum5, _s5); + _sum6 = _mm512_add_epi32(_sum6, _s6); + _sum7 = _mm512_add_epi32(_sum7, _s7); + _sum8 = _mm512_add_epi32(_sum8, _s8); + _sum9 = _mm512_add_epi32(_sum9, _s9); + _suma = _mm512_add_epi32(_suma, _sa); + _sumb = _mm512_add_epi32(_sumb, _sb); + _sumc = _mm512_add_epi32(_sumc, _sc); + _sumd = _mm512_add_epi32(_sumd, _sd); + _sume = _mm512_add_epi32(_sume, _se); + _sumf = _mm512_add_epi32(_sumf, _sf); + + pA += 16; + pB += 16; + } + + _mm512_store_si512((__m512i*)outptr, _sum0); + _mm512_store_si512((__m512i*)(outptr + 16), _sum1); + _mm512_store_si512((__m512i*)(outptr + 32), _sum2); + _mm512_store_si512((__m512i*)(outptr + 48), _sum3); + _mm512_store_si512((__m512i*)(outptr + 64), _sum4); + _mm512_store_si512((__m512i*)(outptr + 80), _sum5); + _mm512_store_si512((__m512i*)(outptr + 96), _sum6); + _mm512_store_si512((__m512i*)(outptr + 112), _sum7); + _mm512_store_si512((__m512i*)(outptr + 128), _sum8); + _mm512_store_si512((__m512i*)(outptr + 128 + 16), _sum9); + _mm512_store_si512((__m512i*)(outptr + 128 + 32), _suma); + _mm512_store_si512((__m512i*)(outptr + 128 + 48), _sumb); + _mm512_store_si512((__m512i*)(outptr + 128 + 64), _sumc); + _mm512_store_si512((__m512i*)(outptr + 128 + 80), _sumd); + _mm512_store_si512((__m512i*)(outptr + 128 + 96), _sume); + _mm512_store_si512((__m512i*)(outptr + 128 + 112), _sumf); + outptr += 256; + } + for (; jj + 7 < max_jj; jj += 8) + { + const signed char* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + __m512i _sum2; + __m512i _sum3; + __m512i _sum4; + __m512i _sum5; + __m512i _sum6; + __m512i _sum7; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); + _sum4 = _mm512_setzero_si512(); + _sum5 = _mm512_setzero_si512(); + _sum6 = _mm512_setzero_si512(); + _sum7 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_load_si512((const __m512i*)outptr); + _sum1 = _mm512_load_si512((const __m512i*)(outptr + 16)); + _sum2 = _mm512_load_si512((const __m512i*)(outptr + 32)); + _sum3 = _mm512_load_si512((const __m512i*)(outptr + 48)); + _sum4 = _mm512_load_si512((const __m512i*)(outptr + 64)); + _sum5 = _mm512_load_si512((const __m512i*)(outptr + 80)); + _sum6 = _mm512_load_si512((const __m512i*)(outptr + 96)); + _sum7 = _mm512_load_si512((const __m512i*)(outptr + 112)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA = _mm256_loadu_si256((const __m256i*)pA); + __m128i _pB = _mm_loadu_si128((const __m128i*)pB); + + __m512i _pA0 = _mm512_cvtepi8_epi16(_pA); + __m256i _pBB = _mm256_cvtepi8_epi16(_pB); + + // 0123 4567 89ab cdef + // 2301 6745 ab89 efcd + __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); + + // 0123 4567 0123 4567 + // 1230 5674 1230 5674 + // 4567 0123 4567 0123 + // 5674 1230 5674 1230 + __m512i _pB0 = _mm512_inserti32x8(_mm512_castsi256_si512(_pBB), _pBB, 1); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB2 = _mm512_permutex_epi64(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512i _pB3 = _mm512_shuffle_epi32(_pB2, _MM_PERM_ADCB); + + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); + _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA1, _pB0)); + _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA1, _pB1)); + _sum4 = _mm512_add_epi32(_sum4, _mm512_madd_epi16(_pA0, _pB2)); + _sum5 = _mm512_add_epi32(_sum5, _mm512_madd_epi16(_pA0, _pB3)); + _sum6 = _mm512_add_epi32(_sum6, _mm512_madd_epi16(_pA1, _pB2)); + _sum7 = _mm512_add_epi32(_sum7, _mm512_madd_epi16(_pA1, _pB3)); + + pA += 32; + pB += 16; + } + for (; kk < max_kk; kk += 1) + { + __m128i _pA = _mm_load_si128((const __m128i*)pA); + __m128i _pB = _mm_loadl_epi64((const __m128i*)pB); + + __m256i _pA0 = _mm256_cvtepi8_epi16(_pA); + _pB = _mm_cvtepi8_epi16(_pB); + + // 01234567 89abcdef + // 23016745 ab89efcd + __m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(2, 3, 0, 1)); + + // 01234567 01234567 + // 12305674 12305674 + // 45670123 45670123 + // 56741230 56741230 + __m256i _pB0 = _mm256_inserti128_si256(_mm256_castsi128_si256(_pB), _pB, 1); + __m256i _pB1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB0, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); + __m256i _pB2 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB3 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB2, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); + + __m512i _s0 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB0)); + __m512i _s1 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB1)); + __m512i _s2 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB0)); + __m512i _s3 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB1)); + __m512i _s4 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB2)); + __m512i _s5 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB3)); + __m512i _s6 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB2)); + __m512i _s7 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB3)); + + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + _sum2 = _mm512_add_epi32(_sum2, _s2); + _sum3 = _mm512_add_epi32(_sum3, _s3); + _sum4 = _mm512_add_epi32(_sum4, _s4); + _sum5 = _mm512_add_epi32(_sum5, _s5); + _sum6 = _mm512_add_epi32(_sum6, _s6); + _sum7 = _mm512_add_epi32(_sum7, _s7); + + pA += 16; + pB += 8; + } + + _mm512_store_si512((__m512i*)outptr, _sum0); + _mm512_store_si512((__m512i*)(outptr + 16), _sum1); + _mm512_store_si512((__m512i*)(outptr + 32), _sum2); + _mm512_store_si512((__m512i*)(outptr + 48), _sum3); + _mm512_store_si512((__m512i*)(outptr + 64), _sum4); + _mm512_store_si512((__m512i*)(outptr + 80), _sum5); + _mm512_store_si512((__m512i*)(outptr + 96), _sum6); + _mm512_store_si512((__m512i*)(outptr + 112), _sum7); + outptr += 128; + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const signed char* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + __m512i _sum2; + __m512i _sum3; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_load_si512((const __m512i*)outptr); + _sum1 = _mm512_load_si512((const __m512i*)(outptr + 16)); + _sum2 = _mm512_load_si512((const __m512i*)(outptr + 32)); + _sum3 = _mm512_load_si512((const __m512i*)(outptr + 48)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA = _mm256_loadu_si256((const __m256i*)pA); + __m256i _pB = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)pB)); + + __m512i _pA0 = _mm512_cvtepi8_epi16(_pA); + __m512i _pB0 = _mm512_cvtepi8_epi16(_pB); + + // 0123 4567 89ab cdef + // 2301 6745 ab89 efcd + __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); + + // 0123 0123 0123 0123 + // 1230 1230 1230 1230 + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); + _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA1, _pB0)); + _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA1, _pB1)); + + pA += 32; + pB += 8; + } + for (; kk < max_kk; kk += 1) + { + __m128i _pA = _mm_load_si128((const __m128i*)pA); + __m128i _pB = _mm_castps_si128(_mm_load1_ps((const float*)pB)); + + __m256i _pA0 = _mm256_cvtepi8_epi16(_pA); + __m256i _pB0 = _mm256_cvtepi8_epi16(_pB); + + // 01234567 89abcdef + // 23016745 ab89efcd + __m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(2, 3, 0, 1)); + + // 01230123 01230123 + // 12301230 12301230 + __m256i _pB1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB0, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); + + __m512i _s0 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB0)); + __m512i _s1 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB1)); + __m512i _s2 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB0)); + __m512i _s3 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB1)); + + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + _sum2 = _mm512_add_epi32(_sum2, _s2); + _sum3 = _mm512_add_epi32(_sum3, _s3); + + pA += 16; + pB += 4; + } + + _mm512_store_si512((__m512i*)outptr, _sum0); + _mm512_store_si512((__m512i*)(outptr + 16), _sum1); + _mm512_store_si512((__m512i*)(outptr + 32), _sum2); + _mm512_store_si512((__m512i*)(outptr + 48), _sum3); + outptr += 64; + } + for (; jj + 1 < max_jj; jj += 2) + { + const signed char* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + + if (k == 0) { - p0[0] = f0; - p0[out_hstep] = f1; - p0 += out_hstep * 2; + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); } else { - p0[0] = f0; - p0[1] = f1; - p0 += 2; + _sum0 = _mm512_load_si512((const __m512i*)outptr); + _sum1 = _mm512_load_si512((const __m512i*)(outptr + 16)); } - pp += 2; - } - for (; jj < max_jj; jj++) - { - float f0 = pp[0] * descale; + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA = _mm256_loadu_si256((const __m256i*)pA); + __m256i _pB = _mm256_castps_si256(_mm256_broadcast_ss((const float*)pB)); - if (pC) + __m512i _pA0 = _mm512_cvtepi8_epi16(_pA); + __m512i _pB0 = _mm512_cvtepi8_epi16(_pB); + + // 0123 4567 89ab cdef + + // 0101 0101 0101 0101 + // 1010 1010 1010 1010 + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CDAB); + + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); + + pA += 32; + pB += 4; + } + for (; kk < max_kk; kk += 1) { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - f0 += c0; - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // out_elempack == 1 - f0 += pC[0] * beta; - pC += 1; - } + __m128i _pA = _mm_load_si128((const __m128i*)pA); + __m128i _pB = _mm_set1_epi16(((const short*)pB)[0]); + + __m256i _pA0 = _mm256_cvtepi8_epi16(_pA); + __m256i _pB0 = _mm256_cvtepi8_epi16(_pB); + + // 01234567 89abcdef + + // 01010101 01010101 + // 10101010 10101010 + __m256i _pB1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB0, _MM_SHUFFLE(0, 1, 0, 1)), _MM_SHUFFLE(0, 1, 0, 1)); + + __m512i _s0 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB0)); + __m512i _s1 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB1)); + + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + + pA += 16; + pB += 2; } - f0 *= alpha; + _mm512_store_si512((__m512i*)outptr, _sum0); + _mm512_store_si512((__m512i*)(outptr + 16), _sum1); + outptr += 32; + } + for (; jj < max_jj; jj += 1) + { + const signed char* pA = pAT; - p0[0] = f0; + __m512i _sum0; - if (output_transpose) + if (k == 0) { - p0 += out_hstep; + _sum0 = _mm512_setzero_si512(); } else { - p0++; + _sum0 = _mm512_load_si512((const __m512i*)outptr); } - pp += 1; - } - } -} + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA = _mm256_loadu_si256((const __m256i*)pA); + __m256i _pB = _mm256_castps_si256(_mm256_broadcast_ss((const float*)pB)); -static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk) -{ -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ - if (ncnn::cpu_support_x86_avx2()) - { - gemm_transB_packed_tile_int8_avx2(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); - return; - } -#endif + __m512i _pA0 = _mm512_cvtepi8_epi16(_pA); + __m512i _pBBBB = _mm512_cvtepi8_epi16(_pB); -#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ - if (ncnn::cpu_support_x86_xop()) - { - gemm_transB_packed_tile_int8_xop(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); - return; - } -#endif + // 0xxx0xxx0xxx0xxx -> 00000000... + __m512i _pB0 = _mm512_shuffle_epi32(_pBBBB, _MM_PERM_AAAA); - NCNN_LOGE("gemm_transB_packed_tile_int8 %d %d %d %d %d %d", i, max_ii, j, max_jj, k, max_kk); + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); - const signed char* pAT = AT_tile; - const signed char* pBT = BT_tile; + pA += 32; + pB += 2; + } + for (; kk < max_kk; kk += 1) + { + __m128i _pA = _mm_load_si128((const __m128i*)pA); + __m256i _pB = _mm256_set1_epi16(pB[0]); - int* outptr = topT_tile; + __m256i _pA0 = _mm256_cvtepi8_epi16(_pA); - int ii = 0; -#if __SSE2__ -#if __AVX2__ + __m512i _s0 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB)); + + _sum0 = _mm512_add_epi32(_sum0, _s0); + + pA += 16; + pB += 1; + } + + _mm512_store_si512((__m512i*)outptr, _sum0); + outptr += 16; + } + + pAT += max_kk * 16; + } +#endif // __AVX512F__ for (; ii + 7 < max_ii; ii += 8) { const signed char* pB = pBT; int jj = 0; #if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + const signed char* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + __m512i _sum2; + __m512i _sum3; + __m512i _sum4; + __m512i _sum5; + __m512i _sum6; + __m512i _sum7; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); + _sum4 = _mm512_setzero_si512(); + _sum5 = _mm512_setzero_si512(); + _sum6 = _mm512_setzero_si512(); + _sum7 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_load_si512((const __m512i*)outptr); + _sum1 = _mm512_load_si512((const __m512i*)(outptr + 16)); + _sum2 = _mm512_load_si512((const __m512i*)(outptr + 32)); + _sum3 = _mm512_load_si512((const __m512i*)(outptr + 48)); + _sum4 = _mm512_load_si512((const __m512i*)(outptr + 64)); + _sum5 = _mm512_load_si512((const __m512i*)(outptr + 80)); + _sum6 = _mm512_load_si512((const __m512i*)(outptr + 96)); + _sum7 = _mm512_load_si512((const __m512i*)(outptr + 112)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA = _mm_loadu_si128((const __m128i*)pA); + __m256i _pB = _mm256_loadu_si256((const __m256i*)pB); + + __m256i _pA0 = _mm256_cvtepi8_epi16(_pA); + __m512i _pB0 = _mm512_cvtepi8_epi16(_pB); + + // 0123 4567 0123 4567 + // 2301 6745 2301 6745 + __m512i _pA00 = _mm512_inserti32x8(_mm512_castsi256_si512(_pA0), _pA0, 1); + __m512i _pA11 = _mm512_shuffle_epi32(_pA00, _MM_PERM_BADC); + + // 0123 4567 89ab cdef + // 1230 5674 9ab8 defc + // 4567 0123 cdef 89ab + // 5674 1230 defc 9ab8 + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB2 = _mm512_permutex_epi64(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512i _pB3 = _mm512_shuffle_epi32(_pB2, _MM_PERM_ADCB); + + // 4567 0123 4567 0123 + // __m512i _pA11 = _mm512_permutex_epi64(_pA00, _MM_SHUFFLE(1, 0, 3, 2)); + + // 2301 6745 ab89 efcd + // 3012 7456 b89a fcde + // __m512i _pB2 = _mm512_shuffle_epi32(_pB0, _MM_PERM_BADC); + // __m512i _pB3 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CBAD); + + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA00, _pB0)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA00, _pB1)); + _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA11, _pB0)); + _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA11, _pB1)); + _sum4 = _mm512_add_epi32(_sum4, _mm512_madd_epi16(_pA00, _pB2)); + _sum5 = _mm512_add_epi32(_sum5, _mm512_madd_epi16(_pA00, _pB3)); + _sum6 = _mm512_add_epi32(_sum6, _mm512_madd_epi16(_pA11, _pB2)); + _sum7 = _mm512_add_epi32(_sum7, _mm512_madd_epi16(_pA11, _pB3)); + + pA += 16; + pB += 32; + } + for (; kk < max_kk; kk += 1) + { + __m128i _pA = _mm_loadl_epi64((const __m128i*)pA); + __m128i _pB = _mm_load_si128((const __m128i*)pB); + + _pA = _mm_cvtepi8_epi16(_pA); + __m256i _pB0 = _mm256_cvtepi8_epi16(_pB); + + // 01234567 01234567 + // 23016745 23016745 + __m256i _pA00 = _mm256_inserti128_si256(_mm256_castsi128_si256(_pA), _pA, 1); + __m256i _pA11 = _mm256_shuffle_epi32(_pA00, _MM_SHUFFLE(2, 3, 0, 1)); + + // 01234567 89abcdef + // 12305674 9ab8defc + // 45670123 cdef89ab + // 56741230 defc9ab8 + __m256i _pB1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB0, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); + __m256i _pB2 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB3 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB2, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); + + // 45670123 45670123 + // __m256i _pA11 = _mm256_permute4x64_epi64(_pA00, _MM_SHUFFLE(2, 3, 0, 1)); + + // 23016745 ab89efcd + // 30127456 b89afcde + // __m256i _pB2 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 3, 0, 1)); + // __m256i _pB3 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB0, _MM_SHUFFLE(2, 1, 0, 3)), _MM_SHUFFLE(2, 1, 0, 3)); + + __m512i _s0 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA00, _pB0)); + __m512i _s1 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA00, _pB1)); + __m512i _s2 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA11, _pB0)); + __m512i _s3 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA11, _pB1)); + __m512i _s4 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA00, _pB2)); + __m512i _s5 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA00, _pB3)); + __m512i _s6 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA11, _pB2)); + __m512i _s7 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA11, _pB3)); + + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + _sum2 = _mm512_add_epi32(_sum2, _s2); + _sum3 = _mm512_add_epi32(_sum3, _s3); + _sum4 = _mm512_add_epi32(_sum4, _s4); + _sum5 = _mm512_add_epi32(_sum5, _s5); + _sum6 = _mm512_add_epi32(_sum6, _s6); + _sum7 = _mm512_add_epi32(_sum7, _s7); + + pA += 8; + pB += 16; + } + + _mm512_store_si512((__m512i*)outptr, _sum0); + _mm512_store_si512((__m512i*)(outptr + 16), _sum1); + _mm512_store_si512((__m512i*)(outptr + 32), _sum2); + _mm512_store_si512((__m512i*)(outptr + 48), _sum3); + _mm512_store_si512((__m512i*)(outptr + 64), _sum4); + _mm512_store_si512((__m512i*)(outptr + 80), _sum5); + _mm512_store_si512((__m512i*)(outptr + 96), _sum6); + _mm512_store_si512((__m512i*)(outptr + 112), _sum7); + + outptr += 128; + } +#endif // __AVX512F__ for (; jj + 7 < max_jj; jj += 8) { const signed char* pA = pAT; @@ -5126,15 +11438,97 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, int jj = 0; #if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + const signed char* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + __m512i _sum2; + __m512i _sum3; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_loadu_si512((const __m512i*)outptr); + _sum1 = _mm512_loadu_si512((const __m512i*)(outptr + 16)); + _sum2 = _mm512_loadu_si512((const __m512i*)(outptr + 32)); + _sum3 = _mm512_loadu_si512((const __m512i*)(outptr + 48)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)pA)); + __m256i _pB = _mm256_loadu_si256((const __m256i*)pB); + + __m512i _pA0 = _mm512_cvtepi8_epi16(_pA); + __m512i _pB0 = _mm512_cvtepi8_epi16(_pB); + + // 0123 0123 0123 0123 + // 2301 2301 2301 2301 + __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); + + // 0123 4567 89ab cdef + // 1230 5674 9ab8 defc + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); + _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA1, _pB0)); + _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA1, _pB1)); + + pA += 8; + pB += 32; + } + for (; kk < max_kk; kk += 1) + { + __m128i _pA = _mm_castps_si128(_mm_load1_ps((const float*)pA)); + __m128i _pB = _mm_load_si128((const __m128i*)pB); + + __m256i _pA0 = _mm256_cvtepi8_epi16(_pA); + __m256i _pB0 = _mm256_cvtepi8_epi16(_pB); + + // 01230123 01230123 + // 23012301 23012301 + __m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(2, 3, 0, 1)); + + // 01234567 89abcdef + // 12305674 9ab8defc + __m256i _pB1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB0, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); + + __m512i _s0 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB0)); + __m512i _s1 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB1)); + __m512i _s2 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB0)); + __m512i _s3 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB1)); + + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + _sum2 = _mm512_add_epi32(_sum2, _s2); + _sum3 = _mm512_add_epi32(_sum3, _s3); + + pA += 4; + pB += 16; + } + + _mm512_storeu_si512((__m512i*)outptr, _sum0); + _mm512_storeu_si512((__m512i*)(outptr + 16), _sum1); + _mm512_storeu_si512((__m512i*)(outptr + 32), _sum2); + _mm512_storeu_si512((__m512i*)(outptr + 48), _sum3); + + outptr += 64; + } +#endif // __AVX512F__ for (; jj + 7 < max_jj; jj += 8) { const signed char* pA = pAT; - // NCNN_LOGE("pA %p", pA); - // if (max_jj == 12) - // { - // NCNN_LOGE("%d %d %d %d %d %d %d %d", pA[0], pA[1], pA[2], pA[3], pA[4], pA[5], pA[6], pA[7]); - // NCNN_LOGE("%d %d %d %d %d %d %d %d", pB[0], pB[1], pB[2], pB[3], pB[4], pB[5], pB[6], pB[7]); - // } __m128i _sum0; __m128i _sum1; @@ -5217,12 +11611,6 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, _sum7 = _mm_add_epi32(_sum7, _mm_madd_epi16(_pA1, _pB3)); #endif - // if (max_jj == 12) - // { - // NCNN_LOGE("%d %d %d %d %d %d %d %d", pA[0], pA[1], pA[2], pA[3], pA[4], pA[5], pA[6], pA[7]); - // NCNN_LOGE("%d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d", pB[0], pB[1], pB[2], pB[3], pB[4], pB[5], pB[6], pB[7], pB[8], pB[9], pB[10], pB[11], pB[12], pB[13], pB[14], pB[15]); - // } - pA += 8; pB += 16; } @@ -5299,12 +11687,6 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, _sum7 = _mm_add_epi32(_sum7, _s7); #endif - // if (max_jj == 12) - // { - // NCNN_LOGE("%d %d %d %d", pA[0], pA[1], pA[2], pA[3]); - // NCNN_LOGE("%d %d %d %d %d %d %d %d", pB[0], pB[1], pB[2], pB[3], pB[4], pB[5], pB[6], pB[7]); - // } - pA += 4; pB += 8; } @@ -5318,11 +11700,6 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, _mm_store_si128((__m128i*)(outptr + 24), _sum6); _mm_store_si128((__m128i*)(outptr + 28), _sum7); - if (max_jj == 12) - { - NCNN_LOGE("outptr %d %d %d %d %d %d %d %d", outptr[0], outptr[1], outptr[2], outptr[3], outptr[4], outptr[5], outptr[6], outptr[7]); - } - outptr += 32; } #endif // defined(__x86_64__) || defined(_M_X64) @@ -5629,6 +12006,75 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, int jj = 0; #if __SSE2__ #if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + __m512i _sum0; + __m512i _sum1; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_loadu_si512((const __m512i*)outptr); + _sum1 = _mm512_loadu_si512((const __m512i*)(outptr + 16)); + } + + const signed char* pA = pAT; + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA = _mm256_castps_si256(_mm256_broadcast_ss((const float*)pA)); + __m256i _pB = _mm256_loadu_si256((const __m256i*)pB); + + __m512i _pA0 = _mm512_cvtepi8_epi16(_pA); + __m512i _pB0 = _mm512_cvtepi8_epi16(_pB); + + // 0101 0101 0101 0101 + + // 0123 4567 89ab cdef + // 1230 5674 9ab8 defc + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); + + pA += 4; + pB += 32; + } + for (; kk < max_kk; kk += 1) + { + __m128i _pA = _mm_set1_epi16(((const short*)pA)[0]); + __m128i _pB = _mm_load_si128((const __m128i*)pB); + + __m256i _pA0 = _mm256_cvtepi8_epi16(_pA); + __m256i _pB0 = _mm256_cvtepi8_epi16(_pB); + + // 01010101 01010101 + + // 01234567 89abcdef + // 12305674 9ab8defc + __m256i _pB1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB0, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); + + __m512i _s0 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB0)); + __m512i _s1 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB1)); + + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + + pA += 2; + pB += 16; + } + + _mm512_storeu_si512((__m512i*)outptr, _sum0); + _mm512_storeu_si512((__m512i*)(outptr + 16), _sum1); + + outptr += 32; + } +#endif // __AVX512F__ for (; jj + 7 < max_jj; jj += 8) { __m128i _sum0; @@ -5927,6 +12373,55 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, int jj = 0; #if __SSE2__ #if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + __m512i _sum0; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_loadu_si512((const __m512i*)outptr); + } + + const signed char* pA = pAT; + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA = _mm256_set1_epi16(((const short*)pA)[0]); + __m256i _pB = _mm256_loadu_si256((const __m256i*)pB); + + __m512i _pA0 = _mm512_cvtepi8_epi16(_pA); + __m512i _pB0 = _mm512_cvtepi8_epi16(_pB); + + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); + + pA += 2; + pB += 32; + } + for (; kk < max_kk; kk += 1) + { + __m256i _pA = _mm256_set1_epi16(pA[0]); + __m128i _pB = _mm_load_si128((const __m128i*)pB); + + __m256i _pB0 = _mm256_cvtepi8_epi16(_pB); + + __m512i _s0 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA, _pB0)); + + _sum0 = _mm512_add_epi32(_sum0, _s0); + + pA += 1; + pB += 16; + } + + _mm512_storeu_si512((__m512i*)outptr, _sum0); + + outptr += 16; + } +#endif // __AVX512F__ for (; jj + 7 < max_jj; jj += 8) { __m128i _sum0;