From b8fcb227fee49b02b966001b7a18be92d9d02816 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Tue, 26 Mar 2024 19:02:44 +0800 Subject: [PATCH] [CPU] Optimize first inference latency for PagedAttention when running bf16 (#23620) ### Details: - *Use specific kernel for 2d f32 to bf16 conversion instead of multiple calls to cpu_convert* - there is an invocation of parallel_for inside cpu_convert, when copying count is small such as only a head size: 128, each core will only copy ~2 elements if core number is 60, this will result false sharing. The cost can reduce from ~1700ms to ~860ms after the fix. SDPA path will copy a block of heads such as 32*128, so it will not easily be impacted but very small prompt size should also suffer from the problem. - *Change the loop order from B,H,L to B,L,H due to the physical layout, can reduce the cost from ~860ms to ~830ms.* - *Changes in vLLM: https://github.com/ilya-lavrenov/vllm/pull/15* ### Tickets: - *ticket-id* --- src/plugins/intel_cpu/CMakeLists.txt | 2 +- .../nodes/kernels/scaled_attn/attn_memcpy.cpp | 44 ++++++++++++++++++- .../nodes/kernels/scaled_attn/attn_memcpy.hpp | 15 +++++-- .../intel_cpu/src/nodes/scaled_attn.cpp | 28 +++++++----- 4 files changed, 71 insertions(+), 18 deletions(-) diff --git a/src/plugins/intel_cpu/CMakeLists.txt b/src/plugins/intel_cpu/CMakeLists.txt index c65bceae2a1d0b..5e494c29b86956 100644 --- a/src/plugins/intel_cpu/CMakeLists.txt +++ b/src/plugins/intel_cpu/CMakeLists.txt @@ -194,7 +194,7 @@ cross_compiled_file(${TARGET_NAME} ARCH AVX512F AVX2 ANY src/nodes/kernels/scaled_attn/attn_memcpy.cpp API src/nodes/kernels/scaled_attn/attn_memcpy.hpp - NAME attn_memcpy paged_attn_memcpy + NAME attn_memcpy paged_attn_memcpy attn_memcpy2d_kernel NAMESPACE ov::Extensions::Cpu::XARCH ) cross_compiled_file(${TARGET_NAME} diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.cpp index c170464eeb47ee..2b0d532168b7a5 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.cpp @@ -83,7 +83,7 @@ static void paged_attn_memcpy_kernel(const ov::intel_cpu::PlainTensor& k_input, const ov::intel_cpu::PlainTensor& past_v_output, const ov::intel_cpu::PlainTensor& slot_mapping) { size_t B = k_input.m_dims[0], H = k_input.m_dims[1], L1 = k_input.m_dims[2], S = k_input.m_dims[3]; - parallel_for3d(B, H, L1, [&](size_t b, size_t h, size_t m) { + parallel_for3d(B, L1, H, [&](size_t b, size_t m, size_t h) { auto block_idx = slot_mapping.ptr(b)[m]; if (block_idx < 0) return; attn_copy(past_k_output.ptr(block_idx, h, 0), @@ -101,7 +101,7 @@ static void paged_attn_memcpy_kernel(const ov::intel_cpu::PlainTensor& k_input, const ov::intel_cpu::PlainTensor& past_v_output, const ov::intel_cpu::PlainTensor& slot_mapping) { size_t B = k_input.m_dims[0], H = k_input.m_dims[1], L1 = k_input.m_dims[2], S = k_input.m_dims[3]; - parallel_for3d(B, H, L1, [&](size_t b, size_t h, size_t m) { + parallel_for3d(B, L1, H, [&](size_t b, size_t m, size_t h) { auto block_idx = slot_mapping.ptr(b)[m]; if (block_idx < 0) return; std::memcpy(past_k_output.ptr_v(block_idx, h, 0), @@ -144,6 +144,46 @@ void paged_attn_memcpy(const ov::intel_cpu::PlainTensor& k_input, } } +void attn_memcpy2d_kernel(void* src, + void* dst, + ov::element::Type src_type, + ov::element::Type dst_type, + size_t src_stride, + size_t dst_stride, + size_t width, + size_t height) { + if (src_type == dst_type) { + auto src_u8 = reinterpret_cast(src); + auto dst_u8 = reinterpret_cast(dst); + + for (size_t j = 0; j < height; j++) { + std::memcpy(dst_u8, src_u8, width * src_type.size()); + dst_u8 += dst_stride * src_type.size(); + src_u8 += src_stride * src_type.size(); + } + } else if (src_type == ov::element::f32 && dst_type == ov::element::bf16) { + auto src_f = reinterpret_cast(src); + auto dst_f = reinterpret_cast(dst); + + for (size_t j = 0; j < height; j++) { + attn_copy(dst_f, src_f, width); + dst_f += dst_stride; + src_f += src_stride; + } + } else if (src_type == ov::element::f32 && dst_type == ov::element::f16) { + auto src_f = reinterpret_cast(src); + auto dst_f = reinterpret_cast(dst); + + for (size_t j = 0; j < height; j++) { + attn_copy(dst_f, src_f, width); + dst_f += dst_stride; + src_f += src_stride; + } + } else { + OPENVINO_THROW("unsupport src type: ", src_type, ", dst type: ", dst_type, " in attn_memcpy2d_kernel"); + } +} + } // namespace XARCH } // namespace Cpu } // namespace Extensions diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.hpp index 2c44534a8462d7..c0e5892db9926b 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.hpp @@ -16,9 +16,9 @@ namespace Cpu { namespace XARCH { void attn_memcpy(const ov::intel_cpu::PlainTensor& k_input, - const ov::intel_cpu::PlainTensor& v_input, - const ov::intel_cpu::PlainTensor& past_k_output, - const ov::intel_cpu::PlainTensor& past_v_output); + const ov::intel_cpu::PlainTensor& v_input, + const ov::intel_cpu::PlainTensor& past_k_output, + const ov::intel_cpu::PlainTensor& past_v_output); void paged_attn_memcpy(const ov::intel_cpu::PlainTensor& k_input, const ov::intel_cpu::PlainTensor& v_input, @@ -26,6 +26,15 @@ void paged_attn_memcpy(const ov::intel_cpu::PlainTensor& k_input, const ov::intel_cpu::PlainTensor& past_v_output, const ov::intel_cpu::PlainTensor& slot_mapping); +void attn_memcpy2d_kernel(void* src, + void* dst, + ov::element::Type src_type, + ov::element::Type dst_type, + size_t src_stride, + size_t dst_stride, + size_t width, + size_t height); + } // namespace XARCH } // namespace Cpu } // namespace Extensions diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp index 273f81a07ad9c1..d5a711858edcdc 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp @@ -477,19 +477,23 @@ struct MHAKernel { wv_scratch_a ? &wv_scratch_a.at({tid, 0}) : nullptr); if (is_bf16) { if (has_out_transpose) { - for (size_t m = m_start; m < m_end; m++) { - cpu_convert(&fp32_out.at({b, m, h, 0}), - &output_emb.at({b, m, h * head_size}), - ov::element::f32, - ov::element::bf16, - head_size); - } + attn_memcpy2d_kernel(&fp32_out.at({b, m_start, h, 0}), + &output_emb.at({b, m_start, h * head_size}), + ov::element::f32, + ov::element::bf16, + fp32_out.stride(1), + output_emb.stride(1), + head_size, + m_cnt); } else { - cpu_convert(&fp32_out.at({b, h, m_start, 0}), - &output_emb.at({b, h, m_start, 0}), - ov::element::f32, - ov::element::bf16, - m_cnt * head_size); + attn_memcpy2d_kernel(&fp32_out.at({b, h, m_start, 0}), + &output_emb.at({b, h, m_start, 0}), + ov::element::f32, + ov::element::bf16, + 0, + 0, + m_cnt * head_size, + 1); } } });