Skip to content

Commit

Permalink
[CPU] Optimize first inference latency for PagedAttention when runnin…
Browse files Browse the repository at this point in the history
…g bf16 (#23620)

### Details:
- *Use specific kernel for 2d f32 to bf16 conversion instead of multiple
calls to cpu_convert*
- there is an invocation of parallel_for inside cpu_convert, when
copying count is small such as only a head size: 128, each core will
only copy ~2 elements if core number is 60, this will result false
sharing. The cost can reduce from ~1700ms to ~860ms after the fix. SDPA
path will copy a block of heads such as 32*128, so it will not easily be
impacted but very small prompt size should also suffer from the problem.
- *Change the loop order from B,H,L to B,L,H due to the physical layout,
can reduce the cost from ~860ms to ~830ms.*
 - *Changes in vLLM:  ilya-lavrenov/vllm#15

### Tickets:
 - *ticket-id*
  • Loading branch information
luo-cheng2021 authored Mar 26, 2024
1 parent 63bb8d1 commit b8fcb22
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 18 deletions.
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ cross_compiled_file(${TARGET_NAME}
ARCH AVX512F AVX2 ANY
src/nodes/kernels/scaled_attn/attn_memcpy.cpp
API src/nodes/kernels/scaled_attn/attn_memcpy.hpp
NAME attn_memcpy paged_attn_memcpy
NAME attn_memcpy paged_attn_memcpy attn_memcpy2d_kernel
NAMESPACE ov::Extensions::Cpu::XARCH
)
cross_compiled_file(${TARGET_NAME}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ static void paged_attn_memcpy_kernel(const ov::intel_cpu::PlainTensor& k_input,
const ov::intel_cpu::PlainTensor& past_v_output,
const ov::intel_cpu::PlainTensor& slot_mapping) {
size_t B = k_input.m_dims[0], H = k_input.m_dims[1], L1 = k_input.m_dims[2], S = k_input.m_dims[3];
parallel_for3d(B, H, L1, [&](size_t b, size_t h, size_t m) {
parallel_for3d(B, L1, H, [&](size_t b, size_t m, size_t h) {
auto block_idx = slot_mapping.ptr<int32_t>(b)[m];
if (block_idx < 0) return;
attn_copy(past_k_output.ptr<T2>(block_idx, h, 0),
Expand All @@ -101,7 +101,7 @@ static void paged_attn_memcpy_kernel(const ov::intel_cpu::PlainTensor& k_input,
const ov::intel_cpu::PlainTensor& past_v_output,
const ov::intel_cpu::PlainTensor& slot_mapping) {
size_t B = k_input.m_dims[0], H = k_input.m_dims[1], L1 = k_input.m_dims[2], S = k_input.m_dims[3];
parallel_for3d(B, H, L1, [&](size_t b, size_t h, size_t m) {
parallel_for3d(B, L1, H, [&](size_t b, size_t m, size_t h) {
auto block_idx = slot_mapping.ptr<int32_t>(b)[m];
if (block_idx < 0) return;
std::memcpy(past_k_output.ptr_v(block_idx, h, 0),
Expand Down Expand Up @@ -144,6 +144,46 @@ void paged_attn_memcpy(const ov::intel_cpu::PlainTensor& k_input,
}
}

void attn_memcpy2d_kernel(void* src,
void* dst,
ov::element::Type src_type,
ov::element::Type dst_type,
size_t src_stride,
size_t dst_stride,
size_t width,
size_t height) {
if (src_type == dst_type) {
auto src_u8 = reinterpret_cast<uint8_t*>(src);
auto dst_u8 = reinterpret_cast<uint8_t*>(dst);

for (size_t j = 0; j < height; j++) {
std::memcpy(dst_u8, src_u8, width * src_type.size());
dst_u8 += dst_stride * src_type.size();
src_u8 += src_stride * src_type.size();
}
} else if (src_type == ov::element::f32 && dst_type == ov::element::bf16) {
auto src_f = reinterpret_cast<float*>(src);
auto dst_f = reinterpret_cast<ov::bfloat16*>(dst);

for (size_t j = 0; j < height; j++) {
attn_copy<ov::bfloat16, float>(dst_f, src_f, width);
dst_f += dst_stride;
src_f += src_stride;
}
} else if (src_type == ov::element::f32 && dst_type == ov::element::f16) {
auto src_f = reinterpret_cast<float*>(src);
auto dst_f = reinterpret_cast<ov::float16*>(dst);

for (size_t j = 0; j < height; j++) {
attn_copy<ov::float16, float>(dst_f, src_f, width);
dst_f += dst_stride;
src_f += src_stride;
}
} else {
OPENVINO_THROW("unsupport src type: ", src_type, ", dst type: ", dst_type, " in attn_memcpy2d_kernel");
}
}

} // namespace XARCH
} // namespace Cpu
} // namespace Extensions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,25 @@ namespace Cpu {
namespace XARCH {

void attn_memcpy(const ov::intel_cpu::PlainTensor& k_input,
const ov::intel_cpu::PlainTensor& v_input,
const ov::intel_cpu::PlainTensor& past_k_output,
const ov::intel_cpu::PlainTensor& past_v_output);
const ov::intel_cpu::PlainTensor& v_input,
const ov::intel_cpu::PlainTensor& past_k_output,
const ov::intel_cpu::PlainTensor& past_v_output);

void paged_attn_memcpy(const ov::intel_cpu::PlainTensor& k_input,
const ov::intel_cpu::PlainTensor& v_input,
const ov::intel_cpu::PlainTensor& past_k_output,
const ov::intel_cpu::PlainTensor& past_v_output,
const ov::intel_cpu::PlainTensor& slot_mapping);

void attn_memcpy2d_kernel(void* src,
void* dst,
ov::element::Type src_type,
ov::element::Type dst_type,
size_t src_stride,
size_t dst_stride,
size_t width,
size_t height);

} // namespace XARCH
} // namespace Cpu
} // namespace Extensions
Expand Down
28 changes: 16 additions & 12 deletions src/plugins/intel_cpu/src/nodes/scaled_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -477,19 +477,23 @@ struct MHAKernel<ScaledDotProductAttention::KT_ONEDNN, T> {
wv_scratch_a ? &wv_scratch_a.at<T>({tid, 0}) : nullptr);
if (is_bf16) {
if (has_out_transpose) {
for (size_t m = m_start; m < m_end; m++) {
cpu_convert(&fp32_out.at<float>({b, m, h, 0}),
&output_emb.at<T>({b, m, h * head_size}),
ov::element::f32,
ov::element::bf16,
head_size);
}
attn_memcpy2d_kernel(&fp32_out.at<float>({b, m_start, h, 0}),
&output_emb.at<T>({b, m_start, h * head_size}),
ov::element::f32,
ov::element::bf16,
fp32_out.stride(1),
output_emb.stride(1),
head_size,
m_cnt);
} else {
cpu_convert(&fp32_out.at<float>({b, h, m_start, 0}),
&output_emb.at<T>({b, h, m_start, 0}),
ov::element::f32,
ov::element::bf16,
m_cnt * head_size);
attn_memcpy2d_kernel(&fp32_out.at<float>({b, h, m_start, 0}),
&output_emb.at<T>({b, h, m_start, 0}),
ov::element::f32,
ov::element::bf16,
0,
0,
m_cnt * head_size,
1);
}
}
});
Expand Down

0 comments on commit b8fcb22

Please sign in to comment.