Skip to content

Commit

Permalink
Xu enable avx512 vnni (#612)
Browse files Browse the repository at this point in the history
* add avx512-vnni compiler check.
* enable avx512_vnni for vec kernels.
* fix build issue.
  • Loading branch information
xuhancn authored Mar 15, 2022
1 parent 59008ea commit b5f7770
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 7 deletions.
14 changes: 13 additions & 1 deletion cmake/Codegen.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,22 @@ if(CXX_AVX512_BF16_FOUND)
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}/arch:AVX512") # TODO: CHECK HERE
else(MSVC)
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -D__AVX512F__ -DCPU_CAPABILITY_AVX512 \
-mavx512f -mavx512bw -mavx512vl -mavx512dq -mavx512bf16 -mfma")
-DCPU_CAPABILITY_AVX512_VNNI -mavx512f -mavx512bw -mavx512vl -mavx512dq -mavx512vnni \
-mavx512bf16 -mfma")
endif(MSVC)
endif(CXX_AVX512_BF16_FOUND)

if(CXX_AVX512_VNNI_FOUND)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_AVX512_VNNI_CPU_DEFINITION")
list(APPEND CPU_CAPABILITY_NAMES "AVX512_VNNI")
if(MSVC)
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}/arch:AVX512") # TODO: CHECK HERE
else(MSVC)
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -D__AVX512F__ -DCPU_CAPABILITY_AVX512 \
-mavx512f -mavx512bw -mavx512vl -mavx512dq -mavx512vnni -mfma")
endif(MSVC)
endif(CXX_AVX512_VNNI_FOUND)

if(CXX_AVX512_FOUND)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_AVX512_CPU_DEFINITION")
list(APPEND CPU_CAPABILITY_NAMES "AVX512")
Expand Down
24 changes: 23 additions & 1 deletion cmake/Modules/FindAVX.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,30 @@ SET(AVX512_CODE "
}
")

SET(AVX512_VNNI_CODE "
#include <stdint.h>
#include <immintrin.h>
int main() {
char a1 = 1;
char a2 = 2;
char a3 = 0;
__m512i src1 = _mm512_set1_epi8(a1);
__m512i src2 = _mm512_set1_epi8(a2);
__m512i src3 = _mm512_set1_epi8(a3);
// detect avx512_vnni
_mm512_dpbusds_epi32(src3, src1, src2);
return 0;
}
")

SET(AVX512_BF16_CODE "
#include <stdint.h>
#include <immintrin.h>
int main() {
__m512 src;
// detect avx512f and avx512bf16
// detect avx512f and avx512_bf16
_mm512_cvtneps_pbh(src);
return 0;
}
Expand Down Expand Up @@ -79,6 +96,11 @@ CHECK_SSE(CXX "AVX2" " ;-mavx2 -mfma;/arch:AVX2")
CHECK_SSE(C "AVX512" " ;-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma;/arch:AVX512")
CHECK_SSE(CXX "AVX512" " ;-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma;/arch:AVX512")

# gcc version 9.2 can support this avx512_vnni.
# https://gcc.gnu.org/onlinedocs/gcc-9.2.0/gcc/x86-Options.html#x86-Options
CHECK_SSE(C "AVX512_VNNI" " ;-mavx512f -mavx512dq -mavx512vl -mavx512bw -mavx512vnni -mfma;/arch:AVX512")
CHECK_SSE(CXX "AVX512_VNNI" " ;-mavx512f -mavx512dq -mavx512vl -mavx512bw -mavx512vnni -mfma;/arch:AVX512")

# gcc start to support avx512bf16 from version 10.3
# https://gcc.gnu.org/onlinedocs/gcc-10.3.0/gcc/x86-Options.html#x86-Options
CHECK_SSE(C "AVX512_BF16" " ;-mavx512f -mavx512dq -mavx512vl -mavx512bw -mavx512bf16 -mfma;/arch:AVX512")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ static inline __attribute__((always_inline)) void mul_and_sum_s8x128_to_s32x16(
auto b_3_i = _mm512_cvtepi8_epi16(b_3);
a_0_i = _mm512_madd_epi16(a_0_i, b_0_i);
a_2_i = _mm512_madd_epi16(a_2_i, b_2_i);
#ifdef AVX512_VNNI
#ifdef CPU_CAPABILITY_AVX512_VNNI
a_0_i = _mm512_dpwssd_epi32(a_0_i, a_1_i, b_1_i);
a_2_i = _mm512_dpwssd_epi32(a_2_i, a_3_i, b_3_i);
#else
Expand Down Expand Up @@ -524,7 +524,7 @@ static inline __attribute__((always_inline)) void mul_and_sum_s16x128_to_s32x16(
const __m512i* b16x4) {
auto a_0_i = _mm512_madd_epi16(a16x4[0], b16x4[0]);
auto a_2_i = _mm512_madd_epi16(a16x4[2], b16x4[2]);
#ifdef AVX512_VNNI
#ifdef CPU_CAPABILITY_AVX512_VNNI
a_0_i = _mm512_dpwssd_epi32(a_0_i, a16x4[1], b16x4[1]);
a_2_i = _mm512_dpwssd_epi32(a_2_i, a16x4[3], b16x4[3]);
#else
Expand Down Expand Up @@ -580,7 +580,7 @@ mul_and_sum_s8x128x2_to_s32x16x2(
a0_2_i = _mm512_madd_epi16(a0_2_i, b0_2_i);
a1_0_i = _mm512_madd_epi16(a1_0_i, b1_0_i);
a1_2_i = _mm512_madd_epi16(a1_2_i, b1_2_i);
#ifdef AVX512_VNNI
#ifdef CPU_CAPABILITY_AVX512_VNNI
a0_0_i = _mm512_dpwssd_epi32(a0_0_i, a0_1_i, b0_1_i);
a1_0_i = _mm512_dpwssd_epi32(a1_0_i, a1_1_i, b1_1_i);
a0_2_i = _mm512_dpwssd_epi32(a0_2_i, a0_3_i, b0_3_i);
Expand Down Expand Up @@ -611,7 +611,7 @@ mul_and_sum_s16x128x2_to_s32x16x2(
auto a1_0_i = _mm512_madd_epi16(a1_16x4[0], b1_16x4[0]);
auto a0_2_i = _mm512_madd_epi16(a0_16x4[2], b0_16x4[2]);
auto a1_2_i = _mm512_madd_epi16(a1_16x4[2], b1_16x4[2]);
#ifdef AVX512_VNNI
#ifdef CPU_CAPABILITY_AVX512_VNNI
a0_0_i = _mm512_dpwssd_epi32(a0_0_i, a0_16x4[1], b0_16x4[1]);
a1_0_i = _mm512_dpwssd_epi32(a1_0_i, a1_16x4[1], b1_16x4[1]);
a0_2_i = _mm512_dpwssd_epi32(a0_2_i, a0_16x4[3], b0_16x4[3]);
Expand Down
35 changes: 35 additions & 0 deletions intel_extension_for_pytorch/csrc/dyndisp/DispatchStub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ const char* CPUCapabilityToString(CPUCapability isa) {
return "AVX2";
case CPUCapability::AVX512:
return "AVX512";
case CPUCapability::AVX512_VNNI:
return "AVX512_VNNI";
case CPUCapability::AVX512_BF16:
return "AVX512_BF16";
case CPUCapability::NUM_OPTIONS:
Expand All @@ -35,6 +37,8 @@ CPUCapability _get_highest_cpu_support_isa_level() {
*/
if (CPUFeature::get_instance().isa_level_avx512_bf16()) {
return CPUCapability::AVX512_BF16;
} else if (CPUFeature::get_instance().isa_level_avx512_vnni()) {
return CPUCapability::AVX512_VNNI;
} else if (CPUFeature::get_instance().isa_level_avx512_core()) {
return CPUCapability::AVX512;
}
Expand All @@ -49,6 +53,9 @@ CPUCapability _get_highest_binary_support_isa_level() {
#ifdef HAVE_AVX512_BF16_CPU_DEFINITION
return CPUCapability::AVX512_BF16;
#endif
#ifdef HAVE_AVX512_VNNI_CPU_DEFINITION
return CPUCapability::AVX512_VNNI;
#endif
#ifdef HAVE_AVX512_CPU_DEFINITION
return CPUCapability::AVX512;
#endif
Expand All @@ -75,6 +82,8 @@ static CPUCapability compute_cpu_capability() {
if (envar) {
if (strcmp(envar, "avx512_bf16") == 0) {
manual_setup_isa_level = CPUCapability::AVX512_BF16;
} else if (strcmp(envar, "avx512_vnni") == 0) {
manual_setup_isa_level = CPUCapability::AVX512_VNNI;
} else if (strcmp(envar, "avx512") == 0) {
manual_setup_isa_level = CPUCapability::AVX512;
} else if (strcmp(envar, "avx2") == 0) {
Expand Down Expand Up @@ -113,6 +122,10 @@ void* DispatchStubImpl::get_call_ptr(
,
void* AVX512_BF16
#endif
#ifdef HAVE_AVX512_VNNI_CPU_DEFINITION
,
void* AVX512_VNNI
#endif
#ifdef HAVE_AVX512_CPU_DEFINITION
,
void* AVX512
Expand All @@ -134,6 +147,10 @@ void* DispatchStubImpl::get_call_ptr(
,
AVX512_BF16
#endif
#ifdef HAVE_AVX512_VNNI_CPU_DEFINITION
,
AVX512_VNNI
#endif
#ifdef HAVE_AVX512_CPU_DEFINITION
,
AVX512
Expand Down Expand Up @@ -169,6 +186,10 @@ void* DispatchStubImpl::choose_cpu_impl(
,
void* AVX512_BF16
#endif
#ifdef HAVE_AVX512_VNNI_CPU_DEFINITION
,
void* AVX512_VNNI
#endif
#ifdef HAVE_AVX512_CPU_DEFINITION
,
void* AVX512
Expand All @@ -194,6 +215,20 @@ void* DispatchStubImpl::choose_cpu_impl(
}
}
#endif
#ifdef HAVE_AVX512_VNNI_CPU_DEFINITION
if (capability >= static_cast<int>(CPUCapability::AVX512_VNNI)) {
// Quantization kernels have also been disabled on Windows
// for AVX512 because some of their tests are flaky on Windows.
// Ideally, we should have AVX512 kernels for all kernels.
if (C10_UNLIKELY(!AVX512_VNNI)) {
// dispatch to AVX2, since the AVX512 kernel is missing
TORCH_INTERNAL_ASSERT(AVX2, "DispatchStub: missing AVX2 kernel");
return AVX2;
} else {
return AVX512_VNNI;
}
}
#endif
#ifdef HAVE_AVX512_CPU_DEFINITION
if (capability >= static_cast<int>(CPUCapability::AVX512)) {
// Quantization kernels have also been disabled on Windows
Expand Down
18 changes: 17 additions & 1 deletion intel_extension_for_pytorch/csrc/dyndisp/DispatchStub.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ enum class CPUCapability {
DEFAULT = 0,
AVX2 = 1,
AVX512 = 2,
AVX512_BF16 = 3,
AVX512_VNNI = 3,
AVX512_BF16 = 4,
NUM_OPTIONS
};

Expand All @@ -79,6 +80,10 @@ struct TORCH_API DispatchStubImpl {
,
void* AVX512_BF16
#endif
#ifdef HAVE_AVX512_VNNI_CPU_DEFINITION
,
void* AVX512_VNNI
#endif
#ifdef HAVE_AVX512_CPU_DEFINITION
,
void* AVX512
Expand All @@ -100,6 +105,10 @@ struct TORCH_API DispatchStubImpl {
,
void* AVX512_BF16
#endif
#ifdef HAVE_AVX512_VNNI_CPU_DEFINITION
,
void* AVX512_VNNI
#endif
#ifdef HAVE_AVX512_CPU_DEFINITION
,
void* AVX512
Expand Down Expand Up @@ -140,6 +149,10 @@ struct DispatchStub<rT (*)(Args...), T> {
,
reinterpret_cast<void*>(AVX512_BF16)
#endif
#ifdef HAVE_AVX512_VNNI_CPU_DEFINITION
,
reinterpret_cast<void*>(AVX512_VNNI)
#endif
#ifdef HAVE_AVX512_CPU_DEFINITION
,
reinterpret_cast<void*>(AVX512)
Expand Down Expand Up @@ -170,6 +183,9 @@ struct DispatchStub<rT (*)(Args...), T> {
#ifdef HAVE_AVX512_BF16_CPU_DEFINITION
static FnPtr AVX512_BF16;
#endif
#ifdef HAVE_AVX512_VNNI_CPU_DEFINITION
static FnPtr AVX512_VNNI;
#endif
#ifdef HAVE_AVX512_CPU_DEFINITION
static FnPtr AVX512;
#endif
Expand Down

0 comments on commit b5f7770

Please sign in to comment.