-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Op] Add RMSNorm op core class (#23842)
### Details: - RMSNorm op core class - Registration in the opset and op check (conformance) test will be added in the next PRs Spec PR: - #23569 ### Tickets: - 136261
- Loading branch information
Showing
8 changed files
with
682 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/op/op.hpp" | ||
|
||
namespace ov { | ||
namespace op { | ||
namespace v14 { | ||
/// \brief Operator performing Root Mean Square Normalization | ||
/// \ingroup ov_ops_cpp_api | ||
class OPENVINO_API RMSNorm : public ov::op::Op { | ||
public: | ||
OPENVINO_OP("RMSNorm", "opset14", ov::op::Op); | ||
|
||
RMSNorm() = default; | ||
/// \brief Constructs an RMSNorm operation without scaling. | ||
/// | ||
/// \param data Input tensor with data | ||
/// \param axes Axes for reduce mean calculation | ||
/// \param eps Epsilon for not dividing by zero while normalizing the value | ||
/// \param compute_type Precision for the internal computation, if undefined it's the same as the input type | ||
RMSNorm(const Output<Node>& data, | ||
const Output<Node>& axes, | ||
double epsilson, | ||
const ov::element::Type& compute_type = ov::element::undefined); | ||
|
||
/// \brief Constructs an RMSNorm operation with scaling. | ||
/// | ||
/// \param data Input tensor with data | ||
/// \param axes Axes for reduce mean calculation | ||
/// \param scale Scale values for weight | ||
/// \param eps Epsilon for not dividing by zero while normalizing the value | ||
/// \param compute_type Precision for the internal computation, if undefined it's the same as the input type | ||
RMSNorm(const Output<Node>& data, | ||
const Output<Node>& axes, | ||
const Output<Node>& scale, | ||
double epsilson, | ||
const ov::element::Type& compute_type = ov::element::undefined); | ||
|
||
bool visit_attributes(ov::AttributeVisitor& visitor) override; | ||
void validate_and_infer_types() override; | ||
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override; | ||
|
||
double get_epsilon() const; | ||
const ov::element::Type& get_compute_type() const; | ||
|
||
private: | ||
double m_epsilon{0}; | ||
ov::element::Type m_compute_type{ov::element::undefined}; | ||
}; | ||
|
||
} // namespace v14 | ||
} // namespace op | ||
} // namespace ov |
65 changes: 65 additions & 0 deletions
65
src/core/shape_inference/include/rms_norm_shape_inference.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/op/rms_norm.hpp" | ||
#include "utils.hpp" | ||
|
||
namespace ov { | ||
namespace op { | ||
namespace v14 { | ||
template <class TShape, class TRShape = result_shape_t<TShape>> | ||
std::vector<TRShape> shape_infer(const RMSNorm* op, | ||
const std::vector<TShape>& input_shapes, | ||
const ITensorAccessor& tensor_accessor = make_tensor_accessor()) { | ||
const auto inputs_count = input_shapes.size(); | ||
const auto has_scale_input = inputs_count == 3; | ||
NODE_SHAPE_INFER_CHECK(op, input_shapes, inputs_count == 2 || has_scale_input); | ||
|
||
const auto& data_shape = input_shapes[0]; | ||
const auto& data_rank = data_shape.rank(); | ||
const auto& axes_shape = input_shapes[1]; | ||
const auto& axes_rank = axes_shape.rank(); | ||
|
||
NODE_SHAPE_INFER_CHECK(op, | ||
input_shapes, | ||
ov::util::is_rank_compatible_any_of(axes_rank, {0, 1}), | ||
"Axes input must be a scalar or 1D input. Got: ", | ||
axes_shape); | ||
|
||
// Further validation requires data rank to be static | ||
if (data_rank.is_dynamic()) { | ||
return {data_shape}; | ||
} | ||
|
||
if (axes_shape.rank().is_static()) { | ||
const bool has_axes_compatible = axes_shape.size() == 0 || axes_shape[0].is_dynamic() || | ||
cmp::ge(data_rank.get_length(), axes_shape.get_shape()[0]); | ||
NODE_SHAPE_INFER_CHECK(op, | ||
input_shapes, | ||
has_axes_compatible, | ||
"Number of the axes can't be higher than the rank of the data shape."); | ||
} | ||
|
||
if (has_scale_input) { // Validate scale input | ||
TRShape scale_shape = input_shapes[2]; | ||
const bool is_scale_shape_broadcastable = | ||
TRShape::broadcast_merge_into(scale_shape, data_shape, ov::op::AutoBroadcastType::NUMPY); | ||
NODE_SHAPE_INFER_CHECK(op, | ||
input_shapes, | ||
is_scale_shape_broadcastable, | ||
"Scale input shape must be broadcastable to the shape of the data input."); | ||
} | ||
|
||
// Axes values validation | ||
if (const auto axes_val = ov::op::get_input_const_data_as<TRShape, int64_t>(op, 1, tensor_accessor)) { | ||
ov::util::normalize_axes(op, data_rank.get_length(), *axes_val); | ||
} | ||
|
||
return {data_shape}; | ||
} | ||
} // namespace v14 | ||
} // namespace op | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "openvino/op/rms_norm.hpp" | ||
|
||
#include "itt.hpp" | ||
#include "openvino/core/validation_util.hpp" | ||
#include "openvino/op/op.hpp" | ||
#include "rms_norm_shape_inference.hpp" | ||
|
||
namespace ov { | ||
namespace op { | ||
namespace v14 { | ||
|
||
RMSNorm::RMSNorm(const Output<Node>& data, | ||
const Output<Node>& axes, | ||
double epsilson, | ||
const ov::element::Type& compute_type) | ||
: Op({data, axes}), | ||
m_epsilon(epsilson), | ||
m_compute_type(compute_type) { | ||
constructor_validate_and_infer_types(); | ||
} | ||
|
||
RMSNorm::RMSNorm(const Output<Node>& data, | ||
const Output<Node>& axes, | ||
const Output<Node>& scale, | ||
double epsilson, | ||
const ov::element::Type& compute_type) | ||
: Op({data, axes, scale}), | ||
m_epsilon(epsilson), | ||
m_compute_type(compute_type) { | ||
constructor_validate_and_infer_types(); | ||
} | ||
|
||
bool RMSNorm::visit_attributes(ov::AttributeVisitor& visitor) { | ||
OV_OP_SCOPE(v14_RMSNorm_visit_attributes); | ||
visitor.on_attribute("epsilon", m_epsilon); | ||
visitor.on_attribute("compute_type", m_compute_type); | ||
return true; | ||
} | ||
|
||
void RMSNorm::validate_and_infer_types() { | ||
OV_OP_SCOPE(v14_RMSNorm_validate_and_infer_types); | ||
|
||
const auto& data_element_type = get_input_element_type(0); | ||
const bool is_valid_data_type = data_element_type.is_dynamic() || data_element_type.is_real(); | ||
NODE_VALIDATION_CHECK(this, | ||
is_valid_data_type, | ||
"The element type of the data tensor must be a floating point type. Got: ", | ||
data_element_type); | ||
|
||
const auto& axes_element_type = get_input_element_type(1); | ||
const bool is_valid_axes_type = | ||
data_element_type.is_dynamic() || axes_element_type == element::i32 || axes_element_type == element::i64; | ||
NODE_VALIDATION_CHECK(this, | ||
is_valid_axes_type, | ||
"The element type of the axes tensor must be i32 or i64 type. Got: ", | ||
axes_element_type); | ||
|
||
if (get_input_size() > 2) { // Validate scale input type | ||
|
||
// Validate input types | ||
auto merged_et = element::dynamic; | ||
const auto& scale_element_type = get_input_element_type(2); | ||
const bool is_scale_type_compatible = element::Type::merge(merged_et, data_element_type, scale_element_type); | ||
NODE_VALIDATION_CHECK(this, | ||
is_scale_type_compatible, | ||
"Element type of the scale input must be the same as the data input type."); | ||
} | ||
|
||
const auto output_shapes = shape_infer(this, ov::util::get_node_input_partial_shapes(*this)); | ||
// Output type and shape is the same as the first input | ||
set_output_type(0, data_element_type, output_shapes[0]); | ||
} | ||
|
||
std::shared_ptr<Node> RMSNorm::clone_with_new_inputs(const ov::OutputVector& new_args) const { | ||
OV_OP_SCOPE(v14_RMSNorm_clone_with_new_inputs); | ||
check_new_args_count(this, new_args); | ||
if (new_args.size() == 2) { | ||
return std::make_shared<RMSNorm>(new_args.at(0), new_args.at(1), m_epsilon, m_compute_type); | ||
} | ||
return std::make_shared<RMSNorm>(new_args.at(0), new_args.at(1), new_args.at(2), m_epsilon, m_compute_type); | ||
} | ||
|
||
double RMSNorm::get_epsilon() const { | ||
return m_epsilon; | ||
} | ||
|
||
const ov::element::Type& RMSNorm::get_compute_type() const { | ||
return m_compute_type; | ||
} | ||
|
||
} // namespace v14 | ||
} // namespace op | ||
} // namespace ov |
Oops, something went wrong.