Skip to content

Commit

Permalink
fix rope for BS > 1 (#2912) (#2928)
Browse files Browse the repository at this point in the history
* fix rope for BS > 1

* fix ut
  • Loading branch information
zhuhaozhe authored May 27, 2024
1 parent 9419cc1 commit 2d02768
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
10 changes: 8 additions & 2 deletions csrc/cpu/aten/kernels/RotaryPositionEmbeddingKnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,12 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> ApplyROPEKernel(
auto out_stride_kb = concat_qkv ? key.stride(0) : 0;
auto out_stride_ks = concat_qkv ? key.stride(1) : 0;
auto emb_pos_ptr = t_emb_pos.data_ptr<float>(); // [MP][HR]
auto pos_ptr = t_pos.data_ptr<long>(); // [MB][S]
auto pos_ptr = t_pos.data_ptr<long>(); // [B][S] or [1][S]
bool t_pos_no_repeated_for_batch = false;
if (t_pos.numel() != 1 && t_pos.size(0) == 1 && B > 1) {
// we do not perform t_pos.repeat here to avoid the overhead of copying
t_pos_no_repeated_for_batch = true;
}
{
#pragma omp parallel for collapse(3)
for (int b = 0; b < B; b++) {
Expand All @@ -106,7 +111,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> ApplyROPEKernel(
sin_start = emb_pos_ptr + (p + s) * HR;
cos_start = emb_pos_ptr + (p + s) * HR + COFF;
} else {
p = pos_ptr[b * S + s];
auto start_idx = t_pos_no_repeated_for_batch ? 0 : b * S;
p = pos_ptr[start_idx + s];
sin_start = emb_pos_ptr + p * HR;
cos_start = emb_pos_ptr + p * HR + COFF;
}
Expand Down
7 changes: 5 additions & 2 deletions tests/cpu/test_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class FusedROPETester(TestCase):
def setUp(self):
self.batch = 1
self.batch = 2
self.seq_len = 32
self.max_seq_len = 384
self.head_size = 256
Expand Down Expand Up @@ -76,7 +76,10 @@ def hf_forward(
query, key, position_ids, embed_positions, offset=None, rotary_dim=None
):
embed_positions = _get_embed_positions(embed_positions, position_ids)
sincos = embed_positions.squeeze()[position_ids]
repeated_position_ids = position_ids.unsqueeze(-1).repeat(
1, 1, embed_positions.shape[-1]
)
sincos = torch.gather(embed_positions, 1, repeated_position_ids)
sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)

if rotary_dim < self.head_size:
Expand Down

0 comments on commit 2d02768

Please sign in to comment.