diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 8308298e7087..b0464439b01a 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -534,7 +534,7 @@ def unary(expr, type_map): out_t.scale, out_t.zero_point, ) - return [out, x_t] + return [out, out_t] return register_fake_quantization_to_integer(op_name, unary) diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index d0c8cca6b78d..38520ff2df7f 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -318,23 +318,36 @@ def test_fake_quantize_global_avg_pool(): class TestUnaryQNNOp: - def helper_test_fake_quantize_unary_op(self, fp32_op, scale=0.125): - x = relay.var("x", shape=[1, 3, 3, 3], dtype="int8") - mid_point = relay.const(-128) + def helper_test_fake_quantize_unary_op(self, fp32_op, pos_values=False): + for dtype in ["int8", "uint8"]: + x = relay.var("x", shape=[1, 3, 3, 3], dtype=dtype) - x = relay.qnn.op.dequantize(x, relay.const(scale), mid_point) - op = fp32_op(x) - op = relay.qnn.op.quantize(op, relay.const(scale), mid_point) + zero = -128 if dtype == "int8" else 0 + if pos_values: + # Use a positive range for quanitzed ops that only work on positive values + input_mid_point = relay.const(zero) + output_mid_point = relay.const(zero) + else: + input_mid_point = relay.const(np.random.randint(0, 255) + zero) + output_mid_point = relay.const(np.random.randint(0, 255) + zero) - x_np = np.random.randint(-128, 127, size=[1, 3, 3, 3], dtype="int8") + input_scale = relay.const(np.random.rand()) + output_scale = relay.const(np.random.rand()) - compare_fq_to_int(op, [x_np], True) + x = relay.qnn.op.dequantize(x, input_scale, input_mid_point) + op = fp32_op(x) + + op = relay.qnn.op.quantize(op, output_scale, output_mid_point, out_dtype=dtype) + + x_np = np.random.randint(0 + zero, 255 + zero, size=[1, 3, 3, 3], dtype=dtype) + + compare_fq_to_int(op, [x_np], True) def test_sqrt(self): - self.helper_test_fake_quantize_unary_op(fp32_op=relay.sqrt) + self.helper_test_fake_quantize_unary_op(fp32_op=relay.sqrt, pos_values=True) def test_rsqrt(self): - self.helper_test_fake_quantize_unary_op(fp32_op=relay.rsqrt) + self.helper_test_fake_quantize_unary_op(fp32_op=relay.rsqrt, pos_values=True) def test_exp(self): self.helper_test_fake_quantize_unary_op(fp32_op=relay.exp) @@ -349,7 +362,7 @@ def test_tanh(self): self.helper_test_fake_quantize_unary_op(fp32_op=relay.tanh) def test_log(self): - self.helper_test_fake_quantize_unary_op(fp32_op=relay.log) + self.helper_test_fake_quantize_unary_op(fp32_op=relay.log, pos_values=True) def test_fake_quantize_reshape():