Skip to content

Commit

Permalink
fix things (apache#9146)
Browse files Browse the repository at this point in the history
Co-authored-by: Andrew Zhao Luo <andrewzhaoluo@system76-pc.localdomain>
  • Loading branch information
2 people authored and ylc committed Jan 7, 2022
1 parent e4fba79 commit bb37a8c
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1468,18 +1468,26 @@ class Unsqueeze(OnnxOpConverter):
"""Operator converter for Unsqueeze."""

@classmethod
def _impl_v1(cls, inputs, attr, params):
axes = sorted(attr["axes"])
def run_calculation(cls, tensor, axes):
axes = sorted(axes)
for axis in axes:
inputs[0] = _op.expand_dims(inputs[0], axis=axis, num_newaxis=1)
return inputs[0]
tensor = _op.expand_dims(tensor, axis=axis, num_newaxis=1)
return tensor

@classmethod
def _impl_v12(cls, inputs, attr, params):
def _impl_v1(cls, inputs, attr, params):
return cls.run_calculation(inputs[0], attr["axes"])

@classmethod
def _impl_v13(cls, inputs, attr, params):
if isinstance(inputs[1], _expr.Constant):
constant_axes = list(inputs[1].data.numpy())
constant_axes = list(map(int, constant_axes))
return cls.run_calculation(inputs[0], constant_axes)

rank_input = len(infer_type(inputs[0]).checked_type.shape)
num_new_axis = int(infer_type(inputs[1]).checked_type.shape[0])
axes = relay.split(inputs[1], num_new_axis).astuple()

result = inputs[0]

# TODO (AndrewZhaoLuo): investigate performance issues with consecutive
Expand Down

0 comments on commit bb37a8c

Please sign in to comment.