From beb950eb66308eeaa8c60e4db9a006948e2ba7bb Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Wed, 2 Dec 2020 14:40:17 -0800 Subject: [PATCH] Fuse MatMulIntegerToFloat only when scales are scalar (#6008) MatMulIntegerToFloat fusion fuses per-row and per-column MatMulInteger, which is not supported by the MatMulIntegerToFloat kernel now. Limit the fusion to per-matrix only before we supporting the per-channel fully. --- .../core/optimizer/matmul_integer_to_float.cc | 9 ++++++++- onnxruntime/core/optimizer/utils.cc | 2 +- onnxruntime/core/optimizer/utils.h | 5 ++++- .../test/optimizer/graph_transform_test.cc | 6 +++--- .../fusion/matmul_integer_to_float.onnx | Bin 1520 -> 1922 bytes .../fusion/matmul_integer_to_float.py | 5 +++++ 6 files changed, 21 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/optimizer/matmul_integer_to_float.cc b/onnxruntime/core/optimizer/matmul_integer_to_float.cc index 630455c62161b..6aa473cb72417 100644 --- a/onnxruntime/core/optimizer/matmul_integer_to_float.cc +++ b/onnxruntime/core/optimizer/matmul_integer_to_float.cc @@ -34,7 +34,7 @@ static bool CheckBiasShape(const TensorShapeProto* bias_shape) { /** MatMulIntegerToFloatFusion will fuse subgraph like below into MatMulIntegerToFloat: - A A_Zero B B_Zero A_Scale) B_Scale Bias (Const, Optional) + A A_Zero B B_Zero A_Scale B_Scale Bias (Const, Optional) \ | | / \ / | \ | | / \ / | \ | | / \ / | @@ -84,6 +84,13 @@ Status MatMulIntegerToFloatFusion::ApplyImpl(Graph& graph, bool& modified, int g continue; } + // A_Scale is scalar and B_Scale is scalar or 1D tensor + auto mul_node_input_defs = p_mul_node_right->InputDefs(); + if (!optimizer_utils::IsScalar(*mul_node_input_defs[0]) || + !optimizer_utils::IsScalar(*mul_node_input_defs[1])) { + continue; + } + Node& cast_node = *graph.GetNode(p_cast_node->Index()); Node& matmulinteger_node = *graph.GetNode(p_matmulinteger_node->Index()); Node& mul_node_right = *graph.GetNode(p_mul_node_right->Index()); diff --git a/onnxruntime/core/optimizer/utils.cc b/onnxruntime/core/optimizer/utils.cc index 39d178ae07638..fb658bfb848eb 100644 --- a/onnxruntime/core/optimizer/utils.cc +++ b/onnxruntime/core/optimizer/utils.cc @@ -24,7 +24,7 @@ bool IsFloatingPointDataType(const ONNX_NAMESPACE::TensorProto& tensor_proto) { return tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT || tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 || tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_DOUBLE; } -inline bool IsScalar(const NodeArg& input_arg) { +bool IsScalar(const NodeArg& input_arg) { auto shape = input_arg.Shape(); if (shape == nullptr) { // shape inferencing wasn't able to populate shape information for this NodeArg diff --git a/onnxruntime/core/optimizer/utils.h b/onnxruntime/core/optimizer/utils.h index 535edc9c77488..2d1025f88cbc6 100644 --- a/onnxruntime/core/optimizer/utils.h +++ b/onnxruntime/core/optimizer/utils.h @@ -15,6 +15,9 @@ namespace optimizer_utils { // Check if TensorProto contains a floating point type. bool IsFloatingPointDataType(const ONNX_NAMESPACE::TensorProto& tensor_proto); +// Check if input is a scalar +bool IsScalar(const NodeArg& input_arg); + /** Check whether a input is initializer with specified float value. @param expected_value is the expected value of the initializer. @param is_constant means whether the initializer is required to be constant. @@ -60,7 +63,7 @@ bool ValidateShape(const NodeArg& node_arg, const std::initializer_list */ bool CompareShape(const ONNX_NAMESPACE::TensorShapeProto& node_arg_shape, const ONNX_NAMESPACE::TensorShapeProto& node_arg_other_shape); -/** Check check whether each dimension is known for shape of node_arg +/** Check whether each dimension is known for shape of node_arg @returns false when shape is nullptr, or total dimension is not same as expected_dim_size length, or any dim is unknown (without dim value). */ diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 63bad72be407a..ec3147fd6e324 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -3069,9 +3069,9 @@ TEST_F(GraphTransformationTests, MatMulIntegerToFloatTest) { std::map op_to_count = CountOpsInGraph(graph); EXPECT_EQ(op_to_count["DynamicQuantizeLinear"], 1); - EXPECT_EQ(op_to_count["MatMulInteger"], 0); - EXPECT_EQ(op_to_count["Cast"], 0); - EXPECT_EQ(op_to_count["Mul"], 0); + EXPECT_EQ(op_to_count["MatMulInteger"], 1); + EXPECT_EQ(op_to_count["Cast"], 1); + EXPECT_EQ(op_to_count["Mul"], 2); EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 3); EXPECT_EQ(op_to_count["Add"], 1); } diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float.onnx b/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float.onnx index 7ea69c580ee435be09f12b949f14fdb2efe3d403..a88a4bd6817605d0a41203ea1dfc76527f00b641 100644 GIT binary patch delta 284 zcmeys-Netz!EW`FZzHQQ%ftv-eiJU1#Q3TLF192f8E-QAA)}a(i4?zYVu^2Qj%Qv; zYIyvf8hDl#TQJh`PgC7A^|nIIu4ZXhqdC^J2y1StPTR>4F_G&ivXB$!`XQczkF zpO%xK2$YjzOHM2X(vu%EYpR=Yp%@~>0aa@v1rm=>$}cI&&y6?PY|N6%7!W1T#hVmg zSelqul3A6S5^o~J$Hl|JB*ekR#K8o_%uymRcMEZGaR6mlz@p+@oJsM;$%#3sKrJ8% dMu^1Z8LT0bFR*fpLJVRA8^jFK;>0At0|1dMOTqvE delta 28 jcmZqT|G>@4!EW`Odn2nc%jP)N1jfmH>>(^pOaeRrZvY07 diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float.py b/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float.py index 1a270043baa65..293a45cb48383 100644 --- a/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float.py +++ b/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float.py @@ -29,6 +29,7 @@ def GenerateModel(model_name): nodes.extend(MakeSubGraph("_1", True)) nodes.extend(MakeSubGraph("_2", True)) nodes.extend(MakeSubGraph("_3", False)) + nodes.extend(MakeSubGraph("_4", False)) initializers = [] initializers.extend(MakeInitializer("_1")) @@ -48,11 +49,15 @@ def GenerateModel(model_name): helper.make_tensor_value_info('b_quantized_2', TensorProto.UINT8, [2, 3]), helper.make_tensor_value_info('b_zp_2', TensorProto.UINT8, [1]), helper.make_tensor_value_info('b_scale_2', TensorProto.FLOAT, [1]), + helper.make_tensor_value_info('b_quantized_4', TensorProto.UINT8, [2, 3]), + helper.make_tensor_value_info('b_zp_4', TensorProto.UINT8, [3]), + helper.make_tensor_value_info('b_scale_4', TensorProto.FLOAT, [3]), ], [ # outputs helper.make_tensor_value_info('output_1', TensorProto.FLOAT, [3, 3]), helper.make_tensor_value_info('output_2', TensorProto.FLOAT, [3, 3]), helper.make_tensor_value_info('output_3', TensorProto.FLOAT, [3, 3]), + helper.make_tensor_value_info('output_4', TensorProto.FLOAT, [3, 3]), ], initializers)