Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle Reshape's special zero in SimplifySecondInputOfReshape #20785

Merged
merged 2 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() {
matcher_pass_callback callback = [=](Matcher& m) {
auto node = m.get_match_root();
const auto reshape = as_type_ptr<v1::Reshape>(node);
if (!reshape || reshape->get_special_zero() == false) {
if (!reshape) {
return false;
}

Expand All @@ -219,7 +219,7 @@ pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() {

auto check_shape_of_gather = [&](const std::shared_ptr<Node>& gather) {
auto shape_of = gather->get_input_node_shared_ptr(0);
if (!is_type<v3::ShapeOf>(shape_of) && !is_type<v0::ShapeOf>(shape_of)) {
if (!is_type<op::util::ShapeOfBase>(shape_of)) {
return false;
}
return shape_of->input_value(0) == data;
Expand All @@ -237,16 +237,15 @@ pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() {
gather_dims_expected_location += concat_input_shape[0];
};

bool special_zero = reshape->get_special_zero();

// We need this check to avoid sequences shapeOf -> gather -> concat
// that change the arrangement of dimensions in the reshape pattern
for (auto& concat_input : new_concat_inputs) {
if (const auto gather = as_type_ptr<op::util::GatherBase>(concat_input.get_node_shared_ptr())) {
auto indices_constant = as_type_ptr<v0::Constant>(gather->get_input_node_shared_ptr(1));
if (!indices_constant || !check_shape_of_gather(gather)) {
update_expected_gather_location(gather);
continue;
}

auto node = concat_input.get_node_shared_ptr();
if (ov::is_type<op::util::GatherBase>(node) &&
ov::is_type<v0::Constant>(node->get_input_node_shared_ptr(1)) && check_shape_of_gather(node)) {
auto indices_constant = as_type_ptr<v0::Constant>(node->get_input_node_shared_ptr(1));
bool gather_can_be_fused = true;
const auto indices = indices_constant->cast_vector<std::int64_t>();
for (size_t i = 0; i < indices.size(); ++i) {
Expand All @@ -258,11 +257,21 @@ pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() {

if (gather_can_be_fused) {
const size_t num_of_unchanged_dimensions = indices.size();
const auto subgraph_et = gather->get_input_element_type(0);
const auto subgraph_et = node->get_input_element_type(0);
concat_input = v0::Constant::create(subgraph_et, Shape{num_of_unchanged_dimensions}, {0});
gather_folded = true;
}
} else {
if (!special_zero) {
// If special zero is false - check if other inputs to Concat are Constants.
// If any of those Constants contain zero - return false.
auto constant = as_type_ptr<v0::Constant>(node);
if (!constant)
return false;
auto values = constant->cast_vector<int64_t>();
if (std::find(values.begin(), values.end(), 0) != values.end())
return false;
}
update_expected_gather_location(concat_input);
}
}
Expand All @@ -275,7 +284,7 @@ pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() {
new_concat->set_friendly_name(concat->get_friendly_name());
copy_runtime_info(concat, new_concat);

const auto new_reshape = reshape->clone_with_new_inputs({reshape->input_value(0), new_concat});
const auto new_reshape = std::make_shared<v1::Reshape>(reshape->input_value(0), new_concat, true);
new_reshape->set_friendly_name(reshape->get_friendly_name());

copy_runtime_info(reshape, new_reshape);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -611,3 +611,53 @@ TEST_F(TransformationTestsF, SimplifySecondInputOfReshapeTest21) {
}
comparator.enable(FunctionsComparator::CONST_VALUES);
}

TEST_F(TransformationTestsF, SimplifySecondInputOfReshapeTestFalseSpecialZero) {
PartialShape data_shape{1, 128, 12, 64};
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);

auto shape_of = std::make_shared<opset7::ShapeOf>(data);
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 1});
auto constant = opset7::Constant::create(element::i64, Shape{1}, {768});
auto concat = std::make_shared<opset7::Concat>(OutputVector{gather_op, constant}, -1);

auto reshape = std::make_shared<opset7::Reshape>(data, concat, false);
model = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});

manager.register_pass<ov::pass::SimplifySecondInputOfReshape>();
}
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{3}, {0, 0, 768});
auto reshape = std::make_shared<opset7::Reshape>(data, reshape_pattern, true);
model_ref = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
}
comparator.enable(FunctionsComparator::ATTRIBUTES);
comparator.enable(FunctionsComparator::CONST_VALUES);
}

TEST_F(TransformationTestsF, SimplifySecondInputOfReshapeTestFalseSpecialZeroZeroDim) {
PartialShape data_shape{1, 0, 12, 64};
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);

auto shape_of = std::make_shared<opset7::ShapeOf>(data);
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 1});
auto constant = opset7::Constant::create(element::i64, Shape{1}, {768});
auto concat = std::make_shared<opset7::Concat>(OutputVector{gather_op, constant}, -1);

auto reshape = std::make_shared<opset7::Reshape>(data, concat, false);
model = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});

manager.register_pass<ov::pass::SimplifySecondInputOfReshape>();
}
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{3}, {0, 0, 768});
auto reshape = std::make_shared<opset7::Reshape>(data, reshape_pattern, true);
model_ref = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
}
comparator.enable(FunctionsComparator::ATTRIBUTES);
comparator.enable(FunctionsComparator::CONST_VALUES);
}
Copy link
Contributor

@slyalin slyalin Nov 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use existing test and specialize them with special_zero=false? What's so special for these two cases?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we can improve this point in the next PR