Skip to content

Commit

Permalink
Add min_num_branches option to CombineParallelConv2D
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Apr 9, 2019
1 parent fff9bc9 commit 4810130
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 10 deletions.
9 changes: 6 additions & 3 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,20 +722,23 @@ def fuse_ops(expr, opt_level=1):
return _ir_pass.FuseOps(expr, opt_level)


def combine_parallel_conv2d(expr):
"""Fold multiple conv2d into one.
def combine_parallel_conv2d(expr, min_num_branches=3):
"""Combine multiple conv2d into one.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
min_num_branches : int
The minimum number of parallel branches when the transformation should be applied.
Returns
-------
transformed_expr : tvm.relay.Expr
Transformed expression
"""
return _ir_pass.CombineParallelConv2D(expr)
return _ir_pass.CombineParallelConv2D(expr, min_num_branches)


def alter_op_layout(expr):
Expand Down
13 changes: 10 additions & 3 deletions src/relay/pass/combine_parallel_conv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,13 @@ class BranchGroupFinder : private ExprVisitor {

class ParallelConv2DCombiner {
public:
explicit ParallelConv2DCombiner(uint64_t min_num_branches) : min_num_branches_(min_num_branches) {
}

Expr Combine(const Expr& expr) {
auto groups = BranchGroupFinder().Find(expr);
for (const Group& group : groups) {
if (group.size() <= 2) {
if (group.size() < min_num_branches_) {
continue;
}
CombineBranches(group);
Expand All @@ -172,6 +175,7 @@ class ParallelConv2DCombiner {

private:
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> subst_map_;
uint64_t min_num_branches_;

std::tuple<Expr, IndexExpr> TransformWeight(const Group& branches) {
int64_t num_filters = 0; // number of filters of the transformed weight
Expand Down Expand Up @@ -345,11 +349,14 @@ class ParallelConv2DCombiner {
}
};

Expr CombineParallelConv2D(const Expr& expr) { return ParallelConv2DCombiner().Combine(expr); }
/*! \brief Combine parallel conv2d if number of branches >= min_num_branches */
Expr CombineParallelConv2D(const Expr& expr, uint64_t min_num_branches) {
return ParallelConv2DCombiner(min_num_branches).Combine(expr);
}

TVM_REGISTER_API("relay._ir_pass.CombineParallelConv2D")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = CombineParallelConv2D(args[0]);
*ret = CombineParallelConv2D(args[0], args[1]);
});

} // namespace relay
Expand Down
8 changes: 4 additions & 4 deletions tests/python/relay/test_pass_combine_parallel_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def check(x_shape, channels1, channels2, channels3, channels4):

y_before = before(x, w1, w2, w3, w4)
y = relay.ir_pass.infer_type(y_before)
y = relay.ir_pass.combine_parallel_conv2d(y)
y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2)
y = relay.ir_pass.infer_type(y)
y_expected = expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4)
y_expected = relay.ir_pass.infer_type(y_expected)
Expand Down Expand Up @@ -102,7 +102,7 @@ def check(x_shape, channels1, channels2):
bias = relay.var("bias", shape=(channels2, 1, 1))
y_before = before(x, w1, w2, scale1, scale2, bias)
y = relay.ir_pass.infer_type(y_before)
y = relay.ir_pass.combine_parallel_conv2d(y)
y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2)
y = relay.ir_pass.infer_type(y)
y_expected = expected(x, w1, w2, scale1, scale2, bias, channels1, channels2)
y_expected = relay.ir_pass.infer_type(y_expected)
Expand Down Expand Up @@ -142,7 +142,7 @@ def check(x_shape, channels1, channels2):
scale2 = relay.var("scale2", shape=(1,))
y_before = before(x, w1, w2, scale1, scale2)
y = relay.ir_pass.infer_type(y_before)
y = relay.ir_pass.combine_parallel_conv2d(y)
y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2)
y = relay.ir_pass.infer_type(y)
y_expected = expected(x, w1, w2, scale1, scale2, channels1, channels2)
y_expected = relay.ir_pass.infer_type(y_expected)
Expand Down Expand Up @@ -179,7 +179,7 @@ def check(x_shape, repeat):
w = relay.var("w", shape=(out_c, in_c, 1, 1))
y_before = before(x, w, repeat)
y = relay.ir_pass.infer_type(y_before)
y = relay.ir_pass.combine_parallel_conv2d(y)
y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2)
y = relay.ir_pass.infer_type(y)
y_expected = expected(x, w, out_c, repeat)
y_expected = relay.ir_pass.infer_type(y_expected)
Expand Down

0 comments on commit 4810130

Please sign in to comment.