Skip to content

Commit

Permalink
[NNVM] Enhance operator fusion for more element wise patterns (#1548)
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi authored and tqchen committed Aug 8, 2018
1 parent 0241fdc commit 1ed28ae
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 14 deletions.
97 changes: 97 additions & 0 deletions nnvm/src/compiler/graph_fuse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,103 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) {
}
}
}

/*
Above algorithm will not fuse a node whose output is fed to more than one
child node. This is because in general, it does not make sense to fuse multiple
children branches with their parent, as in the following example.
conv2d
/ | \
/ | \
op op op
| | |
| | |
However, when all children branches meet at a certain node, there is a possibility for
further operator fusion. For example, all nodes in the following subgraph can be fused
into a single node, if three 'in-between' nodes and the bottom node are all element wise
operation.
conv2d
/ | \
/ | \
op op op
\ | /
\ | /
elemwise add
|
This pattern is not uncommon. For example, it arises when conv2d op is followed by exponential
linear unit. If bias add and batch normalization are also present, they can be fused as well.
In fact, above fusion algorithm already fuses three in-between nodes and the element wise
add node in the figure above. The following code fuses the conv2d node with the already
fused children nodes. The following patterns are supported.
* Any number of child nodes from the top node
* The path from the top node to bottom node can contain any number of element wise ops.
The only restriction is that in-between nodes cannot have more than one child.
The overview of the algorithm below is as follows:
1. Check if all children nodes are fused into a single op by the existing fusion algorithm
2. Fuse the parent node to children nodes, and update its group id to be the children's group id
3. If the parent node originally belongs to another group (for example, conv + batch norm),
propagate the new group id to a grand parent and upward
*/
if (opt_level >= 1) {
std::vector<std::vector<uint32_t> > children_group_ids(idx.num_nodes());
std::vector<std::vector<uint32_t> > node_ids_per_group(idx.num_nodes());
for (uint32_t nid = idx.num_nodes() - 1; nid != 0; --nid) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
CHECK_NE(group_vec[nid], -1);
node_ids_per_group[group_vec[nid]].push_back(nid);
if (inode.inputs.size() != 1) continue;
const uint32_t parent_nid = inode.inputs[0].node_id;
// if parent node has more than one child, record each child's group id.
if (ref_count[parent_nid] > 1) children_group_ids[parent_nid].push_back(group_vec[nid]);
}

std::vector<int> new_group_id(idx.num_nodes(), -1);
for (uint32_t nid = idx.num_nodes() - 1; nid != 0; --nid) {
if (new_group_id[group_vec[nid]] != -1) {
// propagate new group id from child
group_vec[nid] = new_group_id[group_vec[nid]];
}
TOpPattern pt = op_pattern.get(idx[nid].source->op(), kOpaque);
if (pt == kOpaque) continue;
const auto& group_ids = children_group_ids[nid];
if (group_ids.size() <= 1) continue;
const uint32_t child_group_id = group_ids[0];
const auto& children_node_ids = node_ids_per_group[child_group_id];

auto is_same_group_id = [child_group_id](uint32_t id) {
return id == child_group_id;
};
auto is_fusible_pattern = [&idx](uint32_t child_nid) {
TOpPattern child_pt = op_pattern.get(idx[child_nid].source->op(), kOpaque);
return child_pt <= kBroadcast;
};
// fuse this node with children if
// all children belong to the same group and
// all nodes in the group are element wise or broadcast op.
const bool can_be_fused = std::all_of(group_ids.begin(), group_ids.end(), is_same_group_id) &&
std::all_of(children_node_ids.begin(), children_node_ids.end(), is_fusible_pattern);

if (can_be_fused) {
new_group_id[group_vec[nid]] = child_group_id;
group_vec[nid] = child_group_id;
for (uint32_t nid2 : node_ids_per_group[child_group_id]) {
pattern_vec[nid2] = pattern_vec[nid];
master_vec[nid2] = master_vec[nid];
}
}
}
}

g.attrs["group_root"] = std::make_shared<any>(std::move(group_vec));
g.attrs["group_master"] = std::make_shared<any>(std::move(master_vec));
g.attrs["pattern"] = std::make_shared<any>(std::move(pattern_vec));
Expand Down
44 changes: 43 additions & 1 deletion nnvm/tests/python/compiler/test_op_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tvm.contrib import graph_runtime
from nnvm import symbol as sym
from nnvm.compiler import graph_util, graph_attr
from nnvm.testing import ctx_list
from nnvm.testing import ctx_list, utils

def test_ewise_injective():
x = sym.Variable("x")
Expand Down Expand Up @@ -77,7 +77,49 @@ def test_injective_reduce_injective():
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)


def build_and_run(sym, params, data, out_shape, target, ctx, opt_level=2):
with nnvm.compiler.build_config(opt_level=opt_level):
graph, lib, params = nnvm.compiler.build(sym, target, shape={"data":data.shape}, params=params)
module = graph_runtime.create(graph, lib, ctx)
module.set_input(**params)
module.set_input("data", data)
module.run()
out = module.get_output(0, tvm.nd.empty(out_shape))
return out.asnumpy(), graph


def test_fuse_conv2d_elu():
def elu(data):
return -0.5 * sym.relu(1 - sym.exp(data)) + sym.relu(data)

def get_sym(out_channel):
data = sym.Variable(name="data")
data = sym.conv2d(data=data, kernel_size=(3,3), channels=out_channel, padding=(1, 1),
layout="NCHW", kernel_layout="OIHW", use_bias=True)
data = sym.batch_norm(data)
data = elu(data)
return data

in_channel = 8
out_channel = 16
size = 64
dshape = (1, in_channel, size, size)
oshape = (1, out_channel, size, size)
data = np.random.uniform(-1, 1, dshape).astype(np.float32)

for target, ctx in ctx_list():
sym1 = get_sym(out_channel)
sym2 = get_sym(out_channel)
_, params1 = utils.create_workload(sym1, 1, dshape[1:], seed=0)
_, params2 = utils.create_workload(sym2, 1, dshape[1:], seed=0)
output1, g1 = build_and_run(sym1, params1, data, oshape, target, ctx, opt_level=2)
output2, g2 = build_and_run(sym2, params2, data, oshape, target, ctx, opt_level=0)
np.testing.assert_allclose(output1, output2, rtol=1e-5, atol=1e-5)
# data, conv weight, bias, batch norm gamma, batch norm beta, conv op
assert g1.index.num_nodes == 6

if __name__ == "__main__":
test_injective_reduce_injective()
test_ewise_injective()
test_conv_ewise_injective()
test_fuse_conv2d_elu()
5 changes: 1 addition & 4 deletions topi/python/topi/arm_cpu/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,10 @@ def decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype):
def schedule_conv2d_nchw_arm_cpu(cfg, outs):
"""TOPI schedule callback"""
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []

def _callback(op):
# schedule conv2d
if 'spatial_conv_output' in op.tag and op not in scheduled_ops:
if 'spatial_conv_output' in op.tag:
output = op.output(0)
conv = op.input_tensors[0]

Expand All @@ -65,8 +64,6 @@ def _callback(op):
output = op.output(0)
_schedule_winograd(cfg, s, output, outs[0])

scheduled_ops.append(op)

traverse_inline(s, outs[0].op, _callback)
return s

Expand Down
26 changes: 17 additions & 9 deletions topi/python/topi/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,33 @@

from . import tag

def traverse_inline(s, op, callback):
def traverse_inline(s, final_op, callback):
"""Traverse computation graph and do auto inline
Parameters
----------
s: schedule
The schedule
op: Operation
final_op: Operation
The final output operator.
callback: callable
The callback function on each op
"""
if tag.is_injective(op.tag):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
traverse_inline(s, tensor.op, callback)
callback(op)
visited = set()

def _traverse(op):
if op in visited:
return
visited.add(op)
if tag.is_injective(op.tag):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
_traverse(tensor.op)
callback(op)

_traverse(final_op)


def prod(x):
Expand Down

0 comments on commit 1ed28ae

Please sign in to comment.