diff --git a/inference-engine/src/transformations/src/transformations/op_conversions/gather_normalize_negative_indices.cpp b/inference-engine/src/transformations/src/transformations/op_conversions/gather_normalize_negative_indices.cpp index 86713451869345..ad16993c98703d 100644 --- a/inference-engine/src/transformations/src/transformations/op_conversions/gather_normalize_negative_indices.cpp +++ b/inference-engine/src/transformations/src/transformations/op_conversions/gather_normalize_negative_indices.cpp @@ -7,6 +7,7 @@ #include #include +#include #include #include #include "itt.hpp" @@ -18,11 +19,11 @@ ngraph::pass::GatherNegativeConstIndicesNormalize::GatherNegativeConstIndicesNor auto data_input = ngraph::pattern::any_input(pattern::has_static_rank()); auto axis_input = ngraph::pattern::wrap_type(); auto indices_input = ngraph::pattern::wrap_type(); - auto gather_node = std::make_shared(data_input, indices_input, axis_input); + auto gather_node = ngraph::pattern::wrap_type({data_input, indices_input, axis_input}); ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { auto& pattern_to_output = m.get_pattern_value_map(); - auto gather = std::dynamic_pointer_cast(pattern_to_output.at(gather_node).get_node_shared_ptr()); + auto gather = pattern_to_output.at(gather_node).get_node_shared_ptr(); auto data = pattern_to_output.at(data_input); auto axis_constant = std::dynamic_pointer_cast(pattern_to_output.at(axis_input).get_node_shared_ptr()); auto indices_constant = std::dynamic_pointer_cast(pattern_to_output.at(indices_input).get_node_shared_ptr()); @@ -62,12 +63,12 @@ ngraph::pass::GatherNegativeConstIndicesNormalize::GatherNegativeConstIndicesNor auto input_gather = std::make_shared(shape_of, ngraph::opset7::Constant::create(input_type, Shape{}, {axis_value}), ngraph::opset7::Constant::create(input_type, Shape{}, {0})); - auto add = std::make_shared(input_gather, indices_constant); - auto gather_new = gather_node->copy_with_new_inputs({data, add, axis_constant}); - gather_new->set_friendly_name(gather->get_friendly_name()); + std::shared_ptr add = std::make_shared(input_gather, indices_constant); + if (auto folded_const = ngraph::get_constant_from_source(add)) + add = folded_const; + gather->input(1).replace_source_output(add); - ngraph::copy_runtime_info(gather, {shape_of, input_gather, add, gather_new}); - ngraph::replace_node(gather, gather_new); + ngraph::copy_runtime_info(gather, {shape_of, input_gather, add}); return true; }; diff --git a/inference-engine/tests/functional/inference_engine/transformations/gather_normalize_negative_indices_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/gather_normalize_negative_indices_test.cpp index ec6c4204a9b8a0..1e600d5f300893 100644 --- a/inference-engine/tests/functional/inference_engine/transformations/gather_normalize_negative_indices_test.cpp +++ b/inference-engine/tests/functional/inference_engine/transformations/gather_normalize_negative_indices_test.cpp @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -46,7 +47,10 @@ TEST(TransformationTests, GatherNegativeIndicesNormalize) { auto input_gather = std::make_shared(shape_of, ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {1}), ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {0})); auto add = std::make_shared(input_gather, indices); - auto gather = std::make_shared(data, add, axis); + auto const_add = ngraph::get_constant_from_source(add); + if (const_add == nullptr) + throw ngraph::ngraph_error("indices should've been constant folded"); + auto gather = std::make_shared(data, const_add, axis); f_ref = std::make_shared(ngraph::NodeVector{gather}, ngraph::ParameterVector{data}); } @@ -84,7 +88,10 @@ TEST(TransformationTests, GatherNegativeIndicesNormalize_neg_axis) { auto input_gather = std::make_shared(shape_of, ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {1}), ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {0})); auto add = std::make_shared(input_gather, indices); - auto gather = std::make_shared(data, add, axis); + auto const_add = ngraph::get_constant_from_source(add); + if (const_add == nullptr) + throw ngraph::ngraph_error("indices should've been constant folded"); + auto gather = std::make_shared(data, const_add, axis); f_ref = std::make_shared(ngraph::NodeVector{gather}, ngraph::ParameterVector{data}); } @@ -122,7 +129,10 @@ TEST(TransformationTests, GatherNegativeIndicesNormalize_dif_input_types) { auto input_gather = std::make_shared(shape_of, ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {1}), ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {0})); auto add = std::make_shared(input_gather, indices); - auto gather = std::make_shared(data, add, axis); + auto const_add = ngraph::get_constant_from_source(add); + if (const_add == nullptr) + throw ngraph::ngraph_error("indices should've been constant folded"); + auto gather = std::make_shared(data, const_add, axis); f_ref = std::make_shared(ngraph::NodeVector{gather}, ngraph::ParameterVector{data}); }