Skip to content

Commit

Permalink
[Legalize][QNN] Pass out_types to Legalize. Update QNN requantize to …
Browse files Browse the repository at this point in the history
…read from out_types.
  • Loading branch information
anijain2305 committed Aug 16, 2019
1 parent d3eb9cb commit 343eb26
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 33 deletions.
22 changes: 19 additions & 3 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,26 @@ def alter_op_layout_conv2d(attrs, inputs, tinfos):
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op)

@reg.register_legalize("nn.conv2d")
def legalize_conv2d(attrs, inputs, arg_dtypes):
"""Legalize conv2d"""
def legalize_conv2d(attrs, inputs, types):
"""Legalize conv2d op.
Parameters
----------
attrs : nnvm.top.AttrDict or tvm.attrs.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized.
types : list of types
List of input and output types
F: symbol
The context, can be either nnvm.sym or relay.op
Returns
-------
result : tvm.relay.Expr
The legalized expr.
"""
from ... import op
return topi.nn.conv2d_legalize(attrs, inputs, arg_dtypes, op)
return topi.nn.conv2d_legalize(attrs, inputs, types, op)

reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)

Expand Down
12 changes: 9 additions & 3 deletions src/relay/pass/legalize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,17 @@ Expr Legalizer(const Call& ref_call, const Array<Expr>& new_args, const NodeRef&
Expr new_e;
bool modified = false;
if (fop_legalize.count(op)) {
tvm::Array<tvm::relay::Type> arg_types;
// Collect input and output dtypes to pass on to Legalize API.
tvm::Array<tvm::relay::Type> types;
for (auto& expr : ref_call->args) {
arg_types.push_back(expr->checked_type());
types.push_back(expr->checked_type());
}
Expr legalized_value = fop_legalize[op](ref_call->attrs, new_args, arg_types);
types.push_back(ref_call->checked_type());

// Transform the op by calling the registered legalize function.
Expr legalized_value = fop_legalize[op](ref_call->attrs, new_args, types);

// Check if the transformation succeeded. If not, revert back to the original ref_call->op.
if (legalized_value.defined()) {
new_e = legalized_value;
modified = true;
Expand Down
4 changes: 2 additions & 2 deletions src/relay/qnn/op/dequantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@ Expr DequantizeLower(const Expr& input_tensor,

Expr DequantizeLegalize(const Attrs& attrs,
const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) {
const Array<tvm::relay::Type>& types) {
CHECK_EQ(new_args.size(), 1);
auto& data = new_args[0];
const auto* dequantize_attrs = attrs.as<DequantizeAttrs>();
CHECK(dequantize_attrs != nullptr);
CHECK_EQ(arg_types.size(), 1);
CHECK_EQ(types.size(), 2);
return DequantizeLower(data, dequantize_attrs);
}

Expand Down
4 changes: 2 additions & 2 deletions src/relay/qnn/op/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ Expr QuantizeLower(const Expr& input_tensor,

Expr QuantizeLegalize(const Attrs& attrs,
const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) {
const Array<tvm::relay::Type>& types) {
CHECK_EQ(new_args.size(), 1);
auto& data = new_args[0];
const auto* quantize_attrs = attrs.as<QuantizeAttrs>();
CHECK(quantize_attrs != nullptr);

CHECK_EQ(arg_types.size(), 1);
CHECK_EQ(types.size(), 2);
return QuantizeLower(data, quantize_attrs);
}

Expand Down
33 changes: 20 additions & 13 deletions src/relay/qnn/op/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplie
* 7) Cast to the out_dtype.
*/
Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
const Array<IndexExpr>& input_shape) {
const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
double double_multiplier = param->input_scale / param->output_scale;

// Choose high precision datatype to be int64. This is for avoiding overflow
Expand Down Expand Up @@ -173,10 +173,10 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
auto shifted_int64_t = Add(output_zp, scaled_int64_t);

// 7) Clip to the out_dtype min/max.
auto q_min = GetQmin(param->out_dtype);
auto q_max = GetQmax(param->out_dtype);
auto q_min = GetQmin(out_dtype);
auto q_max = GetQmax(out_dtype);
auto clipped_t = Clip(shifted_int64_t, q_min, q_max);
return Cast(clipped_t, param->out_dtype);
return Cast(clipped_t, out_dtype);
}

/*
Expand All @@ -193,25 +193,32 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
* Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input)
*/
Expr RequantizeLegalize(const Attrs& attrs, const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) {
const Array<tvm::relay::Type>& types) {
CHECK_EQ(new_args.size(), 1);
auto& quantized_data = new_args[0];
const auto* param = attrs.as<RequantizeAttrs>();
CHECK(param != nullptr);

// Find input shape.
CHECK_EQ(arg_types.size(), 1);
auto input_dtype = arg_types[0];
auto input_tensor_type = input_dtype.as<TensorTypeNode>();
CHECK(input_tensor_type != nullptr) << "Type information missing."
<< " Please run infer_type pass.";
Array<IndexExpr> input_shape = input_tensor_type->shape;
CHECK_EQ(types.size(), 2);
auto in_type = types[0];
auto in_tensor_type = in_type.as<TensorTypeNode>();
CHECK(in_tensor_type != nullptr) << "Type information missing."
<< " Please run infer_type pass.";
Array<IndexExpr> input_shape = in_tensor_type->shape;

// Find the output dtype.
auto out_type = types[1];
auto out_tensor_type = out_type.as<TensorTypeNode>();
CHECK(out_tensor_type != nullptr) << "Type information missing."
<< " Please run infer_type pass.";
auto out_dtype = out_tensor_type->dtype;

// Check rounding validity.
CHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST")
<< "QNN requantize supports two rounding modes - UPWARD and "
<< "TONEAREST";
return RequantizeLower(quantized_data, param, input_shape);
return RequantizeLower(quantized_data, param, input_shape, out_dtype);
}

/*
Expand Down Expand Up @@ -261,7 +268,7 @@ The requantize operator converts one quantized tensor to another quantized
tensor. For the output tensor, we are provided with output scale and zero
point. The computation looks like this
Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input)
Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input)
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.RequantizeAttrs")
Expand Down
15 changes: 8 additions & 7 deletions tests/python/relay/test_pass_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def before():
return y

@register_legalize("nn.conv2d", level=100)
def legalize_conv2d(attrs, inputs, arg_types):
def legalize_conv2d(attrs, inputs, types):
data, weight = inputs
weight = relay.multiply(weight, relay.const(2.0, "float32"))
return relay.nn.conv2d(data, weight, **attrs)
Expand Down Expand Up @@ -80,7 +80,7 @@ def before():
called = [False]

@register_legalize("nn.global_max_pool2d", level=101)
def legalize_conv2d(attrs, inputs, arg_types):
def legalize_conv2d(attrs, inputs, types):
called[0] = True
return None

Expand All @@ -103,12 +103,13 @@ def before():
return func

@register_legalize("concatenate", level=100)
def legalize_concatenate(attrs, inputs, arg_types):
def legalize_concatenate(attrs, inputs, types):
# Check that the correct multi-input case is handled.
assert len(inputs) == 1
assert isinstance(inputs[0], tvm.relay.expr.Tuple)
assert len(arg_types) == 1
assert isinstance(arg_types[0], tvm.relay.ty.TupleType)
assert len(types) == 2
assert isinstance(types[0], tvm.relay.ty.TupleType)
assert isinstance(types[1], tvm.relay.ty.TensorType)
return None

def expected():
Expand Down Expand Up @@ -153,9 +154,9 @@ def before():
return func

@register_legalize("nn.conv2d", level=101)
def legalize_conv2d(attrs, inputs, arg_types):
def legalize_conv2d(attrs, inputs, types):
from topi.arm_cpu.conv2d import _conv2d_legalize
return _conv2d_legalize(attrs, inputs, arg_types, tvm.relay.op)
return _conv2d_legalize(attrs, inputs, types, tvm.relay.op)

a = before()
b = run_opt_pass(a, transform.Legalize())
Expand Down
6 changes: 3 additions & 3 deletions topi/python/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,16 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N


@tvm.target.generic_func
def conv2d_legalize(attrs, inputs, arg_dtypes, F):
def conv2d_legalize(attrs, inputs, types, F):
"""Legalizes Conv2D op.
Parameters
----------
attrs : nnvm.top.AttrDict or tvm.attrs.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized.
arg_dtypes : list of types
List of types of input arguments
types : list of types
List of input and output types
F: symbol
The context, can be either nnvm.sym or relay.op
Note
Expand Down

0 comments on commit 343eb26

Please sign in to comment.