Skip to content

Commit

Permalink
[Relay][Quantization] Speed-aware quantization scheme improvement (#2723
Browse files Browse the repository at this point in the history
)

* [Relay][Quantization] Speed-aware quantization scheme improvement

* Add comment

* Add use_stop_fusion to qconfig

* Update comment
  • Loading branch information
vinx13 authored and ZihengJiang committed Mar 9, 2019
1 parent b0a0ae4 commit 21e8dfa
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 8 deletions.
17 changes: 14 additions & 3 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .. import nd as _nd, target as _target, autotvm
from ..contrib import graph_runtime as _graph_rt
from . import ir_pass
from . import expr
from . import expr as _expr
from .backend import interpreter as _interpreter
from .backend import graph_runtime_codegen as _graph_gen

Expand All @@ -22,6 +22,7 @@
"FoldScaleAxis": 3,
"AlterOpLayout": 3,
"CanonicalizeOps": 3,
"EliminateCommonSubexpr": 3,
}


Expand Down Expand Up @@ -126,8 +127,8 @@ def _bind_params_by_name(func, params):
arg = name_dict[k]
if arg is None:
raise ValueError("Multiple args in the function have name %s" % k)
bind_dict[arg] = expr.const(v)
return expr.bind(func, bind_dict)
bind_dict[arg] = _expr.const(v)
return _expr.bind(func, bind_dict)


def optimize(func, target=None, params=None):
Expand Down Expand Up @@ -162,6 +163,16 @@ def optimize(func, target=None, params=None):
func = ir_pass.infer_type(func)
func = ir_pass.simplify_inference(func)

if cfg.pass_enabled("EliminateCommonSubexpr"):
def fskip(expr):
if isinstance(expr, _expr.Call) and expr.op.name == 'cast' and \
expr.attrs.dtype == 'int32':
return True
return False

func = ir_pass.infer_type(func)
func = ir_pass.eliminate_common_subexpr(func, fskip)

if cfg.pass_enabled("CombineParallelConv2D"):
func = ir_pass.infer_type(func)
func = ir_pass.combine_parallel_conv2d(func)
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/relay/quantize/_annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ def add_rewrite(ref_call, new_args, ctx):
else:
# quantize rhs to INPUT field if it is not Constant
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.ACTIVATION:
# quantize rhs to INPUT field if both lhs and rhs are ACTIVATION
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)

expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/quantize/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class QConfig(NodeBase):
"round_for_shift": True,
"store_lowbit_output": True,
"debug_enabled_ops": None,
"use_stop_fusion": True
}

# pylint: disable=no-member
Expand Down Expand Up @@ -129,6 +130,10 @@ def qconfig(**kwargs):
Whether to store low-bit integer back as output before dequantizing.
Some accelerators need this, e.g. VTA.
use_stop_fusion: boolean
Whether add stop_fusion when casting to dtype_activation. stop_fusion forces lowbit
results to be stored in memory.
Returns
-------
config: QConfig
Expand Down
26 changes: 21 additions & 5 deletions src/relay/pass/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ TVM_REGISTER_API("relay._quantize.annotate")
}
return e;
};
return ForwardRewrite(expr, "FQAnnotateRewrite", nullptr, nullptr);
return ForwardRewrite(expr, "FQAnnotateRewrite", nullptr, fmulti_ref);
});


Expand Down Expand Up @@ -329,9 +329,11 @@ float ChooseDomScale(const std::vector<const QRealizeIntExprNode*>& nptrs) {


/* \brief Unify the dom scale of arguments */
Array<Expr> UnifyDTypeScale(const Array<Expr>& args,
Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args,
const Array<Expr>& args,
DataType* dtype_ptr,
Expr* scale_ptr) {
static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
const QConfig& cfg = QConfig::Current();

std::vector<const QRealizeIntExprNode*> nptrs;
Expand All @@ -344,10 +346,19 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& args,
}

// unify the data type
CHECK_EQ(ref_args.size(), args.size());
DataType dtype = cfg->dtype_activation;
for (size_t i = 0; i < ret.size(); ++i) {
auto ref_arg = ref_args[i].as<CallNode>();
if (nptrs[i]->dtype != dtype) {
ret.Set(i, Cast(ret[i], dtype));
} else if (ref_arg && ref_arg->op.same_as(simulated_quantize) &&
ref_arg->attrs.as<SimulatedQuantizeAttrs>()->kind == kQInput) {
auto new_arg = Cast(ret[i], cfg->dtype_input);
if (cfg->use_stop_fusion) {
new_arg = StopFusion(new_arg);
}
ret.Set(i, Cast(new_arg, dtype));
}
}

Expand All @@ -371,7 +382,7 @@ Expr AddRealize(const Call& ref_call,
if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
DataType dtype;
Expr dom_scale;
Array<Expr> ret_args = UnifyDTypeScale(new_args, &dtype, &dom_scale);
Array<Expr> ret_args = UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale);
Expr ret = ForwardOp(ref_call, ret_args);
return QRealizeIntExprNode::make(ret, dom_scale, dtype);
}
Expand All @@ -387,15 +398,19 @@ Expr ConcatenateRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
CHECK_EQ(new_args.size(), 1);
CHECK_EQ(ref_call->args.size(), 1);

const auto* tuple = new_args[0].as<TupleNode>();
const auto* ref_tuple = ref_call->args[0].as<TupleNode>();
CHECK(tuple);
CHECK(ref_tuple);
const Array<Expr>& arr = tuple->fields;
const Array<Expr>& ref_arr = ref_tuple->fields;

if (arr[0].as<QRealizeIntExprNode>()) {
DataType dtype;
Expr dom_scale;
Array<Expr> ret_args = UnifyDTypeScale(arr, &dtype, &dom_scale);
Array<Expr> ret_args = UnifyDTypeScale(ref_arr, arr, &dtype, &dom_scale);
Expr ret = ForwardOp(ref_call, {TupleNode::make(ret_args)});
return QRealizeIntExprNode::make(ret, dom_scale, dtype);
} else {
Expand Down Expand Up @@ -530,7 +545,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "skip_k_conv==" << op->skip_k_conv << ", ";
p->stream << "round_for_shift==" << op->round_for_shift << ", ";
p->stream << "store_lowbit_output==" << op->store_lowbit_output << ", ";
p->stream << "debug_enabled_ops==" << op->debug_enabled_ops;
p->stream << "debug_enabled_ops==" << op->debug_enabled_ops << ", ";
p->stream << "use_stop_fusion==" << op->use_stop_fusion;
p->stream << ")";
});

Expand Down
2 changes: 2 additions & 0 deletions src/relay/pass/quantize.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class QConfigNode : public Node {
bool round_for_shift = true;
bool store_lowbit_output = true;
Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr));
bool use_stop_fusion = true;

void VisitAttrs(AttrVisitor* v) final {
v->Visit("nbit_input", &nbit_input);
Expand All @@ -123,6 +124,7 @@ class QConfigNode : public Node {
v->Visit("round_for_shift", &round_for_shift);
v->Visit("store_lowbit_output", &store_lowbit_output);
v->Visit("debug_enabled_ops", &debug_enabled_ops);
v->Visit("use_stop_fusion", &use_stop_fusion);
}

static constexpr const char* _type_key = "relay.quantize.QConfig";
Expand Down

0 comments on commit 21e8dfa

Please sign in to comment.