Skip to content

Commit

Permalink
Fix for split op in BF16 inference (#39548)
Browse files Browse the repository at this point in the history
* Fix for split bf16 inference

* added test for pass

* changes after review
  • Loading branch information
jakpiase authored Feb 24, 2022
1 parent c969955 commit 75f91ce
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 48 deletions.
9 changes: 9 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2516,6 +2516,15 @@ PDNode *patterns::DuplicatedInputs::operator()() {
return op;
}

PDNode *patterns::DuplicatedOutputs::operator()() {
auto op = pattern->NewNode(op_repr())->assert_is_ops({"split"});
op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"bfloat16";
});
return op;
}

PDNode *patterns::MKLDNNInPlace::operator()() {
const std::unordered_set<std::string> &supported_op_types = {
"abs", "gelu", "leaky_relu", "relu", "softmax", "sqrt", "swish", "tanh"};
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -1495,6 +1495,15 @@ struct DuplicatedInputs : public PatternBase {
PATTERN_DECL_NODE(op);
};

struct DuplicatedOutputs : public PatternBase {
DuplicatedOutputs(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "many_outputs_op") {}

PDNode* operator()();

PATTERN_DECL_NODE(op);
};

// Pattern used for enforcing inplace computation for in-place computation
// supporting DNNL ops. softmax, batch_norm and layer_norm
struct MKLDNNInPlace : public PatternBase {
Expand Down
166 changes: 120 additions & 46 deletions paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ bool IsPermittedOutputName(const std::string& output_name) {
}

void AddQuantize(Graph* g, ir::Node* op, ir::Node* op_in,
int* quantize_counter) {
int& quantize_counter) {
std::vector<std::string> input_names;

// Find the name of the input linking op to op_in
Expand Down Expand Up @@ -87,10 +87,10 @@ void AddQuantize(Graph* g, ir::Node* op, ir::Node* op_in,
IR_NODE_LINK_TO(op_in, quantize_op);
IR_NODE_LINK_TO(quantize_op, quantize_out_node);
IR_NODE_LINK_TO(quantize_out_node, op);
(*quantize_counter)++;
quantize_counter++;
}

void AddQuantizes(Graph* g, ir::Node* op, int* quantize_counter) {
void AddQuantizes(Graph* g, ir::Node* op, int& quantize_counter) {
auto inputs = op->inputs;
PADDLE_ENFORCE_GE(inputs.size(), 1,
platform::errors::InvalidArgument(
Expand Down Expand Up @@ -127,7 +127,7 @@ void AddQuantizes(Graph* g, ir::Node* op, int* quantize_counter) {
IR_NODE_LINK_TO(inputs[i], quantize_op);
IR_NODE_LINK_TO(quantize_op, quantize_out_nodes[i]);
IR_NODE_LINK_TO(quantize_out_nodes[i], op);
(*quantize_counter)++;
quantize_counter++;
}

op->Op()->SetInput("X", quantize_out_node_names);
Expand All @@ -136,7 +136,7 @@ void AddQuantizes(Graph* g, ir::Node* op, int* quantize_counter) {
// Operators like Concat and Sum have a single input name X, which actually
// consists of multiple inputs. Such operators require a different way to find
// pattern and add quantize ops.
void AddReoderBeforeDuplicatedInputs(ir::Graph* graph, int* quantize_counter) {
void AddReoderBeforeDuplicatedInputs(ir::Graph* graph, int& quantize_counter) {
GraphPatternDetector gpd;
patterns::DuplicatedInputs duplicated_inputs{gpd.mutable_pattern(),
"duplicated_inputs"};
Expand All @@ -151,7 +151,7 @@ void AddReoderBeforeDuplicatedInputs(ir::Graph* graph, int* quantize_counter) {

// Adding quantize ops before all operators except Concat and Sum, which have
// already been handled in AddReoderBeforeDuplicatedInputs
void AddReoderBeforeSingleInputs(ir::Graph* graph, int* quantize_counter) {
void AddReoderBeforeSingleInputs(ir::Graph* graph, int& quantize_counter) {
GraphPatternDetector gpd;
patterns::FirstBfloat16Ops bfloat16_ops{gpd.mutable_pattern(),
"first_bfloat16_ops"};
Expand All @@ -169,60 +169,134 @@ void AddReoderBeforeSingleInputs(ir::Graph* graph, int* quantize_counter) {

void CPUBFloat16Pass::SetInputDataType(ir::Graph* graph) const {
int quantize_counter = 0;
AddReoderBeforeDuplicatedInputs(graph, &quantize_counter);
AddReoderBeforeSingleInputs(graph, &quantize_counter);
AddReoderBeforeDuplicatedInputs(graph, quantize_counter);
AddReoderBeforeSingleInputs(graph, quantize_counter);
PrettyLogDetail("--- added %d quantize ops before bfloat16 op",
quantize_counter);
}

void CPUBFloat16Pass::SetOutputDataType(ir::Graph* graph) const {
void AddDequantize(Graph* g, ir::Node* op, ir::Node* op_out,
int& dequantize_counter) {
if (op->Op()->Type() == "prior_box") return;

// Find the name of the output linking op to op_out
std::vector<std::string> output_names;
for (auto name : op->Op()->OutputNames())
for (auto output_name : op->Op()->Output(name))
if (output_name == op_out->Name() && IsPermittedOutputName(name))
output_names.push_back(name);

if (output_names.empty()) return;

VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc);

OpDesc deq_desc;
deq_desc.SetType("dequantize");
deq_desc.SetInput("Input",
std::vector<std::string>({dequantize_in_node->Name()}));
deq_desc.SetOutput("Output", std::vector<std::string>({op_out->Name()}));
deq_desc.SetAttr("Scale", 1.0f);
deq_desc.SetAttr("Shift", 0.0f);
auto dequantize_op = g->CreateOpNode(&deq_desc); // OpDesc will be copied.

for (auto name = output_names.begin(); name < output_names.end(); name++)
op->Op()->SetOutput(*name,
std::vector<std::string>({dequantize_in_node->Name()}));

UnlinkNodes(op, op_out);
IR_NODE_LINK_TO(op, dequantize_in_node);
IR_NODE_LINK_TO(dequantize_in_node, dequantize_op);
IR_NODE_LINK_TO(dequantize_op, op_out);

dequantize_counter++;
}

void AddDequantizes(Graph* g, ir::Node* op, int& dequantize_counter) {
auto outputs = op->outputs;
PADDLE_ENFORCE_GE(outputs.size(), 1,
platform::errors::InvalidArgument(
"OP(%s)'s outputs(%d) must be equal or greater than 1.",
op->Name(), outputs.size()));
PADDLE_ENFORCE_EQ(op->inputs.size(), 1,
platform::errors::InvalidArgument(
"OP(%s)'s inputs(%d) must be equal to 1.", op->Name(),
op->inputs.size()));

OpDesc deq_desc;
deq_desc.SetType("dequantize");

std::vector<Node*> dequantize_in_nodes(outputs.size());
std::vector<std::string> dequantize_in_node_names(outputs.size());

for (size_t i = 0; i < outputs.size(); i++) {
VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
dequantize_in_nodes[i] = g->CreateVarNode(&dequantize_in_desc);
dequantize_in_node_names[i] = dequantize_in_nodes[i]->Name();

deq_desc.SetInput("Input",
std::vector<std::string>({dequantize_in_node_names[i]}));
deq_desc.SetOutput("Output",
std::vector<std::string>({outputs[i]->Name()}));

deq_desc.SetAttr("Scale", 1.f);
deq_desc.SetAttr("Shift", 0.0f);
deq_desc.SetAttr("bfloat16", true);
deq_desc.SetAttr("output_format", op->Op()->HasAttr("data_layout")
? op->Op()->GetAttr("data_layout")
: std::string("NCHW"));
auto dequantize_op = g->CreateOpNode(&deq_desc); // OpDesc will be copied.

UnlinkNodes(op, outputs[i]);
IR_NODE_LINK_TO(op, dequantize_in_nodes[i]);
IR_NODE_LINK_TO(dequantize_in_nodes[i], dequantize_op);
IR_NODE_LINK_TO(dequantize_op, outputs[i]);

dequantize_counter++;
}

op->Op()->SetOutput("Out", dequantize_in_node_names);
}

// Operators like split have a single output name Out, which actually
// consists of multiple outputs. Such operators require a different way to find
// pattern and add dequantize ops.
void AddReoderAfterDuplicatedOutputs(ir::Graph* graph,
int& dequantize_counter) {
GraphPatternDetector gpd;
patterns::DuplicatedOutputs duplicated_outputs{gpd.mutable_pattern(),
"duplicated_outputs"};
duplicated_outputs();
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op, op, duplicated_outputs);
AddDequantizes(g, op, dequantize_counter);
};
gpd(graph, handler);
}

// Adding dequantize ops after all operators except split, which has
// already been handled in AddReoderAfterDuplicatedOutputs
void AddReoderAfterSingleOutputs(ir::Graph* graph, int& dequantize_counter) {
GraphPatternDetector gpd;
patterns::LastBfloat16Ops bfloat16_ops{gpd.mutable_pattern(),
"last_bfloat16_ops"};
bfloat16_ops();
int dequantize_counter = 0;

auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops);
GET_IR_NODE_FROM_SUBGRAPH(op_out, op_out, bfloat16_ops);

if (op->Op()->Type() != "prior_box") {
// Find the name of the output linking op to op_out
std::vector<std::string> output_names;
for (auto name : op->Op()->OutputNames())
for (auto output_name : op->Op()->Output(name))
if (output_name == op_out->Name() && IsPermittedOutputName(name))
output_names.push_back(name);

if (output_names.empty()) return;

VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc);

OpDesc deq_desc;
deq_desc.SetType("dequantize");
deq_desc.SetInput("Input",
std::vector<std::string>({dequantize_in_node->Name()}));
deq_desc.SetOutput("Output", std::vector<std::string>({op_out->Name()}));
deq_desc.SetAttr("Scale", 1.0f);
deq_desc.SetAttr("Shift", 0.0f);
auto dequantize_op =
g->CreateOpNode(&deq_desc); // OpDesc will be copied.

for (auto name = output_names.begin(); name < output_names.end(); name++)
op->Op()->SetOutput(
*name, std::vector<std::string>({dequantize_in_node->Name()}));

UnlinkNodes(op, op_out);
IR_NODE_LINK_TO(op, dequantize_in_node);
IR_NODE_LINK_TO(dequantize_in_node, dequantize_op);
IR_NODE_LINK_TO(dequantize_op, op_out);

dequantize_counter++;
GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops);
if (op->Op()->Type() != "split") {
AddDequantize(g, op, op_out, dequantize_counter);
}
};
gpd(graph, handler);
}

void CPUBFloat16Pass::SetOutputDataType(ir::Graph* graph) const {
int dequantize_counter = 0;
AddReoderAfterDuplicatedOutputs(graph, dequantize_counter);
AddReoderAfterSingleOutputs(graph, dequantize_counter);
PrettyLogDetail("--- added %d dequantize ops after bfloat16 op",
dequantize_counter);
}
Expand Down
29 changes: 28 additions & 1 deletion paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetInput("Input", {inputs[0]});
op->SetOutput("Out", {outputs[0]});
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
} else if (type == "concat" || type == "sum") {
} else if (type == "concat" || type == "sum" || type == "split") {
op->SetInput("X", inputs);
op->SetOutput("Out", outputs);
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
Expand Down Expand Up @@ -117,6 +117,7 @@ TEST(CpuBfloat16Pass, convolution) {
bool use_mkldnn = true;
int quant_op = 3;
int dequant_op = 3;
// each added op consists of 2 nodes
int added_nodes = quant_op * 2 + dequant_op * 2;
MainTest(BuildProgramDescConv(use_mkldnn), quant_op, dequant_op, added_nodes);
}
Expand All @@ -140,6 +141,7 @@ TEST(CpuBfloat16Pass, double_input_ops) {
bool use_mkldnn = true;
int quant_op = 4;
int dequant_op = 3;
// each added op consists of 2 nodes
int added_nodes = quant_op * 2 + dequant_op * 2;
MainTest(BuildProgramDescDoubleInput(use_mkldnn), quant_op, dequant_op,
added_nodes);
Expand All @@ -164,11 +166,35 @@ TEST(CpuBfloat16Pass, duplicated_input_ops) {
bool use_mkldnn = true;
int quant_op = 5;
int dequant_op = 3;
// each added op consists of 2 nodes
int added_nodes = quant_op * 2 + dequant_op * 2;
MainTest(BuildProgramDescDuplicatedInput(use_mkldnn), quant_op, dequant_op,
added_nodes);
}

ProgramDesc BuildProgramDescDuplicatedOutput(bool use_mkldnn) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "dropout", "Dropout", {"a"}, {"b"}, use_mkldnn, "float32");
SetOp(&prog, "split", "Split", {"b"}, {"c", "d"}, use_mkldnn, "bfloat16");
SetOp(&prog, "transpose2", "Transpose", {"c"}, {"e"}, use_mkldnn, "float32");
SetOp(&prog, "reshape2", "Reshape", {"d"}, {"f"}, use_mkldnn, "bfloat16");

return prog;
}

TEST(CpuBfloat16Pass, duplicated_output_ops) {
bool use_mkldnn = true;
int quant_op = 2;
int dequant_op = 3;
// each added op consists of 2 nodes
int added_nodes = quant_op * 2 + dequant_op * 2;
MainTest(BuildProgramDescDuplicatedOutput(use_mkldnn), quant_op, dequant_op,
added_nodes);
}

ProgramDesc BuildProgramDescDoubleOutputs(bool use_mkldnn) {
ProgramDesc prog;
for (auto& v : variable_names) {
Expand All @@ -190,6 +216,7 @@ TEST(CpuBfloat16Pass, double_outputs_ops) {
bool use_mkldnn = true;
int quant_op = 3;
int dequant_op = 3;
// each added op consists of 2 nodes
int added_nodes = quant_op * 2 + dequant_op * 2;
MainTest(BuildProgramDescDoubleOutputs(use_mkldnn), quant_op, dequant_op,
added_nodes);
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/cpu/split_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,5 @@ PD_REGISTER_KERNEL(split,
int64_t,
int,
bool,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}

0 comments on commit 75f91ce

Please sign in to comment.