diff --git a/python/tvm/relay/qnn/op/layout_conversions.py b/python/tvm/relay/qnn/op/layout_conversions.py index a7c90daf36a4..1a3b1771d6ce 100644 --- a/python/tvm/relay/qnn/op/layout_conversions.py +++ b/python/tvm/relay/qnn/op/layout_conversions.py @@ -78,3 +78,51 @@ def convert_qnn_conv2d(attrs, inputs, tinfos, desired_layouts): return relay.qnn.op.conv2d(*inputs, **new_attrs) raise ValueError("Layout %s is not yet supported" % desired_data_layout) + + +@reg.register_convert_op_layout("qnn.conv2d_transpose") +def convert_qnn_conv2d_transpose(attrs, inputs, tinfos, desired_layouts): + """Convert Layout pass registration for QNN conv2d_transpose op. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + tinfos : list of types + List of input and output types + desired_layouts : list of layout strings + List of layouts defining our desired + layout for the data and kernel inputs respectively. + + Returns + ------- + result : tvm.relay.Expr + The transformed expr + """ + # pylint: disable=import-outside-toplevel + from tvm import relay + + assert ( + len(desired_layouts) == 2 + ), "A desired layout is expected for both of qnn.conv2d_transpose's inputs" + desired_data_layout, desired_kernel_layout = map(str, desired_layouts) + assert desired_data_layout != "default", "Data layout cannot be default" + + new_attrs = dict(attrs) + new_attrs["data_layout"] = desired_data_layout + + if desired_kernel_layout != "default": + new_attrs["kernel_layout"] = desired_kernel_layout + return relay.qnn.op.conv2d_transpose(*inputs, **new_attrs) + + # Handle default kernel layouts + if desired_data_layout == "NCHW": + new_attrs["kernel_layout"] = "OIHW" + return relay.qnn.op.conv2d_transpose(*inputs, **new_attrs) + if desired_data_layout == "NHWC": + new_attrs["kernel_layout"] = "HWIO" + return relay.qnn.op.conv2d_transpose(*inputs, **new_attrs) + + raise ValueError("Layout %s is not yet supported" % desired_data_layout) diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index a1965aa2d0c5..9b4d154360b2 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -1100,6 +1100,74 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +def test_qnn_conv_transpose_requantize_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 56, 56, 64), dtype="int8") + weight = relay.var("weight", shape=(3, 3, 64, 64), dtype="int8") + y = relay.qnn.op.conv2d_transpose( + x, + weight, + relay.const(1, "int32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "float32"), + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + out_dtype="int32", + ) + y = relay.qnn.op.requantize( + y, + relay.const(1, "float32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "int32"), + out_dtype="int32", + ) + y = relay.nn.relu(y) + y = relay.Function([x, weight], y) + return y + + def expected(): + x = relay.var("x", shape=(1, 56, 56, 64), dtype="int8") + weight = relay.var("weight", shape=(3, 3, 64, 64), dtype="int8") + x = relay.layout_transform(x, "NHWC", "NCHW") + weight = relay.layout_transform(weight, "HWIO", "OIHW") + y = relay.qnn.op.conv2d_transpose( + x, + weight, + relay.const(1, "int32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "float32"), + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + out_dtype="int32", + ) + y = relay.qnn.op.requantize( + y, + relay.const(1, "float32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "int32"), + axis=1, + out_dtype="int32", + ) + y = relay.nn.relu(y) + y = relay.layout_transform(y, "NCHW", "NHWC") + y = relay.Function(relay.analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass(a, transform.ConvertLayout({"qnn.conv2d_transpose": ["NCHW", "default"]})) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + def test_conv_convert_kernel_layout(): """Check that convolution kernel layout is correctly transformed."""