diff --git a/csrc/cpu/aten/kernels/RotaryPositionEmbeddingKnl.cpp b/csrc/cpu/aten/kernels/RotaryPositionEmbeddingKnl.cpp index ba044e2e6..0894ef23e 100644 --- a/csrc/cpu/aten/kernels/RotaryPositionEmbeddingKnl.cpp +++ b/csrc/cpu/aten/kernels/RotaryPositionEmbeddingKnl.cpp @@ -86,7 +86,12 @@ std::tuple ApplyROPEKernel( auto out_stride_kb = concat_qkv ? key.stride(0) : 0; auto out_stride_ks = concat_qkv ? key.stride(1) : 0; auto emb_pos_ptr = t_emb_pos.data_ptr(); // [MP][HR] - auto pos_ptr = t_pos.data_ptr(); // [MB][S] + auto pos_ptr = t_pos.data_ptr(); // [B][S] or [1][S] + bool t_pos_no_repeated_for_batch = false; + if (t_pos.numel() != 1 && t_pos.size(0) == 1 && B > 1) { + // we do not perform t_pos.repeat here to avoid the overhead of copying + t_pos_no_repeated_for_batch = true; + } { #pragma omp parallel for collapse(3) for (int b = 0; b < B; b++) { @@ -106,7 +111,8 @@ std::tuple ApplyROPEKernel( sin_start = emb_pos_ptr + (p + s) * HR; cos_start = emb_pos_ptr + (p + s) * HR + COFF; } else { - p = pos_ptr[b * S + s]; + auto start_idx = t_pos_no_repeated_for_batch ? 0 : b * S; + p = pos_ptr[start_idx + s]; sin_start = emb_pos_ptr + p * HR; cos_start = emb_pos_ptr + p * HR + COFF; } diff --git a/tests/cpu/test_rope.py b/tests/cpu/test_rope.py index 3e9a8575a..70f482476 100644 --- a/tests/cpu/test_rope.py +++ b/tests/cpu/test_rope.py @@ -7,7 +7,7 @@ class FusedROPETester(TestCase): def setUp(self): - self.batch = 1 + self.batch = 2 self.seq_len = 32 self.max_seq_len = 384 self.head_size = 256 @@ -76,7 +76,10 @@ def hf_forward( query, key, position_ids, embed_positions, offset=None, rotary_dim=None ): embed_positions = _get_embed_positions(embed_positions, position_ids) - sincos = embed_positions.squeeze()[position_ids] + repeated_position_ids = position_ids.unsqueeze(-1).repeat( + 1, 1, embed_positions.shape[-1] + ) + sincos = torch.gather(embed_positions, 1, repeated_position_ids) sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1) if rotary_dim < self.head_size: