Skip to content

Commit

Permalink
add sin and cos optional parameters to fused_rope op (PaddlePaddle#55415
Browse files Browse the repository at this point in the history
)
  • Loading branch information
tianhaodongbd authored and wyf committed Aug 30, 2023
1 parent 82514c2 commit e01a173
Show file tree
Hide file tree
Showing 11 changed files with 233 additions and 150 deletions.
6 changes: 3 additions & 3 deletions paddle/phi/api/yaml/fused_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
support_dygraph_mode : true

- backward_op : fused_rotary_position_embedding_grad
forward: fused_rotary_position_embedding (Tensor q, Tensor k, Tensor v) -> Tensor(out_q), Tensor(out_k), Tensor(out_v)
args : (Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad)
forward: fused_rotary_position_embedding (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos) -> Tensor(out_q), Tensor(out_k), Tensor(out_v)
args : (Tensor sin, Tensor cos, Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
optional : out_k_grad, out_v_grad, k_grad, v_grad
optional : sin, cos, out_k_grad, out_v_grad, k_grad, v_grad
infer_meta :
func : FusedRopeGradInferMeta
kernel :
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/fused_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,11 @@
optional : cache_kv, pre_caches, rotary_pos_emb, time_step, seq_lengths, src_mask, gather_index

- op : fused_rotary_position_embedding
args : (Tensor q, Tensor k, Tensor v)
args : (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos)
output : Tensor(out_q), Tensor(out_k), Tensor(out_v)
infer_meta :
func : FusedRopeInferMeta
optional : k,v, out_k, out_v
optional : k,v,sin,cos, out_k, out_v
kernel :
func : fused_rotary_position_embedding
data_type : q
Expand Down
4 changes: 3 additions & 1 deletion paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1217,7 +1217,9 @@ void IndexPutGradInferMeta(const MetaTensor& x,
}
}

void FusedRopeGradInferMeta(const MetaTensor& dout_q,
void FusedRopeGradInferMeta(const MetaTensor& sin,
const MetaTensor& cos,
const MetaTensor& dout_q,
const MetaTensor& dout_k,
const MetaTensor& dout_v,
MetaTensor* dq,
Expand Down
4 changes: 3 additions & 1 deletion paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset,
MetaTensor* x_grad,
MetaTensor* y_grad);

void FusedRopeGradInferMeta(const MetaTensor& dout_q,
void FusedRopeGradInferMeta(const MetaTensor& sin,
const MetaTensor& cos,
const MetaTensor& dout_q,
const MetaTensor& dout_k,
const MetaTensor& dout_v,
MetaTensor* dq,
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3617,6 +3617,8 @@ void FusedConvInferMeta(const MetaTensor& input,
void FusedRopeInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
const MetaTensor& sin,
const MetaTensor& cos,
MetaTensor* out_q,
MetaTensor* out_k,
MetaTensor* out_v) {
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,8 @@ void WeightOnlyMatmulInferMeta(const MetaTensor& x,
void FusedRopeInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
const MetaTensor& sin,
const MetaTensor& cos,
MetaTensor* out_q,
MetaTensor* out_k,
MetaTensor* out_v);
Expand Down
75 changes: 17 additions & 58 deletions paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,68 +18,14 @@
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/fusion/gpu/fused_rope_utils.h"
namespace phi {
namespace fusion {
template <typename T, typename MPType, int VecSize = 2>
__global__ void VectorizedFusedRopeGradKernel(phi::Array<const T*, 3> ins_data,
int batch_size,
int seq_len,
int num_heads,
int head_dim,
phi::Array<T*, 3> outs_data,
int num_inputs,
MPType div_c) {
int index = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
int stride = gridDim.x * blockDim.x * VecSize;
int size = batch_size * seq_len * num_heads * head_dim;
MPType sin_value[VecSize];
MPType cos_value[VecSize];
MPType result[VecSize];
T store[VecSize];
using VecType = phi::AlignedVector<T, VecSize>;
constexpr int kVectorsPerThread = VecSize / 2;

for (; index < size; index += stride) {
#pragma unroll
for (int nx = 0; nx < VecSize; ++nx) {
// get sin_index and cos_index
int index_wc = (index + nx) % (seq_len * num_heads * head_dim);
int pos_seq = index_wc / (num_heads * head_dim);
MPType idx = static_cast<MPType>((index_wc % head_dim) / 2 * 2.0);
MPType indicses =
static_cast<MPType>(1) /
pow(static_cast<MPType>(10000), idx * static_cast<MPType>(div_c));
MPType value = pos_seq * indicses;
sin_value[nx] = sin(value);
cos_value[nx] = cos(value);
}

#pragma unroll
for (int iter = 0; iter < 3; iter++) {
if (iter > num_inputs) break;
const T* input = ins_data[iter] + index;
VecType* out = reinterpret_cast<VecType*>(outs_data[iter] + index);

#pragma unroll
for (int nx = 0; nx < kVectorsPerThread; ++nx) {
int pr_index = nx * 2;
int ls_index = pr_index + 1;

MPType p0 = static_cast<MPType>(input[pr_index]);
MPType p1 = static_cast<MPType>(input[ls_index]);
result[pr_index] = cos_value[pr_index] * p0 + sin_value[ls_index] * p1;
result[ls_index] = cos_value[ls_index] * p1 - sin_value[pr_index] * p0;

store[pr_index] = static_cast<T>(result[pr_index]);
store[ls_index] = static_cast<T>(result[ls_index]);
}
out[0] = *(reinterpret_cast<VecType*>(store));
}
}
}

template <typename T, typename Context>
void FusedRopeGradKernel(const Context& dev_ctx,
const paddle::optional<DenseTensor>& sin,
const paddle::optional<DenseTensor>& cos,
const DenseTensor& dout_q,
const paddle::optional<DenseTensor>& dout_k,
const paddle::optional<DenseTensor>& dout_v,
Expand Down Expand Up @@ -111,6 +57,7 @@ void FusedRopeGradKernel(const Context& dev_ctx,

phi::Array<T*, 3> outs_data;
phi::Array<const T*, 3> ins_data;
phi::Array<const T*, 2> sin_cos_data;

ins_data[0] = dout_q.data<T>();
outs_data[0] = dq->data<T>();
Expand All @@ -135,8 +82,20 @@ void FusedRopeGradKernel(const Context& dev_ctx,
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType div_c = static_cast<MPType>(1.0f / head_dim);

VectorizedFusedRopeGradKernel<T, MPType, vec_size>
bool flag_sin_cos = false;
if (sin.get_ptr() && cos.get_ptr()) {
sin_cos_data[0] = sin->data<T>();
sin_cos_data[1] = cos->data<T>();

flag_sin_cos = true;
}

int sign = -1;
VectorizedFusedRopeKernel<T, MPType, vec_size>
<<<grid, block, 0, stream>>>(ins_data,
sin_cos_data,
flag_sin_cos,
sign,
batch_size,
seq_len,
num_heads,
Expand Down
103 changes: 41 additions & 62 deletions paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,76 +18,17 @@
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/fusion/gpu/fused_rope_utils.h"
namespace phi {
namespace fusion {

template <typename T, typename MPType, int VecSize = 2>
__global__ void VectorizedFusedRopeKernel(phi::Array<const T*, 3> ins_data,
int batch_size,
int seq_len,
int num_heads,
int head_dim,
phi::Array<T*, 3> outs_data,
int num_inputs,
MPType div_c) {
int index = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
int stride = gridDim.x * blockDim.x * VecSize;
int size = batch_size * seq_len * num_heads * head_dim;
MPType sin_value[VecSize];
MPType cos_value[VecSize];
MPType result[VecSize];
T store[VecSize];
using VecType = phi::AlignedVector<T, VecSize>;
constexpr int kVectorsPerThread = VecSize / 2;

for (; index < size; index += stride) {
#pragma unroll
for (int nx = 0; nx < VecSize; ++nx) {
// get sin_index and cos_index
int index_wc = (index + nx) % (seq_len * num_heads * head_dim);
int pos_seq = index_wc / (num_heads * head_dim);
MPType idx = static_cast<MPType>((index_wc % head_dim) / 2 * 2.0);
MPType indicses =
static_cast<MPType>(1) /
pow(static_cast<MPType>(10000), idx * static_cast<MPType>(div_c));
MPType value = pos_seq * indicses;
sin_value[nx] = sin(value);
cos_value[nx] = cos(value);
}

#pragma unroll
for (int iter = 0; iter < 3; iter++) {
if (iter > num_inputs) break;
const T* input = ins_data[iter] + index;
VecType* out = reinterpret_cast<VecType*>(outs_data[iter] + index);

#pragma unroll
for (int nx = 0; nx < kVectorsPerThread; ++nx) {
int pr_index = nx * 2;
int ls_index = pr_index + 1;

MPType p0 = static_cast<MPType>(input[pr_index]);
MPType p1 = static_cast<MPType>(input[ls_index]);

result[pr_index] = cos_value[pr_index] * p0;
result[pr_index] -= sin_value[pr_index] * p1;

result[ls_index] = sin_value[ls_index] * p0;
result[ls_index] += cos_value[ls_index] * p1;

store[pr_index] = static_cast<T>(result[pr_index]);
store[ls_index] = static_cast<T>(result[ls_index]);
}
out[0] = *(reinterpret_cast<VecType*>(store));
}
}
}

template <typename T, typename Context>
void FusedRopeKernel(const Context& dev_ctx,
const DenseTensor& q,
const paddle::optional<DenseTensor>& k,
const paddle::optional<DenseTensor>& v,
const paddle::optional<DenseTensor>& sin,
const paddle::optional<DenseTensor>& cos,
DenseTensor* out_q,
DenseTensor* out_k,
DenseTensor* out_v) {
Expand Down Expand Up @@ -116,6 +57,7 @@ void FusedRopeKernel(const Context& dev_ctx,

phi::Array<T*, 3> outs_data;
phi::Array<const T*, 3> ins_data;
phi::Array<const T*, 2> sin_cos_data;

ins_data[0] = q.data<T>();
outs_data[0] = out_q->data<T>();
Expand All @@ -140,8 +82,45 @@ void FusedRopeKernel(const Context& dev_ctx,
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType div_c = static_cast<MPType>(1.0f / head_dim);

bool flag_sin_cos = false;

if (sin.get_ptr() && cos.get_ptr()) {
PADDLE_ENFORCE_EQ(sin.get_ptr()->dims(),
cos.get_ptr()->dims(),
phi::errors::InvalidArgument(
"The dims of sin and cos must be the same."));
auto sin_dims = sin.get_ptr()->dims();
int dims_size = sin_dims.size();
PADDLE_ENFORCE_NE((dims_size == 2 || dims_size == 4),
false,
phi::errors::InvalidArgument(
"The dims of sin and cos must be 2 or 4."));
if (dims_size == 4) {
PADDLE_ENFORCE_NE(
(sin_dims[0] == 1 && sin_dims[1] == 1),
false,
phi::errors::InvalidArgument(
"The batch_size and num_heads of sin and cos must be 1."));
}
PADDLE_ENFORCE_NE(
(sin_dims[dims_size - 1] == head_dim &&
sin_dims[dims_size - 2] == seq_len),
false,
phi::errors::InvalidArgument("The seq_len and head_dim of sin and cos "
"must be the same as those of q."));

sin_cos_data[0] = sin->data<T>();
sin_cos_data[1] = cos->data<T>();

flag_sin_cos = true;
}

int sign = 1;
VectorizedFusedRopeKernel<T, MPType, vec_size>
<<<grid, block, 0, stream>>>(ins_data,
sin_cos_data,
flag_sin_cos,
sign,
batch_size,
seq_len,
num_heads,
Expand Down
Loading

0 comments on commit e01a173

Please sign in to comment.