Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PIR API adaptor No.87】Migrate fused_rotary_position_embedding into pir #58911

Merged
merged 24 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 33 additions & 18 deletions python/paddle/incubate/nn/functional/fused_matmul_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle import _legacy_C_ops
from paddle import _C_ops, _legacy_C_ops
from paddle.base.layer_helper import LayerHelper
from paddle.framework import in_dynamic_mode
from paddle.framework import in_dynamic_mode, in_dynamic_or_pir_mode
from paddle.tensor.linalg import matmul


Expand Down Expand Up @@ -56,10 +56,15 @@ def fused_matmul_bias(
"""
if bias is None:
return matmul(x, y, transpose_x, transpose_y, name)
if in_dynamic_mode():
return _legacy_C_ops.fused_gemm_epilogue(
x, y, bias, 'trans_x', transpose_x, 'trans_y', transpose_y
)
if in_dynamic_or_pir_mode():
if in_dynamic_mode():
return _legacy_C_ops.fused_gemm_epilogue(
x, y, bias, 'trans_x', transpose_x, 'trans_y', transpose_y
)
else:
return _C_ops.fused_gemm_epilogue(
x, y, bias, transpose_x, transpose_y
)
enkilee marked this conversation as resolved.
Show resolved Hide resolved

helper = LayerHelper('fused_matmul_bias', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
Expand Down Expand Up @@ -145,18 +150,28 @@ def fused_linear_activation(
if activation is None:
activation = "none"

if in_dynamic_mode():
return _legacy_C_ops.fused_gemm_epilogue(
x,
y,
bias,
'trans_x',
trans_x,
'trans_y',
trans_y,
'activation',
activation,
)
if in_dynamic_or_pir_mode():
if in_dynamic_mode():
return _legacy_C_ops.fused_gemm_epilogue(
x,
y,
bias,
'trans_x',
trans_x,
'trans_y',
trans_y,
'activation',
activation,
)
else:
return _C_ops.fused_gemm_epilogue(
x,
y,
bias,
trans_x,
trans_y,
activation,
)
enkilee marked this conversation as resolved.
Show resolved Hide resolved

helper = LayerHelper('fused_matmul_bias', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from paddle import _C_ops
from paddle.base.layer_helper import LayerHelper
from paddle.framework import in_dynamic_mode
from paddle.framework import in_dynamic_or_pir_mode


def fused_rotary_position_embedding(
Expand Down Expand Up @@ -87,7 +87,7 @@ def fused_rotary_position_embedding(
[[ 0.07116699, -0.90966797],
[-0.03628540, -0.20202637]]]])
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.fused_rotary_position_embedding(
q, k, v, sin, cos, position_ids, use_neox_rotary_style
)
Expand Down
115 changes: 74 additions & 41 deletions python/paddle/incubate/nn/functional/fused_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1148,47 +1148,80 @@ def fused_multi_transformer(
'downgrade_in_infer' if mode == 'downscale_in_infer' else mode
) # semantic transfer

if in_dynamic_mode():
cache_kv_out, final_out = _legacy_C_ops.fused_multi_transformer(
x,
ln_scales,
ln_biases,
qkv_weights,
qkv_biases,
cache_kvs,
pre_caches,
rotary_embs,
time_step,
seq_lens,
attn_mask,
linear_weights,
linear_biases,
ffn_ln_scales,
ffn_ln_biases,
ffn1_weights,
ffn1_biases,
ffn2_weights,
ffn2_biases,
cache_kvs,
'pre_layer_norm',
pre_layer_norm,
'epsilon',
epsilon,
'dropout_rate',
dropout_rate,
'rotary_emb_dims',
rotary_emb_dims,
'is_test',
not training,
'dropout_implementation',
mode,
'act_method',
activation,
'trans_qkvw',
trans_qkvw,
'ring_id',
ring_id,
)
if in_dynamic_or_pir_mode():
if in_dynamic_mode():
cache_kv_out, final_out = _legacy_C_ops.fused_multi_transformer(
x,
ln_scales,
ln_biases,
qkv_weights,
qkv_biases,
cache_kvs,
pre_caches,
rotary_embs,
time_step,
seq_lens,
attn_mask,
linear_weights,
linear_biases,
ffn_ln_scales,
ffn_ln_biases,
ffn1_weights,
ffn1_biases,
ffn2_weights,
ffn2_biases,
cache_kvs,
'pre_layer_norm',
pre_layer_norm,
'epsilon',
epsilon,
'dropout_rate',
dropout_rate,
'rotary_emb_dims',
rotary_emb_dims,
'is_test',
not training,
'dropout_implementation',
mode,
'act_method',
activation,
'trans_qkvw',
trans_qkvw,
'ring_id',
ring_id,
)
else:
cache_kv_out, final_out = _C_ops.fused_multi_transformer(
x,
ln_scales,
ln_biases,
qkv_weights,
qkv_biases,
cache_kvs,
pre_caches,
rotary_embs,
time_step,
seq_lens,
attn_mask,
linear_weights,
linear_biases,
ffn_ln_scales,
ffn_ln_biases,
ffn1_weights,
ffn1_biases,
ffn2_weights,
ffn2_biases,
cache_kvs,
pre_layer_norm,
epsilon,
dropout_rate,
rotary_emb_dims,
not training,
mode,
activation,
trans_qkvw,
ring_id,
)
enkilee marked this conversation as resolved.
Show resolved Hide resolved
if cache_kvs is not None:
return final_out, cache_kv_out
return final_out
Expand Down
2 changes: 2 additions & 0 deletions test/legacy_test/test_fused_rotary_position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import paddle
from paddle.base import core
from paddle.incubate.nn.functional import fused_rotary_position_embedding
from paddle.pir_utils import test_with_pir_api


def deal_qkv(init_q, init_k, init_v):
Expand Down Expand Up @@ -281,6 +282,7 @@ def test_fused_rope_position_ids(self):
p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05
)

@test_with_pir_api
def test_static(self):
tensor_q, tensor_k, tensor_v, tensor_sin, tensor_cos = self.get_inputs(
self.seed, True
Expand Down