Skip to content

Commit

Permalink
Handle dynamic rank in TSUnsqueezeBackward transformation (openvinoto…
Browse files Browse the repository at this point in the history
…olkit#26786)

### Details:
Handle dynamic rank in TSUnsqueezeBackward transformation

### Tickets:
 - *CVS-152373*
  • Loading branch information
itikhono authored Sep 30, 2024
1 parent 9b0d209 commit 2507d89
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,19 @@ TSUnsqueezeBackward::TSUnsqueezeBackward() {
return false;
}
} else {
auto rank = main_node->get_output_partial_shape(0).rank();
non_negative_axes =
util::try_get_normalized_axis_vector(unsqueeze_axes->get_tensor_view(), rank, *main_node);
const auto& axes = unsqueeze_axes->cast_vector<int64_t>();
if (std::all_of(axes.begin(), axes.end(), [](int64_t axis) {
return axis >= 0;
})) {
non_negative_axes = std::vector<size_t>(axes.begin(), axes.end());
} else {
auto rank = main_node->get_output_partial_shape(0).rank();
if (rank.is_dynamic()) {
return false;
}
non_negative_axes =
util::try_get_normalized_axis_vector(unsqueeze_axes->get_tensor_view(), rank, *main_node);
}
}

auto transpose_order_values = transpose_order->cast_vector<size_t>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1636,6 +1636,47 @@ auto test_backward_reshape_unsqueeze = []() {
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeUnsqueezeBackward,
TSTestFixture,
test_backward_reshape_unsqueeze());

auto test_backward_unsqueeze_dyn_rank = []() {
TestCase test_case;

// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TSUnsqueezeBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, PartialShape::dynamic()),
constant<int64_t>(element::i32, {2}, {-1}),
};

auto dyn_transpose = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
OutputVector result = out_vec;
for (const auto& idx : idxs) {
const auto& out = out_vec[idx];

// fill the order const with the stub values {-1, -2}
auto order = make_shared<Constant>(element::i32, Shape{2}, vector<int64_t>{-1, -2});
auto transpose = make_shared<Transpose>(out, order);
result[idx] = transpose;
}
return result;
};

// Test model description:
test_case.model.main_op = {CREATE_BINARY_FACTORY(Unsqueeze)};
test_case.model.preprocess_outputs_of_main = {{dyn_transpose}, {{0}}};
test_case.model.model_template = create_model;

// Ref model description, the same as the original model, the transformation is not applied
// it's expected.
test_case.model_ref.main_op = {CREATE_BINARY_FACTORY(Unsqueeze)};
test_case.model_ref.preprocess_outputs_of_main = {{dyn_transpose}, {{0}}};
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};

INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnsqueezeBackwardDynRank,
TSTestFixture,
test_backward_unsqueeze_dyn_rank());
} // namespace common
} // namespace testing
} // namespace transpose_sinking

0 comments on commit 2507d89

Please sign in to comment.