Skip to content

Commit

Permalink
[flash attention] fix bugs for attention mask (#2987)
Browse files Browse the repository at this point in the history
  • Loading branch information
Valentine233 authored Jun 19, 2024
1 parent 52f8c48 commit 620a9bf
Show file tree
Hide file tree
Showing 2 changed files with 287 additions and 103 deletions.
266 changes: 181 additions & 85 deletions csrc/cpu/aten/kernels/FlashAttentionKrnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,76 +283,52 @@ inline Vectorized<float> exp_u20(Vectorized<float> data) {
#endif

// out = val * a + b
template <typename T1, typename T2>
// is_b_stride_zero: If the stride of b is 0 (mask broadcasting case),
// take b as a scalar pointer.
template <bool is_b_stride_zero, typename T1, typename T2>
inline void _scale_attn_mask_fusion_kernel(
T1* a,
T2* b,
const int& size,
T1* out,
T1& val) {
auto vec_size = at::vec::Vectorized<T1>::size();
auto vec_scale = at::vec::Vectorized<T1>(val);
for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) {
auto tmp0 = at::vec::Vectorized<T1>::loadu(a + i);
auto tmp1 = at::vec::Vectorized<T2>::loadu(b + i);
auto tmp2 = at::vec::convert<T1>(tmp1);
auto tmp3 = tmp0 * vec_scale + tmp2;
_store(out + i, tmp3);
}
for (long i = vec_size * (size / vec_size); i < size; i++) {
auto tmp0 = a[i];
auto tmp1 = (T1)b[i];
out[i] = tmp0 * val + tmp1;
}
}

// out = val * a + b
template <typename T1>
inline void _scale_attn_mask_fusion_kernel(
T1* a,
T1* b,
const int& size,
T1* out,
T1& val) {
auto vec_size = at::vec::Vectorized<T1>::size();
auto vec_scale = at::vec::Vectorized<T1>(val);
for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) {
auto tmp0 = at::vec::Vectorized<T1>::loadu(a + i);
auto tmp1 = at::vec::Vectorized<T1>::loadu(b + i);
auto tmp2 = tmp0 * vec_scale + tmp1;
_store(out + i, tmp2);
}
for (long i = vec_size * (size / vec_size); i < size; i++) {
auto tmp0 = a[i];
auto tmp1 = b[i];
out[i] = tmp0 * val + tmp1;
}
}

// out = b ? val * a : -inf
template <typename T1>
inline void _scale_attn_mask_fusion_kernel(
T1* a,
bool* b,
const int& size,
T1* out,
T1& val) {
auto vec_size = at::vec::Vectorized<T1>::size();
auto vec_scale = at::vec::Vectorized<T1>(val);
auto neg_inf = -std::numeric_limits<T1>::infinity();
auto vec_neg_inf = at::vec::Vectorized<T1>(neg_inf);
for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) {
auto tmp0 = at::vec::Vectorized<T1>::loadu(a + i);
auto tmp1 = at::vec::Vectorized<bool>::loadu(b + i);
auto tmp2 = at::vec::convert<T1>(tmp1);
auto tmp3 =
at::vec::Vectorized<T1>::blendv(vec_neg_inf, tmp0 * vec_scale, tmp2);
_store(out + i, tmp3);
}
for (long i = vec_size * (size / vec_size); i < size; i++) {
auto tmp0 = a[i];
auto tmp1 = b[i];
out[i] = tmp1 ? tmp0 * val : neg_inf;
const auto vec_size1 = at::vec::Vectorized<T1>::size();
const auto vec_size2 = at::vec::Vectorized<T2>::size();
constexpr int64_t T1_n =
(vec_size2 == vec_size1 * 2 && is_reduced_floating_point_v<T2>) ? 2 : 1;
constexpr int64_t T2_n = 1;
auto vec_scale = at::vec::VectorizedN<T1, T1_n>(val);
int64_t i = 0;
if (is_b_stride_zero) {
auto b_first_val = (T1)b[0];
auto b_first_vec = at::vec::VectorizedN<T2, T2_n>(b_first_val);
for (; i < size - (size % vec_size2); i += vec_size2) {
auto a_n = at::vec::VectorizedN<T1, T1_n>::loadu(a + i);
auto b_n = b_first_vec;
at::vec::VectorizedN<T1, T1_n> b_n_convert =
at::vec::convert<T1, T1_n, T2, T2_n, true>(b_n);
auto res = a_n * vec_scale + b_n_convert;
res.store(out + i);
}
for (; i < size; i++) {
auto tmp0 = a[i];
auto tmp1 = b_first_val;
out[i] = tmp0 * val + tmp1;
}
} else {
for (; i < size - (size % vec_size2); i += vec_size2) {
auto a_n = at::vec::VectorizedN<T1, T1_n>::loadu(a + i);
auto b_n = at::vec::VectorizedN<T2, T2_n>::loadu(b + i);
at::vec::VectorizedN<T1, T1_n> b_n_convert =
at::vec::convert<T1, T1_n, T2, T2_n, true>(b_n);
auto res = a_n * vec_scale + b_n_convert;
res.store(out + i);
}
for (; i < size; i++) {
auto tmp0 = a[i];
auto tmp1 = (T1)b[i];
out[i] = tmp0 * val + tmp1;
}
}
}

Expand Down Expand Up @@ -425,6 +401,82 @@ inline void _mul_reduce_max_fusion_kernel(
vec_tmp_max));
}

// This function is used to produce an attn_mask in a standard format
inline std::optional<at::Tensor> convert_boolean_attn_mask(
const std::optional<at::Tensor>& attn_mask,
caffe2::TypeMeta dtype) {
// Pass through
if (!attn_mask.has_value()) {
return c10::nullopt;
}
// Convert boolean mask to additive mask
if (attn_mask->dtype() == at::kBool) {
auto new_attn_mask = at::zeros_like(attn_mask.value(), dtype);
new_attn_mask.masked_fill_(
attn_mask->logical_not(), -std::numeric_limits<double>::infinity());
return new_attn_mask;
}
// Otherwise, attn_mask represents an additive attention tensor
return attn_mask;
}

// Support mask shapes:
// 2d: ({Q_seq_len, 1} x {KV_seq_len, 1})
// 4d: ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1} x {KV_seq_len, 1})
inline bool check_attn_mask_shape(
at::Tensor& attn_mask,
int64_t batchSize,
int64_t num_head,
int64_t qSize,
int64_t kvSize) {
if (attn_mask.size(-2) != qSize && attn_mask.size(-2) != 1) {
return false;
}
if (attn_mask.size(-1) != kvSize && attn_mask.size(-1) != 1) {
return false;
}
if (attn_mask.dim() == 2) {
return true;
} else if (attn_mask.dim() == 4) {
if ((attn_mask.size(0) == 1 || attn_mask.size(0) == batchSize) &&
(attn_mask.size(1) == 1 || attn_mask.size(1) == num_head)) {
return true;
}
}
return false;
}

// Reshape attention mask to 4d
inline void reshape_attn_mask_to_4d(
at::Tensor& attn_mask,
int64_t batchSize,
int64_t num_head,
int64_t qSize,
int64_t kvSize) {
TORCH_CHECK(
check_attn_mask_shape(attn_mask, batchSize, num_head, qSize, kvSize),
"IPEX flash_attention: Please use the following attn mask shapes: ",
"2d - ({Q_seq_len, 1} x {KV_seq_len, 1}); ",
"4d - ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1} x {KV_seq_len, 1})");
int64_t attn_mask_size_0 = 1;
int64_t attn_mask_size_1 = 1;
if (attn_mask.dim() == 4) {
if (attn_mask.size(0) == batchSize) {
attn_mask_size_0 = batchSize;
}
if (attn_mask.size(1) == num_head) {
attn_mask_size_1 = num_head;
}
}
attn_mask = attn_mask
.view(
{attn_mask_size_0,
attn_mask_size_1,
attn_mask.size(-2),
attn_mask.size(-1)})
.expand({attn_mask_size_0, attn_mask_size_1, qSize, kvSize});
}

/*
*Caculate the flash attention SDPA.
*@template scalar_t: q/k/v data type
Expand Down Expand Up @@ -480,6 +532,12 @@ cpu_flash_attention(
int64_t num_head = query.size(2);
int64_t headSize = query.size(3);

// reshape mask
if (attention_mask.has_value()) {
reshape_attn_mask_to_4d(
attention_mask.value(), batchSize, num_head, qSize, kvSize);
}

// Strides
int64_t qStrideB = query.stride(0);
int64_t qStrideM = query.stride(1);
Expand All @@ -505,7 +563,13 @@ cpu_flash_attention(
? attention_mask.value().stride(1)
: 0;
int64_t mStrideM =
attention_mask.has_value() ? attention_mask.value().stride(2) : 0;
(attention_mask.has_value() && attention_mask.value().size(2) > 1)
? attention_mask.value().stride(2)
: 0;
int64_t mStrideN =
(attention_mask.has_value() && attention_mask.value().size(3) > 1)
? attention_mask.value().stride(3)
: 0;

int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size;
int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size;
Expand Down Expand Up @@ -596,15 +660,24 @@ cpu_flash_attention(
// And apply scaling factor
if (attention_mask.has_value()) {
for (int64_t row = 0; row < qBlockSize; ++row) {
// qk <- attn_mask ? qk : -inf, if attn_mask is bool
// qk <- qk + attn_mask, else
_scale_attn_mask_fusion_kernel(
qk_data + row * kvBlockSize,
mask_data + i * mStrideB + j * mStrideH +
(m + row) * mStrideM + n,
kvBlockSize,
qk_data + row * kvBlockSize,
scaling_factor);
// qk <- qk * scaling_factor + attn_mask, else
if (mStrideN == 0) {
_scale_attn_mask_fusion_kernel</*is_stride_zero*/ true>(
qk_data + row * kvBlockSize,
mask_data + i * mStrideB + j * mStrideH +
(m + row) * mStrideM,
kvBlockSize,
qk_data + row * kvBlockSize,
scaling_factor);
} else {
_scale_attn_mask_fusion_kernel</*is_stride_zero*/ false>(
qk_data + row * kvBlockSize,
mask_data + i * mStrideB + j * mStrideH +
(m + row) * mStrideM + n,
kvBlockSize,
qk_data + row * kvBlockSize,
scaling_factor);
}
}
}
// Update coefficients with Softmax
Expand Down Expand Up @@ -737,6 +810,12 @@ cpu_flash_attention(
int64_t num_head = query.size(2);
int64_t headSize = query.size(3);

// reshape mask
if (attention_mask.has_value()) {
reshape_attn_mask_to_4d(
attention_mask.value(), batchSize, num_head, qSize, kvSize);
}

// Strides
int64_t qStrideB = query.stride(0);
int64_t qStrideM = query.stride(1);
Expand All @@ -762,7 +841,13 @@ cpu_flash_attention(
? attention_mask.value().stride(1)
: 0;
int64_t mStrideM =
attention_mask.has_value() ? attention_mask.value().stride(2) : 0;
(attention_mask.has_value() && attention_mask.value().size(2) > 1)
? attention_mask.value().stride(2)
: 0;
int64_t mStrideN =
(attention_mask.has_value() && attention_mask.value().size(3) > 1)
? attention_mask.value().stride(3)
: 0;

int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size;
int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size;
Expand Down Expand Up @@ -1241,15 +1326,24 @@ cpu_flash_attention(
// And apply scaling factor
if (attention_mask.has_value()) {
for (int64_t row = 0; row < qBlockSize; ++row) {
// qk <- attn_mask ? qk : -inf, if attn_mask is bool
// qk <- qk + attn_mask, else
_scale_attn_mask_fusion_kernel(
qk_data + row * kvBlockSize,
mask_data + i * mStrideB + j * mStrideH +
(m + row) * mStrideM + n,
kvBlockSize,
qk_data + row * kvBlockSize,
scaling_factor);
// qk <- qk * scaling_factor + attn_mask, else
if (mStrideN == 0) {
_scale_attn_mask_fusion_kernel</*is_stride_zero*/ true>(
qk_data + row * kvBlockSize,
mask_data + i * mStrideB + j * mStrideH +
(m + row) * mStrideM,
kvBlockSize,
qk_data + row * kvBlockSize,
scaling_factor);
} else {
_scale_attn_mask_fusion_kernel</*is_stride_zero*/ false>(
qk_data + row * kvBlockSize,
mask_data + i * mStrideB + j * mStrideH +
(m + row) * mStrideM + n,
kvBlockSize,
qk_data + row * kvBlockSize,
scaling_factor);
}
}
}
// Update coefficients with Softmax
Expand Down Expand Up @@ -1558,6 +1652,8 @@ std::tuple<at::Tensor, at::Tensor> flash_attention_kernel(
attention_mask.value().stride(-1) == 1),
"IPEX flash_attention: Q/K/V/Mask should be continuous on the last dim");

std::optional<at::Tensor> attn_mask =
convert_boolean_attn_mask(attention_mask, query.dtype());
at::Tensor output =
at::empty({batchSize, qSize, num_head, headSize}, query.options());
const auto accumulate_dtype = at::toOpMathType(dtype);
Expand All @@ -1572,7 +1668,7 @@ std::tuple<at::Tensor, at::Tensor> flash_attention_kernel(
value,
dropout_p,
is_causal,
attention_mask,
attn_mask,
scale);

output = output.transpose(1, 2);
Expand Down
Loading

0 comments on commit 620a9bf

Please sign in to comment.