diff --git a/src/common/transformations/include/transformations/op_conversions/convert_squeeze15_downgrade.hpp b/src/common/transformations/include/transformations/op_conversions/convert_squeeze15_downgrade.hpp new file mode 100644 index 00000000000000..c2ebfbc0f3138b --- /dev/null +++ b/src/common/transformations/include/transformations/op_conversions/convert_squeeze15_downgrade.hpp @@ -0,0 +1,23 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/matcher_pass.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { +/** + * @ingroup ov_transformation_common_api + * @brief Converts Squeeze v15 to Squeeze v0. + */ +class TRANSFORMATIONS_API ConvertSqueeze15ToSqueeze0 : public MatcherPass { +public: + OPENVINO_RTTI("ConvertSqueeze15ToSqueeze0", "0"); + ConvertSqueeze15ToSqueeze0(); +}; + +} // namespace pass +} // namespace ov diff --git a/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp b/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp index 9d46b583a828f2..37ee2d12d9aebb 100644 --- a/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp @@ -98,6 +98,7 @@ #include "transformations/op_conversions/convert_softmax_downgrade.hpp" #include "transformations/op_conversions/convert_softmax_upgrade.hpp" #include "transformations/op_conversions/convert_space_to_depth.hpp" +#include "transformations/op_conversions/convert_squeeze15_downgrade.hpp" #include "transformations/op_conversions/convert_subtract.hpp" #include "transformations/op_conversions/convert_topk11_downgrade.hpp" #include "transformations/op_conversions/convert_xor_to_logical_xor.hpp" @@ -235,6 +236,7 @@ bool ov::pass::CommonOptimizations::run_on_model(const std::shared_ptr(); ADD_MATCHER(fq_fusions, FakeQuantizeMulFusion) diff --git a/src/common/transformations/src/transformations/op_conversions/convert_squeeze15_downgrade.cpp b/src/common/transformations/src/transformations/op_conversions/convert_squeeze15_downgrade.cpp new file mode 100644 index 00000000000000..50701d3d6acd56 --- /dev/null +++ b/src/common/transformations/src/transformations/op_conversions/convert_squeeze15_downgrade.cpp @@ -0,0 +1,40 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/op_conversions/convert_squeeze15_downgrade.hpp" + +#include "itt.hpp" +#include "openvino/core/rt_info.hpp" +#include "openvino/op/squeeze.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "transformations/utils/utils.hpp" + +ov::pass::ConvertSqueeze15ToSqueeze0::ConvertSqueeze15ToSqueeze0() { + MATCHER_SCOPE(ConvertSqueeze15ToSqueeze0); + + const auto& squeeze_v15_pattern = pattern::wrap_type(); + + const matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](pattern::Matcher& m) { + const auto& squeeze_v15 = ov::as_type_ptr(m.get_match_root()); + if (!squeeze_v15 || transformation_callback(squeeze_v15)) { + return false; + } + std::shared_ptr squeeze_v0; + if (squeeze_v15->get_input_size() == 1) { + squeeze_v0 = std::make_shared(squeeze_v15->input_value(0)); + } else if (squeeze_v15->get_input_size() == 2 && !squeeze_v15->get_allow_axis_skip()) { + squeeze_v0 = std::make_shared(squeeze_v15->input_value(0), squeeze_v15->input_value(1)); + } else { + return false; + } + squeeze_v0->set_friendly_name(squeeze_v15->get_friendly_name()); + copy_runtime_info(squeeze_v15, squeeze_v0); + replace_node(squeeze_v15, squeeze_v0); + + return true; + }; + + auto m = std::make_shared(squeeze_v15_pattern, matcher_name); + register_matcher(m, callback); +} diff --git a/src/common/transformations/tests/op_conversions/convert_squeeze15_downgrade_test.cpp b/src/common/transformations/tests/op_conversions/convert_squeeze15_downgrade_test.cpp new file mode 100644 index 00000000000000..f3d90ab2c748bd --- /dev/null +++ b/src/common/transformations/tests/op_conversions/convert_squeeze15_downgrade_test.cpp @@ -0,0 +1,112 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/op_conversions/convert_squeeze15_downgrade.hpp" + +#include + +#include + +#include "common_test_utils/ov_test_utils.hpp" +#include "openvino/opsets/opset1.hpp" +#include "openvino/opsets/opset15.hpp" +#include "openvino/pass/manager.hpp" +#include "transformations/utils/utils.hpp" +using namespace ov; +using namespace testing; + +namespace { + +enum class IndicesMode { NONE, CONST, PARAM }; + +std::shared_ptr create_v15_model(const IndicesMode indices_mode, + const std::vector indices_const_val, + const bool allow_axis_skip) { + const PartialShape data_shape{-1, {2, 5}, 1, {1, 5}, 4}; + const auto& data = std::make_shared(ov::element::f32, data_shape); + ov::ParameterVector params = {data}; + std::shared_ptr squeeze; + if (indices_mode == IndicesMode::NONE) { + squeeze = std::make_shared(data, allow_axis_skip); + } else if (indices_mode == IndicesMode::PARAM) { + const auto& indices = + std::make_shared(ov::element::i32, PartialShape({data_shape.rank()})); + params.push_back(indices); + squeeze = std::make_shared(data, indices, allow_axis_skip); + } else if (indices_mode == IndicesMode::CONST) { + const auto& indices = + ov::opset15::Constant::create(ov::element::i32, Shape({indices_const_val.size()}), indices_const_val); + squeeze = std::make_shared(data, indices, allow_axis_skip); + } + squeeze->set_friendly_name("squeeze15"); + return std::make_shared(squeeze->outputs(), params); +} + +std::shared_ptr create_v1_model(const IndicesMode indices_mode, const std::vector indices_const_val) { + const PartialShape data_shape{-1, {2, 5}, 1, {1, 5}, 4}; + const auto& data = std::make_shared(ov::element::f32, data_shape); + ov::ParameterVector params = {data}; + std::shared_ptr squeeze; + if (indices_mode == IndicesMode::NONE) { + squeeze = std::make_shared(data); + } else if (indices_mode == IndicesMode::PARAM) { + const auto& indices = + std::make_shared(ov::element::i32, PartialShape({data_shape.rank()})); + params.push_back(indices); + squeeze = std::make_shared(data, indices); + } else if (indices_mode == IndicesMode::CONST) { + const auto& indices = + ov::opset1::Constant::create(ov::element::i32, Shape({indices_const_val.size()}), indices_const_val); + squeeze = std::make_shared(data, indices); + } + squeeze->set_friendly_name("squeeze15"); + return std::make_shared(squeeze->outputs(), params); +} + +} // namespace + +TEST_F(TransformationTestsF, ConvertSqueeze15ToSqueeze1_no_indices_no_skip) { + manager.register_pass(); + model = create_v15_model(IndicesMode::NONE, {}, false); + model_ref = create_v1_model(IndicesMode::NONE, {}); + EXPECT_EQ(model->output(0).get_partial_shape(), model_ref->output(0).get_partial_shape()); + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); + comparator.enable(FunctionsComparator::CmpValues::NAMES); +} + +TEST_F(TransformationTestsF, ConvertSqueeze15ToSqueeze1_no_indices_skip) { + manager.register_pass(); + model = create_v15_model(IndicesMode::NONE, {}, true); + model_ref = create_v1_model(IndicesMode::NONE, {}); + EXPECT_EQ(model->output(0).get_partial_shape(), model_ref->output(0).get_partial_shape()); + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); + comparator.enable(FunctionsComparator::CmpValues::NAMES); +} + +TEST_F(TransformationTestsF, ConvertSqueeze15ToSqueeze1_const_indices_no_skip) { + manager.register_pass(); + model = create_v15_model(IndicesMode::CONST, {0, -4, 3}, false); + model_ref = create_v1_model(IndicesMode::CONST, {0, -4, 3}); + EXPECT_EQ(model->output(0).get_partial_shape(), model_ref->output(0).get_partial_shape()); + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); + comparator.enable(FunctionsComparator::CmpValues::NAMES); +} + +TEST_F(TransformationTestsF, ConvertSqueeze15ToSqueeze1_dynamic_indices_no_skip) { + manager.register_pass(); + model = create_v15_model(IndicesMode::PARAM, {}, false); + model_ref = create_v1_model(IndicesMode::PARAM, {}); + EXPECT_EQ(model->output(0).get_partial_shape(), model_ref->output(0).get_partial_shape()); + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); + comparator.enable(FunctionsComparator::CmpValues::NAMES); +} + +TEST_F(TransformationTestsF, ConvertSqueeze15ToSqueeze1_unsupported_skip) { + manager.register_pass(); + model = create_v15_model(IndicesMode::PARAM, {}, true); +}