diff --git a/cmake/developer_package/compile_flags/os_flags.cmake b/cmake/developer_package/compile_flags/os_flags.cmake index fdfd7211c8e815..dae0368638f27b 100644 --- a/cmake/developer_package/compile_flags/os_flags.cmake +++ b/cmake/developer_package/compile_flags/os_flags.cmake @@ -4,6 +4,8 @@ include(ProcessorCount) include(CheckCXXCompilerFlag) +include(CheckCSourceCompiles) +include(CheckCXXSourceCompiles) # # ov_disable_deprecated_warnings() @@ -91,6 +93,54 @@ macro(ov_dev_package_no_errors) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${ov_c_cxx_dev_no_errors}") endmacro() +# +# ov_check_compiler_supports_sve(lang flags) +# +# Checks whether compiler for passed language supports SVE code compilation +# +macro(ov_check_compiler_supports_sve lang flags) + # Code to compile + set(SVE_CODE " + #include + int main() { + svfloat64_t a; + a = svdup_n_f64(0); + return 0; + }") + + # Save the current state of required flags + set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) + + # Set the flags necessary for compiling the test code with SVE support + set(CMAKE_REQUIRED_FLAGS "${CMAKE_${lang}_FLAGS_INIT} ${flags}") + + # Check if the source code compiles with the given flags for the specified language (C or C++) + if(lang STREQUAL "CXX") + CHECK_CXX_SOURCE_COMPILES("${SVE_CODE}" ${lang}_HAS_SVE) + else() + CHECK_C_SOURCE_COMPILES("${SVE_CODE}" ${lang}_HAS_SVE) + endif() + + # If the compilation test is successful, set appropriate variables indicating support + if(${lang}_HAS_SVE) + set(${lang}_SVE_FOUND TRUE CACHE BOOL "SVE available on host") + set(${lang}_SVE_FOUND TRUE CACHE BOOL "${lang} SVE support") + set(${lang}_SVE_FLAGS "${flags}" CACHE STRING "${lang} SVE flags") + endif() + + # Restore the original state of required flags + set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE}) + + # If the compilation test fails, indicate that the support is not found + if(NOT ${lang}_SVE_FOUND) + set(${lang}_SVE_FOUND FALSE CACHE BOOL "${lang} SVE support") + set(${lang}_SVE_FLAGS "" CACHE STRING "${lang} SVE flags") + endif() + + # Mark the variables as advanced to hide them in the default CMake GUI + mark_as_advanced(${lang}_SVE_FOUND ${lang}_SVE_FLAGS) +endmacro() + # # ov_sse42_optimization_flags() # @@ -208,6 +258,50 @@ macro(ov_arm_neon_fp16_optimization_flags flags) endif() endmacro() +# +# ov_arm_sve_optimization_flags() +# +macro(ov_arm_sve_optimization_flags flags) + # Check for compiler SVE support + ov_check_compiler_supports_sve(CXX "-march=armv8-a+sve") + ov_check_compiler_supports_sve(C "-march=armv8-a+sve") + + if(OV_COMPILER_IS_INTEL_LLVM) + message(WARNING "Unsupported CXX compiler ${CMAKE_CXX_COMPILER_ID}") + elseif(CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") + # nothing should be required here + elseif(ANDROID) + if(ANDROID_ABI STREQUAL "arm64-v8a") + set(${flags} -Wno-unused-command-line-argument) + if(CXX_SVE_FOUND AND C_SVE_FOUND) + list(APPEND ${flags} -march=armv8-a+sve) + else() + message(WARNING "SVE is not supported on this Android ABI: ${ANDROID_ABI}") + endif() + else() + message(WARNING "SVE is not supported on this Android ABI: ${ANDROID_ABI}") + endif() + else() + if(AARCH64) + set(${flags} -O2) + + # Add flag for SVE if supported + if(CXX_SVE_FOUND AND C_SVE_FOUND) + list(APPEND ${flags} -march=armv8-a+sve) + endif() + if(NOT CMAKE_CL_64) + list(APPEND ${flags} -ftree-vectorize) + endif() + + set(${flags} ${${flags}}) + elseif(ARM) + message(WARNING "SVE is not supported on 32-bit ARM architectures.") + else() + message(WARNING "SVE is not supported by architecture ${CMAKE_SYSTEM_PROCESSOR}") + endif() + endif() +endmacro() + # # ov_disable_all_warnings() # diff --git a/cmake/developer_package/cross_compile/cross_compiled_disp_gen.cmake b/cmake/developer_package/cross_compile/cross_compiled_disp_gen.cmake index c33d64635eb10b..fd534f3e600bfe 100644 --- a/cmake/developer_package/cross_compile/cross_compiled_disp_gen.cmake +++ b/cmake/developer_package/cross_compile/cross_compiled_disp_gen.cmake @@ -18,6 +18,7 @@ set(_CPU_CHECK_ANY "true") set(_CPU_CHECK_SSE42 "with_cpu_x86_sse42()") set(_CPU_CHECK_AVX "with_cpu_x86_avx()") set(_CPU_CHECK_NEON_FP16 "with_cpu_neon_fp16()") +set(_CPU_CHECK_SVE "with_cpu_sve()") set(_CPU_CHECK_AVX2 "with_cpu_x86_avx2()") set(_CPU_CHECK_AVX512F "with_cpu_x86_avx512f()") diff --git a/cmake/developer_package/cross_compile/cross_compiled_func.cmake b/cmake/developer_package/cross_compile/cross_compiled_func.cmake index 1e92fe3bfdaf8c..962aa5d373a4db 100644 --- a/cmake/developer_package/cross_compile/cross_compiled_func.cmake +++ b/cmake/developer_package/cross_compile/cross_compiled_func.cmake @@ -3,7 +3,7 @@ # ## list of available instruction sets -set(_ARCH_LIST ANY SSE42 AVX AVX2 AVX512F NEON_FP16) +set(_ARCH_LIST ANY SSE42 AVX AVX2 AVX512F NEON_FP16 SVE) set(_ACCEPTED_ARCHS_ANY "^(ANY)$") set(_ACCEPTED_ARCHS_SSE42 "^(ANY|SSE42)$") @@ -11,6 +11,7 @@ set(_ACCEPTED_ARCHS_AVX "^(ANY|SSE42|AVX)$") set(_ACCEPTED_ARCHS_AVX2 "^(ANY|SSE42|AVX|AVX2)$") set(_ACCEPTED_ARCHS_AVX512F "^(ANY|SSE42|AVX|AVX2|AVX512F)$") set(_ACCEPTED_ARCHS_NEON_FP16 "^(ANY|NEON_FP16)$") +set(_ACCEPTED_ARCHS_SVE "^(ANY|SVE)$") ## Arch specific definitions set(_DEFINE_ANY "") @@ -19,12 +20,14 @@ set(_DEFINE_AVX "HAVE_AVX" ${_DEFINE_SSE42}) set(_DEFINE_AVX2 "HAVE_AVX2" ${_DEFINE_AVX}) set(_DEFINE_AVX512F "HAVE_AVX512F" ${_DEFINE_AVX2}) set(_DEFINE_NEON_FP16 "HAVE_NEON_FP16" ${_DEFINE_ANY}) +set(_DEFINE_SVE "HAVE_SVE" ${_DEFINE_SVE}) ## Arch specific compile options ov_avx512_optimization_flags(_FLAGS_AVX512F) ov_avx2_optimization_flags (_FLAGS_AVX2) ov_sse42_optimization_flags (_FLAGS_SSE42) ov_arm_neon_fp16_optimization_flags(_FLAGS_NEON_FP16) +ov_arm_sve_optimization_flags(_FLAGS_SVE) set(_FLAGS_AVX "") ## TBD is not defined for OV project yet set(_FLAGS_ANY "") ## @@ -185,6 +188,8 @@ endfunction() function(_currently_requested_top_arch VAR) if(ENABLE_NEON_FP16) set(RES NEON_FP16) + elseif(ENABLE_SVE) + set(RES SVE) elseif(ENABLE_AVX512F) set(RES AVX512F) elseif(ENABLE_AVX2) diff --git a/cmake/developer_package/features.cmake b/cmake/developer_package/features.cmake index 8d1f3696c6759c..ae5313cea8a8b4 100644 --- a/cmake/developer_package/features.cmake +++ b/cmake/developer_package/features.cmake @@ -51,6 +51,8 @@ ov_dependent_option (ENABLE_AVX512F "Enable AVX512 optimizations" ON "X86_64 OR ov_dependent_option(ENABLE_NEON_FP16 "Enable ARM FP16 optimizations" ON "AARCH64" OFF) +ov_dependent_option(ENABLE_SVE "Enable SVE optimizations" ON "AARCH64" OFF) + # Type of build, we add this as an explicit option to default it to ON get_property(BUILD_SHARED_LIBS_DEFAULT GLOBAL PROPERTY TARGET_SUPPORTS_SHARED_LIBS) ov_option (BUILD_SHARED_LIBS "Build as a shared library" ${BUILD_SHARED_LIBS_DEFAULT}) diff --git a/src/inference/dev_api/openvino/runtime/system_conf.hpp b/src/inference/dev_api/openvino/runtime/system_conf.hpp index 59d56dfdd49d73..bebc2014ab8028 100644 --- a/src/inference/dev_api/openvino/runtime/system_conf.hpp +++ b/src/inference/dev_api/openvino/runtime/system_conf.hpp @@ -83,6 +83,13 @@ OPENVINO_RUNTIME_API bool with_cpu_x86_sse42(); */ OPENVINO_RUNTIME_API bool with_cpu_neon_fp16(); +/** + * @brief Checks whether CPU supports ARM SVE capability + * @ingroup ov_dev_api_system_conf + * @return `True` if ARM SVE instructions are available, `false` otherwise + */ +OPENVINO_RUNTIME_API bool with_cpu_sve(); + /** * @brief Checks whether CPU supports AVX capability * @ingroup ov_dev_api_system_conf diff --git a/src/inference/src/system_conf.cpp b/src/inference/src/system_conf.cpp index 27c671d07ad5c9..3227b1a3034903 100644 --- a/src/inference/src/system_conf.cpp +++ b/src/inference/src/system_conf.cpp @@ -22,6 +22,7 @@ # include # define ARM_COMPUTE_CPU_FEATURE_HWCAP_FPHP (1 << 9) # define ARM_COMPUTE_CPU_FEATURE_HWCAP_ASIMDHP (1 << 10) +# define ARM_COMPUTE_CPU_FEATURE_HWCAP_SVE (1 << 24) #elif defined(__APPLE__) && defined(__aarch64__) # include # include @@ -114,6 +115,10 @@ bool with_cpu_neon_fp16() { return false; } +bool with_cpu_sve() { + return false; +} + #else // OPENVINO_ARCH_X86 || OPENVINO_ARCH_X86_64 bool with_cpu_x86_sse42() { @@ -173,6 +178,20 @@ bool with_cpu_neon_fp16() { return false; # endif } +bool with_cpu_sve() { +# if !defined(_WIN64) && !defined(BARE_METAL) && !defined(__APPLE__) && !defined(__OpenBSD__) && \ + !defined(__arm__) && defined(__aarch64__) + const uint32_t hwcaps = getauxval(AT_HWCAP); + return hwcaps & ARM_COMPUTE_CPU_FEATURE_HWCAP_SVE; +# elif !defined(_WIN64) && !defined(BARE_METAL) && !defined(__APPLE__) && !defined(__OpenBSD__) && \ + !defined(__aarch64__) && defined(__arm__) + return false; +# elif defined(__aarch64__) && defined(__APPLE__) + return false; +# else + return false; +# endif +} #endif // OPENVINO_ARCH_X86 || OPENVINO_ARCH_X86_64 bool check_open_mp_env_vars(bool include_omp_num_threads) { diff --git a/src/plugins/intel_cpu/CMakeLists.txt b/src/plugins/intel_cpu/CMakeLists.txt index 04909c7d8f5a5a..3aac6f98fb3a80 100644 --- a/src/plugins/intel_cpu/CMakeLists.txt +++ b/src/plugins/intel_cpu/CMakeLists.txt @@ -277,6 +277,66 @@ target_include_directories(${TARGET_NAME} SYSTEM PRIVATE $) +# Checking for SVE support in CXX and C +INCLUDE(CheckCSourceCompiles) +INCLUDE(CheckCXXSourceCompiles) + +SET(SVE_CODE " +#include +int main() { + svfloat64_t a; + a = svdup_n_f64(0); + return 0; +} +") + +MACRO(CHECK_SVE lang flags) + # Save the current state of required flags + SET(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) + + # Set the flags necessary for compiling the test code with SVE support + SET(CMAKE_REQUIRED_FLAGS "${CMAKE_${lang}_FLAGS_INIT} ${flags}") + + # Check if the source code compiles with the given flags for the specified language (C or C++) + IF(lang STREQUAL "CXX") + CHECK_CXX_SOURCE_COMPILES("${SVE_CODE}" ${lang}_HAS_SVE) + ELSE() + CHECK_C_SOURCE_COMPILES("${SVE_CODE}" ${lang}_HAS_SVE) + ENDIF() + + # If the compilation test is successful, set appropriate variables indicating support + IF(${lang}_HAS_SVE) + SET(${lang}_SVE_FOUND TRUE CACHE BOOL "SVE available on host") + SET(${lang}_SVE_FOUND TRUE CACHE BOOL "${lang} SVE support") + SET(${lang}_SVE_FLAGS "${flags}" CACHE STRING "${lang} SVE flags") + ENDIF() + + # Restore the original state of required flags + SET(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE}) + + # If the compilation test fails, indicate that the support is not found + IF(NOT ${lang}_SVE_FOUND) + SET(${lang}_SVE_FOUND FALSE CACHE BOOL "${lang} SVE support") + SET(${lang}_SVE_FLAGS "" CACHE STRING "${lang} SVE flags") + ENDIF() + + # Mark the variables as advanced to hide them in the default CMake GUI + MARK_AS_ADVANCED(${lang}_SVE_FOUND ${lang}_SVE_FLAGS) +ENDMACRO() + +CHECK_SVE(CXX "-march=armv8-a+sve") +CHECK_SVE(C "-march=armv8-a+sve") + +# ARCH lists for softmax.cpp and mha_single_token.cpp +# Based on result of above calls, decide whether to add SVE +set(SOFTMAX_ARCH_LIST AVX512F AVX2 NEON_FP16 ANY) +set(MHA_SINGLE_TOKEN_ARCH_LIST AVX512F AVX2 NEON_FP16 ANY) + +if(CXX_SVE_FOUND AND C_SVE_FOUND) + list(APPEND SOFTMAX_ARCH_LIST SVE) + list(APPEND MHA_SINGLE_TOKEN_ARCH_LIST SVE) +endif() + # Cross compiled function # TODO: The same for proposal, proposalONNX, topk cross_compiled_file(${TARGET_NAME} @@ -287,14 +347,14 @@ cross_compiled_file(${TARGET_NAME} NAMESPACE ov::Extensions::Cpu::XARCH ) cross_compiled_file(${TARGET_NAME} - ARCH AVX512F AVX2 NEON_FP16 ANY + ARCH ${SOFTMAX_ARCH_LIST} src/nodes/kernels/scaled_attn/softmax.cpp API src/nodes/kernels/scaled_attn/softmax.hpp NAME attn_softmax NAMESPACE ov::Extensions::Cpu::XARCH ) cross_compiled_file(${TARGET_NAME} - ARCH AVX512F AVX2 NEON_FP16 ANY + ARCH ${MHA_SINGLE_TOKEN_ARCH_LIST} src/nodes/kernels/scaled_attn/mha_single_token.cpp API src/nodes/kernels/scaled_attn/mha_single_token.hpp NAME mha_single_token diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp index 2956c8a6a6b5b8..157cc5333a1cf0 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp @@ -13,6 +13,9 @@ #include "openvino/core/type/float16.hpp" #if defined(OPENVINO_ARCH_ARM64) +#if defined(HAVE_SVE) +#include "arm_sve.h" +#endif #include "arm_neon.h" #endif @@ -31,6 +34,10 @@ static constexpr size_t vec_len_f32_avx2 = vec_len_avx2 / sizeof(float); static constexpr size_t vec_len_f32_neon = vec_len_neon / sizeof(float); static constexpr size_t vec_len_f16_neon = vec_len_neon / sizeof(ov::float16); +#if defined(HAVE_SVE) +static constexpr size_t vec_len_f32_sve = svcntw(); +#endif + #ifdef HAVE_AVX512F inline __m512 cvt_bf16_to_fp32(const __m256i src) { __m512i y = _mm512_cvtepu16_epi32(src); @@ -246,6 +253,79 @@ static constexpr size_t vec_len_f16_neon = vec_len_neon / sizeof(ov::float16); #endif #ifdef OPENVINO_ARCH_ARM64 +#if defined(HAVE_SVE) + inline svfloat32_t exp_ps_sve(svbool_t& pg, svfloat32_t& src) { + // Constants + const auto log2_e = svdup_n_f32(1.4426950409f); + const auto ln2 = svdup_n_f32(0.6931473921f); + const auto half_ln2_sq = svdup_n_f32(0.2413862043f); + const auto not_mask17 = svdup_n_u32(~((1u << 17) - 1)); + const auto one = svdup_n_f32(1.0f); + + // Algorithm starts here + svfloat32_t t0 = svmul_f32_z(pg, src, log2_e); // y = x * log2(e) + svfloat32_t t1 = svrintm_f32_z(pg, t0); // rount to int (float) + svint32_t t2 = svcvt_s32_f32_z(pg, t1); // n + + t1 = svsub_f32_z(pg, t0, t1); // a = y - floor(y) + t1 = svadd_f32_z(pg, t1, one); // b = a + 1 + + svuint32_t t3 = svlsr_n_u32_z(pg, svreinterpret_u32_f32(t1), 17); // v = b >> 17 (u32) + svfloat32_t t4 = svexpa_f32(t3); // c = fexpa(v) + t4 = svscale_f32_z(pg, t4, t2); // fexpa(v) * 2^(n) + + // and_(t2.d, t1.d, not_mask17.d) + svfloat32_t t5 = svreinterpret_f32_u32(svand_u32_z(pg, svreinterpret_u32_f32(t1), not_mask17)); + t5 = svsub_f32_z(pg, t1, t5); // z + t0 = svmla_f32_z(pg, ln2, t5, half_ln2_sq); // ln2 + half_ln2_sq * z + t0 = svmla_f32_z(pg, one, t5, t0); // 1 + (ln2 * z) + (half_ln2_sq * z * z) + t0 = svmul_f32_z(pg, t0, t4); // Final result + + return t0; + } + inline svfloat32_t exp_ps_sve_legacy(svbool_t& pg, svfloat32_t& src) { + const auto c1 = svreinterpret_f32_u32(svdup_n_u32(0x3f7ffff6)); + const auto c2 = svreinterpret_f32_u32(svdup_n_u32(0x3efffedb)); + const auto c3 = svreinterpret_f32_u32(svdup_n_u32(0x3e2aaf33)); + const auto c4 = svreinterpret_f32_u32(svdup_n_u32(0x3d2b9f17)); + const auto c5 = svreinterpret_f32_u32(svdup_n_u32(0x3c072010)); + + const auto shift = svreinterpret_f32_u32(svdup_n_u32(0x4b00007f)); // 2^23 + 127 = 0x1.0000fep23f + const auto one = svdup_n_f32(1.0f); // 1 + const auto two = svdup_n_f32(2.0f); // 2 + const auto inv_ln2 = svreinterpret_f32_u32(svdup_n_u32(0x3fb8aa3b)); + const auto neg_ln2_hi = svreinterpret_f32_u32(svdup_n_u32(0xbf317200)); + const auto neg_ln2_lo = svreinterpret_f32_u32(svdup_n_u32(0xb5bfbe8e)); + + const auto inf = svdup_n_f32(std::numeric_limits::infinity()); + const auto max_input = svdup_n_f32(88.37f); // Approximately ln(2^127.5) + const auto zero = svdup_n_f32(0.f); + const auto min_input = svdup_n_f32(-86.64f); // Approximately ln(2^-125) + + const auto z = svmla_f32_z(pg, shift, src, inv_ln2); + auto n = svsub_f32_z(pg, z, shift); + n = svsub_f32_z(pg, n, one); + const auto scale = svreinterpret_f32_u32(svlsl_n_u32_z(pg, svreinterpret_u32_f32(z), 23)); // 2^n + + const auto r_hi = svmla_f32_z(pg, src, n, neg_ln2_hi); + const auto r = svmla_f32_z(pg, r_hi, n, neg_ln2_lo); + const auto r2 = svmul_f32_z(pg, r, r); + + const auto p1 = svmul_f32_z(pg, c1, r); + const auto p23 = svmla_f32_z(pg, c2, c3, r); + const auto p45 = svmla_f32_z(pg, c4, c5, r); + const auto p2345 = svmla_f32_z(pg, p23, p45, r2); + const auto p12345 = svmla_f32_z(pg, p1, p2345, r2); + + auto poly = svmla_f32_z(pg, scale, p12345, scale); + poly = svmul_f32_z(pg, poly, two); + + poly = svsel_f32(svcmplt_f32(pg, src, min_input), zero, poly); + poly = svsel_f32(svcmpgt_f32(pg, src, max_input), inf, poly); + + return poly; + } +#endif inline float32x4_t exp_ps_neon_f32(const float32x4_t& src) { const auto c1 = vreinterpretq_f32_u32(vdupq_n_u32(0x3f7ffff6)); const auto c2 = vreinterpretq_f32_u32(vdupq_n_u32(0x3efffedb)); diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp index 25ddbb1b4246b1..db12273ea27aff 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -13,7 +14,6 @@ # include #endif - #include "openvino/core/type/bfloat16.hpp" #include "openvino/core/parallel.hpp" #include "mha_single_token.hpp" @@ -21,7 +21,10 @@ #include "softmax_kernel.hpp" #if defined(OPENVINO_ARCH_ARM64) -# include +#if defined(HAVE_SVE) +# include +#endif +# include #endif namespace ov { @@ -59,12 +62,21 @@ void cvt_copy(TA* dst, TB* src, size_t n) { mm256_uni_storeu_ps(dst + i, vb); } #elif defined(OPENVINO_ARCH_ARM64) - if (std::is_same::value && std::is_same::value) { - for (; i + vec_len_f32_neon <= n; i += vec_len_f32_neon) { - float32x4_t vb1 = __vld1q_f32(src + i); - __vst1q_f32(dst + i, vb1); +#if defined(HAVE_SVE) + auto _dst = reinterpret_cast(dst); + size_t inc = vec_len_f32_sve; + svbool_t pg = svptrue_b32(); + + while (i < n) { + if (n - i < vec_len_f32_sve) { + inc = n - i; + pg = svwhilelt_b32(0, static_cast(inc)); } + svfloat32_t b1 = svld1_f32(pg, src + i); + svst1_f32(pg, _dst + i, b1); + i += inc; } +#else #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) if (std::is_same::value && std::is_same::value) { for (; i + vec_len_f16_neon <= n; i += vec_len_f16_neon) { @@ -72,6 +84,14 @@ void cvt_copy(TA* dst, TB* src, size_t n) { vst1q_f16(reinterpret_cast(dst + i), vb1); } } +#else + if (std::is_same::value && std::is_same::value) { + for (; i + vec_len_f32_neon <= n; i += vec_len_f32_neon) { + float32x4_t vb1 = __vld1q_f32(src + i); + __vst1q_f32(dst + i, vb1); + } + } +#endif #endif #endif for (; i < n; i++) { @@ -99,6 +119,27 @@ static void attn_acc_value(float* out, float weight, T* v, size_t S, float* scal mm256_uni_storeu_ps(out + i, v_out); } #elif defined(OPENVINO_ARCH_ARM64) +#if defined(HAVE_SVE) + auto _v = reinterpret_cast(v); + svfloat32_t attn_w_vec_fp32 = svdup_n_f32(weight); + size_t inc = vec_len_f32_sve; + svbool_t pg = svptrue_b32(); + + while (i < S) { + if (S - i < vec_len_f32_sve) { + inc = S - i; + pg = svwhilelt_b32(0, static_cast(inc)); + } + svfloat32_t v_value = svld1_f32(pg, _v + i); + svfloat32_t v_out = svld1_f32(pg, out + i); + + // svmla with merging to preserve inactive lane values when there's ... + // fewer than vec_len elements left + v_out = svmla_f32_m(pg, v_out, attn_w_vec_fp32, v_value); + svst1_f32(pg, out + i, v_out); + i += inc; + } +#else float32x4_t attn_w_vec_fp32 = vdupq_n_f32(weight); for (; i + vec_len_f32_neon <= S; i += vec_len_f32_neon) { float32x4_t v_value = __vld1q_f32(v + i); @@ -106,6 +147,7 @@ static void attn_acc_value(float* out, float weight, T* v, size_t S, float* scal v_out = vmlaq_f32(v_out, attn_w_vec_fp32, v_value); __vst1q_f32(out + i, v_out); } +#endif #endif for (; i < S; i++) { out[i] += weight * v[i]; @@ -357,7 +399,50 @@ static float sum_q_head(T* a, size_t n) { hsum(vsum0); sum = _mm256_cvtss_f32(vsum0); #elif defined(OPENVINO_ARCH_ARM64) - size_t vec_len_f32_neon = 4; +#if defined(HAVE_SVE) + svfloat32_t sum0 = svdup_n_f32(0.0f); + svfloat32_t sum1 = svdup_n_f32(0.0f); + svfloat32_t sum2 = svdup_n_f32(0.0f); + svfloat32_t sum3 = svdup_n_f32(0.0f); + svbool_t pg = svptrue_b32(); + + for (; i + 4 * vec_len_f32_sve <= n; i += 4 * vec_len_f32_sve) { + svfloat32_t a0 = svld1_f32(pg, a + i); + svfloat32_t a1 = svld1_f32(pg, a + i + vec_len_f32_sve); + svfloat32_t a2 = svld1_f32(pg, a + i + vec_len_f32_sve * 2); + svfloat32_t a3 = svld1_f32(pg, a + i + vec_len_f32_sve * 3); + + sum0 = svadd_f32_z(pg, a0, sum0); + sum1 = svadd_f32_z(pg, a1, sum1); + sum2 = svadd_f32_z(pg, a2, sum2); + sum3 = svadd_f32_z(pg, a3, sum3); + } + if (i + 2 * vec_len_f32_sve <= n) { + svfloat32_t a0 = svld1_f32(pg, a + i); + svfloat32_t a1 = svld1_f32(pg, a + i + vec_len_f32_sve); + + sum0 = svadd_f32_z(pg, a0, sum0); + sum1 = svadd_f32_z(pg, a1, sum1); + i += 2 * vec_len_f32_sve; + } + if (i + vec_len_f32_sve <= n) { + svfloat32_t a0 = svld1_f32(pg, a + i); + sum0 = svadd_f32_z(pg, a0, sum0); + i += vec_len_f32_sve; + } + // Process tail elements parallely as well (if any) + if (i != n) { + svbool_t pg_rem = svwhilelt_b32(0, static_cast(n - i)); + svfloat32_t a0 = svld1_f32(pg_rem, a + i); + sum0 = svadd_f32_m(pg_rem, sum0, a0); + i = n; + } + float32_t sum_0 = svaddv_f32(pg, sum0); + float32_t sum_1 = svaddv_f32(pg, sum1); + float32_t sum_2 = svaddv_f32(pg, sum2); + float32_t sum_3 = svaddv_f32(pg, sum3); + sum = static_cast(sum_0 + sum_1 + sum_2 + sum_3); +#else float32x4_t vsum0 = vdupq_n_f32(0.0f); float32x4_t vsum1 = vdupq_n_f32(0.0f); float32x4_t vsum2 = vdupq_n_f32(0.0f); @@ -398,7 +483,7 @@ static float sum_q_head(T* a, size_t n) { sum_low = vpadd_f32(sum_low, sum_low); sum = vget_lane_f32(sum_low, 0); #endif - +#endif for (; i < n; i++) { float tmp = a[i]; sum += tmp; @@ -497,6 +582,63 @@ static float dot_product(TA* a, TB* b, size_t n, float* scale, float* zp, float* sum = _mm256_cvtss_f32(vsum0); #elif defined(OPENVINO_ARCH_ARM64) +#if defined(HAVE_SVE) + svbool_t pg = svptrue_b32(); + svfloat32_t sum0 = svdup_n_f32(0.0f); + svfloat32_t sum1 = svdup_n_f32(0.0f); + svfloat32_t sum2 = svdup_n_f32(0.0f); + svfloat32_t sum3 = svdup_n_f32(0.0f); + + auto _a = reinterpret_cast(a); + auto _b = reinterpret_cast(b); + + for (; i + 4 * vec_len_f32_sve <= n; i += 4 * vec_len_f32_sve) { + svfloat32_t a0 = svld1_f32(pg, _a + i); + svfloat32_t a1 = svld1_f32(pg, _a + i + vec_len_f32_sve); + svfloat32_t a2 = svld1_f32(pg, _a + i + vec_len_f32_sve * 2); + svfloat32_t a3 = svld1_f32(pg, _a + i + vec_len_f32_sve * 3); + + svfloat32_t b0 = svld1_f32(pg, _b + i); + svfloat32_t b1 = svld1_f32(pg, _b + i + vec_len_f32_sve); + svfloat32_t b2 = svld1_f32(pg, _b + i + vec_len_f32_sve * 2); + svfloat32_t b3 = svld1_f32(pg, _b + i + vec_len_f32_sve * 3); + + sum0 = svmla_f32_z(pg, sum0, a0, b0); + sum1 = svmla_f32_z(pg, sum1, a1, b1); + sum2 = svmla_f32_z(pg, sum2, a2, b2); + sum3 = svmla_f32_z(pg, sum3, a3, b3); + } + if (i + 2 * vec_len_f32_sve <= n) { + svfloat32_t a0 = svld1_f32(pg, _a + i); + svfloat32_t a1 = svld1_f32(pg, _a + i + vec_len_f32_sve); + + svfloat32_t b0 = svld1_f32(pg, _b + i); + svfloat32_t b1 = svld1_f32(pg, _b + i + vec_len_f32_sve); + + sum0 = svmla_f32_z(pg, sum0, a0, b0); + sum1 = svmla_f32_z(pg, sum1, a1, b1); + i += 2 * vec_len_f32_sve; + } + if (i + vec_len_f32_sve <= n) { + svfloat32_t a0 = svld1_f32(pg, _a + i); + svfloat32_t b0 = svld1_f32(pg, _b + i); + sum0 = svmla_f32_z(pg, sum0, a0, b0); + i += vec_len_f32_sve; + } + // Process the tail elements parallely as well (if any) + if (i != n) { + svbool_t pg_rem = svwhilelt_b32(0, static_cast(n - i)); + svfloat32_t a0 = svld1_f32(pg_rem, _a + i); + svfloat32_t b0 = svld1_f32(pg_rem, _b + i); + sum0 = svmla_f32_m(pg_rem, sum0, a0, b0); + i = n; + } + float32_t sum_0 = svaddv_f32(pg, sum0); + float32_t sum_1 = svaddv_f32(pg, sum1); + float32_t sum_2 = svaddv_f32(pg, sum2); + float32_t sum_3 = svaddv_f32(pg, sum3); + sum = static_cast(sum_0 + sum_1 + sum_2 + sum_3); +#else float32x4_t vsum0 = vdupq_n_f32(0.0f); float32x4_t vsum1 = vdupq_n_f32(0.0f); float32x4_t vsum2 = vdupq_n_f32(0.0f); @@ -544,7 +686,7 @@ static float dot_product(TA* a, TB* b, size_t n, float* scale, float* zp, float* temp_sum = vpadd_f32(temp_sum, temp_sum); sum = vget_lane_f32(temp_sum, 0); #endif - +#endif for (; i < n; i++) { sum += a[i] * b[i]; } @@ -765,7 +907,7 @@ static float dot_product(TA* a, uint8_t* b, size_t n, float* scale, float* zp, f template static void attn_reduce(T* dst, float* temp, size_t M, size_t S, size_t temp_stride) { - size_t i = 0; + int i = 0; #if defined(HAVE_AVX512F) for (; i + vec_len_f32_avx512 <= S; i+= vec_len_f32_avx512) { auto* src = temp + i; @@ -790,6 +932,28 @@ static void attn_reduce(T* dst, float* temp, size_t M, size_t S, size_t temp_str mm256_uni_storeu_ps(dst + i, result_vec_fp32); } #elif defined(OPENVINO_ARCH_ARM64) +#if defined(HAVE_SVE) + auto _dst = reinterpret_cast(dst); + size_t inc = vec_len_f32_sve; + svbool_t pg = svptrue_b32(); + + while (i < S) { + if (S - i < vec_len_f32_sve) { + inc = S - i; + pg = svwhilelt_b32(0, static_cast(inc)); + } + auto* src = temp + i; + auto result_vec_fp32 = svdup_n_f32(0.0f); + + for (size_t m = 0; m < M; m++) { + auto o_vec_fp32 = svld1_f32(pg, src); + result_vec_fp32 = svadd_f32_m(pg, result_vec_fp32, o_vec_fp32); + src += temp_stride; + } + svst1_f32(pg, _dst + i, result_vec_fp32); + i += inc; + } +#else for (; i + vec_len_f32_neon <= S; i += vec_len_f32_neon) { auto* src = temp + i; auto result_vec_fp32 = vdupq_n_f32(0.0f); @@ -800,6 +964,7 @@ static void attn_reduce(T* dst, float* temp, size_t M, size_t S, size_t temp_str } __vst1q_f32(dst + i, result_vec_fp32); } +#endif #endif for (; i < S; i++) { auto* src = temp + i; @@ -1239,6 +1404,7 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query, OPENVINO_THROW("Unsupported precision: ", query.get_precision()); } } + } // namespace XARCH } // namespace Cpu } // namespace Extensions diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp index 60c6a24ec5f2fa..284a71e1450051 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp @@ -12,6 +12,9 @@ #include #if defined(OPENVINO_ARCH_ARM64) +#if defined(HAVE_SVE) +#include "arm_sve.h" +#endif #include "arm_neon.h" #endif @@ -656,6 +659,28 @@ inline void exp_reduce_sum(float* a, const float max, const size_t size, float& hsum(v_sum); sum = _mm256_cvtss_f32(v_sum); #elif defined(OPENVINO_ARCH_ARM64) +#if defined(HAVE_SVE) + svfloat32_t v_a; + svfloat32_t v_max = svdup_n_f32(max); + svfloat32_t v_sum = svdup_n_f32(0.0f); + size_t vec_len_f32_sve = svcntw(); + size_t inc = vec_len_f32_sve; + svbool_t pg = svptrue_b32(); + + while (i < size) { + if (size - i < vec_len_f32_sve) { + inc = size - i; + pg = svwhilelt_b32(0, static_cast(inc)); + } + v_a = svld1_f32(pg, a + i); + v_a = svsub_f32_z(pg, v_a, v_max); + v_a = exp_ps_sve(pg, v_a); + v_sum = svadd_f32_m(pg, v_sum, v_a); + svst1_f32(pg, a + i, v_a); + i += inc; + } + sum = svaddv_f32(svptrue_b32(), v_sum); +#else float32x4_t v_a; float32x4_t v_max = vdupq_n_f32(max); float32x4_t v_sum = vdupq_n_f32(0.0f); @@ -669,7 +694,7 @@ inline void exp_reduce_sum(float* a, const float max, const size_t size, float& i += vec_len_f32_neon; } sum = vaddvq_f32(v_sum); - +#endif #endif for (; i < size; i++) { a[i] = exp(a[i] - max); @@ -780,6 +805,22 @@ inline void multiply_scalar(float* a, float* a_dst, const float val, const size_ i += (size - i); } #elif defined(OPENVINO_ARCH_ARM64) +#if defined(HAVE_SVE) + svfloat32_t v_scale = svdup_n_f32(val); + size_t inc = vec_len_f32_sve; + svbool_t pg = svptrue_b32(); + + while (i < size) { + if (size - i < vec_len_f32_sve) { + inc = size - i; + pg = svwhilelt_b32(0, static_cast(inc)); + } + svfloat32_t v_a = svld1_f32(pg, a + i); + v_a = svmul_f32_z(pg, v_a, v_scale); + svst1_f32(pg, a_dst + i, v_a); + i += inc; + } +#else float32x4_t v_scale = vdupq_n_f32(val); while (i + vec_len_f32_neon <= size) { float32x4_t v_a = vld1q_f32(a + i); @@ -787,6 +828,7 @@ inline void multiply_scalar(float* a, float* a_dst, const float val, const size_ vst1q_f32(a_dst + i, v_a); i += vec_len_f32_neon; } +#endif #endif for (; i < size; i++) { a_dst[i] = a[i] * val; @@ -948,7 +990,7 @@ inline void attn_softmax_kernel(float* a, // divide sum float scalar = 1.0f / sum; if (dst_precision == ov::element::f32) { - multiply_scalar(a, static_cast(a_dst), scalar, len); + multiply_scalar(a, reinterpret_cast(a_dst), scalar, len); // apply causual mask to final result instead of attn_score if (total_size > len) memset(static_cast(a_dst) + len, 0, sizeof(float) * (total_size - len));