Skip to content

Commit

Permalink
ONNX bitshfit (apache#7800)
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart authored and Trevor Morris committed May 6, 2021
1 parent 6e0f137 commit 43310ca
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
19 changes: 19 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2782,6 +2782,24 @@ def _impl_v1(cls, inputs, attr, params):
return cls._op_dispatch(operator, inputs, attr, params)


class BitShift(OnnxOpConverter):
"""Operator converter for NonZero"""

@classmethod
def _impl_v11(cls, inputs, attr, params):
if len(inputs) != 2:
raise ValueError("Bitshift expects 2 inputs")

direction = attr.get("direction", "LEFT").decode("ascii")
if direction == "LEFT":
out = _op.left_shift(*inputs)
elif direction == "RIGHT":
out = _op.right_shift(*inputs)
else:
raise ValueError("Unsupported Shift Direction: " + direction)
return out


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand All @@ -2796,6 +2814,7 @@ def _get_convert_map(opset):
# defs/experimental
"Identity": Renamer("copy"),
"Affine": Affine.get_converter(opset),
"BitShift": BitShift.get_converter(opset),
"ThresholdedRelu": ThresholdedRelu.get_converter(opset),
"ScaledTanh": ScaledTanh.get_converter(opset),
"ParametricSoftplus": ParametricSoftPlus.get_converter(opset),
Expand Down
8 changes: 0 additions & 8 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4137,14 +4137,6 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"):

unsupported_onnx_tests = [
"test_basic_convinteger/",
"test_bitshift_left_uint16/",
"test_bitshift_left_uint32/",
"test_bitshift_left_uint64/",
"test_bitshift_left_uint8/",
"test_bitshift_right_uint16/",
"test_bitshift_right_uint32/",
"test_bitshift_right_uint64/",
"test_bitshift_right_uint8/",
"test_cast_DOUBLE_to_FLOAT16/",
"test_cast_FLOAT16_to_DOUBLE/",
"test_cast_FLOAT16_to_FLOAT/",
Expand Down

0 comments on commit 43310ca

Please sign in to comment.