Skip to content

Commit

Permalink
workaround msvc
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Jun 9, 2024
1 parent fda521f commit 36e53f3
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions src/layer/x86/lstm_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -1676,11 +1676,14 @@ static void lstm_dynamic_quantize_scale2int8(const float* ptr, int size, float s
__m128 _p = _mm_loadu_ps(ptr);
_p = _mm_mul_ps(_p, _scale);
*(int32_t*)outptr = float2int8_sse(_p);
#ifndef _MSC_VER
// but msvc feels unhappy :L
#if __AVXVNNI__ || __AVX512VNNI__
outptr[0] += 127;
outptr[1] += 127;
outptr[2] += 127;
outptr[3] += 127;
#endif
#endif
ptr += 4;
outptr += 4;
Expand Down Expand Up @@ -1781,10 +1784,17 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d
}
for (; i + 3 < num_output; i += 4)
{
#ifdef _MSC_VER
hs[0] = 0;
hs[1] = 0;
hs[2] = 0;
hs[3] = 0;
#else
hs[0] = 127;
hs[1] = 127;
hs[2] = 127;
hs[3] = 127;
#endif
hs += 4;
}
for (; i < num_output; i++)
Expand Down Expand Up @@ -2602,6 +2612,9 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d
__m128i _sum1 = _mm_setzero_si128();
int i = 0;
#if __AVXVNNI__ || __AVX512VNNI__
#ifdef _MSC_VER
__m128i _v127 = _mm_set1_epi8(127);
#endif
#if defined(__x86_64__) || defined(_M_X64)
__m128i _sum2 = _mm_setzero_si128();
__m128i _sum3 = _mm_setzero_si128();
Expand All @@ -2613,6 +2626,9 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d
__m128i _w2 = _mm_loadu_si128((const __m128i*)(kptr + 32));
__m128i _w3 = _mm_loadu_si128((const __m128i*)(kptr + 48));

#ifdef _MSC_VER
_xi = _mm_add_epi8(_xi, _v127);
#endif
_sum0 = _mm_dpbusd_epi32(_sum0, _xi, _w0);
_sum1 = _mm_dpbusd_epi32(_sum1, _xi, _w1);
_sum2 = _mm_dpbusd_epi32(_sum2, _xi, _w2);
Expand All @@ -2637,6 +2653,9 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d
__m128i _w0 = _mm_loadu_si128((const __m128i*)kptr);
__m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16));

#ifdef _MSC_VER
_xi = _mm_add_epi8(_xi, _v127);
#endif
_sum0 = _mm_dpbusd_epi32(_sum0, _xi, _w0);
_sum1 = _mm_dpbusd_epi32(_sum1, _xi, _w1);

Expand All @@ -2652,6 +2671,9 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d
__m128i _xi = _mm_castps_si128(_mm_load1_ps((const float*)(x + i)));
__m128i _w = _mm_loadu_si128((const __m128i*)kptr);

#ifdef _MSC_VER
_xi = _mm_add_epi8(_xi, _v127);
#endif
_lstm_IFOGx0 = _mm_dpbusd_epi32(_lstm_IFOGx0, _xi, _w);

kptr += 16;
Expand Down Expand Up @@ -2819,6 +2841,9 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d
__m128i _w2 = _mm_loadu_si128((const __m128i*)(kptr + 32));
__m128i _w3 = _mm_loadu_si128((const __m128i*)(kptr + 48));

#ifdef _MSC_VER
_h_cont = _mm_add_epi8(_h_cont, _v127);
#endif
_sum0 = _mm_dpbusd_epi32(_sum0, _h_cont, _w0);
_sum1 = _mm_dpbusd_epi32(_sum1, _h_cont, _w1);
_sum2 = _mm_dpbusd_epi32(_sum2, _h_cont, _w2);
Expand All @@ -2843,6 +2868,9 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d
__m128i _w0 = _mm_loadu_si128((const __m128i*)kptr);
__m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16));

#ifdef _MSC_VER
_h_cont = _mm_add_epi8(_h_cont, _v127);
#endif
_sum0 = _mm_dpbusd_epi32(_sum0, _h_cont, _w0);
_sum1 = _mm_dpbusd_epi32(_sum1, _h_cont, _w1);

Expand All @@ -2858,6 +2886,9 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d
__m128i _h_cont = _mm_castps_si128(_mm_load1_ps((const float*)(hs + i)));
__m128i _w = _mm_loadu_si128((const __m128i*)kptr);

#ifdef _MSC_VER
_h_cont = _mm_add_epi8(_h_cont, _v127);
#endif
_lstm_IFOGh0 = _mm_dpbusd_epi32(_lstm_IFOGh0, _h_cont, _w);

kptr += 16;
Expand Down

0 comments on commit 36e53f3

Please sign in to comment.