Skip to content

Commit

Permalink
add compatibility registry
Browse files Browse the repository at this point in the history
  • Loading branch information
cryoco committed Sep 16, 2020
1 parent 654b929 commit 3951be4
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 0 deletions.
12 changes: 12 additions & 0 deletions paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/platform/enforce.h"

Expand Down Expand Up @@ -225,3 +226,14 @@ REGISTER_PASS(conv_affine_channel_fuse_pass,
paddle::framework::ir::ConvAffineChannelFusePass);
REGISTER_PASS(conv_eltwiseadd_affine_channel_fuse_pass,
paddle::framework::ir::ConvEltwiseAddAffineChannelFusePass);
REGISTER_PASS_CAPABILITY(conv_affine_channel_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.EQ("affine_channel", 0));
REGISTER_PASS_CAPABILITY(conv_eltwiseadd_affine_channel_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.EQ("elementwise_add", 0)
.EQ("affine_channel", 0));
12 changes: 12 additions & 0 deletions paddle/fluid/framework/ir/conv_bn_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/platform/enforce.h"

Expand Down Expand Up @@ -372,3 +373,14 @@ REGISTER_PASS(depthwise_conv_bn_fuse_pass,
paddle::framework::ir::DepthwiseConvBNFusePass);
REGISTER_PASS(depthwise_conv_eltwiseadd_bn_fuse_pass,
paddle::framework::ir::DepthwiseConvEltwiseAddBNFusePass);
REGISTER_PASS_CAPABILITY(conv_bn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.EQ("batch_norm", 0));
REGISTER_PASS_CAPABILITY(conv_eltwiseadd_bn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.EQ("elementwise_add", 0)
.EQ("batch_norm", 0));
6 changes: 6 additions & 0 deletions paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License. */
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"

#define MAX_NUM_FC 10

Expand Down Expand Up @@ -388,3 +389,8 @@ void RepeatedFCReluFusePass::ApplyImpl(ir::Graph* graph) const {

REGISTER_PASS(repeated_fc_relu_fuse_pass,
paddle::framework::ir::RepeatedFCReluFusePass);
REGISTER_PASS_CAPABILITY(repeated_fc_relu_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("fc", 0)
.EQ("relu", 0));
6 changes: 6 additions & 0 deletions paddle/fluid/framework/ir/shuffle_channel_detect_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/shuffle_channel_detect_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -95,3 +96,8 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const {

REGISTER_PASS(shuffle_channel_detect_pass,
paddle::framework::ir::ShuffleChannelDetectPass);
REGISTER_PASS_CAPABILITY(shuffle_channel_detect_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("reshape2", 0)
.EQ("transpose2", 0));

0 comments on commit 3951be4

Please sign in to comment.