From 0bce86f10a44f1f2082dca5bf7a15b5cf07a1389 Mon Sep 17 00:00:00 2001 From: Ivy Zhang Date: Tue, 5 Jul 2022 15:41:25 +0800 Subject: [PATCH] [BYOC-DNNL]rewrite downsize blocks for rensetv1 to get better performance (#11822) * rewrite downsize blocks for rensetv1 to get better performance * fix lint --- python/tvm/relay/op/contrib/dnnl.py | 179 ++++++++++++++++++++++++++++ tests/python/contrib/test_dnnl.py | 100 ++++++++++++++++ 2 files changed, 279 insertions(+) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index c251b66bfbc7..b3ef478f201d 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -782,6 +782,185 @@ def rewrite_dense_bias_gelu_reshape_last(mod): return mod +class ResNetV1Rewrite(DFPatternCallback): + """ + A callback to advance downsize operation when the patterns are as pattern1, + and the result is written in pattern2: + Pattern #1: + %26 = nn.conv2d(%25, ty=Tensor[(64, 256, 1, 1)); + %27 = add(%26, ty=Tensor[(64, 1, 1)); + %28 = nn.relu(%27); + + %29 = nn.conv2d(%28, ty=Tensor[(64, 64, 3, 3)); + %30 = add(%29, ty=Tensor[(64, 1, 1)); + %31 = nn.relu(%30); + + %32 = nn.conv2d(%31, ty=Tensor[(256, 64, 1, 1)); + %33 = add(%32, ty=Tensor[(256, 1, 1)); + %34 = add(%33, %25); + %35 = nn.relu(%34); + + %36 = nn.conv2d(%35, ty=Tensor[(128, 256, 1, 1), strides=[2, 2]); + %37 = add(%36, ty=Tensor[(128, 1, 1)); + %38 = nn.relu(%37); + + %39 = nn.conv2d(%38, ty=Tensor[(128, 128, 3, 3)); + %40 = add(%39, ty=Tensor[(128, 1, 1)]); + %41 = nn.relu(%40); + + %42 = nn.conv2d(%41, ty=Tensor[(512, 128, 1, 1)); + %43 = nn.conv2d(%35, ty=Tensor[(512, 256, 1, 1), strides=[2, 2]); + %44 = add(%42, ty=Tensor[(512, 1, 1)); + %45 = add(%43, ty=Tensor[(512, 1, 1)); + + %46 = add(%44, %45); + %47 = nn.relu(%46); + Pattern #2: + %26 = nn.conv2d(%25, ty=Tensor[(64, 256, 1, 1)); + %27 = add(%26, ty=Tensor[(64, 1, 1)); + %28 = nn.relu(%27); + + %29 = nn.conv2d(%28, ty=Tensor[(64, 64, 3, 3), strides=[2, 2]); + %30 = add(%29, ty=Tensor[(64, 1, 1)); + %31 = nn.relu(%30); + + %32 = nn.conv2d(%31, ty=Tensor[(256, 64, 1, 1)); + %33 = add(%32, ty=Tensor[(256, 1, 1)); + %34 = nn.max_pool2d(%25, pool_size=[1, 1], strides=[2, 2], padding=[0, 0, 0, 0]); + %35 = add(%33, %34); + %36 = nn.relu(%35); + + %37 = nn.conv2d(%36, ty=Tensor[(128, 256, 1, 1)); + %38 = add(%37, ty=Tensor[(128, 1, 1)); + %39 = nn.relu(%38); + + %40 = nn.conv2d(%39, ty=Tensor[(128, 128, 3, 3)); + %41 = add(%40, ty=Tensor[(128, 1, 1)); + %42 = nn.relu(%41); + + %43 = nn.conv2d(%42, ty=Tensor[(512, 128, 1, 1)); + %44 = nn.conv2d(%36, ty=Tensor[(512, 256, 1, 1)); + %45 = add(%43, ty=Tensor[(512, 1, 1)); + %46 = add(%44, ty=Tensor[(512, 1, 1)); + %47 = add(%45, %46); + %48 = nn.relu(%47); + """ + + def __init__(self): + super(ResNetV1Rewrite, self).__init__() + self.attr_lst = [] + self.data = wildcard() + self.w1, self.b1 = wildcard(), wildcard() + self.w2, self.b2 = wildcard(), wildcard() + self.w3, self.b3 = wildcard(), wildcard() + self.w4, self.b4 = wildcard(), wildcard() + self.w5, self.b5 = wildcard(), wildcard() + self.w6, self.b6 = wildcard(), wildcard() + self.w7, self.b7 = wildcard(), wildcard() + + conv1 = is_op("nn.conv2d")(self.data, self.w1).has_attr({"kernel_size": [1, 1]}) + conv1 = is_op("add")(conv1, self.b1) + conv1 = is_op("nn.relu")(conv1) + + conv2 = is_op("nn.conv2d")(conv1, self.w2).has_attr({"kernel_size": [3, 3]}) + conv2 = is_op("add")(conv2, self.b2) + conv2 = is_op("nn.relu")(conv2) + + conv3 = is_op("nn.conv2d")(conv2, self.w3).has_attr({"kernel_size": [1, 1]}) + conv3 = is_op("add")(conv3, self.b3) + conv3 = is_op("add")(conv3, self.data) + conv3 = is_op("nn.relu")(conv3) + + left_conv4 = is_op("nn.conv2d")(conv3, self.w4).has_attr({"strides": [2, 2]}) + left_conv4 = is_op("add")(left_conv4, self.b4) + left_conv4 = is_op("nn.relu")(left_conv4) + + left_conv5 = is_op("nn.conv2d")(left_conv4, self.w5).has_attr({"kernel_size": [3, 3]}) + left_conv5 = is_op("add")(left_conv5, self.b5) + left_conv5 = is_op("nn.relu")(left_conv5) + + left_conv6 = is_op("nn.conv2d")(left_conv5, self.w6).has_attr({"kernel_size": [1, 1]}) + left_conv6 = is_op("add")(left_conv6, self.b6) + + right_conv7 = is_op("nn.conv2d")(conv3, self.w7).has_attr({"strides": [2, 2]}) + right_conv7 = is_op("add")(right_conv7, self.b7) + + out = is_op("add")(left_conv6, right_conv7) + out = is_op("nn.relu")(out) + self.pattern = out + + def get_attr(self, pre): + """Recursively retrieve attributes from reshape operator.""" + + def visit_func(expr): + if isinstance(expr, _expr.Call) and expr.op == relay.op.get("nn.conv2d"): + self.attr_lst.append(expr.attrs) + + _analysis.post_order_visit(pre, visit_func) + + def callback(self, pre, post, node_map): + self.get_attr(pre) + data = node_map[self.data][0] + w1, b1 = node_map[self.w1][0], node_map[self.b1][0] + w2, b2 = node_map[self.w2][0], node_map[self.b2][0] + w3, b3 = node_map[self.w3][0], node_map[self.b3][0] + w4, b4 = node_map[self.w4][0], node_map[self.b4][0] + w5, b5 = node_map[self.w5][0], node_map[self.b5][0] + w6, b6 = node_map[self.w6][0], node_map[self.b6][0] + w7, b7 = node_map[self.w7][0], node_map[self.b7][0] + + new_attrs = self.attr_lst[-7] + conv1 = relay.op.nn.conv2d(data, w1, **new_attrs) + conv1 = relay.op.add(conv1, b1) + conv1 = relay.op.nn.relu(conv1) + + new_attrs = dict(self.attr_lst[-6]) + new_attrs["strides"] = [2, 2] + conv2 = relay.op.nn.conv2d(conv1, w2, **new_attrs) + conv2 = relay.op.add(conv2, b2) + conv2 = relay.op.nn.relu(conv2) + + new_attrs = self.attr_lst[-5] + conv3 = relay.op.nn.conv2d(conv2, w3, **new_attrs) + conv3 = relay.op.add(conv3, b3) + max_pool = relay.op.nn.max_pool2d( + data, pool_size=(1, 1), strides=(2, 2), layout=new_attrs["data_layout"] + ) + conv3 = relay.op.add(conv3, max_pool) + conv3 = relay.op.nn.relu(conv3) + + new_attrs = dict(self.attr_lst[-4]) + new_attrs["strides"] = [1, 1] + left_conv4 = relay.op.nn.conv2d(conv3, w4, **new_attrs) + left_conv4 = relay.op.add(left_conv4, b4) + left_conv4 = relay.op.nn.relu(left_conv4) + + new_attrs = self.attr_lst[-3] + left_conv5 = relay.op.nn.conv2d(left_conv4, w5, **new_attrs) + left_conv5 = relay.op.add(left_conv5, b5) + left_conv5 = relay.op.nn.relu(left_conv5) + + new_attrs = self.attr_lst[-2] + left_conv6 = relay.op.nn.conv2d(left_conv5, w6, **new_attrs) + left_conv6 = relay.op.add(left_conv6, b6) + + new_attrs = dict(self.attr_lst[-1]) + new_attrs["strides"] = [1, 1] + right_conv7 = relay.op.nn.conv2d(conv3, w7, **new_attrs) + right_conv7 = relay.op.add(right_conv7, b7) + + out = relay.op.add(left_conv6, right_conv7) + out = relay.op.nn.relu(out) + self.attr_lst = [] + return out + + +def rewrite_resnetv1(mod): + """Rewrite the the ResNetV1 downsize block to reduce the computation complexity.""" + mod["main"] = rewrite(ResNetV1Rewrite(), mod["main"]) + return mod + + class LegalizeQnnOpForDnnl(DFPatternCallback): """Legalize QNN based patterns to match DNNL diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py index 2138eda08697..078483798c6d 100755 --- a/tests/python/contrib/test_dnnl.py +++ b/tests/python/contrib/test_dnnl.py @@ -1128,6 +1128,106 @@ def get_graph(act=None): ) +def test_resnetv1_rewrite(run_module, dtype="float32"): + def get_graph(): + data_shape = (1, 256, 56, 56) + w_shapes = [ + (64, 256, 1, 1), + (64, 64, 3, 3), + (256, 64, 1, 1), + (128, 256, 1, 1), + (128, 128, 3, 3), + (512, 128, 1, 1), + (512, 256, 1, 1), + ] + x = relay.var("x", shape=data_shape, dtype=dtype) + wights = [relay.const(np.random.randint(0, 1, w).astype(dtype)) for w in w_shapes] + biases = [relay.const(np.random.randint(0, 1, w[0]).astype(dtype)) for w in w_shapes] + + conv1 = relay.nn.conv2d( + x, + wights[0], + channels=w_shapes[0][0], + kernel_size=w_shapes[0][2:4], + padding=(w_shapes[0][2] // 2, w_shapes[0][3] // 2), + ) + conv1 = relay.nn.bias_add(conv1, biases[0]) + conv1 = relay.nn.relu(conv1) + + conv2 = relay.nn.conv2d( + conv1, + wights[1], + channels=w_shapes[1][0], + kernel_size=w_shapes[1][2:4], + padding=(w_shapes[1][2] // 2, w_shapes[1][3] // 2), + ) + conv2 = relay.nn.bias_add(conv2, biases[1]) + conv2 = relay.nn.relu(conv2) + + conv3 = relay.nn.conv2d( + conv2, + wights[2], + channels=w_shapes[2][0], + kernel_size=w_shapes[2][2:4], + padding=(w_shapes[2][2] // 2, w_shapes[2][3] // 2), + ) + conv3 = relay.nn.bias_add(conv3, biases[2]) + conv3 = relay.add(conv3, x) + conv3 = relay.nn.relu(conv3) + + left_conv4 = relay.nn.conv2d( + conv3, + wights[3], + channels=w_shapes[3][0], + strides=(2, 2), + kernel_size=w_shapes[3][2:4], + padding=(w_shapes[3][2] // 2, w_shapes[3][3] // 2), + ) + left_conv4 = relay.nn.bias_add(left_conv4, biases[3]) + left_conv4 = relay.nn.relu(left_conv4) + + left_conv5 = relay.nn.conv2d( + left_conv4, + wights[4], + channels=w_shapes[4][0], + kernel_size=w_shapes[4][2:4], + padding=(w_shapes[4][2] // 2, w_shapes[4][3] // 2), + ) + left_conv5 = relay.nn.bias_add(left_conv5, biases[4]) + left_conv5 = relay.nn.relu(left_conv5) + + left_conv6 = relay.nn.conv2d( + left_conv5, + wights[5], + channels=w_shapes[5][0], + kernel_size=w_shapes[5][2:4], + padding=(w_shapes[5][2] // 2, w_shapes[5][3] // 2), + ) + left_conv6 = relay.nn.bias_add(left_conv6, biases[5]) + + right_conv7 = relay.nn.conv2d( + conv3, + wights[6], + channels=w_shapes[6][0], + strides=(2, 2), + kernel_size=w_shapes[6][2:4], + padding=(w_shapes[6][2] // 2, w_shapes[6][3] // 2), + ) + right_conv7 = relay.nn.bias_add(right_conv7, biases[6]) + + out = relay.add(left_conv6, right_conv7) + out = relay.nn.relu(out) + + dic = {"x": data_shape} + param_lst = [] + return out, dic, param_lst + + net, dic, param_lst = get_graph() + net = tvm.IRModule.from_expr(net) + config = net, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + def permute_shape(shape, l_from="", l_to=""): res_shape = [] for label in l_to: