diff --git a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py index fe591a6be48fd..3fd80ffbaaa80 100644 --- a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py +++ b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py @@ -46,33 +46,44 @@ def fused_rotary_position_embedding( .. code-block:: python - # required: gpu - import paddle - from paddle.incubate.nn.functional import fused_rotary_position_embedding - - # batch_size = 2 - # seq_len = 8 - # num_heads = 2 - # head_dim = 10 - - # q, k, v: [batch_size, seq_len, num_heads, head_dim] - q = paddle.randn([2, 8, 2, 10], dtype='float16') - k = paddle.randn([2, 8, 2, 10], dtype='float16') - v = paddle.randn([2, 8, 2, 10], dtype='float16') - - # sin, cos: [1, seq_len, 1, head_dim] - x = paddle.randn([1, 8, 1, 10], dtype='float16') - y = paddle.randn([1, 8, 1, 10], dtype='float16') - sin = paddle.sin(x) - cos = paddle.cos(y) - - # position_ids: [batch_size, seq_len] - position_ids = paddle.randint(high=8, shape=[2, 8], dtype='int64') - - # out_q, out_k, out_v: [batch_size, seq_len, num_heads, head_dim] - out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v, sin=sin, cos=cos, position_ids=position_ids, use_neox_rotary_style=False) - print(out_q.shape) - # [2, 8, 2, 10] + >>> # required: gpu + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> from paddle.incubate.nn.functional import fused_rotary_position_embedding + + >>> # batch_size = 2 + >>> # seq_len = 2 + >>> # num_heads = 2 + >>> # head_dim = 2 + + >>> paddle.seed(1024) + + >>> # q, k, v: [batch_size, seq_len, num_heads, head_dim] + >>> q = paddle.randn([2, 2, 2, 2], dtype='float16') + >>> k = paddle.randn([2, 2, 2, 2], dtype='float16') + >>> v = paddle.randn([2, 2, 2, 2], dtype='float16') + + >>> # sin, cos: [1, seq_len, 1, head_dim] + >>> x = paddle.randn([1, 2, 1, 2], dtype='float16') + >>> y = paddle.randn([1, 2, 1, 2], dtype='float16') + >>> sin = paddle.sin(x) + >>> cos = paddle.cos(y) + + >>> # position_ids: [batch_size, seq_len] + >>> position_ids = paddle.randint(high=2, shape=[2, 2], dtype='int64') + + >>> # out_q, out_k, out_v: [batch_size, seq_len, num_heads, head_dim] + >>> out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v, sin=sin, cos=cos, position_ids=position_ids, use_neox_rotary_style=False) + >>> print(out_q) + Tensor(shape=[2, 2, 2, 2], dtype=float16, place=Place(gpu:0), stop_gradient=True, + [[[[-0.54931641, 0.64990234], + [-1.08691406, 1.18261719]], + [[ 0.57812500, 0.11749268], + [-0.63281250, 0.15551758]]], + [[[-0.77050781, 0.07733154], + [-0.73730469, -0.16735840]], + [[ 0.07116699, -0.90966797], + [-0.03628540, -0.20202637]]]]) """ if in_dynamic_mode(): return _C_ops.fused_rotary_position_embedding(