Skip to content

Commit

Permalink
Normalize_L2 relax constant input restriction (#17568)
Browse files Browse the repository at this point in the history
* Normalize_L2 relax constant input restriction

* Fix warning treated as error during windows build
  • Loading branch information
Evgenya Stepyreva authored May 17, 2023
1 parent a880cba commit 293fccc
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 12 deletions.
2 changes: 0 additions & 2 deletions src/core/src/op/normalize_l2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ void op::v0::NormalizeL2::validate_and_infer_types() {
const auto& input_rank = input_pshape.rank();
const auto& axes_rank = axes_pshape.rank();

NODE_VALIDATION_CHECK(this, has_and_set_equal_bounds(input_value(1)), "Input axes must be Constant type");

if (axes_rank.is_static()) {
NODE_VALIDATION_CHECK(this,
axes_rank.get_length() <= 1,
Expand Down
11 changes: 1 addition & 10 deletions src/core/tests/type_prop/normalize_l2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,7 @@ TEST(type_prop, normalize_l2_axes_input_not_constant) {
auto axes = make_shared<op::Parameter>(element::u64, Shape{1});
float eps{1e-6f};
auto eps_mode = op::EpsMode::ADD;

try {
auto normalize = make_shared<op::v0::NormalizeL2>(data, axes, eps, eps_mode);
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input tensor rank.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input axes must be Constant type"));
} catch (...) {
FAIL() << "Deduced type check failed for unexpected reason";
}
ASSERT_NO_THROW(auto op = make_shared<op::v0::NormalizeL2>(data, axes, eps, eps_mode));
}

TEST(type_prop, normalize_l2_invalid_axes_rank) {
Expand Down

0 comments on commit 293fccc

Please sign in to comment.