From 6dbb4dcfa3476088d7b9b24a8ebe12ce5f7a0142 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Mon, 22 Jul 2024 10:51:58 -0700 Subject: [PATCH] Fix sdpa flash attention op for et llama deployment (#4322) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/4322 We retropfitted flash attention cpu from aten. The retrofit we did was to make it work to cacluate attention for a) batched prefill and b) decode with different start_pos. For b, there was a bug when kv cache's seqlen dim is split. As a result attention calculation is not right. There is a detail in the code to explain the issue. bypass-github-export-checks ghstack-source-id: 234634902 Reviewed By: larryliu0820 Differential Revision: D60011925 fbshipit-source-id: 50921846b329e449a4a767cf28c7a55d507217bd --- examples/models/llama2/custom_ops/op_sdpa.cpp | 37 ++++- .../custom_ops/test_sdpa_with_kv_cache.py | 132 ++++++++++++++++++ 2 files changed, 168 insertions(+), 1 deletion(-) diff --git a/examples/models/llama2/custom_ops/op_sdpa.cpp b/examples/models/llama2/custom_ops/op_sdpa.cpp index 5bda250646..758973bcb7 100644 --- a/examples/models/llama2/custom_ops/op_sdpa.cpp +++ b/examples/models/llama2/custom_ops/op_sdpa.cpp @@ -239,6 +239,12 @@ void cpu_flash_attention( at::Tensor value = v.transpose(1, 2); */ + // Without this we have out-of-bounds writes for + // causal masking + static_assert( + kv_split_size > q_split_size, + "KV_split_size must be greater than q_split_size"); + constexpr bool is_reduced_type = torch::executor::is_reduced_floating_point::value; @@ -417,7 +423,35 @@ void cpu_flash_attention( // Initialize max and sum fill_stub( qk_max_data, -std::numeric_limits::infinity(), qBlockSize); - int64_t num_keys = is_causal ? std::min(m + qBlockSize, kvSize) : kvSize; + // Original flash sdpa wasnt really meant to be used + // for decode the way we are using via start_pos here. + // Thus when num_keys is 1 during decode phase, we + // still need to iterate through all the kv_splits + // Take start_pos = 130 and k_split_size = 128 + // Here we have to produce [1x130] of q @ k.T + // when seq_len = 1 + // But if num_keys = 1 then we dont really loop over + // all kv_splits. + // When k_split_size > 130, this is not an issue because + // there is only one iteration of the following loop anyway. + // Outside of determining how many loop iterations are needed + // num_keys participates only in causal attention. + // Rest of the calculation of q @ k.T and @ v.T is same. + // We dont run into this bug when k_split_size < start_pos + seqlen + // since there is only one iteration and that applies + // causal attention correctly. + // Howeve when k_split_size > start_pos + seqlen, we have + // more than one iteration, however if we dont adjust num_keys + // we dont get more than one iteration + // This is unique to this deployment of flash attention since + // original implementation wasnt deployed on this way. + + // Some of these bugs can be resolved by relying on attention mask + // but that requires storing attention mask in float as the current + // code doesnt support bool attention mask. + // However, lets just fix that as well. + int64_t num_keys = + is_causal ? std::min(m + start_pos + qBlockSize, kvSize) : kvSize; auto j_kv = j / num_reps; for (int64_t n = 0; n < num_keys; n += kvSplitSize) { int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); @@ -452,6 +486,7 @@ void cpu_flash_attention( // entries need masked out. In our example n = 4 // will qualify for that if (is_causal && num_keys - n <= kvSplitSize) { + // For this fn to work k_split_size > q_split_size for (int32_t row = 0; row < qBlockSize; ++row) { int64_t last_col = m + (row + start_pos) - n; accum_t* row_ptr = qk_data + row * kvBlockSize; diff --git a/examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py b/examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py index 1b8f425b67..d71cb486b8 100644 --- a/examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py +++ b/examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py @@ -365,3 +365,135 @@ def test_sdpa_with_cache_mqa_3(self): q, k, v, self.k_cache, self.v_cache, 1, 1, None, 0, False ) self.assertTrue(torch.allclose(ref_output, op_output)) + + +class SDPATestForLargeSeqLength(unittest.TestCase): + + def setup_caches(self): + self.k_cache = torch.zeros( + (1, self.max_seq_len, self.n_heads_kv, self.head_dim) + ) + self.v_cache = torch.zeros( + (1, self.max_seq_len, self.n_heads_kv, self.head_dim) + ) + self.mask = torch.full( + (self.max_seq_len, self.max_seq_len), + float("-inf"), + ) + self.mask = torch.triu(self.mask, diagonal=1) + + def setUp(self): + torch.manual_seed(42) + self.n_heads_kv = 32 + self.n_heads_q = 32 + self.head_dim = 128 + self.max_seq_len = 2048 + self.setup_caches() + + def test_sdpa_with_cache_seq_len_130(self): + self.n_heads_kv = 32 + self.n_heads_q = 32 + self.head_dim = 128 + self.max_seq_len = 2048 + self.setup_caches() + seq_len = 130 + q = torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)) + k = torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)) + v = torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)) + start_pos = 0 + attn_mask = self.mask[start_pos : start_pos + seq_len, :] + attn_mask = attn_mask[:, : start_pos + seq_len] + ref_output = _sdpa_with_kv_cache_ref( + q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len + ) + op_output = torch.ops.llama.sdpa_with_kv_cache( + q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True + ) + self.assertTrue(torch.allclose(ref_output, op_output)) + + q = torch.rand((1, 1, self.n_heads_kv, self.head_dim)) + k = torch.rand((1, 1, self.n_heads_kv, self.head_dim)) + v = torch.rand((1, 1, self.n_heads_kv, self.head_dim)) + start_pos = seq_len + seq_len = q.size(1) + attn_mask = self.mask[start_pos : start_pos + seq_len, :] + attn_mask = attn_mask[:, : start_pos + seq_len] + ref_output = _sdpa_with_kv_cache_ref( + q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len + ) + op_output = torch.ops.llama.sdpa_with_kv_cache( + q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True + ) + self.assertTrue(torch.allclose(ref_output, op_output)) + + def test_sdpa_with_cache_seq_len_small(self): + self.n_heads_kv = 4 + self.n_heads_q = 4 + self.head_dim = 4 + self.max_seq_len = 8 + self.setup_caches() + q = torch.rand((1, 4, self.n_heads_q, 4)) + k = torch.rand((1, 4, self.n_heads_q, 4)) + v = torch.rand((1, 4, self.n_heads_q, 4)) + start_pos = 0 + seq_len = q.size(1) + attn_mask = self.mask[start_pos : start_pos + seq_len, :] + attn_mask = attn_mask[:, : start_pos + seq_len] + ref_output = _sdpa_with_kv_cache_ref( + q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len + ) + op_output = torch.ops.llama.sdpa_with_kv_cache( + q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True + ) + self.assertTrue(torch.allclose(ref_output, op_output)) + + q = torch.rand((1, 1, self.n_heads_q, 4)) + k = torch.rand((1, 1, self.n_heads_q, 4)) + v = torch.rand((1, 1, self.n_heads_q, 4)) + start_pos = 4 + seq_len = q.size(1) + attn_mask = self.mask[start_pos : start_pos + seq_len, :] + attn_mask = attn_mask[:, : start_pos + seq_len] + ref_output = _sdpa_with_kv_cache_ref( + q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len + ) + op_output = torch.ops.llama.sdpa_with_kv_cache( + q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True + ) + self.assertTrue(torch.allclose(ref_output, op_output)) + + def test_sdpa_with_cache_seq_len_llava_example(self): + self.n_heads_kv = 32 + self.n_heads_q = 32 + self.head_dim = 128 + self.max_seq_len = 2048 + self.setup_caches() + seq_len = 634 + q = torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)) + k = torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)) + v = torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)) + start_pos = 0 + attn_mask = self.mask[start_pos : start_pos + seq_len, :] + attn_mask = attn_mask[:, : start_pos + seq_len] + ref_output = _sdpa_with_kv_cache_ref( + q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len + ) + op_output = torch.ops.llama.sdpa_with_kv_cache( + q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True + ) + self.assertTrue(torch.allclose(ref_output, op_output)) + + q = torch.rand((1, 1, self.n_heads_kv, self.head_dim)) + k = torch.rand((1, 1, self.n_heads_kv, self.head_dim)) + v = torch.rand((1, 1, self.n_heads_kv, self.head_dim)) + start_pos = seq_len + seq_len = q.size(1) + attn_mask = self.mask[start_pos : start_pos + seq_len, :] + attn_mask = attn_mask[:, : start_pos + seq_len] + ref_output = _sdpa_with_kv_cache_ref( + q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len + ) + op_output = torch.ops.llama.sdpa_with_kv_cache( + q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True + ) + self.assertTrue(torch.allclose(ref_output, op_output))