From 53d5adaaa38b126bb5eb9a866fe20cd45a556c3f Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 10:59:26 -0700 Subject: [PATCH 01/52] support select_last_index for argmin/max --- include/tvm/relay/attrs/reduce.h | 36 +++++++++++++ include/tvm/relay/expr_functor.h | 6 +-- include/tvm/topi/reduction.h | 86 ++++++++++++++++++++----------- python/tvm/relay/frontend/onnx.py | 11 ++++ src/relay/op/tensor/reduce.cc | 35 ++++++++++--- 5 files changed, 133 insertions(+), 41 deletions(-) diff --git a/include/tvm/relay/attrs/reduce.h b/include/tvm/relay/attrs/reduce.h index 14b75ff1c0a8..274ccc8c352c 100644 --- a/include/tvm/relay/attrs/reduce.h +++ b/include/tvm/relay/attrs/reduce.h @@ -61,6 +61,42 @@ struct ReduceAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes for Reduce operators which reduce by finding a single element. E.g. argmin */ +struct OneElementReduceAttrs : public tvm::AttrsNode { + Array axis; + bool keepdims; + bool select_last_index; + bool exclude; + + TVM_DECLARE_ATTRS(ReduceAttrs, "relay.attrs.ReduceAttrs") { + TVM_ATTR_FIELD(axis) + .set_default(NullValue>()) + .describe(R"code(The axis or axes along which to perform the reduction. + + The default, `axis=()`, will compute over all elements into a + scalar array with shape `(1,)`. + + If `axis` is int, a reduction is performed on a particular axis. + + If `axis` is a tuple of ints, a reduction is performed on all the axes + specified in the tuple. + + If `exclude` is true, reduction will be performed on the axes that are + NOT in axis instead.)code"); + + TVM_ATTR_FIELD(keepdims).set_default(false).describe( + "If this is set to `True`, the reduced axes are left " + "in the result as dimension with size one."); + TVM_ATTR_FIELD(select_last_index) + .set_default(false) + .describe( + "Whether to select the last index if the target element appears multiple times, else " + "select the first index which the target element appears"); + TVM_ATTR_FIELD(exclude).set_default(false).describe( + "Whether to perform reduction on axis that are NOT in axis instead."); + } +}; + struct VarianceAttrs : public tvm::AttrsNode { Array axis; bool keepdims; diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 688ad8254fa8..1932882ad198 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -241,9 +241,9 @@ class MixedModeVisitor : public ::tvm::relay::ExprVisitor { * \brief VisitExpr is finalized to preserve call expansion of dataflow regions */ void VisitExpr(const Expr& expr) final; - void VisitExpr_(const CallNode* op) override; - void VisitExpr_(const TupleNode* op) override; - void VisitExpr_(const TupleGetItemNode* op) override; + virtual void VisitExpr_(const CallNode* op) override; + virtual void VisitExpr_(const TupleNode* op) override; + virtual void VisitExpr_(const TupleGetItemNode* op) override; protected: /*! diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index 15d1455bb267..cf8bdad502dd 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -431,6 +431,40 @@ inline Tensor max(const Tensor& data, const Array& axis, bool keepdims return CommReduce(data, axis, MaxOp, keepdims, atleast1d); } +inline FCommReduce MakeSinglePassReducer( + std::function comparison_op, + std::function initial_value_generator, String name) { + // Create a Commutative Reducer with a comparison operation, and method to get the initial value. + auto fcombine = [&](Array lhs, Array rhs) { + Array result; + result.push_back(tvm::tir::Select(comparison_op(lhs[1], rhs[1]), lhs[0], rhs[0])); // idx + result.push_back(tvm::tir::Select(comparison_op(lhs[1], rhs[1]), lhs[1], rhs[1])); // val + return result; + }; + auto fidentity = [&](std::vector types) { + Array result; + result.push_back(tvm::tir::make_const(types[0], -1)); // idx + result.push_back(initial_value_generator(types[1])); // val + return result; + }; + return MakeCommReducer(fcombine, fidentity, name); +} + +inline FCommReduce MakeArgminReducer(bool select_last_index = false) { + std::function comparison_op; + if (select_last_index) { + comparison_op = [](Var lhs, Var rhs) { return lhs <= rhs; }; + } else { + comparison_op = [](Var lhs, Var rhs) { return lhs < rhs; }; + } + + std::function initial_value_generator = [](const DataType& data_type) { + return tvm::max_value(data_type); + }; + + return MakeSinglePassReducer(comparison_op, initial_value_generator, "argmin"); +} + /*! * \brief Creates an operation that finds the indices of the minimum * values over a given axis. @@ -442,41 +476,30 @@ inline Tensor max(const Tensor& data, const Array& axis, bool keepdims * left in the result as dimensions with size one. This enables the result * to broadcast correctly against the input array. * \param atleast1d Whether the output need to be atleast1d. + * \param select_last_index Whether to select the last index if the minimum element + * appears multiple times, else select the first index. * * \return A Tensor whose op member is the argmin operation */ inline Tensor argmin(const Tensor& data, const Array& axis, bool keepdims = false, - bool atleast1d = false) { - auto fcombine = [](Array lhs, Array rhs) { - Array result; - result.push_back(tvm::tir::Select(lhs[1] <= rhs[1], lhs[0], rhs[0])); // idx - result.push_back(tvm::tir::Select(lhs[1] <= rhs[1], lhs[1], rhs[1])); // val - return result; - }; - auto fidentity = [](std::vector types) { - Array result; - result.push_back(tvm::tir::make_const(types[0], -1)); // idx - result.push_back(tvm::max_value(types[1])); // val - return result; - }; - auto func = MakeCommReducer(fcombine, fidentity, "argmin"); - return CommReduceIdx(data, axis, func, keepdims, atleast1d); + bool atleast1d = false, bool select_last_index = false) { + auto reducer = MakeArgminReducer(select_last_index); + return CommReduceIdx(data, axis, reducer, keepdims, atleast1d); } -inline FCommReduce MakeArgmaxReducer() { - auto fcombine = [](Array lhs, Array rhs) { - Array result; - result.push_back(tvm::tir::Select(lhs[1] >= rhs[1], lhs[0], rhs[0])); // idx - result.push_back(tvm::tir::Select(lhs[1] >= rhs[1], lhs[1], rhs[1])); // val - return result; - }; - auto fidentity = [](std::vector types) { - Array result; - result.push_back(tvm::tir::make_const(types[0], -1)); // idx - result.push_back(tvm::min_value(types[1])); // val - return result; +inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { + std::function comparison_op; + if (select_last_index) { + comparison_op = [](Var lhs, Var rhs) { return lhs >= rhs; }; + } else { + comparison_op = [](Var lhs, Var rhs) { return lhs > rhs; }; + } + + std::function initial_value_generator = [](const DataType& data_type) { + return tvm::min_value(data_type); }; - return MakeCommReducer(fcombine, fidentity, "argmax"); + + return MakeSinglePassReducer(comparison_op, initial_value_generator, "argmax"); } /*! @@ -490,12 +513,13 @@ inline FCommReduce MakeArgmaxReducer() { * left in the result as dimensions with size one. This enables the result * to broadcast correctly against the input array. * \param atleast1d Whether the output need to be atleast1d. - * + * \param select_last_index Whether to select the last index if the maximum element + * appears multiple times, else select the first index. * \return A Tensor whose op member is the argmax operation */ inline Tensor argmax(const Tensor& data, const Array& axis, bool keepdims = false, - bool atleast1d = false) { - auto reducer = MakeArgmaxReducer(); + bool atleast1d = false, bool select_last_index = false) { + auto reducer = MakeArgmaxReducer(select_last_index); return CommReduceIdx(data, axis, reducer, keepdims, atleast1d); } diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 5471f67ea106..6fc896ca5e01 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1846,6 +1846,17 @@ def _impl_v1(cls, inputs, attr, params): return _op.cast(AttrCvt("argmin")(inputs, attr), "int64") + @classmethod + def _impl_v13(cls, inputs, attr, params): + if "select_last_index" in attr: + raise NotImplementedError("select_last_index not supported in ArgMin") + axis = attr.get("axis", 0) + keepdims = attr.get("keepdims", True) + select_last_index = attr.get("select_last_index", False) + attr = {"axis": axis, "keepdims": keepdims} + return _op.cast(AttrCvt("argmin")(inputs, attr), "int64") + # return _op.argmin() + class Softmax(OnnxOpConverter): """Operator converter for Softmax.""" diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index f08af1e7e4ad..04ab183f9019 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -203,13 +203,34 @@ Array ReduceCompute(const Attrs& attrs, const Array& inp auto axes = param->axis; if (param->exclude) { axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis); - if (axes.size() == 0) { - return {topi::identity(inputs[0])}; - } } + + if (axes.size() == 0) { + return {topi::identity(inputs[0])}; + } + return {f(inputs[0], axes, param->keepdims, false)}; } +template +Array OneElementReduceCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type, F f) { + const OneElementReduceAttrs* param = attrs.as(); + ICHECK(param != nullptr); + if (inputs[0]->shape.size() == 0) { + return {topi::identity(inputs[0])}; + } + auto axes = param->axis; + if (param->exclude) { + axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis); + } + + if (axes.size() == 0) { + return {topi::identity(inputs[0])}; + } + return {f(inputs[0], axes, param->keepdims, false, param->select_last_index)}; +} + /*! * \brief ReduceShapeImpl get the outshape for the reduction operator * \param in_shape Shape of input data. @@ -333,7 +354,7 @@ Expr MakeReduce(Expr data, Array axis, bool keepdims, bool exclude, Str Array ArgMaxCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - return ReduceCompute(attrs, inputs, out_type, topi::argmax); + return OneElementReduceCompute(attrs, inputs, out_type, topi::argmax); } RELAY_REGISTER_REDUCE_OP("argmax") @@ -341,7 +362,7 @@ RELAY_REGISTER_REDUCE_OP("argmax") values over a given axis. )code" TVM_ADD_FILELINE) - .set_attrs_type() + .set_attrs_type() .set_support_level(4) .add_type_rel("ArgReduce", ArgReduceRel) .set_attr("FTVMCompute", ArgMaxCompute) @@ -349,7 +370,7 @@ values over a given axis. Array ArgMinCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - return ReduceCompute(attrs, inputs, out_type, topi::argmin); + return OneElementReduceCompute(attrs, inputs, out_type, topi::argmin); } RELAY_REGISTER_REDUCE_OP("argmin") @@ -357,7 +378,7 @@ RELAY_REGISTER_REDUCE_OP("argmin") values over a given axis. )code" TVM_ADD_FILELINE) - .set_attrs_type() + .set_attrs_type() .set_support_level(4) .add_type_rel("ArgReduce", ArgReduceRel) .set_attr("FTVMCompute", ArgMinCompute) From 95a6517aa418358dbfb34da862c51bfe75ed9e61 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 11:13:59 -0700 Subject: [PATCH 02/52] reverse conditions which made on accident --- include/tvm/topi/reduction.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index cf8bdad502dd..19b8195a5116 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -453,9 +453,9 @@ inline FCommReduce MakeSinglePassReducer( inline FCommReduce MakeArgminReducer(bool select_last_index = false) { std::function comparison_op; if (select_last_index) { - comparison_op = [](Var lhs, Var rhs) { return lhs <= rhs; }; - } else { comparison_op = [](Var lhs, Var rhs) { return lhs < rhs; }; + } else { + comparison_op = [](Var lhs, Var rhs) { return lhs <= rhs; }; } std::function initial_value_generator = [](const DataType& data_type) { @@ -490,9 +490,9 @@ inline Tensor argmin(const Tensor& data, const Array& axis, bool keepdi inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { std::function comparison_op; if (select_last_index) { - comparison_op = [](Var lhs, Var rhs) { return lhs >= rhs; }; - } else { comparison_op = [](Var lhs, Var rhs) { return lhs > rhs; }; + } else { + comparison_op = [](Var lhs, Var rhs) { return lhs >= rhs; }; } std::function initial_value_generator = [](const DataType& data_type) { From 5e1f06ac5bedebafd338de727296c4e114591c85 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 11:21:41 -0700 Subject: [PATCH 03/52] forward args in reduce.py --- python/tvm/relay/op/reduce.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/op/reduce.py b/python/tvm/relay/op/reduce.py index 368ffb5ab0ca..dfc8e71026af 100644 --- a/python/tvm/relay/op/reduce.py +++ b/python/tvm/relay/op/reduce.py @@ -17,13 +17,13 @@ """Reduce operators.""" # pylint: disable=redefined-builtin +from ..expr import Tuple, TupleWrapper from . import _make -from .tensor import sqrt, log, exp +from .tensor import exp, log, sqrt from .transform import squeeze -from ..expr import Tuple, TupleWrapper -def argmax(data, axis=None, keepdims=False, exclude=False): +def argmax(data, axis=None, keepdims=False, exclude=False, select_last_index=False): """Returns the indices of the maximum values along an axis. Parameters @@ -45,16 +45,20 @@ def argmax(data, axis=None, keepdims=False, exclude=False): If `exclude` is true, reduction will be performed on the axes that are NOT in axis instead. + select_last_index : bool + Whether to select the last index or the first index if the max element appears in multiple indices, + default is False (first index). + Returns ------- result : relay.Expr The computed result. """ axis = [axis] if isinstance(axis, int) else axis - return _make.argmax(data, axis, keepdims, exclude) + return _make.argmax(data, axis, keepdims, exclude, select_last_index) -def argmin(data, axis=None, keepdims=False, exclude=False): +def argmin(data, axis=None, keepdims=False, exclude=False, select_last_index=False): """Returns the indices of the minimum values along an axis. Parameters @@ -76,13 +80,17 @@ def argmin(data, axis=None, keepdims=False, exclude=False): If `exclude` is true, reduction will be performed on the axes that are NOT in axis instead. + select_last_index : bool + Whether to select the last index or the first index if the min element appears in multiple indices, + default is False (first index). + Returns ------- result : relay.Expr The computed result. """ axis = [axis] if isinstance(axis, int) else axis - return _make.argmin(data, axis, keepdims, exclude) + return _make.argmin(data, axis, keepdims, exclude, select_last_index) def sum(data, axis=None, keepdims=False, exclude=False): From 962e38a8329c67c425dc3cdae222acb0e2ffe3e9 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 11:52:46 -0700 Subject: [PATCH 04/52] make proper nodes for reduction ops --- src/relay/op/tensor/reduce.cc | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 04ab183f9019..96702a925089 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -345,6 +345,16 @@ Expr MakeReduce(Expr data, Array axis, bool keepdims, bool exclude, Str return Call(Op::Get(op_name), {data}, Attrs(attrs), {}); } +Expr MakeOneElementReduce(Expr data, Array axis, bool keepdims, bool exclude, + bool select_last_index, String op_name) { + auto attrs = make_object(); + attrs->axis = std::move(axis); + attrs->keepdims = keepdims; + attrs->exclude = exclude; + attrs->select_last_index = select_last_index; + return Call(Op::Get(op_name), {data}, Attrs(attrs), {}); +} + #define RELAY_REGISTER_REDUCE_OP(OpName) \ TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ .set_body_typed([](Expr data, Array axis, bool keepdims, bool exclude) { \ @@ -352,12 +362,20 @@ Expr MakeReduce(Expr data, Array axis, bool keepdims, bool exclude, Str }); \ RELAY_REGISTER_OP(OpName).set_num_inputs(1).add_argument("data", "Tensor", "The input tensor.") +#define RELAY_REGISTER_ONE_ELEMENT_REDUCE_OP(OpName) \ + TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ + .set_body_typed([](Expr data, Array axis, bool keepdims, bool exclude, \ + bool select_last_index) { \ + return MakeOneElementReduce(data, axis, keepdims, exclude, select_last_index, OpName); \ + }); \ + RELAY_REGISTER_OP(OpName).set_num_inputs(1).add_argument("data", "Tensor", "The input tensor.") + Array ArgMaxCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return OneElementReduceCompute(attrs, inputs, out_type, topi::argmax); } -RELAY_REGISTER_REDUCE_OP("argmax") +RELAY_REGISTER_ONE_ELEMENT_REDUCE_OP("argmax") .describe(R"code(Creates an operation that finds the indices of the maximum values over a given axis. @@ -373,7 +391,7 @@ Array ArgMinCompute(const Attrs& attrs, const Array& inp return OneElementReduceCompute(attrs, inputs, out_type, topi::argmin); } -RELAY_REGISTER_REDUCE_OP("argmin") +RELAY_REGISTER_ONE_ELEMENT_REDUCE_OP("argmin") .describe(R"code(Creates an operation that finds the indices of the minimum values over a given axis. From f92089b42c1ff31740c028c2f69526cbae29cb9b Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 12:56:07 -0700 Subject: [PATCH 05/52] remove complicated nested lambdas --- include/tvm/topi/reduction.h | 53 +++++++++++++---------------------- src/relay/op/tensor/reduce.cc | 13 ++++----- 2 files changed, 26 insertions(+), 40 deletions(-) diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index 19b8195a5116..4f2d566e4538 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -431,38 +431,22 @@ inline Tensor max(const Tensor& data, const Array& axis, bool keepdims return CommReduce(data, axis, MaxOp, keepdims, atleast1d); } -inline FCommReduce MakeSinglePassReducer( - std::function comparison_op, - std::function initial_value_generator, String name) { +inline FCommReduce MakeArgminReducer(bool select_last_index = false) { // Create a Commutative Reducer with a comparison operation, and method to get the initial value. auto fcombine = [&](Array lhs, Array rhs) { Array result; - result.push_back(tvm::tir::Select(comparison_op(lhs[1], rhs[1]), lhs[0], rhs[0])); // idx - result.push_back(tvm::tir::Select(comparison_op(lhs[1], rhs[1]), lhs[1], rhs[1])); // val + auto comparison = select_last_index ? lhs[1] < rhs[1] : lhs[1] <= rhs[1]; + result.push_back(tvm::tir::Select(comparison, lhs[0], rhs[0])); // idx + result.push_back(tvm::tir::Select(comparison, lhs[1], rhs[1])); // val return result; }; auto fidentity = [&](std::vector types) { Array result; result.push_back(tvm::tir::make_const(types[0], -1)); // idx - result.push_back(initial_value_generator(types[1])); // val + result.push_back(tvm::max_value(types[1])); // val return result; }; - return MakeCommReducer(fcombine, fidentity, name); -} - -inline FCommReduce MakeArgminReducer(bool select_last_index = false) { - std::function comparison_op; - if (select_last_index) { - comparison_op = [](Var lhs, Var rhs) { return lhs < rhs; }; - } else { - comparison_op = [](Var lhs, Var rhs) { return lhs <= rhs; }; - } - - std::function initial_value_generator = [](const DataType& data_type) { - return tvm::max_value(data_type); - }; - - return MakeSinglePassReducer(comparison_op, initial_value_generator, "argmin"); + return MakeCommReducer(fcombine, fidentity, "argmin"); } /*! @@ -488,18 +472,21 @@ inline Tensor argmin(const Tensor& data, const Array& axis, bool keepdi } inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { - std::function comparison_op; - if (select_last_index) { - comparison_op = [](Var lhs, Var rhs) { return lhs > rhs; }; - } else { - comparison_op = [](Var lhs, Var rhs) { return lhs >= rhs; }; - } - - std::function initial_value_generator = [](const DataType& data_type) { - return tvm::min_value(data_type); + // Create a Commutative Reducer with a comparison operation, and method to get the initial value. + auto fcombine = [&](Array lhs, Array rhs) { + Array result; + auto comparison = select_last_index ? lhs[1] > rhs[1] : lhs[1] >= rhs[1]; + result.push_back(tvm::tir::Select(comparison, lhs[0], rhs[0])); // idx + result.push_back(tvm::tir::Select(comparison, lhs[1], rhs[1])); // val + return result; }; - - return MakeSinglePassReducer(comparison_op, initial_value_generator, "argmax"); + auto fidentity = [&](std::vector types) { + Array result; + result.push_back(tvm::tir::make_const(types[0], -1)); // idx + result.push_back(tvm::min_value(types[1])); // val + return result; + }; + return MakeCommReducer(fcombine, fidentity, "argmax"); } /*! diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 96702a925089..b5643a10f8d4 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -203,10 +203,9 @@ Array ReduceCompute(const Attrs& attrs, const Array& inp auto axes = param->axis; if (param->exclude) { axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis); - } - - if (axes.size() == 0) { - return {topi::identity(inputs[0])}; + if (axes.size() == 0) { + return {topi::identity(inputs[0])}; + } } return {f(inputs[0], axes, param->keepdims, false)}; @@ -223,11 +222,11 @@ Array OneElementReduceCompute(const Attrs& attrs, const Arrayaxis; if (param->exclude) { axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis); + if (axes.size() == 0) { + return {topi::identity(inputs[0])}; + } } - if (axes.size() == 0) { - return {topi::identity(inputs[0])}; - } return {f(inputs[0], axes, param->keepdims, false, param->select_last_index)}; } From 9edc8e686580ddbf900cf62347d79a1fd95631ba Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 15:24:49 -0700 Subject: [PATCH 06/52] fix lambda capture for conversion --- include/tvm/relay/attrs/reduce.h | 2 +- include/tvm/topi/reduction.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/include/tvm/relay/attrs/reduce.h b/include/tvm/relay/attrs/reduce.h index 274ccc8c352c..44ea79fcb517 100644 --- a/include/tvm/relay/attrs/reduce.h +++ b/include/tvm/relay/attrs/reduce.h @@ -62,7 +62,7 @@ struct ReduceAttrs : public tvm::AttrsNode { }; /*! \brief Attributes for Reduce operators which reduce by finding a single element. E.g. argmin */ -struct OneElementReduceAttrs : public tvm::AttrsNode { +struct OneElementReduceAttrs : public tvm::AttrsNode { Array axis; bool keepdims; bool select_last_index; diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index 4f2d566e4538..ad2d80eefbca 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -433,7 +433,7 @@ inline Tensor max(const Tensor& data, const Array& axis, bool keepdims inline FCommReduce MakeArgminReducer(bool select_last_index = false) { // Create a Commutative Reducer with a comparison operation, and method to get the initial value. - auto fcombine = [&](Array lhs, Array rhs) { + auto fcombine = [=](Array lhs, Array rhs) { Array result; auto comparison = select_last_index ? lhs[1] < rhs[1] : lhs[1] <= rhs[1]; result.push_back(tvm::tir::Select(comparison, lhs[0], rhs[0])); // idx @@ -473,7 +473,7 @@ inline Tensor argmin(const Tensor& data, const Array& axis, bool keepdi inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { // Create a Commutative Reducer with a comparison operation, and method to get the initial value. - auto fcombine = [&](Array lhs, Array rhs) { + auto fcombine = [=](Array lhs, Array rhs) { Array result; auto comparison = select_last_index ? lhs[1] > rhs[1] : lhs[1] >= rhs[1]; result.push_back(tvm::tir::Select(comparison, lhs[0], rhs[0])); // idx From 9e4e69a8c4fada0a1774ae558dc7a3338a938faf Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 16:22:40 -0700 Subject: [PATCH 07/52] forward more arguments --- src/topi/reduction.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/topi/reduction.cc b/src/topi/reduction.cc index 55c59162e68c..3d1c6f9f7d5b 100644 --- a/src/topi/reduction.cc +++ b/src/topi/reduction.cc @@ -45,11 +45,11 @@ TVM_REGISTER_GLOBAL("topi.max").set_body([](TVMArgs args, TVMRetValue* rv) { }); TVM_REGISTER_GLOBAL("topi.argmin").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = topi::argmin(args[0], ArrayOrInt(args[1]), args[2]); + *rv = topi::argmin(args[0], ArrayOrInt(args[1]), args[2], false, args[3]); }); TVM_REGISTER_GLOBAL("topi.argmax").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = topi::argmax(args[0], ArrayOrInt(args[1]), args[2]); + *rv = topi::argmax(args[0], ArrayOrInt(args[1]), args[2], false, args[3]); }); TVM_REGISTER_GLOBAL("topi.prod").set_body([](TVMArgs args, TVMRetValue* rv) { From 5cf47727cef3b1f734e9b63d0faf4e82570a1148 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 16:36:27 -0700 Subject: [PATCH 08/52] forward more args --- python/tvm/topi/reduction.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/topi/reduction.py b/python/tvm/topi/reduction.py index 77f9ad447ed1..cba43297f293 100644 --- a/python/tvm/topi/reduction.py +++ b/python/tvm/topi/reduction.py @@ -167,7 +167,7 @@ def min(data, axis=None, keepdims=False): return cpp.min(data, axis, keepdims) -def argmax(data, axis=None, keepdims=False): +def argmax(data, axis=None, keepdims=False, exclude=False, select_last_index=False): """Returns the indices of the maximum values along an axis. Parameters @@ -189,10 +189,10 @@ def argmax(data, axis=None, keepdims=False): ------- ret : tvm.te.Tensor """ - return cpp.argmax(data, axis, keepdims) + return cpp.argmax(data, axis, keepdims, exclude=exclude, select_last_index=select_last_index) -def argmin(data, axis=None, keepdims=False): +def argmin(data, axis=None, keepdims=False, exclude=False, select_last_index=False): """Returns the indices of the minimum values along an axis. Parameters @@ -214,7 +214,7 @@ def argmin(data, axis=None, keepdims=False): ------- ret : tvm.te.Tensor """ - return cpp.argmin(data, axis, keepdims) + return cpp.argmin(data, axis, keepdims, exclude, select_last_index) def prod(data, axis=None, keepdims=False): From 4f5a662289eaa1a736de54060d23e4404acc97cd Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 16:42:02 -0700 Subject: [PATCH 09/52] enable onnx tests --- python/tvm/relay/frontend/onnx.py | 32 ++++++--------- tests/python/frontend/onnx/test_forward.py | 20 +-------- tests/python/relay/test_op_level4.py | 48 +++++++++++++++++++--- 3 files changed, 57 insertions(+), 43 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 6fc896ca5e01..3c027981ac31 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -32,24 +32,12 @@ from .. import loops as _loops from .. import op as _op from .. import qnn as _qnn +from .. import random as _random from .. import ty as _ty from .. import vision as _vision -from .. import random as _random -from .common import ( - AttrCvt, - Renamer, - fold_constant, - get_name, - get_relay_op, - infer_channels, - infer_shape, - infer_type, - infer_value, - new_var, - unbind, - gru_cell, - lstm_cell, -) +from .common import (AttrCvt, Renamer, fold_constant, get_name, get_relay_op, + gru_cell, infer_channels, infer_shape, infer_type, + infer_value, lstm_cell, new_var, unbind) __all__ = ["from_onnx"] @@ -1832,6 +1820,13 @@ def _impl_v1(cls, inputs, attr, params): attr = {"axis": axis, "keepdims": keepdims} return _op.cast(AttrCvt("argmax")(inputs, attr), "int64") + @classmethod + def _impl_v13(cls, inputs, attr, params): + axis = attr.get("axis", 0) + keepdims = attr.get("keepdims", True) + select_last_index = attr.get("select_last_index", False) + attr = {"axis": axis, "keepdims": keepdims, "select_last_index": select_last_index} + return _op.cast(AttrCvt("argmax")(inputs, attr), "int64") class ArgMin(OnnxOpConverter): """Operator converter for ArgMin.""" @@ -1848,14 +1843,11 @@ def _impl_v1(cls, inputs, attr, params): @classmethod def _impl_v13(cls, inputs, attr, params): - if "select_last_index" in attr: - raise NotImplementedError("select_last_index not supported in ArgMin") axis = attr.get("axis", 0) keepdims = attr.get("keepdims", True) select_last_index = attr.get("select_last_index", False) - attr = {"axis": axis, "keepdims": keepdims} + attr = {"axis": axis, "keepdims": keepdims, "select_last_index": select_last_index} return _op.cast(AttrCvt("argmin")(inputs, attr), "int64") - # return _op.argmin() class Softmax(OnnxOpConverter): """Operator converter for Softmax.""" diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 9e0eb1f75217..67a94757a2d7 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -17,7 +17,6 @@ import glob import os import re -import glob import numpy as np import pytest @@ -236,7 +235,8 @@ def verify_with_ort( def quantize_and_verify_with_ort(onnx_model, input_names, input_shapes, target, dev): - from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType + from onnxruntime.quantization import (CalibrationDataReader, QuantType, + quantize_static) input_arrays = [np.random.random(shape).astype("float32") for shape in input_shapes] @@ -4680,22 +4680,6 @@ def verify_eyelike(indata): "test_adagrad_multiple", "test_adam", "test_adam_multiple", - "test_argmax_default_axis_example_select_last_index", - "test_argmax_default_axis_random_select_last_index", - "test_argmax_keepdims_example_select_last_index", - "test_argmax_keepdims_random_select_last_index", - "test_argmax_negative_axis_keepdims_example_select_last_index", - "test_argmax_negative_axis_keepdims_random_select_last_index", - "test_argmax_no_keepdims_example_select_last_index", - "test_argmax_no_keepdims_random_select_last_index", - "test_argmin_default_axis_example_select_last_index", - "test_argmin_default_axis_random_select_last_index", - "test_argmin_keepdims_example_select_last_index", - "test_argmin_keepdims_random_select_last_index", - "test_argmin_negative_axis_keepdims_example_select_last_index", - "test_argmin_negative_axis_keepdims_random_select_last_index", - "test_argmin_no_keepdims_example_select_last_index", - "test_argmin_no_keepdims_random_select_last_index", "test_cast_BFLOAT16_to_FLOAT", "test_cast_DOUBLE_to_FLOAT16", "test_cast_FLOAT_to_BFLOAT16", diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index df77c33658de..ec347224b0be 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -14,14 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import tvm -from tvm import te import numpy as np -from tvm import relay +import numpy.random +import tvm +import tvm.testing +import tvm.topi.testing +from tvm import relay, te from tvm.relay import transform from tvm.relay.testing import run_infer_type -import tvm.topi.testing -import tvm.testing @tvm.testing.uses_gpu @@ -342,6 +342,44 @@ def _unbiased_func(a, axis=None, dtype=None, keepdims=None): verify_reduce(func, (128, 24, 128), (0, 2), True, False, (1, 24, 1)) +@tvm.testing.uses_gpu +def test_argmin_argmax_get_last_elements(): + def get_test_case(shape, gt_func, test_argmin=False): + total_ele = np.product(shape) + arr = np.zeros(total_ele) + target_value = -1 if test_argmin else 1 + arr[: total_ele // 3] = target_value + np.random.shuffle(arr) + arr = arr.reshape(shape) + ans = gt_func(np.flip(arr)) + return arr, len(arr) - ans - 1 + + funcs_and_gt_funcs = [(relay.argmax, np.argmax), (relay.argmin, np.argmin)] + lengths = [5, 10, 15] + for func, gt_func in funcs_and_gt_funcs: + for shape in lengths: + x_in = relay.var("x_in", shape=[shape]) + try: + output = func(x_in, select_last_index=True) + except: + breakpoint() + arr, ans = get_test_case(shape, gt_func, test_argmin=func == relay.argmin) + + mod = tvm.IRModule.from_expr(output) + for target, dev in tvm.testing.enabled_targets(): + op_res = relay.create_executor( + "graph", mod=mod, device=dev, target=target + ).evaluate()(arr) + print(target) + print(dev) + print(arr) + print(ans) + print(op_res) + print() + + raise ValueError("WHAT") + + def verify_mean_var_std(funcs, shape, axis, keepdims): test_func = funcs[0] ref_func = funcs[1] From 75cb60852556603377853304272575f88b86d1d5 Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Mon, 23 Aug 2021 17:42:42 -0700 Subject: [PATCH 10/52] wrapping casts to remove ambiguity --- include/tvm/topi/reduction.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index ad2d80eefbca..9bd863873531 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -435,7 +435,8 @@ inline FCommReduce MakeArgminReducer(bool select_last_index = false) { // Create a Commutative Reducer with a comparison operation, and method to get the initial value. auto fcombine = [=](Array lhs, Array rhs) { Array result; - auto comparison = select_last_index ? lhs[1] < rhs[1] : lhs[1] <= rhs[1]; + // Cast to resolve ambiguous operators + auto comparison = select_last_index ? PrimExpr(lhs[1]) < PrimExpr(rhs[1]) : lhs[1] <= rhs[1]; result.push_back(tvm::tir::Select(comparison, lhs[0], rhs[0])); // idx result.push_back(tvm::tir::Select(comparison, lhs[1], rhs[1])); // val return result; @@ -475,7 +476,8 @@ inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { // Create a Commutative Reducer with a comparison operation, and method to get the initial value. auto fcombine = [=](Array lhs, Array rhs) { Array result; - auto comparison = select_last_index ? lhs[1] > rhs[1] : lhs[1] >= rhs[1]; + // Cast to resolve ambiguous operators + auto comparison = select_last_index ? PrimExpr(lhs[1]) > PrimExpr(rhs[1]) : lhs[1] >= rhs[1]; result.push_back(tvm::tir::Select(comparison, lhs[0], rhs[0])); // idx result.push_back(tvm::tir::Select(comparison, lhs[1], rhs[1])); // val return result; From 7a60353312be0bf76a41253774fdf20df1218382 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 18:06:57 -0700 Subject: [PATCH 11/52] revert changes extraneous --- include/tvm/relay/expr_functor.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 1932882ad198..688ad8254fa8 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -241,9 +241,9 @@ class MixedModeVisitor : public ::tvm::relay::ExprVisitor { * \brief VisitExpr is finalized to preserve call expansion of dataflow regions */ void VisitExpr(const Expr& expr) final; - virtual void VisitExpr_(const CallNode* op) override; - virtual void VisitExpr_(const TupleNode* op) override; - virtual void VisitExpr_(const TupleGetItemNode* op) override; + void VisitExpr_(const CallNode* op) override; + void VisitExpr_(const TupleNode* op) override; + void VisitExpr_(const TupleGetItemNode* op) override; protected: /*! From 6a9e82f92f01565cabac6992ae7778c916edc77d Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 22:27:32 -0700 Subject: [PATCH 12/52] correct incorrect attrs being used for ops --- src/relay/op/tensor/reduce.cc | 41 +++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index b5643a10f8d4..7881e03e8b37 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -289,22 +289,16 @@ inline std::vector ReduceShapeImpl(const std::vector& in_s } } -/*! - * \brief ArgReduceRel Output type and shape relation evaluation function. - * \param num_inputs Number of input types in the args. - * \param attrs The additional attributes of the operator. - * \param reporter The reporter to report solution to. - * \return false if This relation cannot be resolved. true if this relation has been resolved. - */ -bool ArgReduceRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { +template +bool GenericReduceRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { ICHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) return false; ICHECK(static_cast(data->shape.size()) != 0); std::vector in_shape(data->shape.begin(), data->shape.end()); - const ReduceAttrs* param = attrs.as(); + const T* param = attrs.as(); ICHECK(param != nullptr); // assign output type and shape @@ -312,6 +306,29 @@ bool ArgReduceRel(const Array& types, int num_inputs, const Attrs& attrs, reporter->Assign(types[1], TensorType(oshape, DataType::Int(32))); return true; } +/*! + * \brief ArgReduceRel Output type and shape relation evaluation function. + * \param num_inputs Number of input types in the args. + * \param attrs The additional attributes of the operator. + * \param reporter The reporter to report solution to. + * \return false if This relation cannot be resolved. true if this relation has been resolved. + */ +bool ArgReduceRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + return GenericReduceRel(types, num_inputs, attrs, reporter); +} + +/*! + * \brief SingleElementArgReduceRel Output type and shape relation evaluation function. + * \param num_inputs Number of input types in the args. + * \param attrs The additional attributes of the operator. + * \param reporter The reporter to report solution to. + * \return false if This relation cannot be resolved. true if this relation has been resolved. + */ +bool SingleElementArgReduceRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + return GenericReduceRel(types, num_inputs, attrs, reporter); +} /*! * \brief ReduceRel Output type and shape relation evaluation function. @@ -381,7 +398,7 @@ values over a given axis. )code" TVM_ADD_FILELINE) .set_attrs_type() .set_support_level(4) - .add_type_rel("ArgReduce", ArgReduceRel) + .add_type_rel("ArgReduce", SingleElementArgReduceRel) .set_attr("FTVMCompute", ArgMaxCompute) .set_attr("TOpPattern", kCommReduce); @@ -397,7 +414,7 @@ values over a given axis. )code" TVM_ADD_FILELINE) .set_attrs_type() .set_support_level(4) - .add_type_rel("ArgReduce", ArgReduceRel) + .add_type_rel("ArgReduce", SingleElementArgReduceRel) .set_attr("FTVMCompute", ArgMinCompute) .set_attr("TOpPattern", kCommReduce); From 0fb5db57f1552612971bf164e9f66c6c43c67145 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 22:32:42 -0700 Subject: [PATCH 13/52] change attributes --- include/tvm/relay/attrs/reduce.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/relay/attrs/reduce.h b/include/tvm/relay/attrs/reduce.h index 44ea79fcb517..8c4794e9ac00 100644 --- a/include/tvm/relay/attrs/reduce.h +++ b/include/tvm/relay/attrs/reduce.h @@ -68,7 +68,7 @@ struct OneElementReduceAttrs : public tvm::AttrsNode { bool select_last_index; bool exclude; - TVM_DECLARE_ATTRS(ReduceAttrs, "relay.attrs.ReduceAttrs") { + TVM_DECLARE_ATTRS(OneElementReduceAttrs, "relay.attrs.OneElementReduceAttrs") { TVM_ATTR_FIELD(axis) .set_default(NullValue>()) .describe(R"code(The axis or axes along which to perform the reduction. From 55a412d46532ac51bcfdb1e97a407c7592dd5b29 Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Mon, 23 Aug 2021 22:41:50 -0700 Subject: [PATCH 14/52] remove old impl --- python/tvm/relay/frontend/onnx.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 3c027981ac31..1d9f1bc3ab28 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1811,15 +1811,6 @@ def _impl_v1(cls, inputs, attr, params): class ArgMax(OnnxOpConverter): """Operator converter for ArgMax.""" - @classmethod - def _impl_v1(cls, inputs, attr, params): - if "select_last_index" in attr: - raise NotImplementedError("select_last_index not supported in ArgMax") - axis = attr.get("axis", 0) - keepdims = attr.get("keepdims", True) - attr = {"axis": axis, "keepdims": keepdims} - return _op.cast(AttrCvt("argmax")(inputs, attr), "int64") - @classmethod def _impl_v13(cls, inputs, attr, params): axis = attr.get("axis", 0) @@ -1831,16 +1822,6 @@ def _impl_v13(cls, inputs, attr, params): class ArgMin(OnnxOpConverter): """Operator converter for ArgMin.""" - @classmethod - def _impl_v1(cls, inputs, attr, params): - if "select_last_index" in attr: - raise NotImplementedError("select_last_index not supported in ArgMin") - axis = attr.get("axis", 0) - keepdims = attr.get("keepdims", True) - attr = {"axis": axis, "keepdims": keepdims} - return _op.cast(AttrCvt("argmin")(inputs, attr), "int64") - - @classmethod def _impl_v13(cls, inputs, attr, params): axis = attr.get("axis", 0) From 93173fc543a008199bc8fdb5a9da9b2528c2ab09 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 22:49:50 -0700 Subject: [PATCH 15/52] register new attribute node --- src/relay/op/tensor/reduce.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 7881e03e8b37..75483ea8b015 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -38,6 +38,7 @@ namespace tvm { namespace relay { TVM_REGISTER_NODE_TYPE(ReduceAttrs); +TVM_REGISTER_NODE_TYPE(OneElementReduceAttrs); TVM_REGISTER_NODE_TYPE(VarianceAttrs); /*! From 47b7eed226b0d25b4181e88eb3f9b962079788fb Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 22:53:36 -0700 Subject: [PATCH 16/52] clean up test --- tests/python/relay/test_op_level4.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index ec347224b0be..6415976bfd59 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -359,10 +359,7 @@ def get_test_case(shape, gt_func, test_argmin=False): for func, gt_func in funcs_and_gt_funcs: for shape in lengths: x_in = relay.var("x_in", shape=[shape]) - try: - output = func(x_in, select_last_index=True) - except: - breakpoint() + output = func(x_in, select_last_index=True) arr, ans = get_test_case(shape, gt_func, test_argmin=func == relay.argmin) mod = tvm.IRModule.from_expr(output) @@ -370,14 +367,7 @@ def get_test_case(shape, gt_func, test_argmin=False): op_res = relay.create_executor( "graph", mod=mod, device=dev, target=target ).evaluate()(arr) - print(target) - print(dev) - print(arr) - print(ans) - print(op_res) - print() - - raise ValueError("WHAT") + assert op_res.numpy().item() == ans def verify_mean_var_std(funcs, shape, axis, keepdims): From e62513b795edc2082f54f5d1f36676ba407077cc Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 23:01:12 -0700 Subject: [PATCH 17/52] reformat --- tests/python/frontend/onnx/test_forward.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 67a94757a2d7..a1d821686ed5 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -235,8 +235,7 @@ def verify_with_ort( def quantize_and_verify_with_ort(onnx_model, input_names, input_shapes, target, dev): - from onnxruntime.quantization import (CalibrationDataReader, QuantType, - quantize_static) + from onnxruntime.quantization import CalibrationDataReader, QuantType, quantize_static input_arrays = [np.random.random(shape).astype("float32") for shape in input_shapes] From e9ea784fdbac364b535bcdc1de304b34cea47115 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 23:03:39 -0700 Subject: [PATCH 18/52] reformat --- python/tvm/relay/frontend/onnx.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 1d9f1bc3ab28..8376d56fe089 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -35,9 +35,20 @@ from .. import random as _random from .. import ty as _ty from .. import vision as _vision -from .common import (AttrCvt, Renamer, fold_constant, get_name, get_relay_op, - gru_cell, infer_channels, infer_shape, infer_type, - infer_value, lstm_cell, new_var, unbind) +from .common import ( + AttrCvt, + Renamer, + fold_constant, + get_name, + get_relay_op, + infer_channels, + infer_shape, + infer_type, + infer_value, + lstm_cell, + new_var, + unbind, +) __all__ = ["from_onnx"] @@ -1819,6 +1830,7 @@ def _impl_v13(cls, inputs, attr, params): attr = {"axis": axis, "keepdims": keepdims, "select_last_index": select_last_index} return _op.cast(AttrCvt("argmax")(inputs, attr), "int64") + class ArgMin(OnnxOpConverter): """Operator converter for ArgMin.""" @@ -1830,6 +1842,7 @@ def _impl_v13(cls, inputs, attr, params): attr = {"axis": axis, "keepdims": keepdims, "select_last_index": select_last_index} return _op.cast(AttrCvt("argmin")(inputs, attr), "int64") + class Softmax(OnnxOpConverter): """Operator converter for Softmax.""" From 587e94a438ca5cf5508b8e8240f022e80e3174ed Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 23:11:29 -0700 Subject: [PATCH 19/52] coolio --- python/tvm/relay/op/reduce.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/op/reduce.py b/python/tvm/relay/op/reduce.py index dfc8e71026af..23accebfd0ec 100644 --- a/python/tvm/relay/op/reduce.py +++ b/python/tvm/relay/op/reduce.py @@ -46,8 +46,8 @@ def argmax(data, axis=None, keepdims=False, exclude=False, select_last_index=Fal NOT in axis instead. select_last_index : bool - Whether to select the last index or the first index if the max element appears in multiple indices, - default is False (first index). + Whether to select the last index or the first index if the max element appears in + multiple indices, default is False (first index). Returns ------- @@ -81,8 +81,8 @@ def argmin(data, axis=None, keepdims=False, exclude=False, select_last_index=Fal NOT in axis instead. select_last_index : bool - Whether to select the last index or the first index if the min element appears in multiple indices, - default is False (first index). + Whether to select the last index or the first index if the min element appears in + multiple indices, default is False (first index). Returns ------- From d048e2509be0199334424ef8f4d4ef9e9577bcb4 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Tue, 24 Aug 2021 12:07:49 -0700 Subject: [PATCH 20/52] stable comparison --- include/tvm/topi/reduction.h | 45 +++++++++++++++++++++++++++++------- 1 file changed, 37 insertions(+), 8 deletions(-) diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index 9bd863873531..e4318957160f 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -435,10 +435,24 @@ inline FCommReduce MakeArgminReducer(bool select_last_index = false) { // Create a Commutative Reducer with a comparison operation, and method to get the initial value. auto fcombine = [=](Array lhs, Array rhs) { Array result; - // Cast to resolve ambiguous operators - auto comparison = select_last_index ? PrimExpr(lhs[1]) < PrimExpr(rhs[1]) : lhs[1] <= rhs[1]; - result.push_back(tvm::tir::Select(comparison, lhs[0], rhs[0])); // idx - result.push_back(tvm::tir::Select(comparison, lhs[1], rhs[1])); // val + + // These variables compare the actual values of the array + auto is_smaller = lhs[1] < rhs[1]; + auto is_same = lhs[1] == rhs[1]; + + // This checks if the indices are correct for the reduction. E.g. for select_last_index + // it gives precedence for later indices of the same element and precedence for sooner + // indices if not select_last_index; + PrimExpr proper_index; + if (select_last_index) { + proper_index = lhs[0] > rhs[0]; + } else { + proper_index = lhs[0] < rhs[0]; + } + + PrimExpr update_index = is_smaller || (is_same && proper_index); + result.push_back(tvm::tir::Select(update_index, lhs[0], rhs[0])); // idx + result.push_back(tvm::tir::Select(is_smaller, lhs[1], rhs[1])); // val return result; }; auto fidentity = [&](std::vector types) { @@ -476,10 +490,25 @@ inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { // Create a Commutative Reducer with a comparison operation, and method to get the initial value. auto fcombine = [=](Array lhs, Array rhs) { Array result; - // Cast to resolve ambiguous operators - auto comparison = select_last_index ? PrimExpr(lhs[1]) > PrimExpr(rhs[1]) : lhs[1] >= rhs[1]; - result.push_back(tvm::tir::Select(comparison, lhs[0], rhs[0])); // idx - result.push_back(tvm::tir::Select(comparison, lhs[1], rhs[1])); // val + + // These variables compare the actual values of the array + auto is_bigger = lhs[1] > rhs[1]; + auto is_same = lhs[1] == rhs[1]; + + // This checks if the indices are correct for the reduction. E.g. for select_last_index + // it gives precedence for later indices of the same element and precedence for sooner + // indices if not select_last_index; + PrimExpr proper_index; + if (select_last_index) { + proper_index = lhs[0] > rhs[0]; + } else { + proper_index = lhs[0] < rhs[0]; + } + + PrimExpr update_index = is_bigger || (is_same && proper_index); + result.push_back(tvm::tir::Select(update_index, lhs[0], rhs[0])); // idx + result.push_back(tvm::tir::Select(is_bigger, lhs[1], rhs[1])); // val + LOG(WARNING) << result; return result; }; auto fidentity = [&](std::vector types) { From 71ab1f37add055a93380b6a21e058df762a5794d Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Tue, 24 Aug 2021 12:15:23 -0700 Subject: [PATCH 21/52] casts to avoid ambiguity --- include/tvm/topi/reduction.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index e4318957160f..d34afc8af909 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -445,9 +445,9 @@ inline FCommReduce MakeArgminReducer(bool select_last_index = false) { // indices if not select_last_index; PrimExpr proper_index; if (select_last_index) { - proper_index = lhs[0] > rhs[0]; + proper_index = PrimExpr(lhs[0]) > PrimExpr(rhs[0]); } else { - proper_index = lhs[0] < rhs[0]; + proper_index = PrimExpr(lhs[0]) < PrimExpr(rhs[0]); } PrimExpr update_index = is_smaller || (is_same && proper_index); @@ -500,9 +500,9 @@ inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { // indices if not select_last_index; PrimExpr proper_index; if (select_last_index) { - proper_index = lhs[0] > rhs[0]; + proper_index = PrimExpr(lhs[0]) > PrimExpr(rhs[0]); } else { - proper_index = lhs[0] < rhs[0]; + proper_index = PrimExpr(lhs[0]) < PrimExpr(rhs[0]); } PrimExpr update_index = is_bigger || (is_same && proper_index); From aecf630b3a441729688362edfeca20bafd555e3d Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Tue, 24 Aug 2021 12:24:25 -0700 Subject: [PATCH 22/52] casting more --- include/tvm/topi/reduction.h | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index d34afc8af909..caf327ea2dce 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -436,18 +436,24 @@ inline FCommReduce MakeArgminReducer(bool select_last_index = false) { auto fcombine = [=](Array lhs, Array rhs) { Array result; + // Casting to avoid operator ambiguity + PrimExpr lhs_idx = static_cast(lhs[0]); + PrimExpr rhs_idx = static_cast(rhs[0]); + PrimExpr lhs_val = static_cast(lhs[1]); + PrimExpr rhs_val = static_cast(rhs[1]); + // These variables compare the actual values of the array - auto is_smaller = lhs[1] < rhs[1]; - auto is_same = lhs[1] == rhs[1]; + auto is_smaller = lhs_val < rhs_val; + auto is_same = lhs_val == rhs_val; // This checks if the indices are correct for the reduction. E.g. for select_last_index // it gives precedence for later indices of the same element and precedence for sooner // indices if not select_last_index; PrimExpr proper_index; if (select_last_index) { - proper_index = PrimExpr(lhs[0]) > PrimExpr(rhs[0]); + proper_index = lhs_idx > rhs_idx; } else { - proper_index = PrimExpr(lhs[0]) < PrimExpr(rhs[0]); + proper_index = lhs_idx < rhs_idx; } PrimExpr update_index = is_smaller || (is_same && proper_index); @@ -491,18 +497,24 @@ inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { auto fcombine = [=](Array lhs, Array rhs) { Array result; + // Casting to avoid operator ambiguity + PrimExpr lhs_idx = static_cast(lhs[0]); + PrimExpr rhs_idx = static_cast(rhs[0]); + PrimExpr lhs_val = static_cast(lhs[1]); + PrimExpr rhs_val = static_cast(rhs[1]); + // These variables compare the actual values of the array - auto is_bigger = lhs[1] > rhs[1]; - auto is_same = lhs[1] == rhs[1]; + auto is_bigger = lhs_val > rhs_val; + auto is_same = lhs_val == rhs_val; // This checks if the indices are correct for the reduction. E.g. for select_last_index // it gives precedence for later indices of the same element and precedence for sooner // indices if not select_last_index; PrimExpr proper_index; if (select_last_index) { - proper_index = PrimExpr(lhs[0]) > PrimExpr(rhs[0]); + proper_index = lhs_idx > rhs_idx; } else { - proper_index = PrimExpr(lhs[0]) < PrimExpr(rhs[0]); + proper_index = lhs_idx < rhs_idx; } PrimExpr update_index = is_bigger || (is_same && proper_index); From 423d092952d40e66de94422d1600a70cf49292b9 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Wed, 25 Aug 2021 19:10:40 -0700 Subject: [PATCH 23/52] correct arg passing --- python/tvm/topi/reduction.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/python/tvm/topi/reduction.py b/python/tvm/topi/reduction.py index cba43297f293..45d07af577a3 100644 --- a/python/tvm/topi/reduction.py +++ b/python/tvm/topi/reduction.py @@ -167,7 +167,7 @@ def min(data, axis=None, keepdims=False): return cpp.min(data, axis, keepdims) -def argmax(data, axis=None, keepdims=False, exclude=False, select_last_index=False): +def argmax(data, axis=None, keepdims=False, select_last_index=False): """Returns the indices of the maximum values along an axis. Parameters @@ -185,14 +185,18 @@ def argmax(data, axis=None, keepdims=False, exclude=False, select_last_index=Fal with size one. With this option, the result will broadcast correctly against the input array. + select_last_index: bool + Whether to select the last index if the maximum element appears multiple times, else + select the first index. + Returns ------- ret : tvm.te.Tensor """ - return cpp.argmax(data, axis, keepdims, exclude=exclude, select_last_index=select_last_index) + return cpp.argmax(data, axis, keepdims, select_last_index) -def argmin(data, axis=None, keepdims=False, exclude=False, select_last_index=False): +def argmin(data, axis=None, keepdims=False, select_last_index=False): """Returns the indices of the minimum values along an axis. Parameters @@ -210,11 +214,15 @@ def argmin(data, axis=None, keepdims=False, exclude=False, select_last_index=Fal with size one. With this option, the result will broadcast correctly against the input array. + select_last_index: bool + Whether to select the last index if the minimum element appears multiple times, else + select the first index. + Returns ------- ret : tvm.te.Tensor """ - return cpp.argmin(data, axis, keepdims, exclude, select_last_index) + return cpp.argmin(data, axis, keepdims, select_last_index) def prod(data, axis=None, keepdims=False): From 2faf06d0ba58d5b0d0f9a7e809cb7b1465745982 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 10:59:26 -0700 Subject: [PATCH 24/52] support select_last_index for argmin/max --- include/tvm/relay/attrs/reduce.h | 36 +++++++++++++ include/tvm/relay/expr_functor.h | 6 +-- include/tvm/topi/reduction.h | 86 ++++++++++++++++++++----------- python/tvm/relay/frontend/onnx.py | 11 ++++ src/relay/op/tensor/reduce.cc | 35 ++++++++++--- 5 files changed, 133 insertions(+), 41 deletions(-) diff --git a/include/tvm/relay/attrs/reduce.h b/include/tvm/relay/attrs/reduce.h index 14b75ff1c0a8..274ccc8c352c 100644 --- a/include/tvm/relay/attrs/reduce.h +++ b/include/tvm/relay/attrs/reduce.h @@ -61,6 +61,42 @@ struct ReduceAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes for Reduce operators which reduce by finding a single element. E.g. argmin */ +struct OneElementReduceAttrs : public tvm::AttrsNode { + Array axis; + bool keepdims; + bool select_last_index; + bool exclude; + + TVM_DECLARE_ATTRS(ReduceAttrs, "relay.attrs.ReduceAttrs") { + TVM_ATTR_FIELD(axis) + .set_default(NullValue>()) + .describe(R"code(The axis or axes along which to perform the reduction. + + The default, `axis=()`, will compute over all elements into a + scalar array with shape `(1,)`. + + If `axis` is int, a reduction is performed on a particular axis. + + If `axis` is a tuple of ints, a reduction is performed on all the axes + specified in the tuple. + + If `exclude` is true, reduction will be performed on the axes that are + NOT in axis instead.)code"); + + TVM_ATTR_FIELD(keepdims).set_default(false).describe( + "If this is set to `True`, the reduced axes are left " + "in the result as dimension with size one."); + TVM_ATTR_FIELD(select_last_index) + .set_default(false) + .describe( + "Whether to select the last index if the target element appears multiple times, else " + "select the first index which the target element appears"); + TVM_ATTR_FIELD(exclude).set_default(false).describe( + "Whether to perform reduction on axis that are NOT in axis instead."); + } +}; + struct VarianceAttrs : public tvm::AttrsNode { Array axis; bool keepdims; diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 688ad8254fa8..1932882ad198 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -241,9 +241,9 @@ class MixedModeVisitor : public ::tvm::relay::ExprVisitor { * \brief VisitExpr is finalized to preserve call expansion of dataflow regions */ void VisitExpr(const Expr& expr) final; - void VisitExpr_(const CallNode* op) override; - void VisitExpr_(const TupleNode* op) override; - void VisitExpr_(const TupleGetItemNode* op) override; + virtual void VisitExpr_(const CallNode* op) override; + virtual void VisitExpr_(const TupleNode* op) override; + virtual void VisitExpr_(const TupleGetItemNode* op) override; protected: /*! diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index 15d1455bb267..cf8bdad502dd 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -431,6 +431,40 @@ inline Tensor max(const Tensor& data, const Array& axis, bool keepdims return CommReduce(data, axis, MaxOp, keepdims, atleast1d); } +inline FCommReduce MakeSinglePassReducer( + std::function comparison_op, + std::function initial_value_generator, String name) { + // Create a Commutative Reducer with a comparison operation, and method to get the initial value. + auto fcombine = [&](Array lhs, Array rhs) { + Array result; + result.push_back(tvm::tir::Select(comparison_op(lhs[1], rhs[1]), lhs[0], rhs[0])); // idx + result.push_back(tvm::tir::Select(comparison_op(lhs[1], rhs[1]), lhs[1], rhs[1])); // val + return result; + }; + auto fidentity = [&](std::vector types) { + Array result; + result.push_back(tvm::tir::make_const(types[0], -1)); // idx + result.push_back(initial_value_generator(types[1])); // val + return result; + }; + return MakeCommReducer(fcombine, fidentity, name); +} + +inline FCommReduce MakeArgminReducer(bool select_last_index = false) { + std::function comparison_op; + if (select_last_index) { + comparison_op = [](Var lhs, Var rhs) { return lhs <= rhs; }; + } else { + comparison_op = [](Var lhs, Var rhs) { return lhs < rhs; }; + } + + std::function initial_value_generator = [](const DataType& data_type) { + return tvm::max_value(data_type); + }; + + return MakeSinglePassReducer(comparison_op, initial_value_generator, "argmin"); +} + /*! * \brief Creates an operation that finds the indices of the minimum * values over a given axis. @@ -442,41 +476,30 @@ inline Tensor max(const Tensor& data, const Array& axis, bool keepdims * left in the result as dimensions with size one. This enables the result * to broadcast correctly against the input array. * \param atleast1d Whether the output need to be atleast1d. + * \param select_last_index Whether to select the last index if the minimum element + * appears multiple times, else select the first index. * * \return A Tensor whose op member is the argmin operation */ inline Tensor argmin(const Tensor& data, const Array& axis, bool keepdims = false, - bool atleast1d = false) { - auto fcombine = [](Array lhs, Array rhs) { - Array result; - result.push_back(tvm::tir::Select(lhs[1] <= rhs[1], lhs[0], rhs[0])); // idx - result.push_back(tvm::tir::Select(lhs[1] <= rhs[1], lhs[1], rhs[1])); // val - return result; - }; - auto fidentity = [](std::vector types) { - Array result; - result.push_back(tvm::tir::make_const(types[0], -1)); // idx - result.push_back(tvm::max_value(types[1])); // val - return result; - }; - auto func = MakeCommReducer(fcombine, fidentity, "argmin"); - return CommReduceIdx(data, axis, func, keepdims, atleast1d); + bool atleast1d = false, bool select_last_index = false) { + auto reducer = MakeArgminReducer(select_last_index); + return CommReduceIdx(data, axis, reducer, keepdims, atleast1d); } -inline FCommReduce MakeArgmaxReducer() { - auto fcombine = [](Array lhs, Array rhs) { - Array result; - result.push_back(tvm::tir::Select(lhs[1] >= rhs[1], lhs[0], rhs[0])); // idx - result.push_back(tvm::tir::Select(lhs[1] >= rhs[1], lhs[1], rhs[1])); // val - return result; - }; - auto fidentity = [](std::vector types) { - Array result; - result.push_back(tvm::tir::make_const(types[0], -1)); // idx - result.push_back(tvm::min_value(types[1])); // val - return result; +inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { + std::function comparison_op; + if (select_last_index) { + comparison_op = [](Var lhs, Var rhs) { return lhs >= rhs; }; + } else { + comparison_op = [](Var lhs, Var rhs) { return lhs > rhs; }; + } + + std::function initial_value_generator = [](const DataType& data_type) { + return tvm::min_value(data_type); }; - return MakeCommReducer(fcombine, fidentity, "argmax"); + + return MakeSinglePassReducer(comparison_op, initial_value_generator, "argmax"); } /*! @@ -490,12 +513,13 @@ inline FCommReduce MakeArgmaxReducer() { * left in the result as dimensions with size one. This enables the result * to broadcast correctly against the input array. * \param atleast1d Whether the output need to be atleast1d. - * + * \param select_last_index Whether to select the last index if the maximum element + * appears multiple times, else select the first index. * \return A Tensor whose op member is the argmax operation */ inline Tensor argmax(const Tensor& data, const Array& axis, bool keepdims = false, - bool atleast1d = false) { - auto reducer = MakeArgmaxReducer(); + bool atleast1d = false, bool select_last_index = false) { + auto reducer = MakeArgmaxReducer(select_last_index); return CommReduceIdx(data, axis, reducer, keepdims, atleast1d); } diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 5471f67ea106..6fc896ca5e01 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1846,6 +1846,17 @@ def _impl_v1(cls, inputs, attr, params): return _op.cast(AttrCvt("argmin")(inputs, attr), "int64") + @classmethod + def _impl_v13(cls, inputs, attr, params): + if "select_last_index" in attr: + raise NotImplementedError("select_last_index not supported in ArgMin") + axis = attr.get("axis", 0) + keepdims = attr.get("keepdims", True) + select_last_index = attr.get("select_last_index", False) + attr = {"axis": axis, "keepdims": keepdims} + return _op.cast(AttrCvt("argmin")(inputs, attr), "int64") + # return _op.argmin() + class Softmax(OnnxOpConverter): """Operator converter for Softmax.""" diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index f08af1e7e4ad..04ab183f9019 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -203,13 +203,34 @@ Array ReduceCompute(const Attrs& attrs, const Array& inp auto axes = param->axis; if (param->exclude) { axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis); - if (axes.size() == 0) { - return {topi::identity(inputs[0])}; - } } + + if (axes.size() == 0) { + return {topi::identity(inputs[0])}; + } + return {f(inputs[0], axes, param->keepdims, false)}; } +template +Array OneElementReduceCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type, F f) { + const OneElementReduceAttrs* param = attrs.as(); + ICHECK(param != nullptr); + if (inputs[0]->shape.size() == 0) { + return {topi::identity(inputs[0])}; + } + auto axes = param->axis; + if (param->exclude) { + axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis); + } + + if (axes.size() == 0) { + return {topi::identity(inputs[0])}; + } + return {f(inputs[0], axes, param->keepdims, false, param->select_last_index)}; +} + /*! * \brief ReduceShapeImpl get the outshape for the reduction operator * \param in_shape Shape of input data. @@ -333,7 +354,7 @@ Expr MakeReduce(Expr data, Array axis, bool keepdims, bool exclude, Str Array ArgMaxCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - return ReduceCompute(attrs, inputs, out_type, topi::argmax); + return OneElementReduceCompute(attrs, inputs, out_type, topi::argmax); } RELAY_REGISTER_REDUCE_OP("argmax") @@ -341,7 +362,7 @@ RELAY_REGISTER_REDUCE_OP("argmax") values over a given axis. )code" TVM_ADD_FILELINE) - .set_attrs_type() + .set_attrs_type() .set_support_level(4) .add_type_rel("ArgReduce", ArgReduceRel) .set_attr("FTVMCompute", ArgMaxCompute) @@ -349,7 +370,7 @@ values over a given axis. Array ArgMinCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - return ReduceCompute(attrs, inputs, out_type, topi::argmin); + return OneElementReduceCompute(attrs, inputs, out_type, topi::argmin); } RELAY_REGISTER_REDUCE_OP("argmin") @@ -357,7 +378,7 @@ RELAY_REGISTER_REDUCE_OP("argmin") values over a given axis. )code" TVM_ADD_FILELINE) - .set_attrs_type() + .set_attrs_type() .set_support_level(4) .add_type_rel("ArgReduce", ArgReduceRel) .set_attr("FTVMCompute", ArgMinCompute) From edbc0f19e5f8e085241ecc49eec7167dc384f2ba Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 11:13:59 -0700 Subject: [PATCH 25/52] reverse conditions which made on accident --- include/tvm/topi/reduction.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index cf8bdad502dd..19b8195a5116 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -453,9 +453,9 @@ inline FCommReduce MakeSinglePassReducer( inline FCommReduce MakeArgminReducer(bool select_last_index = false) { std::function comparison_op; if (select_last_index) { - comparison_op = [](Var lhs, Var rhs) { return lhs <= rhs; }; - } else { comparison_op = [](Var lhs, Var rhs) { return lhs < rhs; }; + } else { + comparison_op = [](Var lhs, Var rhs) { return lhs <= rhs; }; } std::function initial_value_generator = [](const DataType& data_type) { @@ -490,9 +490,9 @@ inline Tensor argmin(const Tensor& data, const Array& axis, bool keepdi inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { std::function comparison_op; if (select_last_index) { - comparison_op = [](Var lhs, Var rhs) { return lhs >= rhs; }; - } else { comparison_op = [](Var lhs, Var rhs) { return lhs > rhs; }; + } else { + comparison_op = [](Var lhs, Var rhs) { return lhs >= rhs; }; } std::function initial_value_generator = [](const DataType& data_type) { From ba7f57c9120c93ffc504f12fc9c6a24413285f2c Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 11:21:41 -0700 Subject: [PATCH 26/52] forward args in reduce.py --- python/tvm/relay/op/reduce.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/op/reduce.py b/python/tvm/relay/op/reduce.py index 368ffb5ab0ca..dfc8e71026af 100644 --- a/python/tvm/relay/op/reduce.py +++ b/python/tvm/relay/op/reduce.py @@ -17,13 +17,13 @@ """Reduce operators.""" # pylint: disable=redefined-builtin +from ..expr import Tuple, TupleWrapper from . import _make -from .tensor import sqrt, log, exp +from .tensor import exp, log, sqrt from .transform import squeeze -from ..expr import Tuple, TupleWrapper -def argmax(data, axis=None, keepdims=False, exclude=False): +def argmax(data, axis=None, keepdims=False, exclude=False, select_last_index=False): """Returns the indices of the maximum values along an axis. Parameters @@ -45,16 +45,20 @@ def argmax(data, axis=None, keepdims=False, exclude=False): If `exclude` is true, reduction will be performed on the axes that are NOT in axis instead. + select_last_index : bool + Whether to select the last index or the first index if the max element appears in multiple indices, + default is False (first index). + Returns ------- result : relay.Expr The computed result. """ axis = [axis] if isinstance(axis, int) else axis - return _make.argmax(data, axis, keepdims, exclude) + return _make.argmax(data, axis, keepdims, exclude, select_last_index) -def argmin(data, axis=None, keepdims=False, exclude=False): +def argmin(data, axis=None, keepdims=False, exclude=False, select_last_index=False): """Returns the indices of the minimum values along an axis. Parameters @@ -76,13 +80,17 @@ def argmin(data, axis=None, keepdims=False, exclude=False): If `exclude` is true, reduction will be performed on the axes that are NOT in axis instead. + select_last_index : bool + Whether to select the last index or the first index if the min element appears in multiple indices, + default is False (first index). + Returns ------- result : relay.Expr The computed result. """ axis = [axis] if isinstance(axis, int) else axis - return _make.argmin(data, axis, keepdims, exclude) + return _make.argmin(data, axis, keepdims, exclude, select_last_index) def sum(data, axis=None, keepdims=False, exclude=False): From dbf6dc123389f5ba9d2255af1593b55cfe6c0a09 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 11:52:46 -0700 Subject: [PATCH 27/52] make proper nodes for reduction ops --- src/relay/op/tensor/reduce.cc | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 04ab183f9019..96702a925089 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -345,6 +345,16 @@ Expr MakeReduce(Expr data, Array axis, bool keepdims, bool exclude, Str return Call(Op::Get(op_name), {data}, Attrs(attrs), {}); } +Expr MakeOneElementReduce(Expr data, Array axis, bool keepdims, bool exclude, + bool select_last_index, String op_name) { + auto attrs = make_object(); + attrs->axis = std::move(axis); + attrs->keepdims = keepdims; + attrs->exclude = exclude; + attrs->select_last_index = select_last_index; + return Call(Op::Get(op_name), {data}, Attrs(attrs), {}); +} + #define RELAY_REGISTER_REDUCE_OP(OpName) \ TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ .set_body_typed([](Expr data, Array axis, bool keepdims, bool exclude) { \ @@ -352,12 +362,20 @@ Expr MakeReduce(Expr data, Array axis, bool keepdims, bool exclude, Str }); \ RELAY_REGISTER_OP(OpName).set_num_inputs(1).add_argument("data", "Tensor", "The input tensor.") +#define RELAY_REGISTER_ONE_ELEMENT_REDUCE_OP(OpName) \ + TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ + .set_body_typed([](Expr data, Array axis, bool keepdims, bool exclude, \ + bool select_last_index) { \ + return MakeOneElementReduce(data, axis, keepdims, exclude, select_last_index, OpName); \ + }); \ + RELAY_REGISTER_OP(OpName).set_num_inputs(1).add_argument("data", "Tensor", "The input tensor.") + Array ArgMaxCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return OneElementReduceCompute(attrs, inputs, out_type, topi::argmax); } -RELAY_REGISTER_REDUCE_OP("argmax") +RELAY_REGISTER_ONE_ELEMENT_REDUCE_OP("argmax") .describe(R"code(Creates an operation that finds the indices of the maximum values over a given axis. @@ -373,7 +391,7 @@ Array ArgMinCompute(const Attrs& attrs, const Array& inp return OneElementReduceCompute(attrs, inputs, out_type, topi::argmin); } -RELAY_REGISTER_REDUCE_OP("argmin") +RELAY_REGISTER_ONE_ELEMENT_REDUCE_OP("argmin") .describe(R"code(Creates an operation that finds the indices of the minimum values over a given axis. From fa4dd4322ba92c1b166b12c4ad113dacf3555254 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 12:56:07 -0700 Subject: [PATCH 28/52] remove complicated nested lambdas --- include/tvm/topi/reduction.h | 53 +++++++++++++---------------------- src/relay/op/tensor/reduce.cc | 13 ++++----- 2 files changed, 26 insertions(+), 40 deletions(-) diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index 19b8195a5116..4f2d566e4538 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -431,38 +431,22 @@ inline Tensor max(const Tensor& data, const Array& axis, bool keepdims return CommReduce(data, axis, MaxOp, keepdims, atleast1d); } -inline FCommReduce MakeSinglePassReducer( - std::function comparison_op, - std::function initial_value_generator, String name) { +inline FCommReduce MakeArgminReducer(bool select_last_index = false) { // Create a Commutative Reducer with a comparison operation, and method to get the initial value. auto fcombine = [&](Array lhs, Array rhs) { Array result; - result.push_back(tvm::tir::Select(comparison_op(lhs[1], rhs[1]), lhs[0], rhs[0])); // idx - result.push_back(tvm::tir::Select(comparison_op(lhs[1], rhs[1]), lhs[1], rhs[1])); // val + auto comparison = select_last_index ? lhs[1] < rhs[1] : lhs[1] <= rhs[1]; + result.push_back(tvm::tir::Select(comparison, lhs[0], rhs[0])); // idx + result.push_back(tvm::tir::Select(comparison, lhs[1], rhs[1])); // val return result; }; auto fidentity = [&](std::vector types) { Array result; result.push_back(tvm::tir::make_const(types[0], -1)); // idx - result.push_back(initial_value_generator(types[1])); // val + result.push_back(tvm::max_value(types[1])); // val return result; }; - return MakeCommReducer(fcombine, fidentity, name); -} - -inline FCommReduce MakeArgminReducer(bool select_last_index = false) { - std::function comparison_op; - if (select_last_index) { - comparison_op = [](Var lhs, Var rhs) { return lhs < rhs; }; - } else { - comparison_op = [](Var lhs, Var rhs) { return lhs <= rhs; }; - } - - std::function initial_value_generator = [](const DataType& data_type) { - return tvm::max_value(data_type); - }; - - return MakeSinglePassReducer(comparison_op, initial_value_generator, "argmin"); + return MakeCommReducer(fcombine, fidentity, "argmin"); } /*! @@ -488,18 +472,21 @@ inline Tensor argmin(const Tensor& data, const Array& axis, bool keepdi } inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { - std::function comparison_op; - if (select_last_index) { - comparison_op = [](Var lhs, Var rhs) { return lhs > rhs; }; - } else { - comparison_op = [](Var lhs, Var rhs) { return lhs >= rhs; }; - } - - std::function initial_value_generator = [](const DataType& data_type) { - return tvm::min_value(data_type); + // Create a Commutative Reducer with a comparison operation, and method to get the initial value. + auto fcombine = [&](Array lhs, Array rhs) { + Array result; + auto comparison = select_last_index ? lhs[1] > rhs[1] : lhs[1] >= rhs[1]; + result.push_back(tvm::tir::Select(comparison, lhs[0], rhs[0])); // idx + result.push_back(tvm::tir::Select(comparison, lhs[1], rhs[1])); // val + return result; }; - - return MakeSinglePassReducer(comparison_op, initial_value_generator, "argmax"); + auto fidentity = [&](std::vector types) { + Array result; + result.push_back(tvm::tir::make_const(types[0], -1)); // idx + result.push_back(tvm::min_value(types[1])); // val + return result; + }; + return MakeCommReducer(fcombine, fidentity, "argmax"); } /*! diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 96702a925089..b5643a10f8d4 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -203,10 +203,9 @@ Array ReduceCompute(const Attrs& attrs, const Array& inp auto axes = param->axis; if (param->exclude) { axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis); - } - - if (axes.size() == 0) { - return {topi::identity(inputs[0])}; + if (axes.size() == 0) { + return {topi::identity(inputs[0])}; + } } return {f(inputs[0], axes, param->keepdims, false)}; @@ -223,11 +222,11 @@ Array OneElementReduceCompute(const Attrs& attrs, const Arrayaxis; if (param->exclude) { axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis); + if (axes.size() == 0) { + return {topi::identity(inputs[0])}; + } } - if (axes.size() == 0) { - return {topi::identity(inputs[0])}; - } return {f(inputs[0], axes, param->keepdims, false, param->select_last_index)}; } From 78cc734fcffc0550167baa4fa404a151cf960559 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 15:24:49 -0700 Subject: [PATCH 29/52] fix lambda capture for conversion --- include/tvm/relay/attrs/reduce.h | 2 +- include/tvm/topi/reduction.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/include/tvm/relay/attrs/reduce.h b/include/tvm/relay/attrs/reduce.h index 274ccc8c352c..44ea79fcb517 100644 --- a/include/tvm/relay/attrs/reduce.h +++ b/include/tvm/relay/attrs/reduce.h @@ -62,7 +62,7 @@ struct ReduceAttrs : public tvm::AttrsNode { }; /*! \brief Attributes for Reduce operators which reduce by finding a single element. E.g. argmin */ -struct OneElementReduceAttrs : public tvm::AttrsNode { +struct OneElementReduceAttrs : public tvm::AttrsNode { Array axis; bool keepdims; bool select_last_index; diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index 4f2d566e4538..ad2d80eefbca 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -433,7 +433,7 @@ inline Tensor max(const Tensor& data, const Array& axis, bool keepdims inline FCommReduce MakeArgminReducer(bool select_last_index = false) { // Create a Commutative Reducer with a comparison operation, and method to get the initial value. - auto fcombine = [&](Array lhs, Array rhs) { + auto fcombine = [=](Array lhs, Array rhs) { Array result; auto comparison = select_last_index ? lhs[1] < rhs[1] : lhs[1] <= rhs[1]; result.push_back(tvm::tir::Select(comparison, lhs[0], rhs[0])); // idx @@ -473,7 +473,7 @@ inline Tensor argmin(const Tensor& data, const Array& axis, bool keepdi inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { // Create a Commutative Reducer with a comparison operation, and method to get the initial value. - auto fcombine = [&](Array lhs, Array rhs) { + auto fcombine = [=](Array lhs, Array rhs) { Array result; auto comparison = select_last_index ? lhs[1] > rhs[1] : lhs[1] >= rhs[1]; result.push_back(tvm::tir::Select(comparison, lhs[0], rhs[0])); // idx From 0979f4d6f3214f99a3f58ffbfc2c3aae1dd211fe Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 16:22:40 -0700 Subject: [PATCH 30/52] forward more arguments --- src/topi/reduction.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/topi/reduction.cc b/src/topi/reduction.cc index 55c59162e68c..3d1c6f9f7d5b 100644 --- a/src/topi/reduction.cc +++ b/src/topi/reduction.cc @@ -45,11 +45,11 @@ TVM_REGISTER_GLOBAL("topi.max").set_body([](TVMArgs args, TVMRetValue* rv) { }); TVM_REGISTER_GLOBAL("topi.argmin").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = topi::argmin(args[0], ArrayOrInt(args[1]), args[2]); + *rv = topi::argmin(args[0], ArrayOrInt(args[1]), args[2], false, args[3]); }); TVM_REGISTER_GLOBAL("topi.argmax").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = topi::argmax(args[0], ArrayOrInt(args[1]), args[2]); + *rv = topi::argmax(args[0], ArrayOrInt(args[1]), args[2], false, args[3]); }); TVM_REGISTER_GLOBAL("topi.prod").set_body([](TVMArgs args, TVMRetValue* rv) { From 647413e3cb42db7045429080b6d462566ba0c05b Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 16:36:27 -0700 Subject: [PATCH 31/52] forward more args --- python/tvm/topi/reduction.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/topi/reduction.py b/python/tvm/topi/reduction.py index 77f9ad447ed1..cba43297f293 100644 --- a/python/tvm/topi/reduction.py +++ b/python/tvm/topi/reduction.py @@ -167,7 +167,7 @@ def min(data, axis=None, keepdims=False): return cpp.min(data, axis, keepdims) -def argmax(data, axis=None, keepdims=False): +def argmax(data, axis=None, keepdims=False, exclude=False, select_last_index=False): """Returns the indices of the maximum values along an axis. Parameters @@ -189,10 +189,10 @@ def argmax(data, axis=None, keepdims=False): ------- ret : tvm.te.Tensor """ - return cpp.argmax(data, axis, keepdims) + return cpp.argmax(data, axis, keepdims, exclude=exclude, select_last_index=select_last_index) -def argmin(data, axis=None, keepdims=False): +def argmin(data, axis=None, keepdims=False, exclude=False, select_last_index=False): """Returns the indices of the minimum values along an axis. Parameters @@ -214,7 +214,7 @@ def argmin(data, axis=None, keepdims=False): ------- ret : tvm.te.Tensor """ - return cpp.argmin(data, axis, keepdims) + return cpp.argmin(data, axis, keepdims, exclude, select_last_index) def prod(data, axis=None, keepdims=False): From f694e5899c7bdfd0616de24df0226f10f27a2576 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 16:42:02 -0700 Subject: [PATCH 32/52] enable onnx tests --- python/tvm/relay/frontend/onnx.py | 32 ++++++--------- tests/python/frontend/onnx/test_forward.py | 20 +-------- tests/python/relay/test_op_level4.py | 48 +++++++++++++++++++--- 3 files changed, 57 insertions(+), 43 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 6fc896ca5e01..3c027981ac31 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -32,24 +32,12 @@ from .. import loops as _loops from .. import op as _op from .. import qnn as _qnn +from .. import random as _random from .. import ty as _ty from .. import vision as _vision -from .. import random as _random -from .common import ( - AttrCvt, - Renamer, - fold_constant, - get_name, - get_relay_op, - infer_channels, - infer_shape, - infer_type, - infer_value, - new_var, - unbind, - gru_cell, - lstm_cell, -) +from .common import (AttrCvt, Renamer, fold_constant, get_name, get_relay_op, + gru_cell, infer_channels, infer_shape, infer_type, + infer_value, lstm_cell, new_var, unbind) __all__ = ["from_onnx"] @@ -1832,6 +1820,13 @@ def _impl_v1(cls, inputs, attr, params): attr = {"axis": axis, "keepdims": keepdims} return _op.cast(AttrCvt("argmax")(inputs, attr), "int64") + @classmethod + def _impl_v13(cls, inputs, attr, params): + axis = attr.get("axis", 0) + keepdims = attr.get("keepdims", True) + select_last_index = attr.get("select_last_index", False) + attr = {"axis": axis, "keepdims": keepdims, "select_last_index": select_last_index} + return _op.cast(AttrCvt("argmax")(inputs, attr), "int64") class ArgMin(OnnxOpConverter): """Operator converter for ArgMin.""" @@ -1848,14 +1843,11 @@ def _impl_v1(cls, inputs, attr, params): @classmethod def _impl_v13(cls, inputs, attr, params): - if "select_last_index" in attr: - raise NotImplementedError("select_last_index not supported in ArgMin") axis = attr.get("axis", 0) keepdims = attr.get("keepdims", True) select_last_index = attr.get("select_last_index", False) - attr = {"axis": axis, "keepdims": keepdims} + attr = {"axis": axis, "keepdims": keepdims, "select_last_index": select_last_index} return _op.cast(AttrCvt("argmin")(inputs, attr), "int64") - # return _op.argmin() class Softmax(OnnxOpConverter): """Operator converter for Softmax.""" diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 9e0eb1f75217..67a94757a2d7 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -17,7 +17,6 @@ import glob import os import re -import glob import numpy as np import pytest @@ -236,7 +235,8 @@ def verify_with_ort( def quantize_and_verify_with_ort(onnx_model, input_names, input_shapes, target, dev): - from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType + from onnxruntime.quantization import (CalibrationDataReader, QuantType, + quantize_static) input_arrays = [np.random.random(shape).astype("float32") for shape in input_shapes] @@ -4680,22 +4680,6 @@ def verify_eyelike(indata): "test_adagrad_multiple", "test_adam", "test_adam_multiple", - "test_argmax_default_axis_example_select_last_index", - "test_argmax_default_axis_random_select_last_index", - "test_argmax_keepdims_example_select_last_index", - "test_argmax_keepdims_random_select_last_index", - "test_argmax_negative_axis_keepdims_example_select_last_index", - "test_argmax_negative_axis_keepdims_random_select_last_index", - "test_argmax_no_keepdims_example_select_last_index", - "test_argmax_no_keepdims_random_select_last_index", - "test_argmin_default_axis_example_select_last_index", - "test_argmin_default_axis_random_select_last_index", - "test_argmin_keepdims_example_select_last_index", - "test_argmin_keepdims_random_select_last_index", - "test_argmin_negative_axis_keepdims_example_select_last_index", - "test_argmin_negative_axis_keepdims_random_select_last_index", - "test_argmin_no_keepdims_example_select_last_index", - "test_argmin_no_keepdims_random_select_last_index", "test_cast_BFLOAT16_to_FLOAT", "test_cast_DOUBLE_to_FLOAT16", "test_cast_FLOAT_to_BFLOAT16", diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index df77c33658de..ec347224b0be 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -14,14 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import tvm -from tvm import te import numpy as np -from tvm import relay +import numpy.random +import tvm +import tvm.testing +import tvm.topi.testing +from tvm import relay, te from tvm.relay import transform from tvm.relay.testing import run_infer_type -import tvm.topi.testing -import tvm.testing @tvm.testing.uses_gpu @@ -342,6 +342,44 @@ def _unbiased_func(a, axis=None, dtype=None, keepdims=None): verify_reduce(func, (128, 24, 128), (0, 2), True, False, (1, 24, 1)) +@tvm.testing.uses_gpu +def test_argmin_argmax_get_last_elements(): + def get_test_case(shape, gt_func, test_argmin=False): + total_ele = np.product(shape) + arr = np.zeros(total_ele) + target_value = -1 if test_argmin else 1 + arr[: total_ele // 3] = target_value + np.random.shuffle(arr) + arr = arr.reshape(shape) + ans = gt_func(np.flip(arr)) + return arr, len(arr) - ans - 1 + + funcs_and_gt_funcs = [(relay.argmax, np.argmax), (relay.argmin, np.argmin)] + lengths = [5, 10, 15] + for func, gt_func in funcs_and_gt_funcs: + for shape in lengths: + x_in = relay.var("x_in", shape=[shape]) + try: + output = func(x_in, select_last_index=True) + except: + breakpoint() + arr, ans = get_test_case(shape, gt_func, test_argmin=func == relay.argmin) + + mod = tvm.IRModule.from_expr(output) + for target, dev in tvm.testing.enabled_targets(): + op_res = relay.create_executor( + "graph", mod=mod, device=dev, target=target + ).evaluate()(arr) + print(target) + print(dev) + print(arr) + print(ans) + print(op_res) + print() + + raise ValueError("WHAT") + + def verify_mean_var_std(funcs, shape, axis, keepdims): test_func = funcs[0] ref_func = funcs[1] From 576c56bcb710bd3bdfea65fa0590b66863c3f57a Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Mon, 23 Aug 2021 17:42:42 -0700 Subject: [PATCH 33/52] wrapping casts to remove ambiguity --- include/tvm/topi/reduction.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index ad2d80eefbca..9bd863873531 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -435,7 +435,8 @@ inline FCommReduce MakeArgminReducer(bool select_last_index = false) { // Create a Commutative Reducer with a comparison operation, and method to get the initial value. auto fcombine = [=](Array lhs, Array rhs) { Array result; - auto comparison = select_last_index ? lhs[1] < rhs[1] : lhs[1] <= rhs[1]; + // Cast to resolve ambiguous operators + auto comparison = select_last_index ? PrimExpr(lhs[1]) < PrimExpr(rhs[1]) : lhs[1] <= rhs[1]; result.push_back(tvm::tir::Select(comparison, lhs[0], rhs[0])); // idx result.push_back(tvm::tir::Select(comparison, lhs[1], rhs[1])); // val return result; @@ -475,7 +476,8 @@ inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { // Create a Commutative Reducer with a comparison operation, and method to get the initial value. auto fcombine = [=](Array lhs, Array rhs) { Array result; - auto comparison = select_last_index ? lhs[1] > rhs[1] : lhs[1] >= rhs[1]; + // Cast to resolve ambiguous operators + auto comparison = select_last_index ? PrimExpr(lhs[1]) > PrimExpr(rhs[1]) : lhs[1] >= rhs[1]; result.push_back(tvm::tir::Select(comparison, lhs[0], rhs[0])); // idx result.push_back(tvm::tir::Select(comparison, lhs[1], rhs[1])); // val return result; From 67b576228037448ca428c3318f0ee2b3badb74c7 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 18:06:57 -0700 Subject: [PATCH 34/52] revert changes extraneous --- include/tvm/relay/expr_functor.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 1932882ad198..688ad8254fa8 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -241,9 +241,9 @@ class MixedModeVisitor : public ::tvm::relay::ExprVisitor { * \brief VisitExpr is finalized to preserve call expansion of dataflow regions */ void VisitExpr(const Expr& expr) final; - virtual void VisitExpr_(const CallNode* op) override; - virtual void VisitExpr_(const TupleNode* op) override; - virtual void VisitExpr_(const TupleGetItemNode* op) override; + void VisitExpr_(const CallNode* op) override; + void VisitExpr_(const TupleNode* op) override; + void VisitExpr_(const TupleGetItemNode* op) override; protected: /*! From 6d59d1c1286304e979a23a83d0b0c3d01f2c1875 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 22:27:32 -0700 Subject: [PATCH 35/52] correct incorrect attrs being used for ops --- src/relay/op/tensor/reduce.cc | 41 +++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index b5643a10f8d4..7881e03e8b37 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -289,22 +289,16 @@ inline std::vector ReduceShapeImpl(const std::vector& in_s } } -/*! - * \brief ArgReduceRel Output type and shape relation evaluation function. - * \param num_inputs Number of input types in the args. - * \param attrs The additional attributes of the operator. - * \param reporter The reporter to report solution to. - * \return false if This relation cannot be resolved. true if this relation has been resolved. - */ -bool ArgReduceRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { +template +bool GenericReduceRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { ICHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) return false; ICHECK(static_cast(data->shape.size()) != 0); std::vector in_shape(data->shape.begin(), data->shape.end()); - const ReduceAttrs* param = attrs.as(); + const T* param = attrs.as(); ICHECK(param != nullptr); // assign output type and shape @@ -312,6 +306,29 @@ bool ArgReduceRel(const Array& types, int num_inputs, const Attrs& attrs, reporter->Assign(types[1], TensorType(oshape, DataType::Int(32))); return true; } +/*! + * \brief ArgReduceRel Output type and shape relation evaluation function. + * \param num_inputs Number of input types in the args. + * \param attrs The additional attributes of the operator. + * \param reporter The reporter to report solution to. + * \return false if This relation cannot be resolved. true if this relation has been resolved. + */ +bool ArgReduceRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + return GenericReduceRel(types, num_inputs, attrs, reporter); +} + +/*! + * \brief SingleElementArgReduceRel Output type and shape relation evaluation function. + * \param num_inputs Number of input types in the args. + * \param attrs The additional attributes of the operator. + * \param reporter The reporter to report solution to. + * \return false if This relation cannot be resolved. true if this relation has been resolved. + */ +bool SingleElementArgReduceRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + return GenericReduceRel(types, num_inputs, attrs, reporter); +} /*! * \brief ReduceRel Output type and shape relation evaluation function. @@ -381,7 +398,7 @@ values over a given axis. )code" TVM_ADD_FILELINE) .set_attrs_type() .set_support_level(4) - .add_type_rel("ArgReduce", ArgReduceRel) + .add_type_rel("ArgReduce", SingleElementArgReduceRel) .set_attr("FTVMCompute", ArgMaxCompute) .set_attr("TOpPattern", kCommReduce); @@ -397,7 +414,7 @@ values over a given axis. )code" TVM_ADD_FILELINE) .set_attrs_type() .set_support_level(4) - .add_type_rel("ArgReduce", ArgReduceRel) + .add_type_rel("ArgReduce", SingleElementArgReduceRel) .set_attr("FTVMCompute", ArgMinCompute) .set_attr("TOpPattern", kCommReduce); From d7a595f0eb064dda171a97bf48499c9af69b66fb Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 22:32:42 -0700 Subject: [PATCH 36/52] change attributes --- include/tvm/relay/attrs/reduce.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/relay/attrs/reduce.h b/include/tvm/relay/attrs/reduce.h index 44ea79fcb517..8c4794e9ac00 100644 --- a/include/tvm/relay/attrs/reduce.h +++ b/include/tvm/relay/attrs/reduce.h @@ -68,7 +68,7 @@ struct OneElementReduceAttrs : public tvm::AttrsNode { bool select_last_index; bool exclude; - TVM_DECLARE_ATTRS(ReduceAttrs, "relay.attrs.ReduceAttrs") { + TVM_DECLARE_ATTRS(OneElementReduceAttrs, "relay.attrs.OneElementReduceAttrs") { TVM_ATTR_FIELD(axis) .set_default(NullValue>()) .describe(R"code(The axis or axes along which to perform the reduction. From 6b645dea3a5ceea167b2d55dd0ba35b2bae7b9b9 Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Mon, 23 Aug 2021 22:41:50 -0700 Subject: [PATCH 37/52] remove old impl --- python/tvm/relay/frontend/onnx.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 3c027981ac31..1d9f1bc3ab28 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1811,15 +1811,6 @@ def _impl_v1(cls, inputs, attr, params): class ArgMax(OnnxOpConverter): """Operator converter for ArgMax.""" - @classmethod - def _impl_v1(cls, inputs, attr, params): - if "select_last_index" in attr: - raise NotImplementedError("select_last_index not supported in ArgMax") - axis = attr.get("axis", 0) - keepdims = attr.get("keepdims", True) - attr = {"axis": axis, "keepdims": keepdims} - return _op.cast(AttrCvt("argmax")(inputs, attr), "int64") - @classmethod def _impl_v13(cls, inputs, attr, params): axis = attr.get("axis", 0) @@ -1831,16 +1822,6 @@ def _impl_v13(cls, inputs, attr, params): class ArgMin(OnnxOpConverter): """Operator converter for ArgMin.""" - @classmethod - def _impl_v1(cls, inputs, attr, params): - if "select_last_index" in attr: - raise NotImplementedError("select_last_index not supported in ArgMin") - axis = attr.get("axis", 0) - keepdims = attr.get("keepdims", True) - attr = {"axis": axis, "keepdims": keepdims} - return _op.cast(AttrCvt("argmin")(inputs, attr), "int64") - - @classmethod def _impl_v13(cls, inputs, attr, params): axis = attr.get("axis", 0) From 0faf5b69767cde922392decb2b9c43862b74c299 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 22:49:50 -0700 Subject: [PATCH 38/52] register new attribute node --- src/relay/op/tensor/reduce.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 7881e03e8b37..75483ea8b015 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -38,6 +38,7 @@ namespace tvm { namespace relay { TVM_REGISTER_NODE_TYPE(ReduceAttrs); +TVM_REGISTER_NODE_TYPE(OneElementReduceAttrs); TVM_REGISTER_NODE_TYPE(VarianceAttrs); /*! From 96d85c2d930ca07f8d715d04215075d9c0621d47 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 22:53:36 -0700 Subject: [PATCH 39/52] clean up test --- tests/python/relay/test_op_level4.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index ec347224b0be..6415976bfd59 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -359,10 +359,7 @@ def get_test_case(shape, gt_func, test_argmin=False): for func, gt_func in funcs_and_gt_funcs: for shape in lengths: x_in = relay.var("x_in", shape=[shape]) - try: - output = func(x_in, select_last_index=True) - except: - breakpoint() + output = func(x_in, select_last_index=True) arr, ans = get_test_case(shape, gt_func, test_argmin=func == relay.argmin) mod = tvm.IRModule.from_expr(output) @@ -370,14 +367,7 @@ def get_test_case(shape, gt_func, test_argmin=False): op_res = relay.create_executor( "graph", mod=mod, device=dev, target=target ).evaluate()(arr) - print(target) - print(dev) - print(arr) - print(ans) - print(op_res) - print() - - raise ValueError("WHAT") + assert op_res.numpy().item() == ans def verify_mean_var_std(funcs, shape, axis, keepdims): From 8a6a4bcf42edebf5153a3c934da740009ca4b1ea Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 23:01:12 -0700 Subject: [PATCH 40/52] reformat --- tests/python/frontend/onnx/test_forward.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 67a94757a2d7..a1d821686ed5 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -235,8 +235,7 @@ def verify_with_ort( def quantize_and_verify_with_ort(onnx_model, input_names, input_shapes, target, dev): - from onnxruntime.quantization import (CalibrationDataReader, QuantType, - quantize_static) + from onnxruntime.quantization import CalibrationDataReader, QuantType, quantize_static input_arrays = [np.random.random(shape).astype("float32") for shape in input_shapes] From 29a2660e1d006efd8a6f315148cbf8abcbe74c65 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 23:03:39 -0700 Subject: [PATCH 41/52] reformat --- python/tvm/relay/frontend/onnx.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 1d9f1bc3ab28..8376d56fe089 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -35,9 +35,20 @@ from .. import random as _random from .. import ty as _ty from .. import vision as _vision -from .common import (AttrCvt, Renamer, fold_constant, get_name, get_relay_op, - gru_cell, infer_channels, infer_shape, infer_type, - infer_value, lstm_cell, new_var, unbind) +from .common import ( + AttrCvt, + Renamer, + fold_constant, + get_name, + get_relay_op, + infer_channels, + infer_shape, + infer_type, + infer_value, + lstm_cell, + new_var, + unbind, +) __all__ = ["from_onnx"] @@ -1819,6 +1830,7 @@ def _impl_v13(cls, inputs, attr, params): attr = {"axis": axis, "keepdims": keepdims, "select_last_index": select_last_index} return _op.cast(AttrCvt("argmax")(inputs, attr), "int64") + class ArgMin(OnnxOpConverter): """Operator converter for ArgMin.""" @@ -1830,6 +1842,7 @@ def _impl_v13(cls, inputs, attr, params): attr = {"axis": axis, "keepdims": keepdims, "select_last_index": select_last_index} return _op.cast(AttrCvt("argmin")(inputs, attr), "int64") + class Softmax(OnnxOpConverter): """Operator converter for Softmax.""" From 3a2a38d88ac2176a0f2615a80f840194afeff3af Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 23 Aug 2021 23:11:29 -0700 Subject: [PATCH 42/52] coolio --- python/tvm/relay/op/reduce.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/op/reduce.py b/python/tvm/relay/op/reduce.py index dfc8e71026af..23accebfd0ec 100644 --- a/python/tvm/relay/op/reduce.py +++ b/python/tvm/relay/op/reduce.py @@ -46,8 +46,8 @@ def argmax(data, axis=None, keepdims=False, exclude=False, select_last_index=Fal NOT in axis instead. select_last_index : bool - Whether to select the last index or the first index if the max element appears in multiple indices, - default is False (first index). + Whether to select the last index or the first index if the max element appears in + multiple indices, default is False (first index). Returns ------- @@ -81,8 +81,8 @@ def argmin(data, axis=None, keepdims=False, exclude=False, select_last_index=Fal NOT in axis instead. select_last_index : bool - Whether to select the last index or the first index if the min element appears in multiple indices, - default is False (first index). + Whether to select the last index or the first index if the min element appears in + multiple indices, default is False (first index). Returns ------- From 296ac2e977605d7bc945df8a9ea6f905aadcc77b Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Tue, 24 Aug 2021 12:07:49 -0700 Subject: [PATCH 43/52] stable comparison --- include/tvm/topi/reduction.h | 45 +++++++++++++++++++++++++++++------- 1 file changed, 37 insertions(+), 8 deletions(-) diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index 9bd863873531..e4318957160f 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -435,10 +435,24 @@ inline FCommReduce MakeArgminReducer(bool select_last_index = false) { // Create a Commutative Reducer with a comparison operation, and method to get the initial value. auto fcombine = [=](Array lhs, Array rhs) { Array result; - // Cast to resolve ambiguous operators - auto comparison = select_last_index ? PrimExpr(lhs[1]) < PrimExpr(rhs[1]) : lhs[1] <= rhs[1]; - result.push_back(tvm::tir::Select(comparison, lhs[0], rhs[0])); // idx - result.push_back(tvm::tir::Select(comparison, lhs[1], rhs[1])); // val + + // These variables compare the actual values of the array + auto is_smaller = lhs[1] < rhs[1]; + auto is_same = lhs[1] == rhs[1]; + + // This checks if the indices are correct for the reduction. E.g. for select_last_index + // it gives precedence for later indices of the same element and precedence for sooner + // indices if not select_last_index; + PrimExpr proper_index; + if (select_last_index) { + proper_index = lhs[0] > rhs[0]; + } else { + proper_index = lhs[0] < rhs[0]; + } + + PrimExpr update_index = is_smaller || (is_same && proper_index); + result.push_back(tvm::tir::Select(update_index, lhs[0], rhs[0])); // idx + result.push_back(tvm::tir::Select(is_smaller, lhs[1], rhs[1])); // val return result; }; auto fidentity = [&](std::vector types) { @@ -476,10 +490,25 @@ inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { // Create a Commutative Reducer with a comparison operation, and method to get the initial value. auto fcombine = [=](Array lhs, Array rhs) { Array result; - // Cast to resolve ambiguous operators - auto comparison = select_last_index ? PrimExpr(lhs[1]) > PrimExpr(rhs[1]) : lhs[1] >= rhs[1]; - result.push_back(tvm::tir::Select(comparison, lhs[0], rhs[0])); // idx - result.push_back(tvm::tir::Select(comparison, lhs[1], rhs[1])); // val + + // These variables compare the actual values of the array + auto is_bigger = lhs[1] > rhs[1]; + auto is_same = lhs[1] == rhs[1]; + + // This checks if the indices are correct for the reduction. E.g. for select_last_index + // it gives precedence for later indices of the same element and precedence for sooner + // indices if not select_last_index; + PrimExpr proper_index; + if (select_last_index) { + proper_index = lhs[0] > rhs[0]; + } else { + proper_index = lhs[0] < rhs[0]; + } + + PrimExpr update_index = is_bigger || (is_same && proper_index); + result.push_back(tvm::tir::Select(update_index, lhs[0], rhs[0])); // idx + result.push_back(tvm::tir::Select(is_bigger, lhs[1], rhs[1])); // val + LOG(WARNING) << result; return result; }; auto fidentity = [&](std::vector types) { From 12f7213248d5cf7083b49f0acaf6a5ab7f1bd078 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Tue, 24 Aug 2021 12:15:23 -0700 Subject: [PATCH 44/52] casts to avoid ambiguity --- include/tvm/topi/reduction.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index e4318957160f..d34afc8af909 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -445,9 +445,9 @@ inline FCommReduce MakeArgminReducer(bool select_last_index = false) { // indices if not select_last_index; PrimExpr proper_index; if (select_last_index) { - proper_index = lhs[0] > rhs[0]; + proper_index = PrimExpr(lhs[0]) > PrimExpr(rhs[0]); } else { - proper_index = lhs[0] < rhs[0]; + proper_index = PrimExpr(lhs[0]) < PrimExpr(rhs[0]); } PrimExpr update_index = is_smaller || (is_same && proper_index); @@ -500,9 +500,9 @@ inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { // indices if not select_last_index; PrimExpr proper_index; if (select_last_index) { - proper_index = lhs[0] > rhs[0]; + proper_index = PrimExpr(lhs[0]) > PrimExpr(rhs[0]); } else { - proper_index = lhs[0] < rhs[0]; + proper_index = PrimExpr(lhs[0]) < PrimExpr(rhs[0]); } PrimExpr update_index = is_bigger || (is_same && proper_index); From 20cdd361b8f5adffe2c203daceac5133d4d54a53 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Tue, 24 Aug 2021 12:24:25 -0700 Subject: [PATCH 45/52] casting more --- include/tvm/topi/reduction.h | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index d34afc8af909..caf327ea2dce 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -436,18 +436,24 @@ inline FCommReduce MakeArgminReducer(bool select_last_index = false) { auto fcombine = [=](Array lhs, Array rhs) { Array result; + // Casting to avoid operator ambiguity + PrimExpr lhs_idx = static_cast(lhs[0]); + PrimExpr rhs_idx = static_cast(rhs[0]); + PrimExpr lhs_val = static_cast(lhs[1]); + PrimExpr rhs_val = static_cast(rhs[1]); + // These variables compare the actual values of the array - auto is_smaller = lhs[1] < rhs[1]; - auto is_same = lhs[1] == rhs[1]; + auto is_smaller = lhs_val < rhs_val; + auto is_same = lhs_val == rhs_val; // This checks if the indices are correct for the reduction. E.g. for select_last_index // it gives precedence for later indices of the same element and precedence for sooner // indices if not select_last_index; PrimExpr proper_index; if (select_last_index) { - proper_index = PrimExpr(lhs[0]) > PrimExpr(rhs[0]); + proper_index = lhs_idx > rhs_idx; } else { - proper_index = PrimExpr(lhs[0]) < PrimExpr(rhs[0]); + proper_index = lhs_idx < rhs_idx; } PrimExpr update_index = is_smaller || (is_same && proper_index); @@ -491,18 +497,24 @@ inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { auto fcombine = [=](Array lhs, Array rhs) { Array result; + // Casting to avoid operator ambiguity + PrimExpr lhs_idx = static_cast(lhs[0]); + PrimExpr rhs_idx = static_cast(rhs[0]); + PrimExpr lhs_val = static_cast(lhs[1]); + PrimExpr rhs_val = static_cast(rhs[1]); + // These variables compare the actual values of the array - auto is_bigger = lhs[1] > rhs[1]; - auto is_same = lhs[1] == rhs[1]; + auto is_bigger = lhs_val > rhs_val; + auto is_same = lhs_val == rhs_val; // This checks if the indices are correct for the reduction. E.g. for select_last_index // it gives precedence for later indices of the same element and precedence for sooner // indices if not select_last_index; PrimExpr proper_index; if (select_last_index) { - proper_index = PrimExpr(lhs[0]) > PrimExpr(rhs[0]); + proper_index = lhs_idx > rhs_idx; } else { - proper_index = PrimExpr(lhs[0]) < PrimExpr(rhs[0]); + proper_index = lhs_idx < rhs_idx; } PrimExpr update_index = is_bigger || (is_same && proper_index); From 49b632208a7aa132fcd86afa019ef2ae29ae012b Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Wed, 25 Aug 2021 19:10:40 -0700 Subject: [PATCH 46/52] correct arg passing --- python/tvm/topi/reduction.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/python/tvm/topi/reduction.py b/python/tvm/topi/reduction.py index cba43297f293..45d07af577a3 100644 --- a/python/tvm/topi/reduction.py +++ b/python/tvm/topi/reduction.py @@ -167,7 +167,7 @@ def min(data, axis=None, keepdims=False): return cpp.min(data, axis, keepdims) -def argmax(data, axis=None, keepdims=False, exclude=False, select_last_index=False): +def argmax(data, axis=None, keepdims=False, select_last_index=False): """Returns the indices of the maximum values along an axis. Parameters @@ -185,14 +185,18 @@ def argmax(data, axis=None, keepdims=False, exclude=False, select_last_index=Fal with size one. With this option, the result will broadcast correctly against the input array. + select_last_index: bool + Whether to select the last index if the maximum element appears multiple times, else + select the first index. + Returns ------- ret : tvm.te.Tensor """ - return cpp.argmax(data, axis, keepdims, exclude=exclude, select_last_index=select_last_index) + return cpp.argmax(data, axis, keepdims, select_last_index) -def argmin(data, axis=None, keepdims=False, exclude=False, select_last_index=False): +def argmin(data, axis=None, keepdims=False, select_last_index=False): """Returns the indices of the minimum values along an axis. Parameters @@ -210,11 +214,15 @@ def argmin(data, axis=None, keepdims=False, exclude=False, select_last_index=Fal with size one. With this option, the result will broadcast correctly against the input array. + select_last_index: bool + Whether to select the last index if the minimum element appears multiple times, else + select the first index. + Returns ------- ret : tvm.te.Tensor """ - return cpp.argmin(data, axis, keepdims, exclude, select_last_index) + return cpp.argmin(data, axis, keepdims, select_last_index) def prod(data, axis=None, keepdims=False): From 8f37f89723f0c5e554ea523c8acfbdcd062b0124 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Fri, 27 Aug 2021 10:19:03 -0700 Subject: [PATCH 47/52] fix broken input --- python/tvm/relay/frontend/onnx.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 8376d56fe089..e32d6a46dc8a 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -41,6 +41,7 @@ fold_constant, get_name, get_relay_op, + gru_cell, infer_channels, infer_shape, infer_type, From 2db29ca43f8c38fa7e17c69537a94e6095ff63a9 Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Mon, 30 Aug 2021 11:26:25 -0700 Subject: [PATCH 48/52] OneElementReduceAttrs-->ArgReduceAttrs" --- include/tvm/relay/attrs/reduce.h | 4 ++-- src/relay/op/tensor/reduce.cc | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/include/tvm/relay/attrs/reduce.h b/include/tvm/relay/attrs/reduce.h index 8c4794e9ac00..d91b3594b5a3 100644 --- a/include/tvm/relay/attrs/reduce.h +++ b/include/tvm/relay/attrs/reduce.h @@ -62,13 +62,13 @@ struct ReduceAttrs : public tvm::AttrsNode { }; /*! \brief Attributes for Reduce operators which reduce by finding a single element. E.g. argmin */ -struct OneElementReduceAttrs : public tvm::AttrsNode { +struct ArgReduceAttrs : public tvm::AttrsNode { Array axis; bool keepdims; bool select_last_index; bool exclude; - TVM_DECLARE_ATTRS(OneElementReduceAttrs, "relay.attrs.OneElementReduceAttrs") { + TVM_DECLARE_ATTRS(ArgReduceAttrs, "relay.attrs.ArgReduceAttrs") { TVM_ATTR_FIELD(axis) .set_default(NullValue>()) .describe(R"code(The axis or axes along which to perform the reduction. diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 75483ea8b015..c17df42f1844 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -38,7 +38,7 @@ namespace tvm { namespace relay { TVM_REGISTER_NODE_TYPE(ReduceAttrs); -TVM_REGISTER_NODE_TYPE(OneElementReduceAttrs); +TVM_REGISTER_NODE_TYPE(ArgReduceAttrs); TVM_REGISTER_NODE_TYPE(VarianceAttrs); /*! @@ -215,7 +215,7 @@ Array ReduceCompute(const Attrs& attrs, const Array& inp template Array OneElementReduceCompute(const Attrs& attrs, const Array& inputs, const Type& out_type, F f) { - const OneElementReduceAttrs* param = attrs.as(); + const ArgReduceAttrs* param = attrs.as(); ICHECK(param != nullptr); if (inputs[0]->shape.size() == 0) { return {topi::identity(inputs[0])}; @@ -328,7 +328,7 @@ bool ArgReduceRel(const Array& types, int num_inputs, const Attrs& attrs, */ bool SingleElementArgReduceRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - return GenericReduceRel(types, num_inputs, attrs, reporter); + return GenericReduceRel(types, num_inputs, attrs, reporter); } /*! @@ -364,7 +364,7 @@ Expr MakeReduce(Expr data, Array axis, bool keepdims, bool exclude, Str Expr MakeOneElementReduce(Expr data, Array axis, bool keepdims, bool exclude, bool select_last_index, String op_name) { - auto attrs = make_object(); + auto attrs = make_object(); attrs->axis = std::move(axis); attrs->keepdims = keepdims; attrs->exclude = exclude; @@ -397,7 +397,7 @@ RELAY_REGISTER_ONE_ELEMENT_REDUCE_OP("argmax") values over a given axis. )code" TVM_ADD_FILELINE) - .set_attrs_type() + .set_attrs_type() .set_support_level(4) .add_type_rel("ArgReduce", SingleElementArgReduceRel) .set_attr("FTVMCompute", ArgMaxCompute) @@ -413,7 +413,7 @@ RELAY_REGISTER_ONE_ELEMENT_REDUCE_OP("argmin") values over a given axis. )code" TVM_ADD_FILELINE) - .set_attrs_type() + .set_attrs_type() .set_support_level(4) .add_type_rel("ArgReduce", SingleElementArgReduceRel) .set_attr("FTVMCompute", ArgMinCompute) From 40551902954b5ed4e695a4ca288e36037973235c Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Mon, 30 Aug 2021 11:37:24 -0700 Subject: [PATCH 49/52] reduce boilerplate --- src/relay/op/tensor/reduce.cc | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index c17df42f1844..8bef3e81c125 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -319,18 +319,6 @@ bool ArgReduceRel(const Array& types, int num_inputs, const Attrs& attrs, return GenericReduceRel(types, num_inputs, attrs, reporter); } -/*! - * \brief SingleElementArgReduceRel Output type and shape relation evaluation function. - * \param num_inputs Number of input types in the args. - * \param attrs The additional attributes of the operator. - * \param reporter The reporter to report solution to. - * \return false if This relation cannot be resolved. true if this relation has been resolved. - */ -bool SingleElementArgReduceRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { - return GenericReduceRel(types, num_inputs, attrs, reporter); -} - /*! * \brief ReduceRel Output type and shape relation evaluation function. * \param num_inputs Number of input types in the args. @@ -399,7 +387,7 @@ values over a given axis. )code" TVM_ADD_FILELINE) .set_attrs_type() .set_support_level(4) - .add_type_rel("ArgReduce", SingleElementArgReduceRel) + .add_type_rel("ArgReduce", GenericReduceRel) .set_attr("FTVMCompute", ArgMaxCompute) .set_attr("TOpPattern", kCommReduce); @@ -415,7 +403,7 @@ values over a given axis. )code" TVM_ADD_FILELINE) .set_attrs_type() .set_support_level(4) - .add_type_rel("ArgReduce", SingleElementArgReduceRel) + .add_type_rel("ArgReduce", GenericReduceRel) .set_attr("FTVMCompute", ArgMinCompute) .set_attr("TOpPattern", kCommReduce); From 1f56147c34d3b3f10a7ff4cbab64d726060efe6f Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Mon, 30 Aug 2021 12:42:23 -0700 Subject: [PATCH 50/52] change names --- src/relay/op/tensor/reduce.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 8bef3e81c125..693589fecfb4 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -213,8 +213,8 @@ Array ReduceCompute(const Attrs& attrs, const Array& inp } template -Array OneElementReduceCompute(const Attrs& attrs, const Array& inputs, - const Type& out_type, F f) { +Array ArgReduceCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type, F f) { const ArgReduceAttrs* param = attrs.as(); ICHECK(param != nullptr); if (inputs[0]->shape.size() == 0) { @@ -377,7 +377,7 @@ Expr MakeOneElementReduce(Expr data, Array axis, bool keepdims, bool ex Array ArgMaxCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - return OneElementReduceCompute(attrs, inputs, out_type, topi::argmax); + return ArgReduceCompute(attrs, inputs, out_type, topi::argmax); } RELAY_REGISTER_ONE_ELEMENT_REDUCE_OP("argmax") @@ -393,7 +393,7 @@ values over a given axis. Array ArgMinCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - return OneElementReduceCompute(attrs, inputs, out_type, topi::argmin); + return ArgReduceCompute(attrs, inputs, out_type, topi::argmin); } RELAY_REGISTER_ONE_ELEMENT_REDUCE_OP("argmin") From d4cbfcca3337c3488c405ea01c6955c3b4572f40 Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Mon, 30 Aug 2021 12:43:16 -0700 Subject: [PATCH 51/52] remove log statement --- include/tvm/topi/reduction.h | 1 - 1 file changed, 1 deletion(-) diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index caf327ea2dce..d4e420d80b02 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -520,7 +520,6 @@ inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { PrimExpr update_index = is_bigger || (is_same && proper_index); result.push_back(tvm::tir::Select(update_index, lhs[0], rhs[0])); // idx result.push_back(tvm::tir::Select(is_bigger, lhs[1], rhs[1])); // val - LOG(WARNING) << result; return result; }; auto fidentity = [&](std::vector types) { From c5f308b290818496494924eab2da8c381ebe0cf3 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 30 Aug 2021 19:45:22 -0700 Subject: [PATCH 52/52] jostle ci