Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fused attention op forward #35905

Merged
merged 32 commits into from
Oct 22, 2021

Conversation

limin2021
Copy link
Contributor

@limin2021 limin2021 commented Sep 22, 2021

PR types

New features

PR changes

OPs

Describe

  1. 功能:本PR的目标是提高attention模块的计算性能。
    为了减少框架层对op的调度开销,本PR通过在C++层手动实现attention模块,对外提供attention 大op;
    为了减少防存开销,本PR采取了两种优化方法:
    (1)在q,k,v计算时通过共享输入X,将该处的gemm,transpose和bias add从三次调用减少为一次;
    (2)使用kernel融合优化技术,在不同cuda kernel之间通过寄存器传输数据;

  2. fused_attention_op 实现的计算逻辑:
    image

  3. fused_attention_op与paddle已有的MultiHeadAttention layer的不同:
    (1)计算逻辑范围扩大了,详见上面的伪代码。
    (2)q, k, v的weight存储格式不一样。
    原有的:保存在三个weight张量中,WQ, WK, WV
    本PR:保存在一个weight张量中,qkv_weight
    由WQ, WK, WV得到qkv_weight的方法:
    image

  4. 实现:
    本PR是fused_attention_op 的前向实现,具体细节:

(1) fused_attention_op.cc and fused_attention_op.cu
The C++ forward impl for fused_attention_op. The impl uses these PRs:
#34883, #35308, #35350 #35621 , #35903, #36185

(2) functional/fused_transformer.py
The python api for fused_attention_op.
Here, it only include dynamic graph api,
the static graph api will be added in the next PR.

(3) test_fused_attention_op.py
The unittest script for fused_attention_op: dynamic, forward;

(4) paddle/fluid/operators/dropout_impl_util.h
Modifications of contents of dropout_impl_util.h in #35820 is overlapped by #36185.
In this PR, we recovered the contents to be same as #35820.

(5) Fix bugs: remove useless "print" in framework.py.

  1. Unittest:
    756c4cbdab1aa8507ced0ef3cc48ccdb

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

"H(`begin_norm_axis` splits the tensor(`X`) to a matrix [N,H])."
"It is applied to the output.")
.AsDispensable();
// AddInput("Seed",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注释代码可以删除

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

});

AddComment(R"DOC(
Fused attention op:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

格式似乎不对

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


namespace paddle {
namespace operators {} // namespace operators
} // namespace paddle
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件可以不需要吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是需要的。pre-commit要求op.cc和op.cu必须include op.h,否则过不了。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该不是强制要求必须有.h的,multi_dot就只有multi_dot_op.cc。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

import unittest


def _convert_attention_mask(attn_mask, dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以调transformer里面的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"(batch_size, seq_len, dim_embed),"
"but received dimensions of"
"Input is [%d]",
x_dim.size()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的QKV_input指的是是x?最好对应到用户接口里的输入。报错信息需要让用户看到后能准确理解和定位问题。
下面也有类似问题

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. 已全部check。

@@ -0,0 +1,298 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是和 #35843 完全一样的吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

一样的哈。等他的合入,这里merge即可。此处不需要review。

TCChenlong
TCChenlong previously approved these changes Sep 23, 2021
Copy link
Contributor

@TCChenlong TCChenlong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

kolinwei
kolinwei previously approved these changes Sep 24, 2021
zkh2016
zkh2016 previously approved these changes Sep 24, 2021
Copy link
Contributor

@zkh2016 zkh2016 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@xingfeng01
Copy link
Contributor

LGTM

xingfeng01
xingfeng01 previously approved these changes Sep 24, 2021
zhangting2020
zhangting2020 previously approved these changes Sep 24, 2021
lanxianghit
lanxianghit previously approved these changes Sep 24, 2021
zhiqiu
zhiqiu previously approved these changes Sep 24, 2021
Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for op_function_generator.cc

you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注意copyright的格式

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是一个输入,名字叫OutXxx会不会有点奇怪?

Copy link
Contributor Author

@limin2021 limin2021 Sep 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在下一个PR: fused_attention_bw中一起修改,会将OutLinear修改成Linear,也会一起polish其他命名。

platform::errors::InvalidArgument(
"'attn_dropout_prob' must be between 0.0 and 1.0."));
});
AddAttr<bool>("is_test1",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

尽量不要叫xxx_1这种名字,很难区分。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

" train: out = input * mask / ( 1.0 - dropout_prob )"
" inference: out = input"
" dropout op can be removed from the program. the program will be "
"efficient")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这么长一段重复的解释,用一个宏、或者一个函数?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cmake里面加了限制,不支持ROCM。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

LaunchResidualDropoutBiasGrad<T, uint8_t>(
dout, mask, dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train, rows_, cols_, dsrc, dbias, ctx);
cudaMemcpyAsync(dresidual, dout, rows_ * cols_ * sizeof(T),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

memory::Copy

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件引自#35843 ,可在其下review,此处无需review。待#35843 合入后,本PR会merge更新。


@unittest.skipIf(not core.is_compiled_with_cuda(),
"Paddle core is not compiled with CUDA")
class TestFusedAttentionOp(unittest.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OP单测建议继承OpTest

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

import unittest


@unittest.skipIf(not core.is_compiled_with_cuda(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

只支持CUDA,可以加在CMakeLists.txt里面。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -60,6 +60,7 @@
from .conv import conv1d # noqa: F401
from .conv import conv1d_transpose # noqa: F401
from .common import linear # noqa: F401
from .common import fused_multihead_attention # noqa: F401
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个functional接口是否要暴露给用户?

Copy link
Contributor Author

@limin2021 limin2021 Sep 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要暴露,已move到functional/fused_transformer.py中。

@@ -1502,6 +1502,33 @@ def linear(x, weight, bias=None, name=None):
return res


def fused_multihead_attention(x,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议还是先实现在fused_transformer.py里面吧。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

auto *qkv_bias_out_data = qkv_bias_out->mutable_data<T>(ctx.GetPlace());

// get data ptr for FMHA.
auto *src_mask_data = (src_mask == nullptr ? nullptr : src_mask->data<T>());
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个变量好像没有使用的地方?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

zhangting2020
zhangting2020 previously approved these changes Oct 14, 2021
Copy link
Contributor

@zhangting2020 zhangting2020 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

xingfeng01
xingfeng01 previously approved these changes Oct 14, 2021
@limin2021 limin2021 dismissed stale reviews from xingfeng01 and zhangting2020 via 10687a6 October 21, 2021 06:02
xingfeng01
xingfeng01 previously approved these changes Oct 21, 2021
out = out * v;
out = transpose(out, perm=[0, 2, 1, 3]);
out = out_linear(out);
out = layer_norm(x + dropout(linear_bias + out));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里格式是不是有点问题

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

None, None, None, None, 1e-5, qkv_bias,
linear_bias, attn_mask)
# [2, 4, 128]
print(output)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里[2, 4, 128]是输出shape吧,print(output.shape)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

@TCChenlong TCChenlong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO:Fix API Docs

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for op_function_generator.cc

@lanxianghit lanxianghit merged commit d490621 into PaddlePaddle:develop Oct 22, 2021
limin2021 added a commit to limin2021/Paddle that referenced this pull request Oct 22, 2021
功能:本PR的目标是提高attention模块的计算性能。
为了减少框架层对op的调度开销,本PR通过在C++层手动实现attention模块,对外提供attention 大op;
为了减少防存开销,本PR采取了两种优化方法:
(1)在q,k,v计算时通过共享输入X,将该处的gemm,transpose和bias add从三次调用减少为一次;
(2)使用kernel融合优化技术,在不同cuda kernel之间通过寄存器传输数据;
limin2021 added a commit to limin2021/Paddle that referenced this pull request Oct 25, 2021
功能:本PR的目标是提高attention模块的计算性能。
为了减少框架层对op的调度开销,本PR通过在C++层手动实现attention模块,对外提供attention 大op;
为了减少防存开销,本PR采取了两种优化方法:
(1)在q,k,v计算时通过共享输入X,将该处的gemm,transpose和bias add从三次调用减少为一次;
(2)使用kernel融合优化技术,在不同cuda kernel之间通过寄存器传输数据;
limin2021 added a commit to limin2021/Paddle that referenced this pull request Oct 25, 2021
功能:本PR的目标是提高attention模块的计算性能。
为了减少框架层对op的调度开销,本PR通过在C++层手动实现attention模块,对外提供attention 大op;
为了减少防存开销,本PR采取了两种优化方法:
(1)在q,k,v计算时通过共享输入X,将该处的gemm,transpose和bias add从三次调用减少为一次;
(2)使用kernel融合优化技术,在不同cuda kernel之间通过寄存器传输数据;
lanxianghit pushed a commit that referenced this pull request Oct 26, 2021
功能:本PR的目标是提高attention模块的计算性能。
为了减少框架层对op的调度开销,本PR通过在C++层手动实现attention模块,对外提供attention 大op;
为了减少防存开销,本PR采取了两种优化方法:
(1)在q,k,v计算时通过共享输入X,将该处的gemm,transpose和bias add从三次调用减少为一次;
(2)使用kernel融合优化技术,在不同cuda kernel之间通过寄存器传输数据;
lanxianghit pushed a commit that referenced this pull request Oct 26, 2021
…ubate (#36704)

将 #35905#35843 PR中新增的的python api接口移到incubate目录下。
zkh2016 pushed a commit to zkh2016/Paddle that referenced this pull request Oct 26, 2021
lanxianghit pushed a commit that referenced this pull request Oct 26, 2021
* add op: fused_feedforward(backward) (#35611)

这个PR是fused_feedforward反向的代码

相关kernel实现:fused_dropout_act_bias, fused_residual_dropout_bias, fused_layernorm_residual_dropout_bias

fused_feedforward是一个融合算子,该算子对transformer模型的feed forward层的算子进行融合和封装,使得前端只呈现一个接口,通过融合减少部分访存和kernel launch的时间,以此提升性能。

* Move fused_attention and fused_feedforward functional api path to incubate (#36704)

将 #35905#35843 PR中新增的的python api接口移到incubate目录下。
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants