Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate xpu_embedding_with_eltwise_add_fuse_pass #50590

Merged
merged 7 commits into from
Feb 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ if(WITH_XPU)
SRCS xpu/pass_utils.cc
DEPS pass)
set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils)
pass_library(embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu)
pass_library(fc_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(multi_encoder_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
Expand Down
74 changes: 39 additions & 35 deletions paddle/fluid/framework/ir/delete_dropout_op_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,46 +30,50 @@ namespace ir {
void DeleteDropoutOpPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "delete_dropout_op_pattern";
FusePassBase::Init(pattern_name, graph);
int found_subgraph_count = 0;

GraphPatternDetector gpd;
patterns::DeleteDropoutOpPattern pattern(gpd.mutable_pattern(), pattern_name);
pattern();
for (auto with_mask : {true, false}) {
GraphPatternDetector gpd;
patterns::DeleteDropoutOpPattern pattern(gpd.mutable_pattern(),
pattern_name);
pattern(with_mask);

int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE(dropout_op_x);
GET_IR_NODE(dropout_op);
GET_IR_NODE(dropout_op_out);
GET_IR_NODE(dropout_op_mask);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE(dropout_op_x);
GET_IR_NODE(dropout_op);
GET_IR_NODE(dropout_op_out);

// link dropout_op_out to pre_op
auto dropout_op_x_name = dropout_op_x->Var()->Name();
auto dropout_op_out_name = dropout_op_out->Var()->Name();
auto pre_ops = dropout_op_x->inputs;
if (pre_ops.empty()) return;
auto pre_op_desc = pre_ops[0]->Op();
auto pre_op_outs = pre_op_desc->Outputs();
for (auto& out_var : pre_op_outs) {
auto names = out_var.second;
for (size_t i = 0; i < names.size(); i++) {
if (names[i] == dropout_op_x_name) {
names[i] = dropout_op_out_name;
pre_op_desc->SetOutput(out_var.first, names);
break;
// link dropout_op_x to next_op
auto dropout_op_x_name = dropout_op_x->Var()->Name();
auto dropout_op_out_name = dropout_op_out->Var()->Name();
auto next_op_nodes = dropout_op_out->outputs;
for (auto next_op_node : next_op_nodes) {
auto next_op_desc = next_op_node->Op();
auto next_op_inputs = next_op_desc->Inputs();
for (auto& input_var : next_op_inputs) {
auto names = input_var.second;
for (size_t i = 0; i < names.size(); i++) {
if (names[i] == dropout_op_out_name) {
names[i] = dropout_op_x_name;
next_op_desc->SetInput(input_var.first, names);
break;
}
}
}
IR_NODE_LINK_TO(dropout_op_x, next_op_node);
}
}
IR_NODE_LINK_TO(pre_ops[0], dropout_op_out);

// delete useless node
std::unordered_set<const Node*> delete_nodes{
dropout_op_x, dropout_op, dropout_op_mask};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};

gpd(graph, handler);
// delete useless node
std::unordered_set<const Node*> delete_nodes{dropout_op, dropout_op_out};
if (with_mask) {
GET_IR_NODE(dropout_op_mask);
delete_nodes.insert(dropout_op_mask);
}
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
}
AddStatis(found_subgraph_count);
}

Expand Down
14 changes: 9 additions & 5 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3032,7 +3032,7 @@ PDNode *patterns::TransposeFlattenConcat::operator()(
return concat_out;
}

void patterns::DeleteDropoutOpPattern::operator()() {
void patterns::DeleteDropoutOpPattern::operator()(bool with_mask) {
auto dropout_op_x = pattern->NewNode(dropout_op_x_repr())
->assert_is_op_input("dropout", "X")
->AsInput();
Expand All @@ -3042,10 +3042,14 @@ void patterns::DeleteDropoutOpPattern::operator()() {
std::string("upscale_in_train"));
auto dropout_op_out = pattern->NewNode(dropout_op_out_repr())
->assert_is_op_output("dropout", "Out");
auto dropout_op_mask = pattern->NewNode(dropout_op_mask_repr())
->assert_is_op_output("dropout", "Mask");
dropout_op->LinksFrom({dropout_op_x})
.LinksTo({dropout_op_out, dropout_op_mask});
if (with_mask) {
auto dropout_op_mask = pattern->NewNode(dropout_op_mask_repr())
->assert_is_op_output("dropout", "Mask");
dropout_op->LinksFrom({dropout_op_x})
.LinksTo({dropout_op_out, dropout_op_mask});
} else {
dropout_op->LinksFrom({dropout_op_x}).LinksTo({dropout_op_out});
}
}

void patterns::DeleteQuantOpFuse::operator()(PDNode *input_act_node,
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -1759,7 +1759,7 @@ struct DeleteDropoutOpPattern : public PatternBase {
DeleteDropoutOpPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "delete_dropout_op_pattern") {}

void operator()();
void operator()(bool with_mask);

PATTERN_DECL_NODE(dropout_op_x);
PATTERN_DECL_NODE(dropout_op);
Expand Down
Loading