Skip to content

Commit

Permalink
[TFLite] Add support for GELU conversion
Browse files Browse the repository at this point in the history
This commit adds support for converting a TFLite fp32 GELU operation
to Relay.

Also includes some neighbouring cleanup of version checks to silence
warnings.

Change-Id: Ic43b1525c4b80cf7f47281c52bb9a8f2643c4073
  • Loading branch information
lhutton1 committed Apr 26, 2024
1 parent 278a6af commit 2c3a56f
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
21 changes: 21 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(self, model, subgraph, exp_tab):
"GATHER_ND": self.convert_gather_nd,
"GREATER_EQUAL": self.convert_greater_equal,
"GREATER": self.convert_greater,
"GELU": self.convert_gelu,
"HARD_SWISH": self.convert_hard_swish,
"L2_NORMALIZATION": self.convert_l2_normalization,
"L2_POOL_2D": self.convert_l2_pool2d,
Expand Down Expand Up @@ -1287,6 +1288,26 @@ def convert_elu(self, op):

return out

def convert_gelu(self, op):
"""Convert TFLite GELU"""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
"The TFLite to Relay converter does not support quantized GELU operator yet."
)

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"

input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)
in_type = self.get_tensor_type_str(input_tensor.tensor.Type())

return in_expr * (
_expr.const(0.5, dtype=in_type)
+ _op.erf(in_expr * _expr.const(0.5**0.5, dtype=in_type))
* _expr.const(0.5, dtype=in_type)
)

def convert_square(self, op):
"""Convert TFLite SQUARE"""
input_tensors = self.get_input_tensors(op)
Expand Down
19 changes: 16 additions & 3 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2150,7 +2150,9 @@ def _test_unary_elemwise(math_op, data, quantized, quant_range=(-6, 6), int_quan
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="in")
out = math_op(in_data)
compare_tflite_with_tvm(data, ["in:0"], [in_data], [out])
compare_tflite_with_tvm(
data, ["in:0"], [in_data], [out], experimental_new_converter=True
)


def _unary_elewise_create_model(math_op, data, offset=0, int_quant_dtype=tf.int8):
Expand Down Expand Up @@ -2400,6 +2402,16 @@ def _test_elu(data, quantized, int_quant_dtype=tf.int8):
return _test_unary_elemwise(nn_ops.elu, data, quantized, int_quant_dtype=int_quant_dtype)


#######################################################################
# Gelu
# ---


def _test_gelu(data, quantized, int_quant_dtype=tf.int8):
"""One iteration of elu"""
return _test_unary_elemwise(nn_ops.gelu, data, quantized, int_quant_dtype=int_quant_dtype)


def _test_forward_unary_elemwise(test_op, int_quant_dtype=None, quantized=True, negative=True):
# input data
in_data, inq_data = [], []
Expand Down Expand Up @@ -2439,15 +2451,16 @@ def test_all_unary_elemwise():
_test_forward_unary_elemwise(_test_sin)
_test_forward_unary_elemwise(_test_neg)
_test_forward_unary_elemwise(_test_sqrt, negative=False)
_test_forward_unary_elemwise(_test_gelu, quantized=False)
# tensorflow version upgrade support
if tf.__version__ < LooseVersion("2.6.1"):
if package_version.parse(tf.VERSION) < package_version.parse("2.6.1"):
_test_forward_unary_elemwise(_test_rsqrt, negative=False, int_quant_dtype=tf.uint8)
else:
_test_forward_unary_elemwise(_test_rsqrt, negative=False, int_quant_dtype=tf.int8)
# ceil and cos come with TFLite 1.14.0.post1 fbs schema
if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
_test_forward_unary_elemwise(_test_ceil)
if tf.__version__ < LooseVersion("2.6.1"):
if package_version.parse(tf.VERSION) < package_version.parse("2.6.1"):
_test_forward_unary_elemwise(_test_cos, quantized=False)
else:
_test_forward_unary_elemwise(_test_cos, int_quant_dtype=tf.int8)
Expand Down

0 comments on commit 2c3a56f

Please sign in to comment.