From 23153a32aa023589b206af67ab81965fcfeed60c Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 6 Sep 2021 00:37:00 -0700 Subject: [PATCH 1/3] update pow --- python/tvm/relay/frontend/onnx.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 29221884702c..27b806741ed7 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1010,6 +1010,21 @@ def _impl_v1(cls, inputs, attr, params): return _op.log(_op.exp(beta * inputs[0]) + _expr.const(1.0)) * alpha +class Pow(OnnxOpConverter): + """Operator converter for Pow.""" + + @classmethod + def _impl_v13(cls, inputs, attr, params): + x = inputs[0] + y = inputs[1] + x_type = infer_type(x).checked_type.dtype + y_type = infer_type(y).checked_type.dtype + + if x_type != y_type: + y = _op.cast(y, x_type) + return _op.power(x, y) + + class Prelu(OnnxOpConverter): """Operator converter for Prelu.""" @@ -3644,7 +3659,7 @@ def _get_convert_map(opset): "Sinh": Renamer("sinh"), "Tan": Renamer("tan"), "Tanh": Renamer("tanh"), - "Pow": Renamer("power"), + "Pow": Pow.get_converter(opset), "PRelu": Prelu.get_converter(opset), "Sigmoid": Renamer("sigmoid"), "HardSigmoid": HardSigmoid.get_converter(opset), From 0464ce47c33c84907124a1671fe61d543d621518 Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Mon, 6 Sep 2021 00:54:16 -0700 Subject: [PATCH 2/3] update pow --- python/tvm/relay/frontend/onnx.py | 17 ++++++++++++++++- tests/python/frontend/onnx/test_forward.py | 10 ---------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 27b806741ed7..62ab1faa5c5b 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1010,6 +1010,10 @@ def _impl_v1(cls, inputs, attr, params): return _op.log(_op.exp(beta * inputs[0]) + _expr.const(1.0)) * alpha +class Pow(OnnxOpConverter): + """Operator converter for Pow.""" + + class Pow(OnnxOpConverter): """Operator converter for Pow.""" @@ -1017,12 +1021,23 @@ class Pow(OnnxOpConverter): def _impl_v13(cls, inputs, attr, params): x = inputs[0] y = inputs[1] + x_type = infer_type(x).checked_type.dtype + output_type = x_type y_type = infer_type(y).checked_type.dtype + if not x_type.startswith("float"): + x_type = "float32" + x = _op.cast(x, x_type) + if x_type != y_type: y = _op.cast(y, x_type) - return _op.power(x, y) + + # TODO: come up with good default integer pow() func for common backends + result = _op.power(x, y) + if x_type != output_type: + return _op.cast(result, output_type) + return result class Prelu(OnnxOpConverter): diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 9e3b48de8764..e9a5f5fc0829 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4766,16 +4766,6 @@ def verify_eyelike(indata): # This nllloss test is flaky and sometimes gives NaNs # Investigate it here: https://github.com/apache/tvm/issues/8918 "test_nllloss_NCd1d2d3_none_no_weight_negative_ii", - "test_pow_types_float", - "test_pow_types_float32_int32", - "test_pow_types_float32_int64", - "test_pow_types_float32_uint32", - "test_pow_types_float32_uint64", - "test_pow_types_int", - "test_pow_types_int32_float32", - "test_pow_types_int32_int32", - "test_pow_types_int64_float32", - "test_pow_types_int64_int64", "test_qlinearmatmul_2D", "test_qlinearmatmul_3D", "test_range_float_type_positive_delta_expanded", From 185df11a252137a47b171290b678ecfcb58c2281 Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Mon, 6 Sep 2021 00:56:04 -0700 Subject: [PATCH 3/3] remove duplicate --- python/tvm/relay/frontend/onnx.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 62ab1faa5c5b..3c958e622cf6 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1010,10 +1010,6 @@ def _impl_v1(cls, inputs, attr, params): return _op.log(_op.exp(beta * inputs[0]) + _expr.const(1.0)) * alpha -class Pow(OnnxOpConverter): - """Operator converter for Pow.""" - - class Pow(OnnxOpConverter): """Operator converter for Pow."""