Skip to content

Commit

Permalink
[cherry-pick] release/2.3 elementwise_mul and matmul mkldnn fix (#43725)
Browse files Browse the repository at this point in the history
* Correct elementwise quantization (#43693)

* [Bug fix] Do not quantize weights Y when matmul X and Y both other ops outputs (#43297)

* fix some matmul that X and Y both other ops outputs, do not dequantize the Y.

* fix CI format

* fix according to review

Co-authored-by: joanna.wozna.intel <joanna.wozna@intel.com>
  • Loading branch information
lidanqing-intel and wozna authored Jun 23, 2022
1 parent d0bbf46 commit a7e0cde
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 13 deletions.
14 changes: 14 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2069,6 +2069,20 @@ PDNode *patterns::Elementwise::operator()(PDNode *x_var, PDNode *y_var,
return out_var;
}

PDNode *patterns::ElementwiseOp::operator()(
const std::string elementwise_type) {
auto elementwise_op =
pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type);

auto out_var = pattern->NewNode(elementwise_out_repr())
->AsOutput()
->assert_is_op_output(elementwise_type, "Out");

elementwise_op->LinksTo({out_var});

return out_var;
}

PDNode *patterns::ResidualElementwise::operator()(
PDNode *op_var, PDNode *residual_var, const std::string elementwise_type,
bool as_x) {
Expand Down
13 changes: 13 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,19 @@ struct Elementwise : public PatternBase {
PATTERN_DECL_NODE(elementwise_out);
};

// Elementwise ops
// Forward pass for element-wise operators
// elementwise_out is the result of the operator
struct ElementwiseOp : public PatternBase {
ElementwiseOp(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "elementwise") {}

PDNode* operator()(const std::string elementwise_type);

PATTERN_DECL_NODE(elementwise_op);
PATTERN_DECL_NODE(elementwise_out);
};

// Residual Elementwise ops
// This pattern allows operator output to be X or Y
// and residual data Y or X, based on as_x flag
Expand Down
23 changes: 14 additions & 9 deletions paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -819,12 +819,9 @@ void CPUQuantizePass::QuantizeElementwise(
Graph* graph, const std::string elementwise_type) const {
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::Elementwise elementwise_pattern{pattern, name_scope_};
patterns::ElementwiseOp elementwise_pattern{pattern, name_scope_};

elementwise_pattern(
pattern->NewNode(elementwise_pattern.elementwise_x_repr()),
pattern->NewNode(elementwise_pattern.elementwise_y_repr()),
elementwise_type);
elementwise_pattern(elementwise_type);

int quantize_elementwise_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Expand All @@ -839,10 +836,18 @@ void CPUQuantizePass::QuantizeElementwise(
return;
}

GET_IR_NODE_FROM_SUBGRAPH(elementwise_x, elementwise_x,
elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_y, elementwise_y,
elementwise_pattern);
auto x_name = elementwise_op->Op()->Input("X");
auto y_name = elementwise_op->Op()->Input("Y");
Node *elementwise_x, *elementwise_y;

for (auto& input : elementwise_op->inputs) {
if (input->Name() == x_name[0]) elementwise_x = input;
if (input->Name() == y_name[0]) elementwise_y = input;
}
if (!elementwise_x || !elementwise_y) {
return;
}

GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out,
elementwise_pattern);

Expand Down
7 changes: 3 additions & 4 deletions paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -354,10 +354,9 @@ bool QuantDequantMkldnnPass::IsInt8Weight(
auto* op_desc = op_node->Op();
auto var_name = op_desc->Input(weight_name)[0];
auto* var = scope->FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound(
"The input persistable [%s] var of [%s] op is not found.",
var_name, op_desc->Type()));
if (var == nullptr) {
return false;
}
auto* weight_tensor = var->GetMutable<LoDTensor>();
auto* weight_data = weight_tensor->data<float>();
bool is_int8 = true;
Expand Down

0 comments on commit a7e0cde

Please sign in to comment.