diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 4444b15dfb12..ba2c6b4b54e7 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1468,18 +1468,26 @@ class Unsqueeze(OnnxOpConverter): """Operator converter for Unsqueeze.""" @classmethod - def _impl_v1(cls, inputs, attr, params): - axes = sorted(attr["axes"]) + def run_calculation(cls, tensor, axes): + axes = sorted(axes) for axis in axes: - inputs[0] = _op.expand_dims(inputs[0], axis=axis, num_newaxis=1) - return inputs[0] + tensor = _op.expand_dims(tensor, axis=axis, num_newaxis=1) + return tensor @classmethod - def _impl_v12(cls, inputs, attr, params): + def _impl_v1(cls, inputs, attr, params): + return cls.run_calculation(inputs[0], attr["axes"]) + + @classmethod + def _impl_v13(cls, inputs, attr, params): + if isinstance(inputs[1], _expr.Constant): + constant_axes = list(inputs[1].data.numpy()) + constant_axes = list(map(int, constant_axes)) + return cls.run_calculation(inputs[0], constant_axes) + rank_input = len(infer_type(inputs[0]).checked_type.shape) num_new_axis = int(infer_type(inputs[1]).checked_type.shape[0]) axes = relay.split(inputs[1], num_new_axis).astuple() - result = inputs[0] # TODO (AndrewZhaoLuo): investigate performance issues with consecutive