Skip to content

Commit

Permalink
[Relay][Quantization] Speed-aware quantization scheme improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Mar 4, 2019
1 parent c8373ec commit 865b37c
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 4 deletions.
10 changes: 10 additions & 0 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"FoldScaleAxis": 3,
"AlterOpLayout": 3,
"CanonicalizeOps": 3,
"EliminateCommonSubexpr": 3,
}


Expand Down Expand Up @@ -162,6 +163,15 @@ def optimize(func, target=None, params=None):
func = ir_pass.infer_type(func)
func = ir_pass.simplify_inference(func)

if cfg.pass_enabled("EliminateCommonSubexpr"):
def fskip(e):
if isinstance(e, expr.Call) and e.op.name == 'cast' and e.attrs.dtype == 'int32':
return True
return False

func = ir_pass.infer_type(func)
func = ir_pass.eliminate_common_subexpr(func, fskip)

if cfg.pass_enabled("CombineParallelConv2D"):
func = ir_pass.infer_type(func)
func = ir_pass.combine_parallel_conv2d(func)
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/quantize/_annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ def add_rewrite(ref_call, new_args, ctx):
else:
# quantize rhs to INPUT field if it is not Constant
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.ACTIVATION:
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)

expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
Expand Down
21 changes: 17 additions & 4 deletions src/relay/pass/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ TVM_REGISTER_API("relay._quantize.annotate")
}
return e;
};
return ForwardRewrite(expr, "FQAnnotateRewrite", nullptr, nullptr);
return ForwardRewrite(expr, "FQAnnotateRewrite", nullptr, fmulti_ref);
});


Expand Down Expand Up @@ -329,9 +329,11 @@ float ChooseDomScale(const std::vector<const QRealizeIntExprNode*>& nptrs) {


/* \brief Unify the dom scale of arguments */
Array<Expr> UnifyDTypeScale(const Array<Expr>& args,
Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args,
const Array<Expr>& args,
DataType* dtype_ptr,
Expr* scale_ptr) {
static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
const QConfig& cfg = QConfig::Current();

std::vector<const QRealizeIntExprNode*> nptrs;
Expand All @@ -344,10 +346,17 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& args,
}

// unify the data type
CHECK_EQ(ref_args.size(), args.size());
DataType dtype = cfg->dtype_activation;
for (size_t i = 0; i < ret.size(); ++i) {
auto ref_arg = ref_args[i].as<CallNode>();
if (nptrs[i]->dtype != dtype) {
ret.Set(i, Cast(ret[i], dtype));
} else if (ref_arg && ref_arg->op.same_as(simulated_quantize) &&
ref_arg->attrs.as<SimulatedQuantizeAttrs>()->kind == kQInput) {
auto new_arg = Cast(ret[i], cfg->dtype_input);
new_arg = StopFusion(new_arg);
ret.Set(i, Cast(new_arg, dtype));
}
}

Expand All @@ -371,7 +380,7 @@ Expr AddRealize(const Call& ref_call,
if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
DataType dtype;
Expr dom_scale;
Array<Expr> ret_args = UnifyDTypeScale(new_args, &dtype, &dom_scale);
Array<Expr> ret_args = UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale);
Expr ret = ForwardOp(ref_call, ret_args);
return QRealizeIntExprNode::make(ret, dom_scale, dtype);
}
Expand All @@ -387,15 +396,19 @@ Expr ConcatenateRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
CHECK_EQ(new_args.size(), 1);
CHECK_EQ(ref_call->args.size(), 1);

const auto* tuple = new_args[0].as<TupleNode>();
const auto* ref_tuple = ref_call->args[0].as<TupleNode>();
CHECK(tuple);
CHECK(ref_tuple);
const Array<Expr>& arr = tuple->fields;
const Array<Expr>& ref_arr = ref_tuple->fields;

if (arr[0].as<QRealizeIntExprNode>()) {
DataType dtype;
Expr dom_scale;
Array<Expr> ret_args = UnifyDTypeScale(arr, &dtype, &dom_scale);
Array<Expr> ret_args = UnifyDTypeScale(ref_arr, arr, &dtype, &dom_scale);
Expr ret = ForwardOp(ref_call, {TupleNode::make(ret_args)});
return QRealizeIntExprNode::make(ret, dom_scale, dtype);
} else {
Expand Down

0 comments on commit 865b37c

Please sign in to comment.