Skip to content

Commit

Permalink
[2/4] CUDNNv8 ResNet Fusion: Add fused_scale_bias_add_relu OP (Paddle…
Browse files Browse the repository at this point in the history
…Paddle#58504)

* Rename op

* Add fused_scale_bias_add_relu
  • Loading branch information
Tom-Zheng authored Nov 13, 2023
1 parent 5511ae0 commit bcd2c52
Show file tree
Hide file tree
Showing 14 changed files with 530 additions and 107 deletions.
6 changes: 4 additions & 2 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@
'fused_embedding_eltwise_layernorm',
'fused_fc_elementwise_layernorm',
'fused_multi_transformer_xpu',
'fused_scale_bias_relu_conv_bnstats',
'fused_scale_bias_relu_conv_bn',
'fused_scale_bias_add_relu',
'fusion_transpose_flatten_concat',
'generate_sequence_xpu',
'layer_norm_act_xpu',
Expand Down Expand Up @@ -106,7 +107,8 @@
'embedding_grad_sparse',
'fused_batch_norm_act_',
'fused_bn_add_activation_',
'fused_scale_bias_relu_conv_bnstats',
'fused_scale_bias_relu_conv_bn',
'fused_scale_bias_add_relu',
'memcpy',
'print',
'recv_v2',
Expand Down
16 changes: 13 additions & 3 deletions paddle/phi/api/yaml/fused_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -206,14 +206,24 @@
backward: fused_rotary_position_embedding_grad
support_dygraph_mode : true

- op : fused_scale_bias_relu_conv_bnstats
- op : fused_scale_bias_add_relu
args : (Tensor x1, Tensor scale1, Tensor bias1, Tensor x2, Tensor scale2, Tensor bias2, bool fuse_dual, bool exhaustive_search)
optional : scale2, bias2
output : Tensor(y)
infer_meta :
func : FusedScaleBiasAddReluInferMeta
kernel :
func : fused_scale_bias_add_relu
data_type : x1

- op : fused_scale_bias_relu_conv_bn
args : (Tensor x, Tensor w, Tensor scale, Tensor bias, Tensor bn_scale, Tensor bn_bias, Tensor input_running_mean, Tensor input_running_var, int[] paddings, int[] dilations, int[] strides, str padding_algorithm, int groups, str data_format, float momentum, float epsilon, bool fuse_prologue, bool exhaustive_search, int64_t accumulation_count = 0)
optional : scale, bias
output : Tensor(out), Tensor(out_running_mean), Tensor(out_running_var), Tensor(saved_mean), Tensor(saved_var), Tensor(eq_scale), Tensor(eq_bias)
infer_meta :
func : FusedScaleBiasReluConvBnstatsInferMeta
func : FusedScaleBiasReluConvBnInferMeta
kernel :
func : fused_scale_bias_relu_conv_bnstats
func : fused_scale_bias_relu_conv_bn
data_type : x

- op : fusion_gru
Expand Down
87 changes: 56 additions & 31 deletions paddle/phi/infermeta/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1687,41 +1687,40 @@ void LayerNormActXPUInferMeta(const MetaTensor& x,
y->set_layout(x.layout());
}

void FusedScaleBiasReluConvBnstatsInferMeta(
const MetaTensor& x,
const MetaTensor& w,
const MetaTensor& scale,
const MetaTensor& bias,
const MetaTensor& bn_scale,
const MetaTensor& bn_bias,
const MetaTensor& input_running_mean,
const MetaTensor& input_running_var,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::string& padding_algorithm,
int groups,
const std::string& data_format,
float momentum,
float epsilon,
bool fuse_prologue,
bool exhaustive_search,
int64_t accumulation_count,
MetaTensor* out,
MetaTensor* out_running_mean,
MetaTensor* out_running_var,
MetaTensor* saved_mean,
MetaTensor* saved_var,
MetaTensor* eq_scale,
MetaTensor* eq_bias) {
void FusedScaleBiasReluConvBnInferMeta(const MetaTensor& x,
const MetaTensor& w,
const MetaTensor& scale,
const MetaTensor& bias,
const MetaTensor& bn_scale,
const MetaTensor& bn_bias,
const MetaTensor& input_running_mean,
const MetaTensor& input_running_var,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::string& padding_algorithm,
int groups,
const std::string& data_format,
float momentum,
float epsilon,
bool fuse_prologue,
bool exhaustive_search,
int64_t accumulation_count,
MetaTensor* out,
MetaTensor* out_running_mean,
MetaTensor* out_running_var,
MetaTensor* saved_mean,
MetaTensor* saved_var,
MetaTensor* eq_scale,
MetaTensor* eq_bias) {
auto in_dims = x.dims();
auto filter_dims = w.dims();
// do some checks
PADDLE_ENFORCE_EQ(
in_dims.size(),
4,
phi::errors::InvalidArgument(
"The input of Op(FusedScaleBiasReluConvBnstats) should be a 4-D "
"The input of Op(FusedScaleBiasReluConvBn) should be a 4-D "
"Tensor. But "
"received: input's dimension is %u, input's shape is [%s].",
in_dims.size(),
Expand All @@ -1732,7 +1731,7 @@ void FusedScaleBiasReluConvBnstatsInferMeta(
filter_dims.size(),
phi::errors::InvalidArgument(
"The input's dimension and filter's dimension of "
"Op(FusedScaleBiasReluConvBnstats) should be equal. But received: "
"Op(FusedScaleBiasReluConvBn) should be equal. But received: "
"the input's"
" shape is [%s], "
"the input's dimension is %d; the filter's shape is [%s], "
Expand All @@ -1747,7 +1746,7 @@ void FusedScaleBiasReluConvBnstatsInferMeta(
data_format,
"NHWC",
phi::errors::InvalidArgument(
"Operator(FusedScaleBiasReluConvBnstats) only supports data format "
"Operator(FusedScaleBiasReluConvBn) only supports data format "
"of "
"channel last (NHWC) now. But recieved: data_format = '%s'.",
data_format));
Expand All @@ -1774,7 +1773,7 @@ void FusedScaleBiasReluConvBnstatsInferMeta(
filter_dims[1] * groups,
phi::errors::InvalidArgument(
"The number of input's channels should be equal to filter's channels "
"* groups for Op(FusedScaleBiasReluConvBnstats). But received: the "
"* groups for Op(FusedScaleBiasReluConvBn). But received: the "
"input's"
" channels is %d, "
"the input's shape is [%s]; the filter's channels is %d, the "
Expand Down Expand Up @@ -1821,6 +1820,32 @@ void FusedScaleBiasReluConvBnstatsInferMeta(
eq_bias->set_dims(c_dims);
}

void FusedScaleBiasAddReluInferMeta(const MetaTensor& x1,
const MetaTensor& scale1,
const MetaTensor& bias1,
const MetaTensor& x2,
const MetaTensor& scale2,
const MetaTensor& bias2,
bool fuse_dual,
bool exhaustive_search,
MetaTensor* y) {
// check optional inputs
if (fuse_dual) {
bool has_scale2 = !!scale2;
bool has_bias2 = !!bias2;
PADDLE_ENFORCE(has_scale2 && has_bias2,
phi::errors::InvalidArgument(
"Argument scale2 and bias2 should be provided when "
"fuse_dual is set, but got has_scale2=%d, has_bias2=%d, "
"fuse_dual=%d.",
has_scale2,
has_bias2,
fuse_dual));
}
// set output dims
y->set_dims(x1.dims());
}

void SqueezeExcitationInferMeta(const MetaTensor& x,
const MetaTensor& filter,
const MetaTensor& filter_max,
Expand Down
63 changes: 36 additions & 27 deletions paddle/phi/infermeta/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -451,33 +451,42 @@ void LayerNormActXPUInferMeta(const MetaTensor& x,
float act_param,
MetaTensor* y);

void FusedScaleBiasReluConvBnstatsInferMeta(
const MetaTensor& x,
const MetaTensor& w,
const MetaTensor& scale,
const MetaTensor& bias,
const MetaTensor& bn_scale,
const MetaTensor& bn_bias,
const MetaTensor& input_running_mean,
const MetaTensor& input_running_var,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::string& padding_algorithm,
int groups,
const std::string& data_format,
float momentum,
float epsilon,
bool fuse_prologue,
bool exhaustive_search,
int64_t accumulation_count,
MetaTensor* out,
MetaTensor* out_running_mean,
MetaTensor* out_running_var,
MetaTensor* saved_mean,
MetaTensor* saved_var,
MetaTensor* eq_scale,
MetaTensor* eq_bias);
void FusedScaleBiasReluConvBnInferMeta(const MetaTensor& x,
const MetaTensor& w,
const MetaTensor& scale,
const MetaTensor& bias,
const MetaTensor& bn_scale,
const MetaTensor& bn_bias,
const MetaTensor& input_running_mean,
const MetaTensor& input_running_var,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::string& padding_algorithm,
int groups,
const std::string& data_format,
float momentum,
float epsilon,
bool fuse_prologue,
bool exhaustive_search,
int64_t accumulation_count,
MetaTensor* out,
MetaTensor* out_running_mean,
MetaTensor* out_running_var,
MetaTensor* saved_mean,
MetaTensor* saved_var,
MetaTensor* eq_scale,
MetaTensor* eq_bias);

void FusedScaleBiasAddReluInferMeta(const MetaTensor& x1,
const MetaTensor& scale1,
const MetaTensor& bias1,
const MetaTensor& x2,
const MetaTensor& scale2,
const MetaTensor& bias2,
bool fuse_prologue,
bool exhaustive_search,
MetaTensor* y);

void SqueezeExcitationInferMeta(const MetaTensor& x,
const MetaTensor& filter,
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ endif()

if(NOT WITH_CUDNN_FRONTEND)
list(REMOVE_ITEM kernel_cu
"fusion/gpu/fused_scale_bias_relu_conv_bnstats_kernel.cu")
"fusion/gpu/fused_scale_bias_relu_conv_bn_kernel.cu"
"fusion/gpu/fused_scale_bias_add_relu_kernel.cu")
endif()

set(cc_search_pattern
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/autotune/cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ std::string AlgorithmTypeString(int64_t algo_type) {
return "scale_bias_relu_conv_bnstats";
} else if (algo_type == static_cast<int64_t>(AlgorithmType::kBNFinalize)) {
return "bn_finalize";
} else if (algo_type ==
static_cast<int64_t>(AlgorithmType::kScaleBiasAddRelu)) {
return "scale_bias_add_relu";
}
#endif
return std::to_string(algo_type);
Expand Down
5 changes: 3 additions & 2 deletions paddle/phi/kernels/autotune/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ enum class AlgorithmType {
kConvBackwardFilterV8 = 12,
kScaleBiasReluConvBNstats = 13,
kBNFinalize = 14,
kAlgorithmCount = 15
kScaleBiasAddRelu = 15,
kAlgorithmCount = 16
#endif
};

Expand Down Expand Up @@ -181,7 +182,7 @@ class AutoTuneCache {
}
#ifdef PADDLE_WITH_CUDNN_FRONTEND
} else if (algo_type >= AlgorithmType::kConvForwardV8 &&
algo_type <= AlgorithmType::kBNFinalize) {
algo_type < AlgorithmType::kAlgorithmCount) {
int64_t key = static_cast<int64_t>(algo_type);
if (cudnn_v8_auto_tune_map_.find(key) == cudnn_v8_auto_tune_map_.end()) {
CudnnFrontendPlanCache cache;
Expand Down
Loading

0 comments on commit bcd2c52

Please sign in to comment.