diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index db6e053628bf..4d607e46c97f 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -91,6 +91,7 @@ def __init__(self, model, subgraph, exp_tab): "EQUAL": self.convert_equal, "EXP": self.convert_exp, "EXPAND_DIMS": self.convert_expand_dims, + "FAKE_QUANT": self.convert_fake_quant, "FILL": self.convert_fill, "FLOOR_DIV": self.convert_floor_div, "FLOOR_MOD": self.convert_floor_mod, @@ -3336,6 +3337,56 @@ def convert_densify(self, op): self.set_prefetched_node(output_tensor.tensor_idx, dense_weight) + def convert_fake_quant(self, op): + """Convert TFLite FAKE_QUANT""" + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + + input_tensor = input_tensors[0] + in_expr = self.get_expr(input_tensor.tensor_idx) + + from tflite.BuiltinOptions import BuiltinOptions + from tflite.FakeQuantOptions import FakeQuantOptions + + assert op.BuiltinOptionsType() == BuiltinOptions.FakeQuantOptions + + op_options = op.BuiltinOptions() + fake_quant_options = FakeQuantOptions() + fake_quant_options.Init(op_options.Bytes, op_options.Pos) + + opt_min = fake_quant_options.Min() + opt_max = fake_quant_options.Max() + narrow_range = fake_quant_options.NarrowRange() + num_bits = fake_quant_options.NumBits() + + assert 2 <= num_bits <= 16 + + quant_min = 1 if narrow_range else 0 + quant_max = (1 << num_bits) - 1 + scale = (opt_max - opt_min) / (quant_max - quant_min) + + zero_point_from_min = quant_min - opt_min / scale + if zero_point_from_min <= quant_min: + nudged_zero_point = quant_min + elif zero_point_from_min >= quant_max: + nudged_zero_point = quant_max + else: + nudged_zero_point = round(zero_point_from_min) + + nudged_min = (quant_min - nudged_zero_point) * scale + nudged_max = (quant_max - nudged_zero_point) * scale + + nudged_min_expr = _op.const(nudged_min) + clamped = _op.clip(in_expr, nudged_min, nudged_max) + clamped_shifted = _op.subtract(clamped, nudged_min_expr) + + half = _op.const(0.5) + one = _op.const(1.0) + scale_expr = _op.const(scale) + inv_scale = _op.divide(one, scale_expr) + rounded = _op.floor(_op.add(_op.multiply(clamped_shifted, inv_scale), half)) + return _op.add(_op.multiply(rounded, scale_expr), nudged_min_expr) + def get_expr(self, input_tensor_idx): return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx)) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 7b7f1b1c43b8..f2941030f0ab 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -322,7 +322,6 @@ def compare_tflite_with_tvm( out_names=out_names, mode=mode, ) - # WARNING: the results could well be random values clipped to 0 or 255 because of badly tuned output # range for the specific operator. While adding test ensure that we aren't getting only clipped values # in output tensors that still pass the assertion. For reference see _test_elemwise_qnn_out_range() @@ -2618,6 +2617,22 @@ def test_forward_select(): ) +@pytest.mark.parametrize("quant_bits", [2, 4, 8, 16]) +@pytest.mark.parametrize( + "value, min, max", [[-10.11, -6, 6], [-3.55, -6, 6], [0, -6, 6], [3.55, -6, 6], [10.11, -6, 6]] +) +def test_forward_fake_quant(value, min, max, quant_bits): + with tf.Graph().as_default(): + with tf.Session() as sess: + input = tf.placeholder(tf.float32, shape=[1], name="input") + out = tf.quantization.fake_quant_with_min_max_args( + input, min=min, max=max, num_bits=quant_bits, name=None + ) + + in_data = np.float32(value) + compare_tflite_with_tvm([in_data], ["input:0"], [input], [out]) + + # Squeeze # -------