From dd05785565b89d88e2caf84bac44192ff16e43ca Mon Sep 17 00:00:00 2001 From: Alexey Voronov Date: Thu, 17 Feb 2022 03:37:55 +0300 Subject: [PATCH 1/3] Add a conversion of individual operations in FQ2I pass. --- include/tvm/relay/dataflow_matcher.h | 11 + include/tvm/relay/qnn/op/dequantize.h | 39 +++ .../transform/fake_quantization_to_integer.py | 10 + .../fake_quantization_to_integer.cc | 254 +++++++++++++++++- src/relay/transforms/type_infer.cc | 7 +- .../test_pass_fake_quantization_to_integer.py | 215 +++++++++++++++ 6 files changed, 528 insertions(+), 8 deletions(-) create mode 100644 include/tvm/relay/qnn/op/dequantize.h diff --git a/include/tvm/relay/dataflow_matcher.h b/include/tvm/relay/dataflow_matcher.h index 10e461645c8b..8dd5fbdd5eac 100644 --- a/include/tvm/relay/dataflow_matcher.h +++ b/include/tvm/relay/dataflow_matcher.h @@ -106,6 +106,17 @@ Expr RewritePatterns(Array callbacks, Expr expr, IRModule mod */ Expr PartitionPattern(DFPattern pattern, Expr expr, Map attrs, PackedFunc check); +/*! + * \brief Infer the type of an expression. + * + * \param expr The expression to rewrite + * + * \return Return An Expr with unambiguous type information filled in, as well as it's + * checked type field populated with the result type. + * + */ +Expr InferType(const Expr& expr); + } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/qnn/op/dequantize.h b/include/tvm/relay/qnn/op/dequantize.h new file mode 100644 index 000000000000..469751726eb1 --- /dev/null +++ b/include/tvm/relay/qnn/op/dequantize.h @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/executor.h + * \brief Relay dequantize. + */ +#ifndef TVM_RELAY_QNN_OP_DEQUANTIZE_H_ +#define TVM_RELAY_QNN_OP_DEQUANTIZE_H_ + +#include + +namespace tvm { +namespace relay { +namespace qnn { + +Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point, int axis); + +} // namespace qnn +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_QNN_OP_DEQUANTIZE_H_ \ No newline at end of file diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index e84ba5557a70..ce0177904b87 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -130,6 +130,16 @@ def global_avgpool2d(expr, type_map): return [out, t] +@register_fake_quantization_to_integer("broadcast_to") +def broadcast_to(expr, type_map): + """Rewrite a broadcast_to op""" + arg = expr.args[0] + t = type_map[arg] + shape = expr.attrs.shape + out = relay.op.broadcast_to(arg, shape) + return [out, t] + + @register_fake_quantization_to_integer("rsqrt") def rsqrt(expr, type_map): """Rewrite a rsqrt op""" diff --git a/src/relay/transforms/fake_quantization_to_integer.cc b/src/relay/transforms/fake_quantization_to_integer.cc index 4273fc29cec8..11e0064dd50a 100644 --- a/src/relay/transforms/fake_quantization_to_integer.cc +++ b/src/relay/transforms/fake_quantization_to_integer.cc @@ -25,9 +25,13 @@ #include "fake_quantization_to_integer.h" +#include +#include +#include #include #include #include +#include #include #include @@ -37,7 +41,8 @@ namespace relay { /* Description of FakeQuantizationToInteger * - * The purpose of this pass is to find regions of the graph that follow + * This pass consists of two parts, a basic one and an optional one. + * The purpose of the basic part is to find regions of the graph that follow * the general pattern: * * x w @@ -52,7 +57,7 @@ namespace relay { * * and convert them into subgraphs with actual integer operations on x and w * - * The pass does this via a multi-pass approach: + * The basic part does this via a multi-pass approach: * * The main pass is a MixedModeMutator that traverses the full graph searching for * quantize operations @@ -69,6 +74,24 @@ namespace relay { * * After the second and third passes run, the first pass replaces the quantize with the * rewritten subgraph and the processing continues + * + * The main idea of the optional part is to find and transform operations with dequantized inputs + * one by one individually. Only operations from the allowed list are allowed. For example, if on + * the above general pattern op2 is not registered with the FTVMFakeQuantizationToInteger + * attribute, op1 operation can still be converted. Converted pattern below: + * + * x w + * | | + * \ / + * op1 + * | + * dq + * | + * op2 + * | + * q + * + * The optional part works in the same multi-pass approach. */ using ExprSet = std::unordered_set; @@ -270,8 +293,233 @@ class FakeQuantizationRewriter : public MixedModeMutator { const bool hard_fail_; }; +bool is_op_enabled_for_optional_fq2i(const CallNode* call_node) { + const Op op = Downcast(call_node->op); + static auto fqfq = Op::GetAttrMap("FTVMFakeQuantizationToInteger"); + static std::unordered_set ops = { + Op::Get("reshape"), + Op::Get("squeeze"), + Op::Get("strided_slice"), + Op::Get("transpose"), + Op::Get("expand_dims"), + Op::Get("nn.max_pool2d"), + Op::Get("nn.batch_flatten"), + Op::Get("nn.depth_to_space"), + Op::Get("max"), + Op::Get("min"), + Op::Get("nn.avg_pool2d"), + Op::Get("nn.global_avg_pool2d"), + Op::Get("nn.bias_add"), + Op::Get("nn.conv2d"), + Op::Get("nn.conv2d_transpose"), + Op::Get("nn.dense"), + Op::Get("nn.batch_matmul"), + Op::Get("split"), + Op::Get("clip"), + Op::Get("nn.relu"), + Op::Get("nn.pad"), + Op::Get("broadcast_to"), + Op::Get("minimum"), + Op::Get("maximum")}; + + auto is_enabled = [&](const auto i) { return i == call_node->op; }; + auto result = std::find_if(std::begin(ops), std::end(ops), is_enabled); + return result != ops.end() && fqfq.count(Downcast(op)); +} + +class OptionalSubgraphExtractor : public ExprVisitor { + public: + const ExprSet GetSubgraph(const Expr& expr) { + expr_call_node_ = expr.as(); + ICHECK(expr_call_node_ != nullptr); + ICHECK(is_op_enabled_for_optional_fq2i(expr_call_node_)); + + VisitExpr(expr); + + ExprSet subgraph; + if (is_fake_quantized_) { + for (auto kv : this->visit_counter_) { + if (auto call_node = GetRef(kv.first).as()) { + if (call_node != expr_call_node_) { + subgraph.insert(Downcast(GetRef(kv.first))); + } + } + } + } + return subgraph; + } + const AffineTypeMap GetAffineTypes() { return affine_types_; } + void VisitExpr(const Expr& expr) override { + // When looking for fake quantized subgraphs, we only support data-flow regions of the graph, + // i.e. call nodes/tuples/constants/etc. If we see anything else (like control flow) we + // abort the rewrite. + if (expr.as() == nullptr && expr.as() == nullptr && + expr.as() == nullptr && expr.as() == nullptr && + expr.as() == nullptr) { + DLOG(INFO) << "FakeQuantizationToInteger found a non - dataflow op inside a fake quantize " + "region, aborting this rewrite"; + is_fake_quantized_ = false; + } else { + ExprVisitor::VisitExpr(expr); + } + } + + protected: + void VisitExpr_(const CallNode* call_node) override { + if (call_node->op == dequantize_op_) { + const auto* attrs = call_node->attrs.as(); + ICHECK(attrs != nullptr); + + affine_types_.Set( + GetRef(call_node), + TensorAffineType( + call_node->args[1], call_node->args[2], + tvm::relay::transform::InferTypeLocal(call_node->args[0]).as()->dtype, + attrs->axis)); + } else if (call_node == expr_call_node_) { + for (auto arg : call_node->args) { + VisitExpr(arg); + } + } else { + // run normally on everything else. + ExprVisitor::VisitExpr_(call_node); + } + } + + const Op dequantize_op_ = Op::Get("qnn.dequantize"); + bool is_fake_quantized_ = true; + AffineTypeMap affine_types_; + const CallNode* expr_call_node_ = nullptr; +}; + +class OptionalSubgraphMutator : public ExprMutator { + public: + OptionalSubgraphMutator(ExprSet subgraph, AffineTypeMap affine_types, bool hard_fail) + : subgraph_(subgraph), affine_types_(affine_types), hard_fail_(hard_fail) {} + + Expr MutateSubgraph(const Expr& expr) { + if (subgraph_.size() == 0) { + return expr; + } + + quantize_node_ = expr.as(); + ICHECK(quantize_node_); + ICHECK(is_op_enabled_for_optional_fq2i(quantize_node_)); + + for (auto node : subgraph_) { + const Op op = Downcast(node.as()->op); + + if (node.as()->op != dequantize_op_) { + // Only modify the subgraph if we have translation + // rules for every op + if (hard_fail_) { + LOG(FATAL) << "Found no rewrite rule for " << AsText(op, false) << std::endl; + } else { + DLOG(INFO) << "Found no rewrite rule for " << AsText(op, false) << std::endl; + return expr; + } + } + } + try { + return Mutate(expr); + } catch (std::exception& e) { + if (hard_fail_) { + throw e; + } else { + DLOG(INFO) << "Ran into an error rewriting a subgraph, skipping" << expr << std::endl; + return expr; + } + } + } + + protected: + Expr VisitExpr_(const CallNode* call_node) { + Expr out; + static auto fqfq = + Op::GetAttrMap("FTVMFakeQuantizationToInteger"); + + Op op = Downcast(call_node->op); + if (fqfq.count(op)) { + Expr expr; + if (op == dequantize_op_) { + expr = GetRef(call_node); + } else { + expr = ExprMutator::VisitExpr_(call_node); + } + // Call the rewrite + Array vals = fqfq[op](expr, affine_types_); + // Save the outputs of the rewrite + ICHECK(vals.size() == 2) + << "got the wrong number of returned arguments from FTVMFakeQuantizationToInteger for " + << AsText(op, false); + out = Downcast(vals[0]); + + affine_types_.Set(out, Downcast(vals[1])); + + if (call_node == quantize_node_) { + out = qnn::MakeDequantize(out, vals[1].as()->scale, + vals[1].as()->zero_point, + vals[1].as()->axis); + } + } else { + ICHECK(false) << "When rewriting a fake quantized graph, found an invalid node " + << AsText(GetRef(call_node), false); + } + return out; + } + + Expr VisitExpr_(const TupleNode* node) { + Expr expr = ExprMutator::VisitExpr_(node); + auto new_node = expr.as(); + Array types; + for (Expr field : new_node->fields) { + ICHECK(affine_types_[field].as()); + types.push_back(Downcast(affine_types_[field])); + } + affine_types_.Set(expr, TupleAffineType(types)); + return expr; + } + + Expr VisitExpr_(const TupleGetItemNode* node) { + Expr expr = ExprMutator::VisitExpr_(node); + auto tuple_type = affine_types_[expr.as()->tuple].as(); + affine_types_.Set(expr, tuple_type->types[node->index]); + return expr; + } + + ExprSet subgraph_; + AffineTypeMap affine_types_; + const bool hard_fail_; + const Op dequantize_op_ = Op::Get("qnn.dequantize"); + const CallNode* quantize_node_ = nullptr; +}; + +class OptionalFakeQuantizationRewriter : public MixedModeMutator { + public: + explicit OptionalFakeQuantizationRewriter(bool hard_fail) : hard_fail_(hard_fail) {} + + protected: + Expr Rewrite_(const CallNode* pre, const Expr& post) override { + if (const CallNode* call_node = post.as()) { + const Op op = Downcast(call_node->op); + if (is_op_enabled_for_optional_fq2i(call_node)) { + OptionalSubgraphExtractor extractor; + ExprSet subgraph = extractor.GetSubgraph(post); + AffineTypeMap affine_types = extractor.GetAffineTypes(); + Expr out = OptionalSubgraphMutator(subgraph, affine_types, hard_fail_).MutateSubgraph(post); + return out; + } + } + return post; + } + const bool hard_fail_; +}; + Expr FakeQuantizationToInteger(const Expr& expr, const IRModule& mod, bool hard_fail) { - return FakeQuantizationRewriter(hard_fail).Mutate(expr); + auto fq_expr = FakeQuantizationRewriter(hard_fail).Mutate(expr); + auto fq_inferred_expr = tvm::relay::InferType(fq_expr); + auto ofq_expr = OptionalFakeQuantizationRewriter(hard_fail).Mutate(fq_inferred_expr); + return ofq_expr; } namespace transform { diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 456e210f7343..59b96ea1481b 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -41,6 +41,7 @@ #include #include #include +#include #include #include #include @@ -918,11 +919,7 @@ Type InferTypeLocal(const Expr& expr) { mod = transform::InferType()(mod); Type result_type; - if (expr.as()) { - result_type = mod->Lookup("main")->checked_type(); - } else { - result_type = mod->Lookup("main").as()->body->checked_type(); - } + result_type = relay::InferType(sub_graph)->checked_type(); expr->checked_type_ = result_type; return result_type; diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index 2aeb8e3bd554..b2752dd38831 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -17,6 +17,7 @@ # pylint: disable=unused-wildcard-import import numpy as np import pytest +import sys import tvm from tvm import relay from tvm.relay.transform import fake_quantization_to_integer @@ -626,3 +627,217 @@ def conv2d(expr, type_map): # pylint: disable=unused-variable # Catch a generic exception because the tvm FFI eats the python exception type with pytest.raises(Exception): mod_int = tvm.relay.transform.FakeQuantizationToInteger(hard_fail=True)(mod) + + +def compare_expected_fq_to_int(expr, expected_expr, args, allow_rounding_error=False): + mod = tvm.IRModule.from_expr(expr) + mod_def = tvm.relay.transform.InferType()(mod) + mod_int = tvm.relay.transform.FakeQuantizationToInteger(False)(mod_def) + mod_exp = tvm.relay.transform.InferType()(tvm.IRModule.from_expr(expected_expr)) + print("mod_def\n", mod_def, "\n") + print("mod_int\n", mod_int, "\n") + print("mod_exp\n", mod_exp, "\n") + assert not tvm.ir.structural_equal(mod, mod_int) + assert tvm.ir.structural_equal(mod_int, mod_exp) + result_def = ( + relay.create_executor("vm", mod=mod_def, device=tvm.cpu(), target="llvm") + .evaluate()(*args) + .numpy() + ) + result_int = ( + relay.create_executor("vm", mod=mod_int, device=tvm.cpu(), target="llvm") + .evaluate()(*args) + .numpy() + ) + result_exp = ( + relay.create_executor("vm", mod=mod_exp, device=tvm.cpu(), target="llvm") + .evaluate()(*args) + .numpy() + ) + print("result_def\n", result_def) + print("result_int\n", result_int) + print("result_exp\n", result_exp) + if allow_rounding_error: + assert np.all(np.abs(result_def.astype("int32") - result_int.astype("int32")) <= 1) + else: + assert np.array_equal(result_def, result_int) + + assert np.array_equal(result_int, result_exp) + + +def test_fq2i_optional_op_chaind_with_disabled_op(): + shape_x = [1, 4, 2] + shape_w = [1, 4, 2] + a = relay.var("a", shape=shape_x, dtype="int8") + b = relay.var("b", shape=shape_w, dtype="int8") + + op0 = relay.qnn.op.dequantize(a, relay.const(2.0), relay.const(0)) + op1 = relay.qnn.op.dequantize(b, relay.const(6.0), relay.const(0)) + op2 = relay.op.nn.batch_matmul(op0, op1) + op3 = relay.op.add(op2, relay.const(1.0)) + expr = relay.op.erf(op3) + + op0 = relay.qnn.op.qnn.batch_matmul( + a, b, relay.const(0), relay.const(0), relay.const(2.0), relay.const(6.0) + ) + op1 = relay.qnn.op.qnn.dequantize(op0, relay.const(12.0), relay.const(0)) + op2 = relay.op.add(op1, relay.const(1.0)) + expected_expr = relay.op.erf(op2) + + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8") + w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8") + compare_expected_fq_to_int(expr, expected_expr, [x_np, w_np], False) + + +def test_fq2i_optional_negative(): + shape_x = [1, 4, 2] + shape_w = [1, 4, 2] + a = relay.var("a", shape=shape_x, dtype="int8") + b = relay.var("b", shape=shape_w, dtype="int8") + + op0 = relay.qnn.op.dequantize(a, relay.const(2.0), relay.const(0)) + op1 = relay.qnn.op.dequantize(b, relay.const(6.0), relay.const(0)) + op2 = relay.op.add(op1, relay.const(1.0)) + op3 = relay.op.nn.batch_matmul(op0, op2) + expr = relay.op.erf(op3) + + expected_expr = expr + + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8") + w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8") + compare_expected_fq_to_int(expr, expected_expr, [x_np, w_np], False) + + +def test_fq2i_optional_args(): + # pron one + shape_x = [1, 4, 2] + shape_w = [1, 4, 2] + a = relay.var("a", shape=shape_x, dtype="int8") + b = relay.var("b", shape=shape_w, dtype="int8") + + op0 = relay.qnn.op.dequantize(a, relay.const(2.0), relay.const(0)) + op1 = relay.qnn.op.dequantize(b, relay.const(6.0), relay.const(0)) + expr = relay.op.nn.batch_matmul(op0, op1) + + op0 = relay.qnn.op.qnn.batch_matmul( + a, b, relay.const(0), relay.const(0), relay.const(2.0), relay.const(6.0) + ) + expected_expr = relay.qnn.op.qnn.dequantize(op0, relay.const(12.0), relay.const(0)) + + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8") + w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8") + compare_expected_fq_to_int(expr, expected_expr, [x_np, w_np], False) + + +def test_fq2i_optional_idly(): + shape_x = [1, 4, 2] + shape_w = [1, 4, 2] + a = relay.var("a", shape=shape_x, dtype="int8") + b = relay.var("b", shape=shape_w, dtype="int8") + + op0 = relay.qnn.op.dequantize(a, relay.const(2.0), relay.const(0)) + op1 = relay.qnn.op.dequantize(b, relay.const(6.0), relay.const(0)) + op2 = relay.op.nn.batch_matmul(op0, op1) + op3 = relay.op.add(op2, relay.const(1.0)) + expr = relay.qnn.op.quantize(op3, relay.const(1.0), relay.const(0), out_dtype="int8") + + op0 = relay.qnn.op.batch_matmul( + a, b, relay.const(0), relay.const(0), relay.const(2.0), relay.const(6.0) + ) + op1 = relay.qnn.op.quantize( + relay.const(1.0), relay.const(12.0), relay.const(0), out_dtype="int32" + ) + op2 = relay.qnn.op.add( + op0, + op1, + relay.const(12.0), + relay.const(0), + relay.const(12.0), + relay.const(0), + relay.const(12.0), + relay.const(0), + ) + expected_expr = relay.qnn.op.requantize( + op2, relay.const(12.0), relay.const(0), relay.const(1.0), relay.const(0), out_dtype="int8" + ) + + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8") + w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8") + compare_expected_fq_to_int(expr, expected_expr, [x_np, w_np], False) + + +def test_fq2i_optional_op_chain(): + shape_x = [1, 2, 4] + shape_w = [2] + a = relay.var("a", shape=shape_x, dtype="int8") + b = relay.var("b", shape=shape_w, dtype="int8") + + op0 = relay.qnn.op.dequantize(a, relay.const(2.0), relay.const(0)) + op1 = relay.qnn.op.dequantize(b, relay.const(6.0), relay.const(0)) + op2 = relay.op.reshape(op0, (1, 4, 2)) + op3 = relay.op.broadcast_to(op1, (2, 2, 2)) + op4 = relay.op.nn.batch_matmul(op2, op3) + expr = relay.op.erf(op4) + + op0 = relay.op.reshape(a, (1, 4, 2)) + op1 = relay.op.broadcast_to(b, (2, 2, 2)) + op3 = relay.qnn.op.qnn.batch_matmul( + op0, op1, relay.const(0), relay.const(0), relay.const(2.0), relay.const(6.0) + ) + op4 = relay.qnn.op.qnn.dequantize(op3, relay.const(12.0), relay.const(0)) + expected_expr = relay.op.erf(op4) + + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8") + w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8") + compare_expected_fq_to_int(expr, expected_expr, [x_np, w_np], False) + + +def test_fq2i_optional_one_arg(): + shape_x = [1, 2, 4] + a = relay.var("a", shape=shape_x, dtype="int8") + + op0 = relay.qnn.op.dequantize(a, relay.const(2.0), relay.const(0)) + + op1 = relay.op.reshape(op0, (1, 4, 2)) + expr = relay.op.erf(op1) + + op0 = relay.op.reshape(a, (1, 4, 2)) + op1 = relay.qnn.op.dequantize(op0, relay.const(2.0), relay.const(0)) + expected_expr = relay.op.erf(op1) + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8") + compare_expected_fq_to_int(expr, expected_expr, [x_np], False) + + +def test_fq2i_optional_intermediate_infertype(): + shape_x = [1, 2, 4] + x = relay.var("x", shape=shape_x, dtype="float32") + const_0 = relay.const(np.random.uniform(size=[1, 4, 2]).astype("float32")) + + op0 = relay.qnn.op.quantize(x, relay.const(17.0), relay.const(0), out_dtype="int8") + op1 = relay.qnn.op.dequantize(op0, relay.const(17.0), relay.const(0)) + op2 = relay.op.reshape(op1, (1, 4, 2)) + op3 = relay.qnn.op.quantize(op2, relay.const(10.0), relay.const(0), out_dtype="int8") + op4 = relay.qnn.op.quantize(const_0, relay.const(1.0), relay.const(8), out_dtype="int8") + op5 = relay.qnn.op.dequantize(op3, relay.const(10.0), relay.const(0)) + op6 = relay.qnn.op.dequantize(op4, relay.const(4.0), relay.const(9)) + op7 = relay.op.nn.batch_matmul(op5, op6) + expr = relay.op.add(op7, relay.const(5.0)) + + op0 = relay.qnn.op.quantize(x, relay.const(17.0), relay.const(0), out_dtype="int8") + op1 = relay.op.reshape(op0, (1, 4, 2)) + op2 = relay.qnn.op.requantize( + op1, relay.const(17.0), relay.const(0), relay.const(10.0), relay.const(0), out_dtype="int8" + ) + op3 = relay.qnn.op.quantize(const_0, relay.const(1.0), relay.const(8), out_dtype="int8") + op4 = relay.qnn.op.batch_matmul( + op2, op3, relay.const(0), relay.const(9), relay.const(10.0), relay.const(4.0) + ) + op5 = relay.qnn.op.dequantize(op4, relay.const(40.0), relay.const(0)) + expected_expr = relay.op.add(op5, relay.const(5.0)) + + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int32").astype("float32") + compare_expected_fq_to_int(expr, expected_expr, [x_np], False) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) From aa378d2e96107c67c236bc9feeebec191f0de038 Mon Sep 17 00:00:00 2001 From: Alexey Voronov Date: Thu, 17 Feb 2022 03:39:40 +0300 Subject: [PATCH 2/3] apply review comments --- include/tvm/relay/qnn/op/dequantize.h | 39 -------- python/tvm/relay/transform/transform.py | 24 ++++- .../fake_quantization_to_integer.cc | 97 +++++++++++-------- src/relay/transforms/type_infer.cc | 2 - .../test_pass_fake_quantization_to_integer.py | 49 +++++----- 5 files changed, 102 insertions(+), 109 deletions(-) delete mode 100644 include/tvm/relay/qnn/op/dequantize.h diff --git a/include/tvm/relay/qnn/op/dequantize.h b/include/tvm/relay/qnn/op/dequantize.h deleted file mode 100644 index 469751726eb1..000000000000 --- a/include/tvm/relay/qnn/op/dequantize.h +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/relay/executor.h - * \brief Relay dequantize. - */ -#ifndef TVM_RELAY_QNN_OP_DEQUANTIZE_H_ -#define TVM_RELAY_QNN_OP_DEQUANTIZE_H_ - -#include - -namespace tvm { -namespace relay { -namespace qnn { - -Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point, int axis); - -} // namespace qnn -} // namespace relay -} // namespace tvm - -#endif // TVM_RELAY_QNN_OP_DEQUANTIZE_H_ \ No newline at end of file diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 14f1caae06eb..99c61c5bd96f 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1238,7 +1238,7 @@ def AnnotateSpans(): return _ffi_api.AnnotateSpans() -def FakeQuantizationToInteger(hard_fail=False): +def FakeQuantizationToInteger(hard_fail=False, use_qat=False): # pylint: disable=anomalous-backslash-in-string """ Find regions of the graph of the form @@ -1267,12 +1267,30 @@ def FakeQuantizationToInteger(hard_fail=False): If true, raise an error. If false, skip rewriting the subgraph. + use_qat : boolean + To perform an additional QAT pass - convert enabled operations with dequantized inputs. + Example: in the graph above op2 is not registered with the FakeQuantizationToInteger + attribute, op1 operation can still be converted. Converted pattern below: + + .. code-block:: text + + x w + | | + \\ / + op1 + | + dq + | + op2 + | + q + Returns ------- ret : tvm.transform.Pass - The registered SimplifyExpr pass. + The registered FakeQuantizationToInteger pass. """ - return _ffi_api.FakeQuantizationToInteger(hard_fail) + return _ffi_api.FakeQuantizationToInteger(hard_fail, use_qat) def ToMixedPrecision(mixed_precision_type="float16", missing_op_mode=1): diff --git a/src/relay/transforms/fake_quantization_to_integer.cc b/src/relay/transforms/fake_quantization_to_integer.cc index 11e0064dd50a..1c7b6b6c37cb 100644 --- a/src/relay/transforms/fake_quantization_to_integer.cc +++ b/src/relay/transforms/fake_quantization_to_integer.cc @@ -31,18 +31,18 @@ #include #include #include -#include #include #include +#include "../qnn/utils.h" + namespace tvm { namespace relay { /* Description of FakeQuantizationToInteger * - * This pass consists of two parts, a basic one and an optional one. - * The purpose of the basic part is to find regions of the graph that follow + * The purpose of this pass is to find regions of the graph that follow * the general pattern: * * x w @@ -57,7 +57,7 @@ namespace relay { * * and convert them into subgraphs with actual integer operations on x and w * - * The basic part does this via a multi-pass approach: + * The pass does this via a multi-pass approach: * * The main pass is a MixedModeMutator that traverses the full graph searching for * quantize operations @@ -75,10 +75,14 @@ namespace relay { * After the second and third passes run, the first pass replaces the quantize with the * rewritten subgraph and the processing continues * - * The main idea of the optional part is to find and transform operations with dequantized inputs - * one by one individually. Only operations from the allowed list are allowed. For example, if on - * the above general pattern op2 is not registered with the FTVMFakeQuantizationToInteger - * attribute, op1 operation can still be converted. Converted pattern below: + * + * After that an additional QAT pass can be enabled by use_qat flag. The goal of the pass is to find + * operations in those regions(which were not successfully converted by the main pass) that can + * still be converted into quantized form. The idea is to find and transform operations with + * dequantized inputs one by one individually. Only operations for which all parameters can be + * explicitly calculated are allowed. For example, if on the above general pattern op2 is not + * registered with the FTVMFakeQuantizationToInteger attribute, op1 operation can still be + * converted. Converted pattern below: * * x w * | | @@ -91,7 +95,7 @@ namespace relay { * | * q * - * The optional part works in the same multi-pass approach. + * This pass works in the same multi-pass approach. */ using ExprSet = std::unordered_set; @@ -293,41 +297,49 @@ class FakeQuantizationRewriter : public MixedModeMutator { const bool hard_fail_; }; +/* Checks if the operation to convert QAT pass is enabled. + * The following conditions must be satisfied: + * 1. operations registered for FTVMFakeQuantizationToInteger; + * 2. Unary operators or operators with with the TensorAffineType calculated during + * FTVMFakeQuantizationToInteger conversion; + * 3. Not one of the "key" operations: requantize,quantize and dequantize(they are at the boundaries + * of regions defined to be quantized). + */ bool is_op_enabled_for_optional_fq2i(const CallNode* call_node) { const Op op = Downcast(call_node->op); static auto fqfq = Op::GetAttrMap("FTVMFakeQuantizationToInteger"); static std::unordered_set ops = { - Op::Get("reshape"), - Op::Get("squeeze"), - Op::Get("strided_slice"), - Op::Get("transpose"), + Op::Get("broadcast_to"), + Op::Get("clip"), Op::Get("expand_dims"), - Op::Get("nn.max_pool2d"), - Op::Get("nn.batch_flatten"), - Op::Get("nn.depth_to_space"), Op::Get("max"), + Op::Get("maximum"), Op::Get("min"), + Op::Get("minimum"), Op::Get("nn.avg_pool2d"), - Op::Get("nn.global_avg_pool2d"), + Op::Get("nn.batch_flatten"), + Op::Get("nn.batch_matmul"), Op::Get("nn.bias_add"), Op::Get("nn.conv2d"), Op::Get("nn.conv2d_transpose"), Op::Get("nn.dense"), - Op::Get("nn.batch_matmul"), - Op::Get("split"), - Op::Get("clip"), - Op::Get("nn.relu"), + Op::Get("nn.depth_to_space"), + Op::Get("nn.global_avg_pool2d"), + Op::Get("nn.max_pool2d"), Op::Get("nn.pad"), - Op::Get("broadcast_to"), - Op::Get("minimum"), - Op::Get("maximum")}; + Op::Get("nn.relu"), + Op::Get("reshape"), + Op::Get("split"), + Op::Get("squeeze"), + Op::Get("strided_slice"), + Op::Get("transpose")}; auto is_enabled = [&](const auto i) { return i == call_node->op; }; auto result = std::find_if(std::begin(ops), std::end(ops), is_enabled); return result != ops.end() && fqfq.count(Downcast(op)); } -class OptionalSubgraphExtractor : public ExprVisitor { +class QATSubgraphExtractor : public ExprVisitor { public: const ExprSet GetSubgraph(const Expr& expr) { expr_call_node_ = expr.as(); @@ -392,9 +404,9 @@ class OptionalSubgraphExtractor : public ExprVisitor { const CallNode* expr_call_node_ = nullptr; }; -class OptionalSubgraphMutator : public ExprMutator { +class QATSubgraphMutator : public ExprMutator { public: - OptionalSubgraphMutator(ExprSet subgraph, AffineTypeMap affine_types, bool hard_fail) + QATSubgraphMutator(ExprSet subgraph, AffineTypeMap affine_types, bool hard_fail) : subgraph_(subgraph), affine_types_(affine_types), hard_fail_(hard_fail) {} Expr MutateSubgraph(const Expr& expr) { @@ -410,12 +422,12 @@ class OptionalSubgraphMutator : public ExprMutator { const Op op = Downcast(node.as()->op); if (node.as()->op != dequantize_op_) { - // Only modify the subgraph if we have translation - // rules for every op if (hard_fail_) { - LOG(FATAL) << "Found no rewrite rule for " << AsText(op, false) << std::endl; + LOG(FATAL) << "Not dequantization was found in the input arguments for" + << AsText(op, false) << std::endl; } else { - DLOG(INFO) << "Found no rewrite rule for " << AsText(op, false) << std::endl; + DLOG(INFO) << "Not dequantization was found in the input arguments for " + << AsText(op, false) << std::endl; return expr; } } @@ -494,19 +506,19 @@ class OptionalSubgraphMutator : public ExprMutator { const CallNode* quantize_node_ = nullptr; }; -class OptionalFakeQuantizationRewriter : public MixedModeMutator { +class QATRewriter : public MixedModeMutator { public: - explicit OptionalFakeQuantizationRewriter(bool hard_fail) : hard_fail_(hard_fail) {} + explicit QATRewriter(bool hard_fail) : hard_fail_(hard_fail) {} protected: Expr Rewrite_(const CallNode* pre, const Expr& post) override { if (const CallNode* call_node = post.as()) { const Op op = Downcast(call_node->op); if (is_op_enabled_for_optional_fq2i(call_node)) { - OptionalSubgraphExtractor extractor; + QATSubgraphExtractor extractor; ExprSet subgraph = extractor.GetSubgraph(post); AffineTypeMap affine_types = extractor.GetAffineTypes(); - Expr out = OptionalSubgraphMutator(subgraph, affine_types, hard_fail_).MutateSubgraph(post); + Expr out = QATSubgraphMutator(subgraph, affine_types, hard_fail_).MutateSubgraph(post); return out; } } @@ -515,19 +527,22 @@ class OptionalFakeQuantizationRewriter : public MixedModeMutator { const bool hard_fail_; }; -Expr FakeQuantizationToInteger(const Expr& expr, const IRModule& mod, bool hard_fail) { +Expr FakeQuantizationToInteger(const Expr& expr, const IRModule& mod, bool hard_fail, + bool use_qat) { auto fq_expr = FakeQuantizationRewriter(hard_fail).Mutate(expr); - auto fq_inferred_expr = tvm::relay::InferType(fq_expr); - auto ofq_expr = OptionalFakeQuantizationRewriter(hard_fail).Mutate(fq_inferred_expr); - return ofq_expr; + if (use_qat) { + fq_expr = tvm::relay::InferType(fq_expr); + fq_expr = QATRewriter(hard_fail).Mutate(fq_expr); + } + return fq_expr; } namespace transform { -Pass FakeQuantizationToInteger(bool hard_fail) { +Pass FakeQuantizationToInteger(bool hard_fail, bool use_qat) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { - return Downcast(FakeQuantizationToInteger(f, m, hard_fail)); + return Downcast(FakeQuantizationToInteger(f, m, hard_fail, use_qat)); }; return CreateFunctionPass(pass_func, 0, "FakeQuantizationToInteger", {"InferType"}); } diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 59b96ea1481b..7de43eb36882 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -915,8 +915,6 @@ Type InferTypeLocal(const Expr& expr) { */ SameTypedSubgraphExtractor subgraph_extractor; Expr sub_graph = subgraph_extractor(expr); - auto mod = IRModule::FromExpr(sub_graph); - mod = transform::InferType()(mod); Type result_type; result_type = relay::InferType(sub_graph)->checked_type(); diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index b2752dd38831..13eb86f83d79 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -17,7 +17,6 @@ # pylint: disable=unused-wildcard-import import numpy as np import pytest -import sys import tvm from tvm import relay from tvm.relay.transform import fake_quantization_to_integer @@ -629,14 +628,11 @@ def conv2d(expr, type_map): # pylint: disable=unused-variable mod_int = tvm.relay.transform.FakeQuantizationToInteger(hard_fail=True)(mod) -def compare_expected_fq_to_int(expr, expected_expr, args, allow_rounding_error=False): +def compare_expected_fq_qat_to_int(expr, expected_expr, args, allow_rounding_error=False): mod = tvm.IRModule.from_expr(expr) mod_def = tvm.relay.transform.InferType()(mod) - mod_int = tvm.relay.transform.FakeQuantizationToInteger(False)(mod_def) + mod_int = tvm.relay.transform.FakeQuantizationToInteger(False, True)(mod_def) mod_exp = tvm.relay.transform.InferType()(tvm.IRModule.from_expr(expected_expr)) - print("mod_def\n", mod_def, "\n") - print("mod_int\n", mod_int, "\n") - print("mod_exp\n", mod_exp, "\n") assert not tvm.ir.structural_equal(mod, mod_int) assert tvm.ir.structural_equal(mod_int, mod_exp) result_def = ( @@ -654,9 +650,6 @@ def compare_expected_fq_to_int(expr, expected_expr, args, allow_rounding_error=F .evaluate()(*args) .numpy() ) - print("result_def\n", result_def) - print("result_int\n", result_int) - print("result_exp\n", result_exp) if allow_rounding_error: assert np.all(np.abs(result_def.astype("int32") - result_int.astype("int32")) <= 1) else: @@ -665,7 +658,8 @@ def compare_expected_fq_to_int(expr, expected_expr, args, allow_rounding_error=F assert np.array_equal(result_int, result_exp) -def test_fq2i_optional_op_chaind_with_disabled_op(): +def test_fq_qat_op_positive_part(): + # Only the first operation is converted, since the next operation("add") is not enabled. shape_x = [1, 4, 2] shape_w = [1, 4, 2] a = relay.var("a", shape=shape_x, dtype="int8") @@ -686,10 +680,11 @@ def test_fq2i_optional_op_chaind_with_disabled_op(): x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8") w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8") - compare_expected_fq_to_int(expr, expected_expr, [x_np, w_np], False) + compare_expected_fq_qat_to_int(expr, expected_expr, [x_np, w_np]) -def test_fq2i_optional_negative(): +def test_fq_qat_negative_all(): + # None of the operations are converted, since the first operation("add") is not enabled. shape_x = [1, 4, 2] shape_w = [1, 4, 2] a = relay.var("a", shape=shape_x, dtype="int8") @@ -705,11 +700,11 @@ def test_fq2i_optional_negative(): x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8") w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8") - compare_expected_fq_to_int(expr, expected_expr, [x_np, w_np], False) + compare_expected_fq_qat_to_int(expr, expected_expr, [x_np, w_np]) -def test_fq2i_optional_args(): - # pron one +def test_fq_qat_positive_single(): + # The single operation is converted. shape_x = [1, 4, 2] shape_w = [1, 4, 2] a = relay.var("a", shape=shape_x, dtype="int8") @@ -726,10 +721,11 @@ def test_fq2i_optional_args(): x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8") w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8") - compare_expected_fq_to_int(expr, expected_expr, [x_np, w_np], False) + compare_expected_fq_qat_to_int(expr, expected_expr, [x_np, w_np]) -def test_fq2i_optional_idly(): +def test_fq_qat_positive_nothing_to_do(): + # All operations are converted by the non-QAT pass. shape_x = [1, 4, 2] shape_w = [1, 4, 2] a = relay.var("a", shape=shape_x, dtype="int8") @@ -763,10 +759,11 @@ def test_fq2i_optional_idly(): x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8") w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8") - compare_expected_fq_to_int(expr, expected_expr, [x_np, w_np], False) + compare_expected_fq_qat_to_int(expr, expected_expr, [x_np, w_np]) -def test_fq2i_optional_op_chain(): +def test_fq_qat_positive_couple(): + # Several consecutive operations are converted. shape_x = [1, 2, 4] shape_w = [2] a = relay.var("a", shape=shape_x, dtype="int8") @@ -789,10 +786,11 @@ def test_fq2i_optional_op_chain(): x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8") w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8") - compare_expected_fq_to_int(expr, expected_expr, [x_np, w_np], False) + compare_expected_fq_qat_to_int(expr, expected_expr, [x_np, w_np]) -def test_fq2i_optional_one_arg(): +def test_fq_positive_single_arg_part(): + # The single-argument operation is converted. shape_x = [1, 2, 4] a = relay.var("a", shape=shape_x, dtype="int8") @@ -805,10 +803,11 @@ def test_fq2i_optional_one_arg(): op1 = relay.qnn.op.dequantize(op0, relay.const(2.0), relay.const(0)) expected_expr = relay.op.erf(op1) x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8") - compare_expected_fq_to_int(expr, expected_expr, [x_np], False) + compare_expected_fq_qat_to_int(expr, expected_expr, [x_np]) -def test_fq2i_optional_intermediate_infertype(): +def test_fq_qat_intermediate_infertype(): + # Complex conversion of non-QAT and QAT passes that form FakeQuantizationToInteger. shape_x = [1, 2, 4] x = relay.var("x", shape=shape_x, dtype="float32") const_0 = relay.const(np.random.uniform(size=[1, 4, 2]).astype("float32")) @@ -836,8 +835,10 @@ def test_fq2i_optional_intermediate_infertype(): expected_expr = relay.op.add(op5, relay.const(5.0)) x_np = np.random.randint(-128, 127, size=shape_x, dtype="int32").astype("float32") - compare_expected_fq_to_int(expr, expected_expr, [x_np], False) + compare_expected_fq_qat_to_int(expr, expected_expr, [x_np]) if __name__ == "__main__": + import sys + sys.exit(pytest.main([__file__] + sys.argv[1:])) From d77b7af34e04e914a4dcb4ef8338a5a46438bbb0 Mon Sep 17 00:00:00 2001 From: Alexey Voronov Date: Thu, 17 Feb 2022 10:10:32 +0300 Subject: [PATCH 3/3] apply review comments 2 --- src/relay/transforms/fake_quantization_to_integer.cc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/relay/transforms/fake_quantization_to_integer.cc b/src/relay/transforms/fake_quantization_to_integer.cc index 1c7b6b6c37cb..4fd034edc199 100644 --- a/src/relay/transforms/fake_quantization_to_integer.cc +++ b/src/relay/transforms/fake_quantization_to_integer.cc @@ -308,7 +308,7 @@ class FakeQuantizationRewriter : public MixedModeMutator { bool is_op_enabled_for_optional_fq2i(const CallNode* call_node) { const Op op = Downcast(call_node->op); static auto fqfq = Op::GetAttrMap("FTVMFakeQuantizationToInteger"); - static std::unordered_set ops = { + static std::unordered_set ops = { Op::Get("broadcast_to"), Op::Get("clip"), Op::Get("expand_dims"), @@ -334,9 +334,7 @@ bool is_op_enabled_for_optional_fq2i(const CallNode* call_node) { Op::Get("strided_slice"), Op::Get("transpose")}; - auto is_enabled = [&](const auto i) { return i == call_node->op; }; - auto result = std::find_if(std::begin(ops), std::end(ops), is_enabled); - return result != ops.end() && fqfq.count(Downcast(op)); + return ops.find(call_node->op) != ops.end() && fqfq.count(Downcast(op)); } class QATSubgraphExtractor : public ExprVisitor {