Skip to content

Commit

Permalink
[Onnx] Pow support for other types (#8933)
Browse files Browse the repository at this point in the history
* update pow

* update pow

* remove duplicate

Co-authored-by: Andrew Zhao Luo <andrewzhaoluo@system76-pc.localdomain>
  • Loading branch information
AndrewZhaoLuo and Andrew Zhao Luo authored Sep 8, 2021
1 parent 475e9e0 commit 9a47fc0
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 11 deletions.
28 changes: 27 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,32 @@ def _impl_v1(cls, inputs, attr, params):
return _op.log(_op.exp(beta * inputs[0]) + _expr.const(1.0)) * alpha


class Pow(OnnxOpConverter):
"""Operator converter for Pow."""

@classmethod
def _impl_v13(cls, inputs, attr, params):
x = inputs[0]
y = inputs[1]

x_type = infer_type(x).checked_type.dtype
output_type = x_type
y_type = infer_type(y).checked_type.dtype

if not x_type.startswith("float"):
x_type = "float32"
x = _op.cast(x, x_type)

if x_type != y_type:
y = _op.cast(y, x_type)

# TODO: come up with good default integer pow() func for common backends
result = _op.power(x, y)
if x_type != output_type:
return _op.cast(result, output_type)
return result


class Prelu(OnnxOpConverter):
"""Operator converter for Prelu."""

Expand Down Expand Up @@ -3654,7 +3680,7 @@ def _get_convert_map(opset):
"Sinh": Renamer("sinh"),
"Tan": Renamer("tan"),
"Tanh": Renamer("tanh"),
"Pow": Renamer("power"),
"Pow": Pow.get_converter(opset),
"PRelu": Prelu.get_converter(opset),
"Sigmoid": Renamer("sigmoid"),
"HardSigmoid": HardSigmoid.get_converter(opset),
Expand Down
10 changes: 0 additions & 10 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4788,16 +4788,6 @@ def verify_eyelike(indata):
# This nllloss test is flaky and sometimes gives NaNs
# Investigate it here: https://github.com/apache/tvm/issues/8918
"test_nllloss_NCd1d2d3_none_no_weight_negative_ii",
"test_pow_types_float",
"test_pow_types_float32_int32",
"test_pow_types_float32_int64",
"test_pow_types_float32_uint32",
"test_pow_types_float32_uint64",
"test_pow_types_int",
"test_pow_types_int32_float32",
"test_pow_types_int32_int32",
"test_pow_types_int64_float32",
"test_pow_types_int64_int64",
"test_qlinearmatmul_2D",
"test_qlinearmatmul_3D",
"test_range_float_type_positive_delta_expanded",
Expand Down

0 comments on commit 9a47fc0

Please sign in to comment.