Skip to content

Commit

Permalink
[AlterLayout] NCHW upsampling
Browse files Browse the repository at this point in the history
  • Loading branch information
Bing Xu committed Mar 17, 2019
1 parent f8ac138 commit 6fb25d8
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 7 deletions.
35 changes: 32 additions & 3 deletions src/relay/op/nn/upsampling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,31 @@ namespace relay {

TVM_REGISTER_NODE_TYPE(UpSamplingAttrs);

template <typename T>
Array<Array<Layout> > UpsamplingInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
// NOTE: Discard "const" qualifier here.
T *params = const_cast<T*>(attrs.as<T>());

if (new_in_layouts.defined()) {
CHECK_EQ(new_in_layouts.size(), 1);

Layout raw_layout(params->layout);
Layout input = new_in_layouts[0];
if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) &&
input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) &&
!input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h'))) {
params->layout = input.name(); // modify self to follow the input layout
}
}

Layout inferred_layout(params->layout);
return Array<Array<Layout> >{{inferred_layout}, {inferred_layout}};
}

bool UpSamplingRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
Expand Down Expand Up @@ -91,6 +116,8 @@ RELAY_REGISTER_OP("nn.upsampling")
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
.add_type_rel("UpSampling", UpSamplingRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
UpsamplingInferCorrectLayout<UpSamplingAttrs>)
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const Attrs& attrs,
Expand All @@ -101,14 +128,16 @@ RELAY_REGISTER_OP("nn.upsampling")
CHECK(uattrs != nullptr);
auto out_tt = out_type.as<TensorTypeNode>();
CHECK(out_tt) << "expected a tensor type: " << out_type;
CHECK(uattrs->layout == "NCHW" || uattrs->layout == "NHWC")
const auto layout = uattrs->layout;
const auto base_layout = layout.substr(0, 4);
CHECK(base_layout == "NCHW" || layout == "NHWC")
<< "unknown layout: " << uattrs->layout;

Array<HalideIR::Expr> oshape;
if (uattrs->layout == "NCHW") {
if (base_layout == "NCHW") {
oshape.push_back(out_tt->shape[2]);
oshape.push_back(out_tt->shape[3]);
} else if (uattrs->layout == "NHWC") {
} else if (layout == "NHWC") {
oshape.push_back(out_tt->shape[1]);
oshape.push_back(out_tt->shape[2]);
}
Expand Down
49 changes: 47 additions & 2 deletions topi/include/topi/image/resize.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,45 @@ inline Tensor resize_nearest_neighbor_nchw(const Tensor& input,
}, name, tag);
}

/*!
* \brief Resize given tensor to given shape using nearest neighbour for NCHWc
*
* \param input The input tensor.
* \param shape Output shape to resize to.
* \param align_corners To preserve centers of 4 corner pixels
* \param name Name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor resized to given shape
*/
inline Tensor resize_nearest_neighbor_nchwc(const Tensor& input,
const Array<Expr>& shape,
bool align_corners = false,
std::string name = "tensor",
std::string tag = kInjective) {
Array<Expr> out_shape;
out_shape.push_back(input->shape[0]);
out_shape.push_back(input->shape[1]);
out_shape.push_back(shape[0]);
out_shape.push_back(shape[1]);
out_shape.push_back(input->shape[4]);

Expr h_ratio = shape[0] / input->shape[2];
Expr w_ratio = shape[1] / input->shape[3];

return compute(
out_shape, [&](const Array<Var>& indices) {
Array<Expr> idx;
idx.push_back(indices[0]);
idx.push_back(indices[1]);
idx.push_back(indices[2] / h_ratio);
idx.push_back(indices[3] / w_ratio);
idx.push_back(indices[4]);

return input(idx);
}, name, tag);
}

/*!
* \brief Resize given tensor to given shape using nearest neighbour
*
Expand All @@ -153,11 +192,17 @@ inline Tensor resize_nearest_neighbor(const Tensor& input,
std::string name = "tensor",
std::string tag = kInjective) {
CHECK_EQ(align_corners, false) << "Align corners not supported for nearest neighbour";

auto base_layout = layout.substr(0, 4);
if (layout == "NHWC") {
return resize_nearest_neighbor_nhwc(input, shape, align_corners);
} else {
} else if (layout == "NCHW") {
return resize_nearest_neighbor_nchw(input, shape, align_corners);
} else if (base_layout == "NCHW") {
// NCHWc
return resize_nearest_neighbor_nchwc(input, shape, align_corners);
} else {
LOG(FATAL) << "Unknown layout: " << layout;
return Tensor();
}
}

Expand Down
4 changes: 2 additions & 2 deletions topi/python/topi/nn/upsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def upsampling(data, scale, layout="NCHW", method='NEAREST_NEIGHBOR'):
4-D with shape [batch, channel, in_height*scale, in_width*scale]
or [batch, in_height*scale, in_width*scale, channel]
"""

if layout == "NCHW":
base_layout = layout[0:4]
if base_layout == "NCHW":
out_shape = (simplify(data.shape[2] * scale), simplify(data.shape[3] * scale))
elif layout == "NHWC":
out_shape = (simplify(data.shape[1] * scale), simplify(data.shape[2] * scale))
Expand Down

0 comments on commit 6fb25d8

Please sign in to comment.