Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] InferCorrectLayout for strided_slice & min_num_branches option in CombineParallelConv2D #2961

Merged
merged 3 commits into from
Apr 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,20 +722,23 @@ def fuse_ops(expr, opt_level=1):
return _ir_pass.FuseOps(expr, opt_level)


def combine_parallel_conv2d(expr):
"""Fold multiple conv2d into one.
def combine_parallel_conv2d(expr, min_num_branches=3):
"""Combine multiple conv2d into one.

Parameters
----------
expr : tvm.relay.Expr
The input expression.

min_num_branches : int
The minimum number of parallel branches when the transformation should be applied.

Returns
-------
transformed_expr : tvm.relay.Expr
Transformed expression
"""
return _ir_pass.CombineParallelConv2D(expr)
return _ir_pass.CombineParallelConv2D(expr, min_num_branches)


def alter_op_layout(expr):
Expand Down
61 changes: 60 additions & 1 deletion src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1722,6 +1722,64 @@ bool StridedSliceRel(const Array<Type>& types,
}


Array<Array<Layout> > StridedSliceInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>>& old_in_shapes) {
CHECK(old_in_layouts.defined());
CHECK_EQ(old_in_layouts.size(), 1);
CHECK(old_in_shapes.defined());
CHECK_EQ(old_in_shapes.size(), 1);

auto layout = old_in_layouts[0];
if (layout.defined() && new_in_layouts.defined()) {
CHECK_EQ(new_in_layouts.size(), 1);
auto new_layout = new_in_layouts[0];
auto shape = old_in_shapes[0];

// NOTE: Discard "const" qualifier here.
auto *params = const_cast<StridedSliceAttrs*>(attrs.as<StridedSliceAttrs>());

Array<Integer> new_begin, new_end;

for (size_t i = 0; i < params->begin.size(); i++) {
const LayoutAxis& axis = layout[i];
if (!axis.IsPrimal()) {
// original layout that contains splitted axes is not supported
return {{Layout::Undef()}, {Layout::Undef()}};
}
auto factor = new_layout.FactorOf(axis);
if (factor == -1) {
new_begin.push_back(params->begin[i]);
new_end.push_back(params->end[i]);
} else {
if (params->strides.defined() && i < params->strides.size()) {
auto stride = params->strides[i];
// arbitrary stride is not supported
if (stride.defined() && stride->value != 1) {
return {{Layout::Undef()}, {Layout::Undef()}};
}
}
int64_t begin = params->begin[i].defined() ? params->begin[i]->value : 0;
int64_t end = params->end[i].defined() ? params->end[i]->value :
shape[i].as<IntImm>()->value;
if (begin % factor || end % factor) {
// transform to original layout
return {{Layout::Undef()}, {Layout::Undef()}};
}
new_begin.push_back(tvm::Integer(begin / factor));
new_end.push_back(tvm::Integer(end / factor));
}
}
layout = new_layout;
params->begin = new_begin;
params->end = new_end;
}
return {{layout}, {layout}};
}


// Positional relay function to create StridedSlice operator used by frontend FFI.
Expr MakeStridedSlice(Expr data,
Array<Integer> begin,
Expand Down Expand Up @@ -1783,7 +1841,8 @@ Examples::
.set_attrs_type_key("relay.attrs.StridedSliceAttrs")
.add_type_rel("StridedSlice", StridedSliceRel)
.set_attr<FTVMCompute>("FTVMCompute", StridedSliceCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", StridedSliceInferCorrectLayout);


// relay.split
Expand Down
15 changes: 12 additions & 3 deletions src/relay/pass/combine_parallel_conv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,17 +159,23 @@ class BranchGroupFinder : private ExprVisitor {

class ParallelConv2DCombiner {
public:
explicit ParallelConv2DCombiner(uint64_t min_num_branches) : min_num_branches_(min_num_branches) {
}

Expr Combine(const Expr& expr) {
auto groups = BranchGroupFinder().Find(expr);
for (const Group& group : groups) {
if (group.size() < 2) continue;
if (group.size() < min_num_branches_) {
continue;
}
CombineBranches(group);
}
return ExprSubst(expr, std::move(subst_map_));
}

private:
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> subst_map_;
uint64_t min_num_branches_;

std::tuple<Expr, IndexExpr> TransformWeight(const Group& branches) {
int64_t num_filters = 0; // number of filters of the transformed weight
Expand Down Expand Up @@ -343,11 +349,14 @@ class ParallelConv2DCombiner {
}
};

Expr CombineParallelConv2D(const Expr& expr) { return ParallelConv2DCombiner().Combine(expr); }
/*! \brief Combine parallel conv2d if number of branches >= min_num_branches */
Expr CombineParallelConv2D(const Expr& expr, uint64_t min_num_branches) {
return ParallelConv2DCombiner(min_num_branches).Combine(expr);
}

TVM_REGISTER_API("relay._ir_pass.CombineParallelConv2D")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = CombineParallelConv2D(args[0]);
*ret = CombineParallelConv2D(args[0], args[1]);
});

} // namespace relay
Expand Down
43 changes: 43 additions & 0 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,48 @@ def expected():
assert(alpha_equal(a, b))


def test_alter_layout_strided_slice():
"""Test rewriting strided_slice during alter_iop_layout"""
def before():
x = relay.var("x", shape=(1, 32, 28, 28))
weight = relay.var('weight', shape=(32, 32, 3, 3))
y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1))
y = relay.strided_slice(y, begin=[0, 16], end=[None, None])
y = relay.Function(free_vars(y), y)
return y

@register_alter_op_layout("nn.conv2d", level=109)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW4c'
return relay.nn.conv2d(data, weight, **new_attrs)

def expected():
x = relay.var("x", shape=(1, 32, 28, 28))
weight = relay.var("weight")
x = relay.layout_transform(x, "NCHW", "NCHW4c")
y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1),
data_layout="NCHW4c")
y = relay.strided_slice(y, begin=[0, 4], end=[None, 8])
y = relay.layout_transform(y, "NCHW4c", "NCHW")
y = relay.Function(free_vars(y), y)
return y

a = before()
a = infer_type(a)
a = canonicalize_ops(a)
a = infer_type(a)

a = alter_op_layout(a)
a = infer_type(a)

b = expected()
b = infer_type(b)

assert(alpha_equal(a, b))


if __name__ == "__main__":
test_alter_op()
test_alter_return_none()
Expand All @@ -482,3 +524,4 @@ def expected():
test_alter_layout_scalar()
test_alter_layout_concatenate()
test_alter_layout_nchw_upsamping_op()
test_alter_layout_strided_slice()
8 changes: 4 additions & 4 deletions tests/python/relay/test_pass_combine_parallel_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def check(x_shape, channels1, channels2, channels3, channels4):

y_before = before(x, w1, w2, w3, w4)
y = relay.ir_pass.infer_type(y_before)
y = relay.ir_pass.combine_parallel_conv2d(y)
y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2)
y = relay.ir_pass.infer_type(y)
y_expected = expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4)
y_expected = relay.ir_pass.infer_type(y_expected)
Expand Down Expand Up @@ -102,7 +102,7 @@ def check(x_shape, channels1, channels2):
bias = relay.var("bias", shape=(channels2, 1, 1))
y_before = before(x, w1, w2, scale1, scale2, bias)
y = relay.ir_pass.infer_type(y_before)
y = relay.ir_pass.combine_parallel_conv2d(y)
y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2)
y = relay.ir_pass.infer_type(y)
y_expected = expected(x, w1, w2, scale1, scale2, bias, channels1, channels2)
y_expected = relay.ir_pass.infer_type(y_expected)
Expand Down Expand Up @@ -142,7 +142,7 @@ def check(x_shape, channels1, channels2):
scale2 = relay.var("scale2", shape=(1,))
y_before = before(x, w1, w2, scale1, scale2)
y = relay.ir_pass.infer_type(y_before)
y = relay.ir_pass.combine_parallel_conv2d(y)
y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2)
y = relay.ir_pass.infer_type(y)
y_expected = expected(x, w1, w2, scale1, scale2, channels1, channels2)
y_expected = relay.ir_pass.infer_type(y_expected)
Expand Down Expand Up @@ -179,7 +179,7 @@ def check(x_shape, repeat):
w = relay.var("w", shape=(out_c, in_c, 1, 1))
y_before = before(x, w, repeat)
y = relay.ir_pass.infer_type(y_before)
y = relay.ir_pass.combine_parallel_conv2d(y)
y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2)
y = relay.ir_pass.infer_type(y)
y_expected = expected(x, w, out_c, repeat)
y_expected = relay.ir_pass.infer_type(y_expected)
Expand Down