Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

在fused_rope算子中增加rotate_half实现方式 #56401

Merged
merged 8 commits into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions paddle/phi/api/yaml/fused_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
support_dygraph_mode : true

- backward_op : fused_rotary_position_embedding_grad
forward: fused_rotary_position_embedding (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos) -> Tensor(out_q), Tensor(out_k), Tensor(out_v)
args : (Tensor sin, Tensor cos, Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad)
forward: fused_rotary_position_embedding (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos, Tensor position_ids, bool use_neox_rotary_style) -> Tensor(out_q), Tensor(out_k), Tensor(out_v)
args : (Tensor sin, Tensor cos, Tensor position_ids, Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad, bool use_neox_rotary_style)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
optional : sin, cos, out_k_grad, out_v_grad, k_grad, v_grad
optional : sin, cos, position_ids, out_k_grad, out_v_grad, k_grad, v_grad
infer_meta :
func : FusedRopeGradInferMeta
kernel :
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/fused_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,11 @@
optional : cache_kv, pre_caches, rotary_pos_emb, time_step, seq_lengths, src_mask, gather_index

- op : fused_rotary_position_embedding
args : (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos)
args : (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos, Tensor position_ids, bool use_neox_rotary_style = true)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for add inputs

output : Tensor(out_q), Tensor(out_k), Tensor(out_v)
infer_meta :
func : FusedRopeInferMeta
optional : k,v,sin,cos, out_k, out_v
optional : k, v, sin, cos, position_ids, out_k, out_v
kernel :
func : fused_rotary_position_embedding
data_type : q
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1219,9 +1219,11 @@ void IndexPutGradInferMeta(const MetaTensor& x,

void FusedRopeGradInferMeta(const MetaTensor& sin,
const MetaTensor& cos,
const MetaTensor& position_ids,
const MetaTensor& dout_q,
const MetaTensor& dout_k,
const MetaTensor& dout_v,
bool use_neox_rotary_style,
MetaTensor* dq,
MetaTensor* dk,
MetaTensor* dv) {
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,11 @@ void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset,

void FusedRopeGradInferMeta(const MetaTensor& sin,
const MetaTensor& cos,
const MetaTensor& position_ids,
const MetaTensor& dout_q,
const MetaTensor& dout_k,
const MetaTensor& dout_v,
bool use_neox_rotary_style,
MetaTensor* dq,
MetaTensor* dk,
MetaTensor* dv);
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3848,6 +3848,8 @@ void FusedRopeInferMeta(const MetaTensor& q,
const MetaTensor& v,
const MetaTensor& sin,
const MetaTensor& cos,
const MetaTensor& position_ids,
bool use_neox_rotary_style,
MetaTensor* out_q,
MetaTensor* out_k,
MetaTensor* out_v) {
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,8 @@ void FusedRopeInferMeta(const MetaTensor& q,
const MetaTensor& v,
const MetaTensor& sin,
const MetaTensor& cos,
const MetaTensor& position_ids,
bool use_neox_rotary_style,
MetaTensor* out_q,
MetaTensor* out_k,
MetaTensor* out_v);
Expand Down
53 changes: 41 additions & 12 deletions paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ template <typename T, typename Context>
void FusedRopeGradKernel(const Context& dev_ctx,
const paddle::optional<DenseTensor>& sin,
const paddle::optional<DenseTensor>& cos,
const paddle::optional<DenseTensor>& position_ids,
const DenseTensor& dout_q,
const paddle::optional<DenseTensor>& dout_k,
const paddle::optional<DenseTensor>& dout_v,
bool use_neox_rotary_style,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv) {
Expand Down Expand Up @@ -58,6 +60,7 @@ void FusedRopeGradKernel(const Context& dev_ctx,
phi::Array<T*, 3> outs_data;
phi::Array<const T*, 3> ins_data;
phi::Array<const T*, 2> sin_cos_data;
const int64_t* position_ids_data;

ins_data[0] = dout_q.data<T>();
outs_data[0] = dq->data<T>();
Expand Down Expand Up @@ -88,19 +91,45 @@ void FusedRopeGradKernel(const Context& dev_ctx,
flag_sin_cos = true;
}

bool flag_position_ids = false;
if (position_ids.get_ptr()) {
position_ids_data = position_ids->data<int64_t>();

flag_position_ids = true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

并不需要加这么个flag,L63将position_ids_data初始化为空,Kernel里面可以用指针是否为空判断

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

}

int sign = -1;
VectorizedFusedRopeKernel<T, MPType, vec_size>
<<<grid, block, 0, stream>>>(ins_data,
sin_cos_data,
flag_sin_cos,
sign,
batch_size,
seq_len,
num_heads,
head_dim,
outs_data,
num_inputs,
div_c);
if (use_neox_rotary_style) {
VectorizedFusedRopeWithRotateEveryTwoKernel<T, MPType, vec_size>
<<<grid, block, 0, stream>>>(ins_data,
sin_cos_data,
position_ids_data,
flag_sin_cos,
flag_position_ids,
sign,
batch_size,
seq_len,
num_heads,
head_dim,
outs_data,
num_inputs,
div_c);
} else {
VectorizedFusedRopeWithRotateHalfKernel<T, MPType, vec_size>
<<<grid, block, 0, stream>>>(ins_data,
sin_cos_data,
position_ids_data,
flag_sin_cos,
flag_position_ids,
sign,
batch_size,
seq_len,
num_heads,
head_dim,
outs_data,
num_inputs,
div_c);
}
}

} // namespace fusion
Expand Down
53 changes: 41 additions & 12 deletions paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ void FusedRopeKernel(const Context& dev_ctx,
const paddle::optional<DenseTensor>& v,
const paddle::optional<DenseTensor>& sin,
const paddle::optional<DenseTensor>& cos,
const paddle::optional<DenseTensor>& position_ids,
bool use_neox_rotary_style,
DenseTensor* out_q,
DenseTensor* out_k,
DenseTensor* out_v) {
Expand Down Expand Up @@ -59,6 +61,7 @@ void FusedRopeKernel(const Context& dev_ctx,
phi::Array<T*, 3> outs_data;
phi::Array<const T*, 3> ins_data;
phi::Array<const T*, 2> sin_cos_data;
const int64_t* position_ids_data;

ins_data[0] = q.data<T>();
outs_data[0] = out_q->data<T>();
Expand Down Expand Up @@ -125,19 +128,45 @@ void FusedRopeKernel(const Context& dev_ctx,
flag_sin_cos = true;
}

bool flag_position_ids = false;
if (position_ids.get_ptr()) {
position_ids_data = position_ids->data<int64_t>();

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要对position_ids的维度进行检查,且只有在传入了sincos的时候才需要用position_ids,且需要修改sin、cos的shape检查逻辑。

image

也就是说,sin、cos依据position_ids里面的坐标切片访问后,shape才是[1, seq_len, 1, head_dim],传进来的可能是一个比较大的shape

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

flag_position_ids = true;
}

int sign = 1;
VectorizedFusedRopeKernel<T, MPType, vec_size>
<<<grid, block, 0, stream>>>(ins_data,
sin_cos_data,
flag_sin_cos,
sign,
batch_size,
seq_len,
num_heads,
head_dim,
outs_data,
num_inputs,
div_c);
if (use_neox_rotary_style) {
VectorizedFusedRopeWithRotateEveryTwoKernel<T, MPType, vec_size>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我觉得kernel名字改成:

VectorizedFusedNeoxRopeKernel 是不是好点

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,下一个pr修改

<<<grid, block, 0, stream>>>(ins_data,
sin_cos_data,
position_ids_data,
flag_sin_cos,
flag_position_ids,
sign,
batch_size,
seq_len,
num_heads,
head_dim,
outs_data,
num_inputs,
div_c);
} else {
VectorizedFusedRopeWithRotateHalfKernel<T, MPType, vec_size>
<<<grid, block, 0, stream>>>(ins_data,
sin_cos_data,
position_ids_data,
flag_sin_cos,
flag_position_ids,
sign,
batch_size,
seq_len,
num_heads,
head_dim,
outs_data,
num_inputs,
div_c);
}
}
} // namespace fusion
} // namespace phi
Expand Down
Loading