diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 3ad230560f3a..114b8f617387 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -577,7 +577,7 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b [&](const Array& indices) { Array real_indices; for (int32_t i = 0; i < src_tensor_dim; ++i) { - real_indices.push_back(indices[i] * strides(i) + begin(i)); + real_indices.push_back(indices[i] * strides(i) + tvm::min(begin(i), x->shape[i] - 1)); } return x(real_indices); }, diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index ffeb0dd73171..cc66cd3c6fe8 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -930,8 +930,8 @@ class Selu(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - alpha = float(attr.get("alpha", 1.6732)) - gamma = float(attr.get("gamma", 1.0507)) + alpha = float(attr.get("alpha", 1.67326319217681884765625)) + gamma = float(attr.get("gamma", 1.05070102214813232421875)) return _expr.const(gamma) * ( _expr.const(-alpha) * _op.nn.relu(_expr.const(1.0) - _op.exp(inputs[0])) + _op.nn.relu(inputs[0]) @@ -948,6 +948,20 @@ def _impl_v1(cls, inputs, attr, params): return _op.tanh(_expr.const(beta) * inputs[0]) * _expr.const(alpha) +class Shrink(OnnxOpConverter): + """Operator converter for Shrink.""" + + @classmethod + def _impl_v9(cls, inputs, attr, params): + x = inputs[0] + dtype = infer_type(x).checked_type.dtype + lambd = _op.const(attr.get("lambd", 0.5), dtype=dtype) + bias = _op.const(attr.get("bias", 0.0), dtype=dtype) + + zeros = _op.zeros_like(x) + return _op.where(x < -lambd, x + bias, zeros) + _op.where(x > lambd, x - bias, zeros) + + class Softsign(OnnxOpConverter): """Operator converter for Softsign.""" @@ -1146,8 +1160,9 @@ class Unsqueeze(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - for axes in attr["axes"]: - inputs[0] = _op.expand_dims(inputs[0], axis=axes, num_newaxis=1) + axes = sorted(attr["axes"]) + for axis in axes: + inputs[0] = _op.expand_dims(inputs[0], axis=axis, num_newaxis=1) return inputs[0] @@ -1545,10 +1560,7 @@ class Softmax(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - # set default value when axis is not set in the model - if "axis" not in attr: - attr["axis"] = 1 - axis = attr["axis"] + axis = attr.get("axis", 1) ndim = len(infer_shape(inputs[0])) if axis < 0: axis += ndim @@ -1564,10 +1576,7 @@ class LogSoftmax(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - # set default value when axis is not set in the model - if "axis" not in attr: - attr["axis"] = 1 - axis = attr["axis"] + axis = attr.get("axis", 1) ndim = len(infer_shape(inputs[0])) if axis < 0: axis += ndim @@ -1579,6 +1588,40 @@ def _impl_v1(cls, inputs, attr, params): return x - m - _op.log(s) +class Hardmax(OnnxOpConverter): + """Operator converter for Hardmax.""" + + @classmethod + def _impl_v1(cls, inputs, attr, params): + axis = attr.get("axis", 1) + ndim = len(infer_shape(inputs[0])) + if axis < 0: + axis += ndim + dtype = infer_type(inputs[0]).checked_type.dtype + + if axis == 0: + pre = _op.const([1], "int64") + else: + pre = _op.prod( + _op.strided_slice(shape_of(inputs[0]), [0], [axis], [1]), axis=0, keepdims=True + ) + post = _op.prod( + _op.strided_slice(shape_of(inputs[0]), [axis], [2147483647], [1]), axis=0, keepdims=True + ) + newshape = _op.concatenate([pre, post], axis=0) + x = _op.reshape(inputs[0], fold_constant(newshape)) + argmax = _op.argmax(x, axis=1) + onehot = _op.one_hot( + argmax, + _op.const(1.0, dtype), + _op.const(0.0, dtype), + fold_constant(_op.take(shape_of(x), _op.const([1], "int64"))), + 1, + dtype, + ) + return _op.reshape(onehot, shape_of(inputs[0])) + + class OneHot(OnnxOpConverter): """Operator converter for OneHot.""" @@ -2717,7 +2760,8 @@ def _get_convert_map(opset): "Softmax": Softmax.get_converter(opset), "LogSoftmax": LogSoftmax.get_converter(opset), "OneHot": OneHot.get_converter(opset), - # 'Hardmax' + "Hardmax": Hardmax.get_converter(opset), + "Shrink": Shrink.get_converter(opset), "Softsign": Softsign.get_converter(opset), "Gemm": Gemm.get_converter(opset), "MatMul": MatMul.get_converter(opset), diff --git a/python/tvm/relay/op/dyn/_transform.py b/python/tvm/relay/op/dyn/_transform.py index a36b56214bc4..de8ee0895462 100644 --- a/python/tvm/relay/op/dyn/_transform.py +++ b/python/tvm/relay/op/dyn/_transform.py @@ -151,40 +151,51 @@ def _strided_slice_shape_func_input_data(data_shape, begin, end, strides, slice_ ndim = len(data_shape) out = output_tensor((ndim,), "int64") for i in const_range(ndim): + dim_size = int64(data_shape[i]) cbegin = int64(0) - cend = int64(data_shape[i]) + cend = dim_size cstride = int64(1) + if strides.shape[0] > i: cstride = int64(strides[i]) + if begin.shape[0] > i: cbegin = int64(begin[i]) - if cbegin < 0: - cbegin += int64(data_shape[i]) + elif cstride < 0: + cbegin = dim_size + if end.shape[0] <= i: - cend = int64(data_shape[i]) + if cstride < 0: + cend = int64(0) elif slice_mode != 0: cstride = int64(1) if end[i] < 0: - cend = int64(data_shape[i]) + cend = dim_size else: cend = cbegin + int64(end[i]) else: if end[i] > data_shape[i]: - cend = int64(data_shape[i]) - elif end[i] < -data_shape[i]: - cend = int64(-1) + cend = dim_size else: cend = int64(end[i]) - if cend < 0: - cend += int64(data_shape[i]) + assert cstride != 0, "Strides can't be zero." + + if cbegin < 0: + cbegin += dim_size + if cend < 0: + cend += dim_size + if cstride < 0: + if cend < 0: + cend = int64(-1) + if cbegin > dim_size - 1: + cbegin = dim_size - 1 slice_range = cbegin - cend step = -cstride else: slice_range = cend - cbegin step = cstride - out[i] = int64(ceil_div(slice_range, step)) return out diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 6d22b5afd0df..595a3b1c89b3 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4128,6 +4128,14 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): verify_cumsum(data, 1, 1, 1, type="int32") +""" + The following parameterized tests loads the tests that ONNX ships as + serialized ONNX files, inputs, and outputs. The goal of this test + is to ensure the ONNX importer is in line with the ONNX specification. + To allow these tests to run in CI before all pass, a number of tests that + are not yet supported are skipped. +""" + from onnx import numpy_helper f = onnx.__file__ @@ -4159,13 +4167,6 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): "test_eyelike_populate_off_main_diagonal/", "test_eyelike_with_dtype/", "test_eyelike_without_dtype/", - "test_hardmax_axis_0/", - "test_hardmax_axis_1/", - "test_hardmax_axis_2/", - "test_hardmax_default_axis/", - "test_hardmax_example/", - "test_hardmax_negative_axis/", - "test_hardmax_one_hot/", "test_isinf_negative/", "test_isinf_positive/", "test_matmulinteger/", @@ -4209,13 +4210,8 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): "test_scan9_sum/", "test_scan_sum/", "test_scatternd/", - "test_selu_default/", - "test_shrink_hard/", - "test_shrink_soft/", "test_simple_rnn_defaults/", "test_simple_rnn_with_initial_bias/", - "test_slice_neg_steps/", - "test_slice_start_out_of_bounds/", "test_strnormalizer_export_monday_casesensintive_lower/", "test_strnormalizer_export_monday_casesensintive_nochangecase/", "test_strnormalizer_export_monday_casesensintive_upper/", @@ -4235,7 +4231,6 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): "test_unique_sorted_with_axis_3d/", "test_unique_sorted_with_negative_axis/", "test_unique_sorted_without_axis/", - "test_unsqueeze_unsorted_axes/", "test_upsample_nearest/", ] diff --git a/tests/python/relay/dyn/test_dynamic_op_level4.py b/tests/python/relay/dyn/test_dynamic_op_level4.py index 3cb706440cad..43e5beba199f 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level4.py +++ b/tests/python/relay/dyn/test_dynamic_op_level4.py @@ -39,18 +39,19 @@ def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True, # target numpy result x_data = np.random.uniform(size=dshape).astype("float32") ref_res = tvm.topi.testing.strided_slice_python(x_data, begin, end, strides, slice_mode) - data = [x_data, np.array(begin), np.array(end)] - - begin = relay.const(begin, dtype=dtype) - end = relay.const(end, dtype=dtype) + data = [x_data, np.array(begin, dtype=dtype), np.array(end, dtype=dtype)] + begin = relay.var("begin", shape=[len(begin)], dtype=dtype) + end = relay.var("end", shape=[len(end)], dtype=dtype) + inputs = [x, begin, end] if strides: - data.append(np.array(strides)) - strides = relay.const(strides, dtype=dtype) + data.append(np.array(strides, dtype=dtype)) + strides = relay.var("strides", shape=[len(strides)], dtype=dtype) + inputs.append(strides) z = relay.strided_slice(x, begin=begin, end=end, strides=strides, slice_mode=slice_mode) else: z = relay.strided_slice(x, begin=begin, end=end, slice_mode=slice_mode) - func = relay.Function([x], z) + func = relay.Function(inputs, z) func = run_infer_type(func) text = func.astext() @@ -60,7 +61,7 @@ def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True, for target, dev in tvm.testing.enabled_targets(): mod = tvm.ir.IRModule.from_expr(func) intrp = relay.create_executor("vm", mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x_data) + op_res = intrp.evaluate()(*data) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res) verify( @@ -79,6 +80,7 @@ def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True, verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3)) verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3)) verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3)) + verify((20, 10, 5), [20, 10, 4], [0, 0, 1], [-1, -3, -2], (19, 3, 2)) verify( (3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], (2, 4, 3), slice_mode="size", test_ref=False )