diff --git a/src/relay/quantize/realize.cc b/src/relay/quantize/realize.cc index 3c2f6eb96d6b..514be1f0a71c 100644 --- a/src/relay/quantize/realize.cc +++ b/src/relay/quantize/realize.cc @@ -34,6 +34,7 @@ #include "../op/annotation/annotation.h" #include "../qnn/utils.h" #include "../transforms/fold_constant.h" +#include "../transforms/infer_layout_utils.h" #include "./quantize.h" namespace tvm { @@ -155,8 +156,26 @@ Expr QuantizeRealize(const Call& ref_call, const Array& new_args, const Ob return QRealizeIntExpr(round_data, dom_scale, DataType::Float(32)); } +InferCorrectLayoutOutput SimQuantizeLayout(const Attrs& attrs, const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { + Layout ret; + + if (new_in_layouts.defined()) { + ICHECK_GE(new_in_layouts.size(), 1); + ret = new_in_layouts[0]; + } else { + ICHECK_GE(old_in_layouts.size(), 1); + ret = old_in_layouts[0]; + } + Layout channel_layout = Layout("C"); + Array input_layouts = {ret, channel_layout, channel_layout, channel_layout}; + return InferCorrectLayoutOutput(input_layouts, {ret}, attrs); +} + RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize") - .set_attr("FQRealizeRewrite", QuantizeRealize); + .set_attr("FQRealizeRewrite", QuantizeRealize) + .set_attr("FInferCorrectLayout", SimQuantizeLayout); Expr Conv2dRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index 72d0232100dc..c3d579186d4a 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -21,6 +21,10 @@ from tvm.relay import analysis, transform from tvm.relay.op import op as reg from tvm.relay.op import register_alter_op_layout +from tvm.relay.quantize._annotate import ( + attach_simulated_quantize, + QAnnotateKind, +) from tvm.relay.transform.infer_layout_utils import InferCorrectLayoutOutput @@ -2635,6 +2639,51 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n Expected = \n" + str(b) +def test_simulated_quantize_uses_specified_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight = relay.var("weight", shape=(64, 64, 3, 3)) + y = relay.nn.conv2d( + x, + weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + ) + y = attach_simulated_quantize(y, QAnnotateKind.INPUT) + y = relay.nn.relu(y) + y = relay.Function(analysis.free_vars(y), y) + return y + + def expected(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight = relay.var("weight", shape=(64, 64, 3, 3)) + x = relay.layout_transform(x, "NCHW", "NHWC") + weight = relay.layout_transform(weight, "OIHW", "OHWI") + y = relay.nn.conv2d( + x, + weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="OHWI", + ) + y = attach_simulated_quantize(y, QAnnotateKind.INPUT) + y = relay.nn.relu(y) + y = relay.layout_transform(y, "NHWC", "NCHW") + y = relay.Function(analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NHWC", "OHWI"]})) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n Expected = \n" + str(b) + + @pytest.mark.parametrize( "data_layout, kernel_layout", [