Skip to content

Commit

Permalink
[AlterLayout] NCHW upsampling
Browse files Browse the repository at this point in the history
  • Loading branch information
Bing Xu committed Mar 17, 2019
1 parent f8ac138 commit a6844de
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 7 deletions.
35 changes: 32 additions & 3 deletions src/relay/op/nn/upsampling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,31 @@ namespace relay {

TVM_REGISTER_NODE_TYPE(UpSamplingAttrs);

template <typename T>
Array<Array<Layout> > UpsamplingInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
// NOTE: Discard "const" qualifier here.
T *params = const_cast<T*>(attrs.as<T>());

if (new_in_layouts.defined()) {
CHECK_EQ(new_in_layouts.size(), 1);

Layout raw_layout(params->layout);
Layout input = new_in_layouts[0];
if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) &&
input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) &&
!input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h'))) {
params->layout = input.name(); // modify self to follow the input layout
}
}

Layout inferred_layout(params->layout);
return Array<Array<Layout> >{{inferred_layout}, {inferred_layout}};
}

bool UpSamplingRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
Expand Down Expand Up @@ -91,6 +116,8 @@ RELAY_REGISTER_OP("nn.upsampling")
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
.add_type_rel("UpSampling", UpSamplingRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
UpsamplingInferCorrectLayout<UpSamplingAttrs>)
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const Attrs& attrs,
Expand All @@ -101,14 +128,16 @@ RELAY_REGISTER_OP("nn.upsampling")
CHECK(uattrs != nullptr);
auto out_tt = out_type.as<TensorTypeNode>();
CHECK(out_tt) << "expected a tensor type: " << out_type;
CHECK(uattrs->layout == "NCHW" || uattrs->layout == "NHWC")
const auto layout = uattrs->layout;
const auto base_layout = layout.substr(0, 4);
CHECK(base_layout == "NCHW" || layout == "NHWC")
<< "unknown layout: " << uattrs->layout;

Array<HalideIR::Expr> oshape;
if (uattrs->layout == "NCHW") {
if (base_layout == "NCHW") {
oshape.push_back(out_tt->shape[2]);
oshape.push_back(out_tt->shape[3]);
} else if (uattrs->layout == "NHWC") {
} else if (layout == "NHWC") {
oshape.push_back(out_tt->shape[1]);
oshape.push_back(out_tt->shape[2]);
}
Expand Down
46 changes: 46 additions & 0 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,51 @@ def expected():

assert(alpha_equal(a, b))


def test_alter_layout_nchw_upsamping_op():
"""Test upsamping operators """
def before():
x = relay.var("x", shape=(1, 32, 28, 28))
weight = relay.var('weight', shape=(32, 32, 3, 3))
y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1))
y = relay.nn.upsampling(y, scale=2)
y = relay.nn.avg_pool2d(y, pool_size=(2, 2), strides=(2, 2))
y = relay.Function(free_vars(y), y)
return y

@register_alter_op_layout("nn.conv2d", level=108)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c'
return relay.nn.conv2d(data, weight, **new_attrs)

def expected():
x = relay.var("x", shape=(1, 32, 28, 28))
weight = relay.var("weight")
x = relay.layout_transform(x, "NCHW", "NCHW16c")
y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1),
data_layout="NCHW16c")
y = relay.nn.upsampling(y, scale=2, layout="NCHW16c")
y = relay.nn.avg_pool2d(y, pool_size=(2, 2), strides=(2, 2), layout='NCHW16c')
y = relay.layout_transform(y, "NCHW16c", "NCHW")
y = relay.Function(free_vars(y), y)
return y

a = before()
a = infer_type(a)
a = canonicalize_ops(a)
a = infer_type(a)

a = alter_op_layout(a)
a = infer_type(a)

b = expected()
b = infer_type(b)

assert(alpha_equal(a, b))


if __name__ == "__main__":
test_alter_op()
test_alter_return_none()
Expand All @@ -420,3 +465,4 @@ def expected():
test_alter_layout_broadcast_op()
test_alter_layout_scalar()
test_alter_layout_concatenate()
test_alter_layout_nchw_upsamping_op()
49 changes: 47 additions & 2 deletions topi/include/topi/image/resize.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,45 @@ inline Tensor resize_nearest_neighbor_nchw(const Tensor& input,
}, name, tag);
}

/*!
* \brief Resize given tensor to given shape using nearest neighbour for NCHWc
*
* \param input The input tensor.
* \param shape Output shape to resize to.
* \param align_corners To preserve centers of 4 corner pixels
* \param name Name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor resized to given shape
*/
inline Tensor resize_nearest_neighbor_nchwc(const Tensor& input,
const Array<Expr>& shape,
bool align_corners = false,
std::string name = "tensor",
std::string tag = kInjective) {
Array<Expr> out_shape;
out_shape.push_back(input->shape[0]);
out_shape.push_back(input->shape[1]);
out_shape.push_back(shape[0]);
out_shape.push_back(shape[1]);
out_shape.push_back(input->shape[4]);

Expr h_ratio = shape[0] / input->shape[2];
Expr w_ratio = shape[1] / input->shape[3];

return compute(
out_shape, [&](const Array<Var>& indices) {
Array<Expr> idx;
idx.push_back(indices[0]);
idx.push_back(indices[1]);
idx.push_back(indices[2] / h_ratio);
idx.push_back(indices[3] / w_ratio);
idx.push_back(indices[4]);

return input(idx);
}, name, tag);
}

/*!
* \brief Resize given tensor to given shape using nearest neighbour
*
Expand All @@ -153,11 +192,17 @@ inline Tensor resize_nearest_neighbor(const Tensor& input,
std::string name = "tensor",
std::string tag = kInjective) {
CHECK_EQ(align_corners, false) << "Align corners not supported for nearest neighbour";

auto base_layout = layout.substr(0, 4);
if (layout == "NHWC") {
return resize_nearest_neighbor_nhwc(input, shape, align_corners);
} else {
} else if (layout == "NCHW") {
return resize_nearest_neighbor_nchw(input, shape, align_corners);
} else if (base_layout == "NCHW") {
// NCHWc
return resize_nearest_neighbor_nchwc(input, shape, align_corners);
} else {
LOG(FATAL) << "Unknown layout: " << layout;
return Tensor();
}
}

Expand Down
4 changes: 2 additions & 2 deletions topi/python/topi/nn/upsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def upsampling(data, scale, layout="NCHW", method='NEAREST_NEIGHBOR'):
4-D with shape [batch, channel, in_height*scale, in_width*scale]
or [batch, in_height*scale, in_width*scale, channel]
"""

if layout == "NCHW":
base_layout = layout[0:4]
if base_layout == "NCHW":
out_shape = (simplify(data.shape[2] * scale), simplify(data.shape[3] * scale))
elif layout == "NHWC":
out_shape = (simplify(data.shape[1] * scale), simplify(data.shape[2] * scale))
Expand Down

0 comments on commit a6844de

Please sign in to comment.