diff --git a/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc b/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc index b50b4f37caecd8..fd8b55a6b7deb9 100644 --- a/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc @@ -18,6 +18,7 @@ #include #include #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/platform/enforce.h" @@ -225,3 +226,14 @@ REGISTER_PASS(conv_affine_channel_fuse_pass, paddle::framework::ir::ConvAffineChannelFusePass); REGISTER_PASS(conv_eltwiseadd_affine_channel_fuse_pass, paddle::framework::ir::ConvEltwiseAddAffineChannelFusePass); +REGISTER_PASS_CAPABILITY(conv_affine_channel_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("conv2d", 0) + .EQ("affine_channel", 0)); +REGISTER_PASS_CAPABILITY(conv_eltwiseadd_affine_channel_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("conv2d", 0) + .EQ("elementwise_add", 0) + .EQ("affine_channel", 0)); diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc index 9d3e0806ac79d8..fb787e08814429 100644 --- a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc @@ -18,6 +18,7 @@ #include #include #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/platform/enforce.h" @@ -372,3 +373,14 @@ REGISTER_PASS(depthwise_conv_bn_fuse_pass, paddle::framework::ir::DepthwiseConvBNFusePass); REGISTER_PASS(depthwise_conv_eltwiseadd_bn_fuse_pass, paddle::framework::ir::DepthwiseConvEltwiseAddBNFusePass); +REGISTER_PASS_CAPABILITY(conv_bn_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("conv2d", 0) + .EQ("batch_norm", 0)); +REGISTER_PASS_CAPABILITY(conv_eltwiseadd_bn_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("conv2d", 0) + .EQ("elementwise_add", 0) + .EQ("batch_norm", 0)); diff --git a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc index bc560fef719e59..23f794c11c2392 100644 --- a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_version_registry.h" #define MAX_NUM_FC 10 @@ -388,3 +389,8 @@ void RepeatedFCReluFusePass::ApplyImpl(ir::Graph* graph) const { REGISTER_PASS(repeated_fc_relu_fuse_pass, paddle::framework::ir::RepeatedFCReluFusePass); +REGISTER_PASS_CAPABILITY(repeated_fc_relu_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("fc", 0) + .EQ("relu", 0)); diff --git a/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc index 60bbee373a9ff4..74ba0093a17beb 100644 --- a/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc +++ b/paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc @@ -16,6 +16,7 @@ #include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/shuffle_channel_detect_pass.h" +#include "paddle/fluid/framework/op_version_registry.h" namespace paddle { namespace framework { @@ -95,3 +96,8 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const { REGISTER_PASS(shuffle_channel_detect_pass, paddle::framework::ir::ShuffleChannelDetectPass); +REGISTER_PASS_CAPABILITY(shuffle_channel_detect_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("reshape2", 0) + .EQ("transpose2", 0));