From 620a9bfd9db42813931a857e78fa3f5d298be200 Mon Sep 17 00:00:00 2001 From: Xuan Liao Date: Wed, 19 Jun 2024 08:04:31 +0800 Subject: [PATCH] [flash attention] fix bugs for attention mask (#2987) --- csrc/cpu/aten/kernels/FlashAttentionKrnl.cpp | 266 +++++++++++++------ tests/cpu/test_cpu_ops.py | 124 +++++++-- 2 files changed, 287 insertions(+), 103 deletions(-) diff --git a/csrc/cpu/aten/kernels/FlashAttentionKrnl.cpp b/csrc/cpu/aten/kernels/FlashAttentionKrnl.cpp index d409c6667..d3252f77f 100644 --- a/csrc/cpu/aten/kernels/FlashAttentionKrnl.cpp +++ b/csrc/cpu/aten/kernels/FlashAttentionKrnl.cpp @@ -283,76 +283,52 @@ inline Vectorized exp_u20(Vectorized data) { #endif // out = val * a + b -template +// is_b_stride_zero: If the stride of b is 0 (mask broadcasting case), +// take b as a scalar pointer. +template inline void _scale_attn_mask_fusion_kernel( T1* a, T2* b, const int& size, T1* out, T1& val) { - auto vec_size = at::vec::Vectorized::size(); - auto vec_scale = at::vec::Vectorized(val); - for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) { - auto tmp0 = at::vec::Vectorized::loadu(a + i); - auto tmp1 = at::vec::Vectorized::loadu(b + i); - auto tmp2 = at::vec::convert(tmp1); - auto tmp3 = tmp0 * vec_scale + tmp2; - _store(out + i, tmp3); - } - for (long i = vec_size * (size / vec_size); i < size; i++) { - auto tmp0 = a[i]; - auto tmp1 = (T1)b[i]; - out[i] = tmp0 * val + tmp1; - } -} - -// out = val * a + b -template -inline void _scale_attn_mask_fusion_kernel( - T1* a, - T1* b, - const int& size, - T1* out, - T1& val) { - auto vec_size = at::vec::Vectorized::size(); - auto vec_scale = at::vec::Vectorized(val); - for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) { - auto tmp0 = at::vec::Vectorized::loadu(a + i); - auto tmp1 = at::vec::Vectorized::loadu(b + i); - auto tmp2 = tmp0 * vec_scale + tmp1; - _store(out + i, tmp2); - } - for (long i = vec_size * (size / vec_size); i < size; i++) { - auto tmp0 = a[i]; - auto tmp1 = b[i]; - out[i] = tmp0 * val + tmp1; - } -} - -// out = b ? val * a : -inf -template -inline void _scale_attn_mask_fusion_kernel( - T1* a, - bool* b, - const int& size, - T1* out, - T1& val) { - auto vec_size = at::vec::Vectorized::size(); - auto vec_scale = at::vec::Vectorized(val); - auto neg_inf = -std::numeric_limits::infinity(); - auto vec_neg_inf = at::vec::Vectorized(neg_inf); - for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) { - auto tmp0 = at::vec::Vectorized::loadu(a + i); - auto tmp1 = at::vec::Vectorized::loadu(b + i); - auto tmp2 = at::vec::convert(tmp1); - auto tmp3 = - at::vec::Vectorized::blendv(vec_neg_inf, tmp0 * vec_scale, tmp2); - _store(out + i, tmp3); - } - for (long i = vec_size * (size / vec_size); i < size; i++) { - auto tmp0 = a[i]; - auto tmp1 = b[i]; - out[i] = tmp1 ? tmp0 * val : neg_inf; + const auto vec_size1 = at::vec::Vectorized::size(); + const auto vec_size2 = at::vec::Vectorized::size(); + constexpr int64_t T1_n = + (vec_size2 == vec_size1 * 2 && is_reduced_floating_point_v) ? 2 : 1; + constexpr int64_t T2_n = 1; + auto vec_scale = at::vec::VectorizedN(val); + int64_t i = 0; + if (is_b_stride_zero) { + auto b_first_val = (T1)b[0]; + auto b_first_vec = at::vec::VectorizedN(b_first_val); + for (; i < size - (size % vec_size2); i += vec_size2) { + auto a_n = at::vec::VectorizedN::loadu(a + i); + auto b_n = b_first_vec; + at::vec::VectorizedN b_n_convert = + at::vec::convert(b_n); + auto res = a_n * vec_scale + b_n_convert; + res.store(out + i); + } + for (; i < size; i++) { + auto tmp0 = a[i]; + auto tmp1 = b_first_val; + out[i] = tmp0 * val + tmp1; + } + } else { + for (; i < size - (size % vec_size2); i += vec_size2) { + auto a_n = at::vec::VectorizedN::loadu(a + i); + auto b_n = at::vec::VectorizedN::loadu(b + i); + at::vec::VectorizedN b_n_convert = + at::vec::convert(b_n); + auto res = a_n * vec_scale + b_n_convert; + res.store(out + i); + } + for (; i < size; i++) { + auto tmp0 = a[i]; + auto tmp1 = (T1)b[i]; + out[i] = tmp0 * val + tmp1; + } } } @@ -425,6 +401,82 @@ inline void _mul_reduce_max_fusion_kernel( vec_tmp_max)); } +// This function is used to produce an attn_mask in a standard format +inline std::optional convert_boolean_attn_mask( + const std::optional& attn_mask, + caffe2::TypeMeta dtype) { + // Pass through + if (!attn_mask.has_value()) { + return c10::nullopt; + } + // Convert boolean mask to additive mask + if (attn_mask->dtype() == at::kBool) { + auto new_attn_mask = at::zeros_like(attn_mask.value(), dtype); + new_attn_mask.masked_fill_( + attn_mask->logical_not(), -std::numeric_limits::infinity()); + return new_attn_mask; + } + // Otherwise, attn_mask represents an additive attention tensor + return attn_mask; +} + +// Support mask shapes: +// 2d: ({Q_seq_len, 1} x {KV_seq_len, 1}) +// 4d: ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1} x {KV_seq_len, 1}) +inline bool check_attn_mask_shape( + at::Tensor& attn_mask, + int64_t batchSize, + int64_t num_head, + int64_t qSize, + int64_t kvSize) { + if (attn_mask.size(-2) != qSize && attn_mask.size(-2) != 1) { + return false; + } + if (attn_mask.size(-1) != kvSize && attn_mask.size(-1) != 1) { + return false; + } + if (attn_mask.dim() == 2) { + return true; + } else if (attn_mask.dim() == 4) { + if ((attn_mask.size(0) == 1 || attn_mask.size(0) == batchSize) && + (attn_mask.size(1) == 1 || attn_mask.size(1) == num_head)) { + return true; + } + } + return false; +} + +// Reshape attention mask to 4d +inline void reshape_attn_mask_to_4d( + at::Tensor& attn_mask, + int64_t batchSize, + int64_t num_head, + int64_t qSize, + int64_t kvSize) { + TORCH_CHECK( + check_attn_mask_shape(attn_mask, batchSize, num_head, qSize, kvSize), + "IPEX flash_attention: Please use the following attn mask shapes: ", + "2d - ({Q_seq_len, 1} x {KV_seq_len, 1}); ", + "4d - ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1} x {KV_seq_len, 1})"); + int64_t attn_mask_size_0 = 1; + int64_t attn_mask_size_1 = 1; + if (attn_mask.dim() == 4) { + if (attn_mask.size(0) == batchSize) { + attn_mask_size_0 = batchSize; + } + if (attn_mask.size(1) == num_head) { + attn_mask_size_1 = num_head; + } + } + attn_mask = attn_mask + .view( + {attn_mask_size_0, + attn_mask_size_1, + attn_mask.size(-2), + attn_mask.size(-1)}) + .expand({attn_mask_size_0, attn_mask_size_1, qSize, kvSize}); +} + /* *Caculate the flash attention SDPA. *@template scalar_t: q/k/v data type @@ -480,6 +532,12 @@ cpu_flash_attention( int64_t num_head = query.size(2); int64_t headSize = query.size(3); + // reshape mask + if (attention_mask.has_value()) { + reshape_attn_mask_to_4d( + attention_mask.value(), batchSize, num_head, qSize, kvSize); + } + // Strides int64_t qStrideB = query.stride(0); int64_t qStrideM = query.stride(1); @@ -505,7 +563,13 @@ cpu_flash_attention( ? attention_mask.value().stride(1) : 0; int64_t mStrideM = - attention_mask.has_value() ? attention_mask.value().stride(2) : 0; + (attention_mask.has_value() && attention_mask.value().size(2) > 1) + ? attention_mask.value().stride(2) + : 0; + int64_t mStrideN = + (attention_mask.has_value() && attention_mask.value().size(3) > 1) + ? attention_mask.value().stride(3) + : 0; int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size; int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size; @@ -596,15 +660,24 @@ cpu_flash_attention( // And apply scaling factor if (attention_mask.has_value()) { for (int64_t row = 0; row < qBlockSize; ++row) { - // qk <- attn_mask ? qk : -inf, if attn_mask is bool - // qk <- qk + attn_mask, else - _scale_attn_mask_fusion_kernel( - qk_data + row * kvBlockSize, - mask_data + i * mStrideB + j * mStrideH + - (m + row) * mStrideM + n, - kvBlockSize, - qk_data + row * kvBlockSize, - scaling_factor); + // qk <- qk * scaling_factor + attn_mask, else + if (mStrideN == 0) { + _scale_attn_mask_fusion_kernel( + qk_data + row * kvBlockSize, + mask_data + i * mStrideB + j * mStrideH + + (m + row) * mStrideM, + kvBlockSize, + qk_data + row * kvBlockSize, + scaling_factor); + } else { + _scale_attn_mask_fusion_kernel( + qk_data + row * kvBlockSize, + mask_data + i * mStrideB + j * mStrideH + + (m + row) * mStrideM + n, + kvBlockSize, + qk_data + row * kvBlockSize, + scaling_factor); + } } } // Update coefficients with Softmax @@ -737,6 +810,12 @@ cpu_flash_attention( int64_t num_head = query.size(2); int64_t headSize = query.size(3); + // reshape mask + if (attention_mask.has_value()) { + reshape_attn_mask_to_4d( + attention_mask.value(), batchSize, num_head, qSize, kvSize); + } + // Strides int64_t qStrideB = query.stride(0); int64_t qStrideM = query.stride(1); @@ -762,7 +841,13 @@ cpu_flash_attention( ? attention_mask.value().stride(1) : 0; int64_t mStrideM = - attention_mask.has_value() ? attention_mask.value().stride(2) : 0; + (attention_mask.has_value() && attention_mask.value().size(2) > 1) + ? attention_mask.value().stride(2) + : 0; + int64_t mStrideN = + (attention_mask.has_value() && attention_mask.value().size(3) > 1) + ? attention_mask.value().stride(3) + : 0; int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size; int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size; @@ -1241,15 +1326,24 @@ cpu_flash_attention( // And apply scaling factor if (attention_mask.has_value()) { for (int64_t row = 0; row < qBlockSize; ++row) { - // qk <- attn_mask ? qk : -inf, if attn_mask is bool - // qk <- qk + attn_mask, else - _scale_attn_mask_fusion_kernel( - qk_data + row * kvBlockSize, - mask_data + i * mStrideB + j * mStrideH + - (m + row) * mStrideM + n, - kvBlockSize, - qk_data + row * kvBlockSize, - scaling_factor); + // qk <- qk * scaling_factor + attn_mask, else + if (mStrideN == 0) { + _scale_attn_mask_fusion_kernel( + qk_data + row * kvBlockSize, + mask_data + i * mStrideB + j * mStrideH + + (m + row) * mStrideM, + kvBlockSize, + qk_data + row * kvBlockSize, + scaling_factor); + } else { + _scale_attn_mask_fusion_kernel( + qk_data + row * kvBlockSize, + mask_data + i * mStrideB + j * mStrideH + + (m + row) * mStrideM + n, + kvBlockSize, + qk_data + row * kvBlockSize, + scaling_factor); + } } } // Update coefficients with Softmax @@ -1558,6 +1652,8 @@ std::tuple flash_attention_kernel( attention_mask.value().stride(-1) == 1), "IPEX flash_attention: Q/K/V/Mask should be continuous on the last dim"); + std::optional attn_mask = + convert_boolean_attn_mask(attention_mask, query.dtype()); at::Tensor output = at::empty({batchSize, qSize, num_head, headSize}, query.options()); const auto accumulate_dtype = at::toOpMathType(dtype); @@ -1572,7 +1668,7 @@ std::tuple flash_attention_kernel( value, dropout_p, is_causal, - attention_mask, + attn_mask, scale); output = output.transpose(1, 2); diff --git a/tests/cpu/test_cpu_ops.py b/tests/cpu/test_cpu_ops.py index c9fcd0f0f..f8491f00c 100644 --- a/tests/cpu/test_cpu_ops.py +++ b/tests/cpu/test_cpu_ops.py @@ -1402,14 +1402,10 @@ def test_cat(self): self.assertTrue(y7.size() == torch.Size([8, 2])) self.assertTrue(y7.dtype == datatype) - def test_flash_attention(self): + def test_flash_attention_without_mask(self): dtypes = [torch.float, torch.double, torch.bfloat16, torch.float16] for dtype in dtypes: - for causal, has_attention_mask in [ - [False, False], - [True, False], - [False, True], - ]: + for causal in [True, False]: for batch_size, seq_len, n_head, head_dim in itertools.product( [2, 12], [1, 129, 267, 533, 1030], [1, 3, 4], [7, 8, 16] ): @@ -1447,30 +1443,18 @@ def test_flash_attention(self): q2 = q2.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2) v2 = v2.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2) - mask = ( - torch.randn( - (batch_size, 1, seq_len, seq_len), - device="cpu", - dtype=dtype, - requires_grad=False, - ) - if has_attention_mask - else None - ) actual = torch.ops.torch_ipex.flash_attention( q, k, v, dropout_p=0.0, is_causal=causal, - attention_mask=mask, )[0] math_ref = ( torch._scaled_dot_product_attention_math( q2, k2, v2, - attn_mask=mask, dropout_p=0.0, is_causal=causal, ) @@ -1480,6 +1464,110 @@ def test_flash_attention(self): math_ref = math_ref.to(dtype) torch.testing.assert_close(actual, math_ref, atol=atol, rtol=rtol) + def test_flash_attention_with_mask(self): + dtypes = [torch.float, torch.double, torch.bfloat16, torch.float16] + for dtype in dtypes: + for mask_dim in [2, 4]: + batch_size, seq_len, n_head, head_dim = 2, 129, 4, 8 + atol = 1e-5 + rtol = 5e-6 + if dtype is torch.bfloat16: + atol = 2e-2 + rtol = 2e-2 + if dtype is torch.float16: + atol = 1e-2 + rtol = 1e-2 + attn_mask_dtypes = ( + [dtype, torch.bool, torch.float] + if dtype in [torch.bfloat16, torch.float16] + else [dtype, torch.bool] + ) + for attn_mask_dtype in attn_mask_dtypes: + for attn_mask_shape in ( + itertools.product([seq_len, 1], [seq_len, 1]) + if mask_dim == 2 + else itertools.product( + [batch_size, 1], [n_head, 1], [seq_len, 1], [seq_len, 1] + ) + ): + n_embd = n_head * head_dim + x = torch.randn( + (batch_size, seq_len, 3 * n_head * head_dim), + device="cpu", + dtype=dtype, + requires_grad=False, + ) + x2 = x.clone() + + q, k, v = x.split(n_embd, dim=2) + q2, k2, v2 = x2.split(n_embd, dim=2) + + if dtype in [torch.bfloat16, torch.float16]: + q2 = q2.float() + k2 = k2.float() + v2 = v2.float() + + # (B, nh, T, hs) + k = k.view(batch_size, seq_len, n_head, head_dim).transpose( + 1, 2 + ) + q = q.view(batch_size, seq_len, n_head, head_dim).transpose( + 1, 2 + ) + v = v.view(batch_size, seq_len, n_head, head_dim).transpose( + 1, 2 + ) + k2 = k2.view(batch_size, seq_len, n_head, head_dim).transpose( + 1, 2 + ) + q2 = q2.view(batch_size, seq_len, n_head, head_dim).transpose( + 1, 2 + ) + v2 = v2.view(batch_size, seq_len, n_head, head_dim).transpose( + 1, 2 + ) + + if attn_mask_dtype == torch.bool: + mask = torch.ones( + attn_mask_shape, + dtype=torch.bool, + device="cpu", + requires_grad=False, + ).tril(diagonal=0) + # _scaled_dot_product_attention_math does the type conversion outside + mask2 = torch.zeros_like(mask, dtype=dtype) + mask2[mask == False] = -float("inf") # noqa: E712 + else: + mask = torch.randn( + attn_mask_shape, + dtype=attn_mask_dtype, + device="cpu", + requires_grad=False, + ) + mask2 = mask + actual = torch.ops.torch_ipex.flash_attention( + q, + k, + v, + dropout_p=0.0, + attention_mask=mask, + )[0] + math_ref = ( + torch._scaled_dot_product_attention_math( + q2, + k2, + v2, + attn_mask=mask2, + dropout_p=0.0, + ) + )[0] + + if dtype in [torch.bfloat16, torch.float16]: + math_ref = math_ref.to(dtype) + torch.testing.assert_close( + actual, math_ref, atol=atol, rtol=rtol + ) + def test_flash_attention_stride0(self): input_shape = ( 1,