Skip to content

Commit

Permalink
[Quantization]: Update simulated_quantize to infer correct layout (#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
f2013519 authored May 18, 2023
1 parent b4475b8 commit 28e9801
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 1 deletion.
21 changes: 20 additions & 1 deletion src/relay/quantize/realize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "../op/annotation/annotation.h"
#include "../qnn/utils.h"
#include "../transforms/fold_constant.h"
#include "../transforms/infer_layout_utils.h"
#include "./quantize.h"

namespace tvm {
Expand Down Expand Up @@ -155,8 +156,26 @@ Expr QuantizeRealize(const Call& ref_call, const Array<Expr>& new_args, const Ob
return QRealizeIntExpr(round_data, dom_scale, DataType::Float(32));
}

InferCorrectLayoutOutput SimQuantizeLayout(const Attrs& attrs, const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
Layout ret;

if (new_in_layouts.defined()) {
ICHECK_GE(new_in_layouts.size(), 1);
ret = new_in_layouts[0];
} else {
ICHECK_GE(old_in_layouts.size(), 1);
ret = old_in_layouts[0];
}
Layout channel_layout = Layout("C");
Array<Layout> input_layouts = {ret, channel_layout, channel_layout, channel_layout};
return InferCorrectLayoutOutput(input_layouts, {ret}, attrs);
}

RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize")
.set_attr<FForwardRewrite>("FQRealizeRewrite", QuantizeRealize);
.set_attr<FForwardRewrite>("FQRealizeRewrite", QuantizeRealize)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", SimQuantizeLayout);

Expr Conv2dRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
const QConfig& cfg = QConfig::Current();
Expand Down
49 changes: 49 additions & 0 deletions tests/python/relay/test_pass_convert_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
from tvm.relay import analysis, transform
from tvm.relay.op import op as reg
from tvm.relay.op import register_alter_op_layout
from tvm.relay.quantize._annotate import (
attach_simulated_quantize,
QAnnotateKind,
)
from tvm.relay.transform.infer_layout_utils import InferCorrectLayoutOutput


Expand Down Expand Up @@ -2635,6 +2639,51 @@ def expected():
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n Expected = \n" + str(b)


def test_simulated_quantize_uses_specified_convert_layout():
def before():
x = relay.var("x", shape=(1, 64, 56, 56))
weight = relay.var("weight", shape=(64, 64, 3, 3))
y = relay.nn.conv2d(
x,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
)
y = attach_simulated_quantize(y, QAnnotateKind.INPUT)
y = relay.nn.relu(y)
y = relay.Function(analysis.free_vars(y), y)
return y

def expected():
x = relay.var("x", shape=(1, 64, 56, 56))
weight = relay.var("weight", shape=(64, 64, 3, 3))
x = relay.layout_transform(x, "NCHW", "NHWC")
weight = relay.layout_transform(weight, "OIHW", "OHWI")
y = relay.nn.conv2d(
x,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NHWC",
kernel_layout="OHWI",
)
y = attach_simulated_quantize(y, QAnnotateKind.INPUT)
y = relay.nn.relu(y)
y = relay.layout_transform(y, "NHWC", "NCHW")
y = relay.Function(analysis.free_vars(y), y)
return y

a = before()
a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NHWC", "OHWI"]}))
b = run_opt_pass(expected(), transform.InferType())

assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n Expected = \n" + str(b)


@pytest.mark.parametrize(
"data_layout, kernel_layout",
[
Expand Down

0 comments on commit 28e9801

Please sign in to comment.