-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
在fused_rope算子中增加rotate_half实现方式 #56401
Changes from 3 commits
c4c4bf3
8cc3cf5
2459faa
2e8be33
b319375
7620f7c
7a56d1c
8619245
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,9 +27,11 @@ template <typename T, typename Context> | |
void FusedRopeGradKernel(const Context& dev_ctx, | ||
const paddle::optional<DenseTensor>& sin, | ||
const paddle::optional<DenseTensor>& cos, | ||
const paddle::optional<DenseTensor>& position_ids, | ||
const DenseTensor& dout_q, | ||
const paddle::optional<DenseTensor>& dout_k, | ||
const paddle::optional<DenseTensor>& dout_v, | ||
bool use_neox_rotary_style, | ||
DenseTensor* dq, | ||
DenseTensor* dk, | ||
DenseTensor* dv) { | ||
|
@@ -58,6 +60,7 @@ void FusedRopeGradKernel(const Context& dev_ctx, | |
phi::Array<T*, 3> outs_data; | ||
phi::Array<const T*, 3> ins_data; | ||
phi::Array<const T*, 2> sin_cos_data; | ||
const int64_t* position_ids_data; | ||
|
||
ins_data[0] = dout_q.data<T>(); | ||
outs_data[0] = dq->data<T>(); | ||
|
@@ -88,19 +91,45 @@ void FusedRopeGradKernel(const Context& dev_ctx, | |
flag_sin_cos = true; | ||
} | ||
|
||
bool flag_position_ids = false; | ||
if (position_ids.get_ptr()) { | ||
position_ids_data = position_ids->data<int64_t>(); | ||
|
||
flag_position_ids = true; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 并不需要加这么个 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
} | ||
|
||
int sign = -1; | ||
VectorizedFusedRopeKernel<T, MPType, vec_size> | ||
<<<grid, block, 0, stream>>>(ins_data, | ||
sin_cos_data, | ||
flag_sin_cos, | ||
sign, | ||
batch_size, | ||
seq_len, | ||
num_heads, | ||
head_dim, | ||
outs_data, | ||
num_inputs, | ||
div_c); | ||
if (use_neox_rotary_style) { | ||
VectorizedFusedRopeWithRotateEveryTwoKernel<T, MPType, vec_size> | ||
<<<grid, block, 0, stream>>>(ins_data, | ||
sin_cos_data, | ||
position_ids_data, | ||
flag_sin_cos, | ||
flag_position_ids, | ||
sign, | ||
batch_size, | ||
seq_len, | ||
num_heads, | ||
head_dim, | ||
outs_data, | ||
num_inputs, | ||
div_c); | ||
} else { | ||
VectorizedFusedRopeWithRotateHalfKernel<T, MPType, vec_size> | ||
<<<grid, block, 0, stream>>>(ins_data, | ||
sin_cos_data, | ||
position_ids_data, | ||
flag_sin_cos, | ||
flag_position_ids, | ||
sign, | ||
batch_size, | ||
seq_len, | ||
num_heads, | ||
head_dim, | ||
outs_data, | ||
num_inputs, | ||
div_c); | ||
} | ||
} | ||
|
||
} // namespace fusion | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,8 @@ void FusedRopeKernel(const Context& dev_ctx, | |
const paddle::optional<DenseTensor>& v, | ||
const paddle::optional<DenseTensor>& sin, | ||
const paddle::optional<DenseTensor>& cos, | ||
const paddle::optional<DenseTensor>& position_ids, | ||
bool use_neox_rotary_style, | ||
DenseTensor* out_q, | ||
DenseTensor* out_k, | ||
DenseTensor* out_v) { | ||
|
@@ -59,6 +61,7 @@ void FusedRopeKernel(const Context& dev_ctx, | |
phi::Array<T*, 3> outs_data; | ||
phi::Array<const T*, 3> ins_data; | ||
phi::Array<const T*, 2> sin_cos_data; | ||
const int64_t* position_ids_data; | ||
|
||
ins_data[0] = q.data<T>(); | ||
outs_data[0] = out_q->data<T>(); | ||
|
@@ -125,19 +128,45 @@ void FusedRopeKernel(const Context& dev_ctx, | |
flag_sin_cos = true; | ||
} | ||
|
||
bool flag_position_ids = false; | ||
if (position_ids.get_ptr()) { | ||
position_ids_data = position_ids->data<int64_t>(); | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
flag_position_ids = true; | ||
} | ||
|
||
int sign = 1; | ||
VectorizedFusedRopeKernel<T, MPType, vec_size> | ||
<<<grid, block, 0, stream>>>(ins_data, | ||
sin_cos_data, | ||
flag_sin_cos, | ||
sign, | ||
batch_size, | ||
seq_len, | ||
num_heads, | ||
head_dim, | ||
outs_data, | ||
num_inputs, | ||
div_c); | ||
if (use_neox_rotary_style) { | ||
VectorizedFusedRopeWithRotateEveryTwoKernel<T, MPType, vec_size> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我觉得kernel名字改成: VectorizedFusedNeoxRopeKernel 是不是好点 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的,下一个pr修改 |
||
<<<grid, block, 0, stream>>>(ins_data, | ||
sin_cos_data, | ||
position_ids_data, | ||
flag_sin_cos, | ||
flag_position_ids, | ||
sign, | ||
batch_size, | ||
seq_len, | ||
num_heads, | ||
head_dim, | ||
outs_data, | ||
num_inputs, | ||
div_c); | ||
} else { | ||
VectorizedFusedRopeWithRotateHalfKernel<T, MPType, vec_size> | ||
<<<grid, block, 0, stream>>>(ins_data, | ||
sin_cos_data, | ||
position_ids_data, | ||
flag_sin_cos, | ||
flag_position_ids, | ||
sign, | ||
batch_size, | ||
seq_len, | ||
num_heads, | ||
head_dim, | ||
outs_data, | ||
num_inputs, | ||
div_c); | ||
} | ||
} | ||
} // namespace fusion | ||
} // namespace phi | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for add inputs