Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the shape of input sin and cos for fused_rope. #56132

Merged
merged 2 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/fusion/gpu/fused_rope_utils.h"

namespace phi {
namespace fusion {

Expand Down
49 changes: 31 additions & 18 deletions paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/fusion/gpu/fused_rope_utils.h"

namespace phi {
namespace fusion {

Expand All @@ -35,13 +36,14 @@ void FusedRopeKernel(const Context& dev_ctx,
int64_t numel = q.numel();
if (numel <= 0) return;
dev_ctx.template Alloc<T>(out_q);
// small size for broadcast

// q.shape: [batch_size, seq_len, num_heads, head_dim]
auto batch_size = q.dims()[0];
auto seq_len = q.dims()[1];
auto num_heads = q.dims()[2];
auto head_dim = q.dims()[3];
auto seq_len = q.dims()[1];
PADDLE_ENFORCE_NE(head_dim % 2,
1,
PADDLE_ENFORCE_EQ(head_dim % 2,
0,
phi::errors::InvalidArgument(
"The head_dim of input must be a multiple of 2."));

Expand Down Expand Up @@ -85,26 +87,37 @@ void FusedRopeKernel(const Context& dev_ctx,
PADDLE_ENFORCE_EQ(sin.get_ptr()->dims(),
cos.get_ptr()->dims(),
phi::errors::InvalidArgument(
"The dims of sin and cos must be the same."));
"The dims of sin and cos must be the same. But "
"recieved sin's dims is {%s}, cos's dims is {%s}.",
sin.get_ptr()->dims(),
cos.get_ptr()->dims()));

auto sin_dims = sin.get_ptr()->dims();
int dims_size = sin_dims.size();
PADDLE_ENFORCE_NE((dims_size == 2 || dims_size == 4),
false,
phi::errors::InvalidArgument(
"The dims of sin and cos must be 2 or 4."));
PADDLE_ENFORCE_EQ(
(dims_size == 2 || dims_size == 4),
true,
phi::errors::InvalidArgument("The dims of sin and cos is expected to "
"be 2 or 4, but recieved %d.",
dims_size));
if (dims_size == 4) {
PADDLE_ENFORCE_NE(
(sin_dims[0] == 1 && sin_dims[1] == 1),
false,
// sin.shape: [1, seq_len, 1, head_dim]
PADDLE_ENFORCE_EQ(
(sin_dims[0] == 1 && sin_dims[2] == 1),
true,
phi::errors::InvalidArgument(
"The batch_size and num_heads of sin and cos must be 1."));
}
PADDLE_ENFORCE_NE(
(sin_dims[dims_size - 1] == head_dim &&
sin_dims[dims_size - 2] == seq_len),
false,
phi::errors::InvalidArgument("The seq_len and head_dim of sin and cos "
"must be the same as those of q."));
int sin_seq_len_dim = (dims_size) == 4 ? 1 : 0;
PADDLE_ENFORCE_EQ((sin_dims[dims_size - 1] == head_dim &&
sin_dims[sin_seq_len_dim] == seq_len),
true,
phi::errors::InvalidArgument(
"The seq_len and head_dim of sin and cos "
"must be the same as those of q. But recieved sin's "
"shape is {%s}, q's shape is {%s}.",
sin_dims,
q.dims()));

sin_cos_data[0] = sin->data<T>();
sin_cos_data[1] = cos->data<T>();
Expand Down
20 changes: 14 additions & 6 deletions test/legacy_test/test_fused_rotary_position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,27 +64,35 @@ def get_sin_cos_tensor(seq_len, head_dim, sign):

tensor_sin = paddle.reshape(
paddle.to_tensor(sin_sin),
[1, 1, seq_len, head_dim],
[1, seq_len, 1, head_dim],
)
tensor_cos = paddle.reshape(
paddle.to_tensor(cos_cos),
[1, 1, seq_len, head_dim],
[1, seq_len, 1, head_dim],
)

return tensor_sin, tensor_cos


def paddle_fused_rotary_position_embedding(init_q, init_k, init_v):
# permute q, k, v from [batch_size, seq_len, num_heads, head_dim]
# to [batch_size, num_heads, seq_len, head_dim]
q, k, v = deal_qkv(init_q, init_k, init_v)

sin_tensor, cos_tensor = get_sin_cos_tensor(q.shape[2], q.shape[3], -1)

# permute sin, cos from [1, seq_len, 1, head_dim]
# to [1, 1, seq_len, head_dim]
perm = [0, 2, 1, 3]
sin_tensor = paddle.transpose(x=sin_tensor, perm=perm)
cos_tensor = paddle.transpose(x=cos_tensor, perm=perm)

query = mult_qkv(q, cos_tensor, sin_tensor)
value = mult_qkv(v, cos_tensor, sin_tensor)
key = mult_qkv(k, cos_tensor, sin_tensor)

# permute the result back to [batch_size, seq_len, num_heads, head_dim]
r_query, r_key, r_value = deal_qkv(query, key, value)

return r_query, r_key, r_value


Expand All @@ -94,7 +102,7 @@ def paddle_fused_rotary_position_embedding(init_q, init_k, init_v):
)
class TestFusedRotaryPositionEmbedding(unittest.TestCase):
def setUp(self):
self.shape = [1, 16, 1, 16]
self.shape = [1, 8, 2, 16]
self.dtype = 'float32'
self.training = True
self.seed = 1203
Expand Down Expand Up @@ -138,7 +146,7 @@ def get_forward_backward(self, rope_function, seed, flag=0):

return fw, bw

def test_fused_dropout_add(self):
def test_fused_rope(self):
p_fw, p_bw = self.get_forward_backward(
paddle_fused_rotary_position_embedding, seed=self.seed
)
Expand All @@ -153,7 +161,7 @@ def test_fused_dropout_add(self):
p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05
)

def test_fused_dropout_add_sin_cos(self):
def test_fused_rope_with_sin_cos(self):
p_fw, p_bw = self.get_forward_backward(
paddle_fused_rotary_position_embedding, seed=self.seed
)
Expand Down