Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[2/4] CUDNNv8 ResNet Fusion: Add fused_scale_bias_add_relu OP #58504

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@
'fused_embedding_eltwise_layernorm',
'fused_fc_elementwise_layernorm',
'fused_multi_transformer_xpu',
'fused_scale_bias_relu_conv_bnstats',
'fused_scale_bias_relu_conv_bn',
'fused_scale_bias_add_relu',
'fusion_transpose_flatten_concat',
'generate_sequence_xpu',
'layer_norm_act_xpu',
Expand Down Expand Up @@ -103,7 +104,8 @@
'embedding_grad_sparse',
'fused_batch_norm_act_',
'fused_bn_add_activation_',
'fused_scale_bias_relu_conv_bnstats',
'fused_scale_bias_relu_conv_bn',
'fused_scale_bias_add_relu',
'memcpy',
'print',
'recv_v2',
Expand Down
16 changes: 13 additions & 3 deletions paddle/phi/api/yaml/fused_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -206,14 +206,24 @@
backward: fused_rotary_position_embedding_grad
support_dygraph_mode : true

- op : fused_scale_bias_relu_conv_bnstats
- op : fused_scale_bias_add_relu
args : (Tensor x1, Tensor scale1, Tensor bias1, Tensor x2, Tensor scale2, Tensor bias2, bool fuse_dual, bool exhaustive_search)
optional : scale2, bias2
output : Tensor(y)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

一般来说,单输出都取名为out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

会在下个PR修改.

infer_meta :
func : FusedScaleBiasAddReluInferMeta
kernel :
func : fused_scale_bias_add_relu
data_type : x1

- op : fused_scale_bias_relu_conv_bn
args : (Tensor x, Tensor w, Tensor scale, Tensor bias, Tensor bn_scale, Tensor bn_bias, Tensor input_running_mean, Tensor input_running_var, int[] paddings, int[] dilations, int[] strides, str padding_algorithm, int groups, str data_format, float momentum, float epsilon, bool fuse_prologue, bool exhaustive_search, int64_t accumulation_count = 0)
optional : scale, bias
output : Tensor(out), Tensor(out_running_mean), Tensor(out_running_var), Tensor(saved_mean), Tensor(saved_var), Tensor(eq_scale), Tensor(eq_bias)
infer_meta :
func : FusedScaleBiasReluConvBnstatsInferMeta
func : FusedScaleBiasReluConvBnInferMeta
kernel :
func : fused_scale_bias_relu_conv_bnstats
func : fused_scale_bias_relu_conv_bn
data_type : x

- op : fusion_gru
Expand Down
87 changes: 56 additions & 31 deletions paddle/phi/infermeta/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1687,41 +1687,40 @@ void LayerNormActXPUInferMeta(const MetaTensor& x,
y->set_layout(x.layout());
}

void FusedScaleBiasReluConvBnstatsInferMeta(
const MetaTensor& x,
const MetaTensor& w,
const MetaTensor& scale,
const MetaTensor& bias,
const MetaTensor& bn_scale,
const MetaTensor& bn_bias,
const MetaTensor& input_running_mean,
const MetaTensor& input_running_var,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::string& padding_algorithm,
int groups,
const std::string& data_format,
float momentum,
float epsilon,
bool fuse_prologue,
bool exhaustive_search,
int64_t accumulation_count,
MetaTensor* out,
MetaTensor* out_running_mean,
MetaTensor* out_running_var,
MetaTensor* saved_mean,
MetaTensor* saved_var,
MetaTensor* eq_scale,
MetaTensor* eq_bias) {
void FusedScaleBiasReluConvBnInferMeta(const MetaTensor& x,
const MetaTensor& w,
const MetaTensor& scale,
const MetaTensor& bias,
const MetaTensor& bn_scale,
const MetaTensor& bn_bias,
const MetaTensor& input_running_mean,
const MetaTensor& input_running_var,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::string& padding_algorithm,
int groups,
const std::string& data_format,
float momentum,
float epsilon,
bool fuse_prologue,
bool exhaustive_search,
int64_t accumulation_count,
MetaTensor* out,
MetaTensor* out_running_mean,
MetaTensor* out_running_var,
MetaTensor* saved_mean,
MetaTensor* saved_var,
MetaTensor* eq_scale,
MetaTensor* eq_bias) {
auto in_dims = x.dims();
auto filter_dims = w.dims();
// do some checks
PADDLE_ENFORCE_EQ(
in_dims.size(),
4,
phi::errors::InvalidArgument(
"The input of Op(FusedScaleBiasReluConvBnstats) should be a 4-D "
"The input of Op(FusedScaleBiasReluConvBn) should be a 4-D "
"Tensor. But "
"received: input's dimension is %u, input's shape is [%s].",
in_dims.size(),
Expand All @@ -1732,7 +1731,7 @@ void FusedScaleBiasReluConvBnstatsInferMeta(
filter_dims.size(),
phi::errors::InvalidArgument(
"The input's dimension and filter's dimension of "
"Op(FusedScaleBiasReluConvBnstats) should be equal. But received: "
"Op(FusedScaleBiasReluConvBn) should be equal. But received: "
"the input's"
" shape is [%s], "
"the input's dimension is %d; the filter's shape is [%s], "
Expand All @@ -1747,7 +1746,7 @@ void FusedScaleBiasReluConvBnstatsInferMeta(
data_format,
"NHWC",
phi::errors::InvalidArgument(
"Operator(FusedScaleBiasReluConvBnstats) only supports data format "
"Operator(FusedScaleBiasReluConvBn) only supports data format "
"of "
"channel last (NHWC) now. But recieved: data_format = '%s'.",
data_format));
Expand All @@ -1774,7 +1773,7 @@ void FusedScaleBiasReluConvBnstatsInferMeta(
filter_dims[1] * groups,
phi::errors::InvalidArgument(
"The number of input's channels should be equal to filter's channels "
"* groups for Op(FusedScaleBiasReluConvBnstats). But received: the "
"* groups for Op(FusedScaleBiasReluConvBn). But received: the "
"input's"
" channels is %d, "
"the input's shape is [%s]; the filter's channels is %d, the "
Expand Down Expand Up @@ -1821,6 +1820,32 @@ void FusedScaleBiasReluConvBnstatsInferMeta(
eq_bias->set_dims(c_dims);
}

void FusedScaleBiasAddReluInferMeta(const MetaTensor& x1,
const MetaTensor& scale1,
const MetaTensor& bias1,
const MetaTensor& x2,
const MetaTensor& scale2,
const MetaTensor& bias2,
bool fuse_dual,
bool exhaustive_search,
MetaTensor* y) {
// check optional inputs
if (fuse_dual) {
bool has_scale2 = !!scale2;
bool has_bias2 = !!bias2;
PADDLE_ENFORCE(has_scale2 && has_bias2,
phi::errors::InvalidArgument(
"Argument scale2 and bias2 should be provided when "
"fuse_dual is set, but got has_scale2=%d, has_bias2=%d, "
"fuse_dual=%d.",
has_scale2,
has_bias2,
fuse_dual));
}
// set output dims
y->set_dims(x1.dims());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

严格来说,这里也需要设置dtype和layout

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

会在下个PR修改.

}

void SqueezeExcitationInferMeta(const MetaTensor& x,
const MetaTensor& filter,
const MetaTensor& filter_max,
Expand Down
63 changes: 36 additions & 27 deletions paddle/phi/infermeta/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -451,33 +451,42 @@ void LayerNormActXPUInferMeta(const MetaTensor& x,
float act_param,
MetaTensor* y);

void FusedScaleBiasReluConvBnstatsInferMeta(
const MetaTensor& x,
const MetaTensor& w,
const MetaTensor& scale,
const MetaTensor& bias,
const MetaTensor& bn_scale,
const MetaTensor& bn_bias,
const MetaTensor& input_running_mean,
const MetaTensor& input_running_var,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::string& padding_algorithm,
int groups,
const std::string& data_format,
float momentum,
float epsilon,
bool fuse_prologue,
bool exhaustive_search,
int64_t accumulation_count,
MetaTensor* out,
MetaTensor* out_running_mean,
MetaTensor* out_running_var,
MetaTensor* saved_mean,
MetaTensor* saved_var,
MetaTensor* eq_scale,
MetaTensor* eq_bias);
void FusedScaleBiasReluConvBnInferMeta(const MetaTensor& x,
const MetaTensor& w,
const MetaTensor& scale,
const MetaTensor& bias,
const MetaTensor& bn_scale,
const MetaTensor& bn_bias,
const MetaTensor& input_running_mean,
const MetaTensor& input_running_var,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::string& padding_algorithm,
int groups,
const std::string& data_format,
float momentum,
float epsilon,
bool fuse_prologue,
bool exhaustive_search,
int64_t accumulation_count,
MetaTensor* out,
MetaTensor* out_running_mean,
MetaTensor* out_running_var,
MetaTensor* saved_mean,
MetaTensor* saved_var,
MetaTensor* eq_scale,
MetaTensor* eq_bias);

void FusedScaleBiasAddReluInferMeta(const MetaTensor& x1,
const MetaTensor& scale1,
const MetaTensor& bias1,
const MetaTensor& x2,
const MetaTensor& scale2,
const MetaTensor& bias2,
bool fuse_prologue,
bool exhaustive_search,
MetaTensor* y);

void SqueezeExcitationInferMeta(const MetaTensor& x,
const MetaTensor& filter,
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ endif()

if(NOT WITH_CUDNN_FRONTEND)
list(REMOVE_ITEM kernel_cu
"fusion/gpu/fused_scale_bias_relu_conv_bnstats_kernel.cu")
"fusion/gpu/fused_scale_bias_relu_conv_bn_kernel.cu"
"fusion/gpu/fused_scale_bias_add_relu_kernel.cu")
endif()

set(cc_search_pattern
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/autotune/cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ std::string AlgorithmTypeString(int64_t algo_type) {
return "scale_bias_relu_conv_bnstats";
} else if (algo_type == static_cast<int64_t>(AlgorithmType::kBNFinalize)) {
return "bn_finalize";
} else if (algo_type ==
static_cast<int64_t>(AlgorithmType::kScaleBiasAddRelu)) {
return "scale_bias_add_relu";
}
#endif
return std::to_string(algo_type);
Expand Down
5 changes: 3 additions & 2 deletions paddle/phi/kernels/autotune/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ enum class AlgorithmType {
kConvBackwardFilterV8 = 12,
kScaleBiasReluConvBNstats = 13,
kBNFinalize = 14,
kAlgorithmCount = 15
kScaleBiasAddRelu = 15,
kAlgorithmCount = 16
#endif
};

Expand Down Expand Up @@ -181,7 +182,7 @@ class AutoTuneCache {
}
#ifdef PADDLE_WITH_CUDNN_FRONTEND
} else if (algo_type >= AlgorithmType::kConvForwardV8 &&
algo_type <= AlgorithmType::kBNFinalize) {
algo_type < AlgorithmType::kAlgorithmCount) {
int64_t key = static_cast<int64_t>(algo_type);
if (cudnn_v8_auto_tune_map_.find(key) == cudnn_v8_auto_tune_map_.end()) {
CudnnFrontendPlanCache cache;
Expand Down
Loading