-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
在fused_rope算子中增加rotate_half实现方式 #56401
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
@@ -102,5 +103,91 @@ __global__ void VectorizedFusedRopeKernel(phi::Array<const T*, 3> ins_data, | |||
} | |||
} | |||
|
|||
template <typename T, typename MPType, int VecSize = 2> | |||
__global__ void VectorizedFusedRopeWithRotateHalfKernel( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- CUDA上面可以定义
__device__
函数,__device__
函数可以被__global__
函数调用 - 拆分成2个
__global__
函数也可以,但还是要避免大段的代码拷贝
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
@@ -27,6 +29,7 @@ def fused_rotary_position_embedding(q, k=None, v=None, sin=None, cos=None): | |||
v (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if v must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2. | |||
sin (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if sin must be [seq_len, head_dim] or [1, 1, seq_len, head_dim] and head_dim must be a multiple of 2. | |||
cos (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if cos must be [seq_len, head_dim] or [1, 1, seq_len, head_dim] and head_dim must be a multiple of 2. | |||
use_neox_rotary_style(optional|bool): Use "rotate_every_two" when use_neox_rotary_style is True, use "ratate_half" when use_neox_rotary_style is False. Default True. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这么解释含义并不直观,rotate_every_two
和rotate_half
并不是大家都知道的通用的表意。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
# to [1, 1, seq_len, head_dim] | ||
perm = [0, 2, 1, 3] | ||
sin_tensor = paddle.transpose(x=sin_tensor, perm=perm) | ||
cos_tensor = paddle.transpose(x=cos_tensor, perm=perm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_neox_rotary_style
为True
或者False
,只有q
、k
、v
的更新逻辑有差异,sin
和cos
的计算逻辑并没有差异,因此sin
和cos
的计算逻辑没有必要在if-else
两个分支中重复。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
if (position_ids.get_ptr()) { | ||
position_ids_data = position_ids->data<int64_t>(); | ||
|
||
flag_position_ids = true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
并不需要加这么个flag
,L63将position_ids_data
初始化为空,Kernel里面可以用指针是否为空判断
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
bool flag_position_ids = false; | ||
if (position_ids.get_ptr()) { | ||
position_ids_data = position_ids->data<int64_t>(); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
@@ -27,6 +35,8 @@ def fused_rotary_position_embedding(q, k=None, v=None, sin=None, cos=None): | |||
v (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if v must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2. | |||
sin (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if sin must be [seq_len, head_dim] or [1, 1, seq_len, head_dim] and head_dim must be a multiple of 2. | |||
cos (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if cos must be [seq_len, head_dim] or [1, 1, seq_len, head_dim] and head_dim must be a multiple of 2. | |||
position_ids (optional|Tensor): The input tensor. The data type is int64. The shape if position_ids must be [batch_size, seq_len]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The shape if
-> The shape of
,文档里面参数的格式应该是:position_ids (Tensor, optional)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. PR Title建议改成英文
import paddle | ||
from paddle.incubate.nn.functional import fused_rotary_position_embedding | ||
|
||
q = paddle.randn([1, 1, 4, 10], dtype='float16') | ||
k = paddle.randn([1, 1, 4, 10], dtype='float16') | ||
v = paddle.randn([1, 1, 4, 10], dtype='float16') | ||
out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v) | ||
# batch_size = 2 | ||
# seq_len = 8 | ||
# num_heads = 2 | ||
# head_dim = 10 | ||
|
||
x = paddle.randn([1, 1, 1, 10], dtype='float16') | ||
y = paddle.randn([1, 1, 1, 10], dtype='float16') | ||
# q, k, v: [batch_size, seq_len, num_heads, head_dim] | ||
q = paddle.randn([2, 8, 2, 10], dtype='float16') | ||
k = paddle.randn([2, 8, 2, 10], dtype='float16') | ||
v = paddle.randn([2, 8, 2, 10], dtype='float16') | ||
|
||
# sin, cos: [1, seq_len, 1, head_dim] | ||
x = paddle.randn([1, 8, 1, 10], dtype='float16') | ||
y = paddle.randn([1, 8, 1, 10], dtype='float16') | ||
sin = paddle.sin(x) | ||
cos = paddle.cos(y) | ||
out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v, sin=sin, cos=cos) | ||
|
||
# position_ids: [batch_size, seq_len] | ||
position_ids = paddle.randint(high=8, shape=[2, 8], dtype='int64') | ||
|
||
# out_q, out_k, out_v: [batch_size, seq_len, num_heads, head_dim] | ||
out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v, sin=sin, cos=cos, position_ids=position_ids, use_neox_rotary_style=False) | ||
print(out_q.shape) | ||
# [2, 8, 2, 10] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
代码示例请严格按照 Google style 样式,即代码前需要加上>>>
或 ...
,若有输出(如print(out_q.shape)
) 则要在输出后加上准确的输出结果,参考 API 文档写作说明—代码示例 和 文档示例代码书写规范。
注意
- 本代码部分有
randn
这类带有随机性的api,请在代码部分增加seed
,以保证输出结果固定,便于检查。 # required: gpu
本环境是需要GPU环境吗?是的话需要在代码开头增加doctest
指令:>>> # doctest: +REQUIRES(env:GPU)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
22f0c81
to
7a56d1c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM,我看API好像是公开的,记得补充一下中文文档
MPType* sin_value = out_sin; | ||
MPType* cos_value = out_cos; | ||
|
||
if (flag_sin_cos) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个参数的命名似乎不是好理解,reuse_***?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,下一个PR修改
@@ -148,11 +148,11 @@ | |||
optional : cache_kv, pre_caches, rotary_pos_emb, time_step, seq_lengths, src_mask, gather_index | |||
|
|||
- op : fused_rotary_position_embedding | |||
args : (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos) | |||
args : (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos, Tensor position_ids, bool use_neox_rotary_style = true) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for add inputs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -86,21 +89,42 @@ void FusedRopeGradKernel(const Context& dev_ctx, | |||
sin_cos_data[1] = cos->data<T>(); | |||
|
|||
flag_sin_cos = true; | |||
|
|||
if (position_ids.get_ptr()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里应该可以直接 if (position_ids) 的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,下一个pr修改
num_inputs, | ||
div_c); | ||
if (use_neox_rotary_style) { | ||
VectorizedFusedRopeWithRotateEveryTwoKernel<T, MPType, vec_size> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我觉得kernel名字改成:
VectorizedFusedNeoxRopeKernel 是不是好点
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,下一个pr修改
* add rotate_half in fused_rope * add position_ids in fused_rope * modified examples about fused_rope * add set_device in examples
PR types
Others
PR changes
OPs
Description
Pcard-70459
在fused_rope算子中增加rotate_half实现方式,通过use_neox_rotary_style这样一个变量来控制,true是rotate_every_two实现、false是rotate_half实现,默认值为true