From 6d69ff397e5d0c694065eaf8867d7f2308175b04 Mon Sep 17 00:00:00 2001 From: AndrewZhaoLuo Date: Wed, 9 Feb 2022 10:54:54 -0800 Subject: [PATCH] [QNN] Lookup operations for hard to implement operators (#10053) * initial tanh impl * smalls error * support uint and int lookup into tables * reinterpret cast, working tanh tests * refactor relay func creation * basic casting tests * explicitly say do not handle multi-channel lookups * add example funcs * fix silent fail * fix some bugs with floating point funcs not working * add TODO * add tood * canonicalizations * refactor integer lookup ops into own folder * fq2i stuff * clean up existing tests * flesh out todo * more tests * test on keeping shape good * lookup table fix * replace canonicalization for rsqrt * remove canonicalization of rsqrt * add asf headers * topi tests * gather supports unsigned integer tests * fix things * move to legalization * jostle ci * linting * use take instead of gather * remove gather changes * undo changes * undo changes * undo changes * move thing in range * initial tanh impl * smalls error * support uint and int lookup into tables * reinterpret cast, working tanh tests * refactor relay func creation * basic casting tests * explicitly say do not handle multi-channel lookups * add example funcs * fix silent fail * fix some bugs with floating point funcs not working * add TODO * add tood * canonicalizations * refactor integer lookup ops into own folder * fq2i stuff * clean up existing tests * flesh out todo * more tests * test on keeping shape good * lookup table fix * replace canonicalization for rsqrt * remove canonicalization of rsqrt * add asf headers * gather supports unsigned integer tests * fix things * move to legalization * jostle ci * linting * use take instead of gather * remove gather changes * undo changes * undo changes * undo changes * move thing in range * lint * remove unneeded line * jostle Co-authored-by: andrewzhaoluo (generated by with_the_same_user script) --- python/tvm/relay/qnn/op/__init__.py | 4 +- python/tvm/relay/qnn/op/canonicalizations.py | 160 ++++++++++++ python/tvm/relay/qnn/op/legalizations.py | 20 +- python/tvm/relay/qnn/op/op.py | 25 +- .../transform/fake_quantization_to_integer.py | 4 + src/relay/op/tensor/transform.cc | 3 +- src/relay/qnn/op/rsqrt.cc | 42 +--- .../relay/qnn/test_canonicalizations.py | 231 ++++++++++++++++++ tests/python/relay/test_op_level3.py | 9 +- tests/python/relay/test_op_qnn_rsqrt.py | 4 +- .../test_pass_fake_quantization_to_integer.py | 11 +- .../python/topi/python/test_topi_transform.py | 17 +- 12 files changed, 461 insertions(+), 69 deletions(-) create mode 100644 python/tvm/relay/qnn/op/canonicalizations.py create mode 100644 tests/python/relay/qnn/test_canonicalizations.py diff --git a/python/tvm/relay/qnn/op/__init__.py b/python/tvm/relay/qnn/op/__init__.py index 848409360a9d..745050e286e8 100644 --- a/python/tvm/relay/qnn/op/__init__.py +++ b/python/tvm/relay/qnn/op/__init__.py @@ -18,5 +18,5 @@ """QNN dialect related operators.""" from __future__ import absolute_import as _abs from .qnn import * -from .op import register_qnn_legalize -from . import _qnn, legalizations, layout_conversions +from .op import register_qnn_legalize, register_qnn_canonicalize +from . import _qnn, legalizations, layout_conversions, canonicalizations diff --git a/python/tvm/relay/qnn/op/canonicalizations.py b/python/tvm/relay/qnn/op/canonicalizations.py new file mode 100644 index 000000000000..95e0cb60368d --- /dev/null +++ b/python/tvm/relay/qnn/op/canonicalizations.py @@ -0,0 +1,160 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Consist of utilities and methods for lowering QNN into mainline relay.""" +from typing import Callable + +import numpy as np +import tvm +from tvm import relay + + +def run_const_expr(expr: "relay.Expr") -> np.ndarray: + """Evaluate a const expression, receiving result as np array.""" + mod = tvm.IRModule.from_expr(expr) + vm_exe = relay.create_executor("vm", mod=mod) + return vm_exe.evaluate()().asnumpy() + + +def create_integer_lookup_table( + floating_point_func: Callable[[np.ndarray], np.ndarray], + input_scale: "relay.Expr", + input_zero_point: "relay.Expr", + output_scale: "relay.Expr", + output_zero_point: "relay.Expr", + in_axis: int = -1, + out_axis: int = -1, + in_dtype: str = "uint8", + out_dtype: str = "uint8", +) -> np.ndarray: + """ + Return a table where each input indexes to the output quantizing the given function. + + Note this also supports mapping unsigned and signed integers to each other. + + Args: + floating_point_func: The numpy function which this table is to approximate + input_scale: The scale of the quantized input tensor. + input_zero_point: The zero point of the quantized input tensor. + output_scale: The scale of the quantized output tensor. + output_zero_point: The zero point of the quantized output tensor. + in_axis: The axis for multi-channel quantization of the input if applicable. + out_axis: The axis for multi-channel quantization of the output if applicable. + in_dtype: The dtype of the input tensor. + out_dtype: The wanted dtype of the output tensor. + + Returns: + A numpy array where values in quantized space will index to the output in quantized space + approximating the given function. + """ + if not np.issubdtype(np.dtype(in_dtype), np.integer) or not np.issubdtype( + np.dtype(out_dtype), np.integer + ): + raise ValueError( + f"Only integer dtypes allowed got {in_dtype} and {out_dtype} for in and out dtypes." + ) + + dtype_info = np.iinfo(in_dtype) + + num_bits = dtype_info.bits + + # Use TVMs quantization methods via relay to be consistent + # inputs_quantized = np.array(range(dtype_info.min, dtype_info.max + 1)).astype(in_dtype) + + # First generate a list of all num_bit integer patterns + inputs_quantized = np.array(range(0, 2 ** num_bits), dtype=f"uint{num_bits}") + + # Reinterpret bits as the real datatype + # Note what we are doing here is a bit tricky, the canonical view of our lookup table + # is using the uintX version. When we run the lookup in the relay graph, we cast the + # bit pattern back into this form. + inputs_quantized = inputs_quantized.view(in_dtype) + inputs_quantized = relay.const(inputs_quantized, dtype=in_dtype) + inputs_dequantized = run_const_expr( + relay.qnn.op.dequantize( + inputs_quantized, + input_scale=input_scale, + input_zero_point=input_zero_point, + axis=in_axis, + ) + ) + + output_dequantized = relay.const(floating_point_func(inputs_dequantized)) + output_quantized = run_const_expr( + relay.qnn.op.quantize( + output_dequantized, output_scale, output_zero_point, out_axis, out_dtype + ) + ) + + return output_quantized + + +def create_integer_lookup_op( + input_arg: "relay.Expr", + floating_point_func: Callable[[np.array], np.array], + in_scale: "relay.Expr", + in_zero_point: "relay.Expr", + out_scale: "relay.Expr", + out_zero_point: "relay.Expr", + in_axis: int = -1, + out_axis: int = -1, + in_dtype: str = "uint8", + out_dtype: str = "uint8", +) -> "relay.Expr": + """ + Create a quantized version of the given floating point unary operation using table lookup. + + Args: + input_arg: The quantized input to the final function. + floating_point_func: The numpy function which this table is to approximate + in_scale: The scale of the quantized input tensor. + in_zero_point: The zero point of the quantized input tensor. + out_scale: The scale of the quantized output tensor. + out_zero_point: The zero point of the quantized output tensor. + in_axis: The axis for multi-channel quantization of the input if applicable. + out_axis: The axis for multi-channel quantization of the output if applicable. + in_dtype: The dtype of the input tensor. + out_dtype: The wanted dtype of the output tensor. + + Returns: + A Relay expression representing a quantized version of the given function. + """ + + # TODO: handle multi-channel q, below will fail with multi-channel q + in_scale = in_scale.data.numpy().item() + in_zero_point = in_zero_point.data.numpy().item() + out_scale = out_scale.data.numpy().item() + out_zero_point = out_zero_point.data.numpy().item() + + lookup_table = create_integer_lookup_table( + floating_point_func, + relay.const(in_scale), + relay.const(in_zero_point, dtype="int32"), + relay.const(out_scale), + relay.const(out_zero_point, dtype="int32"), + in_axis=in_axis, + in_dtype=in_dtype, + out_axis=out_axis, + out_dtype=out_dtype, + ) + + in_dtype_info = np.iinfo(in_dtype) + in_dtype_num_bits = in_dtype_info.bits + + lookup_table = relay.const(lookup_table) + index_tensor = relay.reinterpret(input_arg, f"uint{in_dtype_num_bits}") + result = relay.take(lookup_table, index_tensor, axis=0, mode="fast") + return result diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index 52fe6c8ebe2f..fd835d72fc09 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -17,12 +17,13 @@ # pylint: disable=invalid-name, unused-argument """Backend QNN related feature registration""" import numpy as np - import tvm from tvm import relay from tvm._ffi.base import TVMError -from .. import op as reg +from tvm.relay.qnn.op.canonicalizations import create_integer_lookup_op + from ....topi.x86.utils import target_has_sse42 +from .. import op as reg ################################################# # Register the functions for different operators. @@ -46,6 +47,21 @@ def legalize_qnn_dense(attrs, inputs, types): return qnn_dense_legalize(attrs, inputs, types) +# Registering QNN dense legalization function. +@reg.register_qnn_legalize("qnn.rsqrt") +def legalize_qnn_rsqrt(attrs, inputs, types): + return create_integer_lookup_op( + input_arg=inputs[0], + floating_point_func=lambda arr: 1 / np.sqrt(arr), + in_scale=inputs[1], + in_zero_point=inputs[2], + out_scale=inputs[3], + out_zero_point=inputs[4], + in_dtype=types[0].dtype, + out_dtype=types[0].dtype, + ) + + # Default to None. If overridden by target, this will not be run. # Generic QNN Conv2D legalization function. @tvm.target.generic_func diff --git a/python/tvm/relay/qnn/op/op.py b/python/tvm/relay/qnn/op/op.py index 32a61229951c..335947b9f7ce 100644 --- a/python/tvm/relay/qnn/op/op.py +++ b/python/tvm/relay/qnn/op/op.py @@ -20,7 +20,10 @@ def register_qnn_legalize(op_name, legal_op=None, level=10): - """Register legal transformation function for a QNN op + """Register legal transformation function for a QNN op. + + This helps QNN match hardware intrinsics better and is run before + canonicalization. Parameters ---------- @@ -34,3 +37,23 @@ def register_qnn_legalize(op_name, legal_op=None, level=10): The priority level """ return tvm.ir.register_op_attr(op_name, "FTVMQnnLegalize", legal_op, level) + + +def register_qnn_canonicalize(op_name, legal_op=None, level=10): + """Register canonicalization function for a QNN op. + + This transforms QNN ops to mainline Relay components. + + Parameters + ---------- + op_name : str + The name of the operator + + legal_op: function (Attrs, List[Expr], List[relay.Type]) -> Expr + The function for transforming an expr to another expr. + + level : int + The priority level + """ + + return tvm.ir.register_op_attr(op_name, "FTVMQnnCanonicalize", legal_op, level) diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 823f85fcb2e9..e84ba5557a70 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -19,6 +19,10 @@ import tvm from tvm import relay from tvm.ir import TensorAffineType, TupleAffineType + +# import to register canonicalization funcs for fq2i +# pylint: disable=unused-import +from tvm.relay.qnn.op import canonicalizations from tvm.tir import bijective_layout from ..op import register_fake_quantization_to_integer diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index ff506f684911..0407bd1c681b 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3376,8 +3376,7 @@ bool GatherRel(const Array& types, int num_inputs, const Attrs& attrs, << "Gather: expect indices type to be TensorType but get " << types[1]; return false; } - ICHECK(indices->dtype.is_int() || indices->dtype.is_uint()) - << "indices of gather must be tensor of integer"; + ICHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer"; const auto param = attrs.as(); ICHECK(param != nullptr); ICHECK(param->axis.defined()); diff --git a/src/relay/qnn/op/rsqrt.cc b/src/relay/qnn/op/rsqrt.cc index 55814dff422b..93baa308a796 100644 --- a/src/relay/qnn/op/rsqrt.cc +++ b/src/relay/qnn/op/rsqrt.cc @@ -69,42 +69,9 @@ Expr MakeQuantizedRsqrt(Expr x, Expr scale, Expr zero_point, Expr output_scale, return Call(op, {x, scale, zero_point, output_scale, output_zero_point}, Attrs(), {}); } -/* - * \brief Canonicalizes the QNN rsqrt op. - * \param attrs The empty attribute. - * \param new_args The new mutated args to the call node. - * \param arg_types The types of input and output. - * \return The sequence of Relay ops for add op. - */ -Expr QnnRsqrtCanonicalize(const Attrs& attrs, const Array& new_args, - const Array& arg_types) { - // At this time, due to the complexity of implementing this op in int8 or uint8, - // we dequantize the input, run the op in float, and then quantize the output (as below). - // This acts as a placeholder for future hardware enablement, where more hardware specific - // canonicalization can be provided. - - // Get the args. - QnnUnaryOpArguments args(new_args); - - // Get the input dtype and shape. - QnnUnaryOpTensorType input_type(arg_types, 0); - - // Get the types for dequantize/quantize. - Array types; - for (size_t i = 1; i < 5; ++i) { - types.push_back(arg_types[i]); - } - - // Dequantize input. - auto dequantized_arg = Dequantize(args.x, args.scale, args.zero_point, types, -1); - - // Compute Rsqrt(Q_x') - auto output = Rsqrt(dequantized_arg); - - // Quantize output. - return Quantize(output, args.output_scale, args.output_zero_point, input_type.dtype, types, -1); -} - +// Translation to relay is done via canonicalization/legalization functions in python +// e.g. python/tvm/relay/qnn/op/canonicalizations.py or +// python/tvm/relay/qnn/op/legalizations.py RELAY_REGISTER_OP("qnn.rsqrt") .describe("Elementwise rsqrt for quantized tensors.") .set_num_inputs(5) @@ -116,8 +83,7 @@ RELAY_REGISTER_OP("qnn.rsqrt") "The quantization zero_point of the output tensor.") .set_support_level(11) .add_type_rel("QRsqrt", QnnRsqrtRel) - .set_attr("TNonComputational", true) - .set_attr("FTVMQnnCanonicalize", QnnRsqrtCanonicalize); + .set_attr("TNonComputational", true); TVM_REGISTER_GLOBAL("relay.qnn.op._make.rsqrt").set_body_typed(MakeQuantizedRsqrt); diff --git a/tests/python/relay/qnn/test_canonicalizations.py b/tests/python/relay/qnn/test_canonicalizations.py new file mode 100644 index 000000000000..0505a88c07bd --- /dev/null +++ b/tests/python/relay/qnn/test_canonicalizations.py @@ -0,0 +1,231 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Callable + +import numpy as np +from tvm import relay +from tvm.relay.qnn.op import canonicalizations + + +class TestIntegerTableLookupTable: + """Consists of tests testing functionality of creating lookup tables for integer operations.""" + + def fake_identity_func_numpy(self, arr: np.ndarray): + return arr.astype("float32") + + def fake_identity_func_relay( + self, + floating_point_func: Callable[[np.ndarray], np.ndarray], + input_arg=None, + in_scale=relay.const(1.0, dtype="float32"), + in_zero_point=relay.const(0, dtype="int32"), + out_scale=relay.const(1.0, dtype="float32"), + out_zero_point=relay.const(0, dtype="int32"), + in_axis=-1, + out_axis=-1, + in_dtype="uint8", + out_dtype="uint8", + ): + if input_arg is None: + input_arg = relay.const(np.arange(0, 256, dtype="uint8").view(in_dtype)) + + return ( + canonicalizations.create_integer_lookup_op( + input_arg=input_arg, + floating_point_func=floating_point_func, + in_scale=in_scale, + in_zero_point=in_zero_point, + out_scale=out_scale, + out_zero_point=out_zero_point, + in_axis=in_axis, + out_axis=out_axis, + in_dtype=in_dtype, + out_dtype=out_dtype, + ), + input_arg.data.numpy(), + ) + + def dequantize_numpy(self, np_arr, np_scale=1.0, np_zero_point=0): + return (np_arr.astype("int32") - np_zero_point) * np_scale + + def run_function_test( + self, + in_scale: float, + in_zero_point: int, + out_scale: float, + out_zero_point: int, + in_dtype: str, + out_dtype: str, + floating_point_func: Callable[[np.ndarray], np.ndarray], + input_arg: relay.Expr = None, + rtol=1e-7, + atol=0, + ): + relay_lookup, input_arg = self.fake_identity_func_relay( + input_arg=input_arg, + floating_point_func=floating_point_func, + in_scale=relay.const(in_scale, "float32"), + in_zero_point=relay.const(in_zero_point, "int32"), + out_scale=relay.const(out_scale, "float32"), + out_zero_point=relay.const(out_zero_point, "int32"), + in_dtype=in_dtype, + out_dtype=out_dtype, + ) + result = canonicalizations.run_const_expr(relay_lookup) + np.testing.assert_allclose( + floating_point_func( + self.dequantize_numpy(input_arg, np_scale=in_scale, np_zero_point=in_zero_point) + ), + self.dequantize_numpy(result, np_scale=out_scale, np_zero_point=out_zero_point), + atol=atol, + rtol=rtol, + ) + + """Test mapping between different input/output dtypes""" + + def test_int8_to_int8(self): + self.run_function_test( + in_scale=1.0, + in_zero_point=0, + out_scale=1.0, + out_zero_point=0, + in_dtype="int8", + out_dtype="int8", + floating_point_func=self.fake_identity_func_numpy, + ) + + def test_uint8_to_uint8(self): + self.run_function_test( + in_scale=1.0, + in_zero_point=128, + out_scale=1.0, + out_zero_point=128, + in_dtype="uint8", + out_dtype="uint8", + floating_point_func=self.fake_identity_func_numpy, + ) + + def test_int8_to_uint8(self): + self.run_function_test( + in_scale=1.0, + in_zero_point=0, + out_scale=1.0, + out_zero_point=128, + in_dtype="int8", + out_dtype="uint8", + floating_point_func=self.fake_identity_func_numpy, + ) + + def test_uint8_to_int8(self): + self.run_function_test( + in_scale=1.0, + in_zero_point=128, + out_scale=1.0, + out_zero_point=0, + in_dtype="uint8", + out_dtype="int8", + floating_point_func=self.fake_identity_func_numpy, + ) + + """Test different input shapes""" + + def test_keep_input_shapes(self): + # input in floating point ~[-2, 2], final output ~[0, 8] + self.run_function_test( + input_arg=relay.const(np.arange(-128, 128).astype("int8").reshape([2, 2, 8, 8])), + in_scale=0.015, + in_zero_point=0, + out_scale=16 / 256, + out_zero_point=0, + in_dtype="int8", + out_dtype="int8", + floating_point_func=self.fake_identity_func_numpy, + atol=0.03, + rtol=0.01, + ) + self.run_function_test( + input_arg=relay.const(np.arange(-128, 128).astype("int8").reshape([2, 2, 64])), + in_scale=0.015, + in_zero_point=0, + out_scale=16 / 256, + out_zero_point=0, + in_dtype="int8", + out_dtype="int8", + floating_point_func=self.fake_identity_func_numpy, + atol=0.03, + rtol=0.01, + ) + self.run_function_test( + input_arg=relay.const(np.arange(-128, 128).astype("int8").reshape([2, 128])), + in_scale=0.015, + in_zero_point=0, + out_scale=16 / 256, + out_zero_point=0, + in_dtype="int8", + out_dtype="int8", + floating_point_func=self.fake_identity_func_numpy, + atol=0.03, + rtol=0.01, + ) + + """Test mapping with different in/out qparams works.""" + + def test_different_in_out_qparams(self): + self.run_function_test( + in_scale=1.0, + in_zero_point=128, + out_scale=1.0, + out_zero_point=128, + in_dtype="uint8", + out_dtype="uint8", + floating_point_func=self.fake_identity_func_numpy, + atol=1, # numbers range from -128 -> 128 so not that big error + rtol=0, + ) + + """Test some simple functions""" + + def test_tanh(self): + # 1 / 64 in scale -- input range is ~ (-2, 2), tanh(+-2) ~= +-1 + # 1 / 128 out_scale -- output range is ~(-1, 1) + self.run_function_test( + input_arg=relay.const(np.arange(-128, 128).astype("int8")), + in_scale=1 / 64, + in_zero_point=0, + out_scale=1 / 128, + out_zero_point=0, + in_dtype="int8", + out_dtype="int8", + floating_point_func=np.tanh, + atol=0.01, + rtol=0.01, + ) + + def test_exp(self): + # input in floating point ~[-2, 2], final output ~[0, 8] + self.run_function_test( + input_arg=relay.const(np.arange(-128, 128).astype("int8")), + in_scale=0.015, + in_zero_point=0, + out_scale=16 / 256, + out_zero_point=0, + in_dtype="int8", + out_dtype="int8", + floating_point_func=np.exp, + atol=0.03, + rtol=0.01, + ) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 34f33240f5ac..e58ceabd1879 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1278,12 +1278,12 @@ def test_scatter_add(self, target, dev, ref_data, dshape, ishape, axis, dtype, i ], ) def test_gather(target, dev, executor_kind, data, axis, indices, ref_res): - def verify_gather(data, axis, indices, ref_res, indices_dtype="int32"): + def verify_gather(data, axis, indices, ref_res): data = np.asarray(data, dtype="float32") - indices = np.asarray(indices, dtype=indices_dtype) + indices = np.asarray(indices, dtype="int32") ref_res = np.asarray(ref_res) d = relay.var("x", relay.TensorType(data.shape, "float32")) - i = relay.var("y", relay.TensorType(indices.shape, indices_dtype)) + i = relay.var("y", relay.TensorType(indices.shape, "int32")) z = relay.gather(d, axis, i) func = relay.Function([d, i], z) @@ -1294,9 +1294,6 @@ def verify_gather(data, axis, indices, ref_res, indices_dtype="int32"): tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_gather(data, axis, indices, ref_res) - verify_gather(data, axis, indices, ref_res, indices_dtype="uint32") - - verify_gather(data, axis, indices, ref_res) def test_gather_nd(target, dev, executor_kind): diff --git a/tests/python/relay/test_op_qnn_rsqrt.py b/tests/python/relay/test_op_qnn_rsqrt.py index 1eb9b64057ca..0e40768343bd 100644 --- a/tests/python/relay/test_op_qnn_rsqrt.py +++ b/tests/python/relay/test_op_qnn_rsqrt.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. -import tvm import numpy as np +import tvm from tvm import relay @@ -51,6 +51,7 @@ def test_saturation(): func = relay.Function([x], y) mod = tvm.IRModule.from_expr(func) mod = relay.transform.InferType()(mod) + mod = relay.qnn.transform.Legalize()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] @@ -77,6 +78,7 @@ def test_saturation(): func = relay.Function([x], y) mod = tvm.IRModule.from_expr(func) mod = relay.transform.InferType()(mod) + mod = relay.qnn.transform.Legalize()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index 7ef380a1a4c4..e9bfc640c7d5 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -19,6 +19,7 @@ import pytest import tvm from tvm import relay +from tvm.relay.transform import fake_quantization_to_integer def compare_fq_to_int(expr, args, allow_rounding_error=False): @@ -304,14 +305,14 @@ def test_fake_quantize_global_avg_pool(): def test_fake_quantize_rsqrt(): - x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") - zero = relay.const(0) + x = relay.var("x", shape=[1, 3, 3, 3], dtype="int8") + mid_point = relay.const(-128) - x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + x = relay.qnn.op.dequantize(x, relay.const(0.125), mid_point) op = relay.rsqrt(x) - op = relay.qnn.op.quantize(op, relay.const(2.0), zero) + op = relay.qnn.op.quantize(op, relay.const(0.125), mid_point) - x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + x_np = np.random.randint(-128, 127, size=[1, 3, 3, 3], dtype="int8") compare_fq_to_int(op, [x_np], True) diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index ddec14b16d01..730d22cba16a 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -18,11 +18,14 @@ import numpy as np import pytest import tvm -import tvm.testing +from tvm import te +from tvm import topi +from tvm import relay import tvm.topi.testing -from tvm import relay, te, topi from tvm.contrib.nvcc import have_fp16 +import tvm.testing + def verify_expand_dims(in_shape, out_shape, axis, num_newaxis): A = te.placeholder(shape=in_shape, name="A") @@ -1011,16 +1014,6 @@ def test_gather(): verify_gather(np.random.randn(4, 7, 5), 1, np.random.randint(low=0, high=7, size=(4, 10, 5))) verify_gather(np.random.randn(4, 7, 5), 2, np.random.randint(low=0, high=5, size=(4, 7, 2))) verify_gather(np.random.randn(4, 7, 5), 2, np.random.randint(low=0, high=5, size=(4, 7, 10))) - verify_gather( - np.random.randn(4, 7, 5), - 2, - np.random.randint(low=0, high=5, size=(4, 7, 10)).astype("uint32"), - ) - verify_gather( - np.random.randn(4, 7, 5), - 2, - np.random.randint(low=0, high=5, size=(4, 7, 10)).astype("uint8"), - ) @tvm.testing.uses_gpu