Skip to content

Commit

Permalink
[Relay, TOPI] Add searchsorted op (#9184)
Browse files Browse the repository at this point in the history
* Add relay definition

* 1D cpu test working

* multi dim working

* gpu version working

* check shape in type rel

* support side

* use target specfic max threads

* add relay boilerplate

* relay test working

* cleanup topi test

* fix test

* add torch converter

* handle other cases

* more topi test

* support torch bucketize

* update doc

* fix tests

* fix lint

* rebase fix

* make the test case smaller

* add tests for edge cases

* replace "side" attribute with boolean "right"

* add more descrition to binear_search IR gen params

* return index from binary_search rather than update inplace

* remove unused argument

* format fix
  • Loading branch information
masahi authored Oct 20, 2021
1 parent 3f064b6 commit 9cf0245
Show file tree
Hide file tree
Showing 19 changed files with 619 additions and 2 deletions.
16 changes: 16 additions & 0 deletions include/tvm/relay/attrs/algorithm.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,22 @@ struct TopKAttrs : public tvm::AttrsNode<TopKAttrs> {
}
};

struct SearchSortedAttrs : public tvm::AttrsNode<SearchSortedAttrs> {
bool right;
DataType dtype;

TVM_DECLARE_ATTRS(SearchSortedAttrs, "relay.attrs.SearchSortedAttrs") {
TVM_ATTR_FIELD(right).set_default(false).describe(
"Controls which index is returned if a value lands exactly on one of sorted values. If "
" false, the index of the first suitable location found is given. If true, return the "
"last such index. If there is no suitable index, return either 0 or N (where N is the "
"size of the innermost dimension).");
TVM_ATTR_FIELD(dtype)
.set_default(DataType::Int(32))
.describe("Data type of the output indices.");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_ALGORITHM_H_
22 changes: 22 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2774,6 +2774,26 @@ def all_any_common(self, op, inputs, input_types):
inp = inputs[0]
return op(inp, axis=dim, keepdims=keepdim)

def searchsorted_common(self, sorted_sequence, values, out_int32, right):
dtype = "int32" if out_int32 else "int64"
values_shape = _infer_shape(values)

if len(values_shape) == 0:
values = _op.expand_dims(values, 0)

out = _op.searchsorted(sorted_sequence, values, right=right, dtype=dtype)

if len(values_shape) == 0:
return _op.squeeze(out)

return out

def searchsorted(self, inputs, input_types):
return self.searchsorted_common(*inputs)

def bucketize(self, inputs, input_types):
return self.searchsorted_common(inputs[1], inputs[0], inputs[2], inputs[3])

# Operator mappings
def create_convert_map(self):
self.convert_map = {
Expand Down Expand Up @@ -2999,6 +3019,8 @@ def create_convert_map(self):
"aten::lstm": self.lstm,
"aten::all": functools.partial(self.all_any_common, _op.all),
"aten::any": functools.partial(self.all_any_common, _op.any),
"aten::searchsorted": self.searchsorted,
"aten::bucketize": self.bucketize,
}

def update_convert_map(self, custom_map):
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relay/op/_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@
register_strategy("topk", strategy.topk_strategy)
register_pattern("topk", OpPattern.OPAQUE)

# searchsorted
register_strategy("searchsorted", strategy.searchsorted_strategy)
register_pattern("searchsorted", OpPattern.OPAQUE)


@script
def _topk_shape_func_input_shape(data_shape, k, axis):
Expand Down
34 changes: 34 additions & 0 deletions python/tvm/relay/op/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,37 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"):
if ret_type == "both":
return TupleWrapper(out, 2)
return out


def searchsorted(sorted_sequence, values, right=False, dtype="int32"):
"""Find indices where elements should be inserted to maintain order.
If `sorted_sequence` is N-dimensional, the innermost dimension of
`values` are searched in the corresponding dimension of `sorted_sequence`.
Parameters
----------
sorted_sequence : relay.Expr
N-D or 1-D Tensor, containing monotonically increasing sequence
on the innermost dimension.
values : relay.Expr
N-D Tensor containing the search values. When `sorted_sequence` is 1-D,
the shape of `values` can be arbitrary. Otherwise, ranks of `sorted_sequence`
and `values` must be the same, and outer N-1 axes must have the same size.
right : bool, optional
Controls which index is returned if a value lands exactly on one of sorted values. If
False, the index of the first suitable location found is given. If true, return the
last such index. If there is no suitable index, return either 0 or N (where N is the
size of the innermost dimension).
dtype : string, optional
The data type of the output indices.
Returns
-------
indices : relay.Expr
Tensor with same shape as values, representing the indices of
elements of `values` if they are inserted in `sorted_sequence`.
"""
return _make.searchsorted(sorted_sequence, values, right, dtype)
5 changes: 5 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,11 @@ class TopkAttrs(Attrs):
"""Attributes used in topk operators"""


@tvm._ffi.register_object("relay.attrs.SearchSortedAttrs")
class SearchSortedAttrs(Attrs):
"""Attributes used in searchsorted operators"""


@tvm._ffi.register_object("relay.attrs.TupleGetItemAttrs")
class TupleGetItemAttrs(Attrs):
"""Attributes used in tuple item access operators"""
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,18 @@ def topk_strategy_cuda(attrs, inputs, out_type, target):
return strategy


@searchsorted_strategy.register(["cuda", "gpu"])
def searchsorted_strategy_cuda(attrs, inputs, out_type, target):
"""searchsorted cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_searchsorted(topi.cuda.searchsorted),
wrap_topi_schedule(topi.cuda.schedule_extern),
name="searchsorted.cuda",
)
return strategy


@multibox_prior_strategy.register(["cuda", "gpu"])
def multibox_prior_strategy_cuda(attrs, inputs, out_type, target):
"""multibox_prior cuda strategy"""
Expand Down
25 changes: 25 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,6 +1002,31 @@ def topk_strategy(attrs, inputs, out_type, target):
return strategy


# searchsorted
def wrap_compute_searchsorted(topi_compute):
"""Wrap searchsorted compute"""

def _compute_searchsorted(attrs, inputs, out_type):
right = attrs.right
dtype = attrs.dtype
return [topi_compute(inputs[0], inputs[1], right, dtype)]

return _compute_searchsorted


# searchsorted_strategy
@override_native_generic_func("searchsorted_strategy")
def searchsorted_strategy(attrs, inputs, out_type, target):
"""searchsorted generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_searchsorted(topi.searchsorted),
wrap_topi_schedule(topi.generic.schedule_extern),
name="searchsorted.generic",
)
return strategy


# multibox_prior
def wrap_compute_multibox_prior(topi_compute):
"""Wrap multibox_prior compute"""
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from .scan import *
from .einsum import *
from .unique import *
from .searchsorted import *
from . import generic
from . import nn
from . import x86
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,4 @@
from .sparse_reshape import *
from .transform import *
from .unique import *
from .searchsorted import *
102 changes: 102 additions & 0 deletions python/tvm/topi/cuda/searchsorted.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""searchsorted operator for GPU"""
import tvm
from tvm import te
from .. import utils
from ..searchsorted import binary_search


def searchsorted(sorted_sequence, values, right, out_dtype="int64"):
"""Find indices where elements should be inserted to maintain order.
If `sorted_sequence` is N-dimensional, the innermost dimension of
`values` are searched in the corresponding dimension of `sorted_sequence`.
Parameters
----------
sorted_sequence : te.Tensor
N-D or 1-D Tensor, containing monotonically increasing sequence
on the innermost dimension.
values : te.Tensor
N-D Tensor containing the search values. When `sorted_sequence` is 1-D,
the shape of `values` can be arbitrary. Otherwise, ranks of `sorted_sequence`
and `values` must be the same, and outer N-1 axes must have the same size.
right : bool, optional
Controls which index is returned if a value lands exactly on one of sorted values. If
False, the index of the first suitable location found is given. If true, return the
last such index. If there is no suitable index, return either 0 or N (where N is the
size of the innermost dimension).
dtype : string, optional
The data type of the output indices.
Returns
-------
indices : te.Tensor
Tensor with same shape as values, representing the indices of
elements of `values` if they are inserted in `sorted_sequence`.
"""

def ir(sorted_sequence, values, indices):
ib = tvm.tir.ir_builder.create()
sorted_sequence_shape = sorted_sequence.shape
values_shape = values.shape
num_search = utils.prod(values_shape)
search_range = sorted_sequence_shape[-1]

sorted_sequence = ib.buffer_ptr(sorted_sequence)
values = ib.buffer_ptr(values)
indices = ib.buffer_ptr(indices)

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
bx = te.thread_axis("blockIdx.x")
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(
bx, "thread_extent", tvm.tir.indexdiv(num_search + max_threads - 1, max_threads)
)
ib.scope_attr(tx, "thread_extent", max_threads)
tid = bx * max_threads + tx

with ib.if_scope(tid < num_search):
if len(sorted_sequence_shape) == 1:
sequence_offset = 0
else:
sequence_id = tid // values_shape[-1]
sequence_offset = sequence_id * search_range

indices[tid] = binary_search(
ib,
sequence_offset,
search_range,
sorted_sequence,
values[tid],
right,
out_dtype,
)

return ib.get()

return te.extern(
values.shape,
[sorted_sequence, values],
lambda ins, outs: ir(ins[0], ins[1], outs[0]),
name="searchsorted",
dtype=out_dtype,
)
Loading

0 comments on commit 9cf0245

Please sign in to comment.