Skip to content

Commit

Permalink
changes after review
Browse files Browse the repository at this point in the history
  • Loading branch information
jakpiase committed Feb 18, 2022
1 parent 347f5c7 commit b43aeef
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 31 deletions.
60 changes: 29 additions & 31 deletions paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,41 +177,39 @@ void CPUBFloat16Pass::SetInputDataType(ir::Graph* graph) const {

void AddDequantize(Graph* g, ir::Node* op, ir::Node* op_out,
int& dequantize_counter) {
std::vector<std::string> output_names;

if (op->Op()->Type() != "prior_box") {
// Find the name of the output linking op to op_out
std::vector<std::string> output_names;
for (auto name : op->Op()->OutputNames())
for (auto output_name : op->Op()->Output(name))
if (output_name == op_out->Name() && IsPermittedOutputName(name))
output_names.push_back(name);
if (op->Op()->Type() == "prior_box") return;

if (output_names.empty()) return;

VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc);

OpDesc deq_desc;
deq_desc.SetType("dequantize");
deq_desc.SetInput("Input",
std::vector<std::string>({dequantize_in_node->Name()}));
deq_desc.SetOutput("Output", std::vector<std::string>({op_out->Name()}));
deq_desc.SetAttr("Scale", 1.0f);
deq_desc.SetAttr("Shift", 0.0f);
auto dequantize_op = g->CreateOpNode(&deq_desc); // OpDesc will be copied.
// Find the name of the output linking op to op_out
std::vector<std::string> output_names;
for (auto name : op->Op()->OutputNames())
for (auto output_name : op->Op()->Output(name))
if (output_name == op_out->Name() && IsPermittedOutputName(name))
output_names.push_back(name);

for (auto name = output_names.begin(); name < output_names.end(); name++)
op->Op()->SetOutput(
*name, std::vector<std::string>({dequantize_in_node->Name()}));
if (output_names.empty()) return;

UnlinkNodes(op, op_out);
IR_NODE_LINK_TO(op, dequantize_in_node);
IR_NODE_LINK_TO(dequantize_in_node, dequantize_op);
IR_NODE_LINK_TO(dequantize_op, op_out);
VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc);

dequantize_counter++;
}
OpDesc deq_desc;
deq_desc.SetType("dequantize");
deq_desc.SetInput("Input",
std::vector<std::string>({dequantize_in_node->Name()}));
deq_desc.SetOutput("Output", std::vector<std::string>({op_out->Name()}));
deq_desc.SetAttr("Scale", 1.0f);
deq_desc.SetAttr("Shift", 0.0f);
auto dequantize_op = g->CreateOpNode(&deq_desc); // OpDesc will be copied.

for (auto name = output_names.begin(); name < output_names.end(); name++)
op->Op()->SetOutput(*name,
std::vector<std::string>({dequantize_in_node->Name()}));

UnlinkNodes(op, op_out);
IR_NODE_LINK_TO(op, dequantize_in_node);
IR_NODE_LINK_TO(dequantize_in_node, dequantize_op);
IR_NODE_LINK_TO(dequantize_op, op_out);

dequantize_counter++;
}

void AddDequantizes(Graph* g, ir::Node* op, int& dequantize_counter) {
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ TEST(CpuBfloat16Pass, convolution) {
bool use_mkldnn = true;
int quant_op = 3;
int dequant_op = 3;
// each added op consists of 2 nodes
int added_nodes = quant_op * 2 + dequant_op * 2;
MainTest(BuildProgramDescConv(use_mkldnn), quant_op, dequant_op, added_nodes);
}
Expand All @@ -140,6 +141,7 @@ TEST(CpuBfloat16Pass, double_input_ops) {
bool use_mkldnn = true;
int quant_op = 4;
int dequant_op = 3;
// each added op consists of 2 nodes
int added_nodes = quant_op * 2 + dequant_op * 2;
MainTest(BuildProgramDescDoubleInput(use_mkldnn), quant_op, dequant_op,
added_nodes);
Expand All @@ -164,6 +166,7 @@ TEST(CpuBfloat16Pass, duplicated_input_ops) {
bool use_mkldnn = true;
int quant_op = 5;
int dequant_op = 3;
// each added op consists of 2 nodes
int added_nodes = quant_op * 2 + dequant_op * 2;
MainTest(BuildProgramDescDuplicatedInput(use_mkldnn), quant_op, dequant_op,
added_nodes);
Expand All @@ -186,6 +189,7 @@ TEST(CpuBfloat16Pass, duplicated_output_ops) {
bool use_mkldnn = true;
int quant_op = 2;
int dequant_op = 3;
// each added op consists of 2 nodes
int added_nodes = quant_op * 2 + dequant_op * 2;
MainTest(BuildProgramDescDuplicatedOutput(use_mkldnn), quant_op, dequant_op,
added_nodes);
Expand All @@ -212,6 +216,7 @@ TEST(CpuBfloat16Pass, double_outputs_ops) {
bool use_mkldnn = true;
int quant_op = 3;
int dequant_op = 3;
// each added op consists of 2 nodes
int added_nodes = quant_op * 2 + dequant_op * 2;
MainTest(BuildProgramDescDoubleOutputs(use_mkldnn), quant_op, dequant_op,
added_nodes);
Expand Down

0 comments on commit b43aeef

Please sign in to comment.