Skip to content

Commit

Permalink
Fix sdpa flash attention op for et llama deployment (#4322)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #4322

We retropfitted flash attention cpu from aten. The retrofit we did was to
make it work to cacluate attention for a) batched prefill and b) decode with
different start_pos. For b, there was a bug when kv cache's seqlen dim is
split.
As a result attention calculation is not right. There is a detail in the code
to explain the issue.

bypass-github-export-checks
ghstack-source-id: 234634902

Reviewed By: larryliu0820

Differential Revision: D60011925

fbshipit-source-id: 50921846b329e449a4a767cf28c7a55d507217bd
  • Loading branch information
kimishpatel authored and facebook-github-bot committed Jul 22, 2024
1 parent 9d85965 commit 6dbb4dc
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 1 deletion.
37 changes: 36 additions & 1 deletion examples/models/llama2/custom_ops/op_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,12 @@ void cpu_flash_attention(
at::Tensor value = v.transpose(1, 2);
*/

// Without this we have out-of-bounds writes for
// causal masking
static_assert(
kv_split_size > q_split_size,
"KV_split_size must be greater than q_split_size");

constexpr bool is_reduced_type =
torch::executor::is_reduced_floating_point<scalar_t>::value;

Expand Down Expand Up @@ -417,7 +423,35 @@ void cpu_flash_attention(
// Initialize max and sum
fill_stub(
qk_max_data, -std::numeric_limits<accum_t>::infinity(), qBlockSize);
int64_t num_keys = is_causal ? std::min(m + qBlockSize, kvSize) : kvSize;
// Original flash sdpa wasnt really meant to be used
// for decode the way we are using via start_pos here.
// Thus when num_keys is 1 during decode phase, we
// still need to iterate through all the kv_splits
// Take start_pos = 130 and k_split_size = 128
// Here we have to produce [1x130] of q @ k.T
// when seq_len = 1
// But if num_keys = 1 then we dont really loop over
// all kv_splits.
// When k_split_size > 130, this is not an issue because
// there is only one iteration of the following loop anyway.
// Outside of determining how many loop iterations are needed
// num_keys participates only in causal attention.
// Rest of the calculation of q @ k.T and @ v.T is same.
// We dont run into this bug when k_split_size < start_pos + seqlen
// since there is only one iteration and that applies
// causal attention correctly.
// Howeve when k_split_size > start_pos + seqlen, we have
// more than one iteration, however if we dont adjust num_keys
// we dont get more than one iteration
// This is unique to this deployment of flash attention since
// original implementation wasnt deployed on this way.

// Some of these bugs can be resolved by relying on attention mask
// but that requires storing attention mask in float as the current
// code doesnt support bool attention mask.
// However, lets just fix that as well.
int64_t num_keys =
is_causal ? std::min(m + start_pos + qBlockSize, kvSize) : kvSize;
auto j_kv = j / num_reps;
for (int64_t n = 0; n < num_keys; n += kvSplitSize) {
int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
Expand Down Expand Up @@ -452,6 +486,7 @@ void cpu_flash_attention(
// entries need masked out. In our example n = 4
// will qualify for that
if (is_causal && num_keys - n <= kvSplitSize) {
// For this fn to work k_split_size > q_split_size
for (int32_t row = 0; row < qBlockSize; ++row) {
int64_t last_col = m + (row + start_pos) - n;
accum_t* row_ptr = qk_data + row * kvBlockSize;
Expand Down
132 changes: 132 additions & 0 deletions examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,3 +365,135 @@ def test_sdpa_with_cache_mqa_3(self):
q, k, v, self.k_cache, self.v_cache, 1, 1, None, 0, False
)
self.assertTrue(torch.allclose(ref_output, op_output))


class SDPATestForLargeSeqLength(unittest.TestCase):

def setup_caches(self):
self.k_cache = torch.zeros(
(1, self.max_seq_len, self.n_heads_kv, self.head_dim)
)
self.v_cache = torch.zeros(
(1, self.max_seq_len, self.n_heads_kv, self.head_dim)
)
self.mask = torch.full(
(self.max_seq_len, self.max_seq_len),
float("-inf"),
)
self.mask = torch.triu(self.mask, diagonal=1)

def setUp(self):
torch.manual_seed(42)
self.n_heads_kv = 32
self.n_heads_q = 32
self.head_dim = 128
self.max_seq_len = 2048
self.setup_caches()

def test_sdpa_with_cache_seq_len_130(self):
self.n_heads_kv = 32
self.n_heads_q = 32
self.head_dim = 128
self.max_seq_len = 2048
self.setup_caches()
seq_len = 130
q = torch.rand((1, seq_len, self.n_heads_kv, self.head_dim))
k = torch.rand((1, seq_len, self.n_heads_kv, self.head_dim))
v = torch.rand((1, seq_len, self.n_heads_kv, self.head_dim))
start_pos = 0
attn_mask = self.mask[start_pos : start_pos + seq_len, :]
attn_mask = attn_mask[:, : start_pos + seq_len]
ref_output = _sdpa_with_kv_cache_ref(
q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len
)
op_output = torch.ops.llama.sdpa_with_kv_cache(
q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True
)
self.assertTrue(torch.allclose(ref_output, op_output))

q = torch.rand((1, 1, self.n_heads_kv, self.head_dim))
k = torch.rand((1, 1, self.n_heads_kv, self.head_dim))
v = torch.rand((1, 1, self.n_heads_kv, self.head_dim))
start_pos = seq_len
seq_len = q.size(1)
attn_mask = self.mask[start_pos : start_pos + seq_len, :]
attn_mask = attn_mask[:, : start_pos + seq_len]
ref_output = _sdpa_with_kv_cache_ref(
q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len
)
op_output = torch.ops.llama.sdpa_with_kv_cache(
q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True
)
self.assertTrue(torch.allclose(ref_output, op_output))

def test_sdpa_with_cache_seq_len_small(self):
self.n_heads_kv = 4
self.n_heads_q = 4
self.head_dim = 4
self.max_seq_len = 8
self.setup_caches()
q = torch.rand((1, 4, self.n_heads_q, 4))
k = torch.rand((1, 4, self.n_heads_q, 4))
v = torch.rand((1, 4, self.n_heads_q, 4))
start_pos = 0
seq_len = q.size(1)
attn_mask = self.mask[start_pos : start_pos + seq_len, :]
attn_mask = attn_mask[:, : start_pos + seq_len]
ref_output = _sdpa_with_kv_cache_ref(
q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len
)
op_output = torch.ops.llama.sdpa_with_kv_cache(
q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True
)
self.assertTrue(torch.allclose(ref_output, op_output))

q = torch.rand((1, 1, self.n_heads_q, 4))
k = torch.rand((1, 1, self.n_heads_q, 4))
v = torch.rand((1, 1, self.n_heads_q, 4))
start_pos = 4
seq_len = q.size(1)
attn_mask = self.mask[start_pos : start_pos + seq_len, :]
attn_mask = attn_mask[:, : start_pos + seq_len]
ref_output = _sdpa_with_kv_cache_ref(
q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len
)
op_output = torch.ops.llama.sdpa_with_kv_cache(
q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True
)
self.assertTrue(torch.allclose(ref_output, op_output))

def test_sdpa_with_cache_seq_len_llava_example(self):
self.n_heads_kv = 32
self.n_heads_q = 32
self.head_dim = 128
self.max_seq_len = 2048
self.setup_caches()
seq_len = 634
q = torch.rand((1, seq_len, self.n_heads_kv, self.head_dim))
k = torch.rand((1, seq_len, self.n_heads_kv, self.head_dim))
v = torch.rand((1, seq_len, self.n_heads_kv, self.head_dim))
start_pos = 0
attn_mask = self.mask[start_pos : start_pos + seq_len, :]
attn_mask = attn_mask[:, : start_pos + seq_len]
ref_output = _sdpa_with_kv_cache_ref(
q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len
)
op_output = torch.ops.llama.sdpa_with_kv_cache(
q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True
)
self.assertTrue(torch.allclose(ref_output, op_output))

q = torch.rand((1, 1, self.n_heads_kv, self.head_dim))
k = torch.rand((1, 1, self.n_heads_kv, self.head_dim))
v = torch.rand((1, 1, self.n_heads_kv, self.head_dim))
start_pos = seq_len
seq_len = q.size(1)
attn_mask = self.mask[start_pos : start_pos + seq_len, :]
attn_mask = attn_mask[:, : start_pos + seq_len]
ref_output = _sdpa_with_kv_cache_ref(
q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len
)
op_output = torch.ops.llama.sdpa_with_kv_cache(
q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True
)
self.assertTrue(torch.allclose(ref_output, op_output))

0 comments on commit 6dbb4dc

Please sign in to comment.