Skip to content

Commit

Permalink
[Legalize][QNN] Pass out_types to Legalize. Update QNN requantize to …
Browse files Browse the repository at this point in the history
…read from out_types. (apache#3782)
  • Loading branch information
anijain2305 authored and wweic committed Sep 16, 2019
1 parent 25e056d commit 93ef544
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 54 deletions.
22 changes: 18 additions & 4 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,24 @@ def alter_op_layout_conv2d(attrs, inputs, tinfos):
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op)

@reg.register_legalize("nn.conv2d")
def legalize_conv2d(attrs, inputs, arg_dtypes):
"""Legalize conv2d"""
from ... import op
return topi.nn.conv2d_legalize(attrs, inputs, arg_dtypes, op)
def legalize_conv2d(attrs, inputs, types):
"""Legalize conv2d op.
Parameters
----------
attrs : tvm.attrs.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
return topi.nn.conv2d_legalize(attrs, inputs, types)

reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)

Expand Down
12 changes: 9 additions & 3 deletions src/relay/pass/legalize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,17 @@ Expr Legalizer(const Call& ref_call, const Array<Expr>& new_args, const NodeRef&
Expr new_e;
bool modified = false;
if (fop_legalize.count(op)) {
tvm::Array<tvm::relay::Type> arg_types;
// Collect input and output dtypes to pass on to Legalize API.
tvm::Array<tvm::relay::Type> types;
for (auto& expr : ref_call->args) {
arg_types.push_back(expr->checked_type());
types.push_back(expr->checked_type());
}
Expr legalized_value = fop_legalize[op](ref_call->attrs, new_args, arg_types);
types.push_back(ref_call->checked_type());

// Transform the op by calling the registered legalize function.
Expr legalized_value = fop_legalize[op](ref_call->attrs, new_args, types);

// Check if the transformation succeeded. If not, revert back to the original ref_call->op.
if (legalized_value.defined()) {
new_e = legalized_value;
modified = true;
Expand Down
4 changes: 2 additions & 2 deletions src/relay/qnn/op/dequantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@ Expr DequantizeLower(const Expr& input_tensor,

Expr DequantizeLegalize(const Attrs& attrs,
const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) {
const Array<tvm::relay::Type>& types) {
CHECK_EQ(new_args.size(), 1);
auto& data = new_args[0];
const auto* dequantize_attrs = attrs.as<DequantizeAttrs>();
CHECK(dequantize_attrs != nullptr);
CHECK_EQ(arg_types.size(), 1);
CHECK_EQ(types.size(), 2);
return DequantizeLower(data, dequantize_attrs);
}

Expand Down
4 changes: 2 additions & 2 deletions src/relay/qnn/op/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ Expr QuantizeLower(const Expr& input_tensor,

Expr QuantizeLegalize(const Attrs& attrs,
const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) {
const Array<tvm::relay::Type>& types) {
CHECK_EQ(new_args.size(), 1);
auto& data = new_args[0];
const auto* quantize_attrs = attrs.as<QuantizeAttrs>();
CHECK(quantize_attrs != nullptr);

CHECK_EQ(arg_types.size(), 1);
CHECK_EQ(types.size(), 2);
return QuantizeLower(data, quantize_attrs);
}

Expand Down
33 changes: 20 additions & 13 deletions src/relay/qnn/op/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplie
* 7) Cast to the out_dtype.
*/
Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
const Array<IndexExpr>& input_shape) {
const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
double double_multiplier = param->input_scale / param->output_scale;

// Choose high precision datatype to be int64. This is for avoiding overflow
Expand Down Expand Up @@ -173,10 +173,10 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
auto shifted_int64_t = Add(output_zp, scaled_int64_t);

// 7) Clip to the out_dtype min/max.
auto q_min = GetQmin(param->out_dtype);
auto q_max = GetQmax(param->out_dtype);
auto q_min = GetQmin(out_dtype);
auto q_max = GetQmax(out_dtype);
auto clipped_t = Clip(shifted_int64_t, q_min, q_max);
return Cast(clipped_t, param->out_dtype);
return Cast(clipped_t, out_dtype);
}

/*
Expand All @@ -193,25 +193,32 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
* Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input)
*/
Expr RequantizeLegalize(const Attrs& attrs, const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) {
const Array<tvm::relay::Type>& types) {
CHECK_EQ(new_args.size(), 1);
auto& quantized_data = new_args[0];
const auto* param = attrs.as<RequantizeAttrs>();
CHECK(param != nullptr);

// Find input shape.
CHECK_EQ(arg_types.size(), 1);
auto input_dtype = arg_types[0];
auto input_tensor_type = input_dtype.as<TensorTypeNode>();
CHECK(input_tensor_type != nullptr) << "Type information missing."
<< " Please run infer_type pass.";
Array<IndexExpr> input_shape = input_tensor_type->shape;
CHECK_EQ(types.size(), 2);
auto in_type = types[0];
auto in_tensor_type = in_type.as<TensorTypeNode>();
CHECK(in_tensor_type != nullptr) << "Type information missing."
<< " Please run infer_type pass.";
Array<IndexExpr> input_shape = in_tensor_type->shape;

// Find the output dtype.
auto out_type = types[1];
auto out_tensor_type = out_type.as<TensorTypeNode>();
CHECK(out_tensor_type != nullptr) << "Type information missing."
<< " Please run infer_type pass.";
auto out_dtype = out_tensor_type->dtype;

// Check rounding validity.
CHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST")
<< "QNN requantize supports two rounding modes - UPWARD and "
<< "TONEAREST";
return RequantizeLower(quantized_data, param, input_shape);
return RequantizeLower(quantized_data, param, input_shape, out_dtype);
}

/*
Expand Down Expand Up @@ -261,7 +268,7 @@ The requantize operator converts one quantized tensor to another quantized
tensor. For the output tensor, we are provided with output scale and zero
point. The computation looks like this
Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input)
Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input)
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.RequantizeAttrs")
Expand Down
15 changes: 8 additions & 7 deletions tests/python/relay/test_pass_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def before():
return y

@register_legalize("nn.conv2d", level=100)
def legalize_conv2d(attrs, inputs, arg_types):
def legalize_conv2d(attrs, inputs, types):
data, weight = inputs
weight = relay.multiply(weight, relay.const(2.0, "float32"))
return relay.nn.conv2d(data, weight, **attrs)
Expand Down Expand Up @@ -80,7 +80,7 @@ def before():
called = [False]

@register_legalize("nn.global_max_pool2d", level=101)
def legalize_conv2d(attrs, inputs, arg_types):
def legalize_conv2d(attrs, inputs, types):
called[0] = True
return None

Expand All @@ -103,12 +103,13 @@ def before():
return func

@register_legalize("concatenate", level=100)
def legalize_concatenate(attrs, inputs, arg_types):
def legalize_concatenate(attrs, inputs, types):
# Check that the correct multi-input case is handled.
assert len(inputs) == 1
assert isinstance(inputs[0], tvm.relay.expr.Tuple)
assert len(arg_types) == 1
assert isinstance(arg_types[0], tvm.relay.ty.TupleType)
assert len(types) == 2
assert isinstance(types[0], tvm.relay.ty.TupleType)
assert isinstance(types[1], tvm.relay.ty.TensorType)
return None

def expected():
Expand Down Expand Up @@ -153,9 +154,9 @@ def before():
return func

@register_legalize("nn.conv2d", level=101)
def legalize_conv2d(attrs, inputs, arg_types):
def legalize_conv2d(attrs, inputs, types):
from topi.arm_cpu.conv2d import _conv2d_legalize
return _conv2d_legalize(attrs, inputs, arg_types, tvm.relay.op)
return _conv2d_legalize(attrs, inputs, types)

a = before()
b = run_opt_pass(a, transform.Legalize())
Expand Down
42 changes: 30 additions & 12 deletions topi/python/topi/arm_cpu/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
"""Conv2D schedule for ARM CPU"""
from __future__ import absolute_import as _abs

import warnings
import logging

import tvm
from tvm import autotvm
from tvm import relay
import tvm.contrib.nnpack

from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform, \
Expand All @@ -35,6 +36,8 @@
from ..nn.util import get_const_int, get_pad_tuple
from ..nn.winograd_util import winograd_transform_matrices

logger = logging.getLogger('topi')

@autotvm.register_topi_compute(conv2d, 'arm_cpu', ['direct'])
def conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
"""TOPI compute callback for conv2d
Expand Down Expand Up @@ -679,7 +682,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
if layout != 'NCHW':
return None
if dilation != (1, 1):
warnings.warn("Does not support weight pre-transform for dilated convolution.")
logger.warning("Does not support weight pre-transform for dilated convolution.")
return None

data, kernel = tinfos[0:2]
Expand Down Expand Up @@ -794,31 +797,46 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
return None

@conv2d_legalize.register("arm_cpu")
def _conv2d_legalize(attrs, inputs, arg_types, F):
if F.__name__ != 'tvm.relay.op':
return None
def _conv2d_legalize(attrs, inputs, arg_types):
"""Legalizes Conv2D op.
Parameters
----------
attrs : tvm.attrs.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""

if attrs['data_layout'] == 'NHWC':
data, kernel = inputs
if attrs['kernel_layout'] == 'HWIO':
# Handle HWIO layout. This is common in TF graph.
kernel = F.transpose(kernel, axes=(3, 2, 0, 1))
kernel = relay.transpose(kernel, axes=(3, 2, 0, 1))
elif attrs['kernel_layout'] == 'HWOI':
# Handle HWOI layout. This is common in TF depthwise conv2d graph.
kernel = F.transpose(kernel, axes=(2, 3, 0, 1))
kernel = relay.transpose(kernel, axes=(2, 3, 0, 1))
elif attrs['kernel_layout'] != 'OIHW':
return None

warnings.warn("Legalize arm_cpu - NHWC schedule absent. Inserting layout transforms to "
+ "fallback to NCHW. This can result in performance degradation.")
logger.warning("Legalize arm_cpu - NHWC schedule absent. Inserting layout transforms to "
+ "fallback to NCHW. This can result in performance degradation.")
# Set new attrs for the tranposed conv.
new_attrs = {k: attrs[k] for k in attrs.keys()}
new_attrs['data_layout'] = 'NCHW'
new_attrs['kernel_layout'] = 'OIHW'

# Convert from NHWC to NCHW.
data = F.transpose(data, axes=(0, 3, 1, 2))
conv = F.nn.conv2d(data, kernel, **new_attrs)
data = relay.transpose(data, axes=(0, 3, 1, 2))
conv = relay.nn.conv2d(data, kernel, **new_attrs)
# Convert back to original NHWC layout.
out = F.transpose(conv, axes=(0, 2, 3, 1))
out = relay.transpose(conv, axes=(0, 2, 3, 1))
return out
return None
22 changes: 11 additions & 11 deletions topi/python/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,22 +72,22 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N


@tvm.target.generic_func
def conv2d_legalize(attrs, inputs, arg_dtypes, F):
def conv2d_legalize(attrs, inputs, types):
"""Legalizes Conv2D op.
Parameters
----------
attrs : nnvm.top.AttrDict or tvm.attrs.Attrs
attrs : tvm.attrs.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized.
arg_dtypes : list of types
List of types of input arguments
F: symbol
The context, can be either nnvm.sym or relay.op
Note
----
Unlike other TOPI functions, this function operates on both graph level and operator level,
so we have to pass 'F' to make it support our two versions of graph IR, NNVM and Relay.
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
# not to change by default
return None
Expand Down

0 comments on commit 93ef544

Please sign in to comment.