Skip to content

Commit

Permalink
modified examples according to comment
Browse files Browse the repository at this point in the history
  • Loading branch information
tianhaodongbd committed Sep 1, 2023
1 parent b319375 commit 7620f7c
Showing 1 changed file with 38 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,33 +46,44 @@ def fused_rotary_position_embedding(
.. code-block:: python
# required: gpu
import paddle
from paddle.incubate.nn.functional import fused_rotary_position_embedding
# batch_size = 2
# seq_len = 8
# num_heads = 2
# head_dim = 10
# q, k, v: [batch_size, seq_len, num_heads, head_dim]
q = paddle.randn([2, 8, 2, 10], dtype='float16')
k = paddle.randn([2, 8, 2, 10], dtype='float16')
v = paddle.randn([2, 8, 2, 10], dtype='float16')
# sin, cos: [1, seq_len, 1, head_dim]
x = paddle.randn([1, 8, 1, 10], dtype='float16')
y = paddle.randn([1, 8, 1, 10], dtype='float16')
sin = paddle.sin(x)
cos = paddle.cos(y)
# position_ids: [batch_size, seq_len]
position_ids = paddle.randint(high=8, shape=[2, 8], dtype='int64')
# out_q, out_k, out_v: [batch_size, seq_len, num_heads, head_dim]
out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v, sin=sin, cos=cos, position_ids=position_ids, use_neox_rotary_style=False)
print(out_q.shape)
# [2, 8, 2, 10]
>>> # required: gpu
>>> # doctest: +REQUIRES(env:GPU)
>>> import paddle
>>> from paddle.incubate.nn.functional import fused_rotary_position_embedding
>>> # batch_size = 2
>>> # seq_len = 2
>>> # num_heads = 2
>>> # head_dim = 2
>>> paddle.seed(1024)
>>> # q, k, v: [batch_size, seq_len, num_heads, head_dim]
>>> q = paddle.randn([2, 2, 2, 2], dtype='float16')
>>> k = paddle.randn([2, 2, 2, 2], dtype='float16')
>>> v = paddle.randn([2, 2, 2, 2], dtype='float16')
>>> # sin, cos: [1, seq_len, 1, head_dim]
>>> x = paddle.randn([1, 2, 1, 2], dtype='float16')
>>> y = paddle.randn([1, 2, 1, 2], dtype='float16')
>>> sin = paddle.sin(x)
>>> cos = paddle.cos(y)
>>> # position_ids: [batch_size, seq_len]
>>> position_ids = paddle.randint(high=2, shape=[2, 2], dtype='int64')
>>> # out_q, out_k, out_v: [batch_size, seq_len, num_heads, head_dim]
>>> out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v, sin=sin, cos=cos, position_ids=position_ids, use_neox_rotary_style=False)
>>> print(out_q)
Tensor(shape=[2, 2, 2, 2], dtype=float16, place=Place(gpu:0), stop_gradient=True,
[[[[-0.54931641, 0.64990234],
[-1.08691406, 1.18261719]],
[[ 0.57812500, 0.11749268],
[-0.63281250, 0.15551758]]],
[[[-0.77050781, 0.07733154],
[-0.73730469, -0.16735840]],
[[ 0.07116699, -0.90966797],
[-0.03628540, -0.20202637]]]])
"""
if in_dynamic_mode():
return _C_ops.fused_rotary_position_embedding(
Expand Down

0 comments on commit 7620f7c

Please sign in to comment.