Skip to content

Commit

Permalink
Fix InferCorrectLayout for dynamic upsampling and add a regression te…
Browse files Browse the repository at this point in the history
…st (#6712)

* add a regression test

* fix dyn upsampling infer layout

* fix lint
  • Loading branch information
Matthew Brookhart authored Oct 21, 2020
1 parent 9ae386c commit f65e320
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/relay/op/dyn/nn/upsampling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@
* \brief upsampling operator
*/

#include "../../nn/upsampling.h"
#include "upsampling.h"

#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/tir/data_layout.h>

#include <utility>
#include <vector>

#include "../../op_common.h"
Expand All @@ -48,7 +49,6 @@ bool UpSamplingRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
if (scale_h == nullptr) return false;
if (scale_w == nullptr) return false;

CHECK_EQ(data->shape.size(), 4);
CHECK_EQ(scale_h->shape.size(), 0);
CHECK_EQ(scale_w->shape.size(), 0);
static const Layout kNCHW("NCHW");
Expand Down
69 changes: 69 additions & 0 deletions src/relay/op/dyn/nn/upsampling.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
*
* \file src/relay/op/dyn/nn/upsampling.h
* \brief implementation of the InferCorrectLayout pass for dynamic upsampling
*/

#ifndef TVM_RELAY_OP_DYN_NN_UPSAMPLING_H_
#define TVM_RELAY_OP_DYN_NN_UPSAMPLING_H_

#include <tvm/relay/attrs/nn.h>
#include <tvm/tir/data_layout.h>

#include "../../op_common.h"

namespace tvm {
namespace relay {
namespace dyn {

template <typename T>
Array<Array<Layout> > UpsamplingInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
// NOTE: Discard "const" qualifier here.
T* params = const_cast<T*>(attrs.as<T>());
if (new_in_layouts.defined()) {
CHECK_GT(new_in_layouts.size(), 0);

Layout raw_layout(params->layout);
Layout input = new_in_layouts[0];
if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) &&
input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) &&
!input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h')) &&
(input.IndexOf(LayoutAxis::Get('D')) == -1 ||
(input.IndexOf(LayoutAxis::Get('D')) == raw_layout.IndexOf(LayoutAxis::Get('D')) &&
!input.Contains(LayoutAxis::Get('d'))))) {
params->layout = input.name(); // modify self to follow the input layout
}
}

Layout inferred_layout(params->layout);
Layout param_layout("NCHW");
return Array<Array<Layout> >{{inferred_layout, param_layout, param_layout}, {inferred_layout}};
}

} // namespace dyn
} // namespace relay
} // namespace tvm

#endif // TVM_RELAY_OP_DYN_NN_UPSAMPLING_H_
39 changes: 39 additions & 0 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,45 @@ def expected():
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)


def test_alter_layout_nchw_dyn_upsamping_op():
"""Test upsamping operators """

def before():
x = relay.var("x", shape=(1, 32, 28, 28))
weight = relay.var("weight", shape=(32, 32, 3, 3))
y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1))
y = relay.nn.upsampling(y, scale_h=relay.const(2), scale_w=relay.const(2))
y = relay.nn.avg_pool2d(y, pool_size=(2, 2), strides=(2, 2))
y = relay.Function(analysis.free_vars(y), y)
return y

def alter_conv2d(attrs, inputs, tinfos, out_type):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs["data_layout"] = "NCHW16c"
return relay.nn.conv2d(data, weight, **new_attrs)

def expected():
x = relay.var("x", shape=(1, 32, 28, 28))
weight = relay.var("weight")
x = relay.layout_transform(x, "NCHW", "NCHW16c")
y = relay.nn.conv2d(
x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c"
)
y = relay.nn.upsampling(y, scale_h=relay.const(2), scale_w=relay.const(2), layout="NCHW16c")
y = relay.nn.avg_pool2d(y, pool_size=(2, 2), strides=(2, 2), layout="NCHW16c")
y = relay.layout_transform(y, "NCHW16c", "NCHW")
y = relay.Function(analysis.free_vars(y), y)
return y

with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before()
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected(), transform.InferType())

assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)


@tvm.testing.uses_gpu
def test_alter_layout_strided_slice():
"""Test rewriting strided_slice during alter_iop_layout"""
Expand Down

0 comments on commit f65e320

Please sign in to comment.