From 9dcf71d8fe33f77ed316a95fcffaf1f7f883ff70 Mon Sep 17 00:00:00 2001 From: Vandana Kannan Date: Thu, 20 Feb 2020 13:04:09 -0800 Subject: [PATCH] ONNX export: Slice op - Handle None value for ends (#14942) * ONNX export: Slice op - Handle None value for ends --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 11 +++++++---- python/mxnet/contrib/onnx/onnx2mx/_op_translations.py | 6 +++++- tests/python-pytest/onnx/test_cases.py | 1 + 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 65da6a3933da..bfd905cbae6a 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -1495,9 +1495,12 @@ def convert_slice_axis(node, **kwargs): axes = int(attrs.get("axis")) starts = int(attrs.get("begin")) - ends = int(attrs.get("end", None)) - if not ends: - raise ValueError("Slice: ONNX doesnt't support 'None' in 'end' attribute") + ends = attrs.get("end", None) + if not ends or ends == 'None': + # ONNX doesn't support None for ends. Since ends=None depicts + # length of dimension, passing dimension in this case. + in_shape = kwargs['in_shape'][0] + ends = in_shape[axes] node = onnx.helper.make_node( "Slice", @@ -1505,7 +1508,7 @@ def convert_slice_axis(node, **kwargs): [name], axes=[axes], starts=[starts], - ends=[ends], + ends=[int(ends)], name=name, ) return [node] diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py index 627181d6ae21..311fd86ef623 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py @@ -499,6 +499,8 @@ def split(attrs, inputs, proto_obj): def _slice(attrs, inputs, proto_obj): """Returns a slice of the input tensor along multiple axes.""" + input_tensor_data = proto_obj.model_metadata.get('input_tensor_data')[0] + input_shape = input_tensor_data[1] new_attrs = translation_utils._fix_attribute_names(attrs, {'axes' : 'axis', 'ends' : 'end', @@ -506,8 +508,10 @@ def _slice(attrs, inputs, proto_obj): # onnx slice provides slicing on multiple axis. Adding multiple slice_axis operator # for multiple axes from mxnet begin = new_attrs.get('begin') - end = new_attrs.get('end') + end = list(new_attrs.get('end')) axes = new_attrs.get('axis', tuple(range(len(begin)))) + for i, axis in enumerate(axes): + end[i] = None if end[i] >= input_shape[axis] else end[i] slice_op = symbol.slice_axis(inputs[0], axis=axes[0], begin=begin[0], end=end[0]) if len(axes) > 1: for i, axis in enumerate(axes): diff --git a/tests/python-pytest/onnx/test_cases.py b/tests/python-pytest/onnx/test_cases.py index 0f822d3b240f..9a72d58e0490 100644 --- a/tests/python-pytest/onnx/test_cases.py +++ b/tests/python-pytest/onnx/test_cases.py @@ -41,6 +41,7 @@ 'test_globalaveragepool', 'test_slice_cpu', 'test_slice_neg', + 'test_slice_end', 'test_reciprocal', 'test_sqrt', 'test_pow',