Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for split op in BF16 inference #39548

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2516,6 +2516,15 @@ PDNode *patterns::DuplicatedInputs::operator()() {
return op;
}

PDNode *patterns::DuplicatedOutputs::operator()() {
auto op = pattern->NewNode(op_repr())->assert_is_ops({"split"});
op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"bfloat16";
});
return op;
}

PDNode *patterns::MKLDNNInPlace::operator()() {
const std::unordered_set<std::string> &supported_op_types = {
"abs", "gelu", "leaky_relu", "relu", "softmax", "sqrt", "swish", "tanh"};
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -1495,6 +1495,15 @@ struct DuplicatedInputs : public PatternBase {
PATTERN_DECL_NODE(op);
};

struct DuplicatedOutputs : public PatternBase {
DuplicatedOutputs(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "many_outputs_op") {}

PDNode* operator()();

PATTERN_DECL_NODE(op);
};

// Pattern used for enforcing inplace computation for in-place computation
// supporting DNNL ops. softmax, batch_norm and layer_norm
struct MKLDNNInPlace : public PatternBase {
Expand Down
168 changes: 122 additions & 46 deletions paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ bool IsPermittedOutputName(const std::string& output_name) {
}

void AddQuantize(Graph* g, ir::Node* op, ir::Node* op_in,
int* quantize_counter) {
int& quantize_counter) {
std::vector<std::string> input_names;

// Find the name of the input linking op to op_in
Expand Down Expand Up @@ -87,10 +87,10 @@ void AddQuantize(Graph* g, ir::Node* op, ir::Node* op_in,
IR_NODE_LINK_TO(op_in, quantize_op);
IR_NODE_LINK_TO(quantize_op, quantize_out_node);
IR_NODE_LINK_TO(quantize_out_node, op);
(*quantize_counter)++;
quantize_counter++;
}

void AddQuantizes(Graph* g, ir::Node* op, int* quantize_counter) {
void AddQuantizes(Graph* g, ir::Node* op, int& quantize_counter) {
auto inputs = op->inputs;
PADDLE_ENFORCE_GE(inputs.size(), 1,
platform::errors::InvalidArgument(
Expand Down Expand Up @@ -127,7 +127,7 @@ void AddQuantizes(Graph* g, ir::Node* op, int* quantize_counter) {
IR_NODE_LINK_TO(inputs[i], quantize_op);
IR_NODE_LINK_TO(quantize_op, quantize_out_nodes[i]);
IR_NODE_LINK_TO(quantize_out_nodes[i], op);
(*quantize_counter)++;
quantize_counter++;
}

op->Op()->SetInput("X", quantize_out_node_names);
Expand All @@ -136,7 +136,7 @@ void AddQuantizes(Graph* g, ir::Node* op, int* quantize_counter) {
// Operators like Concat and Sum have a single input name X, which actually
// consists of multiple inputs. Such operators require a different way to find
// pattern and add quantize ops.
void AddReoderBeforeDuplicatedInputs(ir::Graph* graph, int* quantize_counter) {
void AddReoderBeforeDuplicatedInputs(ir::Graph* graph, int& quantize_counter) {
GraphPatternDetector gpd;
patterns::DuplicatedInputs duplicated_inputs{gpd.mutable_pattern(),
"duplicated_inputs"};
Expand All @@ -151,7 +151,7 @@ void AddReoderBeforeDuplicatedInputs(ir::Graph* graph, int* quantize_counter) {

// Adding quantize ops before all operators except Concat and Sum, which have
// already been handled in AddReoderBeforeDuplicatedInputs
void AddReoderBeforeSingleInputs(ir::Graph* graph, int* quantize_counter) {
void AddReoderBeforeSingleInputs(ir::Graph* graph, int& quantize_counter) {
GraphPatternDetector gpd;
patterns::FirstBfloat16Ops bfloat16_ops{gpd.mutable_pattern(),
"first_bfloat16_ops"};
Expand All @@ -169,60 +169,136 @@ void AddReoderBeforeSingleInputs(ir::Graph* graph, int* quantize_counter) {

void CPUBFloat16Pass::SetInputDataType(ir::Graph* graph) const {
int quantize_counter = 0;
AddReoderBeforeDuplicatedInputs(graph, &quantize_counter);
AddReoderBeforeSingleInputs(graph, &quantize_counter);
AddReoderBeforeDuplicatedInputs(graph, quantize_counter);
AddReoderBeforeSingleInputs(graph, quantize_counter);
PrettyLogDetail("--- added %d quantize ops before bfloat16 op",
quantize_counter);
}

void CPUBFloat16Pass::SetOutputDataType(ir::Graph* graph) const {
void AddDequantize(Graph* g, ir::Node* op, ir::Node* op_out,
int& dequantize_counter) {
std::vector<std::string> output_names;
jakpiase marked this conversation as resolved.
Show resolved Hide resolved

if (op->Op()->Type() != "prior_box") {
jakpiase marked this conversation as resolved.
Show resolved Hide resolved
// Find the name of the output linking op to op_out
std::vector<std::string> output_names;
for (auto name : op->Op()->OutputNames())
for (auto output_name : op->Op()->Output(name))
if (output_name == op_out->Name() && IsPermittedOutputName(name))
output_names.push_back(name);

if (output_names.empty()) return;
jakpiase marked this conversation as resolved.
Show resolved Hide resolved

VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc);

OpDesc deq_desc;
deq_desc.SetType("dequantize");
deq_desc.SetInput("Input",
std::vector<std::string>({dequantize_in_node->Name()}));
deq_desc.SetOutput("Output", std::vector<std::string>({op_out->Name()}));
deq_desc.SetAttr("Scale", 1.0f);
deq_desc.SetAttr("Shift", 0.0f);
auto dequantize_op = g->CreateOpNode(&deq_desc); // OpDesc will be copied.

for (auto name = output_names.begin(); name < output_names.end(); name++)
op->Op()->SetOutput(
*name, std::vector<std::string>({dequantize_in_node->Name()}));

UnlinkNodes(op, op_out);
IR_NODE_LINK_TO(op, dequantize_in_node);
IR_NODE_LINK_TO(dequantize_in_node, dequantize_op);
IR_NODE_LINK_TO(dequantize_op, op_out);

dequantize_counter++;
}
}

void AddDequantizes(Graph* g, ir::Node* op, int& dequantize_counter) {
auto outputs = op->outputs;
PADDLE_ENFORCE_GE(outputs.size(), 1,
platform::errors::InvalidArgument(
"OP(%s)'s outputs(%d) must be equal or greater than 1.",
op->Name(), outputs.size()));
PADDLE_ENFORCE_EQ(op->inputs.size(), 1,
platform::errors::InvalidArgument(
"OP(%s)'s inputs(%d) must be equal to 1.", op->Name(),
op->inputs.size()));

OpDesc deq_desc;
deq_desc.SetType("dequantize");

std::vector<Node*> dequantize_in_nodes(outputs.size());
std::vector<std::string> dequantize_in_node_names(outputs.size());

for (size_t i = 0; i < outputs.size(); i++) {
VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
dequantize_in_nodes[i] = g->CreateVarNode(&dequantize_in_desc);
dequantize_in_node_names[i] = dequantize_in_nodes[i]->Name();

deq_desc.SetInput("Input",
std::vector<std::string>({dequantize_in_node_names[i]}));
deq_desc.SetOutput("Output",
std::vector<std::string>({outputs[i]->Name()}));

deq_desc.SetAttr("Scale", 1.f);
deq_desc.SetAttr("Shift", 0.0f);
deq_desc.SetAttr("bfloat16", true);
deq_desc.SetAttr("output_format", op->Op()->HasAttr("data_layout")
? op->Op()->GetAttr("data_layout")
: std::string("NCHW"));
auto dequantize_op = g->CreateOpNode(&deq_desc); // OpDesc will be copied.

UnlinkNodes(op, outputs[i]);
IR_NODE_LINK_TO(op, dequantize_in_nodes[i]);
IR_NODE_LINK_TO(dequantize_in_nodes[i], dequantize_op);
IR_NODE_LINK_TO(dequantize_op, outputs[i]);

dequantize_counter++;
}

op->Op()->SetOutput("Out", dequantize_in_node_names);
}

// Operators like split have a single output name Out, which actually
// consists of multiple outputs. Such operators require a different way to find
// pattern and add dequantize ops.
void AddReoderAfterDuplicatedOutputs(ir::Graph* graph,
int& dequantize_counter) {
GraphPatternDetector gpd;
patterns::DuplicatedOutputs duplicated_outputs{gpd.mutable_pattern(),
"duplicated_outputs"};
duplicated_outputs();
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op, op, duplicated_outputs);
AddDequantizes(g, op, dequantize_counter);
};
gpd(graph, handler);
}

// Adding dequantize ops after all operators except split, which has
// already been handled in AddReoderAfterDuplicatedOutputs
void AddReoderAfterSingleOutputs(ir::Graph* graph, int& dequantize_counter) {
GraphPatternDetector gpd;
patterns::LastBfloat16Ops bfloat16_ops{gpd.mutable_pattern(),
"last_bfloat16_ops"};
bfloat16_ops();
int dequantize_counter = 0;

auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops);
GET_IR_NODE_FROM_SUBGRAPH(op_out, op_out, bfloat16_ops);

if (op->Op()->Type() != "prior_box") {
// Find the name of the output linking op to op_out
std::vector<std::string> output_names;
for (auto name : op->Op()->OutputNames())
for (auto output_name : op->Op()->Output(name))
if (output_name == op_out->Name() && IsPermittedOutputName(name))
output_names.push_back(name);

if (output_names.empty()) return;

VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc);

OpDesc deq_desc;
deq_desc.SetType("dequantize");
deq_desc.SetInput("Input",
std::vector<std::string>({dequantize_in_node->Name()}));
deq_desc.SetOutput("Output", std::vector<std::string>({op_out->Name()}));
deq_desc.SetAttr("Scale", 1.0f);
deq_desc.SetAttr("Shift", 0.0f);
auto dequantize_op =
g->CreateOpNode(&deq_desc); // OpDesc will be copied.

for (auto name = output_names.begin(); name < output_names.end(); name++)
op->Op()->SetOutput(
*name, std::vector<std::string>({dequantize_in_node->Name()}));

UnlinkNodes(op, op_out);
IR_NODE_LINK_TO(op, dequantize_in_node);
IR_NODE_LINK_TO(dequantize_in_node, dequantize_op);
IR_NODE_LINK_TO(dequantize_op, op_out);

dequantize_counter++;
GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops);
if (op->Op()->Type() != "split") {
AddDequantize(g, op, op_out, dequantize_counter);
}
};
gpd(graph, handler);
}

void CPUBFloat16Pass::SetOutputDataType(ir::Graph* graph) const {
int dequantize_counter = 0;
AddReoderAfterDuplicatedOutputs(graph, dequantize_counter);
AddReoderAfterSingleOutputs(graph, dequantize_counter);
PrettyLogDetail("--- added %d dequantize ops after bfloat16 op",
dequantize_counter);
}
Expand Down
3 changes: 2 additions & 1 deletion paddle/pten/kernels/cpu/split_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,5 @@ PT_REGISTER_KERNEL(split,
int64_t,
int,
bool,
pten::dtype::float16) {}
pten::dtype::float16,
jakpiase marked this conversation as resolved.
Show resolved Hide resolved
pten::dtype::bfloat16) {}