Skip to content

Commit

Permalink
port flash attention from pt to ipex (#2317)
Browse files Browse the repository at this point in the history
* port flash attention from pt to ipex
  • Loading branch information
Valentine233 authored Dec 6, 2023
1 parent 799da9f commit 8d0426c
Show file tree
Hide file tree
Showing 8 changed files with 930 additions and 255 deletions.
93 changes: 71 additions & 22 deletions csrc/cpu/aten/FlashAttention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,87 @@ namespace torch_ipex {
namespace cpu {

DEFINE_DISPATCH(flash_attention_kernel_stub);
DEFINE_DISPATCH(flash_attention_mask_kernel_stub);

/*
*Caculate the flash attention SDPA.
*@param query
*@param key
*@param value
*@param scale_attn
*@param attention_mask
*@return attn_outs
*Caculate the flash attention SDPA and substitude the PT one.
*In order to add optimizations which are hard to upstream, like TPP layout
*conversion.
*/
at::Tensor flash_attention_forward_cpu(
at::Tensor query,
at::Tensor key,
at::Tensor value,
const double scale_attn,
at::Tensor attention_mask) {
std::tuple<
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
c10::SymInt,
c10::SymInt,
at::Tensor,
at::Tensor,
at::Tensor>
flash_attention_forward_cpu(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
double dropout_p,
bool is_causal,
bool return_debug_mask,
c10::optional<double> scale) {
return flash_attention_kernel_stub(
kCPU, query, key, value, scale_attn, attention_mask);
kCPU, query, key, value, dropout_p, is_causal, return_debug_mask, scale);
}

} // namespace cpu
} // namespace torch_ipex
/*
*Caculate the flash attention SDPA with attention mask.
*/
std::tuple<
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
c10::SymInt,
c10::SymInt,
at::Tensor,
at::Tensor,
at::Tensor>
flash_attention_mask_forward_cpu(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
double dropout_p,
bool is_causal,
bool return_debug_mask,
c10::optional<at::Tensor> attention_mask,
c10::optional<double> scale) {
return flash_attention_mask_kernel_stub(
kCPU,
query,
key,
value,
dropout_p,
is_causal,
return_debug_mask,
attention_mask,
scale);
}

namespace {
TORCH_LIBRARY_IMPL(aten, CPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("aten::_scaled_dot_product_flash_attention"),
TORCH_FN((&torch_ipex::cpu::flash_attention_forward_cpu)));
}

TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
m.def(
"flash_attention(Tensor query, Tensor key, Tensor value, \
float scale_attn, Tensor attention_mask)-> Tensor");
"flash_attention_mask(Tensor query, Tensor key, Tensor value, \
float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, \
*, Tensor? attention_mask=None, float? scale=None) -> \
(Tensor, Tensor, Tensor, Tensor, SymInt, SymInt, \
Tensor, Tensor, Tensor)");
m.impl(
"flash_attention",
"flash_attention_mask",
c10::DispatchKey::CPU,
torch_ipex::cpu::flash_attention_forward_cpu);
torch_ipex::cpu::flash_attention_mask_forward_cpu);
}
} // namespace

} // namespace cpu
} // namespace torch_ipex
90 changes: 78 additions & 12 deletions csrc/cpu/aten/FlashAttention.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,88 @@ namespace cpu {

namespace {

at::Tensor flash_attention(
at::Tensor query,
at::Tensor key,
at::Tensor value,
const double scale_attn,
at::Tensor attention_mask);
std::tuple<
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
c10::SymInt,
c10::SymInt,
at::Tensor,
at::Tensor,
at::Tensor>
flash_attention(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
double dropout_p,
bool is_causal,
bool return_debug_mask,
c10::optional<double> scale);

std::tuple<
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
c10::SymInt,
c10::SymInt,
at::Tensor,
at::Tensor,
at::Tensor>
flash_attention_mask(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
double dropout_p,
bool is_causal,
bool return_debug_mask,
c10::optional<at::Tensor> attention_mask,
c10::optional<double> scale);
}

using flash_attention_kernel_fn = at::Tensor (*)(
at::Tensor query,
at::Tensor key,
at::Tensor value,
const double scale_attn,
at::Tensor attention_mask);
using flash_attention_kernel_fn = std::tuple<
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
c10::SymInt,
c10::SymInt,
at::Tensor,
at::Tensor,
at::Tensor> (*)(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
double dropout_p,
bool is_causal,
bool return_debug_mask,
c10::optional<double> scale);

DECLARE_DISPATCH(flash_attention_kernel_fn, flash_attention_kernel_stub);

using flash_attention_mask_kernel_fn = std::tuple<
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
c10::SymInt,
c10::SymInt,
at::Tensor,
at::Tensor,
at::Tensor> (*)(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
double dropout_p,
bool is_causal,
bool return_debug_mask,
c10::optional<at::Tensor> attention_mask,
c10::optional<double> scale);

DECLARE_DISPATCH(
flash_attention_mask_kernel_fn,
flash_attention_mask_kernel_stub);

} // namespace cpu
} // namespace torch_ipex
Loading

0 comments on commit 8d0426c

Please sign in to comment.