Skip to content

Commit

Permalink
[Transformations] Add SliceScatter-15 decomposition transformation (o…
Browse files Browse the repository at this point in the history
…penvinotoolkit#27136)

### Details:
- *Add SliceScatter-15 decomposition transformation for unsupported
plugins*
 - *...*

### Tickets:
 - *CVS-151158*

---------

Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com>
  • Loading branch information
mmikolajcz and mlukasze authored Oct 24, 2024
1 parent 97b20c9 commit cbeb131
Show file tree
Hide file tree
Showing 6 changed files with 254 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/matcher_pass.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

class TRANSFORMATIONS_API ConvertSliceScatter;

} // namespace pass
} // namespace ov

class ov::pass::ConvertSliceScatter : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ConvertSliceScatter", "0");
ConvertSliceScatter();
};
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
#include "transformations/op_conversions/convert_scatter_elements_update12_downgrade.hpp"
#include "transformations/op_conversions/convert_scatter_nd_update15_downgrade.hpp"
#include "transformations/op_conversions/convert_slice_to_strided_slice.hpp"
#include "transformations/op_conversions/convert_slicescatter.hpp"
#include "transformations/op_conversions/convert_softmax_downgrade.hpp"
#include "transformations/op_conversions/convert_softmax_upgrade.hpp"
#include "transformations/op_conversions/convert_space_to_depth.hpp"
Expand Down Expand Up @@ -233,6 +234,7 @@ bool ov::pass::CommonOptimizations::run_on_model(const std::shared_ptr<ov::Model
REGISTER_PASS(manager, ConvertEmbeddingBagOffsets15ToEmbeddingBagOffsetsSum3)
REGISTER_PASS(manager, ConvertEmbeddingBagPacked15ToEmbeddingBagPackedSum3)
REGISTER_PASS(manager, ConvertScatterNDUpdate15ToScatterNDUpdate3)
REGISTER_PASS(manager, ConvertSliceScatter)

auto fq_fusions = manager.register_pass<GraphRewrite>();
ADD_MATCHER(fq_fusions, FakeQuantizeMulFusion)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/op_conversions/convert_slicescatter.hpp"

#include <memory>
#include <vector>

#include "itt.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/reduce_prod.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/scatter_nd_update.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/slice_scatter.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"

ov::pass::ConvertSliceScatter::ConvertSliceScatter() {
MATCHER_SCOPE(ConvertSliceScatter);

const auto& slicescatter = pattern::wrap_type<ov::op::v15::SliceScatter>();

const matcher_pass_callback callback = [this](pattern::Matcher& m) {
const auto& slice_node = ov::as_type_ptr<ov::op::v15::SliceScatter>(m.get_match_root());
if (!slice_node || transformation_callback(slice_node)) {
return false;
}
NodeRegistry node_registry;
const auto& const_0 = node_registry.make<ov::op::v0::Constant>(ov::element::i64, Shape{}, 0);
const auto& const_1 = node_registry.make<ov::op::v0::Constant>(ov::element::i64, Shape{}, 1);
const auto& const_1d_neg_1 =
node_registry.make<ov::op::v0::Constant>(ov::element::i64, Shape{1}, std::vector<int64_t>{-1});
const auto& const_scatter_indices_shape =
node_registry.make<ov::op::v0::Constant>(ov::element::i64, Shape{2}, std::vector<int64_t>{-1, 1});
const auto& data_shape = node_registry.make<ov::op::v3::ShapeOf>(slice_node->input_value(0), ov::element::i64);
const auto& num_elements_data = node_registry.make<ov::op::v1::ReduceProd>(data_shape, const_0, false);
const auto& data_indices_flatten =
node_registry.make<ov::op::v4::Range>(const_0, num_elements_data, const_1, ov::element::i64);
const auto& full_data_indices =
node_registry.make<ov::op::v1::Reshape>(data_indices_flatten, data_shape, false);
std::shared_ptr<ov::op::v8::Slice> slice_indices;
if (slice_node->get_input_size() == 5) {
slice_indices = node_registry.make<ov::op::v8::Slice>(full_data_indices,
slice_node->input_value(2),
slice_node->input_value(3),
slice_node->input_value(4));
} else {
slice_indices = node_registry.make<ov::op::v8::Slice>(full_data_indices,
slice_node->input_value(2),
slice_node->input_value(3),
slice_node->input_value(4),
slice_node->input_value(5));
}
const auto& slice_indices_flatten =
node_registry.make<ov::op::v1::Reshape>(slice_indices, const_scatter_indices_shape, false);
const auto& updates_flatten =
node_registry.make<ov::op::v1::Reshape>(slice_node->input_value(1), const_1d_neg_1, false);
const auto& data_flatten =
node_registry.make<ov::op::v1::Reshape>(slice_node->input_value(0), const_1d_neg_1, false);
const auto& output_flatten =
node_registry.make<ov::op::v3::ScatterNDUpdate>(data_flatten, slice_indices_flatten, updates_flatten);
const auto& output = node_registry.make<ov::op::v1::Reshape>(output_flatten, data_shape, false);

output->set_friendly_name(slice_node->get_friendly_name());
copy_runtime_info(slice_node, node_registry.get());
replace_node(slice_node, output);

return true;
};

const auto& m = std::make_shared<pattern::Matcher>(slicescatter, matcher_name);
this->register_matcher(m, callback);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <gtest/gtest.h>

#include <memory>

#include "common_test_utils/ov_test_utils.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/opsets/opset15.hpp"
#include "openvino/opsets/opset8.hpp"
#include "openvino/pass/manager.hpp"
#include "transformations/op_conversions/convert_slicescatter.hpp"
#include "transformations/utils/utils.hpp"
namespace {
class ConvertSliceScatterTest : public TransformationTestsF, public testing::WithParamInterface<ov::NodeVector> {
private:
void SetUp() override {
TransformationTestsF::SetUp();
const auto& inputs = GetParam();
manager.register_pass<ov::pass::ConvertSliceScatter>();
model = create_v15_model(inputs);
model_ref = create_decomposed_model(inputs);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
comparator.enable(FunctionsComparator::CmpValues::NAMES);
}

protected:
std::shared_ptr<ov::Model> create_v15_model(ov::NodeVector inputs) {
const auto& data = inputs.at(0);
const auto& updates = inputs.at(1);
const auto& start = inputs.at(2);
const auto& stop = inputs.at(3);
const auto& step = inputs.at(4);
ov::ParameterVector params{};
for (const auto& inp : inputs) {
const auto& param = ov::as_type_ptr<ov::op::v0::Parameter>(inp);
if (param) {
params.push_back(param);
}
}
std::shared_ptr<ov::opset15::SliceScatter> slicescatter;
if (inputs.size() == 5) {
slicescatter = std::make_shared<ov::opset15::SliceScatter>(data, updates, start, stop, step);
} else {
slicescatter = std::make_shared<ov::opset15::SliceScatter>(data, updates, start, stop, step, inputs.at(5));
}
slicescatter->set_friendly_name("slicescatter15");
return std::make_shared<ov::Model>(slicescatter->outputs(), params);
}

std::shared_ptr<ov::Model> create_decomposed_model(ov::NodeVector inputs) {
const auto& data = inputs.at(0);
const auto& updates = inputs.at(1);
const auto& start = inputs.at(2);
const auto& stop = inputs.at(3);
const auto& step = inputs.at(4);
ov::ParameterVector params{};
for (const auto& inp : inputs) {
const auto& param = ov::as_type_ptr<ov::op::v0::Parameter>(inp);
if (param) {
params.push_back(param);
}
}
const auto& const_0 = ov::op::v0::Constant::create(ov::element::i64, {}, {0});
const auto& const_1 = ov::op::v0::Constant::create(ov::element::i64, {}, {1});
const auto& const_1d_neg_1 = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});
const auto& const_scatter_indices_shape = ov::op::v0::Constant::create(ov::element::i64, {2}, {-1, 1});
const auto& data_shape = std::make_shared<ov::opset8::ShapeOf>(data, ov::element::i64);
const auto& num_elements_data = std::make_shared<ov::opset8::ReduceProd>(data_shape, const_0, false);
const auto& data_indices_flatten =
std::make_shared<ov::opset8::Range>(const_0, num_elements_data, const_1, ov::element::i64);
const auto& full_data_indices = std::make_shared<ov::opset8::Reshape>(data_indices_flatten, data_shape, false);
std::shared_ptr<ov::opset8::Slice> slice_indices;
if (inputs.size() == 5) {
slice_indices = std::make_shared<ov::opset8::Slice>(full_data_indices, start, stop, step);
} else {
slice_indices = std::make_shared<ov::opset8::Slice>(full_data_indices, start, stop, step, inputs.at(5));
}
const auto& slice_indices_flatten =
std::make_shared<ov::opset8::Reshape>(slice_indices, const_scatter_indices_shape, false);
const auto& updates_flatten = std::make_shared<ov::opset8::Reshape>(updates, const_1d_neg_1, false);
const auto& data_flatten = std::make_shared<ov::opset8::Reshape>(data, const_1d_neg_1, false);
const auto& output_flatten =
std::make_shared<ov::opset8::ScatterNDUpdate>(data_flatten, slice_indices_flatten, updates_flatten);
const auto& slicescatter = std::make_shared<ov::opset8::Reshape>(output_flatten, data_shape, false);
slicescatter->set_friendly_name("slicescatter15");
return std::make_shared<ov::Model>(slicescatter->outputs(), params);
}
};

INSTANTIATE_TEST_SUITE_P(
ConvertSliceScatterDecomposition,
ConvertSliceScatterTest,
testing::Values(
ov::NodeVector{
std::make_shared<ov::opset15::Parameter>(ov::element::f32, ov::Shape{256, 10, 15}),
std::make_shared<ov::opset15::Parameter>(ov::element::f32, ov::Shape{4, 7, 2}),
ov::op::v0::Constant::create(ov::element::i32, {3}, {2, -15, 25}),
ov::op::v0::Constant::create(ov::element::i32, {3}, {9, 7, -3}),
ov::op::v0::Constant::create(ov::element::i32, {3}, {2, 1, -1}),
ov::op::v0::Constant::create(ov::element::i32, {3}, {0, 1, -1}),
},
ov::NodeVector{
std::make_shared<ov::opset15::Parameter>(ov::element::f32, ov::Shape{256, 10, 15}),
std::make_shared<ov::opset15::Parameter>(ov::element::f32, ov::Shape{4, 7, 2}),
ov::op::v0::Constant::create(ov::element::i32, {3}, {2, -15, 25}),
ov::op::v0::Constant::create(ov::element::i32, {3}, {9, 7, -3}),
ov::op::v0::Constant::create(ov::element::i32, {3}, {2, 1, -1}),
},
ov::NodeVector{
std::make_shared<ov::opset15::Parameter>(ov::element::i32, ov::PartialShape::dynamic()),
std::make_shared<ov::opset15::Parameter>(ov::element::i32, ov::PartialShape::dynamic()),
ov::op::v0::Constant::create(ov::element::i32, {3}, {2, -15, 25}),
ov::op::v0::Constant::create(ov::element::i32, {3}, {9, 7, -3}),
ov::op::v0::Constant::create(ov::element::i32, {3}, {2, 1, -1}),
ov::op::v0::Constant::create(ov::element::i32, {3}, {0, 1, -1}),
},
ov::NodeVector{
std::make_shared<ov::opset15::Parameter>(ov::element::i32, ov::PartialShape::dynamic()),
std::make_shared<ov::opset15::Parameter>(ov::element::i32, ov::PartialShape::dynamic()),
ov::op::v0::Constant::create(ov::element::i32, {3}, {2, -15, 25}),
ov::op::v0::Constant::create(ov::element::i32, {3}, {9, 7, -3}),
ov::op::v0::Constant::create(ov::element::i32, {3}, {2, 1, -1}),
},
ov::NodeVector{
std::make_shared<ov::opset15::Parameter>(ov::element::i32, ov::PartialShape::dynamic()),
std::make_shared<ov::opset15::Parameter>(ov::element::i32, ov::PartialShape::dynamic()),
std::make_shared<ov::opset15::Parameter>(ov::element::i32, ov::PartialShape::dynamic()),
std::make_shared<ov::opset15::Parameter>(ov::element::i32, ov::PartialShape::dynamic()),
std::make_shared<ov::opset15::Parameter>(ov::element::i32, ov::PartialShape::dynamic()),
std::make_shared<ov::opset15::Parameter>(ov::element::i32, ov::PartialShape::dynamic()),
},
ov::NodeVector{
std::make_shared<ov::opset15::Parameter>(ov::element::i32, ov::PartialShape::dynamic()),
std::make_shared<ov::opset15::Parameter>(ov::element::i32, ov::PartialShape::dynamic()),
std::make_shared<ov::opset15::Parameter>(ov::element::i32, ov::PartialShape::dynamic()),
std::make_shared<ov::opset15::Parameter>(ov::element::i32, ov::PartialShape::dynamic()),
std::make_shared<ov::opset15::Parameter>(ov::element::i32, ov::PartialShape::dynamic()),
}));
TEST_P(ConvertSliceScatterTest, CompareFunctions) {}

} // namespace
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
#include "transformations/op_conversions/convert_scatter_nd_update15_downgrade.hpp"
#include "transformations/op_conversions/convert_sequences_to_tensor_iterator.hpp"
#include "transformations/op_conversions/convert_shuffle_channels3.hpp"
#include "transformations/op_conversions/convert_slicescatter.hpp"
#include "transformations/op_conversions/convert_slice_to_strided_slice.hpp"
#include "transformations/op_conversions/convert_space_to_batch.hpp"
#include "transformations/op_conversions/convert_space_to_depth.hpp"
Expand Down Expand Up @@ -656,6 +657,7 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
CPU_DISABLE_PASS_COMMON(manager, ov::pass::HSwishDecomposition);
CPU_DISABLE_PASS_COMMON(manager, ov::pass::MatMulConstTransposesExtraction);
CPU_DISABLE_PASS_COMMON(manager, ov::pass::ConvertScatterNDUpdate15ToScatterNDUpdate3);
CPU_DISABLE_PASS_COMMON(manager, ov::pass::ConvertSliceScatter);
CPU_DISABLE_PASS_X64(manager, ov::pass::HSigmoidDecomposition);

CPU_DISABLE_PASS_X64(manager, ov::pass::ReduceL1Decomposition);
Expand Down
12 changes: 6 additions & 6 deletions src/plugins/template/backend/ops/scatter_nd_update.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@ bool evaluate(const std::shared_ptr<ov::op::v3::ScatterNDUpdate>& op,
inputs[1].data<const int32_t>(),
inputs[2].data<const T>(),
outputs[0].data<T>(),
op->get_input_shape(0),
op->get_input_shape(1),
op->get_input_shape(2));
inputs[0].get_shape(),
inputs[1].get_shape(),
inputs[2].get_shape());
} else if (idxType == ov::element::i64) {
ov::reference::scatterNdUpdate<T, int64_t>(inputs[0].data<const T>(),
inputs[1].data<const int64_t>(),
inputs[2].data<const T>(),
outputs[0].data<T>(),
op->get_input_shape(0),
op->get_input_shape(1),
op->get_input_shape(2));
inputs[0].get_shape(),
inputs[1].get_shape(),
inputs[2].get_shape());
} else {
OPENVINO_THROW("ScatterNDUpdate layer support only i32 and i64 'indices' input precision!");
}
Expand Down

0 comments on commit cbeb131

Please sign in to comment.