diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 44393d3bb703..4202374f8da8 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1021,6 +1021,32 @@ def _impl_v1(cls, inputs, attr, params): return _op.log(_op.exp(beta * inputs[0]) + _expr.const(1.0)) * alpha +class Pow(OnnxOpConverter): + """Operator converter for Pow.""" + + @classmethod + def _impl_v13(cls, inputs, attr, params): + x = inputs[0] + y = inputs[1] + + x_type = infer_type(x).checked_type.dtype + output_type = x_type + y_type = infer_type(y).checked_type.dtype + + if not x_type.startswith("float"): + x_type = "float32" + x = _op.cast(x, x_type) + + if x_type != y_type: + y = _op.cast(y, x_type) + + # TODO: come up with good default integer pow() func for common backends + result = _op.power(x, y) + if x_type != output_type: + return _op.cast(result, output_type) + return result + + class Prelu(OnnxOpConverter): """Operator converter for Prelu.""" @@ -3654,7 +3680,7 @@ def _get_convert_map(opset): "Sinh": Renamer("sinh"), "Tan": Renamer("tan"), "Tanh": Renamer("tanh"), - "Pow": Renamer("power"), + "Pow": Pow.get_converter(opset), "PRelu": Prelu.get_converter(opset), "Sigmoid": Renamer("sigmoid"), "HardSigmoid": HardSigmoid.get_converter(opset), diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index dcf0da6287dc..042a053d36fa 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4788,16 +4788,6 @@ def verify_eyelike(indata): # This nllloss test is flaky and sometimes gives NaNs # Investigate it here: https://github.com/apache/tvm/issues/8918 "test_nllloss_NCd1d2d3_none_no_weight_negative_ii", - "test_pow_types_float", - "test_pow_types_float32_int32", - "test_pow_types_float32_int64", - "test_pow_types_float32_uint32", - "test_pow_types_float32_uint64", - "test_pow_types_int", - "test_pow_types_int32_float32", - "test_pow_types_int32_int32", - "test_pow_types_int64_float32", - "test_pow_types_int64_int64", "test_qlinearmatmul_2D", "test_qlinearmatmul_3D", "test_range_float_type_positive_delta_expanded",