From 49b632208a7aa132fcd86afa019ef2ae29ae012b Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Wed, 25 Aug 2021 19:10:40 -0700 Subject: [PATCH] correct arg passing --- python/tvm/topi/reduction.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/python/tvm/topi/reduction.py b/python/tvm/topi/reduction.py index cba43297f293..45d07af577a3 100644 --- a/python/tvm/topi/reduction.py +++ b/python/tvm/topi/reduction.py @@ -167,7 +167,7 @@ def min(data, axis=None, keepdims=False): return cpp.min(data, axis, keepdims) -def argmax(data, axis=None, keepdims=False, exclude=False, select_last_index=False): +def argmax(data, axis=None, keepdims=False, select_last_index=False): """Returns the indices of the maximum values along an axis. Parameters @@ -185,14 +185,18 @@ def argmax(data, axis=None, keepdims=False, exclude=False, select_last_index=Fal with size one. With this option, the result will broadcast correctly against the input array. + select_last_index: bool + Whether to select the last index if the maximum element appears multiple times, else + select the first index. + Returns ------- ret : tvm.te.Tensor """ - return cpp.argmax(data, axis, keepdims, exclude=exclude, select_last_index=select_last_index) + return cpp.argmax(data, axis, keepdims, select_last_index) -def argmin(data, axis=None, keepdims=False, exclude=False, select_last_index=False): +def argmin(data, axis=None, keepdims=False, select_last_index=False): """Returns the indices of the minimum values along an axis. Parameters @@ -210,11 +214,15 @@ def argmin(data, axis=None, keepdims=False, exclude=False, select_last_index=Fal with size one. With this option, the result will broadcast correctly against the input array. + select_last_index: bool + Whether to select the last index if the minimum element appears multiple times, else + select the first index. + Returns ------- ret : tvm.te.Tensor """ - return cpp.argmin(data, axis, keepdims, exclude, select_last_index) + return cpp.argmin(data, axis, keepdims, select_last_index) def prod(data, axis=None, keepdims=False):