Skip to content

Commit

Permalink
Add a conversion of individual operations in FQ2I pass.
Browse files Browse the repository at this point in the history
  • Loading branch information
Icemist committed Feb 14, 2022
1 parent 55849e6 commit c595a4d
Show file tree
Hide file tree
Showing 6 changed files with 528 additions and 8 deletions.
11 changes: 11 additions & 0 deletions include/tvm/relay/dataflow_matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,17 @@ Expr RewritePatterns(Array<DFPatternCallback> callbacks, Expr expr, IRModule mod
*/
Expr PartitionPattern(DFPattern pattern, Expr expr, Map<String, ObjectRef> attrs, PackedFunc check);

/*!
* \brief Infer the type of an expression.
*
* \param expr The expression to rewrite
*
* \return Return An Expr with unambiguous type information filled in, as well as it's
* checked type field populated with the result type.
*
*/
Expr InferType(const Expr& expr);

} // namespace relay
} // namespace tvm

Expand Down
39 changes: 39 additions & 0 deletions include/tvm/relay/qnn/op/dequantize.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relay/executor.h
* \brief Relay dequantize.
*/
#ifndef TVM_RELAY_QNN_OP_DEQUANTIZE_H_
#define TVM_RELAY_QNN_OP_DEQUANTIZE_H_

#include <tvm/relay/expr.h>

namespace tvm {
namespace relay {
namespace qnn {

Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point, int axis);

} // namespace qnn
} // namespace relay
} // namespace tvm

#endif // TVM_RELAY_QNN_OP_DEQUANTIZE_H_
10 changes: 10 additions & 0 deletions python/tvm/relay/transform/fake_quantization_to_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,16 @@ def global_avgpool2d(expr, type_map):
return [out, t]


@register_fake_quantization_to_integer("broadcast_to")
def broadcast_to(expr, type_map):
"""Rewrite a broadcast_to op"""
arg = expr.args[0]
t = type_map[arg]
shape = expr.attrs.shape
out = relay.op.broadcast_to(arg, shape)
return [out, t]


@register_fake_quantization_to_integer("rsqrt")
def rsqrt(expr, type_map):
"""Rewrite a rsqrt op"""
Expand Down
254 changes: 251 additions & 3 deletions src/relay/transforms/fake_quantization_to_integer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,13 @@

#include "fake_quantization_to_integer.h"

#include <tvm/ir/affine_type.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/dataflow_matcher.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/qnn/attrs.h>
#include <tvm/relay/qnn/op/dequantize.h>
#include <tvm/relay/transform.h>

#include <unordered_map>
Expand All @@ -37,7 +41,8 @@ namespace relay {

/* Description of FakeQuantizationToInteger
*
* The purpose of this pass is to find regions of the graph that follow
* This pass consists of two parts, a basic one and an optional one.
* The purpose of the basic part is to find regions of the graph that follow
* the general pattern:
*
* x w
Expand All @@ -52,7 +57,7 @@ namespace relay {
*
* and convert them into subgraphs with actual integer operations on x and w
*
* The pass does this via a multi-pass approach:
* The basic part does this via a multi-pass approach:
*
* The main pass is a MixedModeMutator that traverses the full graph searching for
* quantize operations
Expand All @@ -69,6 +74,24 @@ namespace relay {
*
* After the second and third passes run, the first pass replaces the quantize with the
* rewritten subgraph and the processing continues
*
* The main idea of the optional part is to find and transform operations with dequantized inputs
* one by one individually. Only operations from the allowed list are allowed. For example, if on
* the above general pattern op2 is not registered with the FTVMFakeQuantizationToInteger
* attribute, op1 operation can still be converted. Converted pattern below:
*
* x w
* | |
* \ /
* op1
* |
* dq
* |
* op2
* |
* q
*
* The optional part works in the same multi-pass approach.
*/

using ExprSet = std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual>;
Expand Down Expand Up @@ -270,8 +293,233 @@ class FakeQuantizationRewriter : public MixedModeMutator {
const bool hard_fail_;
};

bool is_op_enabled_for_optional_fq2i(const CallNode* call_node) {
const Op op = Downcast<Op>(call_node->op);
static auto fqfq = Op::GetAttrMap<FTVMFakeQuantizationToInteger>("FTVMFakeQuantizationToInteger");
static std::unordered_set<Op, tvm::ObjectHash, tvm::ObjectEqual> ops = {
Op::Get("reshape"),
Op::Get("squeeze"),
Op::Get("strided_slice"),
Op::Get("transpose"),
Op::Get("expand_dims"),
Op::Get("nn.max_pool2d"),
Op::Get("nn.batch_flatten"),
Op::Get("nn.depth_to_space"),
Op::Get("max"),
Op::Get("min"),
Op::Get("nn.avg_pool2d"),
Op::Get("nn.global_avg_pool2d"),
Op::Get("nn.bias_add"),
Op::Get("nn.conv2d"),
Op::Get("nn.conv2d_transpose"),
Op::Get("nn.dense"),
Op::Get("nn.batch_matmul"),
Op::Get("split"),
Op::Get("clip"),
Op::Get("nn.relu"),
Op::Get("nn.pad"),
Op::Get("broadcast_to"),
Op::Get("minimum"),
Op::Get("maximum")};

auto is_enabled = [&](const auto i) { return i == call_node->op; };
auto result = std::find_if(std::begin(ops), std::end(ops), is_enabled);
return result != ops.end() && fqfq.count(Downcast<Op>(op));
}

class OptionalSubgraphExtractor : public ExprVisitor {
public:
const ExprSet GetSubgraph(const Expr& expr) {
expr_call_node_ = expr.as<CallNode>();
ICHECK(expr_call_node_ != nullptr);
ICHECK(is_op_enabled_for_optional_fq2i(expr_call_node_));

VisitExpr(expr);

ExprSet subgraph;
if (is_fake_quantized_) {
for (auto kv : this->visit_counter_) {
if (auto call_node = GetRef<ObjectRef>(kv.first).as<CallNode>()) {
if (call_node != expr_call_node_) {
subgraph.insert(Downcast<Expr>(GetRef<ObjectRef>(kv.first)));
}
}
}
}
return subgraph;
}
const AffineTypeMap GetAffineTypes() { return affine_types_; }
void VisitExpr(const Expr& expr) override {
// When looking for fake quantized subgraphs, we only support data-flow regions of the graph,
// i.e. call nodes/tuples/constants/etc. If we see anything else (like control flow) we
// abort the rewrite.
if (expr.as<CallNode>() == nullptr && expr.as<OpNode>() == nullptr &&
expr.as<TupleNode>() == nullptr && expr.as<TupleGetItemNode>() == nullptr &&
expr.as<ConstantNode>() == nullptr) {
DLOG(INFO) << "FakeQuantizationToInteger found a non - dataflow op inside a fake quantize "
"region, aborting this rewrite";
is_fake_quantized_ = false;
} else {
ExprVisitor::VisitExpr(expr);
}
}

protected:
void VisitExpr_(const CallNode* call_node) override {
if (call_node->op == dequantize_op_) {
const auto* attrs = call_node->attrs.as<qnn::DequantizeAttrs>();
ICHECK(attrs != nullptr);

affine_types_.Set(
GetRef<Expr>(call_node),
TensorAffineType(
call_node->args[1], call_node->args[2],
tvm::relay::transform::InferTypeLocal(call_node->args[0]).as<TensorTypeNode>()->dtype,
attrs->axis));
} else if (call_node == expr_call_node_) {
for (auto arg : call_node->args) {
VisitExpr(arg);
}
} else {
// run normally on everything else.
ExprVisitor::VisitExpr_(call_node);
}
}

const Op dequantize_op_ = Op::Get("qnn.dequantize");
bool is_fake_quantized_ = true;
AffineTypeMap affine_types_;
const CallNode* expr_call_node_ = nullptr;
};

class OptionalSubgraphMutator : public ExprMutator {
public:
OptionalSubgraphMutator(ExprSet subgraph, AffineTypeMap affine_types, bool hard_fail)
: subgraph_(subgraph), affine_types_(affine_types), hard_fail_(hard_fail) {}

Expr MutateSubgraph(const Expr& expr) {
if (subgraph_.size() == 0) {
return expr;
}

quantize_node_ = expr.as<CallNode>();
ICHECK(quantize_node_);
ICHECK(is_op_enabled_for_optional_fq2i(quantize_node_));

for (auto node : subgraph_) {
const Op op = Downcast<Op>(node.as<CallNode>()->op);

if (node.as<CallNode>()->op != dequantize_op_) {
// Only modify the subgraph if we have translation
// rules for every op
if (hard_fail_) {
LOG(FATAL) << "Found no rewrite rule for " << AsText(op, false) << std::endl;
} else {
DLOG(INFO) << "Found no rewrite rule for " << AsText(op, false) << std::endl;
return expr;
}
}
}
try {
return Mutate(expr);
} catch (std::exception& e) {
if (hard_fail_) {
throw e;
} else {
DLOG(INFO) << "Ran into an error rewriting a subgraph, skipping" << expr << std::endl;
return expr;
}
}
}

protected:
Expr VisitExpr_(const CallNode* call_node) {
Expr out;
static auto fqfq =
Op::GetAttrMap<FTVMFakeQuantizationToInteger>("FTVMFakeQuantizationToInteger");

Op op = Downcast<Op>(call_node->op);
if (fqfq.count(op)) {
Expr expr;
if (op == dequantize_op_) {
expr = GetRef<Expr>(call_node);
} else {
expr = ExprMutator::VisitExpr_(call_node);
}
// Call the rewrite
Array<ObjectRef> vals = fqfq[op](expr, affine_types_);
// Save the outputs of the rewrite
ICHECK(vals.size() == 2)
<< "got the wrong number of returned arguments from FTVMFakeQuantizationToInteger for "
<< AsText(op, false);
out = Downcast<Expr>(vals[0]);

affine_types_.Set(out, Downcast<AffineType>(vals[1]));

if (call_node == quantize_node_) {
out = qnn::MakeDequantize(out, vals[1].as<TensorAffineTypeNode>()->scale,
vals[1].as<TensorAffineTypeNode>()->zero_point,
vals[1].as<TensorAffineTypeNode>()->axis);
}
} else {
ICHECK(false) << "When rewriting a fake quantized graph, found an invalid node "
<< AsText(GetRef<Expr>(call_node), false);
}
return out;
}

Expr VisitExpr_(const TupleNode* node) {
Expr expr = ExprMutator::VisitExpr_(node);
auto new_node = expr.as<TupleNode>();
Array<TensorAffineType> types;
for (Expr field : new_node->fields) {
ICHECK(affine_types_[field].as<TensorAffineTypeNode>());
types.push_back(Downcast<TensorAffineType>(affine_types_[field]));
}
affine_types_.Set(expr, TupleAffineType(types));
return expr;
}

Expr VisitExpr_(const TupleGetItemNode* node) {
Expr expr = ExprMutator::VisitExpr_(node);
auto tuple_type = affine_types_[expr.as<TupleGetItemNode>()->tuple].as<TupleAffineTypeNode>();
affine_types_.Set(expr, tuple_type->types[node->index]);
return expr;
}

ExprSet subgraph_;
AffineTypeMap affine_types_;
const bool hard_fail_;
const Op dequantize_op_ = Op::Get("qnn.dequantize");
const CallNode* quantize_node_ = nullptr;
};

class OptionalFakeQuantizationRewriter : public MixedModeMutator {
public:
explicit OptionalFakeQuantizationRewriter(bool hard_fail) : hard_fail_(hard_fail) {}

protected:
Expr Rewrite_(const CallNode* pre, const Expr& post) override {
if (const CallNode* call_node = post.as<CallNode>()) {
const Op op = Downcast<Op>(call_node->op);
if (is_op_enabled_for_optional_fq2i(call_node)) {
OptionalSubgraphExtractor extractor;
ExprSet subgraph = extractor.GetSubgraph(post);
AffineTypeMap affine_types = extractor.GetAffineTypes();
Expr out = OptionalSubgraphMutator(subgraph, affine_types, hard_fail_).MutateSubgraph(post);
return out;
}
}
return post;
}
const bool hard_fail_;
};

Expr FakeQuantizationToInteger(const Expr& expr, const IRModule& mod, bool hard_fail) {
return FakeQuantizationRewriter(hard_fail).Mutate(expr);
auto fq_expr = FakeQuantizationRewriter(hard_fail).Mutate(expr);
auto fq_inferred_expr = tvm::relay::InferType(fq_expr);
auto ofq_expr = OptionalFakeQuantizationRewriter(hard_fail).Mutate(fq_inferred_expr);
return ofq_expr;
}

namespace transform {
Expand Down
7 changes: 2 additions & 5 deletions src/relay/transforms/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include <tvm/ir/transform.h>
#include <tvm/ir/type_functor.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/dataflow_matcher.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/transform.h>
Expand Down Expand Up @@ -918,11 +919,7 @@ Type InferTypeLocal(const Expr& expr) {
mod = transform::InferType()(mod);

Type result_type;
if (expr.as<FunctionNode>()) {
result_type = mod->Lookup("main")->checked_type();
} else {
result_type = mod->Lookup("main").as<FunctionNode>()->body->checked_type();
}
result_type = relay::InferType(sub_graph)->checked_type();

expr->checked_type_ = result_type;
return result_type;
Expand Down
Loading

0 comments on commit c595a4d

Please sign in to comment.