diff --git a/src/frontends/pytorch/src/op/search_sorted.cpp b/src/frontends/pytorch/src/op/search_sorted.cpp new file mode 100644 index 00000000000000..ca9f6b49ff7bf9 --- /dev/null +++ b/src/frontends/pytorch/src/op/search_sorted.cpp @@ -0,0 +1,34 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/op/search_sorted.hpp" + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +using namespace ov::op; + +OutputVector translate_search_sorted(const NodeContext& context) { + num_inputs_check(context, 2, 5); + Output sorted; + Output values; + std::tie(sorted, values) = get_inputs_with_promoted_types(context, 0, 1); + const bool out_int32 = context.const_input(2); + PYTORCH_OP_CONVERSION_CHECK(out_int32 == false, "aten::searchsorted(out_int32=true) unsupported"); + const bool right_mode = context.const_input(3); + PYTORCH_OP_CONVERSION_CHECK(context.input_is_none(4), "aten::searchsorted(side) unsupported"); + PYTORCH_OP_CONVERSION_CHECK(context.input_is_none(5), "aten::searchsorted(out) unsupported"); + PYTORCH_OP_CONVERSION_CHECK(context.input_is_none(6), "aten::searchsorted(sorter) unsupported"); + auto op = context.mark_node(std::make_shared(sorted, values, right_mode)); + return {op}; +}; +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 195977432e40e5..66c76e33032ef6 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -300,6 +300,7 @@ OP_CONVERTER(translate_reshape_fx); OP_CONVERTER(translate_rsub_fx); OP_CONVERTER(translate_scalar_tensor_fx); OP_CONVERTER(translate_scaled_dot_product_attention_fx); +OP_CONVERTER(translate_search_sorted); OP_CONVERTER(translate_select_scatter_fx); OP_CONVERTER(translate_slice_fx); OP_CONVERTER(translate_slice_scatter_fx); @@ -617,6 +618,7 @@ const std::unordered_map get_supported_ops_ts() { {"aten::rsqrt", op::optional_out}, {"aten::rsqrt_", op::inplace_op}, {"aten::rsub", op::translate_rsub}, + {"aten::searchsorted", op::translate_search_sorted}, {"aten::ScalarImplicit", op::skip_node}, {"aten::scaled_dot_product_attention", op::translate_scaled_dot_product_attention}, {"aten::scatter", op::translate_scatter}, diff --git a/tests/layer_tests/pytorch_tests/test_search_sorted.py b/tests/layer_tests/pytorch_tests/test_search_sorted.py new file mode 100644 index 00000000000000..645033e2ee260b --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_search_sorted.py @@ -0,0 +1,47 @@ +# Copyright (C) 2018-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from pytorch_layer_test_class import PytorchLayerTest +import numpy as np + + +class TestSearchSorted(PytorchLayerTest): + def _prepare_input(self): + return (np.array(self.sorted).astype(self.sorted_type),np.array(self.values).astype(self.values_type)) + + def create_model(self, right_mode): + import torch + + class aten_searchsorted(torch.nn.Module): + def __init__(self, right_mode): + super(aten_searchsorted, self).__init__() + self.right_mode = right_mode + + def forward(self, sorted, values): + return torch.searchsorted(sorted, values, right=self.right_mode) + + ref_net = None + + return aten_searchsorted(right_mode), ref_net, "aten::searchsorted" + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize(("sorted", "values"), [ + ([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]], [[3, 6, 9], [3, 6, 9]]), + ([1, 3, 5, 7, 9], [[3, 6, 9],[0, 5, 20]]), + ([4091, 4092], [[4091, 4092]]), # fp16 cannot exactly represent 4091 number + ([1.23, 2.99], [[1.355, 2.9991]]) + ]) + @pytest.mark.parametrize("right_mode", [False, True]) + @pytest.mark.parametrize("sorted_type", [np.float32, np.float16, np.int8]) + @pytest.mark.parametrize("values_type", [np.float16, np.int32, np.int64]) + def test_searchsorted(self, sorted, values, right_mode, sorted_type, values_type, ie_device, precision, ir_version): + self.sorted = sorted + self.values = values + self.sorted_type = sorted_type + self.values_type = values_type + if ie_device == "CPU" and sorted_type == np.float16 and sorted == [4091, 4092]: + pytest.skip(reason="CPU plugin on defult converts fp16 to fp32, if that happens the test will fail for those malicious values") + self._test(*self.create_model(right_mode), ie_device, precision, ir_version)