diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index 56558272f2a3d..a36f8e6c71cfb 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -108,6 +108,11 @@ This level enables additional math and transform operators. tvm.relay.where tvm.relay.argmax tvm.relay.argmin + tvm.relay.sum + tvm.relay.max + tvm.relay.min + tvm.relay.mean + tvm.relay.prod **Level 5: Vision/Image Operators** @@ -187,6 +192,11 @@ Level 4 Definitions .. autofunction:: tvm.relay.where .. autofunction:: tvm.relay.argmax .. autofunction:: tvm.relay.argmin +.. autofunction:: tvm.relay.sum +.. autofunction:: tvm.relay.max +.. autofunction:: tvm.relay.min +.. autofunction:: tvm.relay.mean +.. autofunction:: tvm.relay.prod Level 5 Definitions diff --git a/python/tvm/relay/op/reduce.py b/python/tvm/relay/op/reduce.py index a2a4519512eac..8f6da630312c7 100644 --- a/python/tvm/relay/op/reduce.py +++ b/python/tvm/relay/op/reduce.py @@ -62,3 +62,196 @@ def argmin(data, axis=None, keepdims=False, exclude=False): """ return _make.argmin(data, axis, keepdims, exclude) + + +def sum(data, axis=None, keepdims=False, exclude=False): + """Computes the sum of array elements over given axes. + + Example:: + + data = [[[1,2],[2,3],[1,3]], + [[1,4],[4,3],[5,2]], + [[7,1],[7,2],[7,3]]] + + sum(data, axis=1) + [[ 4. 8.] + [ 10. 9.] + [ 21. 6.]] + + sum(data, axis=[1,2]) + [ 12. 19. 27.] + + Parameters + ---------- + data : relay.Expr + The input data + + axis : None or int or tuple of int + Axis or axes along which a argmin operation is performed. + The default, axis=None, will find the indices of minimum element all of the elements of + the input array. If axis is negative it counts from the last to the first axis. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input array. + + exclude : bool + If `exclude` is true, reduction will be performed on the axes that are + NOT in axis instead. + + Returns + ------- + result : relay.Expr + The computed result. + """ + + return _make.sum(data, axis, keepdims, exclude) + + +def max(data, axis=None, keepdims=False, exclude=False): + """ Computes the max of array elements over given axes. + + Parameters + ---------- + data : relay.Expr + The input data + + axis : None or int or tuple of int + Axis or axes along which a argmin operation is performed. + The default, axis=None, will find the indices of minimum element all of the elements of + the input array. If axis is negative it counts from the last to the first axis. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input array. + + exclude : bool + If `exclude` is true, reduction will be performed on the axes that are + NOT in axis instead. + + Returns + ------- + result : relay.Expr + The computed result. + """ + + return _make.max(data, axis, keepdims, exclude) + + +def min(data, axis=None, keepdims=False, exclude=False): + """Computes the min of array elements over given axes. + + Parameters + ---------- + data : relay.Expr + The input data + + axis : None or int or tuple of int + Axis or axes along which a argmin operation is performed. + The default, axis=None, will find the indices of minimum element all of the elements of + the input array. If axis is negative it counts from the last to the first axis. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input array. + + exclude : bool + If `exclude` is true, reduction will be performed on the axes that are + NOT in axis instead. + + Returns + ------- + result : relay.Expr + The computed result. + """ + + return _make.min(data, axis, keepdims, exclude) + + +def mean(data, axis=None, keepdims=False, exclude=False): + """Computes the mean of array elements over given axes. + + Example:: + + data = [[[1,2],[2,3],[1,3]], + [[1,4],[4,3],[5,2]], + [[7,1],[7,2],[7,3]]] + + mean(data) + [3.22] + + mean(data, axis=[1,2]) + [ 2. 3.16666667 4.5] + + Parameters + ---------- + data : relay.Expr + The input data + + axis : None or int or tuple of int + Axis or axes along which a argmin operation is performed. + The default, axis=None, will find the indices of minimum element all of the elements of + the input array. If axis is negative it counts from the last to the first axis. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input array. + + exclude : bool + If `exclude` is true, reduction will be performed on the axes that are + NOT in axis instead. + + Returns + ------- + result : relay.Expr + The computed result. + """ + + return _make.mean(data, axis, keepdims, exclude) + + +def prod(data, axis=None, keepdims=False, exclude=False): + """Computes the products of array elements over given axes. + + Example:: + + data = [[[1,2],[2,3],[1,3]], + [[1,4],[4,3],[5,2]], + [[7,1],[7,2],[7,3]]] + + mean(data, axis=1) + [35562240] + + mean(data, axis=[1,2]) + [ 36 480 2058] + + Parameters + ---------- + data : relay.Expr + The input data + + axis : None or int or tuple of int + Axis or axes along which a argmin operation is performed. + The default, axis=None, will find the indices of minimum element all of the elements of + the input array. If axis is negative it counts from the last to the first axis. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input array. + + exclude : bool + If `exclude` is true, reduction will be performed on the axes that are + NOT in axis instead. + + Returns + ------- + result : relay.Expr + The computed result. + """ + + return _make.prod(data, axis, keepdims, exclude) diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index d2ec246886336..9394a51bccca2 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -172,6 +172,34 @@ bool ArgReduceRel(const Array& types, return true; } +/*! +* \brief ReduceRel Output type and shape relation evaluation function. +* \param num_inputs Number of input types in the args. +* \param attrs The additional attributes of the operator. +* \param reporter The reporter to report solution to. +* \return false if This relation cannot be resolved. true if this relation has been resolved. +*/ +bool ReduceRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) return false; + CHECK(static_cast(data->shape.size()) != 0); + std::vector in_shape; + for (auto i : data->shape) { + in_shape.push_back(i); + } + + const ReduceAttrs* param = attrs.as(); + CHECK(param != nullptr); + + // assign output type and shape + auto oshape = ReduceShapeImpl(in_shape, param, reporter); + reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + return true; +} #define RELAY_REGISTER_REDUCE_OP(OpName) \ TVM_REGISTER_API("relay.op._make." OpName) \ @@ -213,5 +241,88 @@ values over a given axis. .set_support_level(4) .add_type_rel("ArgReduce", ArgReduceRel); + +RELAY_REGISTER_REDUCE_OP("sum") +.describe(R"code(Computes the sum of array elements over given axes. + +Example:: + + data = [[[1,2],[2,3],[1,3]], + [[1,4],[4,3],[5,2]], + [[7,1],[7,2],[7,3]]] + + sum(data, axis=1) + [[ 4. 8.] + [ 10. 9.] + [ 21. 6.]] + + sum(data, axis=[1,2]) + [ 12. 19. 27.] + +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.set_support_level(4) +.add_type_rel("Reduce", ReduceRel); + + +RELAY_REGISTER_REDUCE_OP("max") +.describe(R"code(Computes the max of array elements over given axes. + +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.set_support_level(4) +.add_type_rel("Reduce", ReduceRel); + + +RELAY_REGISTER_REDUCE_OP("min") +.describe(R"code(Computes the min of array elements over given axes. + +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.set_support_level(4) +.add_type_rel("Reduce", ReduceRel); + + +RELAY_REGISTER_REDUCE_OP("mean") +.describe(R"code(Computes the mean of array elements over given axes. + +Example:: + + data = [[[1,2],[2,3],[1,3]], + [[1,4],[4,3],[5,2]], + [[7,1],[7,2],[7,3]]] + + mean(data) + [3.22] + + mean(data, axis=[1,2]) + [ 2. 3.16666667 4.5] + +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.set_support_level(4) +.add_type_rel("Reduce", ReduceRel); + + +RELAY_REGISTER_REDUCE_OP("prod") +.describe(R"code(Computes the products of array elements over given axes. + +Example:: + + data = [[[1,2],[2,3],[1,3]], + [[1,4],[4,3],[5,2]], + [[7,1],[7,2],[7,3]]] + + mean(data, axis=1) + [35562240] + + mean(data, axis=[1,2]) + [ 36 480 2058] + +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.set_support_level(4) +.add_type_rel("Reduce", ReduceRel); + } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index dea300422e45a..a53002099fd49 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -93,93 +93,6 @@ def test_binary_broadcast(): ftype = func.checked_type assert ftype.ret_type == relay.TensorType((5, 10, 4), "int32") -def test_argmax(): - ib = relay.ir_builder.IRBuilder() - n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") - x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32")) - with ib.function(x) as func: - ib.ret(relay.argmax(x, axis=(1,))) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((n, h, w), "int32") - - ib = relay.ir_builder.IRBuilder() - n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") - x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32")) - with ib.function(x) as func: - ib.ret(relay.argmax(x, axis=(2,), keepdims=True)) - ib.ret(func) - - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((n, c , 1, w), "int32") - - ib = relay.ir_builder.IRBuilder() - n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") - x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32")) - with ib.function(x) as func: - ib.ret(relay.argmax(x, axis=(2,), keepdims=True, exclude=True)) - ib.ret(func) - - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((1, 1 , h, 1), "int32") - -def test_argmin(): - ib = relay.ir_builder.IRBuilder() - n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") - x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32")) - with ib.function(x) as func: - ib.ret(relay.argmax(x, axis=(1,))) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((n, h, w), "int32") - - ib = relay.ir_builder.IRBuilder() - n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") - x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32")) - with ib.function(x) as func: - ib.ret(relay.argmin(x, axis=(2,), keepdims=True)) - ib.ret(func) - - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((n, c , 1, w), "int32") - - ib = relay.ir_builder.IRBuilder() - n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") - x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32")) - with ib.function(x) as func: - ib.ret(relay.argmin(x, axis=(2,), keepdims=True, exclude=True)) - ib.ret(func) - - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((1, 1 , h, 1), "int32") - - ib = relay.ir_builder.IRBuilder() - n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") - x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32")) - with ib.function(x) as func: - ib.ret(relay.argmin(x, axis=(2,1), keepdims=True, exclude=True)) - ib.ret(func) - - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((1, c , h, 1), "int32") - - ib = relay.ir_builder.IRBuilder() - n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") - x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32")) - with ib.function(x) as func: - ib.ret(relay.argmin(x, axis=None, keepdims=True, exclude=True)) - ib.ret(func) - - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((1, 1 , 1, 1), "int32") def test_where(): ib = relay.ir_builder.IRBuilder() @@ -194,6 +107,38 @@ def test_where(): assert ftype.ret_type == relay.TensorType((3, 4), "float32") +def verify_reduce(test_func, data, axis, keepdims, exclude, output): + ib = relay.ir_builder.IRBuilder() + x = ib.param("x", relay.ty.TensorType(data, "float32")) + with ib.function(x) as func: + ib.ret(test_func(x, axis, keepdims, exclude)) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + out_type = "int32" if test_func in [relay.argmin, relay.argmax] else "float32" + assert ftype.ret_type == relay.ty.TensorType(output, out_type) + +def test_reduce(): + d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4") + for func in [relay.sum, + relay.max, + relay.min, + relay.mean, + relay.prod, + relay.argmin, + relay.argmax]: + verify_reduce(func, (d1, d2, d3, d4), (2,), True, False, (d1, d2, 1, d4)) + verify_reduce(func, (2, 3, 4), (1,), True, False, (2, 1, 4)) + verify_reduce(func, (2, 3, 4), (0, 1, 2), False, False, ()) + verify_reduce(func, (4, 4, 3), None, True, False, (1, 1, 1)) + verify_reduce(func, (4, 4, 3), None, False, False, ()) + verify_reduce(func, (4, 4, 3), (0, 2), False, False, (4,)) + verify_reduce(func, (128, 24, 128), (0, 1), False, False, (128,)) + verify_reduce(func, (128, 24, 128), (0, 2), False, False, (24,)) + verify_reduce(func, (128, 24, 128), (0, 1), True, False, (1, 1, 128)) + verify_reduce(func, (128, 24, 128), (0, 2), True, False, (1, 24, 1)) + verify_reduce(func, (128, 24, 128), None, True, False, (1, 1, 1)) + if __name__ == "__main__": test_binary_op() test_binary_broadcast_op() @@ -201,5 +146,4 @@ def test_where(): test_binary_broadcast() test_where() test_multibox_prior() - test_argmax() - test_argmin() + test_reduce()