Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Oct 25, 2022
1 parent be84c30 commit add9ee5
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 113 deletions.
175 changes: 82 additions & 93 deletions src/layer/x86/innerproduct_fp.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,32 +31,21 @@ static void innerproduct_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& w
}
#else // NCNN_RUNTIME_CPU

const int num_input = bottom_blob.w;
const int num_output = top_blob.w;
const int num_input = bottom_blob.w * bottom_blob.elempack;
const int outw = top_blob.w;
const int out_elempack = top_blob.elempack;

const float* bias_data_ptr = bias_data;
NCNN_LOGE("%d %d %d", num_input, outw, out_elempack);

int out_elempack = 1;
#if __SSE2__
if (opt.use_packing_layout)
{
#if __AVX512F__
out_elempack = num_output % 16 == 0 ? 16 : num_output % 8 == 0 ? 8 : num_output % 4 == 0 ? 4 : 1;
#elif __AVX__
out_elempack = num_output % 8 == 0 ? 8 : num_output % 4 == 0 ? 4 : 1;
#else
out_elempack = num_output % 4 == 0 ? 4 : 1;
#endif
}
#endif // __SSE2__
const float* bias_data_ptr = bias_data;

#if __SSE2__
#if __AVX__
#if __AVX512F__
if (out_elempack == 16)
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int p = 0; p < num_output / out_elempack; p++)
for (int p = 0; p < outw; p++)
{
__m512 _sum0 = _mm512_setzero_ps();
__m512 _sum1 = _mm512_setzero_ps();
Expand Down Expand Up @@ -201,7 +190,7 @@ static void innerproduct_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& w
if (out_elempack == 8)
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int p = 0; p < num_output / out_elempack; p++)
for (int p = 0; p < outw; p++)
{
__m256 _sum0 = _mm256_setzero_ps();
__m256 _sum1 = _mm256_setzero_ps();
Expand Down Expand Up @@ -346,7 +335,7 @@ static void innerproduct_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& w
if (out_elempack == 4)
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int p = 0; p < num_output / out_elempack; p++)
for (int p = 0; p < outw; p++)
{
__m128 _sum0 = _mm_setzero_ps();
#if __AVX__
Expand Down Expand Up @@ -497,25 +486,25 @@ static void innerproduct_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& w
{
#if __SSE2__
#if __AVX__
int remain_num_output_start = 0;
int nn_num_output = num_output >> 3;
int remain_outw_start = 0;
int nn_outw = outw >> 3;

#pragma omp parallel for num_threads(opt.num_threads)
for (int pp = 0; pp < nn_num_output; pp++)
for (int pp = 0; pp < nn_outw; pp++)
{
int p = pp * 8;

float sums[8] = {0.0f};
if (bias_data_ptr)
{
sums[0] = bias_data[p];
sums[1] = bias_data[p + 1];
sums[2] = bias_data[p + 2];
sums[3] = bias_data[p + 3];
sums[4] = bias_data[p + 4];
sums[5] = bias_data[p + 5];
sums[6] = bias_data[p + 6];
sums[7] = bias_data[p + 7];
sums[0] = bias_data_ptr[p];
sums[1] = bias_data_ptr[p + 1];
sums[2] = bias_data_ptr[p + 2];
sums[3] = bias_data_ptr[p + 3];
sums[4] = bias_data_ptr[p + 4];
sums[5] = bias_data_ptr[p + 5];
sums[6] = bias_data_ptr[p + 6];
sums[7] = bias_data_ptr[p + 7];
}

#if NCNN_IMPL_FP16S
Expand Down Expand Up @@ -645,25 +634,25 @@ static void innerproduct_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& w
_mm256_storeu_ps(outptr + p, _sums);
}

remain_num_output_start += (nn_num_output << 3);
nn_num_output = (num_output - remain_num_output_start) >> 2;
remain_outw_start += (nn_outw << 3);
nn_outw = (outw - remain_outw_start) >> 2;
#else
int remain_num_output_start = 0;
int nn_num_output = num_output >> 2;
int remain_outw_start = 0;
int nn_outw = outw >> 2;
#endif // __AVX__

#pragma omp parallel for num_threads(opt.num_threads)
for (int pp = 0; pp < nn_num_output; pp++)
for (int pp = 0; pp < nn_outw; pp++)
{
int p = remain_num_output_start + (pp * 4);
int p = remain_outw_start + (pp * 4);

float sums[4] = {0.0f};
if (bias_data_ptr)
{
sums[0] = bias_data[p];
sums[1] = bias_data[p + 1];
sums[2] = bias_data[p + 2];
sums[3] = bias_data[p + 3];
sums[0] = bias_data_ptr[p];
sums[1] = bias_data_ptr[p + 1];
sums[2] = bias_data_ptr[p + 2];
sums[3] = bias_data_ptr[p + 3];
}

#if NCNN_IMPL_FP16S
Expand Down Expand Up @@ -787,18 +776,18 @@ static void innerproduct_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& w
_mm_storeu_ps(outptr + p, _sums);
}

remain_num_output_start += (nn_num_output << 2);
remain_outw_start += (nn_outw << 2);
#else
int remain_num_output_start = 0;
int remain_outw_start = 0;
#endif // __SSE2__

#pragma omp parallel for num_threads(opt.num_threads)
for (int p = remain_num_output_start; p < num_output; p++)
for (int p = remain_outw_start; p < outw; p++)
{
float sum = 0.f;

if (bias_data_ptr)
sum = bias_data[p];
sum = bias_data_ptr[p];

#if NCNN_IMPL_FP16S
const unsigned short* w = weight_data_tm.row<const unsigned short>(p);
Expand Down Expand Up @@ -1042,22 +1031,22 @@ static void innerproduct_transform_kernel_sse(const Mat& weight_data, Mat& weigh
_mm256_storeu_si256((__m256i*)(g0 + 16 * 14), _re);
_mm256_storeu_si256((__m256i*)(g0 + 16 * 15), _rf);
#else
__m512 _r0 = _mm512_load_ps(k0);
__m512 _r1 = _mm512_load_ps(k1);
__m512 _r2 = _mm512_load_ps(k2);
__m512 _r3 = _mm512_load_ps(k3);
__m512 _r4 = _mm512_load_ps(k4);
__m512 _r5 = _mm512_load_ps(k5);
__m512 _r6 = _mm512_load_ps(k6);
__m512 _r7 = _mm512_load_ps(k7);
__m512 _r8 = _mm512_load_ps(k8);
__m512 _r9 = _mm512_load_ps(k9);
__m512 _ra = _mm512_load_ps(ka);
__m512 _rb = _mm512_load_ps(kb);
__m512 _rc = _mm512_load_ps(kc);
__m512 _rd = _mm512_load_ps(kd);
__m512 _re = _mm512_load_ps(ke);
__m512 _rf = _mm512_load_ps(kf);
__m512 _r0 = _mm512_loadu_ps(k0);
__m512 _r1 = _mm512_loadu_ps(k1);
__m512 _r2 = _mm512_loadu_ps(k2);
__m512 _r3 = _mm512_loadu_ps(k3);
__m512 _r4 = _mm512_loadu_ps(k4);
__m512 _r5 = _mm512_loadu_ps(k5);
__m512 _r6 = _mm512_loadu_ps(k6);
__m512 _r7 = _mm512_loadu_ps(k7);
__m512 _r8 = _mm512_loadu_ps(k8);
__m512 _r9 = _mm512_loadu_ps(k9);
__m512 _ra = _mm512_loadu_ps(ka);
__m512 _rb = _mm512_loadu_ps(kb);
__m512 _rc = _mm512_loadu_ps(kc);
__m512 _rd = _mm512_loadu_ps(kd);
__m512 _re = _mm512_loadu_ps(ke);
__m512 _rf = _mm512_loadu_ps(kf);

transpose16_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, _r8, _r9, _ra, _rb, _rc, _rd, _re, _rf);

Expand All @@ -1070,7 +1059,7 @@ static void innerproduct_transform_kernel_sse(const Mat& weight_data, Mat& weigh
_mm512_storeu_ps(g0 + 16 * 6, _r6);
_mm512_storeu_ps(g0 + 16 * 7, _r7);
_mm512_storeu_ps(g0 + 16 * 8, _r8);
_mm512_storeu_ps(g0 + 16 * 9,_r9);
_mm512_storeu_ps(g0 + 16 * 9, _r9);
_mm512_storeu_ps(g0 + 16 * 10, _ra);
_mm512_storeu_ps(g0 + 16 * 11, _rb);
_mm512_storeu_ps(g0 + 16 * 12, _rc);
Expand Down Expand Up @@ -1163,22 +1152,22 @@ static void innerproduct_transform_kernel_sse(const Mat& weight_data, Mat& weigh
_mm256_storeu_si256((__m256i*)(g0 + 16 * 6), _r6e);
_mm256_storeu_si256((__m256i*)(g0 + 16 * 7), _r7f);
#else
__m256 _r0 = _mm256_load_ps(k0);
__m256 _r1 = _mm256_load_ps(k1);
__m256 _r2 = _mm256_load_ps(k2);
__m256 _r3 = _mm256_load_ps(k3);
__m256 _r4 = _mm256_load_ps(k4);
__m256 _r5 = _mm256_load_ps(k5);
__m256 _r6 = _mm256_load_ps(k6);
__m256 _r7 = _mm256_load_ps(k7);
__m256 _r8 = _mm256_load_ps(k8);
__m256 _r9 = _mm256_load_ps(k9);
__m256 _ra = _mm256_load_ps(ka);
__m256 _rb = _mm256_load_ps(kb);
__m256 _rc = _mm256_load_ps(kc);
__m256 _rd = _mm256_load_ps(kd);
__m256 _re = _mm256_load_ps(ke);
__m256 _rf = _mm256_load_ps(kf);
__m256 _r0 = _mm256_loadu_ps(k0);
__m256 _r1 = _mm256_loadu_ps(k1);
__m256 _r2 = _mm256_loadu_ps(k2);
__m256 _r3 = _mm256_loadu_ps(k3);
__m256 _r4 = _mm256_loadu_ps(k4);
__m256 _r5 = _mm256_loadu_ps(k5);
__m256 _r6 = _mm256_loadu_ps(k6);
__m256 _r7 = _mm256_loadu_ps(k7);
__m256 _r8 = _mm256_loadu_ps(k8);
__m256 _r9 = _mm256_loadu_ps(k9);
__m256 _ra = _mm256_loadu_ps(ka);
__m256 _rb = _mm256_loadu_ps(kb);
__m256 _rc = _mm256_loadu_ps(kc);
__m256 _rd = _mm256_loadu_ps(kd);
__m256 _re = _mm256_loadu_ps(ke);
__m256 _rf = _mm256_loadu_ps(kf);

__m256 _tmp0 = _mm256_unpacklo_ps(_r0, _r1);
__m256 _tmp1 = _mm256_unpackhi_ps(_r0, _r1);
Expand Down Expand Up @@ -1397,14 +1386,14 @@ static void innerproduct_transform_kernel_sse(const Mat& weight_data, Mat& weigh
_mm256_storeu_si256((__m256i*)(g0 + 16 * 6), _r6);
_mm256_storeu_si256((__m256i*)(g0 + 16 * 7), _r7);
#else
__m512 _r0 = _mm512_load_ps(k0);
__m512 _r1 = _mm512_load_ps(k1);
__m512 _r2 = _mm512_load_ps(k2);
__m512 _r3 = _mm512_load_ps(k3);
__m512 _r4 = _mm512_load_ps(k4);
__m512 _r5 = _mm512_load_ps(k5);
__m512 _r6 = _mm512_load_ps(k6);
__m512 _r7 = _mm512_load_ps(k7);
__m512 _r0 = _mm512_loadu_ps(k0);
__m512 _r1 = _mm512_loadu_ps(k1);
__m512 _r2 = _mm512_loadu_ps(k2);
__m512 _r3 = _mm512_loadu_ps(k3);
__m512 _r4 = _mm512_loadu_ps(k4);
__m512 _r5 = _mm512_loadu_ps(k5);
__m512 _r6 = _mm512_loadu_ps(k6);
__m512 _r7 = _mm512_loadu_ps(k7);

__m512 _tmp0 = _mm512_unpacklo_ps(_r0, _r1);
__m512 _tmp1 = _mm512_unpackhi_ps(_r0, _r1);
Expand Down Expand Up @@ -1512,14 +1501,14 @@ static void innerproduct_transform_kernel_sse(const Mat& weight_data, Mat& weigh
_mm_storeu_si128((__m128i*)(g0 + 48), _r6);
_mm_storeu_si128((__m128i*)(g0 + 56), _r7);
#else
__m256 _r0 = _mm256_load_ps(k0);
__m256 _r1 = _mm256_load_ps(k1);
__m256 _r2 = _mm256_load_ps(k2);
__m256 _r3 = _mm256_load_ps(k3);
__m256 _r4 = _mm256_load_ps(k4);
__m256 _r5 = _mm256_load_ps(k5);
__m256 _r6 = _mm256_load_ps(k6);
__m256 _r7 = _mm256_load_ps(k7);
__m256 _r0 = _mm256_loadu_ps(k0);
__m256 _r1 = _mm256_loadu_ps(k1);
__m256 _r2 = _mm256_loadu_ps(k2);
__m256 _r3 = _mm256_loadu_ps(k3);
__m256 _r4 = _mm256_loadu_ps(k4);
__m256 _r5 = _mm256_loadu_ps(k5);
__m256 _r6 = _mm256_loadu_ps(k6);
__m256 _r7 = _mm256_loadu_ps(k7);

transpose8_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7);

Expand Down
40 changes: 20 additions & 20 deletions src/layer/x86/innerproduct_gemm_fp.h
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ static void innerproduct_gemm_sse(const Mat& bottom_blob, Mat& top_blob, const M

if (bias_data_ptr)
{
_sum0 = _mm512_set1_ps(bias_data[p]);
_sum0 = _mm512_set1_ps(bias_data_ptr[p]);
}

int i = 0;
Expand Down Expand Up @@ -506,10 +506,10 @@ static void innerproduct_gemm_sse(const Mat& bottom_blob, Mat& top_blob, const M

if (bias_data_ptr)
{
_sum0 = _mm512_set1_ps(bias_data[p * 4 + 0]);
_sum1 = _mm512_set1_ps(bias_data[p * 4 + 1]);
_sum2 = _mm512_set1_ps(bias_data[p * 4 + 2]);
_sum3 = _mm512_set1_ps(bias_data[p * 4 + 3]);
_sum0 = _mm512_set1_ps(bias_data_ptr[p * 4 + 0]);
_sum1 = _mm512_set1_ps(bias_data_ptr[p * 4 + 1]);
_sum2 = _mm512_set1_ps(bias_data_ptr[p * 4 + 2]);
_sum3 = _mm512_set1_ps(bias_data_ptr[p * 4 + 3]);
}

int i = 0;
Expand Down Expand Up @@ -575,14 +575,14 @@ static void innerproduct_gemm_sse(const Mat& bottom_blob, Mat& top_blob, const M

if (bias_data_ptr)
{
_sum0 = _mm512_set1_ps(bias_data[p * 8 + 0]);
_sum1 = _mm512_set1_ps(bias_data[p * 8 + 1]);
_sum2 = _mm512_set1_ps(bias_data[p * 8 + 2]);
_sum3 = _mm512_set1_ps(bias_data[p * 8 + 3]);
_sum4 = _mm512_set1_ps(bias_data[p * 8 + 4]);
_sum5 = _mm512_set1_ps(bias_data[p * 8 + 5]);
_sum6 = _mm512_set1_ps(bias_data[p * 8 + 6]);
_sum7 = _mm512_set1_ps(bias_data[p * 8 + 7]);
_sum0 = _mm512_set1_ps(bias_data_ptr[p * 8 + 0]);
_sum1 = _mm512_set1_ps(bias_data_ptr[p * 8 + 1]);
_sum2 = _mm512_set1_ps(bias_data_ptr[p * 8 + 2]);
_sum3 = _mm512_set1_ps(bias_data_ptr[p * 8 + 3]);
_sum4 = _mm512_set1_ps(bias_data_ptr[p * 8 + 4]);
_sum5 = _mm512_set1_ps(bias_data_ptr[p * 8 + 5]);
_sum6 = _mm512_set1_ps(bias_data_ptr[p * 8 + 6]);
_sum7 = _mm512_set1_ps(bias_data_ptr[p * 8 + 7]);
}

int i = 0;
Expand Down Expand Up @@ -946,7 +946,7 @@ static void innerproduct_gemm_sse(const Mat& bottom_blob, Mat& top_blob, const M

if (bias_data_ptr)
{
_sum0 = _mm256_set1_ps(bias_data[p]);
_sum0 = _mm256_set1_ps(bias_data_ptr[p]);
}

int i = 0;
Expand Down Expand Up @@ -1021,10 +1021,10 @@ static void innerproduct_gemm_sse(const Mat& bottom_blob, Mat& top_blob, const M

if (bias_data_ptr)
{
_sum0 = _mm256_set1_ps(bias_data[p * 4 + 0]);
_sum1 = _mm256_set1_ps(bias_data[p * 4 + 1]);
_sum2 = _mm256_set1_ps(bias_data[p * 4 + 2]);
_sum3 = _mm256_set1_ps(bias_data[p * 4 + 3]);
_sum0 = _mm256_set1_ps(bias_data_ptr[p * 4 + 0]);
_sum1 = _mm256_set1_ps(bias_data_ptr[p * 4 + 1]);
_sum2 = _mm256_set1_ps(bias_data_ptr[p * 4 + 2]);
_sum3 = _mm256_set1_ps(bias_data_ptr[p * 4 + 3]);
}

int i = 0;
Expand Down Expand Up @@ -1191,7 +1191,7 @@ static void innerproduct_gemm_sse(const Mat& bottom_blob, Mat& top_blob, const M

if (bias_data_ptr)
{
_sum0 = _mm_set1_ps(bias_data[p]);
_sum0 = _mm_set1_ps(bias_data_ptr[p]);
}

int i = 0;
Expand Down Expand Up @@ -1266,7 +1266,7 @@ static void innerproduct_gemm_sse(const Mat& bottom_blob, Mat& top_blob, const M

if (bias_data_ptr)
{
sum = bias_data[p];
sum = bias_data_ptr[p];
}

int i = 0;
Expand Down

0 comments on commit add9ee5

Please sign in to comment.