Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Op] Add RMSNorm op core class #23842

Merged
merged 37 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
4b8cabf
RMSNorm init
mitruska Apr 2, 2024
5d97e95
Add axes input
mitruska Apr 2, 2024
89f41df
Update clone inputs
mitruska Apr 3, 2024
5462fc5
Register shape infer for cpu
mitruska Apr 3, 2024
577d6d5
Type validation and more tests
mitruska Apr 3, 2024
1ed3951
Add scale shape validation
mitruska Apr 3, 2024
fd267fe
Add default ctor type
mitruska Apr 3, 2024
d34bef1
Add check and test for scale type
mitruska Apr 3, 2024
5134225
Add symbols test
mitruska Apr 3, 2024
f61acd6
Add visit attrubute tests
mitruska Apr 3, 2024
c7e6054
Merge remote-tracking branch 'upstream/master' into mitruska/rms_op_s…
mitruska Apr 3, 2024
7e1ec64
Merge remote-tracking branch 'upstream/master' into mitruska/rms_op_s…
mitruska Apr 8, 2024
c161e21
Update doxy comments
mitruska Apr 8, 2024
ece0b2a
Move shape validation to shape_infer
mitruska Apr 8, 2024
6e5092d
Update test message
mitruska Apr 8, 2024
3bda1ba
Use has_scale_input variable
mitruska Apr 8, 2024
c47adbb
Add output size checks
mitruska Apr 8, 2024
0f4080b
Check attrs in typeprop tests
mitruska Apr 8, 2024
d53b808
Add input size checks
mitruska Apr 8, 2024
fe079cb
Make scale input shape tests parametrized
mitruska Apr 8, 2024
7fe4341
Register shape_infer for rms_norm
mitruska Apr 9, 2024
1e4b4eb
Add more tests
mitruska Apr 9, 2024
9c4d57c
Use EXPECT_EQ
mitruska Apr 9, 2024
cceea95
Add ov test namespace
mitruska Apr 9, 2024
32a4959
Use TypePropRMSNormTest class
mitruska Apr 9, 2024
6644a6c
More tests and axis check fix
mitruska Apr 9, 2024
ef4ec33
Merge remote-tracking branch 'upstream/master' into mitruska/rms_op_s…
mitruska Apr 9, 2024
676496e
Merge branch 'master' into mitruska/rms_op_shell
mitruska Apr 11, 2024
75b6ce5
Update number of ops
mitruska Apr 11, 2024
4dcf25f
Remove from opset14
mitruska Apr 11, 2024
9104c62
fix opset usage
mitruska Apr 12, 2024
9da7372
Merge remote-tracking branch 'upstream/master' into mitruska/rms_op_s…
mitruska Apr 12, 2024
94d84e4
Update errors to NODE_SHAPE_INFER_CHECK
mitruska Apr 12, 2024
bc45554
Remove redundant header
mitruska Apr 12, 2024
c937e11
Update to use eps from common test class
mitruska Apr 12, 2024
59e350f
Update rank check
mitruska Apr 12, 2024
e1f0538
Merge remote-tracking branch 'upstream/master' into mitruska/rms_op_s…
mitruska Apr 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/core/include/openvino/op/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@
#include "openvino/op/result.hpp"
#include "openvino/op/reverse.hpp"
#include "openvino/op/reverse_sequence.hpp"
#include "openvino/op/rms_norm.hpp"
#include "openvino/op/rnn_cell.hpp"
#include "openvino/op/rnn_sequence.hpp"
#include "openvino/op/roi_align.hpp"
Expand Down
57 changes: 57 additions & 0 deletions src/core/include/openvino/op/rms_norm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/op/op.hpp"

namespace ov {
namespace op {
namespace v14 {
/// \brief Operator performing Root Mean Square Normalization
/// \ingroup ov_ops_cpp_api
class OPENVINO_API RMSNorm : public ov::op::Op {
public:
OPENVINO_OP("RMSNorm", "opset14", ov::op::Op);

RMSNorm() = default;
/// \brief Constructs an RMSNorm operation without scaling.
///
/// \param data Input tensor with data
/// \param axes Axes for reduce mean calculation
/// \param eps Epsilon for not dividing by zero while normalizing the value
/// \param compute_type Precision for the internal computation, same as the input type by default
RMSNorm(const Output<Node>& data,
const Output<Node>& axes,
double epsilson,
const ov::element::Type& compute_type = ov::element::undefined);
mitruska marked this conversation as resolved.
Show resolved Hide resolved

/// \brief Constructs an RMSNorm operation with scaling.
///
/// \param data Input tensor with data
/// \param axes Axes for reduce mean calculation
/// \param scale Scale values for weight
/// \param eps Epsilon for not dividing by zero while normalizing the value
/// \param compute_type Output element type
RMSNorm(const Output<Node>& data,
const Output<Node>& axes,
const Output<Node>& scale,
double epsilson,
const ov::element::Type& compute_type = ov::element::undefined);

bool visit_attributes(ov::AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

double get_epsilon() const;
const ov::element::Type& get_compute_type() const;
mmikolajcz marked this conversation as resolved.
Show resolved Hide resolved

private:
double m_epsilon{0};
ov::element::Type m_compute_type{ov::element::undefined};
praasz marked this conversation as resolved.
Show resolved Hide resolved
};

} // namespace v14
} // namespace op
} // namespace ov
1 change: 1 addition & 0 deletions src/core/include/openvino/opsets/opset14_tbl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,4 @@ _OPENVINO_OP_REG(FakeConvert, ov::op::v13)
// New operations added in opset14
_OPENVINO_OP_REG(ConvertPromoteTypes, ov::op::v14)
_OPENVINO_OP_REG(Inverse, ov::op::v14)
_OPENVINO_OP_REG(RMSNorm, ov::op::v14)
117 changes: 117 additions & 0 deletions src/core/src/op/rms_norm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/op/rms_norm.hpp"

#include "compare.hpp"
mitruska marked this conversation as resolved.
Show resolved Hide resolved
#include "itt.hpp"
#include "openvino/op/op.hpp"

namespace ov {
namespace op {
namespace v14 {

RMSNorm::RMSNorm(const Output<Node>& data,
const Output<Node>& axes,
double epsilson,
const ov::element::Type& compute_type)
: Op({data, axes}),
m_epsilon(epsilson),
m_compute_type(compute_type) {
constructor_validate_and_infer_types();
}

RMSNorm::RMSNorm(const Output<Node>& data,
const Output<Node>& axes,
const Output<Node>& scale,
double epsilson,
const ov::element::Type& compute_type)
: Op({data, axes, scale}),
m_epsilon(epsilson),
m_compute_type(compute_type) {
constructor_validate_and_infer_types();
}

bool RMSNorm::visit_attributes(ov::AttributeVisitor& visitor) {
OV_OP_SCOPE(v14_RMSNorm_visit_attributes);
visitor.on_attribute("epsilon", m_epsilon);
visitor.on_attribute("compute_type", m_compute_type);
return true;
}

void RMSNorm::validate_and_infer_types() {
OV_OP_SCOPE(v14_RMSNorm_validate_and_infer_types);

const auto& data_element_type = get_input_element_type(0);
const bool is_valid_data_type = data_element_type.is_dynamic() || data_element_type.is_real();
praasz marked this conversation as resolved.
Show resolved Hide resolved
NODE_VALIDATION_CHECK(this,
is_valid_data_type,
"The element type of the data tensor must be a floating point type. Got: ",
data_element_type);

const auto& axes_element_type = get_input_element_type(1);
const bool is_valid_axes_type =
data_element_type.is_dynamic() || axes_element_type == element::i32 || axes_element_type == element::i64;
NODE_VALIDATION_CHECK(this,
is_valid_axes_type,
"The element type of the axes tensor must be i32 or i64 type. Got: ",
axes_element_type);

const auto& data_shape = get_input_partial_shape(0);
const auto& axes_shape = get_input_partial_shape(1);
if (axes_shape.rank().is_static()) {
NODE_VALIDATION_CHECK(this,
axes_shape.size() == 1,
"Expected 1D tensor for the 'axes' input. Got: ",
axes_shape);

const auto data_rank = data_shape.rank();
const bool has_axes_compatible = data_rank.is_dynamic() || axes_shape[0].is_dynamic() ||
cmp::ge(data_rank.get_length(), axes_shape.get_shape()[0]);
NODE_VALIDATION_CHECK(this,
has_axes_compatible,
"Number of the axes can't be higher than the rank of the data shape.");
}
praasz marked this conversation as resolved.
Show resolved Hide resolved

if (get_input_size() > 2) { // Validate scale input
auto scale_shape = get_input_partial_shape(2);
const bool is_scale_shape_broadcastable =
PartialShape::broadcast_merge_into(scale_shape, data_shape, ov::op::AutoBroadcastType::NUMPY);
NODE_VALIDATION_CHECK(this,
is_scale_shape_broadcastable,
"Scale input shape must be broadcastable to the shape of the data input.");
mitruska marked this conversation as resolved.
Show resolved Hide resolved

// Validate input types and save result for output type
auto merged_et = element::dynamic;
const auto& scale_element_type = get_input_element_type(2);
const bool is_scale_type_compatible = element::Type::merge(merged_et, data_element_type, scale_element_type);
NODE_VALIDATION_CHECK(this,
is_scale_type_compatible,
"Element type of the scale input must be the same as the data input type.");
}

// Output type and shape is the same as the first input
set_output_type(0, data_element_type, get_input_partial_shape(0));
}

std::shared_ptr<Node> RMSNorm::clone_with_new_inputs(const ov::OutputVector& new_args) const {
OV_OP_SCOPE(v14_RMSNorm_clone_with_new_inputs);
check_new_args_count(this, new_args);
if (new_args.size() == 2) {
return std::make_shared<RMSNorm>(new_args.at(0), new_args.at(1), m_epsilon, m_compute_type);
}
return std::make_shared<RMSNorm>(new_args.at(0), new_args.at(1), new_args.at(2), m_epsilon, m_compute_type);
}

double RMSNorm::get_epsilon() const {
return m_epsilon;
}

const ov::element::Type& RMSNorm::get_compute_type() const {
return m_compute_type;
praasz marked this conversation as resolved.
Show resolved Hide resolved
}

} // namespace v14
} // namespace op
} // namespace ov
Loading
Loading