From 02fbaf0ed9120a8f95155e63de42459f230584aa Mon Sep 17 00:00:00 2001 From: AndrewZhaoLuo Date: Thu, 16 Sep 2021 01:49:26 -0700 Subject: [PATCH] [Onnx] Fix NLL Loss tests (#8971) * support negatibve indices in gather * move check to Tensor level indexing, gathernd * add test, update transform.h * remove unneeded gather * missing gather nd change * update tests * proper tensor comparison * blacking * lint * fix error * turn on test * missing test case * revert changes * add normalize_gather_indices * undo change * update * more removing diffs * more undoing Co-authored-by: Andrew Zhao Luo --- python/tvm/relay/frontend/onnx.py | 30 +++++++++++++++++----- tests/python/frontend/onnx/test_forward.py | 5 ---- tests/python/relay/test_any.py | 4 +-- 3 files changed, 26 insertions(+), 13 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 22c81b9d96fc..b30db2e99418 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1561,7 +1561,12 @@ def _impl_common(cls, data, indices, batch_dims=0): indices_shape = infer_shape(indices) indices = _op.transpose(indices, axes=[-1] + list(range(indices_dims - 1))) index_rank = indices_shape[-1] - return _op.gather_nd(data, indices, batch_dims, index_rank) + return _op.gather_nd( + data, + indices, + batch_dims=batch_dims, + index_rank=index_rank, + ) @classmethod def _impl_v1(cls, inputs, attr, params): @@ -3554,6 +3559,11 @@ def _impl_v13(cls, inputs, attr, params): ) input_tensor, target_tensor = inputs[0], inputs[1] + + # Convert negative indices --> positive indices for gather ops, note we have to + # use the original target tensor to interact with ignore_index to have proper behavior. + normalized_target_tensor = normalize_gather_indices(input_tensor, target_tensor, 1) + if len(inputs) == 3: weight_tensor = inputs[2] else: @@ -3563,12 +3573,18 @@ def _impl_v13(cls, inputs, attr, params): dtype=input_tensor.type_annotation.dtype, ) - loss = -relay.gather(input_tensor, axis=1, indices=relay.expand_dims(target_tensor, 1)) + loss = -relay.gather( + input_tensor, + axis=1, + indices=relay.expand_dims(normalized_target_tensor, 1), + ) loss = relay.squeeze(loss, axis=[1]) - expanded_target_tensor = relay.expand_dims(target_tensor, 0) - expanded_target_tensor = relay.nn.batch_flatten(expanded_target_tensor) - flattened_weights = relay.gather_nd(weight_tensor, expanded_target_tensor) + expanded_normalized_target_tensor = relay.expand_dims(normalized_target_tensor, 0) + expanded_normalized_target_tensor = relay.nn.batch_flatten( + expanded_normalized_target_tensor + ) + flattened_weights = relay.gather_nd(weight_tensor, expanded_normalized_target_tensor) select_weights = relay.reshape_like(flattened_weights, loss) loss *= select_weights @@ -3578,7 +3594,9 @@ def _impl_v13(cls, inputs, attr, params): target_tensor, relay.const(ignore_index, dtype=target_tensor.type_annotation.dtype) ) mask_tensor = relay.const(1, dtype="int8") - relay.cast(mask_tensor, "int8") - loss *= relay.cast_like(mask_tensor, loss) + loss = relay.where( + mask_tensor, loss, relay.const(0, infer_type(loss).checked_type.dtype) + ) # This is not explained super clearly in the onnx spec, but masked values don't # contribute toward the final value in reduction diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 02a9413ae579..7318ff7a3c7c 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4771,11 +4771,6 @@ def verify_eyelike(indata): "test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded", "test_nllloss_NCd1d2d3d4d5_mean_weight_expanded", "test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded", - # These nllloss tests are flaky and sometimes gives NaNs - # Investigate it here: https://github.com/apache/tvm/issues/8918 - "test_nllloss_NCd1d2d3_none_no_weight_negative_ii", - # Investigate it here: https://github.com/apache/tvm/issues/8964 - "test_nllloss_NCd1d2d3_sum_weight_high_ii", "test_qlinearmatmul_2D", "test_qlinearmatmul_3D", "test_range_float_type_positive_delta_expanded", diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index d2e275a6a335..decddc1ef0a4 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -24,8 +24,8 @@ from tvm.relay.loops import while_loop from tvm.relay.testing import run_infer_type as infer_type -from utils.assert_diagnostic import DiagnosticTesting from utils import ref_funcs +from utils.assert_diagnostic import DiagnosticTesting def int32(val): @@ -2022,7 +2022,7 @@ def test_gather_nd(): def verify_gather_nd(data_shape, indices_shape, data_shape_np, indices_shape_np, batch_dims=0): x = relay.var("x", relay.TensorType(data_shape, "float32")) y = relay.var("y", relay.TensorType(indices_shape, "int32")) - z = relay.gather_nd(x, y, batch_dims, indices_shape[0]) + z = relay.gather_nd(x, y, batch_dims=batch_dims, index_rank=indices_shape[0]) mod = tvm.IRModule() mod["main"] = relay.Function([x, y], z)