Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update username repo #7

Merged
merged 56 commits into from
Nov 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
93c3977
open a part of GPU unittest for windows (#28378)
zhwesky2010 Nov 16, 2020
a24d186
fix nccl init failed in parallel dygraph mode (#28497)
danleifeng Nov 16, 2020
1de3cdd
Fix summary api for rnn gru lstm (#28566)
LielinJiang Nov 16, 2020
1c3eef4
Fix vgg error when num_classes is given (#28557)
LielinJiang Nov 16, 2020
90805e2
Register op_version for new attribute use_addto (#28463)
zhiqiu Nov 16, 2020
f962bd3
Fix cudnn workspace limit in cudnn-8 (#28611)
zhiqiu Nov 16, 2020
8b97bb2
Update cmake for arm ft and fix a bug for Predictor dtor. (#28586)
jiweibo Nov 16, 2020
f7dd889
Support squeezed label as input in paddle.metric.Accuracy (#28535)
qingqing01 Nov 16, 2020
c4d22c8
modified timeout value for some ut (#28616)
XieYunshen Nov 16, 2020
a3bc3bc
Fix scaled_params append error in AdamW. (#28633)
guoshengCS Nov 16, 2020
110febd
Fix gradients with ignore_idx in softmax_with_cross_entropy (#28622)
guoshengCS Nov 16, 2020
cf2c42a
fix exec nightly error on mac (#28567)
XieYunshen Nov 16, 2020
2b1e7e5
Polish where english doc (#28595)
GaoWei8 Nov 16, 2020
c5c273c
[Dy2stat] Fix Using Tuple for Transpose in Dy2stat (#28574)
zhhsplendid Nov 16, 2020
89d27de
DataLoader support not auto collate batch (#28425)
heavengate Nov 16, 2020
b889a0c
add gaussian_random op_version (#28602)
pangyoki Nov 16, 2020
72e068f
fix test_multinomial (#28558)
pangyoki Nov 16, 2020
804271c
Op version python mkldnn_inplace test (#28354)
lidanqing-intel Nov 16, 2020
2cb71c0
Add checkpoint to quantize (#28612)
wozna Nov 16, 2020
ece1e4c
Add weighted random sampler (#28545)
heavengate Nov 16, 2020
a972c33
refine gather OP performance for dynamic mode (#28587)
wangchaochaohu Nov 16, 2020
8f2656e
fix the gradient bug for the topk v2
wawltor Nov 16, 2020
b2f7ab6
bug fix, test=develop (#28648)
Nov 16, 2020
d1e84f3
Add some ops for cacluating output scale, test=develop (#28644)
juncaipeng Nov 16, 2020
361a539
fix doc of save/load (#28645)
zhwesky2010 Nov 16, 2020
a083c76
adjust signal failed wait time (#28640)
chenwhql Nov 16, 2020
2cd10fc
fix 2.0 api docs (#28445)
zhupengyang Nov 17, 2020
65aac81
Fix fake_quant error when cout > 1024, test=develop (#28603)
juncaipeng Nov 17, 2020
68ee7f7
fix overwrite for gather OP of API2.0(#28659)
wangchaochaohu Nov 17, 2020
57dab95
add datanorm op new scale_w register (#28657)
Shixiaowei02 Nov 17, 2020
8040fa2
Fix output dtype inconsistent with input (#28649)
Aurelius84 Nov 17, 2020
d71c346
fix pool exclusive and delete disable_static (#28655)
LDOUBLEV Nov 17, 2020
bf14365
fix lstm OP compile error on windows (#28667)
zhwesky2010 Nov 17, 2020
82f0b5e
adapt pad const (#28585)
Nov 17, 2020
912a5c3
fix the matmul_v2 test for cuda11 (#28635)
wangchaochaohu Nov 17, 2020
80d2024
bug fix, test=develop (#28674)
Nov 17, 2020
cdc4e66
fix lenet num classes (#28642)
LielinJiang Nov 17, 2020
6d8d3d4
[oneDNN] Layer norm bf16 kernel (#28619)
jczaja Nov 17, 2020
e4f9415
update doc, test=document_fix (#28498)
Nov 17, 2020
b6f86b8
Fix Using "isinstance" in Loop, test=develop (#28641)
zhhsplendid Nov 17, 2020
11e32ba
Add matmtl_v2 to amp list (#28693)
zhiqiu Nov 17, 2020
5050e76
Support user-defined activation/weight quantize and preprocess. (#28570)
baiyfbupt Nov 18, 2020
358d6bc
Fix test_weight_decay_extend random failed on windows (#28643)
chenwhql Nov 18, 2020
f78211d
Add delta file for precision test
chalsliu Nov 18, 2020
858ffa0
Fix the dropout setting when not initialized in rnn_op. (#28561)
guoshengCS Nov 18, 2020
7eeb99f
Add basic hook classes for dygraph & implement reduce hook (#28584)
chenwhql Nov 18, 2020
db2e6ce
add two paddle-2.0 apis: paddle.static.io.save_inference_model and pa…
T8T9 Nov 18, 2020
532e4bb
fix docs (#28683)
LielinJiang Nov 18, 2020
01a14e1
Add with_pool args for vgg (#28684)
LielinJiang Nov 18, 2020
628fb29
modified the sys adress of quickly disable file (#28660)
XieYunshen Nov 18, 2020
e880c90
fix error when setting ut timeout value (#28696)
XieYunshen Nov 18, 2020
8c75b25
Support Tensor for attr_scale and attr_size (#28677)
tink2123 Nov 18, 2020
5a9f688
[Sharding] add new features (#28568)
JZ-LIANG Nov 18, 2020
20b1276
faster the compare ops dygraph model speed
wawltor Nov 18, 2020
19226ba
Simplify the timeline, to remove the prefix of each event. (#28723)
Xreki Nov 18, 2020
3d09929
Add check for non-dispensable input (#28666)
zhiqiu Nov 18, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,25 +74,30 @@ if(WIN32)
endforeach(flag_var)
endif()

# windows build turn off warnings.
# windows build turn off warnings, use parallel compiling.
foreach(flag_var
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO
CMAKE_C_FLAGS CMAKE_C_FLAGS_DEBUG CMAKE_C_FLAGS_RELEASE
CMAKE_C_FLAGS_MINSIZEREL CMAKE_C_FLAGS_RELWITHDEBINFO)
string(REGEX REPLACE "/W[1-4]" " /W0 " ${flag_var} "${${flag_var}}")
set(${flag_var} "${${flag_var}} /MP")
endforeach(flag_var)
foreach(flag_var CMAKE_CXX_FLAGS CMAKE_C_FLAGS)
set(${flag_var} "${${flag_var}} /w")
endforeach(flag_var)

set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /wd4068 /wd4129 /wd4244 /wd4267 /wd4297 /wd4530 /wd4577 /wd4819 /wd4838 /MP")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4068 /wd4129 /wd4244 /wd4267 /wd4297 /wd4530 /wd4577 /wd4819 /wd4838 /MP")
message(STATUS "Using parallel compiling (/MP)")
set(PADDLE_LINK_FLAGS "/IGNORE:4006 /IGNORE:4098 /IGNORE:4217 /IGNORE:4221")
set(CMAKE_STATIC_LINKER_FLAGS "${CMAKE_STATIC_LINKER_FLAGS} ${PADDLE_LINK_FLAGS}")
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} ${PADDLE_LINK_FLAGS}")
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${PADDLE_LINK_FLAGS}")
# Windows Remove /Zi, /ZI for Release, MinSizeRel builds
foreach(flag_var
CMAKE_C_FLAGS CMAKE_C_FLAGS_RELEASE CMAKE_C_FLAGS_MINSIZEREL
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_RELEASE CMAKE_CXX_FLAGS_MINSIZEREL)
if(${flag_var} MATCHES "/Z[iI]")
string(REGEX REPLACE "/Z[iI]" "" ${flag_var} "${${flag_var}}")
endif()
endforeach(flag_var)

set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /wd4068 /wd4129 /wd4244 /wd4267 /wd4297 /wd4530 /wd4577 /wd4819 /wd4838")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4068 /wd4129 /wd4244 /wd4267 /wd4297 /wd4530 /wd4577 /wd4819 /wd4838")
else(WIN32)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=deprecated-declarations -Wno-deprecated-declarations")
endif(WIN32)
Expand Down
10 changes: 3 additions & 7 deletions cmake/init.cmake
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Attention: cmake will append these flags to compile command automatically.
# So if you want to add global option, change this file rather than flags.cmake

# NOT WIN32
# Linux
# DEBUG: default: "-g"
# RELEASE: default: "-O3 -DNDEBUG"
# RELWITHDEBINFO: default: "-O2 -g -DNDEBUG"
Expand All @@ -17,6 +17,8 @@ if(NOT WIN32)
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -DNDEBUG")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O2 -g -DNDEBUG")
set(CMAKE_CXX_FLAGS_MINSIZEREL "-Os -DNDEBUG")
else()
set(WIN_PROPS ${CMAKE_SOURCE_DIR}/cmake/paddle_win.props)
endif()

if(WITH_GPU)
Expand All @@ -25,9 +27,3 @@ if(WITH_GPU)
set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "-O2 -g -DNDEBUG")
set(CMAKE_CUDA_FLAGS_MINSIZEREL "-O1 -DNDEBUG")
endif()

if(WIN32)
set(WIN_PROPS ${CMAKE_SOURCE_DIR}/cmake/paddle_win.props)
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -Os -DNDEBUG")
endif()

4 changes: 2 additions & 2 deletions paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,11 @@ REGISTER_PASS(conv_eltwiseadd_affine_channel_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_affine_channel_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.LE("conv2d", 1)
.EQ("affine_channel", 0));
REGISTER_PASS_CAPABILITY(conv_eltwiseadd_affine_channel_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.LE("conv2d", 1)
.EQ("elementwise_add", 0)
.EQ("affine_channel", 0));
4 changes: 2 additions & 2 deletions paddle/fluid/framework/ir/conv_bn_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -383,11 +383,11 @@ REGISTER_PASS(depthwise_conv_eltwiseadd_bn_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_bn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.LE("conv2d", 1)
.EQ("batch_norm", 0));
REGISTER_PASS_CAPABILITY(conv_eltwiseadd_bn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.LE("conv2d", 1)
.EQ("elementwise_add", 0)
.EQ("batch_norm", 0));
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h"

#include <string>

#include "paddle/fluid/framework/op_version_registry.h"

namespace paddle {
Expand Down Expand Up @@ -119,7 +121,7 @@ REGISTER_PASS(conv_elementwise_add2_act_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_elementwise_add2_act_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.LE("conv2d", 1)
.EQ("elementwise_add", 0)
.EQ("relu", 0)
.EQ("identity", 0));
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.h"

#include <string>

#include "paddle/fluid/framework/ir/graph_viz_pass.h"
Expand Down Expand Up @@ -107,7 +108,7 @@ REGISTER_PASS(conv_elementwise_add_act_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_elementwise_add_act_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.LE("conv2d", 1)
.EQ("elementwise_add", 0)
.EQ("relu", 0)
.EQ("identity", 0));
3 changes: 2 additions & 1 deletion paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.h"

#include <string>

#include "paddle/fluid/framework/ir/graph_viz_pass.h"
Expand Down Expand Up @@ -93,5 +94,5 @@ REGISTER_PASS(conv_elementwise_add_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_elementwise_add_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.LE("conv2d", 1)
.EQ("elementwise_add", 0));
4 changes: 2 additions & 2 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2102,8 +2102,8 @@ PDNode *patterns::Bfloat16Placement::operator()(
const std::unordered_set<std::string> &bfloat16_enabled_op_types) {
std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>({"concat", "conv2d", "fusion_gru", "gelu",
"reshape2", "softmax", "sum",
"transpose2"});
"layer_norm", "reshape2", "softmax",
"sum", "transpose2"});
if (!bfloat16_enabled_op_types.empty()) {
supported_op_types = bfloat16_enabled_op_types;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
// limitations under the License.

#include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h"

#include <vector>

#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"

Expand Down Expand Up @@ -107,29 +109,29 @@ REGISTER_PASS(conv_relu_mkldnn_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_relu_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.LE("conv2d", 1)
.EQ("relu", 0));

REGISTER_PASS(conv_leaky_relu_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DLeakyReLUFusePass);
REGISTER_PASS_CAPABILITY(conv_leaky_relu_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.LE("conv2d", 1)
.LE("leaky_relu", 1));

REGISTER_PASS(conv_relu6_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DReLU6FusePass);
REGISTER_PASS_CAPABILITY(conv_relu6_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.LE("conv2d", 1)
.EQ("relu6", 0));

REGISTER_PASS(conv_swish_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DSwishFusePass);
REGISTER_PASS_CAPABILITY(conv_swish_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.LE("conv2d", 1)
.EQ("swish", 0));
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
// limitations under the License.

#include "paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h"

#include <functional>
#include <vector>

#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
Expand Down Expand Up @@ -150,7 +152,7 @@ REGISTER_PASS(conv_bias_mkldnn_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_bias_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.LE("conv2d", 1)
.EQ("elementwise_add", 0));

REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
// limitations under the License.

#include "paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.h"

#include <vector>

#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"

Expand Down Expand Up @@ -128,6 +130,6 @@ REGISTER_PASS(conv_concat_relu_mkldnn_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_concat_relu_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.LE("conv2d", 1)
.EQ("concat", 0)
.EQ("relu", 0));
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
// limitations under the License.

#include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h"

#include <functional>
#include <list>
#include <map>
#include <memory>
#include <tuple>

#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/op_version_registry.h"

Expand Down Expand Up @@ -226,19 +228,20 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr()));
conv_output->AsIntermediate();

auto get_node_from_elementwise_add = [&elementwise_add_pattern](
const GraphPatternDetector::subgraph_t& subgraph)
auto get_node_from_elementwise_add =
[&elementwise_add_pattern](
const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*, Node*> {
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);

return std::make_tuple(elementwise_add_op, elementwise_add_y,
elementwise_add_out);
};
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);

return std::make_tuple(elementwise_add_op, elementwise_add_y,
elementwise_add_out);
};

return ExecuteHandleOnGraph<IdentityFuseHandle>(
&gpd, graph_with_stats,
Expand All @@ -263,19 +266,20 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
conv_output);
conv_output->AsIntermediate();

auto get_node_from_elementwise_add = [&elementwise_add_pattern](
const GraphPatternDetector::subgraph_t& subgraph)
auto get_node_from_elementwise_add =
[&elementwise_add_pattern](
const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*, Node*> {
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);

return std::make_tuple(elementwise_add_op, elementwise_add_x,
elementwise_add_out);
};
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);

return std::make_tuple(elementwise_add_op, elementwise_add_x,
elementwise_add_out);
};

return ExecuteHandleOnGraph<IdentityFuseHandle>(
&gpd, graph_with_stats,
Expand All @@ -302,16 +306,17 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
conv_x_output->AsIntermediate();
conv_y_output->AsIntermediate();

auto get_node_from_elementwise_add = [&elementwise_add_pattern](
const GraphPatternDetector::subgraph_t& subgraph)
auto get_node_from_elementwise_add =
[&elementwise_add_pattern](
const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*> {
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);

return std::make_tuple(elementwise_add_op, elementwise_add_out);
};
return std::make_tuple(elementwise_add_op, elementwise_add_out);
};

return ExecuteHandleOnGraph<ProjectionFuseHandle>(
&gpd, graph_with_stats,
Expand Down Expand Up @@ -345,5 +350,5 @@ REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_elementwise_add_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.LE("conv2d", 1)
.EQ("elementwise_add", 0));
6 changes: 6 additions & 0 deletions paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */
#include <vector>

#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/string/pretty_log.h"

Expand Down Expand Up @@ -157,3 +158,8 @@ void CPUBFloat16Pass::ApplyImpl(ir::Graph* graph) const {
} // namespace paddle

REGISTER_PASS(cpu_bfloat16_pass, paddle::framework::ir::CPUBFloat16Pass);

REGISTER_PASS_CAPABILITY(cpu_bfloat16_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().GE(
"quantize", 1));
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,5 @@ REGISTER_PASS(depthwise_conv_mkldnn_pass,
paddle::framework::ir::DepthwiseConvMKLDNNPass);
REGISTER_PASS_CAPABILITY(depthwise_conv_mkldnn_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"depthwise_conv2d", 0));
paddle::framework::compatible::OpVersionComparatorCombination().LE(
"depthwise_conv2d", 1));
8 changes: 8 additions & 0 deletions paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
Expand Down Expand Up @@ -215,3 +217,9 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
} // namespace paddle

REGISTER_PASS(mkldnn_inplace_pass, paddle::framework::ir::MKLDNNInPlacePass);
REGISTER_PASS_CAPABILITY(mkldnn_inplace_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("softmax", 0)
.EQ("elementwise_add", 0)
.EQ("tanh", 0));
Loading