Skip to content

Commit

Permalink
[Relay] Fix reduce axis bug (#3422)
Browse files Browse the repository at this point in the history
* fix relay reduce axis bug

* add tests for reduce bug
  • Loading branch information
altanh authored and tqchen committed Jun 27, 2019
1 parent 7db5779 commit 1e9d014
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/tvm/relay/op/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def sum(data, axis=None, keepdims=False, exclude=False):
result : relay.Expr
The computed result.
"""
axis = [axis] if axis and isinstance(axis, int) else axis
axis = [axis] if isinstance(axis, int) else axis
return _make.sum(data, axis, keepdims, exclude)


Expand Down Expand Up @@ -159,7 +159,7 @@ def all(data, axis=None, keepdims=False, exclude=False):
# [False, True, False]]
"""
axis = [axis] if axis and isinstance(axis, int) else axis
axis = [axis] if isinstance(axis, int) else axis
return _make.all(data, axis, keepdims, exclude)


Expand Down
2 changes: 2 additions & 0 deletions tests/python/relay/test_op_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,9 @@ def _wrapper(data, axis=None, keepdims=False):
[relay.argmax, _with_keepdims(np.argmax)]]:
verify_reduce(func, (d1, d2, d3, d4), None, False, False, ())
verify_reduce(func, (d1, d2, d3, d4), 2, True, False, (d1, d2, 1, d4))
verify_reduce(func, (d1, d2, d3, d4), 0, True, False, (1, d2, d3, d4))
verify_reduce(func, (d1, d2, d3), 1, True, False, (d1, 1, d3))
verify_reduce(func, (d1, d2, d3), 0, True, False, (1, d2, d3))
verify_reduce(func, (d1, d2, d3), None, True, False, (1, 1, 1))
verify_reduce(func, (d1, d2, d3), (0, 1), True, False, (1, 1, d3))
verify_reduce(func, (2, 3, 4), 1, True, False, (2, 1, 4))
Expand Down

0 comments on commit 1e9d014

Please sign in to comment.