diff --git a/include/tvm/relay/attrs/algorithm.h b/include/tvm/relay/attrs/algorithm.h index 83b4ddaead43..3652a09e9168 100644 --- a/include/tvm/relay/attrs/algorithm.h +++ b/include/tvm/relay/attrs/algorithm.h @@ -76,6 +76,22 @@ struct TopKAttrs : public tvm::AttrsNode { } }; +struct SearchSortedAttrs : public tvm::AttrsNode { + bool right; + DataType dtype; + + TVM_DECLARE_ATTRS(SearchSortedAttrs, "relay.attrs.SearchSortedAttrs") { + TVM_ATTR_FIELD(right).set_default(false).describe( + "Controls which index is returned if a value lands exactly on one of sorted values. If " + " false, the index of the first suitable location found is given. If true, return the " + "last such index. If there is no suitable index, return either 0 or N (where N is the " + "size of the innermost dimension)."); + TVM_ATTR_FIELD(dtype) + .set_default(DataType::Int(32)) + .describe("Data type of the output indices."); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_ALGORITHM_H_ diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 76cd0455661b..3fc202a7cc91 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2774,6 +2774,26 @@ def all_any_common(self, op, inputs, input_types): inp = inputs[0] return op(inp, axis=dim, keepdims=keepdim) + def searchsorted_common(self, sorted_sequence, values, out_int32, right): + dtype = "int32" if out_int32 else "int64" + values_shape = _infer_shape(values) + + if len(values_shape) == 0: + values = _op.expand_dims(values, 0) + + out = _op.searchsorted(sorted_sequence, values, right=right, dtype=dtype) + + if len(values_shape) == 0: + return _op.squeeze(out) + + return out + + def searchsorted(self, inputs, input_types): + return self.searchsorted_common(*inputs) + + def bucketize(self, inputs, input_types): + return self.searchsorted_common(inputs[1], inputs[0], inputs[2], inputs[3]) + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -2999,6 +3019,8 @@ def create_convert_map(self): "aten::lstm": self.lstm, "aten::all": functools.partial(self.all_any_common, _op.all), "aten::any": functools.partial(self.all_any_common, _op.any), + "aten::searchsorted": self.searchsorted, + "aten::bucketize": self.bucketize, } def update_convert_map(self, custom_map): diff --git a/python/tvm/relay/op/_algorithm.py b/python/tvm/relay/op/_algorithm.py index 817f96b696df..19162a108395 100644 --- a/python/tvm/relay/op/_algorithm.py +++ b/python/tvm/relay/op/_algorithm.py @@ -41,6 +41,10 @@ register_strategy("topk", strategy.topk_strategy) register_pattern("topk", OpPattern.OPAQUE) +# searchsorted +register_strategy("searchsorted", strategy.searchsorted_strategy) +register_pattern("searchsorted", OpPattern.OPAQUE) + @script def _topk_shape_func_input_shape(data_shape, k, axis): diff --git a/python/tvm/relay/op/algorithm.py b/python/tvm/relay/op/algorithm.py index 119936f632f8..809a9061ade0 100644 --- a/python/tvm/relay/op/algorithm.py +++ b/python/tvm/relay/op/algorithm.py @@ -115,3 +115,37 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"): if ret_type == "both": return TupleWrapper(out, 2) return out + + +def searchsorted(sorted_sequence, values, right=False, dtype="int32"): + """Find indices where elements should be inserted to maintain order. + If `sorted_sequence` is N-dimensional, the innermost dimension of + `values` are searched in the corresponding dimension of `sorted_sequence`. + + Parameters + ---------- + sorted_sequence : relay.Expr + N-D or 1-D Tensor, containing monotonically increasing sequence + on the innermost dimension. + + values : relay.Expr + N-D Tensor containing the search values. When `sorted_sequence` is 1-D, + the shape of `values` can be arbitrary. Otherwise, ranks of `sorted_sequence` + and `values` must be the same, and outer N-1 axes must have the same size. + + right : bool, optional + Controls which index is returned if a value lands exactly on one of sorted values. If + False, the index of the first suitable location found is given. If true, return the + last such index. If there is no suitable index, return either 0 or N (where N is the + size of the innermost dimension). + + dtype : string, optional + The data type of the output indices. + + Returns + ------- + indices : relay.Expr + Tensor with same shape as values, representing the indices of + elements of `values` if they are inserted in `sorted_sequence`. + """ + return _make.searchsorted(sorted_sequence, values, right, dtype) diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 8fd46817b817..dba40b2f6f34 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -564,6 +564,11 @@ class TopkAttrs(Attrs): """Attributes used in topk operators""" +@tvm._ffi.register_object("relay.attrs.SearchSortedAttrs") +class SearchSortedAttrs(Attrs): + """Attributes used in searchsorted operators""" + + @tvm._ffi.register_object("relay.attrs.TupleGetItemAttrs") class TupleGetItemAttrs(Attrs): """Attributes used in tuple item access operators""" diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index da7cbd5cec10..5f24dbda9d35 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -1022,6 +1022,18 @@ def topk_strategy_cuda(attrs, inputs, out_type, target): return strategy +@searchsorted_strategy.register(["cuda", "gpu"]) +def searchsorted_strategy_cuda(attrs, inputs, out_type, target): + """searchsorted cuda strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_searchsorted(topi.cuda.searchsorted), + wrap_topi_schedule(topi.cuda.schedule_extern), + name="searchsorted.cuda", + ) + return strategy + + @multibox_prior_strategy.register(["cuda", "gpu"]) def multibox_prior_strategy_cuda(attrs, inputs, out_type, target): """multibox_prior cuda strategy""" diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index d021b5d9d84d..777f17ba6084 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1002,6 +1002,31 @@ def topk_strategy(attrs, inputs, out_type, target): return strategy +# searchsorted +def wrap_compute_searchsorted(topi_compute): + """Wrap searchsorted compute""" + + def _compute_searchsorted(attrs, inputs, out_type): + right = attrs.right + dtype = attrs.dtype + return [topi_compute(inputs[0], inputs[1], right, dtype)] + + return _compute_searchsorted + + +# searchsorted_strategy +@override_native_generic_func("searchsorted_strategy") +def searchsorted_strategy(attrs, inputs, out_type, target): + """searchsorted generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_searchsorted(topi.searchsorted), + wrap_topi_schedule(topi.generic.schedule_extern), + name="searchsorted.generic", + ) + return strategy + + # multibox_prior def wrap_compute_multibox_prior(topi_compute): """Wrap multibox_prior compute""" diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py index 6b22cf13f5b9..e243d6ee3bc7 100644 --- a/python/tvm/topi/__init__.py +++ b/python/tvm/topi/__init__.py @@ -45,6 +45,7 @@ from .scan import * from .einsum import * from .unique import * +from .searchsorted import * from . import generic from . import nn from . import x86 diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py index 21ddf57ca1d0..88d306761310 100644 --- a/python/tvm/topi/cuda/__init__.py +++ b/python/tvm/topi/cuda/__init__.py @@ -59,3 +59,4 @@ from .sparse_reshape import * from .transform import * from .unique import * +from .searchsorted import * diff --git a/python/tvm/topi/cuda/searchsorted.py b/python/tvm/topi/cuda/searchsorted.py new file mode 100644 index 000000000000..1c39ccaa8632 --- /dev/null +++ b/python/tvm/topi/cuda/searchsorted.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""searchsorted operator for GPU""" +import tvm +from tvm import te +from .. import utils +from ..searchsorted import binary_search + + +def searchsorted(sorted_sequence, values, right, out_dtype="int64"): + """Find indices where elements should be inserted to maintain order. + If `sorted_sequence` is N-dimensional, the innermost dimension of + `values` are searched in the corresponding dimension of `sorted_sequence`. + + Parameters + ---------- + sorted_sequence : te.Tensor + N-D or 1-D Tensor, containing monotonically increasing sequence + on the innermost dimension. + + values : te.Tensor + N-D Tensor containing the search values. When `sorted_sequence` is 1-D, + the shape of `values` can be arbitrary. Otherwise, ranks of `sorted_sequence` + and `values` must be the same, and outer N-1 axes must have the same size. + + right : bool, optional + Controls which index is returned if a value lands exactly on one of sorted values. If + False, the index of the first suitable location found is given. If true, return the + last such index. If there is no suitable index, return either 0 or N (where N is the + size of the innermost dimension). + + dtype : string, optional + The data type of the output indices. + + Returns + ------- + indices : te.Tensor + Tensor with same shape as values, representing the indices of + elements of `values` if they are inserted in `sorted_sequence`. + """ + + def ir(sorted_sequence, values, indices): + ib = tvm.tir.ir_builder.create() + sorted_sequence_shape = sorted_sequence.shape + values_shape = values.shape + num_search = utils.prod(values_shape) + search_range = sorted_sequence_shape[-1] + + sorted_sequence = ib.buffer_ptr(sorted_sequence) + values = ib.buffer_ptr(values) + indices = ib.buffer_ptr(indices) + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + ib.scope_attr( + bx, "thread_extent", tvm.tir.indexdiv(num_search + max_threads - 1, max_threads) + ) + ib.scope_attr(tx, "thread_extent", max_threads) + tid = bx * max_threads + tx + + with ib.if_scope(tid < num_search): + if len(sorted_sequence_shape) == 1: + sequence_offset = 0 + else: + sequence_id = tid // values_shape[-1] + sequence_offset = sequence_id * search_range + + indices[tid] = binary_search( + ib, + sequence_offset, + search_range, + sorted_sequence, + values[tid], + right, + out_dtype, + ) + + return ib.get() + + return te.extern( + values.shape, + [sorted_sequence, values], + lambda ins, outs: ir(ins[0], ins[1], outs[0]), + name="searchsorted", + dtype=out_dtype, + ) diff --git a/python/tvm/topi/searchsorted.py b/python/tvm/topi/searchsorted.py new file mode 100644 index 000000000000..28ffd170c955 --- /dev/null +++ b/python/tvm/topi/searchsorted.py @@ -0,0 +1,127 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""searchsorted operator""" +from . import utils +from . import te +from ..tir import ir_builder +from .math import cast + + +def binary_search(ib, sequence_offset, search_range, sorted_sequence, value, right, out_dtype): + """Common IR generator for binary search used by CPU and GPU backends. + + `sorted_sequence` is a N-D Buffer whose innermost dimension we want to search for `value`, + and `search_range` is the size of the innermost dimension. `sequence_offset` is + a 1-D linearlized offset specifying which of innermost sequences to search. + + So the search for `value` is performed over + `sorted_sequence[sequence_offset:(sequence_offset + search_range)]`. + Note that we index N-D Buffer by 1-D linearlized indices. + + """ + lo = ib.allocate(out_dtype, (1,), name="lo", scope="local") + hi = ib.allocate(out_dtype, (1,), name="hi", scope="local") + + lo[0] = cast(0, out_dtype) + hi[0] = cast(search_range, out_dtype) + + # Reference: pytorch/aten/src/ATen/native/cuda/Bucketization.cu + def condition(current_val, target_val): + if right: + return current_val <= target_val + return current_val < target_val + + with ib.while_loop(lo[0] < hi[0]): + mid = lo[0] + (hi[0] - lo[0] >> 1) + with ib.if_scope(condition(sorted_sequence[sequence_offset + mid], value)): + lo[0] = mid + 1 + with ib.else_scope(): + hi[0] = mid + + return lo[0] + + +def searchsorted(sorted_sequence, values, right=False, out_dtype="int64"): + """Find indices where elements should be inserted to maintain order. + If `sorted_sequence` is N-dimensional, the innermost dimension of + `values` are searched in the corresponding dimension of `sorted_sequence`. + + Parameters + ---------- + sorted_sequence : te.Tensor + N-D or 1-D Tensor, containing monotonically increasing sequence + on the innermost dimension. + + values : te.Tensor + N-D Tensor containing the search values. When `sorted_sequence` is 1-D, + the shape of `values` can be arbitrary. Otherwise, ranks of `sorted_sequence` + and `values` must be the same, and outer N-1 axes must have the same size. + + right : bool, optional + Controls which index is returned if a value lands exactly on one of sorted values. If + False, the index of the first suitable location found is given. If true, return the + last such index. If there is no suitable index, return either 0 or N (where N is the + size of the innermost dimension). + + dtype : string, optional + The data type of the output indices. + + Returns + ------- + indices : te.Tensor + Tensor with same shape as values, representing the indices of + elements of `values` if they are inserted in `sorted_sequence`. + """ + + def ir(sorted_sequence, values, indices): + ib = ir_builder.create() + sorted_sequence_shape = sorted_sequence.shape + values_shape = values.shape + num_search = utils.prod(values_shape) + search_range = sorted_sequence_shape[-1] + + sorted_sequence = ib.buffer_ptr(sorted_sequence) + values = ib.buffer_ptr(values) + indices = ib.buffer_ptr(indices) + + with ib.for_range(0, num_search, name="i", kind="parallel") as i: + if len(sorted_sequence_shape) == 1: + sequence_offset = 0 + else: + sequence_id = i // values_shape[-1] + sequence_offset = sequence_id * search_range + + indices[i] = binary_search( + ib, + sequence_offset, + search_range, + sorted_sequence, + values[i], + right, + out_dtype, + ) + + return ib.get() + + return te.extern( + values.shape, + [sorted_sequence, values], + lambda ins, outs: ir(ins[0], ins[1], outs[0]), + name="searchsorted", + dtype=out_dtype, + ) diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index d10c49f5c084..2d7d0a4b9e11 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -73,3 +73,4 @@ from .batch_to_space_nd import batch_to_space_nd_python from .nll_loss import nll_loss from .dense import dense +from .searchsorted import searchsorted_ref diff --git a/python/tvm/topi/testing/searchsorted.py b/python/tvm/topi/testing/searchsorted.py new file mode 100644 index 000000000000..10762600992d --- /dev/null +++ b/python/tvm/topi/testing/searchsorted.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The reference implementation of searchsorted in Numpy.""" +import numpy as np + + +def searchsorted_ref(sorted_sequence, values, right, out_dtype): + """Run Numpy searchsorted on 1-D or N-D sorted_sequence.""" + side = "right" if right else "left" + if len(sorted_sequence.shape) == 1 and len(values.shape) > 1: + sorted_sequence_2d = np.tile(sorted_sequence, (np.prod(values.shape[:-1]), 1)) + else: + sorted_sequence_2d = np.reshape(sorted_sequence, (-1, sorted_sequence.shape[-1])) + + values_2d = np.reshape(values, (-1, values.shape[-1])) + indices = np.zeros(values_2d.shape, dtype=out_dtype) + + for i in range(indices.shape[0]): + indices[i] = np.searchsorted(sorted_sequence_2d[i], values_2d[i], side=side) + + return np.reshape(indices, values.shape) diff --git a/src/relay/op/algorithm/searchsorted.cc b/src/relay/op/algorithm/searchsorted.cc new file mode 100644 index 000000000000..be5921311660 --- /dev/null +++ b/src/relay/op/algorithm/searchsorted.cc @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file searchsorted.cc + * \brief SearchSorted operators + */ +#include +#include +#include + +namespace tvm { +namespace relay { + +TVM_REGISTER_NODE_TYPE(SearchSortedAttrs); + +bool SearchSortedRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + const SearchSortedAttrs* param = attrs.as(); + ICHECK_EQ(types.size(), 3); + const auto* sorted_sequence = types[0].as(); + const auto* values = types[1].as(); + ICHECK(sorted_sequence) << "Expects TensorType in the first input"; + ICHECK(values) << "Expects TensorType in the second input"; + ICHECK_GT(values->shape.size(), 0) << "The rank of `values` must be greater than one"; + + if (sorted_sequence->shape.size() > 1) { + ICHECK_EQ(sorted_sequence->shape.size(), values->shape.size()) + << "Ranks of `sorted_sequence` and values must be the same if `sorted_sequence` is " + "multi-dimensional."; + + for (size_t i = 0; i < values->shape.size() - 1; ++i) { + if (sorted_sequence->shape[i].as() && values->shape[i].as()) { + ICHECK_EQ(sorted_sequence->shape[i].as()->value, + values->shape[i].as()->value) + << "`sorted_sequence and `values` do not have the same shape along outer axes"; + } + } + } + + reporter->Assign(types[2], TensorType(values->shape, param->dtype)); + return true; +} + +Expr MakeSearchSorted(Expr sorted_sequence, Expr values, Bool right, DataType dtype) { + auto attrs = make_object(); + static const Op& op = Op::Get("searchsorted"); + attrs->dtype = dtype; + attrs->right = right; + return Call(op, {sorted_sequence, values}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.searchsorted").set_body_typed(MakeSearchSorted); + +RELAY_REGISTER_OP("searchsorted") + .describe( + R"doc(Find indices where elements should be inserted to maintain order. +If `sorted_sequence` is N-dimensional, the innermost dimension of +`values` are searched in the corresponding dimension of `sorted_sequence`. +)doc" TVM_ADD_FILELINE) + .set_num_inputs(2) + .set_attrs_type() + .add_argument("sorted_sequence", "Tensor", + "Monotonically increasing sequence on the innermost dimension.") + .add_argument("values", "Tensor", "Values to search for.") + .set_support_level(6) + .add_type_rel("SearchSorted", SearchSortedRel); + +} // namespace relay +} // namespace tvm diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index aa164b03a2a7..657dc121961c 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -48,7 +48,7 @@ class ProducerToBufferTransformer : public StmtExprMutator { const std::unordered_map& tensor2buffers_; }; -/*! \brief Helper data structural to store informations. */ +/*! \brief Helper data structure to store information. */ struct CreateFuncInfo { /*! \brief The Tensor arg_list. */ Array arg_list; diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 3a3889d5cfb7..0031f4143fab 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3962,5 +3962,35 @@ def test_fn(f, dim=None, keepdim=False): verify_model(test_fn(f, 0, keepdim=True), [torch.rand(4, 2).bool()]) +@tvm.testing.uses_gpu +def test_searchsorted(): + def test_fn(out_int32=False, right=False): + return lambda x, y: torch.searchsorted(x, y, out_int32=out_int32, right=right) + + sorted_sequence = torch.tensor([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]) + values = torch.tensor([[3, 6, 9], [3, 6, 9]]) + verify_model(test_fn(), [sorted_sequence, values]) + verify_model(test_fn(out_int32=True), [sorted_sequence[0], values[0]]) + verify_model(test_fn(right=True), [sorted_sequence, values]) + + sorted_sequence_1d = torch.tensor([1, 3, 5, 7, 9]) + values = torch.tensor([[3, 6, 9], [4, 2, 7]]) + verify_model(test_fn(), [sorted_sequence_1d, values]) + + verify_model(test_fn(), [sorted_sequence_1d, torch.tensor(6)]) + + +@tvm.testing.uses_gpu +def test_bucketize(): + def test_fn(out_int32=False, right=False): + return lambda x, y: torch.bucketize(x, y, out_int32=out_int32, right=right) + + boundaries = torch.tensor([1, 3, 5, 7, 9]) + values = torch.tensor([3, 6, 9]) + + verify_model(test_fn(), [values, boundaries]) + verify_model(test_fn(out_int32=True, right=True), [values, boundaries]) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index eb4eee379b08..c968c5a7f19f 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -773,7 +773,6 @@ def verify_roi_align( mode=mode, ) for target, dev in tvm.testing.enabled_targets(): - print("test on", target) op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( np_data, np_rois ) diff --git a/tests/python/relay/test_op_level6.py b/tests/python/relay/test_op_level6.py index ea640c62dfeb..48c58dc2dc33 100644 --- a/tests/python/relay/test_op_level6.py +++ b/tests/python/relay/test_op_level6.py @@ -20,6 +20,7 @@ import numpy as np import tvm from tvm import relay +from tvm.topi.testing import searchsorted_ref import tvm.testing @@ -149,5 +150,28 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype, in_dtype="float32"): verify_topk(k, axis, ret_type, False, "int64", "float16") +@tvm.testing.uses_gpu +def test_searchsorted(): + def verify_searchsorted(right, dtype): + shape = (8, 9, 10) + values_shape = shape[:-1] + (10,) + sorted_sequence = relay.var("sorted_sequence", relay.TensorType(shape, "float32")) + values = relay.var("sorted_sequence", relay.TensorType(values_shape, "float32")) + out = relay.searchsorted(sorted_sequence, values, right, dtype) + func = relay.Function([sorted_sequence, values], out) + sorted_sequence_np = np.sort(np.random.randn(*shape).astype("float32"), axis=-1) + values_np = np.random.randn(*values_shape).astype("float32") + np_indices = searchsorted_ref(sorted_sequence_np, values_np, right, dtype) + + for target, dev in tvm.testing.enabled_targets(): + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + sorted_sequence_np, values_np + ) + np.testing.assert_equal(op_res.numpy(), np_indices) + + verify_searchsorted(False, "int32") + verify_searchsorted(True, "int64") + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/topi/python/test_topi_searchsorted.py b/tests/python/topi/python/test_topi_searchsorted.py new file mode 100644 index 000000000000..7b3976b7eb74 --- /dev/null +++ b/tests/python/topi/python/test_topi_searchsorted.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm +import tvm.testing +import tvm.topi.testing +from tvm.topi.testing import searchsorted_ref +from tvm import te, topi + +topi_funcs = {"generic": topi.searchsorted, "cuda": topi.cuda.searchsorted} + + +def get_implementations(): + topi_func_generic = topi_funcs["generic"] + topi_func_cuda = topi_funcs["cuda"] + + return { + "generic": ( + lambda x, y, side, out_dtype: topi_func_generic(x, y, side, out_dtype), + topi.generic.schedule_extern, + ), + "cuda": ( + lambda x, y, side, out_dtype: topi_func_cuda(x, y, side, out_dtype), + topi.cuda.schedule_extern, + ), + "vulkan": ( + lambda x, y, side, out_dtype: topi_func_cuda(x, y, side, out_dtype), + topi.cuda.schedule_extern, + ), + } + + +@tvm.testing.parametrize_targets +def test_searchsorted(dev, target): + def verify_with_input(sorted_sequence_np, values_np, right): + sorted_sequence = te.placeholder(sorted_sequence_np.shape, dtype="float32") + values = te.placeholder(values_np.shape, dtype="float32") + out_dtype = "int32" + implementations = get_implementations() + fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) + + with tvm.target.Target(target): + indices = fcompute(sorted_sequence, values, right, out_dtype) + s = fschedule([indices]) + + func = tvm.build(s, [sorted_sequence, values, indices], target=target) + dev = tvm.device(target, 0) + + a = tvm.nd.array(sorted_sequence_np, dev) + b = tvm.nd.array(values_np, dev) + c = tvm.nd.array(np.zeros(values_np.shape, dtype=indices.dtype), dev) + func(a, b, c) + ref = searchsorted_ref(sorted_sequence_np, values_np, right, out_dtype) + np.testing.assert_equal(c.numpy(), ref) + + def verify(sequence_len, num_search, outer_axes, right, sorted_sequence_1d=False): + if sorted_sequence_1d: + sorted_sequence_shape = (sequence_len,) + else: + sorted_sequence_shape = outer_axes + (sequence_len,) + values_shape = outer_axes + (num_search,) + + verify_with_input( + np.sort(np.random.randn(*sorted_sequence_shape).astype("float32"), axis=-1), + np.random.randn(*values_shape).astype("float32"), + right, + ) + + verify(1024, 1000, (10, 5, 3), False) + verify(999, 2000, (10, 5, 3), True) + verify(1000, 1000, (), False) + verify(2001, 100, (500,), True) + verify(2001, 100, (500,), False, sorted_sequence_1d=True) + + # Check edge cases + for right in [True, False]: + sorted_sequence = np.array([1, 2, 3, 4, 5], dtype="float32") + verify_with_input(sorted_sequence, np.array([6], dtype="float32"), right) + verify_with_input(sorted_sequence, np.array([0], dtype="float32"), right)