From 3c945fbde1db815d086bd9544b5e562510706461 Mon Sep 17 00:00:00 2001 From: Mateusz Tabaka Date: Wed, 1 Nov 2023 15:02:36 +0100 Subject: [PATCH] Handle Reshape's special zero in SimplifySecondInputOfReshape (#20785) * Handle Reshape's special zero in SimplifySecondInputOfReshape SimplifySecondInputOfReshape detects ShapeOf->Gather->Concat subgraphs on Reshape's second input and replaces ShapeOf->Gather with a Constant with zero(s). Currently it works only with Reshapes that have special_zero set to true, but it can work for Reshapes with special_zero == false if non-Gather inputs to Concat are Constants and don't contain any zero. Ticket: CVS-123434 * fix no default output --- .../simplify_shape_of_sub_graph.cpp | 31 ++++++++---- .../simplify_second_input_of_reshape_test.cpp | 50 +++++++++++++++++++ 2 files changed, 70 insertions(+), 11 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/simplify_shape_of_sub_graph.cpp b/src/common/transformations/src/transformations/common_optimizations/simplify_shape_of_sub_graph.cpp index 5ef33a33326e00..7facf950ee7bd4 100644 --- a/src/common/transformations/src/transformations/common_optimizations/simplify_shape_of_sub_graph.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/simplify_shape_of_sub_graph.cpp @@ -200,7 +200,7 @@ pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() { matcher_pass_callback callback = [=](Matcher& m) { auto node = m.get_match_root(); const auto reshape = as_type_ptr(node); - if (!reshape || reshape->get_special_zero() == false) { + if (!reshape) { return false; } @@ -219,7 +219,7 @@ pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() { auto check_shape_of_gather = [&](const std::shared_ptr& gather) { auto shape_of = gather->get_input_node_shared_ptr(0); - if (!is_type(shape_of) && !is_type(shape_of)) { + if (!is_type(shape_of)) { return false; } return shape_of->input_value(0) == data; @@ -237,16 +237,15 @@ pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() { gather_dims_expected_location += concat_input_shape[0]; }; + bool special_zero = reshape->get_special_zero(); + // We need this check to avoid sequences shapeOf -> gather -> concat // that change the arrangement of dimensions in the reshape pattern for (auto& concat_input : new_concat_inputs) { - if (const auto gather = as_type_ptr(concat_input.get_node_shared_ptr())) { - auto indices_constant = as_type_ptr(gather->get_input_node_shared_ptr(1)); - if (!indices_constant || !check_shape_of_gather(gather)) { - update_expected_gather_location(gather); - continue; - } - + auto node = concat_input.get_node_shared_ptr(); + if (ov::is_type(node) && + ov::is_type(node->get_input_node_shared_ptr(1)) && check_shape_of_gather(node)) { + auto indices_constant = as_type_ptr(node->get_input_node_shared_ptr(1)); bool gather_can_be_fused = true; const auto indices = indices_constant->cast_vector(); for (size_t i = 0; i < indices.size(); ++i) { @@ -258,11 +257,21 @@ pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() { if (gather_can_be_fused) { const size_t num_of_unchanged_dimensions = indices.size(); - const auto subgraph_et = gather->get_input_element_type(0); + const auto subgraph_et = node->get_input_element_type(0); concat_input = v0::Constant::create(subgraph_et, Shape{num_of_unchanged_dimensions}, {0}); gather_folded = true; } } else { + if (!special_zero) { + // If special zero is false - check if other inputs to Concat are Constants. + // If any of those Constants contain zero - return false. + auto constant = as_type_ptr(node); + if (!constant) + return false; + auto values = constant->cast_vector(); + if (std::find(values.begin(), values.end(), 0) != values.end()) + return false; + } update_expected_gather_location(concat_input); } } @@ -275,7 +284,7 @@ pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() { new_concat->set_friendly_name(concat->get_friendly_name()); copy_runtime_info(concat, new_concat); - const auto new_reshape = reshape->clone_with_new_inputs({reshape->input_value(0), new_concat}); + const auto new_reshape = std::make_shared(reshape->input_value(0), new_concat, true); new_reshape->set_friendly_name(reshape->get_friendly_name()); copy_runtime_info(reshape, new_reshape); diff --git a/src/common/transformations/tests/common_optimizations/simplify_second_input_of_reshape_test.cpp b/src/common/transformations/tests/common_optimizations/simplify_second_input_of_reshape_test.cpp index 7431174daaa0ae..cd8ca2f1f0a640 100644 --- a/src/common/transformations/tests/common_optimizations/simplify_second_input_of_reshape_test.cpp +++ b/src/common/transformations/tests/common_optimizations/simplify_second_input_of_reshape_test.cpp @@ -611,3 +611,53 @@ TEST_F(TransformationTestsF, SimplifySecondInputOfReshapeTest21) { } comparator.enable(FunctionsComparator::CONST_VALUES); } + +TEST_F(TransformationTestsF, SimplifySecondInputOfReshapeTestFalseSpecialZero) { + PartialShape data_shape{1, 128, 12, 64}; + { + auto data = std::make_shared(element::f32, data_shape); + + auto shape_of = std::make_shared(data); + auto gather_op = gather(shape_of, std::vector{0, 1}); + auto constant = opset7::Constant::create(element::i64, Shape{1}, {768}); + auto concat = std::make_shared(OutputVector{gather_op, constant}, -1); + + auto reshape = std::make_shared(data, concat, false); + model = std::make_shared(NodeVector{reshape}, ParameterVector{data}); + + manager.register_pass(); + } + { + auto data = std::make_shared(element::f32, data_shape); + auto reshape_pattern = opset7::Constant::create(element::i64, Shape{3}, {0, 0, 768}); + auto reshape = std::make_shared(data, reshape_pattern, true); + model_ref = std::make_shared(NodeVector{reshape}, ParameterVector{data}); + } + comparator.enable(FunctionsComparator::ATTRIBUTES); + comparator.enable(FunctionsComparator::CONST_VALUES); +} + +TEST_F(TransformationTestsF, SimplifySecondInputOfReshapeTestFalseSpecialZeroZeroDim) { + PartialShape data_shape{1, 0, 12, 64}; + { + auto data = std::make_shared(element::f32, data_shape); + + auto shape_of = std::make_shared(data); + auto gather_op = gather(shape_of, std::vector{0, 1}); + auto constant = opset7::Constant::create(element::i64, Shape{1}, {768}); + auto concat = std::make_shared(OutputVector{gather_op, constant}, -1); + + auto reshape = std::make_shared(data, concat, false); + model = std::make_shared(NodeVector{reshape}, ParameterVector{data}); + + manager.register_pass(); + } + { + auto data = std::make_shared(element::f32, data_shape); + auto reshape_pattern = opset7::Constant::create(element::i64, Shape{3}, {0, 0, 768}); + auto reshape = std::make_shared(data, reshape_pattern, true); + model_ref = std::make_shared(NodeVector{reshape}, ParameterVector{data}); + } + comparator.enable(FunctionsComparator::ATTRIBUTES); + comparator.enable(FunctionsComparator::CONST_VALUES); +}