Skip to content

Commit

Permalink
avx vnni int8, avx vnni int16, avx ne convert infrastructure (#5749)
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui authored Nov 14, 2024
1 parent e71fdf8 commit 9cefe9a
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 0 deletions.
48 changes: 48 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,15 @@ else()
set(CMAKE_REQUIRED_FLAGS "/arch:AVX2")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX2")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpbssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX2")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpwsud_avx_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT16)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX2")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256 _a; __m128bh _s = _mm256_cvtneps_avx_pbh(_a); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_NE_CONVERT)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX512")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m512i _s, _a, _b; _s = _mm512_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_VNNI)

Expand Down Expand Up @@ -534,6 +543,15 @@ else()
set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c -mavxvnni")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c -mavxvnni -mavxvnniint8")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpbssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c -mavxvnni -mavxvnniint16")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpwsud_avx_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT16)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c -mavxneconvert")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256 _a; __m128bh _s = _mm256_cvtneps_avx_pbh(_a); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_NE_CONVERT)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX512 -mfma -mf16c -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mavx512vnni")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m512i _s, _a, _b; _s = _mm512_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_VNNI)

Expand All @@ -560,6 +578,15 @@ else()
set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx2 -mavxvnni")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI)

set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx2 -mavxvnni -mavxvnniint8")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpbssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8)

set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx2 -mavxvnni -mavxvnniint16")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpwsud_avx_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT16)

set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx2 -mavxneconvert")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256 _a; __m128bh _s = _mm256_cvtneps_avx_pbh(_a); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_NE_CONVERT)

set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mavx512vnni")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m512i _s, _a, _b; _s = _mm512_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_VNNI)

Expand Down Expand Up @@ -603,9 +630,30 @@ else()
if(NCNN_AVX2)
option(NCNN_AVXVNNI "optimize x86 platform with avx vnni extension" ON)
endif()
if(NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8)
if(NCNN_AVXVNNI)
option(NCNN_AVXVNNIINT8 "optimize x86 platform with avx vnni int8 extension" ON)
endif()
else()
message(WARNING "The compiler does not support avx vnni int8 extension. NCNN_AVXVNNIINT8 will be OFF.")
endif()
if(NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT16)
if(NCNN_AVXVNNI)
option(NCNN_AVXVNNIINT16 "optimize x86 platform with avx vnni int16 extension" ON)
endif()
else()
message(WARNING "The compiler does not support avx vnni int16 extension. NCNN_AVXVNNIINT16 will be OFF.")
endif()
else()
message(WARNING "The compiler does not support avx vnni extension. NCNN_AVXVNNI will be OFF.")
endif()
if(NCNN_COMPILER_SUPPORT_X86_AVX_NE_CONVERT)
if(NCNN_AVX2)
option(NCNN_AVXNECONVERT "optimize x86 platform with avx ne convert extension" ON)
endif()
else()
message(WARNING "The compiler does not support avx ne convert extension. NCNN_AVXNECONVERT will be OFF.")
endif()
if(NCNN_COMPILER_SUPPORT_X86_AVX512)
if(NCNN_AVX2)
option(NCNN_AVX512 "optimize x86 platform with avx512 extension" ON)
Expand Down
27 changes: 27 additions & 0 deletions cmake/ncnn_add_layer.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,15 @@ macro(ncnn_add_layer class)
if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNI)
ncnn_add_arch_opt_source(${class} avxvnni "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__")
endif()
if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNIINT8)
ncnn_add_arch_opt_source(${class} avxvnniint8 "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__ /D__AVXVNNIINT8__")
endif()
if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNIINT16)
ncnn_add_arch_opt_source(${class} avxvnniint16 "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__ /D__AVXVNNIINT16__")
endif()
if(NCNN_RUNTIME_CPU AND NCNN_AVXNECONVERT)
ncnn_add_arch_opt_source(${class} avxneconvert "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXNECONVERT__")
endif()
if(NCNN_RUNTIME_CPU AND NCNN_AVX2)
ncnn_add_arch_opt_source(${class} avx2 "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__")
endif()
Expand Down Expand Up @@ -187,6 +196,15 @@ macro(ncnn_add_layer class)
if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNI)
ncnn_add_arch_opt_source(${class} avxvnni "/arch:AVX2 -mfma -mf16c -mavxvnni /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__")
endif()
if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNIINT8)
ncnn_add_arch_opt_source(${class} avxvnniint8 "/arch:AVX2 -mfma -mf16c -mavxvnni -mavxvnniint8 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__ /D__AVXVNNIINT8__")
endif()
if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNIINT16)
ncnn_add_arch_opt_source(${class} avxvnniint16 "/arch:AVX2 -mfma -mf16c -mavxvnni -mavxvnniint16 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__ /D__AVXVNNIINT16__")
endif()
if(NCNN_RUNTIME_CPU AND NCNN_AVXNECONVERT)
ncnn_add_arch_opt_source(${class} avxneconvert "/arch:AVX2 -mfma -mf16c -mavxneconvert /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXNECONVERT__")
endif()
if(NCNN_RUNTIME_CPU AND NCNN_AVX2)
ncnn_add_arch_opt_source(${class} avx2 "/arch:AVX2 -mfma -mf16c /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__")
endif()
Expand Down Expand Up @@ -218,6 +236,15 @@ macro(ncnn_add_layer class)
if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNI)
ncnn_add_arch_opt_source(${class} avxvnni "-mavx2 -mfma -mf16c -mavxvnni")
endif()
if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNIINT8)
ncnn_add_arch_opt_source(${class} avxvnniint8 "-mavx2 -mfma -mf16c -mavxvnni -mavxvnniint8")
endif()
if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNIINT16)
ncnn_add_arch_opt_source(${class} avxvnniint16 "-mavx2 -mfma -mf16c -mavxvnni -mavxvnniint16")
endif()
if(NCNN_RUNTIME_CPU AND NCNN_AVXNECONVERT)
ncnn_add_arch_opt_source(${class} avxneconvert "-mavx2 -mfma -mf16c -mavxneconvert")
endif()
if(NCNN_RUNTIME_CPU AND NCNN_AVX2)
ncnn_add_arch_opt_source(${class} avx2 "-mavx2 -mfma -mf16c")
endif()
Expand Down
27 changes: 27 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,15 @@ if(NCNN_TARGET_ARCH STREQUAL "x86")
else()
target_compile_options(ncnn PRIVATE /arch:AVX /D__SSSE3__ /D__SSE4_1__ /D__FMA__)
endif()
if(NCNN_AVXVNNIINT8)
target_compile_options(ncnn PRIVATE /D__AVXVNNIINT8__)
endif()
if(NCNN_AVXVNNIINT16)
target_compile_options(ncnn PRIVATE /D__AVXVNNIINT16__)
endif()
if(NCNN_AVXNECONVERT)
target_compile_options(ncnn PRIVATE /D__AVXNECONVERT__)
endif()
if(NCNN_AVXVNNI)
target_compile_options(ncnn PRIVATE /D__AVXVNNI__)
elseif(NCNN_XOP)
Expand All @@ -460,6 +469,15 @@ if(NCNN_TARGET_ARCH STREQUAL "x86")
else()
target_compile_options(ncnn PRIVATE /arch:AVX -mfma /D__SSSE3__ /D__SSE4_1__ /D__FMA__)
endif()
if(NCNN_AVXVNNIINT8)
target_compile_options(ncnn PRIVATE -mavxvnniint8 /D__AVXVNNIINT8__)
endif()
if(NCNN_AVXVNNIINT16)
target_compile_options(ncnn PRIVATE -mavxvnniint16 /D__AVXVNNIINT16__)
endif()
if(NCNN_AVXNECONVERT)
target_compile_options(ncnn PRIVATE -mavxneconvert /D__AVXNECONVERT__)
endif()
if(NCNN_AVXVNNI)
target_compile_options(ncnn PRIVATE -mavxvnni /D__AVXVNNI__)
elseif(NCNN_XOP)
Expand All @@ -474,6 +492,15 @@ if(NCNN_TARGET_ARCH STREQUAL "x86")
else()
target_compile_options(ncnn PRIVATE -mavx -mfma)
endif()
if(NCNN_AVXVNNIINT8)
target_compile_options(ncnn PRIVATE -mavxvnniint8)
endif()
if(NCNN_AVXVNNIINT16)
target_compile_options(ncnn PRIVATE -mavxvnniint16)
endif()
if(NCNN_AVXNECONVERT)
target_compile_options(ncnn PRIVATE -mavxneconvert)
endif()
if(NCNN_AVXVNNI)
target_compile_options(ncnn PRIVATE -mavxvnni)
elseif(NCNN_XOP)
Expand Down
102 changes: 102 additions & 0 deletions src/cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ static int g_cpu_support_x86_xop;
static int g_cpu_support_x86_f16c;
static int g_cpu_support_x86_avx2;
static int g_cpu_support_x86_avx_vnni;
static int g_cpu_support_x86_avx_vnni_int8;
static int g_cpu_support_x86_avx_vnni_int16;
static int g_cpu_support_x86_avx_ne_convert;
static int g_cpu_support_x86_avx512;
static int g_cpu_support_x86_avx512_vnni;
static int g_cpu_support_x86_avx512_bf16;
Expand Down Expand Up @@ -617,6 +620,72 @@ static int get_cpu_support_x86_avx_vnni()
return cpu_info[0] & (1u << 4);
}

static int get_cpu_support_x86_avx_vnni_int8()
{
unsigned int cpu_info[4] = {0};
x86_cpuid(0, cpu_info);

int nIds = cpu_info[0];
if (nIds < 7)
return 0;

x86_cpuid(1, cpu_info);
// check AVX XSAVE OSXSAVE
if (!(cpu_info[2] & (1u << 28)) || !(cpu_info[2] & (1u << 26)) || !(cpu_info[2] & (1u << 27)))
return 0;

// check XSAVE enabled by kernel
if ((x86_get_xcr0() & 6) != 6)
return 0;

x86_cpuid_sublevel(7, 1, cpu_info);
return cpu_info[3] & (1u << 4);
}

static int get_cpu_support_x86_avx_vnni_int16()
{
unsigned int cpu_info[4] = {0};
x86_cpuid(0, cpu_info);

int nIds = cpu_info[0];
if (nIds < 7)
return 0;

x86_cpuid(1, cpu_info);
// check AVX XSAVE OSXSAVE
if (!(cpu_info[2] & (1u << 28)) || !(cpu_info[2] & (1u << 26)) || !(cpu_info[2] & (1u << 27)))
return 0;

// check XSAVE enabled by kernel
if ((x86_get_xcr0() & 6) != 6)
return 0;

x86_cpuid_sublevel(7, 1, cpu_info);
return cpu_info[3] & (1u << 10);
}

static int get_cpu_support_x86_avx_ne_convert()
{
unsigned int cpu_info[4] = {0};
x86_cpuid(0, cpu_info);

int nIds = cpu_info[0];
if (nIds < 7)
return 0;

x86_cpuid(1, cpu_info);
// check AVX XSAVE OSXSAVE
if (!(cpu_info[2] & (1u << 28)) || !(cpu_info[2] & (1u << 26)) || !(cpu_info[2] & (1u << 27)))
return 0;

// check XSAVE enabled by kernel
if ((x86_get_xcr0() & 6) != 6)
return 0;

x86_cpuid_sublevel(7, 1, cpu_info);
return cpu_info[3] & (1u << 5);
}

static int get_cpu_support_x86_avx512()
{
#if __APPLE__
Expand Down Expand Up @@ -1967,6 +2036,9 @@ static void initialize_global_cpu_info()
g_cpu_support_x86_f16c = get_cpu_support_x86_f16c();
g_cpu_support_x86_avx2 = get_cpu_support_x86_avx2();
g_cpu_support_x86_avx_vnni = get_cpu_support_x86_avx_vnni();
g_cpu_support_x86_avx_vnni_int8 = get_cpu_support_x86_avx_vnni_int8();
g_cpu_support_x86_avx_vnni_int16 = get_cpu_support_x86_avx_vnni_int16();
g_cpu_support_x86_avx_ne_convert = get_cpu_support_x86_avx_ne_convert();
g_cpu_support_x86_avx512 = get_cpu_support_x86_avx512();
g_cpu_support_x86_avx512_vnni = get_cpu_support_x86_avx512_vnni();
g_cpu_support_x86_avx512_bf16 = get_cpu_support_x86_avx512_bf16();
Expand Down Expand Up @@ -2489,6 +2561,36 @@ int cpu_support_x86_avx_vnni()
#endif
}

int cpu_support_x86_avx_vnni_int8()
{
try_initialize_global_cpu_info();
#if defined(__i386__) || defined(__x86_64__) || defined(_M_IX86) || defined(_M_X64)
return g_cpu_support_x86_avx_vnni_int8;
#else
return 0;
#endif
}

int cpu_support_x86_avx_vnni_int16()
{
try_initialize_global_cpu_info();
#if defined(__i386__) || defined(__x86_64__) || defined(_M_IX86) || defined(_M_X64)
return g_cpu_support_x86_avx_vnni_int16;
#else
return 0;
#endif
}

int cpu_support_x86_avx_ne_convert()
{
try_initialize_global_cpu_info();
#if defined(__i386__) || defined(__x86_64__) || defined(_M_IX86) || defined(_M_X64)
return g_cpu_support_x86_avx_ne_convert;
#else
return 0;
#endif
}

int cpu_support_x86_avx512()
{
try_initialize_global_cpu_info();
Expand Down
6 changes: 6 additions & 0 deletions src/cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ NCNN_EXPORT int cpu_support_x86_f16c();
NCNN_EXPORT int cpu_support_x86_avx2();
// avx_vnni = x86 avx vnni
NCNN_EXPORT int cpu_support_x86_avx_vnni();
// avx_vnni_int8 = x86 avx vnni int8
NCNN_EXPORT int cpu_support_x86_avx_vnni_int8();
// avx_vnni_int16 = x86 avx vnni int16
NCNN_EXPORT int cpu_support_x86_avx_vnni_int16();
// avx_ne_convert = x86 avx ne convert
NCNN_EXPORT int cpu_support_x86_avx_ne_convert();
// avx512 = x86 avx512f + avx512cd + avx512bw + avx512dq + avx512vl
NCNN_EXPORT int cpu_support_x86_avx512();
// avx512_vnni = x86 avx512 vnni
Expand Down
3 changes: 3 additions & 0 deletions src/platform.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
#cmakedefine01 NCNN_F16C
#cmakedefine01 NCNN_AVX2
#cmakedefine01 NCNN_AVXVNNI
#cmakedefine01 NCNN_AVXVNNIINT8
#cmakedefine01 NCNN_AVXVNNIINT16
#cmakedefine01 NCNN_AVXNECONVERT
#cmakedefine01 NCNN_AVX512
#cmakedefine01 NCNN_AVX512VNNI
#cmakedefine01 NCNN_AVX512BF16
Expand Down

0 comments on commit 9cefe9a

Please sign in to comment.