Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for split op in BF16 inference #39548

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2516,6 +2516,15 @@ PDNode *patterns::DuplicatedInputs::operator()() {
return op;
}

PDNode *patterns::DuplicatedOutputs::operator()() {
auto op = pattern->NewNode(op_repr())->assert_is_ops({"split"});
op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"bfloat16";
});
return op;
}

PDNode *patterns::MKLDNNInPlace::operator()() {
const std::unordered_set<std::string> &supported_op_types = {
"abs", "gelu", "leaky_relu", "relu", "softmax", "sqrt", "swish", "tanh"};
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -1495,6 +1495,15 @@ struct DuplicatedInputs : public PatternBase {
PATTERN_DECL_NODE(op);
};

struct DuplicatedOutputs : public PatternBase {
DuplicatedOutputs(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "many_outputs_op") {}

PDNode* operator()();

PATTERN_DECL_NODE(op);
};

// Pattern used for enforcing inplace computation for in-place computation
// supporting DNNL ops. softmax, batch_norm and layer_norm
struct MKLDNNInPlace : public PatternBase {
Expand Down
166 changes: 120 additions & 46 deletions paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ bool IsPermittedOutputName(const std::string& output_name) {
}

void AddQuantize(Graph* g, ir::Node* op, ir::Node* op_in,
int* quantize_counter) {
int& quantize_counter) {
std::vector<std::string> input_names;

// Find the name of the input linking op to op_in
Expand Down Expand Up @@ -87,10 +87,10 @@ void AddQuantize(Graph* g, ir::Node* op, ir::Node* op_in,
IR_NODE_LINK_TO(op_in, quantize_op);
IR_NODE_LINK_TO(quantize_op, quantize_out_node);
IR_NODE_LINK_TO(quantize_out_node, op);
(*quantize_counter)++;
quantize_counter++;
}

void AddQuantizes(Graph* g, ir::Node* op, int* quantize_counter) {
void AddQuantizes(Graph* g, ir::Node* op, int& quantize_counter) {
auto inputs = op->inputs;
PADDLE_ENFORCE_GE(inputs.size(), 1,
platform::errors::InvalidArgument(
Expand Down Expand Up @@ -127,7 +127,7 @@ void AddQuantizes(Graph* g, ir::Node* op, int* quantize_counter) {
IR_NODE_LINK_TO(inputs[i], quantize_op);
IR_NODE_LINK_TO(quantize_op, quantize_out_nodes[i]);
IR_NODE_LINK_TO(quantize_out_nodes[i], op);
(*quantize_counter)++;
quantize_counter++;
}

op->Op()->SetInput("X", quantize_out_node_names);
Expand All @@ -136,7 +136,7 @@ void AddQuantizes(Graph* g, ir::Node* op, int* quantize_counter) {
// Operators like Concat and Sum have a single input name X, which actually
// consists of multiple inputs. Such operators require a different way to find
// pattern and add quantize ops.
void AddReoderBeforeDuplicatedInputs(ir::Graph* graph, int* quantize_counter) {
void AddReoderBeforeDuplicatedInputs(ir::Graph* graph, int& quantize_counter) {
GraphPatternDetector gpd;
patterns::DuplicatedInputs duplicated_inputs{gpd.mutable_pattern(),
"duplicated_inputs"};
Expand All @@ -151,7 +151,7 @@ void AddReoderBeforeDuplicatedInputs(ir::Graph* graph, int* quantize_counter) {

// Adding quantize ops before all operators except Concat and Sum, which have
// already been handled in AddReoderBeforeDuplicatedInputs
void AddReoderBeforeSingleInputs(ir::Graph* graph, int* quantize_counter) {
void AddReoderBeforeSingleInputs(ir::Graph* graph, int& quantize_counter) {
GraphPatternDetector gpd;
patterns::FirstBfloat16Ops bfloat16_ops{gpd.mutable_pattern(),
"first_bfloat16_ops"};
Expand All @@ -169,60 +169,134 @@ void AddReoderBeforeSingleInputs(ir::Graph* graph, int* quantize_counter) {

void CPUBFloat16Pass::SetInputDataType(ir::Graph* graph) const {
int quantize_counter = 0;
AddReoderBeforeDuplicatedInputs(graph, &quantize_counter);
AddReoderBeforeSingleInputs(graph, &quantize_counter);
AddReoderBeforeDuplicatedInputs(graph, quantize_counter);
AddReoderBeforeSingleInputs(graph, quantize_counter);
PrettyLogDetail("--- added %d quantize ops before bfloat16 op",
quantize_counter);
}

void CPUBFloat16Pass::SetOutputDataType(ir::Graph* graph) const {
void AddDequantize(Graph* g, ir::Node* op, ir::Node* op_out,
int& dequantize_counter) {
if (op->Op()->Type() == "prior_box") return;
Copy link
Contributor

@baoachun baoachun Feb 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does prior_box return directly? Could you add descriptions or can we use set to maintain operators that require special handling?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wozna could you please advise us on this prior_box scenario?

Copy link
Contributor

@wozna wozna Feb 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prior_box output always produces floating-point results because these are prior boxes generated. Therefore, we do not need dequantization. And so far only this operator is behaving this way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jakub may add comments in next PR, but this PR has passed all CIs and hopefully it could be merged


// Find the name of the output linking op to op_out
std::vector<std::string> output_names;
for (auto name : op->Op()->OutputNames())
for (auto output_name : op->Op()->Output(name))
if (output_name == op_out->Name() && IsPermittedOutputName(name))
output_names.push_back(name);

if (output_names.empty()) return;

VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc);

OpDesc deq_desc;
deq_desc.SetType("dequantize");
deq_desc.SetInput("Input",
std::vector<std::string>({dequantize_in_node->Name()}));
deq_desc.SetOutput("Output", std::vector<std::string>({op_out->Name()}));
deq_desc.SetAttr("Scale", 1.0f);
deq_desc.SetAttr("Shift", 0.0f);
auto dequantize_op = g->CreateOpNode(&deq_desc); // OpDesc will be copied.

for (auto name = output_names.begin(); name < output_names.end(); name++)
op->Op()->SetOutput(*name,
std::vector<std::string>({dequantize_in_node->Name()}));

UnlinkNodes(op, op_out);
IR_NODE_LINK_TO(op, dequantize_in_node);
IR_NODE_LINK_TO(dequantize_in_node, dequantize_op);
IR_NODE_LINK_TO(dequantize_op, op_out);

dequantize_counter++;
}

void AddDequantizes(Graph* g, ir::Node* op, int& dequantize_counter) {
auto outputs = op->outputs;
PADDLE_ENFORCE_GE(outputs.size(), 1,
platform::errors::InvalidArgument(
"OP(%s)'s outputs(%d) must be equal or greater than 1.",
op->Name(), outputs.size()));
PADDLE_ENFORCE_EQ(op->inputs.size(), 1,
platform::errors::InvalidArgument(
"OP(%s)'s inputs(%d) must be equal to 1.", op->Name(),
op->inputs.size()));

OpDesc deq_desc;
deq_desc.SetType("dequantize");

std::vector<Node*> dequantize_in_nodes(outputs.size());
std::vector<std::string> dequantize_in_node_names(outputs.size());

for (size_t i = 0; i < outputs.size(); i++) {
VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
dequantize_in_nodes[i] = g->CreateVarNode(&dequantize_in_desc);
dequantize_in_node_names[i] = dequantize_in_nodes[i]->Name();

deq_desc.SetInput("Input",
std::vector<std::string>({dequantize_in_node_names[i]}));
deq_desc.SetOutput("Output",
std::vector<std::string>({outputs[i]->Name()}));

deq_desc.SetAttr("Scale", 1.f);
deq_desc.SetAttr("Shift", 0.0f);
deq_desc.SetAttr("bfloat16", true);
deq_desc.SetAttr("output_format", op->Op()->HasAttr("data_layout")
? op->Op()->GetAttr("data_layout")
: std::string("NCHW"));
auto dequantize_op = g->CreateOpNode(&deq_desc); // OpDesc will be copied.

UnlinkNodes(op, outputs[i]);
IR_NODE_LINK_TO(op, dequantize_in_nodes[i]);
IR_NODE_LINK_TO(dequantize_in_nodes[i], dequantize_op);
IR_NODE_LINK_TO(dequantize_op, outputs[i]);

dequantize_counter++;
}

op->Op()->SetOutput("Out", dequantize_in_node_names);
}

// Operators like split have a single output name Out, which actually
// consists of multiple outputs. Such operators require a different way to find
// pattern and add dequantize ops.
void AddReoderAfterDuplicatedOutputs(ir::Graph* graph,
int& dequantize_counter) {
GraphPatternDetector gpd;
patterns::DuplicatedOutputs duplicated_outputs{gpd.mutable_pattern(),
"duplicated_outputs"};
duplicated_outputs();
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op, op, duplicated_outputs);
AddDequantizes(g, op, dequantize_counter);
};
gpd(graph, handler);
}

// Adding dequantize ops after all operators except split, which has
// already been handled in AddReoderAfterDuplicatedOutputs
void AddReoderAfterSingleOutputs(ir::Graph* graph, int& dequantize_counter) {
GraphPatternDetector gpd;
patterns::LastBfloat16Ops bfloat16_ops{gpd.mutable_pattern(),
"last_bfloat16_ops"};
bfloat16_ops();
int dequantize_counter = 0;

auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops);
GET_IR_NODE_FROM_SUBGRAPH(op_out, op_out, bfloat16_ops);

if (op->Op()->Type() != "prior_box") {
// Find the name of the output linking op to op_out
std::vector<std::string> output_names;
for (auto name : op->Op()->OutputNames())
for (auto output_name : op->Op()->Output(name))
if (output_name == op_out->Name() && IsPermittedOutputName(name))
output_names.push_back(name);

if (output_names.empty()) return;

VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc);

OpDesc deq_desc;
deq_desc.SetType("dequantize");
deq_desc.SetInput("Input",
std::vector<std::string>({dequantize_in_node->Name()}));
deq_desc.SetOutput("Output", std::vector<std::string>({op_out->Name()}));
deq_desc.SetAttr("Scale", 1.0f);
deq_desc.SetAttr("Shift", 0.0f);
auto dequantize_op =
g->CreateOpNode(&deq_desc); // OpDesc will be copied.

for (auto name = output_names.begin(); name < output_names.end(); name++)
op->Op()->SetOutput(
*name, std::vector<std::string>({dequantize_in_node->Name()}));

UnlinkNodes(op, op_out);
IR_NODE_LINK_TO(op, dequantize_in_node);
IR_NODE_LINK_TO(dequantize_in_node, dequantize_op);
IR_NODE_LINK_TO(dequantize_op, op_out);

dequantize_counter++;
GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops);
if (op->Op()->Type() != "split") {
AddDequantize(g, op, op_out, dequantize_counter);
}
};
gpd(graph, handler);
}

void CPUBFloat16Pass::SetOutputDataType(ir::Graph* graph) const {
int dequantize_counter = 0;
AddReoderAfterDuplicatedOutputs(graph, dequantize_counter);
AddReoderAfterSingleOutputs(graph, dequantize_counter);
PrettyLogDetail("--- added %d dequantize ops after bfloat16 op",
dequantize_counter);
}
Expand Down
29 changes: 28 additions & 1 deletion paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetInput("Input", {inputs[0]});
op->SetOutput("Out", {outputs[0]});
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
} else if (type == "concat" || type == "sum") {
} else if (type == "concat" || type == "sum" || type == "split") {
op->SetInput("X", inputs);
op->SetOutput("Out", outputs);
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
Expand Down Expand Up @@ -117,6 +117,7 @@ TEST(CpuBfloat16Pass, convolution) {
bool use_mkldnn = true;
int quant_op = 3;
int dequant_op = 3;
// each added op consists of 2 nodes
int added_nodes = quant_op * 2 + dequant_op * 2;
MainTest(BuildProgramDescConv(use_mkldnn), quant_op, dequant_op, added_nodes);
}
Expand All @@ -140,6 +141,7 @@ TEST(CpuBfloat16Pass, double_input_ops) {
bool use_mkldnn = true;
int quant_op = 4;
int dequant_op = 3;
// each added op consists of 2 nodes
int added_nodes = quant_op * 2 + dequant_op * 2;
MainTest(BuildProgramDescDoubleInput(use_mkldnn), quant_op, dequant_op,
added_nodes);
Expand All @@ -164,11 +166,35 @@ TEST(CpuBfloat16Pass, duplicated_input_ops) {
bool use_mkldnn = true;
int quant_op = 5;
int dequant_op = 3;
// each added op consists of 2 nodes
int added_nodes = quant_op * 2 + dequant_op * 2;
MainTest(BuildProgramDescDuplicatedInput(use_mkldnn), quant_op, dequant_op,
added_nodes);
}

ProgramDesc BuildProgramDescDuplicatedOutput(bool use_mkldnn) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "dropout", "Dropout", {"a"}, {"b"}, use_mkldnn, "float32");
SetOp(&prog, "split", "Split", {"b"}, {"c", "d"}, use_mkldnn, "bfloat16");
SetOp(&prog, "transpose2", "Transpose", {"c"}, {"e"}, use_mkldnn, "float32");
SetOp(&prog, "reshape2", "Reshape", {"d"}, {"f"}, use_mkldnn, "bfloat16");

return prog;
}

TEST(CpuBfloat16Pass, duplicated_output_ops) {
bool use_mkldnn = true;
int quant_op = 2;
int dequant_op = 3;
// each added op consists of 2 nodes
int added_nodes = quant_op * 2 + dequant_op * 2;
MainTest(BuildProgramDescDuplicatedOutput(use_mkldnn), quant_op, dequant_op,
added_nodes);
}

ProgramDesc BuildProgramDescDoubleOutputs(bool use_mkldnn) {
ProgramDesc prog;
for (auto& v : variable_names) {
Expand All @@ -190,6 +216,7 @@ TEST(CpuBfloat16Pass, double_outputs_ops) {
bool use_mkldnn = true;
int quant_op = 3;
int dequant_op = 3;
// each added op consists of 2 nodes
int added_nodes = quant_op * 2 + dequant_op * 2;
MainTest(BuildProgramDescDoubleOutputs(use_mkldnn), quant_op, dequant_op,
added_nodes);
Expand Down
3 changes: 2 additions & 1 deletion paddle/pten/kernels/cpu/split_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,5 @@ PT_REGISTER_KERNEL(split,
int64_t,
int,
bool,
pten::dtype::float16) {}
pten::dtype::float16,
jakpiase marked this conversation as resolved.
Show resolved Hide resolved
pten::dtype::bfloat16) {}