Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hybrid performance] softmax mask fuse op #33841

Merged
merged 1 commit into from
Jul 16, 2021

Conversation

FeixLiu
Copy link
Contributor

@FeixLiu FeixLiu commented Jun 29, 2021

PR types

New features

PR changes

OPs

Describe

fuse mask elementwise add and softmax together, for transformer used

general pass:

# declare Q, K, dk and mask
QK = matmul(Q, K.transpose()) / sqrt(dk)    # launch first op for matmul
tmp = QK + mask                             # launch second op for elementwise_mul
res = softmax(tmp)                          # launch third op for softmax

fused pass:

# declare Q, K, dk and mask
QK = matmul(Q, K.transpose()) / sqrt(dk)    # launch first op for matmul
res = softmax_mask_fuse(QK, mask)           # launch second op for softmax_mask_fuse

performance, based on PaddleNLP/GPT under AMP:

model size hybird size (dp, pp, mp) no fuse throughput fused throughput gain
117M 8(1,1,8) 12318.38 15046.72 +22%
117M 1(1,1,1) 8729.29 9859.29 +13%
345M 8(1,1,1) 5317.19 6766.67 +27%
345M 1(1,1,1) 3796.09 4322.57 +14%
1.3B 8(1,1,8) 3970.90 4727.14 +19%

loss curve, for PaddleNLP/GPT under AMP:
Screen Shot 2021-07-01 at 8 55 09 AM
average loss diff for 20,000 steps: 0.0077
loss diff between fused pass and no fuse pass:
Screen Shot 2021-07-01 at 10 57 00 AM

Currently, this OP only supports fp16 dtype.

To use this op from python side, follow these codes for static mode:

import paddle.incubate as incubate
import paddle.fluid as fluid
import numpy as np

input_x = fluid.data(name="x", shape=[4, 4, 8, 132], dtype="float16")
input_mask = fluid.data(name="mask", shape=[4, 1, 8, 132], dtype="float16")
rst = incubate.softmax_mask_fuse(input_x, input_mask)

x_in_np = np.random.random((4, 4, 8, 132)).astype("float16")
mask = np.random.randint(0, 2, (4, 1, 8, 132)).astype("float16")
mask_in_np = np.where(mask == 1, -10000.0, mask)

exe = fluid.Executor(place)
fetches = exe.run(fluid.default_main_program(),
                  feed={"x": x_in_np,
                        "mask": mask_in_np},
                  fetch_list=[rst])

Follow these codes for dynamic mode:

x_in_np = np.random.random((2, 8, 8, 2040)).astype("float16")
mask = np.random.randint(0, 2, (2, 1, 8, 2040)).astype("float16")
mask_in_np = np.where(mask == 1, -10000.0, mask)
input_x = fluid.dygraph.to_variable(x_in_np)
input_mask = fluid.dygraph.to_variable(mask_in_np)

rst = incubate.softmax_mask_fuse(input_x, input_mask)

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@FeixLiu FeixLiu force-pushed the softmax_mask_fuse branch from 845deef to 907037c Compare June 30, 2021 06:52
@FeixLiu FeixLiu force-pushed the softmax_mask_fuse branch from 907037c to 04c3eff Compare June 30, 2021 07:11
ForFishes
ForFishes previously approved these changes Jul 2, 2021
Copy link
Member

@ForFishes ForFishes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Xreki
Xreki previously approved these changes Jul 5, 2021
Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for the modification of atol and LGTM for op benchmark ci.

@paddle-bot-old
Copy link

paddle-bot-old bot commented Jul 9, 2021

Sorry to inform you that 170cde8's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

@FeixLiu FeixLiu dismissed stale reviews from Xreki and ForFishes via 3b18c23 July 12, 2021 05:49
@FeixLiu FeixLiu force-pushed the softmax_mask_fuse branch from 170cde8 to 3b18c23 Compare July 12, 2021 05:49
@FeixLiu FeixLiu force-pushed the softmax_mask_fuse branch from 74b218e to 00c0a44 Compare July 12, 2021 09:19
@FeixLiu FeixLiu force-pushed the softmax_mask_fuse branch from 00c0a44 to 7151fcc Compare July 12, 2021 09:34
@FeixLiu FeixLiu force-pushed the softmax_mask_fuse branch 3 times, most recently from b15c513 to c959fcf Compare July 12, 2021 10:01
@FeixLiu FeixLiu force-pushed the softmax_mask_fuse branch from c959fcf to cb9f173 Compare July 13, 2021 05:53
@FeixLiu FeixLiu force-pushed the softmax_mask_fuse branch from cb9f173 to 6e6ff16 Compare July 13, 2021 05:56
@FeixLiu FeixLiu force-pushed the softmax_mask_fuse branch from 6e6ff16 to f01ffe0 Compare July 14, 2021 02:56
Copy link
Member

@ForFishes ForFishes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG API

Copy link

@PangHua PangHua left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ForFishes ForFishes merged commit 44bdbe9 into PaddlePaddle:develop Jul 16, 2021
@FeixLiu FeixLiu deleted the softmax_mask_fuse branch July 16, 2021 02:58
@FeixLiu FeixLiu changed the title softmax mask fuse op [hybrid performance] softmax mask fuse op Oct 11, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants