Skip to content

Commit

Permalink
[Op]: Added SearchSorted op def. (openvinotoolkit#26904)
Browse files Browse the repository at this point in the history
### Details:
 - Added SearchSorted op def with unittests.

### Tickets:
 - *CVS-154141*

NOTE: Depends on openvinotoolkit#26887

---------

Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com>
Co-authored-by: Pawel Raasz <pawel.raasz@intel.com>
  • Loading branch information
3 people authored Oct 9, 2024
1 parent 73e9c6c commit 8517c36
Show file tree
Hide file tree
Showing 6 changed files with 272 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/core/include/openvino/op/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@
#include "openvino/op/scatter_elements_update.hpp"
#include "openvino/op/scatter_nd_update.hpp"
#include "openvino/op/scatter_update.hpp"
#include "openvino/op/search_sorted.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/selu.hpp"
#include "openvino/op/shape_of.hpp"
Expand Down
46 changes: 46 additions & 0 deletions src/core/include/openvino/op/search_sorted.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/op/op.hpp"

namespace ov {
namespace op {
namespace v15 {
/// \brief SearchSorted operation.
///
/// \ingroup ov_ops_cpp_api
class OPENVINO_API SearchSorted : public Op {
public:
OPENVINO_OP("SearchSorted", "opset15", Op);

SearchSorted() = default;
/// \brief Constructs a SearchSorted operation.
/// \param sorted_sequence Sorted sequence to search in.
/// \param values Values to search indexs for.
/// \param right_mode If False, return the first suitable index that is found for given value. If True, return
/// the last such index.
SearchSorted(const Output<Node>& sorted_sequence, const Output<Node>& values, bool right_mode = false);

void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;

bool get_right_mode() const {
return m_right_mode;
}

void set_right_mode(bool right_mode) {
m_right_mode = right_mode;
}

bool validate() const;

private:
bool m_right_mode{};
};
} // namespace v15
} // namespace op
} // namespace ov
35 changes: 35 additions & 0 deletions src/core/shape_inference/include/search_sorted_shape_inference.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/op/search_sorted.hpp"
#include "utils.hpp"

namespace ov {
namespace op {
namespace v15 {
template <class TShape, class TRShape = result_shape_t<TShape>>
std::vector<TRShape> shape_infer(const SearchSorted* op, const std::vector<TShape>& input_shapes) {
// [HACK]: By convention, shape_infer should also perform node validation..
op->validate();
const auto& sorted_shape = input_shapes[0];
const auto& values_shape = input_shapes[1];
auto output_shape = values_shape;
TShape::merge_into(output_shape, sorted_shape);

if (output_shape.rank().is_static()) {
auto last_it = output_shape.end() - 1;
if (values_shape.rank().is_static()) {
*last_it = *(input_shapes[1].end() - 1);
} else {
*last_it = Dimension::dynamic();
}
}

return {std::move(output_shape)};
}
} // namespace v15
} // namespace op
} // namespace ov
66 changes: 66 additions & 0 deletions src/core/src/op/search_sorted.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <openvino/op/search_sorted.hpp>

#include "itt.hpp"
#include "openvino/core/validation_util.hpp"
#include "search_sorted_shape_inference.hpp"

namespace ov {
namespace op {
namespace v15 {

SearchSorted::SearchSorted(const Output<Node>& sorted_sequence, const Output<Node>& values, bool right_mode)
: Op({sorted_sequence, values}),
m_right_mode(right_mode) {
constructor_validate_and_infer_types();
}

bool SearchSorted::validate() const {
NODE_VALIDATION_CHECK(this, get_input_size() == 2);
NODE_VALIDATION_CHECK(this,
get_input_element_type(0) == get_input_element_type(1),
"Sorted sequence and values must have the same element type.");

const auto& sorted_shape = get_input_partial_shape(0);
const auto& values_shape = get_input_partial_shape(1);

if (sorted_shape.rank().is_static() && values_shape.rank().is_static() && sorted_shape.rank().get_length() > 1) {
NODE_VALIDATION_CHECK(this,
sorted_shape.rank().get_length() == values_shape.rank().get_length(),
"Sorted sequence and values have different ranks.");

for (int64_t i = 0; i < sorted_shape.rank().get_length() - 1; ++i) {
NODE_VALIDATION_CHECK(this,
sorted_shape[i].compatible(values_shape[i]),
"Sorted sequence and values has different ",
i,
" dimension.");
}
}

return true;
}

void SearchSorted::validate_and_infer_types() {
OV_OP_SCOPE(v15_SearchSorted_validate_and_infer_types);
const auto& output_shapes = shape_infer(this, ov::util::get_node_input_partial_shapes(*this));
set_output_type(0, ov::element::i64, output_shapes[0]);
}

bool SearchSorted::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v15_SearchSorted_visit_attributes);
visitor.on_attribute("right_mode", m_right_mode);
return true;
}

std::shared_ptr<Node> SearchSorted::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v15_SearchSorted_clone_with_new_inputs);
check_new_args_count(this, new_args);
return std::make_shared<SearchSorted>(new_args.at(0), new_args.at(1), get_right_mode());
}
} // namespace v15
} // namespace op
} // namespace ov
94 changes: 94 additions & 0 deletions src/core/tests/type_prop/search_sorted.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/op/search_sorted.hpp"

#include "common_test_utils/test_assertions.hpp"
#include "common_test_utils/type_prop.hpp"

using namespace std;
using namespace ov;

#define EXPECT_THROW_SUBSTRING(SORTED, VALUES, SUBSTRING) \
OV_EXPECT_THROW_HAS_SUBSTRING(std::ignore = make_shared<op::v15::SearchSorted>(SORTED, VALUES), \
NodeValidationFailure, \
SUBSTRING);

static void PerformShapeTest(const PartialShape& sorted_shape,
const PartialShape& values_shape,
const PartialShape& expected_output_shape) {
auto sorted = make_shared<op::v0::Parameter>(element::i32, sorted_shape);
auto values = make_shared<op::v0::Parameter>(element::i32, values_shape);
auto search_sorted_op = make_shared<op::v15::SearchSorted>(sorted, values);
EXPECT_EQ(search_sorted_op->get_element_type(), element::i64);
EXPECT_EQ(search_sorted_op->get_output_partial_shape(0), expected_output_shape);
}

TEST(type_prop, search_sorted_shape_infer_equal_inputs) {
PerformShapeTest({1, 3, 6}, {1, 3, 6}, {1, 3, 6});
}

TEST(type_prop, search_sorted_shape_infer_sorted_dynamic) {
PerformShapeTest(PartialShape::dynamic(), {1, 3, 6}, {1, 3, 6});
}

TEST(type_prop, search_sorted_shape_infer_values_dynamic) {
PerformShapeTest({1, 3, 7, 5}, PartialShape::dynamic(), {1, 3, 7, -1});
}

TEST(type_prop, search_sorted_shape_infer_different_last_dim) {
PerformShapeTest({1, 3, 7, 100}, {1, 3, 7, 10}, {1, 3, 7, 10});
}

TEST(type_prop, search_sorted_shape_infer_sorted_1d) {
PerformShapeTest({5}, {2, 3}, {2, 3});
}

TEST(type_prop, search_sorted_shape_infer_sorted_and_values_1d) {
PerformShapeTest({5}, {20}, {20});
}

TEST(type_prop, search_sorted_shape_infer_sorted_1d_values_dynamic) {
PerformShapeTest({8}, {-1, -1, 3}, {-1, -1, 3});
}

TEST(type_prop, search_sorted_shape_infer_both_dynamic_1) {
PerformShapeTest({1, -1, 7, -1}, {-1, 3, -1, 10}, {1, 3, 7, 10});
}

TEST(type_prop, search_sorted_shape_infer_both_dynamic_2) {
PerformShapeTest({1, -1, 7, 50}, {-1, 3, -1, -1}, {1, 3, 7, -1});
}

TEST(type_prop, search_sorted_shape_infer_both_dynamic_3) {
PerformShapeTest(PartialShape::dynamic(), PartialShape::dynamic(), PartialShape::dynamic());
}

TEST(type_prop, search_sorted_shape_infer_both_dynamic_4) {
PerformShapeTest({-1, -1, 50}, {-1, -1, 20}, {-1, -1, 20});
}

TEST(type_prop, search_sorted_shape_infer_both_dynamic_5) {
PerformShapeTest({-1}, {-1, -1, 3}, {-1, -1, 3});
}

TEST(type_prop, search_sorted_shape_infer_different_types) {
auto sorted = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1, 3, 6});
auto values = make_shared<ov::op::v0::Parameter>(element::i32, Shape{1, 3, 6});
EXPECT_THROW_SUBSTRING(values, sorted, std::string("must have the same element type"));
}

TEST(type_prop, search_sorted_shape_infer_wrong_rank) {
auto sorted = make_shared<ov::op::v0::Parameter>(element::i32, Shape{1, 1, 3, 6});
auto values = make_shared<ov::op::v0::Parameter>(element::i32, Shape{1, 3, 6});
EXPECT_THROW_SUBSTRING(sorted, values, std::string("Sorted sequence and values have different ranks"));
}

TEST(type_prop, search_sorted_shape_infer_wrong_dim) {
auto sorted = make_shared<ov::op::v0::Parameter>(element::i32, Shape{1, 1, 3, 6});
auto values = make_shared<ov::op::v0::Parameter>(element::i32, Shape{1, 1, 5, 6});
EXPECT_THROW_SUBSTRING(sorted, values, std::string(" different 2 dimension."));
}

#undef EXPECT_THROW_SUBSTRING
30 changes: 30 additions & 0 deletions src/core/tests/visitors/op/sorted_search.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <gtest/gtest.h>

#include "openvino/op/search_sorted.hpp"
#include "visitors/visitors.hpp"

using namespace std;
using namespace ov;
using ov::test::NodeBuilder;

TEST(attributes, search_sorted_op) {
using TOp = ov::op::v15::SearchSorted;
NodeBuilder::opset().insert<TOp>();
auto sorted = make_shared<ov::op::v0::Parameter>(element::i32, Shape{2, 3, 50, 50});
auto values = make_shared<ov::op::v0::Parameter>(element::i32, Shape{2, 3, 50, 50});

auto op = make_shared<TOp>(sorted, values);
NodeBuilder builder(op, {sorted, values});
auto g_op = ov::as_type_ptr<TOp>(builder.create());

// attribute count
const auto expected_attr_count = 1;
EXPECT_EQ(builder.get_value_map_size(), expected_attr_count);

// space_to_depth attributes
EXPECT_EQ(g_op->get_right_mode(), op->get_right_mode());
}

0 comments on commit 8517c36

Please sign in to comment.