diff --git a/python/tvm/relay/op/_algorithm.py b/python/tvm/relay/op/_algorithm.py index ace3de5fff067..19162a1083955 100644 --- a/python/tvm/relay/op/_algorithm.py +++ b/python/tvm/relay/op/_algorithm.py @@ -45,6 +45,7 @@ register_strategy("searchsorted", strategy.searchsorted_strategy) register_pattern("searchsorted", OpPattern.OPAQUE) + @script def _topk_shape_func_input_shape(data_shape, k, axis): ndim = data_shape.shape[0] diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index f9a4051f6aeb0..dba40b2f6f342 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -563,10 +563,12 @@ class SparseConv2DAttrs(Attrs): class TopkAttrs(Attrs): """Attributes used in topk operators""" + @tvm._ffi.register_object("relay.attrs.SearchSortedAttrs") class SearchSortedAttrs(Attrs): """Attributes used in searchsorted operators""" + @tvm._ffi.register_object("relay.attrs.TupleGetItemAttrs") class TupleGetItemAttrs(Attrs): """Attributes used in tuple item access operators""" diff --git a/python/tvm/topi/cuda/searchsorted.py b/python/tvm/topi/cuda/searchsorted.py index d84a1941369fc..286112ed1776b 100644 --- a/python/tvm/topi/cuda/searchsorted.py +++ b/python/tvm/topi/cuda/searchsorted.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=invalid-name """searchsorted operator for GPU""" import tvm from tvm import te diff --git a/python/tvm/topi/searchsorted.py b/python/tvm/topi/searchsorted.py index 3a984a60ecb5f..02b7d23621092 100644 --- a/python/tvm/topi/searchsorted.py +++ b/python/tvm/topi/searchsorted.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=invalid-name """searchsorted operator""" from . import utils from . import te @@ -22,8 +23,7 @@ def binary_search( - ib, sequence_offset, search_range, index, sorted_sequence, values, out_indices, - side, out_dtype + ib, sequence_offset, search_range, index, sorted_sequence, values, out_indices, side, out_dtype ): """Common IR generator for CPU and GPU searchsorted.""" lo = ib.allocate(out_dtype, (1,), name="lo", scope="local") diff --git a/python/tvm/topi/testing/searchsorted.py b/python/tvm/topi/testing/searchsorted.py index e939fc8a0f7d8..1f4300b834110 100644 --- a/python/tvm/topi/testing/searchsorted.py +++ b/python/tvm/topi/testing/searchsorted.py @@ -1,7 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The reference implementation of searchsorted in Numpy.""" import numpy as np def searchsorted_ref(sorted_sequence, values, side, out_dtype): + """Run Numpy searchsorted on 1-D or N-D sorted_sequence.""" if len(sorted_sequence.shape) == 1 and len(values.shape) > 1: sorted_sequence_2d = np.tile(sorted_sequence, (np.prod(values.shape[:-1]), 1)) else: