Skip to content

Commit

Permalink
[Op] Add RMSNorm op core class (openvinotoolkit#23842)
Browse files Browse the repository at this point in the history
### Details:
 - RMSNorm op core class
- Registration in the opset and op check (conformance) test will be
added in the next PRs

Spec PR: 
- openvinotoolkit#23569

### Tickets:
 - 136261
  • Loading branch information
mitruska authored and alvoron committed Apr 29, 2024
1 parent 1978fff commit ddc44a2
Show file tree
Hide file tree
Showing 8 changed files with 682 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/core/include/openvino/op/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@
#include "openvino/op/result.hpp"
#include "openvino/op/reverse.hpp"
#include "openvino/op/reverse_sequence.hpp"
#include "openvino/op/rms_norm.hpp"
#include "openvino/op/rnn_cell.hpp"
#include "openvino/op/rnn_sequence.hpp"
#include "openvino/op/roi_align.hpp"
Expand Down
57 changes: 57 additions & 0 deletions src/core/include/openvino/op/rms_norm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/op/op.hpp"

namespace ov {
namespace op {
namespace v14 {
/// \brief Operator performing Root Mean Square Normalization
/// \ingroup ov_ops_cpp_api
class OPENVINO_API RMSNorm : public ov::op::Op {
public:
OPENVINO_OP("RMSNorm", "opset14", ov::op::Op);

RMSNorm() = default;
/// \brief Constructs an RMSNorm operation without scaling.
///
/// \param data Input tensor with data
/// \param axes Axes for reduce mean calculation
/// \param eps Epsilon for not dividing by zero while normalizing the value
/// \param compute_type Precision for the internal computation, if undefined it's the same as the input type
RMSNorm(const Output<Node>& data,
const Output<Node>& axes,
double epsilson,
const ov::element::Type& compute_type = ov::element::undefined);

/// \brief Constructs an RMSNorm operation with scaling.
///
/// \param data Input tensor with data
/// \param axes Axes for reduce mean calculation
/// \param scale Scale values for weight
/// \param eps Epsilon for not dividing by zero while normalizing the value
/// \param compute_type Precision for the internal computation, if undefined it's the same as the input type
RMSNorm(const Output<Node>& data,
const Output<Node>& axes,
const Output<Node>& scale,
double epsilson,
const ov::element::Type& compute_type = ov::element::undefined);

bool visit_attributes(ov::AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

double get_epsilon() const;
const ov::element::Type& get_compute_type() const;

private:
double m_epsilon{0};
ov::element::Type m_compute_type{ov::element::undefined};
};

} // namespace v14
} // namespace op
} // namespace ov
65 changes: 65 additions & 0 deletions src/core/shape_inference/include/rms_norm_shape_inference.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/op/rms_norm.hpp"
#include "utils.hpp"

namespace ov {
namespace op {
namespace v14 {
template <class TShape, class TRShape = result_shape_t<TShape>>
std::vector<TRShape> shape_infer(const RMSNorm* op,
const std::vector<TShape>& input_shapes,
const ITensorAccessor& tensor_accessor = make_tensor_accessor()) {
const auto inputs_count = input_shapes.size();
const auto has_scale_input = inputs_count == 3;
NODE_SHAPE_INFER_CHECK(op, input_shapes, inputs_count == 2 || has_scale_input);

const auto& data_shape = input_shapes[0];
const auto& data_rank = data_shape.rank();
const auto& axes_shape = input_shapes[1];
const auto& axes_rank = axes_shape.rank();

NODE_SHAPE_INFER_CHECK(op,
input_shapes,
ov::util::is_rank_compatible_any_of(axes_rank, {0, 1}),
"Axes input must be a scalar or 1D input. Got: ",
axes_shape);

// Further validation requires data rank to be static
if (data_rank.is_dynamic()) {
return {data_shape};
}

if (axes_shape.rank().is_static()) {
const bool has_axes_compatible = axes_shape.size() == 0 || axes_shape[0].is_dynamic() ||
cmp::ge(data_rank.get_length(), axes_shape.get_shape()[0]);
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
has_axes_compatible,
"Number of the axes can't be higher than the rank of the data shape.");
}

if (has_scale_input) { // Validate scale input
TRShape scale_shape = input_shapes[2];
const bool is_scale_shape_broadcastable =
TRShape::broadcast_merge_into(scale_shape, data_shape, ov::op::AutoBroadcastType::NUMPY);
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
is_scale_shape_broadcastable,
"Scale input shape must be broadcastable to the shape of the data input.");
}

// Axes values validation
if (const auto axes_val = ov::op::get_input_const_data_as<TRShape, int64_t>(op, 1, tensor_accessor)) {
ov::util::normalize_axes(op, data_rank.get_length(), *axes_val);
}

return {data_shape};
}
} // namespace v14
} // namespace op
} // namespace ov
97 changes: 97 additions & 0 deletions src/core/src/op/rms_norm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/op/rms_norm.hpp"

#include "itt.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/op/op.hpp"
#include "rms_norm_shape_inference.hpp"

namespace ov {
namespace op {
namespace v14 {

RMSNorm::RMSNorm(const Output<Node>& data,
const Output<Node>& axes,
double epsilson,
const ov::element::Type& compute_type)
: Op({data, axes}),
m_epsilon(epsilson),
m_compute_type(compute_type) {
constructor_validate_and_infer_types();
}

RMSNorm::RMSNorm(const Output<Node>& data,
const Output<Node>& axes,
const Output<Node>& scale,
double epsilson,
const ov::element::Type& compute_type)
: Op({data, axes, scale}),
m_epsilon(epsilson),
m_compute_type(compute_type) {
constructor_validate_and_infer_types();
}

bool RMSNorm::visit_attributes(ov::AttributeVisitor& visitor) {
OV_OP_SCOPE(v14_RMSNorm_visit_attributes);
visitor.on_attribute("epsilon", m_epsilon);
visitor.on_attribute("compute_type", m_compute_type);
return true;
}

void RMSNorm::validate_and_infer_types() {
OV_OP_SCOPE(v14_RMSNorm_validate_and_infer_types);

const auto& data_element_type = get_input_element_type(0);
const bool is_valid_data_type = data_element_type.is_dynamic() || data_element_type.is_real();
NODE_VALIDATION_CHECK(this,
is_valid_data_type,
"The element type of the data tensor must be a floating point type. Got: ",
data_element_type);

const auto& axes_element_type = get_input_element_type(1);
const bool is_valid_axes_type =
data_element_type.is_dynamic() || axes_element_type == element::i32 || axes_element_type == element::i64;
NODE_VALIDATION_CHECK(this,
is_valid_axes_type,
"The element type of the axes tensor must be i32 or i64 type. Got: ",
axes_element_type);

if (get_input_size() > 2) { // Validate scale input type

// Validate input types
auto merged_et = element::dynamic;
const auto& scale_element_type = get_input_element_type(2);
const bool is_scale_type_compatible = element::Type::merge(merged_et, data_element_type, scale_element_type);
NODE_VALIDATION_CHECK(this,
is_scale_type_compatible,
"Element type of the scale input must be the same as the data input type.");
}

const auto output_shapes = shape_infer(this, ov::util::get_node_input_partial_shapes(*this));
// Output type and shape is the same as the first input
set_output_type(0, data_element_type, output_shapes[0]);
}

std::shared_ptr<Node> RMSNorm::clone_with_new_inputs(const ov::OutputVector& new_args) const {
OV_OP_SCOPE(v14_RMSNorm_clone_with_new_inputs);
check_new_args_count(this, new_args);
if (new_args.size() == 2) {
return std::make_shared<RMSNorm>(new_args.at(0), new_args.at(1), m_epsilon, m_compute_type);
}
return std::make_shared<RMSNorm>(new_args.at(0), new_args.at(1), new_args.at(2), m_epsilon, m_compute_type);
}

double RMSNorm::get_epsilon() const {
return m_epsilon;
}

const ov::element::Type& RMSNorm::get_compute_type() const {
return m_compute_type;
}

} // namespace v14
} // namespace op
} // namespace ov
Loading

0 comments on commit ddc44a2

Please sign in to comment.