Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Jul 19, 2023
1 parent 0dab232 commit ff913b1
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 175 deletions.
274 changes: 99 additions & 175 deletions src/layer/x86/convolution_packed_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,78 +164,47 @@ static void convolution_transform_kernel_packed_int8(const Mat& kernel, Mat& ker
int p = 0;
for (; p + 15 < inch; p += 16)
{
__m512i _vindex = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
_vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(maxk));

for (int k = 0; k < maxk; k++)
{
const signed char* k0 = kptr0 + k;
const signed char* k1 = kptr1 + k;
const signed char* k2 = kptr2 + k;
const signed char* k3 = kptr3 + k;
const signed char* k4 = kptr4 + k;
const signed char* k5 = kptr5 + k;
const signed char* k6 = kptr6 + k;
const signed char* k7 = kptr7 + k;
const signed char* k8 = kptr8 + k;
const signed char* k9 = kptr9 + k;
const signed char* ka = kptra + k;
const signed char* kb = kptrb + k;
const signed char* kc = kptrc + k;
const signed char* kd = kptrd + k;
const signed char* ke = kptre + k;
const signed char* kf = kptrf + k;
__m128i _w0 = _mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, kptr0 + k, sizeof(signed char)));
__m128i _w1 = _mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, kptr1 + k, sizeof(signed char)));
__m128i _w2 = _mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, kptr2 + k, sizeof(signed char)));
__m128i _w3 = _mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, kptr3 + k, sizeof(signed char)));
__m128i _w4 = _mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, kptr4 + k, sizeof(signed char)));
__m128i _w5 = _mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, kptr5 + k, sizeof(signed char)));
__m128i _w6 = _mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, kptr6 + k, sizeof(signed char)));
__m128i _w7 = _mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, kptr7 + k, sizeof(signed char)));
__m128i _w8 = _mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, kptr8 + k, sizeof(signed char)));
__m128i _w9 = _mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, kptr9 + k, sizeof(signed char)));
__m128i _wa = _mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, kptra + k, sizeof(signed char)));
__m128i _wb = _mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, kptrb + k, sizeof(signed char)));
__m128i _wc = _mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, kptrc + k, sizeof(signed char)));
__m128i _wd = _mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, kptrd + k, sizeof(signed char)));
__m128i _we = _mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, kptre + k, sizeof(signed char)));
__m128i _wf = _mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, kptrf + k, sizeof(signed char)));

transpose8x16_epi16(_w0, _w1, _w2, _w3, _w4, _w5, _w6, _w7, _w8, _w9, _wa, _wb, _wc, _wd, _we, _wf);

for (int i = 0; i < 8; i++)
{
g00[0] = k0[0];
g00[1] = k0[maxk];
g00[2] = k1[0];
g00[3] = k1[maxk];
g00[4] = k2[0];
g00[5] = k2[maxk];
g00[6] = k3[0];
g00[7] = k3[maxk];
g00[8] = k4[0];
g00[9] = k4[maxk];
g00[10] = k5[0];
g00[11] = k5[maxk];
g00[12] = k6[0];
g00[13] = k6[maxk];
g00[14] = k7[0];
g00[15] = k7[maxk];
g00[16] = k8[0];
g00[17] = k8[maxk];
g00[18] = k9[0];
g00[19] = k9[maxk];
g00[20] = ka[0];
g00[21] = ka[maxk];
g00[22] = kb[0];
g00[23] = kb[maxk];
g00[24] = kc[0];
g00[25] = kc[maxk];
g00[26] = kd[0];
g00[27] = kd[maxk];
g00[28] = ke[0];
g00[29] = ke[maxk];
g00[30] = kf[0];
g00[31] = kf[maxk];

g00 += 32;
k0 += maxk * 2;
k1 += maxk * 2;
k2 += maxk * 2;
k3 += maxk * 2;
k4 += maxk * 2;
k5 += maxk * 2;
k6 += maxk * 2;
k7 += maxk * 2;
k8 += maxk * 2;
k9 += maxk * 2;
ka += maxk * 2;
kb += maxk * 2;
kc += maxk * 2;
kd += maxk * 2;
ke += maxk * 2;
kf += maxk * 2;
}
_mm_storeu_si128((__m128i*)g00, _w0);
_mm_storeu_si128((__m128i*)(g00 + 16), _w1);
_mm_storeu_si128((__m128i*)(g00 + 16 * 2), _w2);
_mm_storeu_si128((__m128i*)(g00 + 16 * 3), _w3);
_mm_storeu_si128((__m128i*)(g00 + 16 * 4), _w4);
_mm_storeu_si128((__m128i*)(g00 + 16 * 5), _w5);
_mm_storeu_si128((__m128i*)(g00 + 16 * 6), _w6);
_mm_storeu_si128((__m128i*)(g00 + 16 * 7), _w7);
_mm_storeu_si128((__m128i*)(g00 + 16 * 8), _w8);
_mm_storeu_si128((__m128i*)(g00 + 16 * 9), _w9);
_mm_storeu_si128((__m128i*)(g00 + 16 * 10), _wa);
_mm_storeu_si128((__m128i*)(g00 + 16 * 11), _wb);
_mm_storeu_si128((__m128i*)(g00 + 16 * 12), _wc);
_mm_storeu_si128((__m128i*)(g00 + 16 * 13), _wd);
_mm_storeu_si128((__m128i*)(g00 + 16 * 14), _we);
_mm_storeu_si128((__m128i*)(g00 + 16 * 15), _wf);
g00 += 256;
}

kptr0 += maxk * 16;
Expand All @@ -257,78 +226,48 @@ static void convolution_transform_kernel_packed_int8(const Mat& kernel, Mat& ker
}
for (; p + 7 < inch; p += 8)
{
__m256i _vindex = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
_vindex = _mm256_mullo_epi32(_vindex, _mm256_set1_epi32(maxk));

for (int k = 0; k < maxk; k++)
{
const signed char* k0 = kptr0 + k;
const signed char* k1 = kptr1 + k;
const signed char* k2 = kptr2 + k;
const signed char* k3 = kptr3 + k;
const signed char* k4 = kptr4 + k;
const signed char* k5 = kptr5 + k;
const signed char* k6 = kptr6 + k;
const signed char* k7 = kptr7 + k;
const signed char* k8 = kptr8 + k;
const signed char* k9 = kptr9 + k;
const signed char* ka = kptra + k;
const signed char* kb = kptrb + k;
const signed char* kc = kptrc + k;
const signed char* kd = kptrd + k;
const signed char* ke = kptre + k;
const signed char* kf = kptrf + k;

for (int i = 0; i < 4; i++)
{
g00[0] = k0[0];
g00[1] = k0[maxk];
g00[2] = k1[0];
g00[3] = k1[maxk];
g00[4] = k2[0];
g00[5] = k2[maxk];
g00[6] = k3[0];
g00[7] = k3[maxk];
g00[8] = k4[0];
g00[9] = k4[maxk];
g00[10] = k5[0];
g00[11] = k5[maxk];
g00[12] = k6[0];
g00[13] = k6[maxk];
g00[14] = k7[0];
g00[15] = k7[maxk];
g00[16] = k8[0];
g00[17] = k8[maxk];
g00[18] = k9[0];
g00[19] = k9[maxk];
g00[20] = ka[0];
g00[21] = ka[maxk];
g00[22] = kb[0];
g00[23] = kb[maxk];
g00[24] = kc[0];
g00[25] = kc[maxk];
g00[26] = kd[0];
g00[27] = kd[maxk];
g00[28] = ke[0];
g00[29] = ke[maxk];
g00[30] = kf[0];
g00[31] = kf[maxk];

g00 += 32;
k0 += maxk * 2;
k1 += maxk * 2;
k2 += maxk * 2;
k3 += maxk * 2;
k4 += maxk * 2;
k5 += maxk * 2;
k6 += maxk * 2;
k7 += maxk * 2;
k8 += maxk * 2;
k9 += maxk * 2;
ka += maxk * 2;
kb += maxk * 2;
kc += maxk * 2;
kd += maxk * 2;
ke += maxk * 2;
kf += maxk * 2;
}
__m128i _w0 = _mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(kptr0 + k), _vindex, sizeof(signed char)));
__m128i _w1 = _mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(kptr1 + k), _vindex, sizeof(signed char)));
__m128i _w2 = _mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(kptr2 + k), _vindex, sizeof(signed char)));
__m128i _w3 = _mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(kptr3 + k), _vindex, sizeof(signed char)));
__m128i _w4 = _mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(kptr4 + k), _vindex, sizeof(signed char)));
__m128i _w5 = _mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(kptr5 + k), _vindex, sizeof(signed char)));
__m128i _w6 = _mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(kptr6 + k), _vindex, sizeof(signed char)));
__m128i _w7 = _mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(kptr7 + k), _vindex, sizeof(signed char)));
__m128i _w8 = _mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(kptr8 + k), _vindex, sizeof(signed char)));
__m128i _w9 = _mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(kptr9 + k), _vindex, sizeof(signed char)));
__m128i _wa = _mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(kptra + k), _vindex, sizeof(signed char)));
__m128i _wb = _mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(kptrb + k), _vindex, sizeof(signed char)));
__m128i _wc = _mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(kptrc + k), _vindex, sizeof(signed char)));
__m128i _wd = _mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(kptrd + k), _vindex, sizeof(signed char)));
__m128i _we = _mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(kptre + k), _vindex, sizeof(signed char)));
__m128i _wf = _mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(kptrf + k), _vindex, sizeof(signed char)));

__m128i _w08 = _mm_unpacklo_epi64(_w0, _w8);
__m128i _w19 = _mm_unpacklo_epi64(_w1, _w9);
__m128i _w2a = _mm_unpacklo_epi64(_w2, _wa);
__m128i _w3b = _mm_unpacklo_epi64(_w3, _wb);
__m128i _w4c = _mm_unpacklo_epi64(_w4, _wc);
__m128i _w5d = _mm_unpacklo_epi64(_w5, _wd);
__m128i _w6e = _mm_unpacklo_epi64(_w6, _we);
__m128i _w7f = _mm_unpacklo_epi64(_w7, _wf);

transpose8x8_epi16(_w08, _w19, _w2a, _w3b, _w4c, _w5d, _w6e, _w7f);

_mm_storeu_si128((__m128i*)g00, _w08);
_mm_storeu_si128((__m128i*)(g00 + 16), _w4c);
_mm_storeu_si128((__m128i*)(g00 + 16 * 2), _w19);
_mm_storeu_si128((__m128i*)(g00 + 16 * 3), _w5d);
_mm_storeu_si128((__m128i*)(g00 + 16 * 4), _w2a);
_mm_storeu_si128((__m128i*)(g00 + 16 * 5), _w6e);
_mm_storeu_si128((__m128i*)(g00 + 16 * 6), _w3b);
_mm_storeu_si128((__m128i*)(g00 + 16 * 7), _w7f);
g00 += 128;
}

kptr0 += maxk * 8;
Expand Down Expand Up @@ -402,46 +341,31 @@ static void convolution_transform_kernel_packed_int8(const Mat& kernel, Mat& ker
#if __AVX512F__
for (; p + 15 < inch; p += 16)
{
__m512i _vindex = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
_vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(maxk));

for (int k = 0; k < maxk; k++)
{
const signed char* k0 = kptr0 + k;
const signed char* k1 = kptr1 + k;
const signed char* k2 = kptr2 + k;
const signed char* k3 = kptr3 + k;
const signed char* k4 = kptr4 + k;
const signed char* k5 = kptr5 + k;
const signed char* k6 = kptr6 + k;
const signed char* k7 = kptr7 + k;
__m128i _w0 = _mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, kptr0 + k, sizeof(signed char)));
__m128i _w1 = _mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, kptr1 + k, sizeof(signed char)));
__m128i _w2 = _mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, kptr2 + k, sizeof(signed char)));
__m128i _w3 = _mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, kptr3 + k, sizeof(signed char)));
__m128i _w4 = _mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, kptr4 + k, sizeof(signed char)));
__m128i _w5 = _mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, kptr5 + k, sizeof(signed char)));
__m128i _w6 = _mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, kptr6 + k, sizeof(signed char)));
__m128i _w7 = _mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, kptr7 + k, sizeof(signed char)));

for (int i = 0; i < 8; i++)
{
g00[0] = k0[0];
g00[1] = k0[maxk];
g00[2] = k1[0];
g00[3] = k1[maxk];
g00[4] = k2[0];
g00[5] = k2[maxk];
g00[6] = k3[0];
g00[7] = k3[maxk];
g00[8] = k4[0];
g00[9] = k4[maxk];
g00[10] = k5[0];
g00[11] = k5[maxk];
g00[12] = k6[0];
g00[13] = k6[maxk];
g00[14] = k7[0];
g00[15] = k7[maxk];
transpose8x8_epi16(_w0, _w1, _w2, _w3, _w4, _w5, _w6, _w7);

g00 += 16;
k0 += maxk * 2;
k1 += maxk * 2;
k2 += maxk * 2;
k3 += maxk * 2;
k4 += maxk * 2;
k5 += maxk * 2;
k6 += maxk * 2;
k7 += maxk * 2;
}
_mm_storeu_si128((__m128i*)g00, _w0);
_mm_storeu_si128((__m128i*)(g00 + 16), _w1);
_mm_storeu_si128((__m128i*)(g00 + 16 * 2), _w2);
_mm_storeu_si128((__m128i*)(g00 + 16 * 3), _w3);
_mm_storeu_si128((__m128i*)(g00 + 16 * 4), _w4);
_mm_storeu_si128((__m128i*)(g00 + 16 * 5), _w5);
_mm_storeu_si128((__m128i*)(g00 + 16 * 6), _w6);
_mm_storeu_si128((__m128i*)(g00 + 16 * 7), _w7);
g00 += 128;
}

kptr0 += maxk * 16;
Expand Down
54 changes: 54 additions & 0 deletions src/layer/x86/x86_usability.h
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,60 @@ static NCNN_FORCEINLINE void transpose16x8_epi16(__m256i& _r0, __m256i& _r1, __m
_r7 = _mm256_permute2x128_si256(_tmp6, _tmp7, _MM_SHUFFLE(0, 3, 0, 1));
}

static NCNN_FORCEINLINE void transpose8x16_epi16(__m128i& _r0, __m128i& _r1, __m128i& _r2, __m128i& _r3, __m128i& _r4, __m128i& _r5, __m128i& _r6, __m128i& _r7, __m128i& _r8, __m128i& _r9, __m128i& _ra, __m128i& _rb, __m128i& _rc, __m128i& _rd, __m128i& _re, __m128i& _rf)
{
__m128i _tmp0 = _mm_unpacklo_epi16(_r0, _r1);
__m128i _tmp1 = _mm_unpackhi_epi16(_r0, _r1);
__m128i _tmp2 = _mm_unpacklo_epi16(_r2, _r3);
__m128i _tmp3 = _mm_unpackhi_epi16(_r2, _r3);
__m128i _tmp4 = _mm_unpacklo_epi16(_r4, _r5);
__m128i _tmp5 = _mm_unpackhi_epi16(_r4, _r5);
__m128i _tmp6 = _mm_unpacklo_epi16(_r6, _r7);
__m128i _tmp7 = _mm_unpackhi_epi16(_r6, _r7);
__m128i _tmp8 = _mm_unpacklo_epi16(_r8, _r9);
__m128i _tmp9 = _mm_unpackhi_epi16(_r8, _r9);
__m128i _tmpa = _mm_unpacklo_epi16(_ra, _rb);
__m128i _tmpb = _mm_unpackhi_epi16(_ra, _rb);
__m128i _tmpc = _mm_unpacklo_epi16(_rc, _rd);
__m128i _tmpd = _mm_unpackhi_epi16(_rc, _rd);
__m128i _tmpe = _mm_unpacklo_epi16(_re, _rf);
__m128i _tmpf = _mm_unpackhi_epi16(_re, _rf);

__m128i _tmpg = _mm_unpacklo_epi32(_tmp0, _tmp2);
__m128i _tmph = _mm_unpackhi_epi32(_tmp0, _tmp2);
__m128i _tmpi = _mm_unpacklo_epi32(_tmp1, _tmp3);
__m128i _tmpj = _mm_unpackhi_epi32(_tmp1, _tmp3);
__m128i _tmpk = _mm_unpacklo_epi32(_tmp4, _tmp6);
__m128i _tmpl = _mm_unpackhi_epi32(_tmp4, _tmp6);
__m128i _tmpm = _mm_unpacklo_epi32(_tmp5, _tmp7);
__m128i _tmpn = _mm_unpackhi_epi32(_tmp5, _tmp7);
__m128i _tmpo = _mm_unpacklo_epi32(_tmp8, _tmpa);
__m128i _tmpp = _mm_unpackhi_epi32(_tmp8, _tmpa);
__m128i _tmpq = _mm_unpacklo_epi32(_tmp9, _tmpb);
__m128i _tmpr = _mm_unpackhi_epi32(_tmp9, _tmpb);
__m128i _tmps = _mm_unpacklo_epi32(_tmpc, _tmpe);
__m128i _tmpt = _mm_unpackhi_epi32(_tmpc, _tmpe);
__m128i _tmpu = _mm_unpacklo_epi32(_tmpd, _tmpf);
__m128i _tmpv = _mm_unpackhi_epi32(_tmpd, _tmpf);

_r0 = _mm_unpacklo_epi64(_tmpg, _tmpk);
_r1 = _mm_unpacklo_epi64(_tmpo, _tmps);
_r2 = _mm_unpackhi_epi64(_tmpg, _tmpk);
_r3 = _mm_unpackhi_epi64(_tmpo, _tmps);
_r4 = _mm_unpacklo_epi64(_tmph, _tmpl);
_r5 = _mm_unpacklo_epi64(_tmpp, _tmpt);
_r6 = _mm_unpackhi_epi64(_tmph, _tmpl);
_r7 = _mm_unpackhi_epi64(_tmpp, _tmpt);
_r8 = _mm_unpacklo_epi64(_tmpi, _tmpm);
_r9 = _mm_unpacklo_epi64(_tmpq, _tmpu);
_ra = _mm_unpackhi_epi64(_tmpi, _tmpm);
_rb = _mm_unpackhi_epi64(_tmpq, _tmpu);
_rc = _mm_unpacklo_epi64(_tmpj, _tmpn);
_rd = _mm_unpacklo_epi64(_tmpr, _tmpv);
_re = _mm_unpackhi_epi64(_tmpj, _tmpn);
_rf = _mm_unpackhi_epi64(_tmpr, _tmpv);
}

static NCNN_FORCEINLINE float _mm512_comp_reduce_add_ps(__m512 x)
{
const __m256 x256 = _mm256_add_ps(_mm512_castps512_ps256(x), _mm512_extractf32x8_ps(x, 1));
Expand Down

0 comments on commit ff913b1

Please sign in to comment.