Skip to content

Commit

Permalink
[TOPI] Fix in interpretation of empty axis parameter in reduction fun…
Browse files Browse the repository at this point in the history
…ctions
  • Loading branch information
padreofthegame committed May 18, 2023
1 parent 28e9801 commit 11fbe42
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 8 deletions.
6 changes: 3 additions & 3 deletions include/tvm/topi/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,6 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T
return T;
}
}

auto T = tvm::te::compute(
targets->shape,
[&](const tvm::Array<tvm::tir::Var>& target_indices) {
Expand All @@ -710,9 +709,10 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T
tvm::tir::make_const(predictions->dtype, 0));
},
name, tag);
return topi::divide(topi::sum(T, {}), topi::sum(W, {}));
return topi::divide(topi::sum(T, tvm::Array<Integer>(nullptr)),
topi::sum(W, tvm::Array<Integer>(nullptr)));
} else if (reduction == "sum") {
return topi::sum(T, {});
return topi::sum(T, tvm::Array<Integer>(nullptr));
} else { // reduction == "none"
return T;
}
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/topi/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ using FCommReduce = std::function<Array<PrimExpr>(Array<PrimExpr> exprs, const A
* \param ndim Number of dimensions in the target.
* \param axis The axis parameter.
*
* \return A non-empty sorted array of valid dimension indices, with no duplicates.
* If the input axis is empty, the result will be an axis including all dimensions.
* \return A sorted array of valid dimension indices, with no duplicates.
* If the input axis is None, the result will be an axis including all dimensions.
* If any input element is negative, it will be treated as an offset from the
* last dimension (same as python indexing rules).
*/
inline std::vector<int> GetRealAxis(int ndim, const Array<Integer>& axis) {
std::vector<int> real_axis;
if (!axis.defined() || axis.size() == 0) {
if (!axis.defined()) {
for (int i = 0; i < ndim; ++i) {
real_axis.push_back(i);
}
Expand All @@ -75,7 +75,7 @@ inline std::vector<int> GetRealAxis(int ndim, const Array<Integer>& axis) {
if (val < 0) {
val += ndim;
}
ICHECK_LE(val, ndim) << " exceeds the maximum dimension " << ndim;
ICHECK_LT(val, ndim) << " exceeds the maximum dimension " << ndim;
ICHECK_GE(val, 0);
real_axis.push_back(static_cast<int>(val));
}
Expand Down
33 changes: 32 additions & 1 deletion tests/python/topi/python/test_topi_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,33 @@
((128, 24, 128, 24), 2, False, "any", "bool"),
((128, 24, 128, 24), 2, False, "sum", "bool"),
((128, 24, 128, 24), 0, True, "sum", "bool"),
((3, 4, 5), None, False, "prod", "float32"),
((3, 4, 5), (2,), False, "prod", "float32"),
((3, 4, 5), (1, 2), True, "prod", "float32"),
((3, 4, 5), (), False, "sum", "float32"),
((3, 4, 5), (), True, "sum", "float32"),
((3, 4, 5), (0, 1, 2), False, "sum", "float32"),
((3, 4, 5), (0, 1, 2), True, "sum", "float32"),
((3, 4, 5), (), False, "prod", "float32"),
((3, 4, 5), (), True, "prod", "float32"),
((3, 4, 5), (0, 1, 2), False, "prod", "float32"),
((3, 4, 5), (0, 1, 2), True, "prod", "float32"),
((3, 4, 5), (), False, "min", "float32"),
((3, 4, 5), (), True, "min", "float32"),
((3, 4, 5), (0, 1, 2), False, "min", "float32"),
((3, 4, 5), (0, 1, 2), True, "min", "float32"),
((3, 4, 5), (), False, "max", "float32"),
((3, 4, 5), (), True, "max", "float32"),
((3, 4, 5), (0, 1, 2), False, "max", "float32"),
((3, 4, 5), (0, 1, 2), True, "max", "float32"),
((3, 4, 5), (), False, "any", "bool"),
((3, 4, 5), (), True, "any", "bool"),
((3, 4, 5), (0, 1, 2), False, "any", "bool"),
((3, 4, 5), (0, 1, 2), True, "any", "bool"),
((3, 4, 5), (), False, "all", "bool"),
((3, 4, 5), (), True, "all", "bool"),
((3, 4, 5), (0, 1, 2), False, "all", "bool"),
((3, 4, 5), (0, 1, 2), True, "all", "bool"),
)


Expand All @@ -63,6 +90,8 @@ def ref_data(in_shape, axis, keepdims, reduce_type, dtype):
out_npy = in_npy_map.sum(axis=axis, keepdims=keepdims, dtype="bool")
else:
out_npy = in_npy_map.sum(axis=axis, keepdims=keepdims)
elif reduce_type == "prod":
out_npy = in_npy_map.prod(axis=axis, keepdims=keepdims)
elif reduce_type == "all" and dtype == "bool":
out_npy = in_npy_map.all(axis=axis, keepdims=keepdims)
elif reduce_type == "any" and dtype == "bool":
Expand Down Expand Up @@ -108,7 +137,7 @@ def _my_npy_argmin(arr, axis, keepdims):

def test_reduce_map(target, dev, ref_data, in_shape, axis, keepdims, reduce_type, dtype):
target = tvm.target.Target(target)
if target.kind.name == "vulkan" and reduce_type in ["sum", "any", "all"]:
if target.kind.name == "vulkan" and reduce_type in ["sum", "prod", "any", "all"]:
pytest.xfail(f"Vulkan backend has known errors on {reduce_type}")

in_npy, in_npy_map, out_npy = ref_data
Expand All @@ -122,6 +151,8 @@ def test_reduce_map(target, dev, ref_data, in_shape, axis, keepdims, reduce_type
B = topi.sum(A, axis=axis, keepdims=keepdims)
else:
B = topi.sum(A1, axis=axis, keepdims=keepdims)
elif reduce_type == "prod":
B = topi.prod(A1, axis=axis, keepdims=keepdims)
elif reduce_type == "all":
B = topi.all(A, axis=axis, keepdims=keepdims)
elif reduce_type == "any":
Expand Down

0 comments on commit 11fbe42

Please sign in to comment.