Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a conversion of individual operations in FQ2I pass. #10239

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions include/tvm/relay/dataflow_matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,17 @@ Expr RewritePatterns(Array<DFPatternCallback> callbacks, Expr expr, IRModule mod
*/
Expr PartitionPattern(DFPattern pattern, Expr expr, Map<String, ObjectRef> attrs, PackedFunc check);

/*!
* \brief Infer the type of an expression.
*
* \param expr The expression to rewrite
*
* \return Return An Expr with unambiguous type information filled in, as well as it's
* checked type field populated with the result type.
*
*/
Expr InferType(const Expr& expr);

} // namespace relay
} // namespace tvm

Expand Down
10 changes: 10 additions & 0 deletions python/tvm/relay/transform/fake_quantization_to_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,16 @@ def global_avgpool2d(expr, type_map):
return [out, t]


@register_fake_quantization_to_integer("broadcast_to")
def broadcast_to(expr, type_map):
"""Rewrite a broadcast_to op"""
arg = expr.args[0]
t = type_map[arg]
shape = expr.attrs.shape
out = relay.op.broadcast_to(arg, shape)
return [out, t]


@register_fake_quantization_to_integer("rsqrt")
def rsqrt(expr, type_map):
"""Rewrite a rsqrt op"""
Expand Down
24 changes: 21 additions & 3 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,7 +1238,7 @@ def AnnotateSpans():
return _ffi_api.AnnotateSpans()


def FakeQuantizationToInteger(hard_fail=False):
def FakeQuantizationToInteger(hard_fail=False, use_qat=False):
# pylint: disable=anomalous-backslash-in-string
"""
Find regions of the graph of the form
Expand Down Expand Up @@ -1267,12 +1267,30 @@ def FakeQuantizationToInteger(hard_fail=False):
If true, raise an error.
If false, skip rewriting the subgraph.

use_qat : boolean
To perform an additional QAT pass - convert enabled operations with dequantized inputs.
Example: in the graph above op2 is not registered with the FakeQuantizationToInteger
attribute, op1 operation can still be converted. Converted pattern below:

.. code-block:: text

x w
| |
\\ /
op1
|
dq
|
op2
|
q

Returns
-------
ret : tvm.transform.Pass
The registered SimplifyExpr pass.
The registered FakeQuantizationToInteger pass.
"""
return _ffi_api.FakeQuantizationToInteger(hard_fail)
return _ffi_api.FakeQuantizationToInteger(hard_fail, use_qat)


def ToMixedPrecision(mixed_precision_type="float16", missing_op_mode=1):
Expand Down
269 changes: 265 additions & 4 deletions src/relay/transforms/fake_quantization_to_integer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,18 @@

#include "fake_quantization_to_integer.h"

#include <tvm/ir/affine_type.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/dataflow_matcher.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/qnn/attrs.h>
#include <tvm/relay/transform.h>

#include <unordered_map>

#include "../qnn/utils.h"

namespace tvm {
namespace relay {

Expand Down Expand Up @@ -69,6 +74,28 @@ namespace relay {
*
* After the second and third passes run, the first pass replaces the quantize with the
* rewritten subgraph and the processing continues
*
*
* After that an additional QAT pass can be enabled by use_qat flag. The goal of the pass is to find
* operations in those regions(which were not successfully converted by the main pass) that can
* still be converted into quantized form. The idea is to find and transform operations with
* dequantized inputs one by one individually. Only operations for which all parameters can be
* explicitly calculated are allowed. For example, if on the above general pattern op2 is not
* registered with the FTVMFakeQuantizationToInteger attribute, op1 operation can still be
* converted. Converted pattern below:
*
* x w
* | |
* \ /
* op1
* |
* dq
* |
* op2
* |
* q
*
* This pass works in the same multi-pass approach.
*/

using ExprSet = std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual>;
Expand Down Expand Up @@ -270,16 +297,250 @@ class FakeQuantizationRewriter : public MixedModeMutator {
const bool hard_fail_;
};

Expr FakeQuantizationToInteger(const Expr& expr, const IRModule& mod, bool hard_fail) {
return FakeQuantizationRewriter(hard_fail).Mutate(expr);
/* Checks if the operation to convert QAT pass is enabled.
* The following conditions must be satisfied:
* 1. operations registered for FTVMFakeQuantizationToInteger;
* 2. Unary operators or operators with with the TensorAffineType calculated during
* FTVMFakeQuantizationToInteger conversion;
* 3. Not one of the "key" operations: requantize,quantize and dequantize(they are at the boundaries
* of regions defined to be quantized).
*/
bool is_op_enabled_for_optional_fq2i(const CallNode* call_node) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add some comments and advice about what ops should be included in this list?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a comment of how I selected these operations.

const Op op = Downcast<Op>(call_node->op);
static auto fqfq = Op::GetAttrMap<FTVMFakeQuantizationToInteger>("FTVMFakeQuantizationToInteger");
static std::unordered_set<relay::Expr, tvm::ObjectHash, tvm::ObjectEqual> ops = {
Op::Get("broadcast_to"),
Op::Get("clip"),
Op::Get("expand_dims"),
Op::Get("max"),
Op::Get("maximum"),
Op::Get("min"),
Op::Get("minimum"),
Op::Get("nn.avg_pool2d"),
Op::Get("nn.batch_flatten"),
Op::Get("nn.batch_matmul"),
Op::Get("nn.bias_add"),
Op::Get("nn.conv2d"),
Op::Get("nn.conv2d_transpose"),
Op::Get("nn.dense"),
Op::Get("nn.depth_to_space"),
Op::Get("nn.global_avg_pool2d"),
Op::Get("nn.max_pool2d"),
Op::Get("nn.pad"),
Op::Get("nn.relu"),
Op::Get("reshape"),
Op::Get("split"),
Op::Get("squeeze"),
Op::Get("strided_slice"),
Op::Get("transpose")};

return ops.find(call_node->op) != ops.end() && fqfq.count(Downcast<Op>(op));
}

class QATSubgraphExtractor : public ExprVisitor {
public:
const ExprSet GetSubgraph(const Expr& expr) {
expr_call_node_ = expr.as<CallNode>();
ICHECK(expr_call_node_ != nullptr);
ICHECK(is_op_enabled_for_optional_fq2i(expr_call_node_));

VisitExpr(expr);

ExprSet subgraph;
if (is_fake_quantized_) {
for (auto kv : this->visit_counter_) {
if (auto call_node = GetRef<ObjectRef>(kv.first).as<CallNode>()) {
if (call_node != expr_call_node_) {
subgraph.insert(Downcast<Expr>(GetRef<ObjectRef>(kv.first)));
}
}
}
}
return subgraph;
}
const AffineTypeMap GetAffineTypes() { return affine_types_; }
void VisitExpr(const Expr& expr) override {
// When looking for fake quantized subgraphs, we only support data-flow regions of the graph,
// i.e. call nodes/tuples/constants/etc. If we see anything else (like control flow) we
// abort the rewrite.
if (expr.as<CallNode>() == nullptr && expr.as<OpNode>() == nullptr &&
expr.as<TupleNode>() == nullptr && expr.as<TupleGetItemNode>() == nullptr &&
expr.as<ConstantNode>() == nullptr) {
DLOG(INFO) << "FakeQuantizationToInteger found a non - dataflow op inside a fake quantize "
"region, aborting this rewrite";
is_fake_quantized_ = false;
} else {
ExprVisitor::VisitExpr(expr);
}
}

protected:
void VisitExpr_(const CallNode* call_node) override {
if (call_node->op == dequantize_op_) {
const auto* attrs = call_node->attrs.as<qnn::DequantizeAttrs>();
ICHECK(attrs != nullptr);

affine_types_.Set(
GetRef<Expr>(call_node),
TensorAffineType(
call_node->args[1], call_node->args[2],
tvm::relay::transform::InferTypeLocal(call_node->args[0]).as<TensorTypeNode>()->dtype,
attrs->axis));
} else if (call_node == expr_call_node_) {
for (auto arg : call_node->args) {
VisitExpr(arg);
}
} else {
// run normally on everything else.
ExprVisitor::VisitExpr_(call_node);
}
}

const Op dequantize_op_ = Op::Get("qnn.dequantize");
bool is_fake_quantized_ = true;
AffineTypeMap affine_types_;
const CallNode* expr_call_node_ = nullptr;
};

class QATSubgraphMutator : public ExprMutator {
public:
QATSubgraphMutator(ExprSet subgraph, AffineTypeMap affine_types, bool hard_fail)
: subgraph_(subgraph), affine_types_(affine_types), hard_fail_(hard_fail) {}

Expr MutateSubgraph(const Expr& expr) {
if (subgraph_.size() == 0) {
return expr;
}

quantize_node_ = expr.as<CallNode>();
ICHECK(quantize_node_);
ICHECK(is_op_enabled_for_optional_fq2i(quantize_node_));

for (auto node : subgraph_) {
const Op op = Downcast<Op>(node.as<CallNode>()->op);

if (node.as<CallNode>()->op != dequantize_op_) {
if (hard_fail_) {
LOG(FATAL) << "Not dequantization was found in the input arguments for"
<< AsText(op, false) << std::endl;
} else {
DLOG(INFO) << "Not dequantization was found in the input arguments for "
<< AsText(op, false) << std::endl;
return expr;
}
}
}
try {
return Mutate(expr);
} catch (std::exception& e) {
if (hard_fail_) {
throw e;
} else {
DLOG(INFO) << "Ran into an error rewriting a subgraph, skipping" << expr << std::endl;
return expr;
}
}
}

protected:
Expr VisitExpr_(const CallNode* call_node) {
Expr out;
static auto fqfq =
Op::GetAttrMap<FTVMFakeQuantizationToInteger>("FTVMFakeQuantizationToInteger");

Op op = Downcast<Op>(call_node->op);
if (fqfq.count(op)) {
Expr expr;
if (op == dequantize_op_) {
expr = GetRef<Expr>(call_node);
} else {
expr = ExprMutator::VisitExpr_(call_node);
}
// Call the rewrite
Array<ObjectRef> vals = fqfq[op](expr, affine_types_);
// Save the outputs of the rewrite
ICHECK(vals.size() == 2)
<< "got the wrong number of returned arguments from FTVMFakeQuantizationToInteger for "
<< AsText(op, false);
out = Downcast<Expr>(vals[0]);

affine_types_.Set(out, Downcast<AffineType>(vals[1]));

if (call_node == quantize_node_) {
out = qnn::MakeDequantize(out, vals[1].as<TensorAffineTypeNode>()->scale,
vals[1].as<TensorAffineTypeNode>()->zero_point,
vals[1].as<TensorAffineTypeNode>()->axis);
}
} else {
ICHECK(false) << "When rewriting a fake quantized graph, found an invalid node "
<< AsText(GetRef<Expr>(call_node), false);
}
return out;
}

Expr VisitExpr_(const TupleNode* node) {
Expr expr = ExprMutator::VisitExpr_(node);
auto new_node = expr.as<TupleNode>();
Array<TensorAffineType> types;
for (Expr field : new_node->fields) {
ICHECK(affine_types_[field].as<TensorAffineTypeNode>());
types.push_back(Downcast<TensorAffineType>(affine_types_[field]));
}
affine_types_.Set(expr, TupleAffineType(types));
return expr;
}

Expr VisitExpr_(const TupleGetItemNode* node) {
Expr expr = ExprMutator::VisitExpr_(node);
auto tuple_type = affine_types_[expr.as<TupleGetItemNode>()->tuple].as<TupleAffineTypeNode>();
affine_types_.Set(expr, tuple_type->types[node->index]);
return expr;
}

ExprSet subgraph_;
AffineTypeMap affine_types_;
const bool hard_fail_;
const Op dequantize_op_ = Op::Get("qnn.dequantize");
const CallNode* quantize_node_ = nullptr;
};

class QATRewriter : public MixedModeMutator {
public:
explicit QATRewriter(bool hard_fail) : hard_fail_(hard_fail) {}

protected:
Expr Rewrite_(const CallNode* pre, const Expr& post) override {
if (const CallNode* call_node = post.as<CallNode>()) {
const Op op = Downcast<Op>(call_node->op);
if (is_op_enabled_for_optional_fq2i(call_node)) {
QATSubgraphExtractor extractor;
ExprSet subgraph = extractor.GetSubgraph(post);
AffineTypeMap affine_types = extractor.GetAffineTypes();
Expr out = QATSubgraphMutator(subgraph, affine_types, hard_fail_).MutateSubgraph(post);
return out;
}
}
return post;
}
const bool hard_fail_;
};

Expr FakeQuantizationToInteger(const Expr& expr, const IRModule& mod, bool hard_fail,
bool use_qat) {
auto fq_expr = FakeQuantizationRewriter(hard_fail).Mutate(expr);
if (use_qat) {
fq_expr = tvm::relay::InferType(fq_expr);
fq_expr = QATRewriter(hard_fail).Mutate(fq_expr);
}
return fq_expr;
}

namespace transform {

Pass FakeQuantizationToInteger(bool hard_fail) {
Pass FakeQuantizationToInteger(bool hard_fail, bool use_qat) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(FakeQuantizationToInteger(f, m, hard_fail));
return Downcast<Function>(FakeQuantizationToInteger(f, m, hard_fail, use_qat));
};
return CreateFunctionPass(pass_func, 0, "FakeQuantizationToInteger", {"InferType"});
}
Expand Down
Loading