Skip to content

Commit

Permalink
[RELAY][OP] roi_align operator alter layout (apache#6443)
Browse files Browse the repository at this point in the history
* [RELAY][OP] roi_align operator alter layout

* [RELAY][OP] roi_align operator alter layout

* [RELAY][OP] roi_align operator alter layout

* [RELAY][OP] roi_align operator alter layout

Co-authored-by: honghua.cao <honghua.cao@streamcomputing.com>
  • Loading branch information
2 people authored and trevor-m committed Sep 18, 2020
1 parent 46d94f7 commit 930d78b
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 4 deletions.
2 changes: 2 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,8 @@ def nms_strategy_cuda(attrs, inputs, out_type, target):
def roi_align_strategy_cuda(attrs, inputs, out_type, target):
"""roi_align cuda strategy"""
strategy = _op.OpStrategy()
layout = attrs.layout
assert layout == "NCHW", "only support nchw for now"
strategy.add_implementation(
wrap_compute_roi_align(topi.vision.rcnn.roi_align_nchw),
wrap_topi_schedule(topi.cuda.schedule_roi_align),
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,8 @@ def _compute_roi_align(attrs, inputs, out_type):
def roi_align_strategy(attrs, inputs, out_type, target):
"""roi_align generic strategy"""
strategy = _op.OpStrategy()
layout = attrs.layout
assert layout == "NCHW", "only support nchw for now"
strategy.add_implementation(
wrap_compute_roi_align(topi.vision.rcnn.roi_align_nchw),
wrap_topi_schedule(topi.generic.schedule_roi_align),
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,8 @@ def sparse_dense_strategy_cpu(attrs, inputs, out_type, target):
def roi_align_strategy_cpu(attrs, inputs, out_type, target):
"""roi_align x86 strategy"""
strategy = _op.OpStrategy()
layout = attrs.layout
assert layout == "NCHW", "only support nchw for now"
strategy.add_implementation(
wrap_compute_roi_align(topi.x86.roi_align_nchw),
wrap_topi_schedule(topi.generic.schedule_roi_align),
Expand Down
43 changes: 43 additions & 0 deletions python/tvm/relay/op/vision/_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,49 @@
reg.register_strategy("vision.roi_align", strategy.roi_align_strategy)
reg.register_pattern("vision.roi_align", OpPattern.OUT_ELEMWISE_FUSABLE)


@reg.register_convert_op_layout("vision.roi_align")
def convert_roi_align(attrs, inputs, tinfos, desired_layouts):
"""Convert Layout pass registration for roi_align op.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current roi_align
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
tinfos : list of types
List of input and output types
desired_layouts : list of layout strings
List of layouts defining our desired
layout for the data and rois inputs respectively.
Returns
-------
result : tvm.relay.Expr
The transformed expr
"""
# pylint: disable=import-outside-toplevel
from tvm import relay

data, rois = inputs
new_attrs = dict(attrs)
assert (
len(desired_layouts) == 2
), "A desired layout is expected for both of vision.roi_align's inputs"

desired_data_layout, desired_rois_layout = map(str, desired_layouts)
assert desired_data_layout != "default", "Data layout cannot be default"
assert desired_rois_layout == "default", "Rois layout must be default"

new_attrs["layout"] = desired_data_layout
# rois layout not change
if desired_data_layout in ["NCHW", "NHWC"]:
return relay.vision.roi_align(data, rois, **new_attrs)

raise ValueError("Layout %s is not yet supported." % desired_data_layout)


# roi_pool
@reg.register_compute("vision.roi_pool")
def compute_roi_pool(attrs, inputs, _):
Expand Down
34 changes: 30 additions & 4 deletions src/relay/op/vision/rcnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>

#include "../../transforms/infer_layout_util.h"

namespace tvm {
namespace relay {

Expand All @@ -43,14 +45,36 @@ bool ROIAlignRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
CHECK(roi_align_attrs);
CHECK_EQ(dshape.size(), 4) << "Input data should be 4-D.";
CHECK_EQ(rshape.size(), 2) << "Input rois should be 2-D.";
CHECK_EQ(roi_align_attrs->layout, "NCHW") << "ROI Align only supports NCHW layout";
// assign output type
std::vector<IndexExpr> oshape(
{rshape[0], dshape[1], roi_align_attrs->pooled_size[0], roi_align_attrs->pooled_size[1]});
std::vector<IndexExpr> oshape;
if (roi_align_attrs->layout == "NCHW") {
oshape = {rshape[0], dshape[1], roi_align_attrs->pooled_size[0],
roi_align_attrs->pooled_size[1]};
} else {
CHECK_EQ(roi_align_attrs->layout, "NHWC") << "Unexpected ROI Align layout";
oshape = {rshape[0], roi_align_attrs->pooled_size[0], roi_align_attrs->pooled_size[1],
dshape[3]};
}

reporter->Assign(types[2], TensorType(oshape, data->dtype));
return true;
}

template <typename T>
Array<Array<Layout> > ROIAlignInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
// NOTE: Discard "const" qualifier here.
T* params = const_cast<T*>(attrs.as<T>());
Layout data_layout = params->layout;

// Layout inference needs to define the layout for all inputs and output data layouts.
// For roi_align, the second inputs is 2-D tensor with shape [num_roi, 5].
// So, we set the layout as "N5".
return Array<Array<Layout> >{{data_layout, Layout("N5")}, {data_layout}};
}

Expr MakeROIAlign(Expr data, Expr rois, Array<IndexExpr> pooled_size, double spatial_scale,
int sample_ratio, String layout) {
auto attrs = make_object<ROIAlignAttrs>();
Expand Down Expand Up @@ -78,7 +102,9 @@ RELAY_REGISTER_OP("vision.roi_align")
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("rois", "Tensor", "The input rois")
.set_support_level(5)
.add_type_rel("ROIAlign", ROIAlignRel);
.add_type_rel("ROIAlign", ROIAlignRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ROIAlignInferCorrectLayout<ROIAlignAttrs>);

TVM_REGISTER_NODE_TYPE(ROIPoolAttrs);

Expand Down
54 changes: 54 additions & 0 deletions tests/python/relay/test_pass_convert_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,59 @@ def expected():
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)


def test_conv_roi_align_convert_layout():
def before():
x = relay.var("x", shape=(1, 64, 56, 56))
weight1 = relay.var("weight1", shape=(64, 64, 3, 3))
y = relay.nn.conv2d(
x,
weight1,
channels=64,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
)
rois = relay.var("rois", shape=(32, 5))
y = relay.vision.roi_align(
y, rois, pooled_size=(14, 14), spatial_scale=0.0625, sample_ratio=2, layout="NCHW"
)
y = relay.Function(analysis.free_vars(y), y)
return y

def expected():
x = relay.var("x", shape=(1, 64, 56, 56))
weight1 = relay.var("weight1", shape=(64, 64, 3, 3))
x = relay.layout_transform(x, "NCHW", "NHWC")
weight1 = relay.layout_transform(weight1, "OIHW", "HWIO")
y = relay.nn.conv2d(
x,
weight1,
channels=64,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
)
rois = relay.var("rois", shape=(32, 5))
y = relay.vision.roi_align(
y, rois, pooled_size=(14, 14), spatial_scale=0.0625, sample_ratio=2, layout="NHWC"
)
ret = relay.layout_transform(y, "NHWC", "NCHW")
y = relay.Function(analysis.free_vars(ret), ret)
return y

a = before()
desired_layouts = {
"nn.conv2d": ["NHWC", "HWIO"],
"vision.roi_align": ["NHWC", "default"],
}
a = run_opt_pass(a, transform.ConvertLayout(desired_layouts))
b = run_opt_pass(expected(), transform.InferType())

assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)


def test_default_keyword():
""" Check that the default keyword selects correct TVM default layout. """

Expand Down Expand Up @@ -1005,5 +1058,6 @@ def expected():
test_qnn_conv_nhwc_convert_layout()
test_conv_convert_kernel_layout()
test_conv_transpose_convert_layout()
test_conv_roi_align_convert_layout()
test_default_keyword()
test_different_ops_convert_layout()

0 comments on commit 930d78b

Please sign in to comment.