Skip to content

Commit

Permalink
Handle Reshape's special zero in SimplifySecondInputOfReshape (openvi…
Browse files Browse the repository at this point in the history
…notoolkit#20785)

* Handle Reshape's special zero in SimplifySecondInputOfReshape

SimplifySecondInputOfReshape detects ShapeOf->Gather->Concat
subgraphs on Reshape's second input and replaces ShapeOf->Gather
with a Constant with zero(s). Currently it works only with Reshapes
that have special_zero set to true, but it can work for Reshapes
with special_zero == false if non-Gather inputs to Concat are Constants
and don't contain any zero.

Ticket: CVS-123434

* fix no default output
  • Loading branch information
mateusztabaka authored and alvoron committed Nov 6, 2023
1 parent 52b73ef commit 3c945fb
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() {
matcher_pass_callback callback = [=](Matcher& m) {
auto node = m.get_match_root();
const auto reshape = as_type_ptr<v1::Reshape>(node);
if (!reshape || reshape->get_special_zero() == false) {
if (!reshape) {
return false;
}

Expand All @@ -219,7 +219,7 @@ pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() {

auto check_shape_of_gather = [&](const std::shared_ptr<Node>& gather) {
auto shape_of = gather->get_input_node_shared_ptr(0);
if (!is_type<v3::ShapeOf>(shape_of) && !is_type<v0::ShapeOf>(shape_of)) {
if (!is_type<op::util::ShapeOfBase>(shape_of)) {
return false;
}
return shape_of->input_value(0) == data;
Expand All @@ -237,16 +237,15 @@ pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() {
gather_dims_expected_location += concat_input_shape[0];
};

bool special_zero = reshape->get_special_zero();

// We need this check to avoid sequences shapeOf -> gather -> concat
// that change the arrangement of dimensions in the reshape pattern
for (auto& concat_input : new_concat_inputs) {
if (const auto gather = as_type_ptr<op::util::GatherBase>(concat_input.get_node_shared_ptr())) {
auto indices_constant = as_type_ptr<v0::Constant>(gather->get_input_node_shared_ptr(1));
if (!indices_constant || !check_shape_of_gather(gather)) {
update_expected_gather_location(gather);
continue;
}

auto node = concat_input.get_node_shared_ptr();
if (ov::is_type<op::util::GatherBase>(node) &&
ov::is_type<v0::Constant>(node->get_input_node_shared_ptr(1)) && check_shape_of_gather(node)) {
auto indices_constant = as_type_ptr<v0::Constant>(node->get_input_node_shared_ptr(1));
bool gather_can_be_fused = true;
const auto indices = indices_constant->cast_vector<std::int64_t>();
for (size_t i = 0; i < indices.size(); ++i) {
Expand All @@ -258,11 +257,21 @@ pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() {

if (gather_can_be_fused) {
const size_t num_of_unchanged_dimensions = indices.size();
const auto subgraph_et = gather->get_input_element_type(0);
const auto subgraph_et = node->get_input_element_type(0);
concat_input = v0::Constant::create(subgraph_et, Shape{num_of_unchanged_dimensions}, {0});
gather_folded = true;
}
} else {
if (!special_zero) {
// If special zero is false - check if other inputs to Concat are Constants.
// If any of those Constants contain zero - return false.
auto constant = as_type_ptr<v0::Constant>(node);
if (!constant)
return false;
auto values = constant->cast_vector<int64_t>();
if (std::find(values.begin(), values.end(), 0) != values.end())
return false;
}
update_expected_gather_location(concat_input);
}
}
Expand All @@ -275,7 +284,7 @@ pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() {
new_concat->set_friendly_name(concat->get_friendly_name());
copy_runtime_info(concat, new_concat);

const auto new_reshape = reshape->clone_with_new_inputs({reshape->input_value(0), new_concat});
const auto new_reshape = std::make_shared<v1::Reshape>(reshape->input_value(0), new_concat, true);
new_reshape->set_friendly_name(reshape->get_friendly_name());

copy_runtime_info(reshape, new_reshape);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -611,3 +611,53 @@ TEST_F(TransformationTestsF, SimplifySecondInputOfReshapeTest21) {
}
comparator.enable(FunctionsComparator::CONST_VALUES);
}

TEST_F(TransformationTestsF, SimplifySecondInputOfReshapeTestFalseSpecialZero) {
PartialShape data_shape{1, 128, 12, 64};
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);

auto shape_of = std::make_shared<opset7::ShapeOf>(data);
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 1});
auto constant = opset7::Constant::create(element::i64, Shape{1}, {768});
auto concat = std::make_shared<opset7::Concat>(OutputVector{gather_op, constant}, -1);

auto reshape = std::make_shared<opset7::Reshape>(data, concat, false);
model = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});

manager.register_pass<ov::pass::SimplifySecondInputOfReshape>();
}
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{3}, {0, 0, 768});
auto reshape = std::make_shared<opset7::Reshape>(data, reshape_pattern, true);
model_ref = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
}
comparator.enable(FunctionsComparator::ATTRIBUTES);
comparator.enable(FunctionsComparator::CONST_VALUES);
}

TEST_F(TransformationTestsF, SimplifySecondInputOfReshapeTestFalseSpecialZeroZeroDim) {
PartialShape data_shape{1, 0, 12, 64};
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);

auto shape_of = std::make_shared<opset7::ShapeOf>(data);
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 1});
auto constant = opset7::Constant::create(element::i64, Shape{1}, {768});
auto concat = std::make_shared<opset7::Concat>(OutputVector{gather_op, constant}, -1);

auto reshape = std::make_shared<opset7::Reshape>(data, concat, false);
model = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});

manager.register_pass<ov::pass::SimplifySecondInputOfReshape>();
}
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{3}, {0, 0, 768});
auto reshape = std::make_shared<opset7::Reshape>(data, reshape_pattern, true);
model_ref = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
}
comparator.enable(FunctionsComparator::ATTRIBUTES);
comparator.enable(FunctionsComparator::CONST_VALUES);
}

0 comments on commit 3c945fb

Please sign in to comment.