Skip to content

Commit

Permalink
[Frontend][TFLite] Implement fake quant (apache#8780)
Browse files Browse the repository at this point in the history
* [Frontend][TFLite] Implement fake quant

* remove unused variable

* fix linting errors

* add more tests

* use pytest parametrize instead of a separate function
  • Loading branch information
euntaik authored and shingjan committed Aug 23, 2021
1 parent 95e55f5 commit 843a893
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 1 deletion.
51 changes: 51 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(self, model, subgraph, exp_tab):
"EQUAL": self.convert_equal,
"EXP": self.convert_exp,
"EXPAND_DIMS": self.convert_expand_dims,
"FAKE_QUANT": self.convert_fake_quant,
"FILL": self.convert_fill,
"FLOOR_DIV": self.convert_floor_div,
"FLOOR_MOD": self.convert_floor_mod,
Expand Down Expand Up @@ -3336,6 +3337,56 @@ def convert_densify(self, op):

self.set_prefetched_node(output_tensor.tensor_idx, dense_weight)

def convert_fake_quant(self, op):
"""Convert TFLite FAKE_QUANT"""
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"

input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)

from tflite.BuiltinOptions import BuiltinOptions
from tflite.FakeQuantOptions import FakeQuantOptions

assert op.BuiltinOptionsType() == BuiltinOptions.FakeQuantOptions

op_options = op.BuiltinOptions()
fake_quant_options = FakeQuantOptions()
fake_quant_options.Init(op_options.Bytes, op_options.Pos)

opt_min = fake_quant_options.Min()
opt_max = fake_quant_options.Max()
narrow_range = fake_quant_options.NarrowRange()
num_bits = fake_quant_options.NumBits()

assert 2 <= num_bits <= 16

quant_min = 1 if narrow_range else 0
quant_max = (1 << num_bits) - 1
scale = (opt_max - opt_min) / (quant_max - quant_min)

zero_point_from_min = quant_min - opt_min / scale
if zero_point_from_min <= quant_min:
nudged_zero_point = quant_min
elif zero_point_from_min >= quant_max:
nudged_zero_point = quant_max
else:
nudged_zero_point = round(zero_point_from_min)

nudged_min = (quant_min - nudged_zero_point) * scale
nudged_max = (quant_max - nudged_zero_point) * scale

nudged_min_expr = _op.const(nudged_min)
clamped = _op.clip(in_expr, nudged_min, nudged_max)
clamped_shifted = _op.subtract(clamped, nudged_min_expr)

half = _op.const(0.5)
one = _op.const(1.0)
scale_expr = _op.const(scale)
inv_scale = _op.divide(one, scale_expr)
rounded = _op.floor(_op.add(_op.multiply(clamped_shifted, inv_scale), half))
return _op.add(_op.multiply(rounded, scale_expr), nudged_min_expr)

def get_expr(self, input_tensor_idx):
return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx))

Expand Down
17 changes: 16 additions & 1 deletion tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,6 @@ def compare_tflite_with_tvm(
out_names=out_names,
mode=mode,
)

# WARNING: the results could well be random values clipped to 0 or 255 because of badly tuned output
# range for the specific operator. While adding test ensure that we aren't getting only clipped values
# in output tensors that still pass the assertion. For reference see _test_elemwise_qnn_out_range()
Expand Down Expand Up @@ -2618,6 +2617,22 @@ def test_forward_select():
)


@pytest.mark.parametrize("quant_bits", [2, 4, 8, 16])
@pytest.mark.parametrize(
"value, min, max", [[-10.11, -6, 6], [-3.55, -6, 6], [0, -6, 6], [3.55, -6, 6], [10.11, -6, 6]]
)
def test_forward_fake_quant(value, min, max, quant_bits):
with tf.Graph().as_default():
with tf.Session() as sess:
input = tf.placeholder(tf.float32, shape=[1], name="input")
out = tf.quantization.fake_quant_with_min_max_args(
input, min=min, max=max, num_bits=quant_bits, name=None
)

in_data = np.float32(value)
compare_tflite_with_tvm([in_data], ["input:0"], [input], [out])


# Squeeze
# -------

Expand Down

0 comments on commit 843a893

Please sign in to comment.