-
Notifications
You must be signed in to change notification settings - Fork 2.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Pytorch fronted]: Added support for Search Sorted op #26976
Changes from 78 commits
187f1c4
ccb2134
ef5d369
c87ed2f
5987b73
ddee155
58cfcdd
7663026
60984a6
420b709
60eff69
a33b7f9
cd3873d
6b80399
f7b82e0
cb6531a
80b6c25
cf87e7a
4778d91
669d921
ae194d1
0b05df1
738a52b
1939b8f
07ad6db
94aff1e
ea85a2f
ad0b980
f6ba78f
0ac600b
ca60a93
d3fc978
1d98baa
be1c9b4
6f9bec3
c041426
32e9285
25aa185
5990ba9
39cdc36
954f211
580c4ab
6bb89b2
1dd4c67
f3d747b
70be606
83ec200
5aa918c
f3f6c3f
08db8b6
8843e65
d3ac6e0
4ae2598
f173059
d6c3a8f
53d8c6d
7b9c698
b8f9ee5
46ed119
17b6e6f
394e474
ae6a83d
2ed58e1
f138792
ed5d4d8
c9bb4b1
f4b0601
bc4c902
5c20b4c
5a9d440
106c5c3
df1a322
16a4032
a36442a
2f10f4e
3782d0f
dc92c5c
3b056b7
4cee5bd
9765fa3
9b782f3
6b9e286
8b4824c
71ee38a
b179570
d61f3dd
fbc8f32
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "openvino/op/search_sorted.hpp" | ||
|
||
#include "openvino/frontend/pytorch/node_context.hpp" | ||
#include "utils.hpp" | ||
|
||
namespace ov { | ||
namespace frontend { | ||
namespace pytorch { | ||
namespace op { | ||
|
||
using namespace ov::op; | ||
|
||
OutputVector translate_search_sorted(const NodeContext& context) { | ||
num_inputs_check(context, 2, 5); | ||
Output<Node> sorted; | ||
Output<Node> values; | ||
std::tie(sorted, values) = get_inputs_with_promoted_types(context, 0, 1); | ||
const bool out_int32 = context.const_input<bool>(2); | ||
PYTORCH_OP_CONVERSION_CHECK(out_int32 == false, "aten::searchsorted(out_int32=true) unsupported"); | ||
const bool right_mode = context.const_input<bool>(3); | ||
PYTORCH_OP_CONVERSION_CHECK(context.input_is_none(4), "aten::searchsorted(side) unsupported"); | ||
PYTORCH_OP_CONVERSION_CHECK(context.input_is_none(5), "aten::searchsorted(sorter) unsupported"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What index does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
auto op = context.mark_node(std::make_shared<ov::op::v15::SearchSorted>(sorted, values, right_mode)); | ||
return {op}; | ||
}; | ||
} // namespace op | ||
} // namespace pytorch | ||
} // namespace frontend | ||
} // namespace ov |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Copyright (C) 2018-2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
|
||
from pytorch_layer_test_class import PytorchLayerTest | ||
import numpy as np | ||
|
||
|
||
class TestSearchSorted(PytorchLayerTest): | ||
def _prepare_input(self): | ||
return (np.array(self.sorted).astype(self.sorted_type),np.array(self.values).astype(self.values_type)) | ||
|
||
def create_model(self, right_mode): | ||
import torch | ||
|
||
class aten_searchsorted(torch.nn.Module): | ||
def __init__(self, right_mode): | ||
super(aten_searchsorted, self).__init__() | ||
self.right_mode = right_mode | ||
|
||
def forward(self, sorted, values): | ||
return torch.searchsorted(sorted, values, right=self.right_mode) | ||
|
||
ref_net = None | ||
|
||
return aten_searchsorted(right_mode), ref_net, "aten::searchsorted" | ||
|
||
@pytest.mark.nightly | ||
@pytest.mark.precommit | ||
@pytest.mark.parametrize(("sorted", "values"), [([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]], [[3, 6, 9], [3, 6, 9]]),([1, 3, 5, 7, 9], [[3, 6, 9],[0, 5, 20]])]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given this op contains type promotion and input conversion done internally in torch may affect out values, maybe test data should include also few handpicked values with fractional part just to be sure? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point - will prepare a malicious case just for that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FIxed. |
||
@pytest.mark.parametrize("right_mode", [False, True]) | ||
@pytest.mark.parametrize("sorted_type", [np.float32, np.int32, np.int8]) | ||
@pytest.mark.parametrize("values_type", [np.float16, np.int32, np.int64]) | ||
def test_searchsorted(self, sorted, values, right_mode, sorted_type, values_type, ie_device, precision, ir_version): | ||
self.sorted = sorted | ||
self.values = values | ||
self.sorted_type = sorted_type | ||
self.values_type = values_type | ||
self._test(*self.create_model(right_mode), ie_device, precision, ir_version) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it not just the output type of the operation? Can it be supported by just inserting
Convert
to i32?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is a part of aten::searchsorted API and it is(that flag) not supported at the moment as it is not needed.