From 7b8b8faf2333239fd36a8bb05613b319b3e6d12d Mon Sep 17 00:00:00 2001 From: "Tian Zheng (Engrg-Hardware 1)" Date: Fri, 27 Oct 2023 08:27:15 +0000 Subject: [PATCH 1/2] Rename op --- .../pir/dialect/op_generator/ops_api_gen.py | 4 +- paddle/phi/api/yaml/fused_ops.yaml | 6 +- paddle/phi/infermeta/fusion.cc | 61 +++++++++---------- paddle/phi/infermeta/fusion.h | 53 ++++++++-------- paddle/phi/kernels/CMakeLists.txt | 2 +- ...> fused_scale_bias_relu_conv_bn_kernel.cu} | 59 +++++++++--------- test/legacy_test/CMakeLists.txt | 2 +- ... test_fused_scale_bias_relu_conv_bn_op.py} | 12 ++-- test/white_list/op_accuracy_white_list.py | 2 +- tools/gpups_test.sh | 2 +- 10 files changed, 98 insertions(+), 105 deletions(-) rename paddle/phi/kernels/fusion/gpu/{fused_scale_bias_relu_conv_bnstats_kernel.cu => fused_scale_bias_relu_conv_bn_kernel.cu} (92%) rename test/legacy_test/{test_fused_scale_bias_relu_conv_bnstats_op.py => test_fused_scale_bias_relu_conv_bn_op.py} (95%) diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index c2076b25cd514..e265b7c44e03c 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -73,7 +73,7 @@ 'fused_embedding_eltwise_layernorm', 'fused_fc_elementwise_layernorm', 'fused_multi_transformer_xpu', - 'fused_scale_bias_relu_conv_bnstats', + 'fused_scale_bias_relu_conv_bn', 'fusion_transpose_flatten_concat', 'generate_sequence_xpu', 'layer_norm_act_xpu', @@ -103,7 +103,7 @@ 'embedding_grad_sparse', 'fused_batch_norm_act_', 'fused_bn_add_activation_', - 'fused_scale_bias_relu_conv_bnstats', + 'fused_scale_bias_relu_conv_bn', 'memcpy', 'print', 'recv_v2', diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index 45186294ce979..ba5e5bbaa07ab 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -206,14 +206,14 @@ backward: fused_rotary_position_embedding_grad support_dygraph_mode : true -- op : fused_scale_bias_relu_conv_bnstats +- op : fused_scale_bias_relu_conv_bn args : (Tensor x, Tensor w, Tensor scale, Tensor bias, Tensor bn_scale, Tensor bn_bias, Tensor input_running_mean, Tensor input_running_var, int[] paddings, int[] dilations, int[] strides, str padding_algorithm, int groups, str data_format, float momentum, float epsilon, bool fuse_prologue, bool exhaustive_search, int64_t accumulation_count = 0) optional : scale, bias output : Tensor(out), Tensor(out_running_mean), Tensor(out_running_var), Tensor(saved_mean), Tensor(saved_var), Tensor(eq_scale), Tensor(eq_bias) infer_meta : - func : FusedScaleBiasReluConvBnstatsInferMeta + func : FusedScaleBiasReluConvBnInferMeta kernel : - func : fused_scale_bias_relu_conv_bnstats + func : fused_scale_bias_relu_conv_bn data_type : x - op : fusion_gru diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index e7062879573c5..06a457ccaba09 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -1687,33 +1687,32 @@ void LayerNormActXPUInferMeta(const MetaTensor& x, y->set_layout(x.layout()); } -void FusedScaleBiasReluConvBnstatsInferMeta( - const MetaTensor& x, - const MetaTensor& w, - const MetaTensor& scale, - const MetaTensor& bias, - const MetaTensor& bn_scale, - const MetaTensor& bn_bias, - const MetaTensor& input_running_mean, - const MetaTensor& input_running_var, - const std::vector& paddings, - const std::vector& dilations, - const std::vector& strides, - const std::string& padding_algorithm, - int groups, - const std::string& data_format, - float momentum, - float epsilon, - bool fuse_prologue, - bool exhaustive_search, - int64_t accumulation_count, - MetaTensor* out, - MetaTensor* out_running_mean, - MetaTensor* out_running_var, - MetaTensor* saved_mean, - MetaTensor* saved_var, - MetaTensor* eq_scale, - MetaTensor* eq_bias) { +void FusedScaleBiasReluConvBnInferMeta(const MetaTensor& x, + const MetaTensor& w, + const MetaTensor& scale, + const MetaTensor& bias, + const MetaTensor& bn_scale, + const MetaTensor& bn_bias, + const MetaTensor& input_running_mean, + const MetaTensor& input_running_var, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const std::string& padding_algorithm, + int groups, + const std::string& data_format, + float momentum, + float epsilon, + bool fuse_prologue, + bool exhaustive_search, + int64_t accumulation_count, + MetaTensor* out, + MetaTensor* out_running_mean, + MetaTensor* out_running_var, + MetaTensor* saved_mean, + MetaTensor* saved_var, + MetaTensor* eq_scale, + MetaTensor* eq_bias) { auto in_dims = x.dims(); auto filter_dims = w.dims(); // do some checks @@ -1721,7 +1720,7 @@ void FusedScaleBiasReluConvBnstatsInferMeta( in_dims.size(), 4, phi::errors::InvalidArgument( - "The input of Op(FusedScaleBiasReluConvBnstats) should be a 4-D " + "The input of Op(FusedScaleBiasReluConvBn) should be a 4-D " "Tensor. But " "received: input's dimension is %u, input's shape is [%s].", in_dims.size(), @@ -1732,7 +1731,7 @@ void FusedScaleBiasReluConvBnstatsInferMeta( filter_dims.size(), phi::errors::InvalidArgument( "The input's dimension and filter's dimension of " - "Op(FusedScaleBiasReluConvBnstats) should be equal. But received: " + "Op(FusedScaleBiasReluConvBn) should be equal. But received: " "the input's" " shape is [%s], " "the input's dimension is %d; the filter's shape is [%s], " @@ -1747,7 +1746,7 @@ void FusedScaleBiasReluConvBnstatsInferMeta( data_format, "NHWC", phi::errors::InvalidArgument( - "Operator(FusedScaleBiasReluConvBnstats) only supports data format " + "Operator(FusedScaleBiasReluConvBn) only supports data format " "of " "channel last (NHWC) now. But recieved: data_format = '%s'.", data_format)); @@ -1774,7 +1773,7 @@ void FusedScaleBiasReluConvBnstatsInferMeta( filter_dims[1] * groups, phi::errors::InvalidArgument( "The number of input's channels should be equal to filter's channels " - "* groups for Op(FusedScaleBiasReluConvBnstats). But received: the " + "* groups for Op(FusedScaleBiasReluConvBn). But received: the " "input's" " channels is %d, " "the input's shape is [%s]; the filter's channels is %d, the " diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index b6b9c64314ca8..5b3de8aea0d60 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -451,33 +451,32 @@ void LayerNormActXPUInferMeta(const MetaTensor& x, float act_param, MetaTensor* y); -void FusedScaleBiasReluConvBnstatsInferMeta( - const MetaTensor& x, - const MetaTensor& w, - const MetaTensor& scale, - const MetaTensor& bias, - const MetaTensor& bn_scale, - const MetaTensor& bn_bias, - const MetaTensor& input_running_mean, - const MetaTensor& input_running_var, - const std::vector& paddings, - const std::vector& dilations, - const std::vector& strides, - const std::string& padding_algorithm, - int groups, - const std::string& data_format, - float momentum, - float epsilon, - bool fuse_prologue, - bool exhaustive_search, - int64_t accumulation_count, - MetaTensor* out, - MetaTensor* out_running_mean, - MetaTensor* out_running_var, - MetaTensor* saved_mean, - MetaTensor* saved_var, - MetaTensor* eq_scale, - MetaTensor* eq_bias); +void FusedScaleBiasReluConvBnInferMeta(const MetaTensor& x, + const MetaTensor& w, + const MetaTensor& scale, + const MetaTensor& bias, + const MetaTensor& bn_scale, + const MetaTensor& bn_bias, + const MetaTensor& input_running_mean, + const MetaTensor& input_running_var, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const std::string& padding_algorithm, + int groups, + const std::string& data_format, + float momentum, + float epsilon, + bool fuse_prologue, + bool exhaustive_search, + int64_t accumulation_count, + MetaTensor* out, + MetaTensor* out_running_mean, + MetaTensor* out_running_var, + MetaTensor* saved_mean, + MetaTensor* saved_var, + MetaTensor* eq_scale, + MetaTensor* eq_bias); void SqueezeExcitationInferMeta(const MetaTensor& x, const MetaTensor& filter, diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index cc2594c9a720a..617095fbff4d1 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -108,7 +108,7 @@ endif() if(NOT WITH_CUDNN_FRONTEND) list(REMOVE_ITEM kernel_cu - "fusion/gpu/fused_scale_bias_relu_conv_bnstats_kernel.cu") + "fusion/gpu/fused_scale_bias_relu_conv_bn_kernel.cu") endif() set(cc_search_pattern diff --git a/paddle/phi/kernels/fusion/gpu/fused_scale_bias_relu_conv_bnstats_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_scale_bias_relu_conv_bn_kernel.cu similarity index 92% rename from paddle/phi/kernels/fusion/gpu/fused_scale_bias_relu_conv_bnstats_kernel.cu rename to paddle/phi/kernels/fusion/gpu/fused_scale_bias_relu_conv_bn_kernel.cu index f891b94bf1eb7..ff2e85ed16ee8 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_scale_bias_relu_conv_bnstats_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_scale_bias_relu_conv_bn_kernel.cu @@ -472,34 +472,33 @@ void BNFinalizeImpl(const Context& dev_ctx, } template -void FusedScaleBiasReluConvBnstatsKernel( - const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& w, - const paddle::optional& scale, - const paddle::optional& bias, - const DenseTensor& bn_scale, - const DenseTensor& bn_bias, - const DenseTensor& input_running_mean, - const DenseTensor& input_running_var, - const std::vector& paddings, - const std::vector& dilations, - const std::vector& strides, - const std::string& padding_algorithm, - int groups, - const std::string& data_format, - float momentum, - float epsilon, - bool fuse_prologue, - bool exhaustive_search, - int64_t accumulation_count, - DenseTensor* out, - DenseTensor* out_running_mean, - DenseTensor* out_running_var, - DenseTensor* saved_mean, - DenseTensor* saved_var, - DenseTensor* eq_scale, - DenseTensor* eq_bias) { +void FusedScaleBiasReluConvBnKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& w, + const paddle::optional& scale, + const paddle::optional& bias, + const DenseTensor& bn_scale, + const DenseTensor& bn_bias, + const DenseTensor& input_running_mean, + const DenseTensor& input_running_var, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const std::string& padding_algorithm, + int groups, + const std::string& data_format, + float momentum, + float epsilon, + bool fuse_prologue, + bool exhaustive_search, + int64_t accumulation_count, + DenseTensor* out, + DenseTensor* out_running_mean, + DenseTensor* out_running_var, + DenseTensor* saved_mean, + DenseTensor* saved_var, + DenseTensor* eq_scale, + DenseTensor* eq_bias) { auto cudnn_version = phi::backends::gpu::DnnVersion(); PADDLE_ENFORCE_GE(cudnn_version, 8800, @@ -606,10 +605,10 @@ void FusedScaleBiasReluConvBnstatsKernel( } // namespace fusion } // namespace phi -PD_REGISTER_KERNEL(fused_scale_bias_relu_conv_bnstats, +PD_REGISTER_KERNEL(fused_scale_bias_relu_conv_bn, GPU, ALL_LAYOUT, - phi::fusion::FusedScaleBiasReluConvBnstatsKernel, + phi::fusion::FusedScaleBiasReluConvBnKernel, phi::dtype::float16) { kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index 03860aaa01a20..fb6b02ce37b37 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -504,7 +504,7 @@ if(NOT WITH_GPU endif() if(NOT WITH_CUDNN_FRONTEND) - list(REMOVE_ITEM TEST_OPS test_fused_scale_bias_relu_conv_bnstats_op) + list(REMOVE_ITEM TEST_OPS test_fused_scale_bias_relu_conv_bn_op) endif() # Some ops need to check results when gc is enabled diff --git a/test/legacy_test/test_fused_scale_bias_relu_conv_bnstats_op.py b/test/legacy_test/test_fused_scale_bias_relu_conv_bn_op.py similarity index 95% rename from test/legacy_test/test_fused_scale_bias_relu_conv_bnstats_op.py rename to test/legacy_test/test_fused_scale_bias_relu_conv_bn_op.py index 4beef728d9e11..4ed12fecc77f9 100644 --- a/test/legacy_test/test_fused_scale_bias_relu_conv_bnstats_op.py +++ b/test/legacy_test/test_fused_scale_bias_relu_conv_bn_op.py @@ -39,9 +39,9 @@ def skip_unit_test(): @skip_check_grad_ci(reason="no grap op") @unittest.skipIf(skip_unit_test(), skip_msg) -class TestFusedScaleBiasReluConvBnstatsOp(OpTest): +class TestFusedScaleBiasReluConvBnOp(OpTest): def setUp(self): - self.__class__.op_type = "fused_scale_bias_relu_conv_bnstats" + self.__class__.op_type = "fused_scale_bias_relu_conv_bn" self.dtype = np.float16 self.outputs = None self.padding_algorithm = "EXIPLICIT" @@ -219,17 +219,13 @@ def init_attr(self): self.exhaustive_search = False -class TestFusedScaleBiasReluConvBnstatsOpNoPrologue( - TestFusedScaleBiasReluConvBnstatsOp -): +class TestFusedScaleBiasReluConvBnOpNoPrologue(TestFusedScaleBiasReluConvBnOp): def init_attr(self): self.fuse_prologue = False self.exhaustive_search = False -class TestFusedScaleBiasReluConvBnstatsOpExhaustive( - TestFusedScaleBiasReluConvBnstatsOp -): +class TestFusedScaleBiasReluConvBnOpExhaustive(TestFusedScaleBiasReluConvBnOp): def init_attr(self): self.fuse_prologue = True self.exhaustive_search = True diff --git a/test/white_list/op_accuracy_white_list.py b/test/white_list/op_accuracy_white_list.py index d2520739339eb..ce0e7dde85fe8 100644 --- a/test/white_list/op_accuracy_white_list.py +++ b/test/white_list/op_accuracy_white_list.py @@ -87,7 +87,7 @@ NO_FP16_COMPARED_WITH_FP32_OP_LIST = [ 'fake_quantize_moving_average_abs_max', - 'fused_scale_bias_relu_conv_bnstats', + 'fused_scale_bias_relu_conv_bn', 'p_norm', ] diff --git a/tools/gpups_test.sh b/tools/gpups_test.sh index 822f0a11fec21..982e850deeb13 100644 --- a/tools/gpups_test.sh +++ b/tools/gpups_test.sh @@ -95,7 +95,7 @@ parallel_list="^init_phi_test$|\ ^test_fused_multi_transformer_int8_op$|\ ^test_fused_residual_dropout_bias$|\ ^test_fused_rotary_position_embedding$|\ -^test_fused_scale_bias_relu_conv_bnstats_op$|\ +^test_fused_scale_bias_relu_conv_bn_op$|\ ^test_fused_token_prune_op$|\ ^test_fused_transformer_encoder_layer$|\ ^test_fused_transformer_with_amp_decorator$|\ From 99c5d30fabc147ae4d64d1b9b123bdbcddecac72 Mon Sep 17 00:00:00 2001 From: "Tian Zheng (Engrg-Hardware 1)" Date: Sun, 29 Oct 2023 07:35:32 +0000 Subject: [PATCH 2/2] Add fused_scale_bias_add_relu --- .../pir/dialect/op_generator/ops_api_gen.py | 2 + paddle/phi/api/yaml/fused_ops.yaml | 10 + paddle/phi/infermeta/fusion.cc | 26 ++ paddle/phi/infermeta/fusion.h | 10 + paddle/phi/kernels/CMakeLists.txt | 3 +- paddle/phi/kernels/autotune/cache.cc | 3 + paddle/phi/kernels/autotune/cache.h | 5 +- .../gpu/fused_scale_bias_add_relu_kernel.cu | 244 ++++++++++++++++++ test/legacy_test/CMakeLists.txt | 1 + .../test_fused_scale_bias_add_relu_op.py | 130 ++++++++++ test/white_list/op_accuracy_white_list.py | 1 + tools/gpups_test.sh | 1 + 12 files changed, 433 insertions(+), 3 deletions(-) create mode 100644 paddle/phi/kernels/fusion/gpu/fused_scale_bias_add_relu_kernel.cu create mode 100644 test/legacy_test/test_fused_scale_bias_add_relu_op.py diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index e265b7c44e03c..642cf224c5c18 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -74,6 +74,7 @@ 'fused_fc_elementwise_layernorm', 'fused_multi_transformer_xpu', 'fused_scale_bias_relu_conv_bn', + 'fused_scale_bias_add_relu', 'fusion_transpose_flatten_concat', 'generate_sequence_xpu', 'layer_norm_act_xpu', @@ -104,6 +105,7 @@ 'fused_batch_norm_act_', 'fused_bn_add_activation_', 'fused_scale_bias_relu_conv_bn', + 'fused_scale_bias_add_relu', 'memcpy', 'print', 'recv_v2', diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index ba5e5bbaa07ab..0e764467ddc3c 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -206,6 +206,16 @@ backward: fused_rotary_position_embedding_grad support_dygraph_mode : true +- op : fused_scale_bias_add_relu + args : (Tensor x1, Tensor scale1, Tensor bias1, Tensor x2, Tensor scale2, Tensor bias2, bool fuse_dual, bool exhaustive_search) + optional : scale2, bias2 + output : Tensor(y) + infer_meta : + func : FusedScaleBiasAddReluInferMeta + kernel : + func : fused_scale_bias_add_relu + data_type : x1 + - op : fused_scale_bias_relu_conv_bn args : (Tensor x, Tensor w, Tensor scale, Tensor bias, Tensor bn_scale, Tensor bn_bias, Tensor input_running_mean, Tensor input_running_var, int[] paddings, int[] dilations, int[] strides, str padding_algorithm, int groups, str data_format, float momentum, float epsilon, bool fuse_prologue, bool exhaustive_search, int64_t accumulation_count = 0) optional : scale, bias diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 06a457ccaba09..e6fa399e3b7c5 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -1820,6 +1820,32 @@ void FusedScaleBiasReluConvBnInferMeta(const MetaTensor& x, eq_bias->set_dims(c_dims); } +void FusedScaleBiasAddReluInferMeta(const MetaTensor& x1, + const MetaTensor& scale1, + const MetaTensor& bias1, + const MetaTensor& x2, + const MetaTensor& scale2, + const MetaTensor& bias2, + bool fuse_dual, + bool exhaustive_search, + MetaTensor* y) { + // check optional inputs + if (fuse_dual) { + bool has_scale2 = !!scale2; + bool has_bias2 = !!bias2; + PADDLE_ENFORCE(has_scale2 && has_bias2, + phi::errors::InvalidArgument( + "Argument scale2 and bias2 should be provided when " + "fuse_dual is set, but got has_scale2=%d, has_bias2=%d, " + "fuse_dual=%d.", + has_scale2, + has_bias2, + fuse_dual)); + } + // set output dims + y->set_dims(x1.dims()); +} + void SqueezeExcitationInferMeta(const MetaTensor& x, const MetaTensor& filter, const MetaTensor& filter_max, diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index 5b3de8aea0d60..7b1508d5a639c 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -478,6 +478,16 @@ void FusedScaleBiasReluConvBnInferMeta(const MetaTensor& x, MetaTensor* eq_scale, MetaTensor* eq_bias); +void FusedScaleBiasAddReluInferMeta(const MetaTensor& x1, + const MetaTensor& scale1, + const MetaTensor& bias1, + const MetaTensor& x2, + const MetaTensor& scale2, + const MetaTensor& bias2, + bool fuse_prologue, + bool exhaustive_search, + MetaTensor* y); + void SqueezeExcitationInferMeta(const MetaTensor& x, const MetaTensor& filter, const MetaTensor& filter_max, diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 617095fbff4d1..5f5330df4dfac 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -108,7 +108,8 @@ endif() if(NOT WITH_CUDNN_FRONTEND) list(REMOVE_ITEM kernel_cu - "fusion/gpu/fused_scale_bias_relu_conv_bn_kernel.cu") + "fusion/gpu/fused_scale_bias_relu_conv_bn_kernel.cu" + "fusion/gpu/fused_scale_bias_add_relu_kernel.cu") endif() set(cc_search_pattern diff --git a/paddle/phi/kernels/autotune/cache.cc b/paddle/phi/kernels/autotune/cache.cc index ba48e2e00ce54..0f1c9171264d1 100644 --- a/paddle/phi/kernels/autotune/cache.cc +++ b/paddle/phi/kernels/autotune/cache.cc @@ -52,6 +52,9 @@ std::string AlgorithmTypeString(int64_t algo_type) { return "scale_bias_relu_conv_bnstats"; } else if (algo_type == static_cast(AlgorithmType::kBNFinalize)) { return "bn_finalize"; + } else if (algo_type == + static_cast(AlgorithmType::kScaleBiasAddRelu)) { + return "scale_bias_add_relu"; } #endif return std::to_string(algo_type); diff --git a/paddle/phi/kernels/autotune/cache.h b/paddle/phi/kernels/autotune/cache.h index 34b98e28f50c7..ff47bfffcc448 100644 --- a/paddle/phi/kernels/autotune/cache.h +++ b/paddle/phi/kernels/autotune/cache.h @@ -57,7 +57,8 @@ enum class AlgorithmType { kConvBackwardFilterV8 = 12, kScaleBiasReluConvBNstats = 13, kBNFinalize = 14, - kAlgorithmCount = 15 + kScaleBiasAddRelu = 15, + kAlgorithmCount = 16 #endif }; @@ -181,7 +182,7 @@ class AutoTuneCache { } #ifdef PADDLE_WITH_CUDNN_FRONTEND } else if (algo_type >= AlgorithmType::kConvForwardV8 && - algo_type <= AlgorithmType::kBNFinalize) { + algo_type < AlgorithmType::kAlgorithmCount) { int64_t key = static_cast(algo_type); if (cudnn_v8_auto_tune_map_.find(key) == cudnn_v8_auto_tune_map_.end()) { CudnnFrontendPlanCache cache; diff --git a/paddle/phi/kernels/fusion/gpu/fused_scale_bias_add_relu_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_scale_bias_add_relu_kernel.cu new file mode 100644 index 0000000000000..ff5edd689f7f3 --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/fused_scale_bias_add_relu_kernel.cu @@ -0,0 +1,244 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/phi/backends/gpu/cuda/cudnn_helper.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/core/flags.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/autotune/cache.h" +#include "paddle/phi/kernels/cpu/conv_util.h" +#include "paddle/phi/kernels/funcs/batch_norm_utils.h" +#include "paddle/phi/kernels/gpudnn/conv_cudnn_frontend.h" + +PHI_DECLARE_bool(cudnn_deterministic); +PHI_DECLARE_bool(cudnn_exhaustive_search); + +namespace phi { +namespace fusion { + +using helper = phi::CudnnFrontendConvHelper; + +template +using CudnnDataType = phi::backends::gpu::CudnnDataType; + +template +void FusedScaleBiasAddReluKernel(const Context& dev_ctx, + const DenseTensor& x1, + const DenseTensor& scale1, + const DenseTensor& bias1, + const DenseTensor& x2, + const paddle::optional& scale2, + const paddle::optional& bias2, + bool fuse_dual, + bool exhaustive_search, + DenseTensor* y) { + PADDLE_ENFORCE_GE(dev_ctx.GetComputeCapability(), + 80, + phi::errors::PreconditionNotMet( + "This op only supports Ampere and later devices, " + "but got compute capability: %d.", + dev_ctx.GetComputeCapability())); + auto& plan_cache = phi::autotune::AutoTuneCache::Instance().GetConvV8( + phi::autotune::AlgorithmType::kScaleBiasAddRelu); + + // exhaustive search + exhaustive_search = exhaustive_search || FLAGS_cudnn_exhaustive_search; + bool deterministic = FLAGS_cudnn_deterministic; + PADDLE_ENFORCE_EQ(exhaustive_search && deterministic, + false, + phi::errors::InvalidArgument( + "Cann't set exhaustive_search True and " + "FLAGS_cudnn_deterministic True at same time.")); + + // alloc output variables + dev_ctx.template Alloc(y); + + // get handles + auto handle = dev_ctx.cudnn_handle(); + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); + // create tensor descriptors + cudnnTensorFormat_t layout_format = CUDNN_TENSOR_NHWC; + auto tensor_format = phi::backends::gpu::ToCudnnDataType(x1.dtype()); + auto tensor_format_math = CUDNN_DATA_FLOAT; + auto compute_dtype = CUDNN_DATA_FLOAT; + + auto dim_x = + phi::backends::gpu::TransformDimOrder(phi::vectorize(x1.dims())); + std::vector dim_c(dim_x.size(), 1); + dim_c[1] = dim_x[1]; // [1, C, 1, 1] + + std::vector data_ptrs; + std::vector uids; + int64_t uid = 100; + + // inputs + auto x1_desc = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format); + data_ptrs.push_back(const_cast(x1.data())); + uids.push_back(uid); + + auto x2_desc = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format); + data_ptrs.push_back(const_cast(x2.data())); + uids.push_back(uid); + + auto scale1_desc = helper::GetGeneralTensorDescriptor( + dim_c, layout_format, ++uid, 16, tensor_format); + data_ptrs.push_back(const_cast(scale1.data())); + uids.push_back(uid); + + auto bias1_desc = helper::GetGeneralTensorDescriptor( + dim_c, layout_format, ++uid, 16, tensor_format); + data_ptrs.push_back(const_cast(bias1.data())); + uids.push_back(uid); + + // dispensable inputs + auto scale2_desc = helper::GetGeneralTensorDescriptor( + dim_c, layout_format, ++uid, 16, tensor_format); + if (fuse_dual) { + data_ptrs.push_back(const_cast(scale2->data())); + uids.push_back(uid); + } + + auto bias2_desc = helper::GetGeneralTensorDescriptor( + dim_c, layout_format, ++uid, 16, tensor_format); + if (fuse_dual) { + data_ptrs.push_back(const_cast(bias2->data())); + uids.push_back(uid); + } + + // outputs + auto y_desc = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format); + data_ptrs.push_back(y->data()); + uids.push_back(uid); + + // virtual outputs + auto after_scale1 = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format_math, true); + + auto after_bias1 = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format_math, true); + + auto after_scale2 = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format_math, true); + + auto after_bias2 = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format_math, true); + + auto after_add = helper::GetGeneralTensorDescriptor( + dim_x, layout_format, ++uid, 16, tensor_format_math, true); + + // build ops + auto scale1_op = helper::MakePointwiseOp( + CUDNN_POINTWISE_MUL, compute_dtype, x1_desc, scale1_desc, after_scale1); + + auto bias1_op = helper::MakePointwiseOp(CUDNN_POINTWISE_ADD, + compute_dtype, + after_scale1, + bias1_desc, + after_bias1); + + auto scale2_op = helper::MakePointwiseOp( + CUDNN_POINTWISE_MUL, compute_dtype, x2_desc, scale2_desc, after_scale2); + + auto bias2_op = helper::MakePointwiseOp(CUDNN_POINTWISE_ADD, + compute_dtype, + after_scale2, + bias2_desc, + after_bias2); + + cudnn_frontend::Tensor* tensor_to_add = fuse_dual ? &after_bias2 : &x2_desc; + + auto add_op = helper::MakePointwiseOp(CUDNN_POINTWISE_ADD, + compute_dtype, + after_bias1, + *tensor_to_add, + after_add); + + auto relu_desc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_RELU_FWD) + .setComputeType(compute_dtype) + .build(); + + auto relu_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(after_add) + .setyDesc(y_desc) + .setpwDesc(relu_desc) + .build(); + + // build op graph + std::vector ops; + if (fuse_dual) { + ops = std::vector( + {&scale1_op, &bias1_op, &scale2_op, &bias2_op, &add_op, &relu_op}); + } else { + ops = std::vector( + {&scale1_op, &bias1_op, &add_op, &relu_op}); + } + + auto op_graph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle) + .setOperationGraph(ops.size(), ops.data()) + .build(); + VLOG(6) << op_graph.describe(); + + cudnn_frontend::feature_vector_t feature_vector; + phi::autotune::BuildFeatureVector(&feature_vector, dim_x, fuse_dual); + + if (plan_cache.FindPlan(feature_vector, handle)) { + const cudnn_frontend::ExecutionPlan* cached_plan = nullptr; + int64_t workspace_size = 0; + plan_cache.GetPlanAndWorkspaceSize( + feature_vector, &cached_plan, &workspace_size, handle); + helper::ExecutePlan(handle, + &workspace_handle, + &data_ptrs, + &uids, + cached_plan->get_raw_desc(), + workspace_size); + return; + } + + auto plans = helper::FindExecutionPlans(&op_graph, + exhaustive_search, + deterministic, + &data_ptrs, + &uids, + handle, + &workspace_handle); + + helper::ExecutePlansAndCache(handle, + &workspace_handle, + &data_ptrs, + &uids, + &plans, + exhaustive_search, + feature_vector, + &plan_cache); +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fused_scale_bias_add_relu, + GPU, + ALL_LAYOUT, + phi::fusion::FusedScaleBiasAddReluKernel, + phi::dtype::float16) {} diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index fb6b02ce37b37..6e73892602043 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -505,6 +505,7 @@ endif() if(NOT WITH_CUDNN_FRONTEND) list(REMOVE_ITEM TEST_OPS test_fused_scale_bias_relu_conv_bn_op) + list(REMOVE_ITEM TEST_OPS test_fused_scale_bias_add_relu_op) endif() # Some ops need to check results when gc is enabled diff --git a/test/legacy_test/test_fused_scale_bias_add_relu_op.py b/test/legacy_test/test_fused_scale_bias_add_relu_op.py new file mode 100644 index 0000000000000..44952e1ea23d1 --- /dev/null +++ b/test/legacy_test/test_fused_scale_bias_add_relu_op.py @@ -0,0 +1,130 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np +from op_test import OpTest, skip_check_grad_ci + +import paddle +from paddle.base import core + + +def skip_unit_test(): + return ( + not paddle.is_compiled_with_cuda() + or paddle.device.cuda.get_device_capability()[0] < 8 + ) + + +skip_msg = "only support with cuda and Ampere or later devices" + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(skip_unit_test(), skip_msg) +class TestFusedScaleBiasAddReluOp(OpTest): + def setUp(self): + self.op_type = "fused_scale_bias_add_relu" + self.dtype = np.float16 + self.outputs = None + + self.init_test_case() + self.init_attr() + + self.attrs = { + 'fuse_dual': self.fuse_dual, + 'exhaustive_search': self.exhaustive_search, + } + + c_dim = self.input_size[-1] + x1_input = np.random.random(self.input_size).astype(self.dtype) - 0.5 + x2_input = np.random.random(self.input_size).astype(self.dtype) - 0.5 + scale1_input = np.random.random(c_dim).astype(self.dtype) - 0.5 + scale2_input = np.random.random(c_dim).astype(self.dtype) - 0.5 + bias1_input = np.random.random(c_dim).astype(self.dtype) - 0.5 + bias2_input = np.random.random(c_dim).astype(self.dtype) - 0.5 + + # calculate reference output + reshaped_scale1_input = scale1_input.reshape(1, 1, 1, c_dim) + reshaped_scale2_input = scale2_input.reshape(1, 1, 1, c_dim) + reshaped_bias1_input = bias1_input.reshape(1, 1, 1, c_dim) + reshaped_bias2_input = bias2_input.reshape(1, 1, 1, c_dim) + + after_bias1 = x1_input * reshaped_scale1_input + reshaped_bias1_input + after_bias2 = x2_input * reshaped_scale2_input + reshaped_bias2_input + + if self.fuse_dual: + after_add = after_bias1 + after_bias2 + else: + after_add = after_bias1 + x2_input + + y_output = np.maximum(after_add, 0).astype(self.dtype) + + if self.fuse_dual: + self.inputs = { + 'x1': x1_input, + 'scale1': scale1_input, + 'bias1': bias1_input, + 'x2': x2_input, + 'scale2': scale2_input, + 'bias2': bias2_input, + } + else: + self.inputs = { + 'x1': x1_input, + 'scale1': scale1_input, + 'bias1': bias1_input, + 'x2': x2_input, + } + + self.outputs = { + 'y': y_output, + } + + def has_cuda(self): + return core.is_compiled_with_cuda() + + def test_check_output(self): + if self.has_cuda(): + place = core.CUDAPlace(0) + self.check_output_with_place(place, check_dygraph=False, atol=2e-2) + + def init_test_case(self): + self.input_size = [2, 8, 8, 16] + + def init_attr(self): + self.fuse_dual = False + self.exhaustive_search = False + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(skip_unit_test(), skip_msg) +class TestFusedScaleBiasAddReluOpDual(TestFusedScaleBiasAddReluOp): + def init_attr(self): + self.fuse_dual = True + self.exhaustive_search = False + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(skip_unit_test(), skip_msg) +class TestFusedScaleBiasAddReluOpExhaustive(TestFusedScaleBiasAddReluOp): + def init_attr(self): + self.fuse_dual = False + self.exhaustive_search = True + + +if __name__ == '__main__': + np.random.seed(0) + unittest.main() diff --git a/test/white_list/op_accuracy_white_list.py b/test/white_list/op_accuracy_white_list.py index ce0e7dde85fe8..5ad871e071ba4 100644 --- a/test/white_list/op_accuracy_white_list.py +++ b/test/white_list/op_accuracy_white_list.py @@ -88,6 +88,7 @@ NO_FP16_COMPARED_WITH_FP32_OP_LIST = [ 'fake_quantize_moving_average_abs_max', 'fused_scale_bias_relu_conv_bn', + 'fused_scale_bias_add_relu', 'p_norm', ] diff --git a/tools/gpups_test.sh b/tools/gpups_test.sh index 982e850deeb13..6221a4b4f90e1 100644 --- a/tools/gpups_test.sh +++ b/tools/gpups_test.sh @@ -96,6 +96,7 @@ parallel_list="^init_phi_test$|\ ^test_fused_residual_dropout_bias$|\ ^test_fused_rotary_position_embedding$|\ ^test_fused_scale_bias_relu_conv_bn_op$|\ +^test_fused_scale_bias_add_relu_op$|\ ^test_fused_token_prune_op$|\ ^test_fused_transformer_encoder_layer$|\ ^test_fused_transformer_with_amp_decorator$|\