Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BYOC-DNNL]rewrite downsize blocks for rensetv1 to get better performance #11822

Merged
merged 2 commits into from
Jul 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 179 additions & 0 deletions python/tvm/relay/op/contrib/dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,185 @@ def rewrite_dense_bias_gelu_reshape_last(mod):
return mod


class ResNetV1Rewrite(DFPatternCallback):
"""
A callback to advance downsize operation when the patterns are as pattern1,
and the result is written in pattern2:
Pattern #1:
%26 = nn.conv2d(%25, ty=Tensor[(64, 256, 1, 1));
%27 = add(%26, ty=Tensor[(64, 1, 1));
%28 = nn.relu(%27);

%29 = nn.conv2d(%28, ty=Tensor[(64, 64, 3, 3));
%30 = add(%29, ty=Tensor[(64, 1, 1));
%31 = nn.relu(%30);

%32 = nn.conv2d(%31, ty=Tensor[(256, 64, 1, 1));
%33 = add(%32, ty=Tensor[(256, 1, 1));
%34 = add(%33, %25);
%35 = nn.relu(%34);

%36 = nn.conv2d(%35, ty=Tensor[(128, 256, 1, 1), strides=[2, 2]);
%37 = add(%36, ty=Tensor[(128, 1, 1));
%38 = nn.relu(%37);

%39 = nn.conv2d(%38, ty=Tensor[(128, 128, 3, 3));
%40 = add(%39, ty=Tensor[(128, 1, 1)]);
%41 = nn.relu(%40);

%42 = nn.conv2d(%41, ty=Tensor[(512, 128, 1, 1));
%43 = nn.conv2d(%35, ty=Tensor[(512, 256, 1, 1), strides=[2, 2]);
%44 = add(%42, ty=Tensor[(512, 1, 1));
%45 = add(%43, ty=Tensor[(512, 1, 1));

%46 = add(%44, %45);
%47 = nn.relu(%46);
Pattern #2:
%26 = nn.conv2d(%25, ty=Tensor[(64, 256, 1, 1));
%27 = add(%26, ty=Tensor[(64, 1, 1));
%28 = nn.relu(%27);

%29 = nn.conv2d(%28, ty=Tensor[(64, 64, 3, 3), strides=[2, 2]);
%30 = add(%29, ty=Tensor[(64, 1, 1));
%31 = nn.relu(%30);

%32 = nn.conv2d(%31, ty=Tensor[(256, 64, 1, 1));
%33 = add(%32, ty=Tensor[(256, 1, 1));
%34 = nn.max_pool2d(%25, pool_size=[1, 1], strides=[2, 2], padding=[0, 0, 0, 0]);
%35 = add(%33, %34);
%36 = nn.relu(%35);

%37 = nn.conv2d(%36, ty=Tensor[(128, 256, 1, 1));
%38 = add(%37, ty=Tensor[(128, 1, 1));
%39 = nn.relu(%38);

%40 = nn.conv2d(%39, ty=Tensor[(128, 128, 3, 3));
%41 = add(%40, ty=Tensor[(128, 1, 1));
%42 = nn.relu(%41);

%43 = nn.conv2d(%42, ty=Tensor[(512, 128, 1, 1));
%44 = nn.conv2d(%36, ty=Tensor[(512, 256, 1, 1));
%45 = add(%43, ty=Tensor[(512, 1, 1));
%46 = add(%44, ty=Tensor[(512, 1, 1));
%47 = add(%45, %46);
%48 = nn.relu(%47);
"""

def __init__(self):
super(ResNetV1Rewrite, self).__init__()
self.attr_lst = []
self.data = wildcard()
self.w1, self.b1 = wildcard(), wildcard()
self.w2, self.b2 = wildcard(), wildcard()
self.w3, self.b3 = wildcard(), wildcard()
self.w4, self.b4 = wildcard(), wildcard()
self.w5, self.b5 = wildcard(), wildcard()
self.w6, self.b6 = wildcard(), wildcard()
self.w7, self.b7 = wildcard(), wildcard()

conv1 = is_op("nn.conv2d")(self.data, self.w1).has_attr({"kernel_size": [1, 1]})
conv1 = is_op("add")(conv1, self.b1)
conv1 = is_op("nn.relu")(conv1)

conv2 = is_op("nn.conv2d")(conv1, self.w2).has_attr({"kernel_size": [3, 3]})
conv2 = is_op("add")(conv2, self.b2)
conv2 = is_op("nn.relu")(conv2)

conv3 = is_op("nn.conv2d")(conv2, self.w3).has_attr({"kernel_size": [1, 1]})
conv3 = is_op("add")(conv3, self.b3)
conv3 = is_op("add")(conv3, self.data)
conv3 = is_op("nn.relu")(conv3)

left_conv4 = is_op("nn.conv2d")(conv3, self.w4).has_attr({"strides": [2, 2]})
left_conv4 = is_op("add")(left_conv4, self.b4)
left_conv4 = is_op("nn.relu")(left_conv4)

left_conv5 = is_op("nn.conv2d")(left_conv4, self.w5).has_attr({"kernel_size": [3, 3]})
left_conv5 = is_op("add")(left_conv5, self.b5)
left_conv5 = is_op("nn.relu")(left_conv5)

left_conv6 = is_op("nn.conv2d")(left_conv5, self.w6).has_attr({"kernel_size": [1, 1]})
left_conv6 = is_op("add")(left_conv6, self.b6)

right_conv7 = is_op("nn.conv2d")(conv3, self.w7).has_attr({"strides": [2, 2]})
right_conv7 = is_op("add")(right_conv7, self.b7)

out = is_op("add")(left_conv6, right_conv7)
out = is_op("nn.relu")(out)
self.pattern = out

def get_attr(self, pre):
"""Recursively retrieve attributes from reshape operator."""

def visit_func(expr):
if isinstance(expr, _expr.Call) and expr.op == relay.op.get("nn.conv2d"):
self.attr_lst.append(expr.attrs)

_analysis.post_order_visit(pre, visit_func)

def callback(self, pre, post, node_map):
self.get_attr(pre)
data = node_map[self.data][0]
w1, b1 = node_map[self.w1][0], node_map[self.b1][0]
w2, b2 = node_map[self.w2][0], node_map[self.b2][0]
w3, b3 = node_map[self.w3][0], node_map[self.b3][0]
w4, b4 = node_map[self.w4][0], node_map[self.b4][0]
w5, b5 = node_map[self.w5][0], node_map[self.b5][0]
w6, b6 = node_map[self.w6][0], node_map[self.b6][0]
w7, b7 = node_map[self.w7][0], node_map[self.b7][0]

new_attrs = self.attr_lst[-7]
conv1 = relay.op.nn.conv2d(data, w1, **new_attrs)
conv1 = relay.op.add(conv1, b1)
conv1 = relay.op.nn.relu(conv1)

new_attrs = dict(self.attr_lst[-6])
new_attrs["strides"] = [2, 2]
conv2 = relay.op.nn.conv2d(conv1, w2, **new_attrs)
conv2 = relay.op.add(conv2, b2)
conv2 = relay.op.nn.relu(conv2)

new_attrs = self.attr_lst[-5]
conv3 = relay.op.nn.conv2d(conv2, w3, **new_attrs)
conv3 = relay.op.add(conv3, b3)
max_pool = relay.op.nn.max_pool2d(
data, pool_size=(1, 1), strides=(2, 2), layout=new_attrs["data_layout"]
)
conv3 = relay.op.add(conv3, max_pool)
conv3 = relay.op.nn.relu(conv3)

new_attrs = dict(self.attr_lst[-4])
new_attrs["strides"] = [1, 1]
left_conv4 = relay.op.nn.conv2d(conv3, w4, **new_attrs)
left_conv4 = relay.op.add(left_conv4, b4)
left_conv4 = relay.op.nn.relu(left_conv4)

new_attrs = self.attr_lst[-3]
left_conv5 = relay.op.nn.conv2d(left_conv4, w5, **new_attrs)
left_conv5 = relay.op.add(left_conv5, b5)
left_conv5 = relay.op.nn.relu(left_conv5)

new_attrs = self.attr_lst[-2]
left_conv6 = relay.op.nn.conv2d(left_conv5, w6, **new_attrs)
left_conv6 = relay.op.add(left_conv6, b6)

new_attrs = dict(self.attr_lst[-1])
new_attrs["strides"] = [1, 1]
right_conv7 = relay.op.nn.conv2d(conv3, w7, **new_attrs)
right_conv7 = relay.op.add(right_conv7, b7)

out = relay.op.add(left_conv6, right_conv7)
out = relay.op.nn.relu(out)
self.attr_lst = []
return out


def rewrite_resnetv1(mod):
"""Rewrite the the ResNetV1 downsize block to reduce the computation complexity."""
mod["main"] = rewrite(ResNetV1Rewrite(), mod["main"])
return mod


class LegalizeQnnOpForDnnl(DFPatternCallback):
"""Legalize QNN based patterns to match DNNL

Expand Down
100 changes: 100 additions & 0 deletions tests/python/contrib/test_dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,6 +1128,106 @@ def get_graph(act=None):
)


def test_resnetv1_rewrite(run_module, dtype="float32"):
def get_graph():
data_shape = (1, 256, 56, 56)
w_shapes = [
(64, 256, 1, 1),
(64, 64, 3, 3),
(256, 64, 1, 1),
(128, 256, 1, 1),
(128, 128, 3, 3),
(512, 128, 1, 1),
(512, 256, 1, 1),
]
x = relay.var("x", shape=data_shape, dtype=dtype)
wights = [relay.const(np.random.randint(0, 1, w).astype(dtype)) for w in w_shapes]
biases = [relay.const(np.random.randint(0, 1, w[0]).astype(dtype)) for w in w_shapes]

conv1 = relay.nn.conv2d(
x,
wights[0],
channels=w_shapes[0][0],
kernel_size=w_shapes[0][2:4],
padding=(w_shapes[0][2] // 2, w_shapes[0][3] // 2),
)
conv1 = relay.nn.bias_add(conv1, biases[0])
conv1 = relay.nn.relu(conv1)

conv2 = relay.nn.conv2d(
conv1,
wights[1],
channels=w_shapes[1][0],
kernel_size=w_shapes[1][2:4],
padding=(w_shapes[1][2] // 2, w_shapes[1][3] // 2),
)
conv2 = relay.nn.bias_add(conv2, biases[1])
conv2 = relay.nn.relu(conv2)

conv3 = relay.nn.conv2d(
conv2,
wights[2],
channels=w_shapes[2][0],
kernel_size=w_shapes[2][2:4],
padding=(w_shapes[2][2] // 2, w_shapes[2][3] // 2),
)
conv3 = relay.nn.bias_add(conv3, biases[2])
conv3 = relay.add(conv3, x)
conv3 = relay.nn.relu(conv3)

left_conv4 = relay.nn.conv2d(
conv3,
wights[3],
channels=w_shapes[3][0],
strides=(2, 2),
kernel_size=w_shapes[3][2:4],
padding=(w_shapes[3][2] // 2, w_shapes[3][3] // 2),
)
left_conv4 = relay.nn.bias_add(left_conv4, biases[3])
left_conv4 = relay.nn.relu(left_conv4)

left_conv5 = relay.nn.conv2d(
left_conv4,
wights[4],
channels=w_shapes[4][0],
kernel_size=w_shapes[4][2:4],
padding=(w_shapes[4][2] // 2, w_shapes[4][3] // 2),
)
left_conv5 = relay.nn.bias_add(left_conv5, biases[4])
left_conv5 = relay.nn.relu(left_conv5)

left_conv6 = relay.nn.conv2d(
left_conv5,
wights[5],
channels=w_shapes[5][0],
kernel_size=w_shapes[5][2:4],
padding=(w_shapes[5][2] // 2, w_shapes[5][3] // 2),
)
left_conv6 = relay.nn.bias_add(left_conv6, biases[5])

right_conv7 = relay.nn.conv2d(
conv3,
wights[6],
channels=w_shapes[6][0],
strides=(2, 2),
kernel_size=w_shapes[6][2:4],
padding=(w_shapes[6][2] // 2, w_shapes[6][3] // 2),
)
right_conv7 = relay.nn.bias_add(right_conv7, biases[6])

out = relay.add(left_conv6, right_conv7)
out = relay.nn.relu(out)

dic = {"x": data_shape}
param_lst = []
return out, dic, param_lst

net, dic, param_lst = get_graph()
net = tvm.IRModule.from_expr(net)
config = net, dic, param_lst
run_and_verify_func(config, run_module=run_module, dtype=dtype)


def permute_shape(shape, l_from="", l_to=""):
res_shape = []
for label in l_to:
Expand Down