diff --git a/src/layer/riscv/gemm_bf16s_fp16s.h b/src/layer/riscv/gemm_bf16s_fp16s.h new file mode 100644 index 000000000000..12b9e2e58951 --- /dev/null +++ b/src/layer/riscv/gemm_bf16s_fp16s.h @@ -0,0 +1,1542 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void pack_A_tile_bf16_fp16(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +{ + int vl; + const int elempack = A.elempack; + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + + unsigned short* pp = AT; + + int ii = 0; +#if __riscv_vector + for (; ii + 7 < max_ii; ii += 8) + { + if (elempack == 8) + { + vl = 8; + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k * 8; + + for (int kk = 0; kk < max_kk; kk++) + { + vse16_v_u16m1(pp, vle16_v_u16m1(p0, vl), vl); + // vst1q_u16(pp, vle16_v_u16m1(p0), vl); + pp += 8; + p0 += 8; + } + } + if (elempack == 4) + { + vl = 4; + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k * 4; + const unsigned short* p1 = (const unsigned short*)A + (i + ii + 4) * A_hstep + k * 4; + + for (int kk = 0; kk < max_kk; kk++) + { + vuint16m1_t _r00 = vle16_v_u16m1(p0, vl); + vuint16m1_t _r01 = vle16_v_u16m1(p1, vl); + vse16_v_u16m1(pp, _r00, vl); + vse16_v_u16m1(pp + 4, _r01, vl); + pp += 8; + p0 += 4; + p1 += 4; + } + } + if (elempack == 1) + { + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k; + const unsigned short* p1 = (const unsigned short*)A + (i + ii + 1) * A_hstep + k; + const unsigned short* p2 = (const unsigned short*)A + (i + ii + 2) * A_hstep + k; + const unsigned short* p3 = (const unsigned short*)A + (i + ii + 3) * A_hstep + k; + const unsigned short* p4 = (const unsigned short*)A + (i + ii + 4) * A_hstep + k; + const unsigned short* p5 = (const unsigned short*)A + (i + ii + 5) * A_hstep + k; + const unsigned short* p6 = (const unsigned short*)A + (i + ii + 6) * A_hstep + k; + const unsigned short* p7 = (const unsigned short*)A + (i + ii + 7) * A_hstep + k; + + int kk = 0; + + int n = max_kk; + while (n > 0) + { + vl = vsetvl_e16m1(n); + vuint16m1_t _r0 = vle16_v_u16m1(p0, vl); + vuint16m1_t _r1 = vle16_v_u16m1(p1, vl); + vuint16m1_t _r2 = vle16_v_u16m1(p2, vl); + vuint16m1_t _r3 = vle16_v_u16m1(p3, vl); + vuint16m1_t _r4 = vle16_v_u16m1(p4, vl); + vuint16m1_t _r5 = vle16_v_u16m1(p5, vl); + vuint16m1_t _r6 = vle16_v_u16m1(p6, vl); + vuint16m1_t _r7 = vle16_v_u16m1(p7, vl); + + vsseg8e16_v_u16m1(pp, _r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, vl); + + pp += 8 * vl; + p0 += vl; + p1 += vl; + p2 += vl; + p3 += vl; + p4 += vl; + p5 += vl; + p6 += vl; + p7 += vl; + n -= vl; + } + } + } + for (; ii + 3 < max_ii; ii += 4) + { + if (elempack == 4) + { + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k * 4; + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + vl = 8; + vse16_v_u16m1(pp, vle16_v_u16m1(p0, vl), vl); + pp += 8; + p0 += 8; + } + for (; kk < max_kk; kk++) + { + vl = 4; + vse16_v_u16m1(pp, vle16_v_u16m1(p0, vl), vl); + pp += 4; + p0 += 4; + } + } + if (elempack == 1) + { + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k; + const unsigned short* p1 = (const unsigned short*)A + (i + ii + 1) * A_hstep + k; + const unsigned short* p2 = (const unsigned short*)A + (i + ii + 2) * A_hstep + k; + const unsigned short* p3 = (const unsigned short*)A + (i + ii + 3) * A_hstep + k; + + int kk = 0; + + int n = max_kk; + while (n > 0) + { + vl = vsetvl_e16m1(n); + vuint16m1_t _r0 = vle16_v_u16m1(p0, vl); + vuint16m1_t _r1 = vle16_v_u16m1(p1, vl); + vuint16m1_t _r2 = vle16_v_u16m1(p2, vl); + vuint16m1_t _r3 = vle16_v_u16m1(p3, vl); + vsseg4e16_v_u16m1(pp, _r0, _r1, _r2, _r3, vl); + pp += 4 * vl; + p0 += vl; + p1 += vl; + p2 += vl; + p3 += vl; + n -= vl; + } + } + } +#endif // __riscv_vector + for (; ii + 1 < max_ii; ii += 2) + { + // if (elempack == 1) + { + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k; + const unsigned short* p1 = (const unsigned short*)A + (i + ii + 1) * A_hstep + k; + + int kk = 0; +#if __riscv_vector + int n = max_kk; + while (n > 0) + { + vl = vsetvl_e16m1(n); + vuint16m1_t _r0 = vle16_v_u16m1(p0, vl); + vuint16m1_t _r1 = vle16_v_u16m1(p1, vl); + vsseg2e16_v_u16m1(pp, _r0, _r1, vl); + pp += 2 * vl; + p0 += vl; + p1 += vl; + n -= vl; + } +#else + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp += 2; + p0 += 1; + p1 += 1; + } +#endif // __riscv_vector + } + } + for (; ii < max_ii; ii += 1) + { + // if (elempack == 1) + { + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k; + + int kk = 0; +#if __riscv_vector + int n = max_kk; + while (n > 0) + { + vl = vsetvl_e16m1(n); + vuint16m1_t _r0 = vle16_v_u16m1(p0, vl); + vse16_v_u16m1(pp, _r0, vl); + pp += 1 * vl; + p0 += vl; + n -= vl; + } +#else + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp += 1; + p0 += 1; + } +#endif + } + } +} + +static void transpose_pack_A_tile_bf16_fp16(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +{ + int vl; + const int elempack = A.elempack; + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + + unsigned short* pp = AT; + + int ii = 0; +#if __riscv_vector + for (; ii + 7 < max_ii; ii += 8) + { + if (elempack == 8) + { + vl = 8; + const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * 8; + + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + vuint16m1_t _r0 = vle16_v_u16m1(p0, vl); + vuint16m1_t _r1 = vle16_v_u16m1(p0 + 8, vl); + vuint16m1_t _r2 = vle16_v_u16m1(p0 + 16, vl); + vuint16m1_t _r3 = vle16_v_u16m1(p0 + 24, vl); + vuint16m1_t _r4 = vle16_v_u16m1(p0 + 32, vl); + vuint16m1_t _r5 = vle16_v_u16m1(p0 + 40, vl); + vuint16m1_t _r6 = vle16_v_u16m1(p0 + 48, vl); + vuint16m1_t _r7 = vle16_v_u16m1(p0 + 56, vl); + vsseg8e16_v_u16m1(pp, _r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, vl); + // vse16_v_u16m1(pp, _r0, vl); + // vse16_v_u16m1(pp + 8, _r1, vl); + // vse16_v_u16m1(pp + 16, _r2, vl); + // vse16_v_u16m1(pp + 24, _r3, vl); + // vse16_v_u16m1(pp + 32, _r4, vl); + // vse16_v_u16m1(pp + 40, _r5, vl); + // vse16_v_u16m1(pp + 48, _r6, vl); + // vse16_v_u16m1(pp + 56, _r7, vl); + pp += 64; + p0 += A_hstep * 8; + } + } + if (elempack == 4) + { + vl = 8; + const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * 4; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vuint16m1_t _r0; + vuint16m1_t _r1; + vuint16m1_t _r2; + vuint16m1_t _r3; + vlseg4e16_v_u16m1(&_r0, &_r1, &_r2, &_r3, p0, vl); + vse16_v_u16m1(pp, _r0, vl); + vse16_v_u16m1(pp + 8, _r1, vl); + vse16_v_u16m1(pp + 16, _r2, vl); + vse16_v_u16m1(pp + 24, _r3, vl); + pp += 32; + p0 += A_hstep * 4; + } + } + if (elempack == 1) + { + vl = 8; + const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii); + + int kk = 0; + for (; kk < max_kk; kk++) + { + vse16_v_u16m1(pp, vle16_v_u16m1(p0, vl), vl); + pp += 8; + p0 += A_hstep; + } + } + } + for (; ii + 3 < max_ii; ii += 4) + { + if (elempack == 8) + { + vl = 8; + const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * 8; + + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + vuint16m1_t _r0 = vle16_v_u16m1(p0, vl); + vuint16m1_t _r1 = vle16_v_u16m1(p0 + 8, vl); + vuint16m1_t _r2 = vle16_v_u16m1(p0 + 16, vl); + vuint16m1_t _r3 = vle16_v_u16m1(p0 + 24, vl); + vsseg4e16_v_u16m1(pp, _r0, _r1, _r2, _r3, vl); + pp += 32; + p0 += A_hstep * 8; + } + } + if (elempack == 4) + { + vl = 4; + const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * 4; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vuint16m1_t _r0; + vuint16m1_t _r1; + vuint16m1_t _r2; + vuint16m1_t _r3; + vlseg4e16_v_u16m1(&_r0, &_r1, &_r2, &_r3, p0, vl); + vse16_v_u16m1(pp, _r0, vl); + vse16_v_u16m1(pp + 4, _r1, vl); + vse16_v_u16m1(pp + 8, _r2, vl); + vse16_v_u16m1(pp + 12, _r3, vl); + pp += 16; + p0 += A_hstep * 4; + } + } + if (elempack == 1) + { + vl = 4; + const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii); + + int kk = 0; + for (; kk < max_kk; kk++) + { + vse16_v_u16m1(pp, vle16_v_u16m1(p0, vl), vl); + pp += 4; + p0 += A_hstep; + } + } + } +#endif // __riscv_vector + for (; ii + 1 < max_ii; ii += 2) + { +#if __riscv_vector + if (elempack == 8) + { + vl = 8; + const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * 8; + + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + vuint16m1_t _r0 = vle16_v_u16m1(p0, vl); + vuint16m1_t _r1 = vle16_v_u16m1(p0 + 8, vl); + vsseg2e16_v_u16m1(pp, _r0, _r1, vl); + pp += 16; + p0 += A_hstep * 8; + } + } + if (elempack == 4) + { + vl = 4; + const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * 4; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vuint16m1_t _r0 = vle16_v_u16m1(p0, vl); + vuint16m1_t _r1 = vle16_v_u16m1(p0 + 4, vl); + vsseg2e16_v_u16m1(pp, _r0, _r1, vl); + + pp += 8; + p0 += A_hstep * 4; + } + } +#endif // __riscv_vector + if (elempack == 1) + { + const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii); + + int kk = 0; + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp += 2; + p0 += A_hstep; + } + } + } + for (; ii < max_ii; ii += 1) + { +#if __riscv_vector + if (elempack == 8) + { + vl = 8; + const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * 8; + + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + vse16_v_u16m1(pp, vle16_v_u16m1(p0, vl), vl); + pp += 8; + p0 += A_hstep * 8; + } + } + if (elempack == 4) + { + vl = 4; + const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * 4; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vse16_v_u16m1(pp, vle16_v_u16m1(p0, vl), vl); + pp += 4; + p0 += A_hstep * 4; + } + } +#endif // __riscv_vector + if (elempack == 1) + { + const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii); + + int kk = 0; + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp += 1; + p0 += A_hstep; + } + } + } +} + +static void pack_B_tile_bf16_fp16(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ + int vl; + const int elempack = B.elempack; + const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; + + unsigned short* pp = BT; + + int jj = 0; +#if __riscv_vector + for (; jj + 11 < max_jj; jj += 12) + { + if (elempack == 8) + { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) / 8 * 8 * B_hstep + k * 8; + const unsigned short* p1 = (const unsigned short*)B + (j + jj + 8) / 8 * 8 * B_hstep + k * 8; + + if ((j + jj) % 8 == 0) + { + for (int kk = 0; kk < max_kk; kk++) + { + vl = 8; + vse16_v_u16m1(pp, vle16_v_u16m1(p0, vl), vl); + vl = 4; + vse16_v_u16m1(pp + 8, vle16_v_u16m1(p1, vl), vl); + pp += 12; + p0 += 8; + p1 += 8; + } + } + if ((j + jj) % 8 == 4) + { + for (int kk = 0; kk < max_kk; kk++) + { + vl = 4; + vse16_v_u16m1(pp, vle16_v_u16m1(p0 + 4, vl), vl); + vl = 8; + vse16_v_u16m1(pp + 4, vle16_v_u16m1(p1, vl), vl); + pp += 12; + p0 += 8; + p1 += 8; + } + } + } + if (elempack == 4) + { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k * 4; + const unsigned short* p1 = (const unsigned short*)B + (j + jj + 4) * B_hstep + k * 4; + const unsigned short* p2 = (const unsigned short*)B + (j + jj + 8) * B_hstep + k * 4; + + for (int kk = 0; kk < max_kk; kk++) + { + vl = 4; + vse16_v_u16m1(pp, vle16_v_u16m1(p0, vl), vl); + vse16_v_u16m1(pp + 4, vle16_v_u16m1(p1, vl), vl); + vse16_v_u16m1(pp + 8, vle16_v_u16m1(p2, vl), vl); + pp += 12; + p0 += 4; + p1 += 4; + p2 += 4; + } + } + if (elempack == 1) + { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k; + const unsigned short* p1 = (const unsigned short*)B + (j + jj + 1) * B_hstep + k; + const unsigned short* p2 = (const unsigned short*)B + (j + jj + 2) * B_hstep + k; + const unsigned short* p3 = (const unsigned short*)B + (j + jj + 3) * B_hstep + k; + const unsigned short* p4 = (const unsigned short*)B + (j + jj + 4) * B_hstep + k; + const unsigned short* p5 = (const unsigned short*)B + (j + jj + 5) * B_hstep + k; + const unsigned short* p6 = (const unsigned short*)B + (j + jj + 6) * B_hstep + k; + const unsigned short* p7 = (const unsigned short*)B + (j + jj + 7) * B_hstep + k; + const unsigned short* p8 = (const unsigned short*)B + (j + jj + 8) * B_hstep + k; + const unsigned short* p9 = (const unsigned short*)B + (j + jj + 9) * B_hstep + k; + const unsigned short* pa = (const unsigned short*)B + (j + jj + 10) * B_hstep + k; + const unsigned short* pb = (const unsigned short*)B + (j + jj + 11) * B_hstep + k; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vl = 4; + vuint16m1_t _r0 = vle16_v_u16m1(p0, vl); + vuint16m1_t _r1 = vle16_v_u16m1(p1, vl); + vuint16m1_t _r2 = vle16_v_u16m1(p2, vl); + vuint16m1_t _r3 = vle16_v_u16m1(p3, vl); + vuint16m1_t _r4 = vle16_v_u16m1(p4, vl); + vuint16m1_t _r5 = vle16_v_u16m1(p5, vl); + vuint16m1_t _r6 = vle16_v_u16m1(p6, vl); + vuint16m1_t _r7 = vle16_v_u16m1(p7, vl); + vuint16m1_t _r8 = vle16_v_u16m1(p8, vl); + vuint16m1_t _r9 = vle16_v_u16m1(p9, vl); + vuint16m1_t _ra = vle16_v_u16m1(pa, vl); + vuint16m1_t _rb = vle16_v_u16m1(pb, vl); + + vsse16_v_u16m1(pp, 12 * sizeof(unsigned short), _r0, vl); + vsse16_v_u16m1(pp + 1, 12 * sizeof(unsigned short), _r1, vl); + vsse16_v_u16m1(pp + 2, 12 * sizeof(unsigned short), _r2, vl); + vsse16_v_u16m1(pp + 3, 12 * sizeof(unsigned short), _r3, vl); + vsse16_v_u16m1(pp + 4, 12 * sizeof(unsigned short), _r4, vl); + vsse16_v_u16m1(pp + 5, 12 * sizeof(unsigned short), _r5, vl); + vsse16_v_u16m1(pp + 6, 12 * sizeof(unsigned short), _r6, vl); + vsse16_v_u16m1(pp + 7, 12 * sizeof(unsigned short), _r7, vl); + vsse16_v_u16m1(pp + 8, 12 * sizeof(unsigned short), _r8, vl); + vsse16_v_u16m1(pp + 9, 12 * sizeof(unsigned short), _r9, vl); + vsse16_v_u16m1(pp + 10, 12 * sizeof(unsigned short), _ra, vl); + vsse16_v_u16m1(pp + 11, 12 * sizeof(unsigned short), _rb, vl); + + pp += 48; + p0 += 4; + p1 += 4; + p2 += 4; + p3 += 4; + p4 += 4; + p5 += 4; + p6 += 4; + p7 += 4; + p8 += 4; + p9 += 4; + pa += 4; + pb += 4; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp[2] = p2[0]; + pp[3] = p3[0]; + pp[4] = p4[0]; + pp[5] = p5[0]; + pp[6] = p6[0]; + pp[7] = p7[0]; + pp[8] = p8[0]; + pp[9] = p9[0]; + pp[10] = pa[0]; + pp[11] = pb[0]; + pp += 12; + p0++; + p1++; + p2++; + p3++; + p4++; + p5++; + p6++; + p7++; + p8++; + p9++; + pa++; + pb++; + } + } + } + for (; jj + 7 < max_jj; jj += 8) + { + if (elempack == 8) + { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) / 8 * 8 * B_hstep + k * 8; + const unsigned short* p1 = (const unsigned short*)B + (j + jj + 8) / 8 * 8 * B_hstep + k * 8; + + if ((j + jj) % 8 == 0) + { + for (int kk = 0; kk < max_kk; kk++) + { + vl = 8; + vse16_v_u16m1(pp, vle16_v_u16m1(p0, vl), vl); + pp += 8; + p0 += 8; + } + } + if ((j + jj) % 8 == 4) + { + for (int kk = 0; kk < max_kk; kk++) + { + vl = 4; + vse16_v_u16m1(pp, vle16_v_u16m1(p0 + 4, vl), vl); + vse16_v_u16m1(pp + 4, vle16_v_u16m1(p1, vl), vl); + pp += 8; + p0 += 8; + p1 += 8; + } + } + } + if (elempack == 4) + { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k * 4; + const unsigned short* p1 = (const unsigned short*)B + (j + jj + 4) * B_hstep + k * 4; + + for (int kk = 0; kk < max_kk; kk++) + { + vl = 4; + vse16_v_u16m1(pp, vle16_v_u16m1(p0, vl), vl); + vse16_v_u16m1(pp + 4, vle16_v_u16m1(p1, vl), vl); + pp += 8; + p0 += 4; + p1 += 4; + } + } + if (elempack == 1) + { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k; + const unsigned short* p1 = (const unsigned short*)B + (j + jj + 1) * B_hstep + k; + const unsigned short* p2 = (const unsigned short*)B + (j + jj + 2) * B_hstep + k; + const unsigned short* p3 = (const unsigned short*)B + (j + jj + 3) * B_hstep + k; + const unsigned short* p4 = (const unsigned short*)B + (j + jj + 4) * B_hstep + k; + const unsigned short* p5 = (const unsigned short*)B + (j + jj + 5) * B_hstep + k; + const unsigned short* p6 = (const unsigned short*)B + (j + jj + 6) * B_hstep + k; + const unsigned short* p7 = (const unsigned short*)B + (j + jj + 7) * B_hstep + k; + + int kk = 0; + + int n = max_kk; + while (n > 0) + { + vl = vsetvl_e16m1(n); + vuint16m1_t _r0 = vle16_v_u16m1(p0, vl); + vuint16m1_t _r1 = vle16_v_u16m1(p1, vl); + vuint16m1_t _r2 = vle16_v_u16m1(p2, vl); + vuint16m1_t _r3 = vle16_v_u16m1(p3, vl); + vuint16m1_t _r4 = vle16_v_u16m1(p4, vl); + vuint16m1_t _r5 = vle16_v_u16m1(p5, vl); + vuint16m1_t _r6 = vle16_v_u16m1(p6, vl); + vuint16m1_t _r7 = vle16_v_u16m1(p7, vl); + vsseg8e16_v_u16m1(pp, _r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, vl); + + pp += 8 * vl; + p0 += vl; + p1 += vl; + p2 += vl; + p3 += vl; + p4 += vl; + p5 += vl; + p6 += vl; + p7 += vl; + + n -= vl; + } + } + } + for (; jj + 3 < max_jj; jj += 4) + { + if (elempack == 8) + { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) / 8 * 8 * B_hstep + k * 8; + + if ((j + jj) % 8 == 0) + { + for (int kk = 0; kk < max_kk; kk++) + { + vl = 4; + vse16_v_u16m1(pp, vle16_v_u16m1(p0, vl), vl); + pp += 4; + p0 += 8; + } + } + if ((j + jj) % 8 == 4) + { + for (int kk = 0; kk < max_kk; kk++) + { + vl = 4; + vse16_v_u16m1(pp, vle16_v_u16m1(p0 + 4, vl), vl); + pp += 4; + p0 += 8; + } + } + } + if (elempack == 4) + { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k * 4; + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + vl = 8; + vse16_v_u16m1(pp, vle16_v_u16m1(p0, vl), vl); + pp += 8; + p0 += 8; + } + for (; kk < max_kk; kk++) + { + vl = 4; + vse16_v_u16m1(pp, vle16_v_u16m1(p0, vl), vl); + pp += 4; + p0 += 4; + } + } + if (elempack == 1) + { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k; + const unsigned short* p1 = (const unsigned short*)B + (j + jj + 1) * B_hstep + k; + const unsigned short* p2 = (const unsigned short*)B + (j + jj + 2) * B_hstep + k; + const unsigned short* p3 = (const unsigned short*)B + (j + jj + 3) * B_hstep + k; + + int kk = 0; + + int n = max_kk; + while (n > 0) + { + vl = vsetvl_e16m1(n); + vuint16m1_t _r0 = vle16_v_u16m1(p0, vl); + vuint16m1_t _r1 = vle16_v_u16m1(p1, vl); + vuint16m1_t _r2 = vle16_v_u16m1(p2, vl); + vuint16m1_t _r3 = vle16_v_u16m1(p3, vl); + vsseg4e16_v_u16m1(pp, _r0, _r1, _r2, _r3, vl); + pp += 4 * vl; + p0 += vl; + p1 += vl; + p2 += vl; + p3 += vl; + n -= vl; + } + } + } +#endif // __riscv_vector + for (; jj + 1 < max_jj; jj += 2) + { + // if (elempack == 1) + { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k; + const unsigned short* p1 = (const unsigned short*)B + (j + jj + 1) * B_hstep + k; + + int kk = 0; +#if __riscv_vector + + int n = max_kk; + while (n > 0) + { + vl = vsetvl_e16m1(n); + vuint16m1_t _r0 = vle16_v_u16m1(p0, vl); + vuint16m1_t _r1 = vle16_v_u16m1(p1, vl); + vsseg2e16_v_u16m1(pp, _r0, _r1, vl); + pp += 2 * vl; + p0 += vl; + p1 += vl; + n -= vl; + } +#else + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp += 2; + p0++; + p1++; + } +#endif // __riscv_vector + } + } + for (; jj < max_jj; jj += 1) + { + // if (elempack == 1) + { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k; + + int kk = 0; +#if __riscv_vector + + int n = max_kk; + while (n > 0) + { + vl = vsetvl_e16m1(n); + vuint16m1_t _r0 = vle16_v_u16m1(p0, vl); + vse16_v_u16m1(pp, _r0, vl); + pp += vl; + p0 += vl; + n -= vl; + } +#else + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp += 1; + p0++; + } +#endif // __riscv_vector + } + } +} + +static void transpose_pack_B_tile_bf16_fp16(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ + int vl; + const int elempack = B.elempack; + const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; + + unsigned short* pp = BT; + + int jj = 0; +#if __riscv_vector + for (; jj + 11 < max_jj; jj += 12) + { + if (elempack == 8) + { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * 8; + + int kk = 0; + + for (; kk + 7 < max_kk; kk += 8) + { + vl = 8; + vuint16m1_t _r0 = vle16_v_u16m1(p0, vl); + vuint16m1_t _r1 = vle16_v_u16m1(p0 + 8, vl); + vuint16m1_t _r2 = vle16_v_u16m1(p0 + 16, vl); + vuint16m1_t _r3 = vle16_v_u16m1(p0 + 24, vl); + vuint16m1_t _r4 = vle16_v_u16m1(p0 + 32, vl); + vuint16m1_t _r5 = vle16_v_u16m1(p0 + 40, vl); + vuint16m1_t _r6 = vle16_v_u16m1(p0 + 48, vl); + vuint16m1_t _r7 = vle16_v_u16m1(p0 + 56, vl); + vuint16m1_t _r8 = vle16_v_u16m1(p0 + 64, vl); + vuint16m1_t _r9 = vle16_v_u16m1(p0 + 72, vl); + vuint16m1_t _ra = vle16_v_u16m1(p0 + 80, vl); + vuint16m1_t _rb = vle16_v_u16m1(p0 + 88, vl); + + vsse16_v_u16m1(pp, 12 * sizeof(unsigned short), _r0, vl); + vsse16_v_u16m1(pp + 1, 12 * sizeof(unsigned short), _r1, vl); + vsse16_v_u16m1(pp + 2, 12 * sizeof(unsigned short), _r2, vl); + vsse16_v_u16m1(pp + 3, 12 * sizeof(unsigned short), _r3, vl); + vsse16_v_u16m1(pp + 4, 12 * sizeof(unsigned short), _r4, vl); + vsse16_v_u16m1(pp + 5, 12 * sizeof(unsigned short), _r5, vl); + vsse16_v_u16m1(pp + 6, 12 * sizeof(unsigned short), _r6, vl); + vsse16_v_u16m1(pp + 7, 12 * sizeof(unsigned short), _r7, vl); + vsse16_v_u16m1(pp + 8, 12 * sizeof(unsigned short), _r8, vl); + vsse16_v_u16m1(pp + 9, 12 * sizeof(unsigned short), _r9, vl); + vsse16_v_u16m1(pp + 10, 12 * sizeof(unsigned short), _ra, vl); + vsse16_v_u16m1(pp + 11, 12 * sizeof(unsigned short), _rb, vl); + + pp += 96; + p0 += B_hstep * 8; + } + } + if (elempack == 4) + { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * 4; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vuint16m1_t _r0; + vuint16m1_t _r1; + vuint16m1_t _r2; + vuint16m1_t _r3; + vl = 8; + vlseg4e16_v_u16m1(&_r0, &_r1, &_r2, &_r3, p0, vl); + vuint16m1_t _r8; + vuint16m1_t _r9; + vuint16m1_t _ra; + vuint16m1_t _rb; + vl = 4; + vlseg4e16_v_u16m1(&_r8, &_r9, &_ra, &_rb, p0 + 32, vl); + vl = 8; + vse16_v_u16m1(pp, _r0, vl); + vl = 4; + vse16_v_u16m1(pp + 8, _r8, vl); + vl = 8; + vse16_v_u16m1(pp + 12, _r1, vl); + vl = 4; + vse16_v_u16m1(pp + 20, _r9, vl); + vl = 8; + vse16_v_u16m1(pp + 24, _r2, vl); + vl = 4; + vse16_v_u16m1(pp + 32, _ra, vl); + vl = 8; + vse16_v_u16m1(pp + 36, _r3, vl); + vl = 4; + vse16_v_u16m1(pp + 44, _rb, vl); + + pp += 48; + p0 += B_hstep * 4; + } + } + if (elempack == 1) + { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj); + + int kk = 0; + for (; kk < max_kk; kk++) + { + vl = 12; + vse16_v_u16m2(pp, vle16_v_u16m2(p0, vl), vl); + pp += 12; + p0 += B_hstep; + } + } + } + for (; jj + 7 < max_jj; jj += 8) + { + if (elempack == 8) + { + vl = 8; + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * 8; + + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + vuint16m1_t _r0 = vle16_v_u16m1(p0, vl); + vuint16m1_t _r1 = vle16_v_u16m1(p0 + 8, vl); + vuint16m1_t _r2 = vle16_v_u16m1(p0 + 16, vl); + vuint16m1_t _r3 = vle16_v_u16m1(p0 + 24, vl); + vuint16m1_t _r4 = vle16_v_u16m1(p0 + 32, vl); + vuint16m1_t _r5 = vle16_v_u16m1(p0 + 40, vl); + vuint16m1_t _r6 = vle16_v_u16m1(p0 + 48, vl); + vuint16m1_t _r7 = vle16_v_u16m1(p0 + 56, vl); + vsseg8e16_v_u16m1(pp, _r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, vl); + pp += 64; + p0 += B_hstep * 8; + } + } + if (elempack == 4) + { + vl = 8; + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * 4; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vuint16m1_t _r0; + vuint16m1_t _r1; + vuint16m1_t _r2; + vuint16m1_t _r3; + + vlseg4e16_v_u16m1(&_r0, &_r1, &_r2, &_r3, p0, vl); + + vse16_v_u16m1(pp, _r0, vl); + vse16_v_u16m1(pp + 8, _r1, vl); + vse16_v_u16m1(pp + 16, _r2, vl); + vse16_v_u16m1(pp + 24, _r3, vl); + + pp += 32; + p0 += B_hstep * 4; + } + } + if (elempack == 1) + { + vl = 8; + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj); + + int kk = 0; + for (; kk < max_kk; kk++) + { + vse16_v_u16m1(pp, vle16_v_u16m1(p0, vl), vl); + pp += 8; + p0 += B_hstep; + } + } + } + for (; jj + 3 < max_jj; jj += 4) + { + if (elempack == 8) + { + vl = 8; + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * 8; + + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + vuint16m1_t _r0 = vle16_v_u16m1(p0, vl); + vuint16m1_t _r1 = vle16_v_u16m1(p0 + 8, vl); + vuint16m1_t _r2 = vle16_v_u16m1(p0 + 16, vl); + vuint16m1_t _r3 = vle16_v_u16m1(p0 + 24, vl); + + vsseg4e16_v_u16m1(pp, _r0, _r1, _r2, _r3, vl); + pp += 32; + p0 += B_hstep * 8; + } + } + if (elempack == 4) + { + vl = 4; + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * 4; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vuint16m1_t _r0; + vuint16m1_t _r1; + vuint16m1_t _r2; + vuint16m1_t _r3; + + vlseg4e16_v_u16m1(&_r0, &_r1, &_r2, &_r3, p0, vl); + + vse16_v_u16m1(pp, _r0, vl); + vse16_v_u16m1(pp + 4, _r1, vl); + vse16_v_u16m1(pp + 8, _r2, vl); + vse16_v_u16m1(pp + 12, _r3, vl); + pp += 16; + p0 += B_hstep * 4; + } + } + if (elempack == 1) + { + vl = 4; + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj); + + int kk = 0; + for (; kk < max_kk; kk++) + { + vse16_v_u16m1(pp, vle16_v_u16m1(p0, vl), vl); + pp += 4; + p0 += B_hstep; + } + } + } +#endif // __riscv_vector + for (; jj + 1 < max_jj; jj += 2) + { +#if __riscv_vector + if (elempack == 8) + { + vl = 8; + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * 8; + + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + vuint16m1_t _r0 = vle16_v_u16m1(p0, vl); + vuint16m1_t _r1 = vle16_v_u16m1(p0 + 8, vl); + vsseg2e16_v_u16m1(pp, _r0, _r1, vl); + pp += 16; + p0 += B_hstep * 8; + } + } + if (elempack == 4) + { + vl = 4; + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * 4; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vuint16m1_t _r0 = vle16_v_u16m1(p0, vl); + vuint16m1_t _r1 = vle16_v_u16m1(p0 + 4, vl); + vsseg2e16_v_u16m1(pp, _r0, _r1, vl); + pp += 8; + p0 += B_hstep * 4; + } + } +#endif // __riscv_vector + if (elempack == 1) + { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj); + + int kk = 0; + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp += 2; + p0 += B_hstep; + } + } + } + for (; jj < max_jj; jj += 1) + { +#if __riscv_vector + if (elempack == 8) + { + vl = 8; + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * 8; + + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + vse16_v_u16m1(pp, vle16_v_u16m1(p0, vl), vl); + pp += 8; + p0 += B_hstep * 8; + } + } + if (elempack == 4) + { + vl = 4; + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * 4; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vse16_v_u16m1(pp, vle16_v_u16m1(p0, vl), vl); + pp += 4; + p0 += B_hstep * 4; + } + } +#endif // __riscv_vector + if (elempack == 1) + { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj); + + int kk = 0; + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp += 1; + p0 += B_hstep; + } + } + } +} + +static void transpose_unpack_output_tile_bf16_fp16(const Mat& topT, Mat& top_blob, int i, int max_ii, int j, int max_jj) +{ + int vl; + const int out_elempack = top_blob.elempack; + const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; + + const unsigned short* pp = topT; + + int ii = 0; +#if __riscv_vector + for (; ii + 7 < max_ii; ii += 8) + { + if (out_elempack == 8) + { + unsigned short* p0 = (unsigned short*)top_blob + (j / 8 * 8) * out_hstep + (i + ii) * 8; + + int jj = 0; + if (j % 8 == 4) + { + vl = 4; + vuint16m1_t _r0 = vle16_v_u16m1(pp, vl); + vuint16m1_t _r1 = vle16_v_u16m1(pp + 4, vl); + vuint16m1_t _r2 = vle16_v_u16m1(pp + 8, vl); + vuint16m1_t _r3 = vle16_v_u16m1(pp + 12, vl); + vuint16m1_t _r4 = vle16_v_u16m1(pp + 16, vl); + vuint16m1_t _r5 = vle16_v_u16m1(pp + 20, vl); + vuint16m1_t _r6 = vle16_v_u16m1(pp + 24, vl); + vuint16m1_t _r7 = vle16_v_u16m1(pp + 28, vl); + + vsseg8e16_v_u16m1(p0 + 4, _r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, vl); + + pp += 32; + p0 += out_hstep * 8; + jj += 4; + } + for (; jj + 7 < max_jj; jj += 8) + { + vl = 8; + vuint16m1_t _r0 = vle16_v_u16m1(pp, vl); + vuint16m1_t _r1 = vle16_v_u16m1(pp + 8, vl); + vuint16m1_t _r2 = vle16_v_u16m1(pp + 8 * 2, vl); + vuint16m1_t _r3 = vle16_v_u16m1(pp + 8 * 3, vl); + vuint16m1_t _r4 = vle16_v_u16m1(pp + 8 * 4, vl); + vuint16m1_t _r5 = vle16_v_u16m1(pp + 8 * 5, vl); + vuint16m1_t _r6 = vle16_v_u16m1(pp + 8 * 6, vl); + vuint16m1_t _r7 = vle16_v_u16m1(pp + 8 * 7, vl); + + vsseg8e16_v_u16m1(p0, _r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, vl); + + pp += 64; + p0 += out_hstep * 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + vl = 4; + vuint16m1_t _r0 = vle16_v_u16m1(pp, vl); + vuint16m1_t _r1 = vle16_v_u16m1(pp + 4, vl); + vuint16m1_t _r2 = vle16_v_u16m1(pp + 8, vl); + vuint16m1_t _r3 = vle16_v_u16m1(pp + 12, vl); + vuint16m1_t _r4 = vle16_v_u16m1(pp + 16, vl); + vuint16m1_t _r5 = vle16_v_u16m1(pp + 20, vl); + vuint16m1_t _r6 = vle16_v_u16m1(pp + 24, vl); + vuint16m1_t _r7 = vle16_v_u16m1(pp + 28, vl); + + vsseg8e16_v_u16m1(p0, _r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, vl); + + pp += 32; + p0 += out_hstep * 8; + } + } + if (out_elempack == 4) + { + unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * 4; + + for (int jj = 0; jj + 3 < max_jj; jj += 4) + { + vl = 8; + vuint16m1_t _r0 = vle16_v_u16m1(pp, vl); + vuint16m1_t _r1 = vle16_v_u16m1(pp + 8, vl); + vuint16m1_t _r2 = vle16_v_u16m1(pp + 8 * 2, vl); + vuint16m1_t _r3 = vle16_v_u16m1(pp + 8 * 3, vl); + + vsseg4e16_v_u16m1(p0, _r0, _r1, _r2, _r3, vl); + + pp += 32; + p0 += out_hstep * 4; + } + } + if (out_elempack == 1) + { + unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii); + + for (int jj = 0; jj < max_jj; jj += 1) + { + vl = 8; + vuint16m1_t _r0 = vle16_v_u16m1(pp, vl); + vse16_v_u16m1(p0, _r0, vl); + pp += 8; + p0 += out_hstep; + } + } + } + for (; ii + 3 < max_ii; ii += 4) + { + if (out_elempack == 8) + { + unsigned short* p0 = (unsigned short*)top_blob + (j / 8 * 8) * out_hstep + (i + ii) * 8; + + int jj = 0; + if (j % 8 == 4) + { + vl = 4; + vuint16m1_t _r0; + vuint16m1_t _r1; + vuint16m1_t _r2; + vuint16m1_t _r3; + + vlseg4e16_v_u16m1(&_r0, &_r1, &_r2, &_r3, pp, vl); + vse16_v_u16m1(p0 + 4, _r0, vl); + vse16_v_u16m1(p0 + 8 + 4, _r1, vl); + vse16_v_u16m1(p0 + 16 + 4, _r2, vl); + vse16_v_u16m1(p0 + 24 + 4, _r3, vl); + + pp += 16; + p0 += out_hstep * 8; + jj += 4; + } + for (; jj + 7 < max_jj; jj += 8) + { + vl = 8; + vuint16m1_t _r0; + vuint16m1_t _r1; + vuint16m1_t _r2; + vuint16m1_t _r3; + + vlseg4e16_v_u16m1(&_r0, &_r1, &_r2, &_r3, pp, vl); + vse16_v_u16m1(p0, _r0, vl); + vse16_v_u16m1(p0 + 8, _r1, vl); + vse16_v_u16m1(p0 + 8 * 2, _r2, vl); + vse16_v_u16m1(p0 + 8 * 3, _r3, vl); + + pp += 32; + p0 += out_hstep * 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + vl = 4; + vuint16m1_t _r0; + vuint16m1_t _r1; + vuint16m1_t _r2; + vuint16m1_t _r3; + + vlseg4e16_v_u16m1(&_r0, &_r1, &_r2, &_r3, pp, vl); + vse16_v_u16m1(p0, _r0, vl); + vse16_v_u16m1(p0 + 8, _r1, vl); + vse16_v_u16m1(p0 + 16, _r2, vl); + vse16_v_u16m1(p0 + 24, _r3, vl); + + pp += 16; + p0 += out_hstep * 8; + } + } + if (out_elempack == 4) + { + unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * 4; + + for (int jj = 0; jj + 3 < max_jj; jj += 4) + { + vl = 4; + vuint16m1_t _r0 = vle16_v_u16m1(pp, vl); + vuint16m1_t _r1 = vle16_v_u16m1(pp + 4, vl); + vuint16m1_t _r2 = vle16_v_u16m1(pp + 8, vl); + vuint16m1_t _r3 = vle16_v_u16m1(pp + 12, vl); + + vsseg4e16_v_u16m1(p0, _r0, _r1, _r2, _r3, vl); + + pp += 16; + p0 += out_hstep * 4; + } + } + if (out_elempack == 1) + { + unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii); + + for (int jj = 0; jj < max_jj; jj += 1) + { + vl = 4; + vuint16m1_t _r0 = vle16_v_u16m1(pp, vl); + vse16_v_u16m1(p0, _r0, vl); + + pp += 4; + p0 += out_hstep; + } + } + } +#endif // __riscv_vector + for (; ii + 1 < max_ii; ii += 2) + { +#if __riscv_vector + if (out_elempack == 8) + { + unsigned short* p0 = (unsigned short*)top_blob + (j / 8 * 8) * out_hstep + (i + ii) * 8; + + int jj = 0; + if (j % 8 == 4) + { + p0[0 + 4] = pp[0]; + p0[1 + 4] = pp[2]; + p0[2 + 4] = pp[4]; + p0[3 + 4] = pp[6]; + p0[8 + 4] = pp[1]; + p0[9 + 4] = pp[3]; + p0[10 + 4] = pp[5]; + p0[11 + 4] = pp[7]; + pp += 8; + p0 += out_hstep * 8; + jj += 4; + } + for (; jj + 7 < max_jj; jj += 8) + { + p0[0] = pp[0]; + p0[1] = pp[2]; + p0[2] = pp[4]; + p0[3] = pp[6]; + p0[4] = pp[8]; + p0[5] = pp[10]; + p0[6] = pp[12]; + p0[7] = pp[14]; + p0[8] = pp[1]; + p0[9] = pp[3]; + p0[10] = pp[5]; + p0[11] = pp[7]; + p0[12] = pp[9]; + p0[13] = pp[11]; + p0[14] = pp[13]; + p0[15] = pp[15]; + pp += 16; + p0 += out_hstep * 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + p0[0] = pp[0]; + p0[1] = pp[2]; + p0[2] = pp[4]; + p0[3] = pp[6]; + p0[8] = pp[1]; + p0[9] = pp[3]; + p0[10] = pp[5]; + p0[11] = pp[7]; + pp += 8; + p0 += out_hstep * 8; + } + } + if (out_elempack == 4) + { + unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * 4; + + for (int jj = 0; jj + 3 < max_jj; jj += 4) + { + p0[0] = pp[0]; + p0[1] = pp[2]; + p0[2] = pp[4]; + p0[3] = pp[6]; + p0[4] = pp[1]; + p0[5] = pp[3]; + p0[6] = pp[5]; + p0[7] = pp[7]; + pp += 8; + p0 += out_hstep * 4; + } + } +#endif // __riscv_vector + if (out_elempack == 1) + { + unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii); + + for (int jj = 0; jj < max_jj; jj += 1) + { + p0[0] = pp[0]; + p0[1] = pp[1]; + pp += 2; + p0 += out_hstep; + } + } + } + for (; ii < max_ii; ii += 1) + { +#if __riscv_vector + if (out_elempack == 8) + { + unsigned short* p0 = (unsigned short*)top_blob + (j / 8 * 8) * out_hstep + (i + ii) * 8; + + int jj = 0; + if (j % 8 == 4) + { + vl = 4; + vuint16m1_t _r0 = vle16_v_u16m1(pp, vl); + vse16_v_u16m1(p0 + 4, _r0, vl); + + pp += 4; + p0 += out_hstep * 8; + jj += 4; + } + for (; jj + 7 < max_jj; jj += 8) + { + vl = 8; + vuint16m1_t _r0 = vle16_v_u16m1(pp, vl); + vse16_v_u16m1(p0, _r0, vl); + pp += 8; + p0 += out_hstep * 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + vl = 4; + vuint16m1_t _r0 = vle16_v_u16m1(pp, vl); + vse16_v_u16m1(p0, _r0, vl); + + pp += 4; + p0 += out_hstep * 8; + } + } + if (out_elempack == 4) + { + unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * 4; + + for (int jj = 0; jj + 3 < max_jj; jj += 4) + { + vl = 4; + vuint16m1_t _r0 = vle16_v_u16m1(pp, vl); + vse16_v_u16m1(p0, _r0, vl); + + pp += 4; + p0 += out_hstep * 4; + } + } +#endif // __riscv_vector + if (out_elempack == 1) + { + unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii); + + for (int jj = 0; jj < max_jj; jj += 1) + { + p0[0] = pp[0]; + pp += 1; + p0 += out_hstep; + } + } + } +} + +static void get_optimal_tile_mnk_bf16s_fp16s(int M, int N, int K, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int& TILE_M, int& TILE_N, int& TILE_K, int nT) +{ + // resolve optimal tile size from cache size + const size_t l2_cache_size = get_cpu_level2_cache_size(); + + if (nT == 0) + nT = get_physical_big_cpu_count(); + + int tile_size = (int)sqrtf((float)l2_cache_size / (2 * sizeof(unsigned short) + sizeof(float))); + + TILE_M = std::max(8, tile_size / 8 * 8); + TILE_N = std::max(4, tile_size / 4 * 4); + TILE_K = std::max(8, tile_size / 8 * 8); + + if (K > 0) + { + int nn_K = (K + TILE_K - 1) / TILE_K; + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 7) / 8 * 8); + + if (nn_K == 1) + { + tile_size = (int)((float)l2_cache_size / 2 / sizeof(unsigned short) / TILE_K); + + TILE_M = std::max(8, tile_size / 8 * 8); + TILE_N = std::max(4, tile_size / 4 * 4); + } + } + + TILE_M *= std::min(nT, get_physical_cpu_count()); + + if (M > 0) + { + int nn_M = (M + TILE_M - 1) / TILE_M; + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 7) / 8 * 8); + } + + if (N > 0) + { + int nn_N = (N + TILE_N - 1) / TILE_N; + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); + } + + if (nT > 1) + { + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 7) / 8 * 8); + } + + // always take constant TILE_M/N/K value when provided + if (constant_TILE_M > 0) + { + TILE_M = (constant_TILE_M + 7) / 8 * 8; + } + + if (constant_TILE_N > 0) + { + TILE_N = (constant_TILE_N + 3) / 4 * 4; + } + + if (constant_TILE_K > 0) + { + TILE_K = (constant_TILE_K + 7) / 8 * 8; + } +} diff --git a/src/layer/riscv/gemm_fp16s.h b/src/layer/riscv/gemm_fp16s.h new file mode 100644 index 000000000000..ae561012de69 --- /dev/null +++ b/src/layer/riscv/gemm_fp16s.h @@ -0,0 +1,3049 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void print_f32_m2(vfloat32m2_t _val, size_t l) +{ + float* ptr = (float*)malloc(l * sizeof(float)); + vse32_v_f32m2(ptr, _val, l); + for (int i = 0; i < l; i++) + { + fprintf(stderr, "%f ", ptr[i]); + } + fprintf(stderr, "\n"); + free(ptr); +} + +static void print_f16_m1(vfloat16m1_t _val, size_t l) +{ + __fp16* ptr = (__fp16*)malloc(l * sizeof(__fp16)); + vse16_v_f16m1(ptr, _val, l); + for (int i = 0; i < l; i++) + { + fprintf(stderr, "%f ", (float)ptr[i]); + } + fprintf(stderr, "\n"); + free(ptr); +} + +static void pack_A_tile_fp32_to_fp16(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +{ + int vl; + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + + __fp16* pp = AT; + + int ii = 0; +#if __riscv_vector + for (; ii + 7 < max_ii; ii += 8) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k; + const float* p1 = (const float*)A + (i + ii + 1) * A_hstep + k; + const float* p2 = (const float*)A + (i + ii + 2) * A_hstep + k; + const float* p3 = (const float*)A + (i + ii + 3) * A_hstep + k; + const float* p4 = (const float*)A + (i + ii + 4) * A_hstep + k; + const float* p5 = (const float*)A + (i + ii + 5) * A_hstep + k; + const float* p6 = (const float*)A + (i + ii + 6) * A_hstep + k; + const float* p7 = (const float*)A + (i + ii + 7) * A_hstep + k; + + int kk = 0; + + int n = max_kk; + while (n > 0) + { + vl = vsetvl_e32m2(n); + vfloat16m1_t _r0 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p0, vl), vl); + vfloat16m1_t _r1 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p1, vl), vl); + vfloat16m1_t _r2 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p2, vl), vl); + vfloat16m1_t _r3 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p3, vl), vl); + vfloat16m1_t _r4 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p4, vl), vl); + vfloat16m1_t _r5 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p5, vl), vl); + vfloat16m1_t _r6 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p6, vl), vl); + vfloat16m1_t _r7 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p7, vl), vl); + vsseg8e16_v_f16m1(pp, _r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, vl); + pp += 8 * vl; + p0 += vl; + p1 += vl; + p2 += vl; + p3 += vl; + p4 += vl; + p5 += vl; + p6 += vl; + p7 += vl; + n -= vl; + } + } + for (; ii + 3 < max_ii; ii += 4) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k; + const float* p1 = (const float*)A + (i + ii + 1) * A_hstep + k; + const float* p2 = (const float*)A + (i + ii + 2) * A_hstep + k; + const float* p3 = (const float*)A + (i + ii + 3) * A_hstep + k; + + int kk = 0; + + int n = max_kk; + while (n > 0) + { + vl = vsetvl_e32m2(n); + vfloat16m1_t _r0 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p0, vl), vl); + vfloat16m1_t _r1 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p1, vl), vl); + vfloat16m1_t _r2 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p2, vl), vl); + vfloat16m1_t _r3 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p3, vl), vl); + vsseg4e16_v_f16m1(pp, _r0, _r1, _r2, _r3, vl); + pp += 4 * vl; + p0 += vl; + p1 += vl; + p2 += vl; + p3 += vl; + n -= vl; + } + } + for (; ii + 1 < max_ii; ii += 2) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k; + const float* p1 = (const float*)A + (i + ii + 1) * A_hstep + k; + + int kk = 0; + int n = max_kk; + while (n > 0) + { + vl = vsetvl_e32m2(n); + vfloat16m1_t _r0 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p0, vl), vl); + vfloat16m1_t _r1 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p1, vl), vl); + vsseg2e16_v_f16m1(pp, _r0, _r1, vl); + pp += 2 * vl; + p0 += vl; + p1 += vl; + n -= vl; + } + } + for (; ii < max_ii; ii += 1) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k; + + int kk = 0; + int n = max_kk; + + while (n > 0) + { + vl = vsetvl_e32m2(n); + vfloat16m1_t _r0 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p0, vl), vl); + vse16_v_f16m1(pp, _r0, vl); + pp += vl; + p0 += vl; + n -= vl; + } + } +#endif // __riscv_vector +} + +static void transpose_pack_A_tile_fp32_to_fp16(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +{ + int vl; + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + + __fp16* pp = AT; + + int ii = 0; +#if __riscv_vector + for (; ii + 7 < max_ii; ii += 8) + { + vl = 8; + const float* p0 = (const float*)A + k * A_hstep + (i + ii); + + int kk = 0; + for (; kk < max_kk; kk++) + { + vfloat16m1_t _r0 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p0, vl), vl); + vse16_v_f16m1(pp, _r0, vl); + pp += vl; + p0 += A_hstep; + } + } + for (; ii + 3 < max_ii; ii += 4) + { + vl = 4; + const float* p0 = (const float*)A + k * A_hstep + (i + ii); + + int kk = 0; + for (; kk < max_kk; kk++) + { + vfloat16m1_t _r0 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p0, vl), vl); + vse16_v_f16m1(pp, _r0, vl); + pp += vl; + p0 += A_hstep; + } + } + + for (; ii + 1 < max_ii; ii += 2) + { + vl = 2; + const float* p0 = (const float*)A + k * A_hstep + (i + ii); + + int kk = 0; + for (; kk < max_kk; kk++) + { + vfloat16m1_t _r0 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p0, vl), vl); + vse16_v_f16m1(pp, _r0, vl); + pp += vl; + p0 += A_hstep; + } + } + + for (; ii < max_ii; ii += 1) + { + vl = 1; + const float* p0 = (const float*)A + k * A_hstep + (i + ii); + + int kk = 0; + for (; kk < max_kk; kk++) + { + vfloat16m1_t _r0 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p0, vl), vl); + vse16_v_f16m1(pp, _r0, vl); + pp += vl; + p0 += A_hstep; + } + } +#endif // __riscv_vector +} + +static void pack_B_tile_fp32_to_fp16(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ + int vl; + const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; + + __fp16* pp = BT; + + int jj = 0; +#if __riscv_vector + for (; jj + 11 < max_jj; jj += 12) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k; + const float* p1 = (const float*)B + (j + jj + 1) * B_hstep + k; + const float* p2 = (const float*)B + (j + jj + 2) * B_hstep + k; + const float* p3 = (const float*)B + (j + jj + 3) * B_hstep + k; + const float* p4 = (const float*)B + (j + jj + 4) * B_hstep + k; + const float* p5 = (const float*)B + (j + jj + 5) * B_hstep + k; + const float* p6 = (const float*)B + (j + jj + 6) * B_hstep + k; + const float* p7 = (const float*)B + (j + jj + 7) * B_hstep + k; + const float* p8 = (const float*)B + (j + jj + 8) * B_hstep + k; + const float* p9 = (const float*)B + (j + jj + 9) * B_hstep + k; + const float* pa = (const float*)B + (j + jj + 10) * B_hstep + k; + const float* pb = (const float*)B + (j + jj + 11) * B_hstep + k; + + int kk = 0; + + int n = max_kk; + while (n > 0) + { + vl = vsetvl_e32m2(n); + vfloat16m1_t _r0 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p0, vl), vl); + vfloat16m1_t _r1 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p1, vl), vl); + vfloat16m1_t _r2 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p2, vl), vl); + vfloat16m1_t _r3 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p3, vl), vl); + vfloat16m1_t _r4 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p4, vl), vl); + vfloat16m1_t _r5 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p5, vl), vl); + vfloat16m1_t _r6 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p6, vl), vl); + vfloat16m1_t _r7 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p7, vl), vl); + vfloat16m1_t _r8 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p8, vl), vl); + vfloat16m1_t _r9 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p9, vl), vl); + vfloat16m1_t _ra = vfncvt_f_f_w_f16m1(vle32_v_f32m2(pa, vl), vl); + vfloat16m1_t _rb = vfncvt_f_f_w_f16m1(vle32_v_f32m2(pb, vl), vl); + + vsse16_v_f16m1(pp + 0, 12 * sizeof(__fp16), _r0, vl); + vsse16_v_f16m1(pp + 1, 12 * sizeof(__fp16), _r1, vl); + vsse16_v_f16m1(pp + 2, 12 * sizeof(__fp16), _r2, vl); + vsse16_v_f16m1(pp + 3, 12 * sizeof(__fp16), _r3, vl); + vsse16_v_f16m1(pp + 4, 12 * sizeof(__fp16), _r4, vl); + vsse16_v_f16m1(pp + 5, 12 * sizeof(__fp16), _r5, vl); + vsse16_v_f16m1(pp + 6, 12 * sizeof(__fp16), _r6, vl); + vsse16_v_f16m1(pp + 7, 12 * sizeof(__fp16), _r7, vl); + vsse16_v_f16m1(pp + 8, 12 * sizeof(__fp16), _r8, vl); + vsse16_v_f16m1(pp + 9, 12 * sizeof(__fp16), _r9, vl); + vsse16_v_f16m1(pp + 10, 12 * sizeof(__fp16), _ra, vl); + vsse16_v_f16m1(pp + 11, 12 * sizeof(__fp16), _rb, vl); + pp += 12 * vl; + p0 += vl; + p1 += vl; + p2 += vl; + p3 += vl; + p4 += vl; + p5 += vl; + p6 += vl; + p7 += vl; + p8 += vl; + p9 += vl; + pa += vl; + pb += vl; + + n -= vl; + } + } +#endif // __riscv_vector + for (; jj + 7 < max_jj; jj += 8) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k; + const float* p1 = (const float*)B + (j + jj + 1) * B_hstep + k; + const float* p2 = (const float*)B + (j + jj + 2) * B_hstep + k; + const float* p3 = (const float*)B + (j + jj + 3) * B_hstep + k; + const float* p4 = (const float*)B + (j + jj + 4) * B_hstep + k; + const float* p5 = (const float*)B + (j + jj + 5) * B_hstep + k; + const float* p6 = (const float*)B + (j + jj + 6) * B_hstep + k; + const float* p7 = (const float*)B + (j + jj + 7) * B_hstep + k; + + int kk = 0; + + int n = max_kk; + while (n > 0) + { + vl = vsetvl_e32m2(n); + vfloat16m1_t _r0 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p0, vl), vl); + vfloat16m1_t _r1 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p1, vl), vl); + vfloat16m1_t _r2 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p2, vl), vl); + vfloat16m1_t _r3 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p3, vl), vl); + vfloat16m1_t _r4 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p4, vl), vl); + vfloat16m1_t _r5 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p5, vl), vl); + vfloat16m1_t _r6 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p6, vl), vl); + vfloat16m1_t _r7 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p7, vl), vl); + + vsseg8e16_v_f16m1(pp, _r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, vl); + + pp += 8 * vl; + p0 += vl; + p1 += vl; + p2 += vl; + p3 += vl; + p4 += vl; + p5 += vl; + p6 += vl; + p7 += vl; + n -= vl; + } + } + for (; jj + 3 < max_jj; jj += 4) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k; + const float* p1 = (const float*)B + (j + jj + 1) * B_hstep + k; + const float* p2 = (const float*)B + (j + jj + 2) * B_hstep + k; + const float* p3 = (const float*)B + (j + jj + 3) * B_hstep + k; + + int kk = 0; + + int n = max_kk; + + while (n > 0) + { + vl = vsetvl_e32m2(n); + vfloat16m1_t _r0 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p0, vl), vl); + vfloat16m1_t _r1 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p1, vl), vl); + vfloat16m1_t _r2 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p2, vl), vl); + vfloat16m1_t _r3 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p3, vl), vl); + + vsseg4e16_v_f16m1(pp, _r0, _r1, _r2, _r3, vl); + + pp += 4 * vl; + p0 += vl; + p1 += vl; + p2 += vl; + p3 += vl; + n -= vl; + } + } + for (; jj + 1 < max_jj; jj += 2) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k; + const float* p1 = (const float*)B + (j + jj + 1) * B_hstep + k; + + int kk = 0; + + int n = max_kk; + while (n > 0) + { + vl = vsetvl_e32m2(n); + vfloat16m1_t _r0 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p0, vl), vl); + vfloat16m1_t _r1 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p1, vl), vl); + + vsseg2e16_v_f16m1(pp, _r0, _r1, vl); + + pp += 2 * vl; + p0 += vl; + p1 += vl; + n -= vl; + } + } + for (; jj < max_jj; jj += 1) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k; + + int kk = 0; + + int n = max_kk; + while (n > 0) + { + vl = vsetvl_e32m2(n); + vfloat16m1_t _r0 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p0, vl), vl); + + vse16_v_f16m1(pp, _r0, vl); + + pp += 1 * vl; + p0 += vl; + n -= vl; + } + } +} + +static void transpose_pack_B_tile_fp32_to_fp16(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ + int vl; + const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; + + __fp16* pp = BT; + + int jj = 0; + + for (; jj + 11 < max_jj; jj += 12) + { + vl = 12; + const float* p0 = (const float*)B + k * B_hstep + (j + jj); + + int kk = 0; + for (; kk < max_kk; kk++) + { + vfloat16m2_t _r0 = vfncvt_f_f_w_f16m2(vle32_v_f32m4(p0, vl), vl); + vse16_v_f16m2(pp, _r0, vl); + pp += vl; + p0 += B_hstep; + } + } + + for (; jj + 7 < max_jj; jj += 8) + { + vl = 8; + const float* p0 = (const float*)B + k * B_hstep + (j + jj); + + int kk = 0; + for (; kk < max_kk; kk++) + { + vfloat16m1_t _r0 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p0, vl), vl); + vse16_v_f16m1(pp, _r0, vl); + pp += vl; + p0 += B_hstep; + } + } + + for (; jj + 3 < max_jj; jj += 4) + { + vl = 4; + const float* p0 = (const float*)B + k * B_hstep + (j + jj); + + int kk = 0; + for (; kk < max_kk; kk++) + { + vfloat16m1_t _r0 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p0, vl), vl); + vse16_v_f16m1(pp, _r0, vl); + pp += vl; + p0 += B_hstep; + } + } + + for (; jj + 1 < max_jj; jj += 2) + { + vl = 2; + const float* p0 = (const float*)B + k * B_hstep + (j + jj); + + int kk = 0; + for (; kk < max_kk; kk++) + { + vfloat16m1_t _r0 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p0, vl), vl); + vse16_v_f16m1(pp, _r0, vl); + pp += vl; + p0 += B_hstep; + } + } + + for (; jj < max_jj; jj += 1) + { + vl = 1; + const float* p0 = (const float*)B + k * B_hstep + (j + jj); + + int kk = 0; + for (; kk < max_kk; kk++) + { + vfloat16m1_t _r0 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(p0, vl), vl); + vse16_v_f16m1(pp, _r0, vl); + pp += vl; + p0 += B_hstep; + } + } +} + +static void transpose_unpack_output_tile_fp32_to_fp16(const Mat& topT, Mat& top_blob, int i, int max_ii, int j, int max_jj) +{ + int vl; + const int out_elempack = top_blob.elempack; + const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; + + const float* pp = topT; + + int ii = 0; +#if __riscv_vector + for (; ii + 7 < max_ii; ii += 8) + { + if (out_elempack == 4) + { + __fp16* p0 = (__fp16*)top_blob + j * out_hstep + (i + ii) * 4; + + for (int jj = 0; jj + 3 < max_jj; jj += 4) + { + vl = 8; + vfloat16m1_t _r0 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(pp, vl), vl); + vfloat16m1_t _r1 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(pp + 8, vl), vl); + vfloat16m1_t _r2 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(pp + 16, vl), vl); + vfloat16m1_t _r3 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(pp + 24, vl), vl); + + vsseg4e16_v_f16m1(p0, _r0, _r1, _r2, _r3, vl); + + pp += 32; + p0 += out_hstep * 4; + } + } + if (out_elempack == 1) + { + __fp16* p0 = (__fp16*)top_blob + j * out_hstep + (i + ii); + + for (int jj = 0; jj < max_jj; jj += 1) + { + vl = 8; + vfloat16m1_t _r0 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(pp, vl), vl); + vse16_v_f16m1(p0, _r0, vl); + + pp += 8; + p0 += out_hstep; + } + } + } + for (; ii + 3 < max_ii; ii += 4) + { + if (out_elempack == 4) + { + __fp16* p0 = (__fp16*)top_blob + j * out_hstep + (i + ii) * 4; + + for (int jj = 0; jj + 3 < max_jj; jj += 4) + { + vl = 4; + vfloat16m1_t _r0 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(pp, vl), vl); + vfloat16m1_t _r1 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(pp + 4, vl), vl); + vfloat16m1_t _r2 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(pp + 8, vl), vl); + vfloat16m1_t _r3 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(pp + 12, vl), vl); + + vsseg4e16_v_f16m1(p0, _r0, _r1, _r2, _r3, vl); + + pp += 16; + p0 += out_hstep * 4; + } + } + if (out_elempack == 1) + { + __fp16* p0 = (__fp16*)top_blob + j * out_hstep + (i + ii); + + for (int jj = 0; jj < max_jj; jj += 1) + { + vl = 4; + vfloat16m1_t _r0 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(pp, vl), vl); + vse16_v_f16m1(p0, _r0, vl); + + pp += 4; + p0 += out_hstep; + } + } + } +#endif // __riscv_vector + for (; ii + 1 < max_ii; ii += 2) + { +#if __riscv_vector + if (out_elempack == 4) + { + __fp16* p0 = (__fp16*)top_blob + j * out_hstep + (i + ii) * 4; + + for (int jj = 0; jj + 3 < max_jj; jj += 4) + { + p0[0] = (__fp16)(pp[0]); + p0[1] = (__fp16)(pp[2]); + p0[2] = (__fp16)(pp[4]); + p0[3] = (__fp16)(pp[6]); + p0[4] = (__fp16)(pp[1]); + p0[5] = (__fp16)(pp[3]); + p0[6] = (__fp16)(pp[5]); + p0[7] = (__fp16)(pp[7]); + pp += 8; + p0 += out_hstep * 4; + } + } +#endif // __riscv_vector + if (out_elempack == 1) + { + __fp16* p0 = (__fp16*)top_blob + j * out_hstep + (i + ii); + + for (int jj = 0; jj < max_jj; jj += 1) + { + p0[0] = (__fp16)(pp[0]); + p0[1] = (__fp16)(pp[1]); + pp += 2; + p0 += out_hstep; + } + } + } + for (; ii < max_ii; ii += 1) + { +#if __riscv_vector + if (out_elempack == 4) + { + __fp16* p0 = (__fp16*)top_blob + j * out_hstep + (i + ii) * 4; + + for (int jj = 0; jj + 3 < max_jj; jj += 4) + { + vl = 4; + vfloat16m1_t _r0 = vfncvt_f_f_w_f16m1(vle32_v_f32m2(pp, vl), vl); + vse16_v_f16m1(p0, _r0, vl); + + pp += 4; + p0 += out_hstep * 4; + } + } +#endif // __riscv_vector + if (out_elempack == 1) + { + __fp16* p0 = (__fp16*)top_blob + j * out_hstep + (i + ii); + + for (int jj = 0; jj < max_jj; jj += 1) + { + p0[0] = (__fp16)(pp[0]); + pp += 1; + p0 += out_hstep; + } + } + } +} + +static void gemm_transB_packed_tile_fp16s(const Mat& AT_tile, const Mat& BT_tile, const Mat& CT_tile, Mat& topT_tile, Mat& top_blob, int broadcast_type_C, float alpha, int i, int max_ii, int j, int max_jj, int k, int max_kk, bool k_end) +{ + int vl; + + const int out_elempack = top_blob.elempack; + const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; + + const __fp16* pAT = AT_tile; + const __fp16* pBT = BT_tile; + + const float* pC = CT_tile; + + float* outptr = topT_tile; + + int ii = 0; +#if __riscv_vector + for (; ii + 7 < max_ii; ii += 8) + { + __fp16* outptr0 = (__fp16*)top_blob + (i + ii) * out_hstep + j * out_elempack; + + const __fp16* pB = pBT; + + if (pC) + { + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)CT_tile + i + ii; + } + if (broadcast_type_C == 4) + { + pC = (const float*)CT_tile + j; + } + } + + int jj = 0; + for (; jj + 11 < max_jj; jj += 12) + { + vfloat32m2_t _sum0; + vfloat32m2_t _sum1; + vfloat32m2_t _sum2; + vfloat32m2_t _sum3; + vfloat32m2_t _sum4; + vfloat32m2_t _sum5; + vfloat32m2_t _sum6; + vfloat32m2_t _sum7; + vfloat32m2_t _sum8; + vfloat32m2_t _sum9; + vfloat32m2_t _suma; + vfloat32m2_t _sumb; + + vl = 8; + + if (k == 0) + { + _sum0 = vfmv_v_f_f32m2(0.f, vl); + _sum1 = vfmv_v_f_f32m2(0.f, vl); + _sum2 = vfmv_v_f_f32m2(0.f, vl); + _sum3 = vfmv_v_f_f32m2(0.f, vl); + _sum4 = vfmv_v_f_f32m2(0.f, vl); + _sum5 = vfmv_v_f_f32m2(0.f, vl); + _sum6 = vfmv_v_f_f32m2(0.f, vl); + _sum7 = vfmv_v_f_f32m2(0.f, vl); + _sum8 = vfmv_v_f_f32m2(0.f, vl); + _sum9 = vfmv_v_f_f32m2(0.f, vl); + _suma = vfmv_v_f_f32m2(0.f, vl); + _sumb = vfmv_v_f_f32m2(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = vfmv_v_f_f32m2(pC[0], vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + _sum8 = _sum0; + _sum9 = _sum0; + _suma = _sum0; + _sumb = _sum0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = vle32_v_f32m2(pC, vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + _sum8 = _sum0; + _sum9 = _sum0; + _suma = _sum0; + _sumb = _sum0; + } + if (broadcast_type_C == 3) + { + _sum0 = vle32_v_f32m2(pC, vl); + _sum1 = vle32_v_f32m2(pC + 4 * 2, vl); + _sum2 = vle32_v_f32m2(pC + 4 * 4, vl); + _sum3 = vle32_v_f32m2(pC + 4 * 6, vl); + _sum4 = vle32_v_f32m2(pC + 4 * 8, vl); + _sum5 = vle32_v_f32m2(pC + 4 * 10, vl); + _sum6 = vle32_v_f32m2(pC + 4 * 12, vl); + _sum7 = vle32_v_f32m2(pC + 4 * 14, vl); + _sum8 = vle32_v_f32m2(pC + 4 * 16, vl); + _sum9 = vle32_v_f32m2(pC + 4 * 18, vl); + _suma = vle32_v_f32m2(pC + 4 * 20, vl); + _sumb = vle32_v_f32m2(pC + 4 * 22, vl); + pC += 96; + } + if (broadcast_type_C == 4) + { + _sum0 = vfmv_v_f_f32m2(pC[0], vl); + _sum1 = vfmv_v_f_f32m2(pC[1], vl); + _sum2 = vfmv_v_f_f32m2(pC[2], vl); + _sum3 = vfmv_v_f_f32m2(pC[3], vl); + _sum4 = vfmv_v_f_f32m2(pC[4], vl); + _sum5 = vfmv_v_f_f32m2(pC[5], vl); + _sum6 = vfmv_v_f_f32m2(pC[6], vl); + _sum7 = vfmv_v_f_f32m2(pC[7], vl); + _sum8 = vfmv_v_f_f32m2(pC[8], vl); + _sum9 = vfmv_v_f_f32m2(pC[9], vl); + _suma = vfmv_v_f_f32m2(pC[10], vl); + _sumb = vfmv_v_f_f32m2(pC[11], vl); + pC += 12; + } + } + } + else + { + _sum0 = vle32_v_f32m2(outptr, vl); + _sum1 = vle32_v_f32m2(outptr + 4 * 2, vl); + _sum2 = vle32_v_f32m2(outptr + 4 * 4, vl); + _sum3 = vle32_v_f32m2(outptr + 4 * 6, vl); + _sum4 = vle32_v_f32m2(outptr + 4 * 8, vl); + _sum5 = vle32_v_f32m2(outptr + 4 * 10, vl); + _sum6 = vle32_v_f32m2(outptr + 4 * 12, vl); + _sum7 = vle32_v_f32m2(outptr + 4 * 14, vl); + _sum8 = vle32_v_f32m2(outptr + 4 * 16, vl); + _sum9 = vle32_v_f32m2(outptr + 4 * 18, vl); + _suma = vle32_v_f32m2(outptr + 4 * 20, vl); + _sumb = vle32_v_f32m2(outptr + 4 * 22, vl); + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat16m1_t _pA = vle16_v_f16m1(pA, vl); + + _sum0 = vfwmacc_vf_f32m2(_sum0, pB[0], _pA, vl); + _sum1 = vfwmacc_vf_f32m2(_sum1, pB[1], _pA, vl); + _sum2 = vfwmacc_vf_f32m2(_sum2, pB[2], _pA, vl); + _sum3 = vfwmacc_vf_f32m2(_sum3, pB[3], _pA, vl); + _sum4 = vfwmacc_vf_f32m2(_sum4, pB[4], _pA, vl); + _sum5 = vfwmacc_vf_f32m2(_sum5, pB[5], _pA, vl); + _sum6 = vfwmacc_vf_f32m2(_sum6, pB[6], _pA, vl); + _sum7 = vfwmacc_vf_f32m2(_sum7, pB[7], _pA, vl); + _sum8 = vfwmacc_vf_f32m2(_sum8, pB[8], _pA, vl); + _sum9 = vfwmacc_vf_f32m2(_sum9, pB[9], _pA, vl); + _suma = vfwmacc_vf_f32m2(_suma, pB[10], _pA, vl); + _sumb = vfwmacc_vf_f32m2(_sumb, pB[11], _pA, vl); + + pA += 8; + pB += 12; + } + + if (alpha != 1.f) + { + _sum0 = vfmul_vf_f32m2(_sum0, alpha, vl); + _sum1 = vfmul_vf_f32m2(_sum1, alpha, vl); + _sum2 = vfmul_vf_f32m2(_sum2, alpha, vl); + _sum3 = vfmul_vf_f32m2(_sum3, alpha, vl); + _sum4 = vfmul_vf_f32m2(_sum4, alpha, vl); + _sum5 = vfmul_vf_f32m2(_sum5, alpha, vl); + _sum6 = vfmul_vf_f32m2(_sum6, alpha, vl); + _sum7 = vfmul_vf_f32m2(_sum7, alpha, vl); + _sum8 = vfmul_vf_f32m2(_sum8, alpha, vl); + _sum9 = vfmul_vf_f32m2(_sum9, alpha, vl); + _suma = vfmul_vf_f32m2(_suma, alpha, vl); + _sumb = vfmul_vf_f32m2(_sumb, alpha, vl); + } + + if (k_end) + { + if (out_elempack == 4) + { + vl = 4; + + vse16_v_f16m1(outptr0, vfncvt_f_f_w_f16m1(_sum0, vl), vl); + vse16_v_f16m1(outptr0 + 4, vfncvt_f_f_w_f16m1(_sum1, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 2, vfncvt_f_f_w_f16m1(_sum2, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 3, vfncvt_f_f_w_f16m1(_sum3, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 4, vfncvt_f_f_w_f16m1(_sum4, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 5, vfncvt_f_f_w_f16m1(_sum5, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 6, vfncvt_f_f_w_f16m1(_sum6, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 7, vfncvt_f_f_w_f16m1(_sum7, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 8, vfncvt_f_f_w_f16m1(_sum8, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 9, vfncvt_f_f_w_f16m1(_sum9, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 10, vfncvt_f_f_w_f16m1(_suma, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 11, vfncvt_f_f_w_f16m1(_sumb, vl), vl); + + _sum0 = vslidedown_vx_f32m2(_sum0, _sum0, 4, vl); + _sum1 = vslidedown_vx_f32m2(_sum1, _sum1, 4, vl); + _sum2 = vslidedown_vx_f32m2(_sum2, _sum2, 4, vl); + _sum3 = vslidedown_vx_f32m2(_sum3, _sum3, 4, vl); + _sum4 = vslidedown_vx_f32m2(_sum4, _sum4, 4, vl); + _sum5 = vslidedown_vx_f32m2(_sum5, _sum5, 4, vl); + _sum6 = vslidedown_vx_f32m2(_sum6, _sum6, 4, vl); + _sum7 = vslidedown_vx_f32m2(_sum7, _sum7, 4, vl); + _sum8 = vslidedown_vx_f32m2(_sum8, _sum8, 4, vl); + _sum9 = vslidedown_vx_f32m2(_sum9, _sum9, 4, vl); + _suma = vslidedown_vx_f32m2(_suma, _suma, 4, vl); + _sumb = vslidedown_vx_f32m2(_sumb, _sumb, 4, vl); + vse16_v_f16m1(outptr0 + out_hstep * 4, vfncvt_f_f_w_f16m1(_sum0, vl), vl); + vse16_v_f16m1(outptr0 + out_hstep * 4 + 4, vfncvt_f_f_w_f16m1(_sum1, vl), vl); + vse16_v_f16m1(outptr0 + out_hstep * 4 + 4 * 2, vfncvt_f_f_w_f16m1(_sum2, vl), vl); + vse16_v_f16m1(outptr0 + out_hstep * 4 + 4 * 3, vfncvt_f_f_w_f16m1(_sum3, vl), vl); + vse16_v_f16m1(outptr0 + out_hstep * 4 + 4 * 4, vfncvt_f_f_w_f16m1(_sum4, vl), vl); + vse16_v_f16m1(outptr0 + out_hstep * 4 + 4 * 5, vfncvt_f_f_w_f16m1(_sum5, vl), vl); + vse16_v_f16m1(outptr0 + out_hstep * 4 + 4 * 6, vfncvt_f_f_w_f16m1(_sum6, vl), vl); + vse16_v_f16m1(outptr0 + out_hstep * 4 + 4 * 7, vfncvt_f_f_w_f16m1(_sum7, vl), vl); + vse16_v_f16m1(outptr0 + out_hstep * 4 + 4 * 8, vfncvt_f_f_w_f16m1(_sum8, vl), vl); + vse16_v_f16m1(outptr0 + out_hstep * 4 + 4 * 9, vfncvt_f_f_w_f16m1(_sum9, vl), vl); + vse16_v_f16m1(outptr0 + out_hstep * 4 + 4 * 10, vfncvt_f_f_w_f16m1(_suma, vl), vl); + vse16_v_f16m1(outptr0 + out_hstep * 4 + 4 * 11, vfncvt_f_f_w_f16m1(_sumb, vl), vl); + outptr0 += 48; + } + if (out_elempack == 1) + { + vl = 8; + vfloat16m1_t _sum0_f16 = vfncvt_f_f_w_f16m1(_sum0, vl); + vfloat16m1_t _sum1_f16 = vfncvt_f_f_w_f16m1(_sum1, vl); + vfloat16m1_t _sum2_f16 = vfncvt_f_f_w_f16m1(_sum2, vl); + vfloat16m1_t _sum3_f16 = vfncvt_f_f_w_f16m1(_sum3, vl); + vfloat16m1_t _sum4_f16 = vfncvt_f_f_w_f16m1(_sum4, vl); + vfloat16m1_t _sum5_f16 = vfncvt_f_f_w_f16m1(_sum5, vl); + vfloat16m1_t _sum6_f16 = vfncvt_f_f_w_f16m1(_sum6, vl); + vfloat16m1_t _sum7_f16 = vfncvt_f_f_w_f16m1(_sum7, vl); + vfloat16m1_t _sum8_f16 = vfncvt_f_f_w_f16m1(_sum8, vl); + vfloat16m1_t _sum9_f16 = vfncvt_f_f_w_f16m1(_sum9, vl); + vfloat16m1_t _suma_f16 = vfncvt_f_f_w_f16m1(_suma, vl); + vfloat16m1_t _sumb_f16 = vfncvt_f_f_w_f16m1(_sumb, vl); + + vsse16_v_f16m1(outptr0, out_hstep * sizeof(__fp16), _sum0_f16, vl); + vsse16_v_f16m1(outptr0 + 1, out_hstep * sizeof(__fp16), _sum1_f16, vl); + vsse16_v_f16m1(outptr0 + 2, out_hstep * sizeof(__fp16), _sum2_f16, vl); + vsse16_v_f16m1(outptr0 + 3, out_hstep * sizeof(__fp16), _sum3_f16, vl); + vsse16_v_f16m1(outptr0 + 4, out_hstep * sizeof(__fp16), _sum4_f16, vl); + vsse16_v_f16m1(outptr0 + 5, out_hstep * sizeof(__fp16), _sum5_f16, vl); + vsse16_v_f16m1(outptr0 + 6, out_hstep * sizeof(__fp16), _sum6_f16, vl); + vsse16_v_f16m1(outptr0 + 7, out_hstep * sizeof(__fp16), _sum7_f16, vl); + vsse16_v_f16m1(outptr0 + 8, out_hstep * sizeof(__fp16), _sum8_f16, vl); + vsse16_v_f16m1(outptr0 + 9, out_hstep * sizeof(__fp16), _sum9_f16, vl); + vsse16_v_f16m1(outptr0 + 10, out_hstep * sizeof(__fp16), _suma_f16, vl); + vsse16_v_f16m1(outptr0 + 11, out_hstep * sizeof(__fp16), _sumb_f16, vl); + outptr0 += 12; + } + } + else + { + vl = 8; + vse32_v_f32m2(outptr, _sum0, vl); + vse32_v_f32m2(outptr + 8 * 1, _sum1, vl); + vse32_v_f32m2(outptr + 8 * 2, _sum2, vl); + vse32_v_f32m2(outptr + 8 * 3, _sum3, vl); + vse32_v_f32m2(outptr + 8 * 4, _sum4, vl); + vse32_v_f32m2(outptr + 8 * 5, _sum5, vl); + vse32_v_f32m2(outptr + 8 * 6, _sum6, vl); + vse32_v_f32m2(outptr + 8 * 7, _sum7, vl); + vse32_v_f32m2(outptr + 8 * 8, _sum8, vl); + vse32_v_f32m2(outptr + 8 * 9, _sum9, vl); + vse32_v_f32m2(outptr + 8 * 10, _suma, vl); + vse32_v_f32m2(outptr + 8 * 11, _sumb, vl); + } + + outptr += 96; + } + for (; jj + 7 < max_jj; jj += 8) + { + vfloat32m2_t _sum0; + vfloat32m2_t _sum1; + vfloat32m2_t _sum2; + vfloat32m2_t _sum3; + vfloat32m2_t _sum4; + vfloat32m2_t _sum5; + vfloat32m2_t _sum6; + vfloat32m2_t _sum7; + vl = 8; + + if (k == 0) + { + _sum0 = vfmv_v_f_f32m2(0.f, vl); + _sum1 = vfmv_v_f_f32m2(0.f, vl); + _sum2 = vfmv_v_f_f32m2(0.f, vl); + _sum3 = vfmv_v_f_f32m2(0.f, vl); + _sum4 = vfmv_v_f_f32m2(0.f, vl); + _sum5 = vfmv_v_f_f32m2(0.f, vl); + _sum6 = vfmv_v_f_f32m2(0.f, vl); + _sum7 = vfmv_v_f_f32m2(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = vfmv_v_f_f32m2(pC[0], vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = vle32_v_f32m2(pC, vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + } + if (broadcast_type_C == 3) + { + _sum0 = vle32_v_f32m2(pC, vl); + _sum1 = vle32_v_f32m2(pC + 4 * 2, vl); + _sum2 = vle32_v_f32m2(pC + 4 * 4, vl); + _sum3 = vle32_v_f32m2(pC + 4 * 6, vl); + _sum4 = vle32_v_f32m2(pC + 4 * 8, vl); + _sum5 = vle32_v_f32m2(pC + 4 * 10, vl); + _sum6 = vle32_v_f32m2(pC + 4 * 12, vl); + _sum7 = vle32_v_f32m2(pC + 4 * 14, vl); + pC += 64; + } + if (broadcast_type_C == 4) + { + _sum0 = vfmv_v_f_f32m2(pC[0], vl); + _sum1 = vfmv_v_f_f32m2(pC[1], vl); + _sum2 = vfmv_v_f_f32m2(pC[2], vl); + _sum3 = vfmv_v_f_f32m2(pC[3], vl); + _sum4 = vfmv_v_f_f32m2(pC[4], vl); + _sum5 = vfmv_v_f_f32m2(pC[5], vl); + _sum6 = vfmv_v_f_f32m2(pC[6], vl); + _sum7 = vfmv_v_f_f32m2(pC[7], vl); + pC += 8; + } + } + } + else + { + _sum0 = vle32_v_f32m2(outptr, vl); + _sum1 = vle32_v_f32m2(outptr + 4 * 2, vl); + _sum2 = vle32_v_f32m2(outptr + 4 * 4, vl); + _sum3 = vle32_v_f32m2(outptr + 4 * 6, vl); + _sum4 = vle32_v_f32m2(outptr + 4 * 8, vl); + _sum5 = vle32_v_f32m2(outptr + 4 * 10, vl); + _sum6 = vle32_v_f32m2(outptr + 4 * 12, vl); + _sum7 = vle32_v_f32m2(outptr + 4 * 14, vl); + } + const __fp16* pA = pAT; + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat16m1_t _pA = vle16_v_f16m1(pA, vl); + + _sum0 = vfwmacc_vf_f32m2(_sum0, pB[0], _pA, vl); + _sum1 = vfwmacc_vf_f32m2(_sum1, pB[1], _pA, vl); + _sum2 = vfwmacc_vf_f32m2(_sum2, pB[2], _pA, vl); + _sum3 = vfwmacc_vf_f32m2(_sum3, pB[3], _pA, vl); + _sum4 = vfwmacc_vf_f32m2(_sum4, pB[4], _pA, vl); + _sum5 = vfwmacc_vf_f32m2(_sum5, pB[5], _pA, vl); + _sum6 = vfwmacc_vf_f32m2(_sum6, pB[6], _pA, vl); + _sum7 = vfwmacc_vf_f32m2(_sum7, pB[7], _pA, vl); + + pA += 8; + pB += 8; + } + + if (alpha != 1.f) + { + _sum0 = vfmul_vf_f32m2(_sum0, alpha, vl); + _sum1 = vfmul_vf_f32m2(_sum1, alpha, vl); + _sum2 = vfmul_vf_f32m2(_sum2, alpha, vl); + _sum3 = vfmul_vf_f32m2(_sum3, alpha, vl); + _sum4 = vfmul_vf_f32m2(_sum4, alpha, vl); + _sum5 = vfmul_vf_f32m2(_sum5, alpha, vl); + _sum6 = vfmul_vf_f32m2(_sum6, alpha, vl); + _sum7 = vfmul_vf_f32m2(_sum7, alpha, vl); + } + + if (k_end) + { + if (out_elempack == 4) + { + vl = 4; + + vse16_v_f16m1(outptr0, vfncvt_f_f_w_f16m1(_sum0, vl), vl); + vse16_v_f16m1(outptr0 + 4, vfncvt_f_f_w_f16m1(_sum1, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 2, vfncvt_f_f_w_f16m1(_sum2, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 3, vfncvt_f_f_w_f16m1(_sum3, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 4, vfncvt_f_f_w_f16m1(_sum4, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 5, vfncvt_f_f_w_f16m1(_sum5, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 6, vfncvt_f_f_w_f16m1(_sum6, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 7, vfncvt_f_f_w_f16m1(_sum7, vl), vl); + + vl = 8; + + _sum0 = vslidedown_vx_f32m2(_sum0, _sum0, 4, vl); + _sum1 = vslidedown_vx_f32m2(_sum1, _sum1, 4, vl); + _sum2 = vslidedown_vx_f32m2(_sum2, _sum2, 4, vl); + _sum3 = vslidedown_vx_f32m2(_sum3, _sum3, 4, vl); + _sum4 = vslidedown_vx_f32m2(_sum4, _sum4, 4, vl); + _sum5 = vslidedown_vx_f32m2(_sum5, _sum5, 4, vl); + _sum6 = vslidedown_vx_f32m2(_sum6, _sum6, 4, vl); + _sum7 = vslidedown_vx_f32m2(_sum7, _sum7, 4, vl); + + vl = 4; + vse16_v_f16m1(outptr0 + out_hstep * 4, vfncvt_f_f_w_f16m1(_sum0, vl), vl); + vse16_v_f16m1(outptr0 + out_hstep * 4 + 4, vfncvt_f_f_w_f16m1(_sum1, vl), vl); + vse16_v_f16m1(outptr0 + out_hstep * 4 + 4 * 2, vfncvt_f_f_w_f16m1(_sum2, vl), vl); + vse16_v_f16m1(outptr0 + out_hstep * 4 + 4 * 3, vfncvt_f_f_w_f16m1(_sum3, vl), vl); + vse16_v_f16m1(outptr0 + out_hstep * 4 + 4 * 4, vfncvt_f_f_w_f16m1(_sum4, vl), vl); + vse16_v_f16m1(outptr0 + out_hstep * 4 + 4 * 5, vfncvt_f_f_w_f16m1(_sum5, vl), vl); + vse16_v_f16m1(outptr0 + out_hstep * 4 + 4 * 6, vfncvt_f_f_w_f16m1(_sum6, vl), vl); + vse16_v_f16m1(outptr0 + out_hstep * 4 + 4 * 7, vfncvt_f_f_w_f16m1(_sum7, vl), vl); + + outptr0 += 32; + } + if (out_elempack == 1) + { + vl = 8; + vfloat16m1_t _sum0_f16 = vfncvt_f_f_w_f16m1(_sum0, vl); + vfloat16m1_t _sum1_f16 = vfncvt_f_f_w_f16m1(_sum1, vl); + vfloat16m1_t _sum2_f16 = vfncvt_f_f_w_f16m1(_sum2, vl); + vfloat16m1_t _sum3_f16 = vfncvt_f_f_w_f16m1(_sum3, vl); + vfloat16m1_t _sum4_f16 = vfncvt_f_f_w_f16m1(_sum4, vl); + vfloat16m1_t _sum5_f16 = vfncvt_f_f_w_f16m1(_sum5, vl); + vfloat16m1_t _sum6_f16 = vfncvt_f_f_w_f16m1(_sum6, vl); + vfloat16m1_t _sum7_f16 = vfncvt_f_f_w_f16m1(_sum7, vl); + + vsse16_v_f16m1(outptr0, out_hstep * sizeof(__fp16), _sum0_f16, vl); + vsse16_v_f16m1(outptr0 + 1, out_hstep * sizeof(__fp16), _sum1_f16, vl); + vsse16_v_f16m1(outptr0 + 2, out_hstep * sizeof(__fp16), _sum2_f16, vl); + vsse16_v_f16m1(outptr0 + 3, out_hstep * sizeof(__fp16), _sum3_f16, vl); + vsse16_v_f16m1(outptr0 + 4, out_hstep * sizeof(__fp16), _sum4_f16, vl); + vsse16_v_f16m1(outptr0 + 5, out_hstep * sizeof(__fp16), _sum5_f16, vl); + vsse16_v_f16m1(outptr0 + 6, out_hstep * sizeof(__fp16), _sum6_f16, vl); + vsse16_v_f16m1(outptr0 + 7, out_hstep * sizeof(__fp16), _sum7_f16, vl); + + outptr0 += 8; + } + } + else + { + vl = 8; + vse32_v_f32m2(outptr, _sum0, vl); + vse32_v_f32m2(outptr + 4 * 2, _sum1, vl); + vse32_v_f32m2(outptr + 4 * 4, _sum2, vl); + vse32_v_f32m2(outptr + 4 * 6, _sum3, vl); + vse32_v_f32m2(outptr + 4 * 8, _sum4, vl); + vse32_v_f32m2(outptr + 4 * 10, _sum5, vl); + vse32_v_f32m2(outptr + 4 * 12, _sum6, vl); + vse32_v_f32m2(outptr + 4 * 14, _sum7, vl); + } + + outptr += 64; + } + for (; jj + 3 < max_jj; jj += 4) + { + vfloat32m2_t _sum0; + vfloat32m2_t _sum1; + vfloat32m2_t _sum2; + vfloat32m2_t _sum3; + vl = 8; + + if (k == 0) + { + _sum0 = vfmv_v_f_f32m2(0.f, vl); + _sum1 = vfmv_v_f_f32m2(0.f, vl); + _sum2 = vfmv_v_f_f32m2(0.f, vl); + _sum3 = vfmv_v_f_f32m2(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = vfmv_v_f_f32m2(pC[0], vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = vle32_v_f32m2(pC, vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + } + if (broadcast_type_C == 3) + { + _sum0 = vle32_v_f32m2(pC, vl); + _sum1 = vle32_v_f32m2(pC + 4 * 2, vl); + _sum2 = vle32_v_f32m2(pC + 4 * 4, vl); + _sum3 = vle32_v_f32m2(pC + 4 * 6, vl); + pC += 32; + } + if (broadcast_type_C == 4) + { + _sum0 = vfmv_v_f_f32m2(pC[0], vl); + _sum1 = vfmv_v_f_f32m2(pC[1], vl); + _sum2 = vfmv_v_f_f32m2(pC[2], vl); + _sum3 = vfmv_v_f_f32m2(pC[3], vl); + pC += 4; + } + } + } + else + { + _sum0 = vle32_v_f32m2(outptr, vl); + _sum1 = vle32_v_f32m2(outptr + 4 * 2, vl); + _sum2 = vle32_v_f32m2(outptr + 4 * 4, vl); + _sum3 = vle32_v_f32m2(outptr + 4 * 6, vl); + } + + const __fp16* pA = pAT; + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat16m1_t _pA = vle16_v_f16m1(pA, vl); + _sum0 = vfwmacc_vf_f32m2(_sum0, pB[0], _pA, vl); + _sum1 = vfwmacc_vf_f32m2(_sum1, pB[1], _pA, vl); + _sum2 = vfwmacc_vf_f32m2(_sum2, pB[2], _pA, vl); + _sum3 = vfwmacc_vf_f32m2(_sum3, pB[3], _pA, vl); + + pA += 8; + pB += 4; + } + + if (alpha != 1.f) + { + _sum0 = vfmul_vf_f32m2(_sum0, alpha, vl); + _sum1 = vfmul_vf_f32m2(_sum1, alpha, vl); + _sum2 = vfmul_vf_f32m2(_sum2, alpha, vl); + _sum3 = vfmul_vf_f32m2(_sum3, alpha, vl); + } + + if (k_end) + { + if (out_elempack == 4) + { + vl = 4; + + vse16_v_f16m1(outptr0, vfncvt_f_f_w_f16m1(_sum0, vl), vl); + vse16_v_f16m1(outptr0 + 4, vfncvt_f_f_w_f16m1(_sum1, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 2, vfncvt_f_f_w_f16m1(_sum2, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 3, vfncvt_f_f_w_f16m1(_sum3, vl), vl); + + vl = 8; + + _sum0 = vslidedown_vx_f32m2(_sum0, _sum0, 4, vl); + _sum1 = vslidedown_vx_f32m2(_sum1, _sum1, 4, vl); + _sum2 = vslidedown_vx_f32m2(_sum2, _sum2, 4, vl); + _sum3 = vslidedown_vx_f32m2(_sum3, _sum3, 4, vl); + + vl = 4; + vse16_v_f16m1(outptr0 + out_hstep * 4, vfncvt_f_f_w_f16m1(_sum0, vl), vl); + vse16_v_f16m1(outptr0 + out_hstep * 4 + 4, vfncvt_f_f_w_f16m1(_sum1, vl), vl); + vse16_v_f16m1(outptr0 + out_hstep * 4 + 4 * 2, vfncvt_f_f_w_f16m1(_sum2, vl), vl); + vse16_v_f16m1(outptr0 + out_hstep * 4 + 4 * 3, vfncvt_f_f_w_f16m1(_sum3, vl), vl); + outptr0 += 16; + } + if (out_elempack == 1) + { + vl = 8; + vfloat16m1_t _sum0_f16 = vfncvt_f_f_w_f16m1(_sum0, vl); + vfloat16m1_t _sum1_f16 = vfncvt_f_f_w_f16m1(_sum1, vl); + vfloat16m1_t _sum2_f16 = vfncvt_f_f_w_f16m1(_sum2, vl); + vfloat16m1_t _sum3_f16 = vfncvt_f_f_w_f16m1(_sum3, vl); + + vsse16_v_f16m1(outptr0, out_hstep * sizeof(__fp16), _sum0_f16, vl); + vsse16_v_f16m1(outptr0 + 1, out_hstep * sizeof(__fp16), _sum1_f16, vl); + vsse16_v_f16m1(outptr0 + 2, out_hstep * sizeof(__fp16), _sum2_f16, vl); + vsse16_v_f16m1(outptr0 + 3, out_hstep * sizeof(__fp16), _sum3_f16, vl); + outptr0 += 4; + } + } + else + { + vl = 8; + vse32_v_f32m2(outptr, _sum0, vl); + vse32_v_f32m2(outptr + 4 * 2, _sum1, vl); + vse32_v_f32m2(outptr + 4 * 4, _sum2, vl); + vse32_v_f32m2(outptr + 4 * 6, _sum3, vl); + } + + outptr += 32; + } + for (; jj + 1 < max_jj; jj += 2) + { + vfloat32m2_t _sum0; + vfloat32m2_t _sum1; + + vl = 8; + + if (k == 0) + { + _sum0 = vfmv_v_f_f32m2(0.f, vl); + _sum1 = vfmv_v_f_f32m2(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = vfmv_v_f_f32m2(pC[0], vl); + _sum1 = _sum0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = vle32_v_f32m2(pC, vl); + _sum1 = _sum0; + } + if (broadcast_type_C == 3) + { + _sum0 = vle32_v_f32m2(pC, vl); + _sum1 = vle32_v_f32m2(pC + 4 * 2, vl); + pC += 16; + } + if (broadcast_type_C == 4) + { + _sum0 = vfmv_v_f_f32m2(pC[0], vl); + _sum1 = vfmv_v_f_f32m2(pC[1], vl); + pC += 2; + } + } + } + else + { + _sum0 = vle32_v_f32m2(outptr, vl); + _sum1 = vle32_v_f32m2(outptr + 4 * 2, vl); + } + + const __fp16* pA = pAT; + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat16m1_t _pA = vle16_v_f16m1(pA, vl); + _sum0 = vfwmacc_vf_f32m2(_sum0, pB[0], _pA, vl); + _sum1 = vfwmacc_vf_f32m2(_sum1, pB[1], _pA, vl); + pA += 8; + pB += 2; + } + + if (alpha != 1.f) + { + _sum0 = vfmul_vf_f32m2(_sum0, alpha, vl); + _sum1 = vfmul_vf_f32m2(_sum1, alpha, vl); + } + + if (k_end) + { + if (out_elempack == 4) + { + vl = 4; + vse16_v_f16m1(outptr0, vfncvt_f_f_w_f16m1(_sum0, vl), vl); + vse16_v_f16m1(outptr0 + 4, vfncvt_f_f_w_f16m1(_sum1, vl), vl); + + vl = 8; + + _sum0 = vslidedown_vx_f32m2(_sum0, _sum0, 4, vl); + _sum1 = vslidedown_vx_f32m2(_sum1, _sum1, 4, vl); + + vl = 4; + vse16_v_f16m1(outptr0 + out_hstep * 4, vfncvt_f_f_w_f16m1(_sum0, vl), vl); + vse16_v_f16m1(outptr0 + out_hstep * 4 + 4, vfncvt_f_f_w_f16m1(_sum1, vl), vl); + + outptr0 += 8; + } + if (out_elempack == 1) + { + vl = 8; + vfloat16m1_t _sum0_f16 = vfncvt_f_f_w_f16m1(_sum0, vl); + vfloat16m1_t _sum1_f16 = vfncvt_f_f_w_f16m1(_sum1, vl); + + vsse16_v_f16m1(outptr0, out_hstep * sizeof(__fp16), _sum0_f16, vl); + vsse16_v_f16m1(outptr0 + 1, out_hstep * sizeof(__fp16), _sum1_f16, vl); + + outptr0 += 2; + } + } + else + { + vl = 8; + vse32_v_f32m2(outptr, _sum0, vl); + vse32_v_f32m2(outptr + 4 * 2, _sum1, vl); + } + + outptr += 16; + } + for (; jj < max_jj; jj += 1) + { + vfloat32m2_t _sum0; + vl = 8; + + if (k == 0) + { + _sum0 = vfmv_v_f_f32m2(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = vfmv_v_f_f32m2(pC[0], vl); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = vle32_v_f32m2(pC, vl); + } + if (broadcast_type_C == 3) + { + _sum0 = vle32_v_f32m2(pC, vl); + pC += 8; + } + if (broadcast_type_C == 4) + { + _sum0 = vfmv_v_f_f32m2(pC[0], vl); + pC += 1; + } + } + } + else + { + _sum0 = vle32_v_f32m2(outptr, vl); + } + + const __fp16* pA = pAT; + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat16m1_t _pA = vle16_v_f16m1(pA, vl); + + _sum0 = vfwmacc_vf_f32m2(_sum0, pB[0], _pA, vl); + pA += 8; + pB += 1; + } + + if (alpha != 1.f) + { + _sum0 = vfmul_vf_f32m2(_sum0, alpha, vl); + } + + if (k_end) + { + if (out_elempack == 4) + { + vl = 4; + vse16_v_f16m1(outptr0, vfncvt_f_f_w_f16m1(_sum0, vl), vl); + vl = 8; + _sum0 = vslidedown_vx_f32m2(_sum0, _sum0, 4, vl); + vl = 4; + vse16_v_f16m1(outptr0 + out_hstep * 4, vfncvt_f_f_w_f16m1(_sum0, vl), vl); + outptr0 += 4; + } + if (out_elempack == 1) + { + vl = 8; + vfloat16m1_t _sum0_f16 = vfncvt_f_f_w_f16m1(_sum0, vl); + vsse16_v_f16m1(outptr0, out_hstep * sizeof(__fp16), _sum0_f16, vl); + + outptr0++; + } + } + else + { + vse32_v_f32m2(outptr, _sum0, vl); + } + + outptr += 8; + } + + pAT += max_kk * 8; + } +#endif // __riscv_vector + for (; ii + 3 < max_ii; ii += 4) + { + __fp16* outptr0 = (__fp16*)top_blob + (i + ii) * out_hstep + j * out_elempack; + + const __fp16* pB = pBT; + + if (pC) + { + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)CT_tile + i + ii; + } + if (broadcast_type_C == 4) + { + pC = (const float*)CT_tile + j; + } + } + + int jj = 0; +#if __riscv_vector + for (; jj + 11 < max_jj; jj += 12) + { + vfloat32m2_t _sum0; + vfloat32m2_t _sum1; + vfloat32m2_t _sum2; + vfloat32m2_t _sum3; + vfloat32m2_t _sum4; + vfloat32m2_t _sum5; + vfloat32m2_t _sum6; + vfloat32m2_t _sum7; + vfloat32m2_t _sum8; + vfloat32m2_t _sum9; + vfloat32m2_t _suma; + vfloat32m2_t _sumb; + vl = 4; + + if (k == 0) + { + _sum0 = vfmv_v_f_f32m2(0.f, vl); + _sum1 = vfmv_v_f_f32m2(0.f, vl); + _sum2 = vfmv_v_f_f32m2(0.f, vl); + _sum3 = vfmv_v_f_f32m2(0.f, vl); + _sum4 = vfmv_v_f_f32m2(0.f, vl); + _sum5 = vfmv_v_f_f32m2(0.f, vl); + _sum6 = vfmv_v_f_f32m2(0.f, vl); + _sum7 = vfmv_v_f_f32m2(0.f, vl); + _sum8 = vfmv_v_f_f32m2(0.f, vl); + _sum9 = vfmv_v_f_f32m2(0.f, vl); + _suma = vfmv_v_f_f32m2(0.f, vl); + _sumb = vfmv_v_f_f32m2(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = vfmv_v_f_f32m2(pC[0], vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + _sum8 = _sum0; + _sum9 = _sum0; + _suma = _sum0; + _sumb = _sum0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = vle32_v_f32m2(pC, vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + _sum8 = _sum0; + _sum9 = _sum0; + _suma = _sum0; + _sumb = _sum0; + } + if (broadcast_type_C == 3) + { + _sum0 = vle32_v_f32m2(pC, vl); + _sum1 = vle32_v_f32m2(pC + 4, vl); + _sum2 = vle32_v_f32m2(pC + 8, vl); + _sum3 = vle32_v_f32m2(pC + 12, vl); + _sum4 = vle32_v_f32m2(pC + 16, vl); + _sum5 = vle32_v_f32m2(pC + 20, vl); + _sum6 = vle32_v_f32m2(pC + 24, vl); + _sum7 = vle32_v_f32m2(pC + 28, vl); + _sum8 = vle32_v_f32m2(pC + 32, vl); + _sum9 = vle32_v_f32m2(pC + 36, vl); + _suma = vle32_v_f32m2(pC + 40, vl); + _sumb = vle32_v_f32m2(pC + 44, vl); + pC += 48; + } + if (broadcast_type_C == 4) + { + _sum0 = vfmv_v_f_f32m2(pC[0], vl); + _sum1 = vfmv_v_f_f32m2(pC[1], vl); + _sum2 = vfmv_v_f_f32m2(pC[2], vl); + _sum3 = vfmv_v_f_f32m2(pC[3], vl); + _sum4 = vfmv_v_f_f32m2(pC[4], vl); + _sum5 = vfmv_v_f_f32m2(pC[5], vl); + _sum6 = vfmv_v_f_f32m2(pC[6], vl); + _sum7 = vfmv_v_f_f32m2(pC[7], vl); + _sum8 = vfmv_v_f_f32m2(pC[8], vl); + _sum9 = vfmv_v_f_f32m2(pC[9], vl); + _suma = vfmv_v_f_f32m2(pC[10], vl); + _sumb = vfmv_v_f_f32m2(pC[11], vl); + pC += 12; + } + } + } + else + { + _sum0 = vle32_v_f32m2(outptr, vl); + _sum1 = vle32_v_f32m2(outptr + 4 * 1, vl); + _sum2 = vle32_v_f32m2(outptr + 4 * 2, vl); + _sum3 = vle32_v_f32m2(outptr + 4 * 3, vl); + _sum4 = vle32_v_f32m2(outptr + 4 * 4, vl); + _sum5 = vle32_v_f32m2(outptr + 4 * 5, vl); + _sum6 = vle32_v_f32m2(outptr + 4 * 6, vl); + _sum7 = vle32_v_f32m2(outptr + 4 * 7, vl); + _sum8 = vle32_v_f32m2(outptr + 4 * 8, vl); + _sum9 = vle32_v_f32m2(outptr + 4 * 9, vl); + _suma = vle32_v_f32m2(outptr + 4 * 10, vl); + _sumb = vle32_v_f32m2(outptr + 4 * 11, vl); + } + + const __fp16* pA = pAT; + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat16m1_t _pA = vle16_v_f16m1(pA, vl); + + _sum0 = vfwmacc_vf_f32m2(_sum0, pB[0], _pA, vl); + _sum1 = vfwmacc_vf_f32m2(_sum1, pB[1], _pA, vl); + _sum2 = vfwmacc_vf_f32m2(_sum2, pB[2], _pA, vl); + _sum3 = vfwmacc_vf_f32m2(_sum3, pB[3], _pA, vl); + _sum4 = vfwmacc_vf_f32m2(_sum4, pB[4], _pA, vl); + _sum5 = vfwmacc_vf_f32m2(_sum5, pB[5], _pA, vl); + _sum6 = vfwmacc_vf_f32m2(_sum6, pB[6], _pA, vl); + _sum7 = vfwmacc_vf_f32m2(_sum7, pB[7], _pA, vl); + _sum8 = vfwmacc_vf_f32m2(_sum8, pB[8], _pA, vl); + _sum9 = vfwmacc_vf_f32m2(_sum9, pB[9], _pA, vl); + _suma = vfwmacc_vf_f32m2(_suma, pB[10], _pA, vl); + _sumb = vfwmacc_vf_f32m2(_sumb, pB[11], _pA, vl); + + pA += 4; + pB += 12; + } + + if (alpha != 1.f) + { + _sum0 = vfmul_vf_f32m2(_sum0, alpha, vl); + _sum1 = vfmul_vf_f32m2(_sum1, alpha, vl); + _sum2 = vfmul_vf_f32m2(_sum2, alpha, vl); + _sum3 = vfmul_vf_f32m2(_sum3, alpha, vl); + _sum4 = vfmul_vf_f32m2(_sum4, alpha, vl); + _sum5 = vfmul_vf_f32m2(_sum5, alpha, vl); + _sum6 = vfmul_vf_f32m2(_sum6, alpha, vl); + _sum7 = vfmul_vf_f32m2(_sum7, alpha, vl); + _sum8 = vfmul_vf_f32m2(_sum8, alpha, vl); + _sum9 = vfmul_vf_f32m2(_sum9, alpha, vl); + _suma = vfmul_vf_f32m2(_suma, alpha, vl); + _sumb = vfmul_vf_f32m2(_sumb, alpha, vl); + } + + if (k_end) + { + if (out_elempack == 4) + { + vse16_v_f16m1(outptr0, vfncvt_f_f_w_f16m1(_sum0, vl), vl); + vse16_v_f16m1(outptr0 + 4, vfncvt_f_f_w_f16m1(_sum1, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 2, vfncvt_f_f_w_f16m1(_sum2, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 3, vfncvt_f_f_w_f16m1(_sum3, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 4, vfncvt_f_f_w_f16m1(_sum4, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 5, vfncvt_f_f_w_f16m1(_sum5, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 6, vfncvt_f_f_w_f16m1(_sum6, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 7, vfncvt_f_f_w_f16m1(_sum7, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 8, vfncvt_f_f_w_f16m1(_sum8, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 9, vfncvt_f_f_w_f16m1(_sum9, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 10, vfncvt_f_f_w_f16m1(_suma, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 11, vfncvt_f_f_w_f16m1(_sumb, vl), vl); + + outptr0 += 48; + } + if (out_elempack == 1) + { + vfloat16m1_t _sum0_f16 = vfncvt_f_f_w_f16m1(_sum0, vl); + vfloat16m1_t _sum1_f16 = vfncvt_f_f_w_f16m1(_sum1, vl); + vfloat16m1_t _sum2_f16 = vfncvt_f_f_w_f16m1(_sum2, vl); + vfloat16m1_t _sum3_f16 = vfncvt_f_f_w_f16m1(_sum3, vl); + vfloat16m1_t _sum4_f16 = vfncvt_f_f_w_f16m1(_sum4, vl); + vfloat16m1_t _sum5_f16 = vfncvt_f_f_w_f16m1(_sum5, vl); + vfloat16m1_t _sum6_f16 = vfncvt_f_f_w_f16m1(_sum6, vl); + vfloat16m1_t _sum7_f16 = vfncvt_f_f_w_f16m1(_sum7, vl); + vfloat16m1_t _sum8_f16 = vfncvt_f_f_w_f16m1(_sum8, vl); + vfloat16m1_t _sum9_f16 = vfncvt_f_f_w_f16m1(_sum9, vl); + vfloat16m1_t _suma_f16 = vfncvt_f_f_w_f16m1(_suma, vl); + vfloat16m1_t _sumb_f16 = vfncvt_f_f_w_f16m1(_sumb, vl); + + vsse16_v_f16m1(outptr0, out_hstep * sizeof(__fp16), _sum0_f16, vl); + vsse16_v_f16m1(outptr0 + 1, out_hstep * sizeof(__fp16), _sum1_f16, vl); + vsse16_v_f16m1(outptr0 + 2, out_hstep * sizeof(__fp16), _sum2_f16, vl); + vsse16_v_f16m1(outptr0 + 3, out_hstep * sizeof(__fp16), _sum3_f16, vl); + vsse16_v_f16m1(outptr0 + 4, out_hstep * sizeof(__fp16), _sum4_f16, vl); + vsse16_v_f16m1(outptr0 + 5, out_hstep * sizeof(__fp16), _sum5_f16, vl); + vsse16_v_f16m1(outptr0 + 6, out_hstep * sizeof(__fp16), _sum6_f16, vl); + vsse16_v_f16m1(outptr0 + 7, out_hstep * sizeof(__fp16), _sum7_f16, vl); + vsse16_v_f16m1(outptr0 + 8, out_hstep * sizeof(__fp16), _sum8_f16, vl); + vsse16_v_f16m1(outptr0 + 9, out_hstep * sizeof(__fp16), _sum9_f16, vl); + vsse16_v_f16m1(outptr0 + 10, out_hstep * sizeof(__fp16), _suma_f16, vl); + vsse16_v_f16m1(outptr0 + 11, out_hstep * sizeof(__fp16), _sumb_f16, vl); + + outptr0 += 12; + } + } + else + { + vse32_v_f32m2(outptr, _sum0, vl); + vse32_v_f32m2(outptr + 4, _sum1, vl); + vse32_v_f32m2(outptr + 4 * 2, _sum2, vl); + vse32_v_f32m2(outptr + 4 * 3, _sum3, vl); + vse32_v_f32m2(outptr + 4 * 4, _sum4, vl); + vse32_v_f32m2(outptr + 4 * 5, _sum5, vl); + vse32_v_f32m2(outptr + 4 * 6, _sum6, vl); + vse32_v_f32m2(outptr + 4 * 7, _sum7, vl); + vse32_v_f32m2(outptr + 4 * 8, _sum8, vl); + vse32_v_f32m2(outptr + 4 * 9, _sum9, vl); + vse32_v_f32m2(outptr + 4 * 10, _suma, vl); + vse32_v_f32m2(outptr + 4 * 11, _sumb, vl); + } + + outptr += 48; + } +#endif // __riscv_vector + for (; jj + 7 < max_jj; jj += 8) + { + vfloat32m2_t _sum0; + vfloat32m2_t _sum1; + vfloat32m2_t _sum2; + vfloat32m2_t _sum3; + vfloat32m2_t _sum4; + vfloat32m2_t _sum5; + vfloat32m2_t _sum6; + vfloat32m2_t _sum7; + vl = 4; + + if (k == 0) + { + _sum0 = vfmv_v_f_f32m2(0.f, vl); + _sum1 = vfmv_v_f_f32m2(0.f, vl); + _sum2 = vfmv_v_f_f32m2(0.f, vl); + _sum3 = vfmv_v_f_f32m2(0.f, vl); + _sum4 = vfmv_v_f_f32m2(0.f, vl); + _sum5 = vfmv_v_f_f32m2(0.f, vl); + _sum6 = vfmv_v_f_f32m2(0.f, vl); + _sum7 = vfmv_v_f_f32m2(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = vfmv_v_f_f32m2(pC[0], vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = vle32_v_f32m2(pC, vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + } + if (broadcast_type_C == 3) + { + _sum0 = vle32_v_f32m2(pC, vl); + _sum1 = vle32_v_f32m2(pC + 4, vl); + _sum2 = vle32_v_f32m2(pC + 8, vl); + _sum3 = vle32_v_f32m2(pC + 12, vl); + _sum4 = vle32_v_f32m2(pC + 16, vl); + _sum5 = vle32_v_f32m2(pC + 20, vl); + _sum6 = vle32_v_f32m2(pC + 24, vl); + _sum7 = vle32_v_f32m2(pC + 28, vl); + pC += 32; + } + if (broadcast_type_C == 4) + { + _sum0 = vfmv_v_f_f32m2(pC[0], vl); + _sum1 = vfmv_v_f_f32m2(pC[1], vl); + _sum2 = vfmv_v_f_f32m2(pC[2], vl); + _sum3 = vfmv_v_f_f32m2(pC[3], vl); + _sum4 = vfmv_v_f_f32m2(pC[4], vl); + _sum5 = vfmv_v_f_f32m2(pC[5], vl); + _sum6 = vfmv_v_f_f32m2(pC[6], vl); + _sum7 = vfmv_v_f_f32m2(pC[7], vl); + pC += 8; + } + } + } + else + { + _sum0 = vle32_v_f32m2(outptr, vl); + _sum1 = vle32_v_f32m2(outptr + 4 * 1, vl); + _sum2 = vle32_v_f32m2(outptr + 4 * 2, vl); + _sum3 = vle32_v_f32m2(outptr + 4 * 3, vl); + _sum4 = vle32_v_f32m2(outptr + 4 * 4, vl); + _sum5 = vle32_v_f32m2(outptr + 4 * 5, vl); + _sum6 = vle32_v_f32m2(outptr + 4 * 6, vl); + _sum7 = vle32_v_f32m2(outptr + 4 * 7, vl); + } + + const __fp16* pA = pAT; + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat16m1_t _pA = vle16_v_f16m1(pA, vl); + + _sum0 = vfwmacc_vf_f32m2(_sum0, pB[0], _pA, vl); + _sum1 = vfwmacc_vf_f32m2(_sum1, pB[1], _pA, vl); + _sum2 = vfwmacc_vf_f32m2(_sum2, pB[2], _pA, vl); + _sum3 = vfwmacc_vf_f32m2(_sum3, pB[3], _pA, vl); + _sum4 = vfwmacc_vf_f32m2(_sum4, pB[4], _pA, vl); + _sum5 = vfwmacc_vf_f32m2(_sum5, pB[5], _pA, vl); + _sum6 = vfwmacc_vf_f32m2(_sum6, pB[6], _pA, vl); + _sum7 = vfwmacc_vf_f32m2(_sum7, pB[7], _pA, vl); + + pA += 4; + pB += 8; + } + + if (alpha != 1.f) + { + _sum0 = vfmul_vf_f32m2(_sum0, alpha, vl); + _sum1 = vfmul_vf_f32m2(_sum1, alpha, vl); + _sum2 = vfmul_vf_f32m2(_sum2, alpha, vl); + _sum3 = vfmul_vf_f32m2(_sum3, alpha, vl); + _sum4 = vfmul_vf_f32m2(_sum4, alpha, vl); + _sum5 = vfmul_vf_f32m2(_sum5, alpha, vl); + _sum6 = vfmul_vf_f32m2(_sum6, alpha, vl); + _sum7 = vfmul_vf_f32m2(_sum7, alpha, vl); + } + + if (k_end) + { + if (out_elempack == 4) + { + vse16_v_f16m1(outptr0, vfncvt_f_f_w_f16m1(_sum0, vl), vl); + vse16_v_f16m1(outptr0 + 4, vfncvt_f_f_w_f16m1(_sum1, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 2, vfncvt_f_f_w_f16m1(_sum2, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 3, vfncvt_f_f_w_f16m1(_sum3, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 4, vfncvt_f_f_w_f16m1(_sum4, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 5, vfncvt_f_f_w_f16m1(_sum5, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 6, vfncvt_f_f_w_f16m1(_sum6, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 7, vfncvt_f_f_w_f16m1(_sum7, vl), vl); + + outptr0 += 32; + } + if (out_elempack == 1) + { + vfloat16m1_t _sum0_f16 = vfncvt_f_f_w_f16m1(_sum0, vl); + vfloat16m1_t _sum1_f16 = vfncvt_f_f_w_f16m1(_sum1, vl); + vfloat16m1_t _sum2_f16 = vfncvt_f_f_w_f16m1(_sum2, vl); + vfloat16m1_t _sum3_f16 = vfncvt_f_f_w_f16m1(_sum3, vl); + vfloat16m1_t _sum4_f16 = vfncvt_f_f_w_f16m1(_sum4, vl); + vfloat16m1_t _sum5_f16 = vfncvt_f_f_w_f16m1(_sum5, vl); + vfloat16m1_t _sum6_f16 = vfncvt_f_f_w_f16m1(_sum6, vl); + vfloat16m1_t _sum7_f16 = vfncvt_f_f_w_f16m1(_sum7, vl); + + vsse16_v_f16m1(outptr0, out_hstep * sizeof(__fp16), _sum0_f16, vl); + vsse16_v_f16m1(outptr0 + 1, out_hstep * sizeof(__fp16), _sum1_f16, vl); + vsse16_v_f16m1(outptr0 + 2, out_hstep * sizeof(__fp16), _sum2_f16, vl); + vsse16_v_f16m1(outptr0 + 3, out_hstep * sizeof(__fp16), _sum3_f16, vl); + vsse16_v_f16m1(outptr0 + 4, out_hstep * sizeof(__fp16), _sum4_f16, vl); + vsse16_v_f16m1(outptr0 + 5, out_hstep * sizeof(__fp16), _sum5_f16, vl); + vsse16_v_f16m1(outptr0 + 6, out_hstep * sizeof(__fp16), _sum6_f16, vl); + vsse16_v_f16m1(outptr0 + 7, out_hstep * sizeof(__fp16), _sum7_f16, vl); + + // transpose4x8_ps(_sum0, _sum1, _sum2, _sum3, _sum4, _sum5, _sum6, _sum7); + + outptr0 += 8; + } + } + else + { + vse32_v_f32m2(outptr, _sum0, vl); + vse32_v_f32m2(outptr + 4, _sum1, vl); + vse32_v_f32m2(outptr + 4 * 2, _sum2, vl); + vse32_v_f32m2(outptr + 4 * 3, _sum3, vl); + vse32_v_f32m2(outptr + 4 * 4, _sum4, vl); + vse32_v_f32m2(outptr + 4 * 5, _sum5, vl); + vse32_v_f32m2(outptr + 4 * 6, _sum6, vl); + vse32_v_f32m2(outptr + 4 * 7, _sum7, vl); + } + + outptr += 32; + } + for (; jj + 3 < max_jj; jj += 4) + { + vfloat32m2_t _sum0; + vfloat32m2_t _sum1; + vfloat32m2_t _sum2; + vfloat32m2_t _sum3; + + vl = 4; + + if (k == 0) + { + _sum0 = vfmv_v_f_f32m2(0.f, vl); + _sum1 = vfmv_v_f_f32m2(0.f, vl); + _sum2 = vfmv_v_f_f32m2(0.f, vl); + _sum3 = vfmv_v_f_f32m2(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = vfmv_v_f_f32m2(pC[0], vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = vle32_v_f32m2(pC, vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + } + if (broadcast_type_C == 3) + { + _sum0 = vle32_v_f32m2(pC, vl); + _sum1 = vle32_v_f32m2(pC + 4, vl); + _sum2 = vle32_v_f32m2(pC + 8, vl); + _sum3 = vle32_v_f32m2(pC + 12, vl); + pC += 16; + } + if (broadcast_type_C == 4) + { + _sum0 = vfmv_v_f_f32m2(pC[0], vl); + _sum1 = vfmv_v_f_f32m2(pC[1], vl); + _sum2 = vfmv_v_f_f32m2(pC[2], vl); + _sum3 = vfmv_v_f_f32m2(pC[3], vl); + pC += 4; + } + } + } + else + { + _sum0 = vle32_v_f32m2(outptr, vl); + _sum1 = vle32_v_f32m2(outptr + 4 * 1, vl); + _sum2 = vle32_v_f32m2(outptr + 4 * 2, vl); + _sum3 = vle32_v_f32m2(outptr + 4 * 3, vl); + } + + const __fp16* pA = pAT; + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat16m1_t _pA = vle16_v_f16m1(pA, vl); + + _sum0 = vfwmacc_vf_f32m2(_sum0, pB[0], _pA, vl); + _sum1 = vfwmacc_vf_f32m2(_sum1, pB[1], _pA, vl); + _sum2 = vfwmacc_vf_f32m2(_sum2, pB[2], _pA, vl); + _sum3 = vfwmacc_vf_f32m2(_sum3, pB[3], _pA, vl); + + pA += 4; + pB += 4; + } + + if (alpha != 1.f) + { + _sum0 = vfmul_vf_f32m2(_sum0, alpha, vl); + _sum1 = vfmul_vf_f32m2(_sum1, alpha, vl); + _sum2 = vfmul_vf_f32m2(_sum2, alpha, vl); + _sum3 = vfmul_vf_f32m2(_sum3, alpha, vl); + } + + if (k_end) + { + if (out_elempack == 4) + { + vse16_v_f16m1(outptr0, vfncvt_f_f_w_f16m1(_sum0, vl), vl); + vse16_v_f16m1(outptr0 + 4, vfncvt_f_f_w_f16m1(_sum1, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 2, vfncvt_f_f_w_f16m1(_sum2, vl), vl); + vse16_v_f16m1(outptr0 + 4 * 3, vfncvt_f_f_w_f16m1(_sum3, vl), vl); + + outptr0 += 16; + } + if (out_elempack == 1) + { + vfloat16m1_t _sum0_f16 = vfncvt_f_f_w_f16m1(_sum0, vl); + vfloat16m1_t _sum1_f16 = vfncvt_f_f_w_f16m1(_sum1, vl); + vfloat16m1_t _sum2_f16 = vfncvt_f_f_w_f16m1(_sum2, vl); + vfloat16m1_t _sum3_f16 = vfncvt_f_f_w_f16m1(_sum3, vl); + + vsse16_v_f16m1(outptr0, out_hstep * sizeof(__fp16), _sum0_f16, vl); + vsse16_v_f16m1(outptr0 + 1, out_hstep * sizeof(__fp16), _sum1_f16, vl); + vsse16_v_f16m1(outptr0 + 2, out_hstep * sizeof(__fp16), _sum2_f16, vl); + vsse16_v_f16m1(outptr0 + 3, out_hstep * sizeof(__fp16), _sum3_f16, vl); + // transpose4x4_ps(_sum0, _sum1, _sum2, _sum3); + + outptr0 += 4; + } + } + else + { + vse32_v_f32m2(outptr, _sum0, vl); + vse32_v_f32m2(outptr + 4, _sum1, vl); + vse32_v_f32m2(outptr + 4 * 2, _sum2, vl); + vse32_v_f32m2(outptr + 4 * 3, _sum3, vl); + } + + outptr += 16; + } + for (; jj + 1 < max_jj; jj += 2) + { + vfloat32m2_t _sum0; + vfloat32m2_t _sum1; + + vl = 4; + + if (k == 0) + { + _sum0 = vfmv_v_f_f32m2(0.f, vl); + _sum1 = vfmv_v_f_f32m2(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = vfmv_v_f_f32m2(pC[0], vl); + _sum1 = _sum0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = vle32_v_f32m2(pC, vl); + _sum1 = _sum0; + } + if (broadcast_type_C == 3) + { + _sum0 = vle32_v_f32m2(pC, vl); + _sum1 = vle32_v_f32m2(pC + 4, vl); + pC += 8; + } + if (broadcast_type_C == 4) + { + _sum0 = vfmv_v_f_f32m2(pC[0], vl); + _sum1 = vfmv_v_f_f32m2(pC[1], vl); + pC += 2; + } + } + } + else + { + _sum0 = vle32_v_f32m2(outptr, vl); + _sum1 = vle32_v_f32m2(outptr + 4, vl); + } + + const __fp16* pA = pAT; + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat16m1_t _pA = vle16_v_f16m1(pA, vl); + + _sum0 = vfwmacc_vf_f32m2(_sum0, pB[0], _pA, vl); + _sum1 = vfwmacc_vf_f32m2(_sum1, pB[1], _pA, vl); + + pA += 4; + pB += 2; + } + + if (alpha != 1.f) + { + _sum0 = vfmul_vf_f32m2(_sum0, alpha, vl); + _sum1 = vfmul_vf_f32m2(_sum1, alpha, vl); + } + + if (k_end) + { + if (out_elempack == 4) + { + vse16_v_f16m1(outptr0, vfncvt_f_f_w_f16m1(_sum0, vl), vl); + vse16_v_f16m1(outptr0 + 4, vfncvt_f_f_w_f16m1(_sum1, vl), vl); + outptr0 += 8; + } + if (out_elempack == 1) + { + vfloat16m1_t _sum0_f16 = vfncvt_f_f_w_f16m1(_sum0, vl); + vfloat16m1_t _sum1_f16 = vfncvt_f_f_w_f16m1(_sum1, vl); + vsse16_v_f16m1(outptr0, out_hstep * sizeof(__fp16), _sum0_f16, vl); + vsse16_v_f16m1(outptr0 + 1, out_hstep * sizeof(__fp16), _sum1_f16, vl); + outptr0 += 2; + } + } + else + { + vse32_v_f32m2(outptr, _sum0, vl); + vse32_v_f32m2(outptr + 4, _sum1, vl); + } + + outptr += 8; + } + for (; jj < max_jj; jj += 1) + { + vfloat32m2_t _sum0; + + vl = 4; + + if (k == 0) + { + _sum0 = vfmv_v_f_f32m2(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = vfmv_v_f_f32m2(pC[0], vl); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = vle32_v_f32m2(pC, vl); + } + if (broadcast_type_C == 3) + { + _sum0 = vle32_v_f32m2(pC, vl); + pC += 4; + } + if (broadcast_type_C == 4) + { + _sum0 = vfmv_v_f_f32m2(pC[0], vl); + pC += 1; + } + } + } + else + { + _sum0 = vle32_v_f32m2(outptr, vl); + } + + const __fp16* pA = pAT; + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat16m1_t _pA = vle16_v_f16m1(pA, vl); + + _sum0 = vfwmacc_vf_f32m2(_sum0, pB[0], _pA, vl); + + pA += 4; + pB += 1; + } + + if (alpha != 1.f) + { + _sum0 = vfmul_vf_f32m2(_sum0, alpha, vl); + } + + if (k_end) + { + if (out_elempack == 4) + { + vse16_v_f16m1(outptr0, vfncvt_f_f_w_f16m1(_sum0, vl), vl); + outptr0 += 4; + } + if (out_elempack == 1) + { + vfloat16m1_t _sum0_f16 = vfncvt_f_f_w_f16m1(_sum0, vl); + vsse16_v_f16m1(outptr0, out_hstep * sizeof(__fp16), _sum0_f16, vl); + + outptr0++; + } + } + else + { + vse32_v_f32m2(outptr, _sum0, vl); + } + + outptr += 4; + } + + pAT += max_kk * 4; + } + for (; ii + 1 < max_ii; ii += 2) + { + __fp16* outptr0 = (__fp16*)top_blob + (i + ii) * out_hstep + j; + + const __fp16* pB = pBT; + + if (pC) + { + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)CT_tile + i + ii; + } + if (broadcast_type_C == 4) + { + pC = (const float*)CT_tile + j; + } + } + + int jj = 0; +#if __riscv_vector + for (; jj + 11 < max_jj; jj += 12) + { + vfloat32m2_t _sum00; + vfloat32m2_t _sum01; + vfloat32m2_t _sum02; + vfloat32m2_t _sum10; + vfloat32m2_t _sum11; + vfloat32m2_t _sum12; + + vl = 4; + + if (k == 0) + { + _sum00 = vfmv_v_f_f32m2(0.f, vl); + _sum01 = vfmv_v_f_f32m2(0.f, vl); + _sum02 = vfmv_v_f_f32m2(0.f, vl); + _sum10 = vfmv_v_f_f32m2(0.f, vl); + _sum11 = vfmv_v_f_f32m2(0.f, vl); + _sum12 = vfmv_v_f_f32m2(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum00 = vfmv_v_f_f32m2(pC[0], vl); + _sum01 = _sum00; + _sum02 = _sum00; + _sum10 = _sum00; + _sum11 = _sum00; + _sum12 = _sum00; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum00 = vfmv_v_f_f32m2(pC[0], vl); + _sum01 = _sum00; + _sum02 = _sum00; + _sum10 = vfmv_v_f_f32m2(pC[1], vl); + _sum11 = _sum10; + _sum12 = _sum10; + } + if (broadcast_type_C == 3) + { + vlseg2e32_v_f32m2(&_sum00, &_sum10, pC, vl); + vlseg2e32_v_f32m2(&_sum01, &_sum11, pC + 8, vl); + vlseg2e32_v_f32m2(&_sum02, &_sum12, pC + 16, vl); + + pC += 24; + } + if (broadcast_type_C == 4) + { + _sum00 = vle32_v_f32m2(pC, vl); + _sum01 = vle32_v_f32m2(pC + 4, vl); + _sum02 = vle32_v_f32m2(pC + 8, vl); + _sum10 = _sum00; + _sum11 = _sum01; + _sum12 = _sum02; + pC += 12; + } + } + } + else + { + vlseg2e32_v_f32m2(&_sum00, &_sum10, outptr, vl); + vlseg2e32_v_f32m2(&_sum01, &_sum11, outptr + 8, vl); + vlseg2e32_v_f32m2(&_sum02, &_sum12, outptr + 16, vl); + // float32x4x2_t _tmp01 = vld2q_f32(outptr); + // float32x4x2_t _tmp23 = vld2q_f32(outptr + 8); + // float32x4x2_t _tmp45 = vld2q_f32(outptr + 16); + } + + const __fp16* pA = pAT; + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat16m1_t _pB0 = vle16_v_f16m1(pB, vl); + vfloat16m1_t _pB1 = vle16_v_f16m1(pB + 4, vl); + vfloat16m1_t _pB2 = vle16_v_f16m1(pB + 8, vl); + + _sum00 = vfwmacc_vf_f32m2(_sum00, pA[0], _pB0, vl); + _sum01 = vfwmacc_vf_f32m2(_sum01, pA[0], _pB1, vl); + _sum02 = vfwmacc_vf_f32m2(_sum02, pA[0], _pB2, vl); + _sum10 = vfwmacc_vf_f32m2(_sum10, pA[1], _pB0, vl); + _sum11 = vfwmacc_vf_f32m2(_sum11, pA[1], _pB1, vl); + _sum12 = vfwmacc_vf_f32m2(_sum12, pA[1], _pB2, vl); + + pA += 2; + pB += 12; + } + + if (alpha != 1.f) + { + _sum00 = vfmul_vf_f32m2(_sum00, alpha, vl); + _sum01 = vfmul_vf_f32m2(_sum01, alpha, vl); + _sum02 = vfmul_vf_f32m2(_sum02, alpha, vl); + _sum10 = vfmul_vf_f32m2(_sum10, alpha, vl); + _sum11 = vfmul_vf_f32m2(_sum11, alpha, vl); + _sum12 = vfmul_vf_f32m2(_sum12, alpha, vl); + } + + if (k_end) + { + // if (out_elempack == 1) + { + vse16_v_f16m1(outptr0, vfncvt_f_f_w_f16m1(_sum00, vl), vl); + vse16_v_f16m1(outptr0 + 4, vfncvt_f_f_w_f16m1(_sum01, vl), vl); + vse16_v_f16m1(outptr0 + 8, vfncvt_f_f_w_f16m1(_sum02, vl), vl); + vse16_v_f16m1(outptr0 + out_hstep, vfncvt_f_f_w_f16m1(_sum10, vl), vl); + vse16_v_f16m1(outptr0 + out_hstep + 4, vfncvt_f_f_w_f16m1(_sum11, vl), vl); + vse16_v_f16m1(outptr0 + out_hstep + 8, vfncvt_f_f_w_f16m1(_sum12, vl), vl); + outptr0 += 12; + } + } + else + { + vsseg2e32_v_f32m1(outptr, vget_v_f32m2_f32m1(_sum00, 0), vget_v_f32m2_f32m1(_sum10, 0), vl); + vsseg2e32_v_f32m1(outptr + 8, vget_v_f32m2_f32m1(_sum01, 0), vget_v_f32m2_f32m1(_sum11, 0), vl); + vsseg2e32_v_f32m1(outptr + 16, vget_v_f32m2_f32m1(_sum02, 0), vget_v_f32m2_f32m1(_sum12, 0), vl); + + // vsseg2e16_v_f32m1(outptr, vget_v_f32m2_f32m1(_sum00, 0), vget_v_f32m2_f32m1(_sum10, 0), vl); + // float32x4x2_t _tmp01; + // _tmp01.val[0] = _sum0; + // _tmp01.val[1] = _sum1; + // float32x4x2_t _tmp23; + // _tmp23.val[0] = _sum01; + // _tmp23.val[1] = _sum11; + // float32x4x2_t _tmp45; + // _tmp45.val[0] = _sum02; + // _tmp45.val[1] = _sum12; + } + + outptr += 24; + } +#endif // __riscv_vector + for (; jj + 7 < max_jj; jj += 8) + { + vfloat32m2_t _sum0; + vfloat32m2_t _sum1; + + vl = 8; + + if (k == 0) + { + _sum0 = vfmv_v_f_f32m2(0.f, vl); + _sum1 = vfmv_v_f_f32m2(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = vfmv_v_f_f32m2(pC[0], vl); + _sum1 = _sum0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = vfmv_v_f_f32m2(pC[0], vl); + _sum1 = vfmv_v_f_f32m2(pC[1], vl); + } + if (broadcast_type_C == 3) + { + vfloat32m1_t _tmp0; + vfloat32m1_t _tmp1; + vfloat32m1_t _tmp2; + vfloat32m1_t _tmp3; + + vlseg2e32_v_f32m1(&_tmp0, &_tmp1, pC, vl); + vlseg2e32_v_f32m1(&_tmp2, &_tmp3, pC + 8, vl); + + _sum0 = vset_v_f32m1_f32m2(_sum0, 0, _tmp0); + _sum0 = vset_v_f32m1_f32m2(_sum0, 1, _tmp2); + _sum1 = vset_v_f32m1_f32m2(_sum1, 0, _tmp1); + _sum1 = vset_v_f32m1_f32m2(_sum1, 1, _tmp3); + // float32x4x2_t _tmp01 = vld2q_f32(pC); + // float32x4x2_t _tmp23 = vld2q_f32(pC + 8); + pC += 16; + } + if (broadcast_type_C == 4) + { + _sum0 = vle32_v_f32m2(pC, vl); + _sum1 = _sum0; + pC += 8; + } + } + } + else + { + vl = 4; + vfloat32m1_t _tmp0; + vfloat32m1_t _tmp1; + vfloat32m1_t _tmp2; + vfloat32m1_t _tmp3; + + vlseg2e32_v_f32m1(&_tmp0, &_tmp1, outptr, vl); + vlseg2e32_v_f32m1(&_tmp2, &_tmp3, outptr + 8, vl); + + _sum0 = vset_v_f32m1_f32m2(_sum0, 0, _tmp0); + _sum0 = vset_v_f32m1_f32m2(_sum0, 1, _tmp2); + _sum1 = vset_v_f32m1_f32m2(_sum1, 0, _tmp1); + _sum1 = vset_v_f32m1_f32m2(_sum1, 1, _tmp3); + // float32x4x2_t _tmp01 = vld2q_f32(outptr); + // float32x4x2_t _tmp23 = vld2q_f32(outptr + 8); + } + vl = 8; + + const __fp16* pA = pAT; + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat16m1_t _pB0 = vle16_v_f16m1(pB, vl); + _sum0 = vfwmacc_vf_f32m2(_sum0, pA[0], _pB0, vl); + _sum1 = vfwmacc_vf_f32m2(_sum1, pA[1], _pB0, vl); + + pA += 2; + pB += 8; + } + + if (alpha != 1.f) + { + _sum0 = vfmul_vf_f32m2(_sum0, alpha, vl); + _sum1 = vfmul_vf_f32m2(_sum1, alpha, vl); + } + + if (k_end) + { + // if (out_elempack == 1) + { + vse16_v_f16m1(outptr0, vfncvt_f_f_w_f16m1(_sum0, vl), vl); + vse16_v_f16m1(outptr0 + out_hstep, vfncvt_f_f_w_f16m1(_sum1, vl), vl); + + outptr0 += 8; + } + } + else + { + // spmm? size => 32x32x micro kernel + vfloat32m1_t _tmp00 = vget_v_f32m2_f32m1(_sum0, 0); + vfloat32m1_t _tmp01 = vget_v_f32m2_f32m1(_sum1, 0); + vfloat32m1_t _tmp10 = vget_v_f32m2_f32m1(_sum0, 1); + vfloat32m1_t _tmp11 = vget_v_f32m2_f32m1(_sum1, 1); + vl = 4; + vsseg2e32_v_f32m1(outptr, _tmp00, _tmp01, vl); + vsseg2e32_v_f32m1(outptr + 8, _tmp10, _tmp11, vl); + // float32x4x2_t _tmp01; + // _tmp01.val[0] = _sum00; + // _tmp01.val[1] = _sum10; + // float32x4x2_t _tmp23; + // _tmp23.val[0] = _sum01; + // _tmp23.val[1] = _sum11; + } + outptr += 16; + } + for (; jj + 3 < max_jj; jj += 4) + { + vfloat32m2_t _sum0; + vfloat32m2_t _sum1; + + vl = 4; + + if (k == 0) + { + _sum0 = vfmv_v_f_f32m2(0.f, vl); + _sum1 = vfmv_v_f_f32m2(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = vfmv_v_f_f32m2(pC[0], vl); + _sum1 = _sum0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = vfmv_v_f_f32m2(pC[0], vl); + _sum1 = vfmv_v_f_f32m2(pC[1], vl); + } + if (broadcast_type_C == 3) + { + vfloat32m1_t _tmp0; + vfloat32m1_t _tmp1; + vlseg2e32_v_f32m1(&_tmp0, &_tmp1, pC, vl); + _sum0 = vset_v_f32m1_f32m2(_sum0, 0, _tmp0); + _sum1 = vset_v_f32m1_f32m2(_sum1, 0, _tmp1); + // float32x4x2_t _tmp01 = vld2q_f32(pC); + pC += 8; + } + if (broadcast_type_C == 4) + { + _sum0 = vle32_v_f32m2(pC, vl); + _sum1 = _sum0; + pC += 4; + } + } + } + else + { + vfloat32m1_t _tmp0; + vfloat32m1_t _tmp1; + vlseg2e32_v_f32m1(&_tmp0, &_tmp1, outptr, vl); + _sum0 = vset_v_f32m1_f32m2(_sum0, 0, _tmp0); + _sum1 = vset_v_f32m1_f32m2(_sum1, 0, _tmp1); + // float32x4x2_t _tmp01 = vuzpq_f32(_tmp0, _tmp1); + } + + const __fp16* pA = pAT; + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat16m1_t _pB = vle16_v_f16m1(pB, vl); + + _sum0 = vfwmacc_vf_f32m2(_sum0, pA[0], _pB, vl); + _sum1 = vfwmacc_vf_f32m2(_sum1, pA[1], _pB, vl); + // _pB0 = vslideup_vx_f16m1(_pB0, 4, vl); + + pA += 2; + pB += 4; + } + + if (alpha != 1.f) + { + _sum0 = vfmul_vf_f32m2(_sum0, alpha, vl); + _sum1 = vfmul_vf_f32m2(_sum1, alpha, vl); + } + + if (k_end) + { + // if (out_elempack == 1) + { + vse16_v_f16m1(outptr0, vfncvt_f_f_w_f16m1(_sum0, vl), vl); + vse16_v_f16m1(outptr0 + out_hstep, vfncvt_f_f_w_f16m1(_sum1, vl), vl); + outptr0 += 4; + } + } + else + { + vsseg2e32_v_f32m2(outptr, _sum0, _sum1, vl); + // float32x4x2_t _tmp01; + // _tmp01.val[0] = _sum0; + // _tmp01.val[1] = _sum1; + } + + outptr += 8; + } + for (; jj + 1 < max_jj; jj += 2) + { + float sum00; + float sum01; + float sum10; + float sum11; + + if (k == 0) + { + sum00 = 0.f; + sum01 = 0.f; + sum10 = 0.f; + sum11 = 0.f; + + if (pC) + { + if (broadcast_type_C == 0) + { + sum00 = pC[0]; + sum01 = pC[0]; + sum10 = pC[0]; + sum11 = pC[0]; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + sum00 = pC[0]; + sum01 = pC[1]; + sum10 = pC[0]; + sum11 = pC[1]; + } + if (broadcast_type_C == 3) + { + sum00 = pC[0]; + sum01 = pC[1]; + sum10 = pC[2]; + sum11 = pC[3]; + pC += 4; + } + if (broadcast_type_C == 4) + { + sum00 = pC[0]; + sum01 = pC[0]; + sum10 = pC[1]; + sum11 = pC[1]; + pC += 2; + } + } + } + else + { + sum00 = outptr[0]; + sum01 = outptr[1]; + sum10 = outptr[2]; + sum11 = outptr[3]; + } + + const __fp16* pA = pAT; + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + __fp16 pA0 = pA[0]; + __fp16 pA1 = pA[1]; + __fp16 pB0 = pB[0]; + __fp16 pB1 = pB[1]; + + sum00 += pA0 * pB0; + sum01 += pA1 * pB0; + sum10 += pA0 * pB1; + sum11 += pA1 * pB1; + + pA += 2; + pB += 2; + } + + if (alpha != 1.f) + { + sum00 *= alpha; + sum01 *= alpha; + sum10 *= alpha; + sum11 *= alpha; + } + + if (k_end) + { + // if (out_elempack == 1) + { + outptr0[0] = (__fp16)(sum00); + outptr0[1] = (__fp16)(sum10); + outptr0[out_hstep] = (__fp16)(sum01); + outptr0[out_hstep + 1] = (__fp16)(sum11); + outptr0 += 2; + } + } + else + { + outptr[0] = sum00; + outptr[1] = sum01; + outptr[2] = sum10; + outptr[3] = sum11; + } + + outptr += 4; + } + for (; jj < max_jj; jj += 1) + { + float _sum0; + float _sum1; + + if (k == 0) + { + _sum0 = 0.f; + _sum1 = 0.f; + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = pC[0]; + _sum1 = pC[0]; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = pC[0]; + _sum1 = pC[1]; + } + if (broadcast_type_C == 3) + { + _sum0 = pC[0]; + _sum1 = pC[1]; + pC += 2; + } + if (broadcast_type_C == 4) + { + _sum0 = pC[0]; + _sum1 = pC[0]; + pC += 1; + } + } + } + else + { + _sum0 = outptr[0]; + _sum1 = outptr[1]; + } + + const __fp16* pA = pAT; + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + __fp16 pA0 = pA[0]; + __fp16 pA1 = pA[1]; + __fp16 pB0 = pB[0]; + + _sum0 += pA0 * pB0; + _sum1 += pA1 * pB0; + pA += 2; + pB += 1; + } + + if (alpha != 1.f) + { + _sum0 *= alpha; + _sum1 *= alpha; + } + + if (k_end) + { + // if (out_elempack == 1) + { + outptr0[0] = (__fp16)(_sum0); + outptr0[out_hstep] = (__fp16)(_sum1); + outptr0++; + } + } + else + { + outptr[0] = _sum0; + outptr[1] = _sum1; + } + + outptr += 2; + } + + pAT += max_kk * 2; + } + for (; ii < max_ii; ii += 1) + { + __fp16* outptr0 = (__fp16*)top_blob + (i + ii) * out_hstep + j; + + const __fp16* pB = pBT; + + if (pC) + { + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)CT_tile + i + ii; + } + if (broadcast_type_C == 4) + { + pC = (const float*)CT_tile + j; + } + } + + int jj = 0; +#if __riscv_vector + for (; jj + 11 < max_jj; jj += 12) + { + vfloat32m2_t _sum0; + vfloat32m2_t _sum1; + vfloat32m2_t _sum2; + + vl = 4; + + if (k == 0) + { + _sum0 = vfmv_v_f_f32m2(0.f, vl); + _sum1 = vfmv_v_f_f32m2(0.f, vl); + _sum2 = vfmv_v_f_f32m2(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = vfmv_v_f_f32m2(pC[0], vl); + _sum1 = vfmv_v_f_f32m2(pC[0], vl); + _sum2 = vfmv_v_f_f32m2(pC[0], vl); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + _sum0 = vle32_v_f32m2(pC, vl); + _sum1 = vle32_v_f32m2(pC + 4, vl); + _sum2 = vle32_v_f32m2(pC + 8, vl); + pC += 12; + } + } + } + else + { + _sum0 = vle32_v_f32m2(outptr, vl); + _sum1 = vle32_v_f32m2(outptr + 4, vl); + _sum2 = vle32_v_f32m2(outptr + 8, vl); + } + + const __fp16* pA = pAT; + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat16m1_t _pB0 = vle16_v_f16m1(pB, vl); + vfloat16m1_t _pB1 = vle16_v_f16m1(pB + 4, vl); + vfloat16m1_t _pB2 = vle16_v_f16m1(pB + 8, vl); + + _sum0 = vfwmacc_vf_f32m2(_sum0, pA[0], _pB0, vl); + _sum1 = vfwmacc_vf_f32m2(_sum1, pA[0], _pB1, vl); + _sum2 = vfwmacc_vf_f32m2(_sum2, pA[0], _pB2, vl); + + pA += 1; + pB += 12; + } + + if (alpha != 1.f) + { + _sum0 = vfmul_vf_f32m2(_sum0, alpha, vl); + _sum1 = vfmul_vf_f32m2(_sum1, alpha, vl); + _sum2 = vfmul_vf_f32m2(_sum2, alpha, vl); + } + + if (k_end) + { + // if (out_elempack == 1) + { + vse16_v_f16m1(outptr0, vfncvt_f_f_w_f16m1(_sum0, vl), vl); + vse16_v_f16m1(outptr0 + 4, vfncvt_f_f_w_f16m1(_sum1, vl), vl); + vse16_v_f16m1(outptr0 + 8, vfncvt_f_f_w_f16m1(_sum2, vl), vl); + outptr0 += 12; + } + } + else + { + vse32_v_f32m2(outptr, _sum0, vl); + vse32_v_f32m2(outptr + 4, _sum1, vl); + vse32_v_f32m2(outptr + 8, _sum2, vl); + } + + outptr += 12; + } +#endif // __riscv_vector + for (; jj + 7 < max_jj; jj += 8) + { + vfloat32m2_t _sum0; + vfloat32m2_t _sum1; + vl = 4; + + if (k == 0) + { + _sum0 = vfmv_v_f_f32m2(0.f, vl); + _sum1 = vfmv_v_f_f32m2(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = vfmv_v_f_f32m2(pC[0], vl); + _sum1 = vfmv_v_f_f32m2(pC[0], vl); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + _sum0 = vle32_v_f32m2(pC, vl); + _sum1 = vle32_v_f32m2(pC + 4, vl); + pC += 8; + } + } + } + else + { + _sum0 = vle32_v_f32m2(outptr, vl); + _sum1 = vle32_v_f32m2(outptr + 4, vl); + } + + const __fp16* pA = pAT; + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat16m1_t _pB0 = vle16_v_f16m1(pB, vl); + vfloat16m1_t _pB1 = vle16_v_f16m1(pB + 4, vl); + + _sum0 = vfwmacc_vf_f32m2(_sum0, pA[0], _pB0, vl); + _sum1 = vfwmacc_vf_f32m2(_sum1, pA[0], _pB1, vl); + + pA += 1; + pB += 8; + } + + if (alpha != 1.f) + { + _sum0 = vfmul_vf_f32m2(_sum0, alpha, vl); + _sum1 = vfmul_vf_f32m2(_sum1, alpha, vl); + } + + if (k_end) + { + // if (out_elempack == 1) + { + vse16_v_f16m1(outptr0, vfncvt_f_f_w_f16m1(_sum0, vl), vl); + vse16_v_f16m1(outptr0 + 4, vfncvt_f_f_w_f16m1(_sum1, vl), vl); + outptr0 += 8; + } + } + else + { + vse32_v_f32m2(outptr, _sum0, vl); + vse32_v_f32m2(outptr + 4, _sum1, vl); + } + + outptr += 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + vfloat32m2_t _sum; + vl = 4; + + if (k == 0) + { + _sum = vfmv_v_f_f32m2(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum = vfmv_v_f_f32m2(pC[0], vl); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + _sum = vle32_v_f32m2(pC, vl); + pC += 4; + } + } + } + else + { + _sum = vle32_v_f32m2(outptr, vl); + } + + const __fp16* pA = pAT; + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat16m1_t _pB = vle16_v_f16m1(pB, vl); + + _sum = vfwmacc_vf_f32m2(_sum, pA[0], _pB, vl); + + pA += 1; + pB += 4; + } + + if (alpha != 1.f) + { + _sum = vfmul_vf_f32m2(_sum, alpha, vl); + } + + if (k_end) + { + // if (out_elempack == 1) + { + vse16_v_f16m1(outptr0, vfncvt_f_f_w_f16m1(_sum, vl), vl); + outptr0 += 4; + } + } + else + { + vse32_v_f32m2(outptr, _sum, vl); + } + + outptr += 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + float _sum0; + float _sum1; + + if (k == 0) + { + _sum0 = 0.f; + _sum1 = 0.f; + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = pC[0]; + _sum1 = pC[0]; + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + _sum0 = pC[0]; + _sum1 = pC[1]; + pC += 2; + } + } + } + else + { + _sum0 = outptr[0]; + _sum1 = outptr[1]; + } + + const __fp16* pA = pAT; + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + __fp16 pA0 = pA[0]; + __fp16 pB0 = pB[0]; + __fp16 pB1 = pB[1]; + + _sum0 += pA0 * pB0; + _sum1 += pA0 * pB1; + + pA += 1; + pB += 2; + } + + if (alpha != 1.f) + { + _sum0 *= alpha; + _sum1 *= alpha; + } + + if (k_end) + { + // if (out_elempack == 1) + { + outptr0[0] = (__fp16)(_sum0); + outptr0[1] = (__fp16)(_sum1); + outptr0 += 2; + } + } + else + { + outptr[0] = _sum0; + outptr[1] = _sum1; + } + + outptr += 2; + } + for (; jj < max_jj; jj += 1) + { + float sum; + + if (k == 0) + { + sum = 0.f; + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + sum = pC[0]; + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + sum = pC[0]; + pC += 1; + } + } + } + else + { + sum = outptr[0]; + } + + const __fp16* pA = pAT; + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + __fp16 pA0 = pA[0]; + __fp16 pB0 = pB[0]; + + sum += pA0 * pB0; + pA += 1; + pB += 1; + } + + if (alpha != 1.f) + { + sum *= alpha; + } + + if (k_end) + { + // if (out_elempack == 1) + { + outptr0[0] = (__fp16)(sum); + outptr0++; + } + } + else + { + outptr[0] = sum; + } + + outptr += 1; + } + + pAT += max_kk; + } +} diff --git a/src/layer/riscv/gemm_riscv.cpp b/src/layer/riscv/gemm_riscv.cpp index 9b4b58ac6510..be48b43f9ac3 100644 --- a/src/layer/riscv/gemm_riscv.cpp +++ b/src/layer/riscv/gemm_riscv.cpp @@ -28,6 +28,9 @@ Gemm_riscv::Gemm_riscv() { #if __riscv_vector support_packing = true; +#if __riscv_zfh + support_fp16_storage = true; +#endif #endif // __riscv_vector one_blob_only = false; support_inplace = false; @@ -38,13 +41,16 @@ Gemm_riscv::Gemm_riscv() // even if the current hardware provides vector registers of more than 128 bits, // vl=4 is still used, even though this will waste the width of the vector register. vl = vsetvlmax_e32m1(); - vl = vl >= 4 ? 4 : vl; + // vl = vl >= 4 ? 4 : vl; #else vl = 0; #endif // __riscv_vector } -static void pack_A_tile(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, size_t vl) +#include "gemm_bf16s_fp16s.h" +#include "gemm_fp16s.h" + +void pack_A_tile(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, size_t vl) { const int elempack = A.elempack; const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; @@ -55,6 +61,18 @@ static void pack_A_tile(const Mat& A, Mat& AT, int i, int max_ii, int k, int max #if __riscv_vector for (; ii + 7 < max_ii; ii += 8) { + if (elempack == 8) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k * 8; + + for (int kk = 0; kk < max_kk; kk++) + { + vse32_v_f32m1(pp, vle32_v_f32m1(p0, vl), vl); + vse32_v_f32m1(pp + 4, vle32_v_f32m1(p0 + 4, vl), vl); + pp += 8; + p0 += 8; + } + } if (elempack == 4) { const float* p0 = (const float*)A + (i + ii) * A_hstep + k * 4; @@ -3096,7 +3114,6 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons sum01 += pA[1] * pB[0]; sum10 += pA[0] * pB[1]; sum11 += pA[1] * pB[1]; - pA += 2; pB += 2; } @@ -3936,17 +3953,12 @@ static int gemm_AT_BT_riscv(const Mat& AT, const Mat& BT, const Mat& C, Mat& top return 0; } -int Gemm_riscv::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const -{ - std::vector bottom_blobs(1, bottom_blob); - std::vector top_blobs(1, top_blob); - int ret = forward(bottom_blobs, top_blobs, opt); - top_blob = top_blobs[0]; - return ret; -} - int Gemm_riscv::create_pipeline(const Option& opt) { + if (support_fp16_storage && opt.use_fp16_storage) + { + return create_pipeline_fp16s(opt); + } if (constantA) { const int M = constantM; @@ -4067,6 +4079,13 @@ int Gemm_riscv::create_pipeline(const Option& opt) int Gemm_riscv::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { + const Mat& bottom_blob = constantA ? AT_data : bottom_blobs[0]; + int elembits = bottom_blob.elembits(); + if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) + { + return forward_fp16s(bottom_blobs, top_blobs, opt); + } + int M; int N; if (constantA && constantB) @@ -4248,4 +4267,653 @@ int Gemm_riscv::forward(const std::vector& bottom_blobs, std::vector& return 0; } +// Add riscv fp16 by Xinyu Yang +static int gemm_riscv_fp16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int transA, int transB, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) +{ + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + + // NCNN_LOGE("M/N/K = %d %d %d", M, N, K); + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_bf16s_fp16s(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); + + // NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K); + + int nn_M = (M + TILE_M - 1) / TILE_M; + int nn_N = (N + TILE_N - 1) / TILE_N; + int nn_K = (K + TILE_K - 1) / TILE_K; + + Mat ATX(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, nT, 2u, opt.workspace_allocator); + Mat BT(TILE_K * TILE_N, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 2u, opt.workspace_allocator); + + const int nn_NK = nn_N * nn_K; + + // pack B + #pragma omp parallel for num_threads(nT) + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (transB) + { + pack_B_tile_bf16_fp16(B, BT_tile, j, max_jj, k, max_kk); + } + else + { + transpose_pack_B_tile_bf16_fp16(B, BT_tile, j, max_jj, k, max_kk); + } + } + + Mat topT; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT.create(TILE_N * TILE_M, 1, nT, 4u, opt.workspace_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppi = 0; ppi < nn_M; ppi++) + { + const int i = ppi * TILE_M; + + // shadowed variable for less openmp task args + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + + const int max_ii = std::min((M - i), TILE_M); + + Mat topT_tile; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT_tile = topT.channel(get_omp_thread_num()); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + if (broadcast_type_C == 3) + { + pack_A_tile(C, topT_tile, i, max_ii, j, max_jj, 4); + } + + const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); + + Mat AT_tile = ATX.channel(get_omp_thread_num()).row_range(k / TILE_K, 1); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (j == 0) + { + if (transA) + { + transpose_pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk); + } + else + { + pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk); + } + } + + bool k_end = !output_transpose && k + TILE_K >= K; + float _alpha = k + TILE_K >= K ? alpha : 1.f; + + gemm_transB_packed_tile_fp16s(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, _alpha, i, max_ii, j, max_jj, k, max_kk, k_end); + } + + if (output_transpose) + { + transpose_unpack_output_tile_fp32_to_fp16(topT_tile, top_blob, i, max_ii, j, max_jj); + } + } + } + + return 0; +} + +static int gemm_AT_riscv_fp16s(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int K, int transB, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) +{ + const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + + // NCNN_LOGE("M/N/K = %d %d %d", M, N, K); + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_bf16s_fp16s(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); + + // NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K); + + int nn_M = (M + TILE_M - 1) / TILE_M; + int nn_N = (N + TILE_N - 1) / TILE_N; + int nn_K = (K + TILE_K - 1) / TILE_K; + + Mat BT(TILE_K * TILE_N, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 2u, opt.workspace_allocator); + + const int nn_NK = nn_N * nn_K; + + // pack B + #pragma omp parallel for num_threads(nT) + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (transB) + { + pack_B_tile_bf16_fp16(B, BT_tile, j, max_jj, k, max_kk); + } + else + { + transpose_pack_B_tile_bf16_fp16(B, BT_tile, j, max_jj, k, max_kk); + } + } + + Mat topT; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT.create(TILE_N * TILE_M, 1, nT, 4u, opt.workspace_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppi = 0; ppi < nn_M; ppi++) + { + const int i = ppi * TILE_M; + + const int max_ii = std::min((M - i), TILE_M); + + Mat topT_tile; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT_tile = topT.channel(get_omp_thread_num()); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + if (broadcast_type_C == 3) + { + pack_A_tile(C, topT_tile, i, max_ii, j, max_jj, 4); + } + + const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); + + Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + bool k_end = !output_transpose && k + TILE_K >= K; + float _alpha = k + TILE_K >= K ? alpha : 1.f; + + gemm_transB_packed_tile_fp16s(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, _alpha, i, max_ii, j, max_jj, k, max_kk, k_end); + } + + if (output_transpose) + { + transpose_unpack_output_tile_fp32_to_fp16(topT_tile, top_blob, i, max_ii, j, max_jj); + } + } + } + + return 0; +} + +static int gemm_BT_riscv_fp16s(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int N, int K, int transA, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) +{ + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + + // NCNN_LOGE("M/N/K = %d %d %d", M, N, K); + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_bf16s_fp16s(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); + + // NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K); + + int nn_M = (M + TILE_M - 1) / TILE_M; + // int nn_N = (N + TILE_N - 1) / TILE_N; + + Mat ATX(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, nT, 2u, opt.workspace_allocator); + + Mat topT; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT.create(TILE_N * TILE_M, 1, nT, 4u, opt.workspace_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppi = 0; ppi < nn_M; ppi++) + { + const int i = ppi * TILE_M; + + // shadowed variable for less openmp task args + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + + const int max_ii = std::min((M - i), TILE_M); + + Mat topT_tile; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT_tile = topT.channel(get_omp_thread_num()); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + if (broadcast_type_C == 3) + { + pack_A_tile(C, topT_tile, i, max_ii, j, max_jj, 4); + } + + const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); + + Mat AT_tile = ATX.channel(get_omp_thread_num()).row_range(k / TILE_K, 1); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (j == 0) + { + if (transA) + { + transpose_pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk); + } + else + { + pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk); + } + } + + bool k_end = !output_transpose && k + TILE_K >= K; + float _alpha = k + TILE_K >= K ? alpha : 1.f; + + gemm_transB_packed_tile_fp16s(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, _alpha, i, max_ii, j, max_jj, k, max_kk, k_end); + } + + if (output_transpose) + { + transpose_unpack_output_tile_fp32_to_fp16(topT_tile, top_blob, i, max_ii, j, max_jj); + } + } + } + + return 0; +} + +static int gemm_AT_BT_riscv_fp16s(const Mat& AT, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int N, int K, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) +{ + // NCNN_LOGE("M/N/K = %d %d %d", M, N, K); + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_bf16s_fp16s(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); + + // NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K); + + int nn_M = (M + TILE_M - 1) / TILE_M; + // int nn_N = (N + TILE_N - 1) / TILE_N; + + Mat topT; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT.create(TILE_N * TILE_M, 1, nT, 4u, opt.workspace_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppi = 0; ppi < nn_M; ppi++) + { + const int i = ppi * TILE_M; + + const int max_ii = std::min((M - i), TILE_M); + + Mat topT_tile; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT_tile = topT.channel(get_omp_thread_num()); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + if (broadcast_type_C == 3) + { + pack_A_tile(C, topT_tile, i, max_ii, j, max_jj, 4); + } + + const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); + + Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + bool k_end = !output_transpose && k + TILE_K >= K; + float _alpha = k + TILE_K >= K ? alpha : 1.f; + + gemm_transB_packed_tile_fp16s(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, _alpha, i, max_ii, j, max_jj, k, max_kk, k_end); + } + + if (output_transpose) + { + transpose_unpack_output_tile_fp32_to_fp16(topT_tile, top_blob, i, max_ii, j, max_jj); + } + } + } + + return 0; +} + +int Gemm_riscv::create_pipeline_fp16s(const Option& opt) +{ + if (constantA) + { + const int M = constantM; + const int K = constantK; + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_bf16s_fp16s(M, 0, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, opt.num_threads); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + + AT_data.create(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, (M + TILE_M - 1) / TILE_M, 2u, (Allocator*)0); + if (AT_data.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int ppj = 0; ppj < nn_M; ppj++) + { + const int i = ppj * TILE_M; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_ii = std::min((M - i), TILE_M); + const int max_kk = std::min((K - k), TILE_K); + + Mat AT_tile = AT_data.channel(i / TILE_M).row_range(k / TILE_K, 1); + + if (transA) + { + transpose_pack_A_tile_fp32_to_fp16(A_data, AT_tile, i, max_ii, k, max_kk); + } + else + { + pack_A_tile_fp32_to_fp16(A_data, AT_tile, i, max_ii, k, max_kk); + } + } + } + + A_data.release(); + } + + if (constantB) + { + const int N = constantN; + const int K = constantK; + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_bf16s_fp16s(0, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, opt.num_threads); + + const int nn_N = (N + TILE_N - 1) / TILE_N; + + BT_data.create(TILE_K * TILE_N, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 2u, (Allocator*)0); + if (BT_data.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int ppj = 0; ppj < nn_N; ppj++) + { + const int j = ppj * TILE_N; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat BT_tile = BT_data.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (transB) + { + pack_B_tile_fp32_to_fp16(B_data, BT_tile, j, max_jj, k, max_kk); + } + else + { + transpose_pack_B_tile_fp32_to_fp16(B_data, BT_tile, j, max_jj, k, max_kk); + } + } + } + + B_data.release(); + } + + if (constantC && constant_broadcast_type_C != -1) + { + CT_data = C_data; + + if (constant_broadcast_type_C == 3 && opt.use_packing_layout) + { + int C_elempack = constantM % 4 == 0 ? 4 : 1; + convert_packing(C_data, CT_data, C_elempack, opt); + } + + // pre-multiply C with beta + if (beta != 1.f) + { + Mat C2; + C2.create_like(CT_data); + + const int size = CT_data.total() * CT_data.elempack; + for (int i = 0; i < size; i++) + { + C2[i] = CT_data[i] * beta; + } + + CT_data = C2; + } + + C_data.release(); + } + + if (constantA || constantB || constantC) + { + nT = opt.num_threads; + } + + return 0; +} + +int Gemm_riscv::forward_fp16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + int M; + int N; + if (constantA && constantB) + { + M = constantM; + N = constantN; + } + else if (constantA) + { + const Mat& B = bottom_blobs[0]; + M = constantM; + N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + } + else if (constantB) + { + const Mat& A = bottom_blobs[0]; + M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + N = constantN; + } + else + { + const Mat& A = bottom_blobs[0]; + const Mat& B = bottom_blobs[1]; + M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + } + + Mat C; + int broadcast_type_C = 0; + if (constantC) + { + C = CT_data; + broadcast_type_C = constant_broadcast_type_C; + } + else + { + if (constantA && constantB) + { + C = bottom_blobs.size() == 1 ? bottom_blobs[0] : Mat(); + } + else if (constantA) + { + C = bottom_blobs.size() == 2 ? bottom_blobs[1] : Mat(); + } + else if (constantB) + { + C = bottom_blobs.size() == 2 ? bottom_blobs[1] : Mat(); + } + else + { + C = bottom_blobs.size() == 3 ? bottom_blobs[2] : Mat(); + } + + if (!C.empty()) + { + if (C.dims == 1 && C.w == 1) + { + // scalar + broadcast_type_C = 0; + } + if (C.dims == 1 && C.w * C.elempack == M) + { + // M + // auto broadcast from h to w is the ncnn-style convention + broadcast_type_C = 1; + } + if (C.dims == 1 && C.w * C.elempack == N) + { + // N + broadcast_type_C = 4; + } + if (C.dims == 2 && C.w == 1 && C.h * C.elempack == M) + { + // Mx1 + broadcast_type_C = 2; + } + if (C.dims == 2 && C.w == N && C.h * C.elempack == M) + { + // MxN + broadcast_type_C = 3; + } + if (C.dims == 2 && C.w == N && C.h * C.elempack == 1) + { + // 1xN + broadcast_type_C = 4; + } + + // cast to fp32 + { + Mat CT_data; + cast_float16_to_float32(C, CT_data); + C = CT_data; + } + // pre-multiply C with beta + if (beta != 1.f) + { + Mat CT_data; + CT_data.create_like(C, opt.workspace_allocator); + + const int size = C.total() * C.elempack; + for (int i = 0; i < size; i++) + { + CT_data[i] = C[i] * beta; + } + + C = CT_data; + } + } + } + + int out_elempack = 1; + if (opt.use_packing_layout) + { + int outh = output_transpose ? N : M; + out_elempack = outh % 4 == 0 ? 4 : 1; + } + if (output_elempack) + out_elempack = output_elempack; + size_t out_elemsize = 2u * out_elempack; + + Mat& top_blob = top_blobs[0]; + if (output_transpose) + { + if (output_N1M) + top_blob.create(M, 1, N / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + else + top_blob.create(M, N / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + } + else + { + if (output_N1M) + top_blob.create(N, 1, M / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + else + top_blob.create(N, M / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + } + if (top_blob.empty()) + return -100; + + int _nT = nT ? nT : opt.num_threads; + if (nT != 0 && opt.num_threads != nT) + { + // force num_threads the same as in create_pipeline + // so we could use pre-packed A/B from the same tile config + NCNN_LOGE("opt.num_threads %d changed, gemm will use load-time value %d", opt.num_threads, nT); + } + + int ret = 0; + if (constantA && constantB) + { + ret = gemm_AT_BT_riscv_fp16s(AT_data, BT_data, C, top_blob, broadcast_type_C, constantM, constantN, constantK, output_transpose, alpha, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt); + } + else if (constantA) + { + const Mat& B = bottom_blobs[0]; + ret = gemm_AT_riscv_fp16s(AT_data, B, C, top_blob, broadcast_type_C, constantM, constantK, transB, output_transpose, alpha, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt); + } + else if (constantB) + { + const Mat& A = bottom_blobs[0]; + ret = gemm_BT_riscv_fp16s(A, BT_data, C, top_blob, broadcast_type_C, constantN, constantK, transA, output_transpose, alpha, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt); + } + else + { + const Mat& A = bottom_blobs[0]; + const Mat& B = bottom_blobs[1]; + ret = gemm_riscv_fp16s(A, B, C, top_blob, broadcast_type_C, transA, transB, output_transpose, alpha, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt); + } + + return ret; +} + } // namespace ncnn diff --git a/src/layer/riscv/gemm_riscv.h b/src/layer/riscv/gemm_riscv.h index 6bca092fb1f2..c6acef4d1f28 100644 --- a/src/layer/riscv/gemm_riscv.h +++ b/src/layer/riscv/gemm_riscv.h @@ -26,10 +26,12 @@ class Gemm_riscv : public Gemm virtual int create_pipeline(const Option& opt); - virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; - virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; +protected: + int create_pipeline_fp16s(const Option& opt); + int forward_fp16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; + // public: int nT; size_t vl; diff --git a/src/layer/riscv/riscv_usability.h b/src/layer/riscv/riscv_usability.h index e2824646f871..12950917c0a2 100644 --- a/src/layer/riscv/riscv_usability.h +++ b/src/layer/riscv/riscv_usability.h @@ -311,6 +311,18 @@ static inline void vlseg2e16_v_u16m4(vuint16m4_t* v0, vuint16m4_t* v1, const uin #if __riscv_zfh // f16m1, vsseg.v, 8/4/2 +static inline void vsseg4e16_v_f16mf2(float16_t* base, vfloat16mf2_t v0, vfloat16mf2_t v1, vfloat16mf2_t v2, vfloat16mf2_t v3, size_t vl) +{ + vfloat16mf2x4_t _tmp = vcreate_f16mf2x4(v0, v1, v2, v3); + vsseg4e16_v_f16mf2x4(base, _tmp, vl); +} + +static inline void vsseg2e16_v_f16mf2(float16_t* base, vfloat16mf2_t v0, vfloat16mf2_t v1, size_t vl) +{ + vfloat16mf2x2_t _tmp = vcreate_f16mf2x2(v0, v1); + vsseg2e16_v_f16mf2x2(base, _tmp, vl); +} + static inline void vsseg8e16_v_f16m1(float16_t* base, vfloat16m1_t v0, vfloat16m1_t v1, vfloat16m1_t v2, vfloat16m1_t v3, vfloat16m1_t v4, vfloat16m1_t v5, vfloat16m1_t v6, vfloat16m1_t v7, size_t vl) { vfloat16m1x8_t _tmp = vcreate_f16m1x8(v0, v1, v2, v3, v4, v5, v6, v7); @@ -614,6 +626,42 @@ static inline void transpose8x4_ps(vfloat32m1_t& _r0l, vfloat32m1_t& _r0h, _r3l = vle32_v_f32m1(ptr + 6 * 4, vl); _r3h = vle32_v_f32m1(ptr + 7 * 4, vl); } + +static inline void transpose4x4_f16(vfloat16mf2_t& _r0, vfloat16mf2_t& _r1, vfloat16mf2_t& _r2, vfloat16mf2_t& _r3, size_t vl) +{ + __fp16 tmp[4][4]; + vsse16_v_f16m1(&tmp[0][0], sizeof(__fp16) * 4, _r0, vl); + vsse16_v_f16m1(&tmp[0][1], sizeof(__fp16) * 4, _r1, vl); + vsse16_v_f16m1(&tmp[0][2], sizeof(__fp16) * 4, _r2, vl); + vsse16_v_f16m1(&tmp[0][3], sizeof(__fp16) * 4, _r3, vl); + __fp16* ptr = (__fp16*)tmp; + _r0 = vle16_v_f16m1(ptr + 0 * 4, vl); + _r1 = vle16_v_f16m1(ptr + 1 * 4, vl); + _r2 = vle16_v_f16m1(ptr + 2 * 4, vl); + _r3 = vle16_v_f16m1(ptr + 3 * 4, vl); +} + +static inline void transpose8x8_f16(vfloat16m1_t& _r0, vfloat16m1_t _r1, vfloat16m1_t& _r2, vfloat16m1_t& _r3, vfloat16m1_t& _r4, vfloat16m1_t& _r5, vfloat16m1_t& _r6, vfloat16m1_t& _r7, size_t vl) +{ + __fp16 tmp[8][8]; + vsse16_v_f16m1(&tmp[0][0], sizeof(__fp16) * 8, _r0, vl); + vsse16_v_f16m1(&tmp[0][1], sizeof(__fp16) * 8, _r1, vl); + vsse16_v_f16m1(&tmp[0][2], sizeof(__fp16) * 8, _r2, vl); + vsse16_v_f16m1(&tmp[0][3], sizeof(__fp16) * 8, _r3, vl); + vsse16_v_f16m1(&tmp[0][4], sizeof(__fp16) * 8, _r4, vl); + vsse16_v_f16m1(&tmp[0][5], sizeof(__fp16) * 8, _r5, vl); + vsse16_v_f16m1(&tmp[0][6], sizeof(__fp16) * 8, _r6, vl); + vsse16_v_f16m1(&tmp[0][7], sizeof(__fp16) * 8, _r7, vl); + __fp16* ptr = (__fp16*)tmp; + _r0 = vle16_v_f16m1(ptr + 0 * 4, vl); + _r1 = vle16_v_f16m1(ptr + 1 * 4, vl); + _r2 = vle16_v_f16m1(ptr + 2 * 4, vl); + _r3 = vle16_v_f16m1(ptr + 3 * 4, vl); + _r4 = vle16_v_f16m1(ptr + 4 * 4, vl); + _r5 = vle16_v_f16m1(ptr + 5 * 4, vl); + _r6 = vle16_v_f16m1(ptr + 6 * 4, vl); + _r7 = vle16_v_f16m1(ptr + 7 * 4, vl); +} #endif #endif // RISCV_USABILITY_H