Skip to content

Commit

Permalink
[ONNX] support additional nllloss tests (#9045)
Browse files Browse the repository at this point in the history
* initial dyn unsqueeze example

* simplify, properly unpack scalar

* basic tests

* squish bugs -- assign proper types

* working topi

* fix things

* temp work

* fix casting to int64

* shape encoding method for axis

* working shape encoding metric

* add comment

* move to non-rank encoded axis

* failing regime

* fix

* it works!

* add test

* add comment on shape func

* remove unused topi

* undo some file changes

* more cleanup

* newline

* clean up

* clean up

* enable multiple axis tests

* move tests to dynamic op

* Update docs

* add converter

* initial dyn unsqueeze example

* simplify, properly unpack scalar

* basic tests

* squish bugs -- assign proper types

* working topi

* fix things

* temp work

* fix casting to int64

* shape encoding method for axis

* working shape encoding metric

* add comment

* move to non-rank encoded axis

* failing regime

* fix

* it works!

* add test

* add comment on shape func

* remove unused topi

* undo some file changes

* more cleanup

* newline

* clean up

* clean up

* enable multiple axis tests

* move tests to dynamic op

* Update docs

* add converter

* working tests

* add test, remove unneeded file

* fix things

* more lint

* more lint

* pick things

* disable opencl tests

* unsqueeze tests

* clean up

* dyn stuff

* add num_newaxis

* add support

* black

* doc string

* remove bad merge

* fix default axis behavior

* rebase

* fix squeeze

* jostle ci

Co-authored-by: Andrew Zhao Luo <andrewzhaoluo@system76-pc.localdomain>
  • Loading branch information
AndrewZhaoLuo and Andrew Zhao Luo authored Sep 30, 2021
1 parent 7974e30 commit 3887628
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 21 deletions.
12 changes: 9 additions & 3 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1496,7 +1496,7 @@ def _impl_v13(cls, inputs, attr, params):
axis = relay.TupleGetItem(axes, i)
# Unpack scalar
axis = relay.reshape(axis, [])
axis = relay.If(
axis = relay.where(
axis >= relay.const(0, "int64"), axis, axis + relay.const(rank_input, "int64")
)
result = _op.expand_dims(result, axis)
Expand All @@ -1509,12 +1509,18 @@ class Squeeze(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
axis = attr.get("axes", None)
return _op.squeeze(*inputs, axis)
return _op.squeeze(inputs[0], axis)

@classmethod
def _impl_v13(cls, inputs, attr, params):
axis = inputs[1]
dtype = infer_type(axis).checked_type.dtype

if isinstance(axis, _expr.Constant):
constant_axes = list(inputs[1].data.numpy())
constant_axes = list(map(int, constant_axes))
return _op.squeeze(inputs[0], constant_axes)

rank = _op.shape_of(_op.shape_of(inputs[0], dtype), dtype)
axis = _op.where(axis < _op.const(0, dtype), axis + rank, axis)
return _op.squeeze(inputs[0], fold_constant(axis))
Expand Down Expand Up @@ -1640,7 +1646,7 @@ def normalize_gather_indices(data, indices, axis):
"""Make sure gather indicies aren't negative"""
ind_dtype = infer_type(indices).checked_type.dtype
# Normalize the indices to a positive range
s = _op.take(_op.shape_of(data, dtype=ind_dtype), _op.const(axis))
s = _op.take(_op.shape_of(data, dtype=ind_dtype), _op.const(axis, dtype="int64"))
cond = fold_constant(indices < _op.const(0, ind_dtype))
if isinstance(cond, _expr.Constant):
val = cond.data.numpy()
Expand Down
19 changes: 1 addition & 18 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4939,25 +4939,8 @@ def verify_eyelike(indata):
"test_maxpool_with_argmax_2d_precomputed_strides",
"test_maxunpool_export_with_output_shape",
"test_mvn",
# When unsqueeze is fully supported, remaining nllloss tests should work:
"test_nllloss_NC_expanded",
"test_nllloss_NCd1_expanded",
"test_nllloss_NCd1_ii_expanded",
"test_nllloss_NCd1_mean_weight_negative_ii_expanded",
"test_nllloss_NCd1_weight_expanded",
"test_nllloss_NCd1_weight_ii_expanded",
"test_nllloss_NCd1d2_expanded",
"test_nllloss_NCd1d2_no_weight_reduction_mean_ii_expanded",
"test_nllloss_NCd1d2_reduction_mean_expanded",
"test_nllloss_NCd1d2_reduction_sum_expanded",
"test_nllloss_NCd1d2_with_weight_expanded",
"test_nllloss_NCd1d2_with_weight_reduction_mean_expanded",
"test_nllloss_NCd1d2_with_weight_reduction_sum_expanded",
"test_nllloss_NCd1d2_with_weight_reduction_sum_ii_expanded",
# This test fails llvm with a lowering error:
"test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded",
"test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded",
"test_nllloss_NCd1d2d3d4d5_mean_weight_expanded",
"test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded",
"test_qlinearmatmul_2D",
"test_qlinearmatmul_3D",
"test_range_float_type_positive_delta_expanded",
Expand Down

0 comments on commit 3887628

Please sign in to comment.