Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Oct 4, 2021
1 parent 7233cce commit 8bb70f2
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 2 deletions.
1 change: 1 addition & 0 deletions python/tvm/relay/op/_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
register_strategy("searchsorted", strategy.searchsorted_strategy)
register_pattern("searchsorted", OpPattern.OPAQUE)


@script
def _topk_shape_func_input_shape(data_shape, k, axis):
ndim = data_shape.shape[0]
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,10 +563,12 @@ class SparseConv2DAttrs(Attrs):
class TopkAttrs(Attrs):
"""Attributes used in topk operators"""


@tvm._ffi.register_object("relay.attrs.SearchSortedAttrs")
class SearchSortedAttrs(Attrs):
"""Attributes used in searchsorted operators"""


@tvm._ffi.register_object("relay.attrs.TupleGetItemAttrs")
class TupleGetItemAttrs(Attrs):
"""Attributes used in tuple item access operators"""
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/cuda/searchsorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""searchsorted operator for GPU"""
import tvm
from tvm import te
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/searchsorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""searchsorted operator"""
from . import utils
from . import te
Expand All @@ -22,8 +23,7 @@


def binary_search(
ib, sequence_offset, search_range, index, sorted_sequence, values, out_indices,
side, out_dtype
ib, sequence_offset, search_range, index, sorted_sequence, values, out_indices, side, out_dtype
):
"""Common IR generator for CPU and GPU searchsorted."""
lo = ib.allocate(out_dtype, (1,), name="lo", scope="local")
Expand Down
18 changes: 18 additions & 0 deletions python/tvm/topi/testing/searchsorted.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,25 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""The reference implementation of searchsorted in Numpy."""
import numpy as np


def searchsorted_ref(sorted_sequence, values, side, out_dtype):
"""Run Numpy searchsorted on 1-D or N-D sorted_sequence."""
if len(sorted_sequence.shape) == 1 and len(values.shape) > 1:
sorted_sequence_2d = np.tile(sorted_sequence, (np.prod(values.shape[:-1]), 1))
else:
Expand Down

0 comments on commit 8bb70f2

Please sign in to comment.