diff --git a/src/relay/op/nn/upsampling.cc b/src/relay/op/nn/upsampling.cc index 48a7a04ebb8af..70a20e25b4eaa 100644 --- a/src/relay/op/nn/upsampling.cc +++ b/src/relay/op/nn/upsampling.cc @@ -18,6 +18,31 @@ namespace relay { TVM_REGISTER_NODE_TYPE(UpSamplingAttrs); +template +Array > UpsamplingInferCorrectLayout( + const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array> &old_in_shapes) { + // NOTE: Discard "const" qualifier here. + T *params = const_cast(attrs.as()); + + if (new_in_layouts.defined()) { + CHECK_EQ(new_in_layouts.size(), 1); + + Layout raw_layout(params->layout); + Layout input = new_in_layouts[0]; + if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) && + input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) && + !input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h'))) { + params->layout = input.name(); // modify self to follow the input layout + } + } + + Layout inferred_layout(params->layout); + return Array >{{inferred_layout}, {inferred_layout}}; +} + bool UpSamplingRel(const Array& types, int num_inputs, const Attrs& attrs, @@ -91,6 +116,8 @@ RELAY_REGISTER_OP("nn.upsampling") .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) .add_type_rel("UpSampling", UpSamplingRel) +.set_attr("FInferCorrectLayout", + UpsamplingInferCorrectLayout) .set_attr("TOpPattern", kInjective) .set_attr( "FTVMCompute", [](const Attrs& attrs, @@ -101,14 +128,16 @@ RELAY_REGISTER_OP("nn.upsampling") CHECK(uattrs != nullptr); auto out_tt = out_type.as(); CHECK(out_tt) << "expected a tensor type: " << out_type; - CHECK(uattrs->layout == "NCHW" || uattrs->layout == "NHWC") + const auto layout = uattrs->layout; + const auto base_layout = layout.substr(0, 4); + CHECK(base_layout == "NCHW" || layout == "NHWC") << "unknown layout: " << uattrs->layout; Array oshape; - if (uattrs->layout == "NCHW") { + if (base_layout == "NCHW") { oshape.push_back(out_tt->shape[2]); oshape.push_back(out_tt->shape[3]); - } else if (uattrs->layout == "NHWC") { + } else if (layout == "NHWC") { oshape.push_back(out_tt->shape[1]); oshape.push_back(out_tt->shape[2]); } diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 975973d2b9522..8be3e3de1ea45 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -309,6 +309,49 @@ def expected(): assert(alpha_equal(a, b)) +def test_alter_layout_nchw_upsamping_op(): + """Test upsamping operators """ + def before(): + x = relay.var("x", shape=(1, 32, 28, 28)) + weight = relay.var('weight', shape=(32, 32, 3, 3)) + y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1)) + y = relay.nn.upsampling(y, scale=2) + y = relay.nn.avg_pool2d(y, pool_size=(2, 2), strides=(2, 2)) + y = relay.Function(free_vars(y), y) + return y + + @register_alter_op_layout("nn.conv2d", level=106) + def alter_conv2d(attrs, inputs, tinfos): + data, weight = inputs + new_attrs = dict(attrs) + new_attrs['data_layout'] = 'NCHW16c' + return relay.nn.conv2d(data, weight, **new_attrs) + + def expected(): + x = relay.var("x", shape=(1, 32, 28, 28)) + weight = relay.var("weight") + x = relay.layout_transform(x, "NCHW", "NCHW16c") + y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1), + data_layout="NCHW16c") + y = relay.nn.upsampling(y, scale=2, layout="NCHW16c") + y = relay.nn.avg_pool2d(y, pool_size=(2, 2), strides=(2, 2), layout='NCHW16c') + y = relay.layout_transform(y, "NCHW16c", "NCHW") + y = relay.Function(free_vars(y), y) + return y + + a = before() + a = infer_type(a) + a = canonicalize_ops(a) + a = infer_type(a) + + a = alter_op_layout(a) + a = infer_type(a) + + b = expected() + b = infer_type(b) + + assert(alpha_equal(a, b)) + def test_alter_layout_scalar(): """Test alternating the layout of a conv2d. The layout of broadcast operators and the weight should be changed accordingly. diff --git a/topi/include/topi/image/resize.h b/topi/include/topi/image/resize.h index ae1b9ff264253..9981980539bee 100644 --- a/topi/include/topi/image/resize.h +++ b/topi/include/topi/image/resize.h @@ -134,6 +134,45 @@ inline Tensor resize_nearest_neighbor_nchw(const Tensor& input, }, name, tag); } +/*! +* \brief Resize given tensor to given shape using nearest neighbour for NCHWc +* +* \param input The input tensor. +* \param shape Output shape to resize to. +* \param align_corners To preserve centers of 4 corner pixels +* \param name Name of the operation +* \param tag The tag to mark the operation +* +* \return A Tensor resized to given shape +*/ +inline Tensor resize_nearest_neighbor_nchwc(const Tensor& input, + const Array& shape, + bool align_corners = false, + std::string name = "tensor", + std::string tag = kInjective) { + Array out_shape; + out_shape.push_back(input->shape[0]); + out_shape.push_back(input->shape[1]); + out_shape.push_back(shape[0]); + out_shape.push_back(shape[1]); + out_shape.push_back(input->shape[4]); + + Expr h_ratio = shape[0] / input->shape[2]; + Expr w_ratio = shape[1] / input->shape[3]; + + return compute( + out_shape, [&](const Array& indices) { + Array idx; + idx.push_back(indices[0]); + idx.push_back(indices[1]); + idx.push_back(indices[2] / h_ratio); + idx.push_back(indices[3] / w_ratio); + idx.push_back(indices[4]); + + return input(idx); + }, name, tag); +} + /*! * \brief Resize given tensor to given shape using nearest neighbour * @@ -153,11 +192,17 @@ inline Tensor resize_nearest_neighbor(const Tensor& input, std::string name = "tensor", std::string tag = kInjective) { CHECK_EQ(align_corners, false) << "Align corners not supported for nearest neighbour"; - + auto base_layout = layout.substr(0, 4); if (layout == "NHWC") { return resize_nearest_neighbor_nhwc(input, shape, align_corners); - } else { + } else if (layout == "NCHW") { return resize_nearest_neighbor_nchw(input, shape, align_corners); + } else if (base_layout == "NCHW") { + // NCHWc + return resize_nearest_neighbor_nchwc(input, shape, align_corners); + } else { + LOG(FATAL) << "Unknown layout: " << layout; + return Tensor(); } } diff --git a/topi/python/topi/nn/upsampling.py b/topi/python/topi/nn/upsampling.py index 757d8fe674c22..4b4ddcefea4e1 100644 --- a/topi/python/topi/nn/upsampling.py +++ b/topi/python/topi/nn/upsampling.py @@ -30,8 +30,8 @@ def upsampling(data, scale, layout="NCHW", method='NEAREST_NEIGHBOR'): 4-D with shape [batch, channel, in_height*scale, in_width*scale] or [batch, in_height*scale, in_width*scale, channel] """ - - if layout == "NCHW": + base_layout = layout[0:4] + if base_layout == "NCHW": out_shape = (simplify(data.shape[2] * scale), simplify(data.shape[3] * scale)) elif layout == "NHWC": out_shape = (simplify(data.shape[1] * scale), simplify(data.shape[2] * scale))