Skip to content

Commit

Permalink
[XPU] fix dropout pass; add multi_encoder_xpu_fuse_pass & multi_encod…
Browse files Browse the repository at this point in the history
…er_xpu kernel (#50499)
  • Loading branch information
zhupengyang authored Feb 16, 2023
1 parent df20728 commit c8aa640
Show file tree
Hide file tree
Showing 19 changed files with 2,040 additions and 115 deletions.
11 changes: 9 additions & 2 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,17 @@ endif()

if(WITH_XPU)
cc_library(
quant_utils
xpu_quant_utils
SRCS xpu/quant_utils.cc
DEPS pass)
pass_library(fc_xpu_fuse_pass inference DIR xpu DEPS quant_utils)
cc_library(
xpu_pass_utils
SRCS xpu/pass_utils.cc
DEPS pass)
set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils)
pass_library(fc_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(multi_encoder_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
endif()

cc_library(
Expand Down
79 changes: 32 additions & 47 deletions paddle/fluid/framework/ir/delete_dropout_op_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,71 +25,52 @@ namespace paddle {
namespace framework {
namespace ir {

#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(any_op_out); \
GET_IR_NODE(dropout_op); \
GET_IR_NODE(dropout_op_out); \
GET_IR_NODE(dropout_op_outmask); \
GET_IR_NODE(any_op2);
#define GET_IR_NODE(node_) GET_IR_NODE_FROM_SUBGRAPH(node_, node_, pattern)

void DeleteDropoutOpPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "delete_dropout_op_pattern";
FusePassBase::Init(pattern_name, graph);

GraphPatternDetector gpd;

patterns::DeleteDropoutOpPattern pattern(gpd.mutable_pattern(), pattern_name);
pattern();

int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
IR_NODE_LINK_TO(any_op_out, any_op2);
std::string any_op_out_name = any_op_out->Var()->Name();
std::string dropout_op_out_name = dropout_op_out->Var()->Name();

// any_op2
auto* any_op2_desc = any_op2->Op();
auto var_map = any_op2_desc->Inputs();
std::string arg_name = "";
for (auto& name_m : var_map) {
if (std::find(name_m.second.begin(),
name_m.second.end(),
dropout_op_out_name) != name_m.second.end()) {
arg_name = name_m.first;
}
}
if (arg_name.size() == 0) {
LOG(INFO) << "Delete dropout op pass: can not find the input "
<< dropout_op_out_name;
return;
}

// modify the any_op2's inputs
for (auto& name_m : var_map) {
if (std::find(name_m.second.begin(),
name_m.second.end(),
dropout_op_out_name) != name_m.second.end()) {
std::vector<std::string> new_inputs;
for (auto& i_n : name_m.second) {
if (i_n != dropout_op_out_name) {
new_inputs.push_back(i_n);
}
GET_IR_NODE(dropout_op_x);
GET_IR_NODE(dropout_op);
GET_IR_NODE(dropout_op_out);
GET_IR_NODE(dropout_op_mask);

// link dropout_op_out to pre_op
auto dropout_op_x_name = dropout_op_x->Var()->Name();
auto dropout_op_out_name = dropout_op_out->Var()->Name();
auto pre_ops = dropout_op_x->inputs;
if (pre_ops.empty()) return;
auto pre_op_desc = pre_ops[0]->Op();
auto pre_op_outs = pre_op_desc->Outputs();
for (auto& out_var : pre_op_outs) {
auto names = out_var.second;
for (size_t i = 0; i < names.size(); i++) {
if (names[i] == dropout_op_x_name) {
names[i] = dropout_op_out_name;
pre_op_desc->SetOutput(out_var.first, names);
break;
}
new_inputs.push_back(any_op_out_name);
any_op2_desc->SetInput(name_m.first, new_inputs);
any_op2_desc->Flush();
}
}
any_op2_desc->Flush();
IR_NODE_LINK_TO(pre_ops[0], dropout_op_out);

// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph,
{dropout_op, dropout_op_out, dropout_op_outmask});
// delete useless node
std::unordered_set<const Node*> delete_nodes{
dropout_op_x, dropout_op, dropout_op_mask};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};

gpd(graph, handler);
AddStatis(found_subgraph_count);
}

DeleteDropoutOpXPass::DeleteDropoutOpXPass() {
Expand Down Expand Up @@ -279,6 +260,10 @@ void DeleteDropoutOpXPass::ReplaceOutputVar(Node* op,

REGISTER_PASS(delete_dropout_op_pass,
paddle::framework::ir::DeleteDropoutOpPass);
REGISTER_PASS_CAPABILITY(delete_dropout_op_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"dropout", 0));

REGISTER_PASS(delete_dropout_op_x_pass,
paddle::framework::ir::DeleteDropoutOpXPass);
Expand Down
31 changes: 12 additions & 19 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3034,26 +3034,19 @@ PDNode *patterns::TransposeFlattenConcat::operator()(
}

void patterns::DeleteDropoutOpPattern::operator()() {
auto any_op_out = pattern->NewNode(any_op_out_repr())
->assert_is_op_input("dropout", "X")
->AsInput();

auto dropout_op =
pattern->NewNode(dropout_op_repr())->assert_is_op("dropout");

auto dropout_op_x = pattern->NewNode(dropout_op_x_repr())
->assert_is_op_input("dropout", "X")
->AsInput();
auto dropout_op = pattern->NewNode(dropout_op_repr())
->assert_is_op("dropout")
->assert_op_attr("dropout_implementation",
std::string("upscale_in_train"));
auto dropout_op_out = pattern->NewNode(dropout_op_out_repr())
->assert_is_op_output("dropout", "Out")
->AsIntermediate();

auto dropout_op_outmask = pattern->NewNode(dropout_op_outmask_repr())
->assert_is_op_output("dropout", "Mask")
->AsOutput();
auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput();

dropout_op->LinksFrom({any_op_out});
dropout_op_out->LinksFrom({dropout_op});
dropout_op_outmask->LinksFrom({dropout_op});
any_op2->LinksFrom({dropout_op_out});
->assert_is_op_output("dropout", "Out");
auto dropout_op_mask = pattern->NewNode(dropout_op_mask_repr())
->assert_is_op_output("dropout", "Mask");
dropout_op->LinksFrom({dropout_op_x})
.LinksTo({dropout_op_out, dropout_op_mask});
}

void patterns::DeleteQuantOpFuse::operator()(PDNode *input_act_node,
Expand Down
5 changes: 2 additions & 3 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -1763,11 +1763,10 @@ struct DeleteDropoutOpPattern : public PatternBase {

void operator()();

PATTERN_DECL_NODE(any_op_out);
PATTERN_DECL_NODE(dropout_op_x);
PATTERN_DECL_NODE(dropout_op);
PATTERN_DECL_NODE(dropout_op_out);
PATTERN_DECL_NODE(dropout_op_outmask);
PATTERN_DECL_NODE(any_op2);
PATTERN_DECL_NODE(dropout_op_mask);
};

struct DeleteQuantDequantOpPattern : public PatternBase {
Expand Down
17 changes: 2 additions & 15 deletions paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,6 @@ class FcXPUFusePass : public FusePassBase {
const std::string& act_type) const;

const std::string name_scope_{"fc_xpu_fuse_pass"};
const std::map<std::string, int> act_map_{{"", 0},
{"relu", 1},
{"sigmoid", 2},
{"tanh", 3},
{"gelu", 4},
{"leaky_relu", 5},
{"hard_swish", 14},
{"hard_sigmoid", 15},
{"relu6", 17}};
};

void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const {
Expand Down Expand Up @@ -246,17 +237,13 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
mul_w_max_var->SetPersistable(true);
auto mul_w_max_tensor =
scope->Var(mul_w_max_name)->GetMutable<phi::DenseTensor>();
auto* xpu_ctx = static_cast<phi::XPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::XPUPlace()));
int max_ptr_size = xpu_ctx->x_context()->max_ptr_size();
bool transpose_w = false;
if (mul_type == "matmul") {
transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("transpose_Y"));
} else if (mul_type == "matmul_v2") {
transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("trans_y"));
}
QuantWeight<int16_t>(
mul_w_tensor, mul_w_max_tensor, !transpose_w, max_ptr_size);
QuantWeight<int16_t>(mul_w_tensor, mul_w_max_tensor, !transpose_w);
}

// Generate fc_xpu op
Expand Down Expand Up @@ -288,7 +275,7 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
fc_xpu_op_desc.SetAttr("act_type", 0);
fc_xpu_op_desc.SetAttr("act_alpha", 0.f);
if (act) {
fc_xpu_op_desc.SetAttr("act_type", act_map_.at(act_type));
fc_xpu_op_desc.SetAttr("act_type", ConvertActivationType(act_type));
if (act_type == "leaky_relu") {
fc_xpu_op_desc.SetAttr(
"act_alpha", PADDLE_GET_CONST(float, act->Op()->GetAttr("alpha")));
Expand Down
Loading

0 comments on commit c8aa640

Please sign in to comment.