Skip to content

Commit

Permalink
Reset q block size for FP32 in flash attention (#2392)
Browse files Browse the repository at this point in the history
  • Loading branch information
Valentine233 authored Dec 22, 2023
1 parent 41cc950 commit 5ed3a24
Showing 1 changed file with 89 additions and 29 deletions.
118 changes: 89 additions & 29 deletions csrc/cpu/aten/kernels/FlashAttentionKrnl.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/CPUBlas.h>
#include <ATen/native/cpu/utils.h>

#include <ATen/Tensor.h>
#include <aten/FlashAttention.h>
#include <torch/all.h>
Expand All @@ -7,13 +15,6 @@
#include "mkl.h"
#include "vec/vec.h"

#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/CPUBlas.h>
#include <ATen/native/cpu/utils.h>

inline void _mkl_gemm(
const CBLAS_LAYOUT layout,
const CBLAS_TRANSPOSE transa,
Expand Down Expand Up @@ -777,24 +778,85 @@ void flash_attention_kernel_impl(

AT_DISPATCH_FLOATING_TYPES_AND(
kBFloat16, query.scalar_type(), "flash_attention", [&] {
cpu_flash_attention<scalar_t, 32, 512>(
output,
logsumexp,
cum_seq_q,
cum_seq_k,
max_q,
max_k,
philox_seed,
philox_offset,
debug_attn_mask,
query,
key,
value,
dropout_p,
is_causal,
return_debug_mask,
attention_mask,
scale);
if (query.scalar_type() == kBFloat16) {
cpu_flash_attention<scalar_t, 32, 512>(
output,
logsumexp,
cum_seq_q,
cum_seq_k,
max_q,
max_k,
philox_seed,
philox_offset,
debug_attn_mask,
query,
key,
value,
dropout_p,
is_causal,
return_debug_mask,
attention_mask,
scale);
} else {
if (q_seq_len >= 768) {
cpu_flash_attention<scalar_t, 256, 512>(
output,
logsumexp,
cum_seq_q,
cum_seq_k,
max_q,
max_k,
philox_seed,
philox_offset,
debug_attn_mask,
query,
key,
value,
dropout_p,
is_causal,
return_debug_mask,
attention_mask,
scale);
} else if (q_seq_len >= 192) {
cpu_flash_attention<scalar_t, 64, 512>(
output,
logsumexp,
cum_seq_q,
cum_seq_k,
max_q,
max_k,
philox_seed,
philox_offset,
debug_attn_mask,
query,
key,
value,
dropout_p,
is_causal,
return_debug_mask,
attention_mask,
scale);
} else {
cpu_flash_attention<scalar_t, 32, 512>(
output,
logsumexp,
cum_seq_q,
cum_seq_k,
max_q,
max_k,
philox_seed,
philox_offset,
debug_attn_mask,
query,
key,
value,
dropout_p,
is_causal,
return_debug_mask,
attention_mask,
scale);
}
}
});
}

Expand Down Expand Up @@ -967,10 +1029,8 @@ flash_attention_mask_kernel(
}
} // anonymous namespace

ALSO_REGISTER_AVX512_DISPATCH(
flash_attention_kernel_stub,
&flash_attention_kernel);
ALSO_REGISTER_AVX512_DISPATCH(
REGISTER_DISPATCH(flash_attention_kernel_stub, &flash_attention_kernel);
REGISTER_DISPATCH(
flash_attention_mask_kernel_stub,
&flash_attention_mask_kernel);

Expand Down

0 comments on commit 5ed3a24

Please sign in to comment.