From 5098aca73dd25fe83dde1cab885570dbf0eebfed Mon Sep 17 00:00:00 2001 From: Katarzyna Mitrus Date: Mon, 15 Apr 2024 10:06:06 +0200 Subject: [PATCH] [Op] Add RMSNorm op core class (#23842) ### Details: - RMSNorm op core class - Registration in the opset and op check (conformance) test will be added in the next PRs Spec PR: - https://github.com/openvinotoolkit/openvino/pull/23569 ### Tickets: - 136261 --- src/core/include/openvino/op/ops.hpp | 1 + src/core/include/openvino/op/rms_norm.hpp | 57 ++++ .../include/rms_norm_shape_inference.hpp | 65 +++++ src/core/src/op/rms_norm.cpp | 97 +++++++ src/core/tests/type_prop/rms_norm.cpp | 273 ++++++++++++++++++ src/core/tests/visitors/op/rms_norm.cpp | 50 ++++ .../src/shape_inference/shape_inference.cpp | 2 + .../rms_norm_shape_inference_test.cpp | 137 +++++++++ 8 files changed, 682 insertions(+) create mode 100644 src/core/include/openvino/op/rms_norm.hpp create mode 100644 src/core/shape_inference/include/rms_norm_shape_inference.hpp create mode 100644 src/core/src/op/rms_norm.cpp create mode 100644 src/core/tests/type_prop/rms_norm.cpp create mode 100644 src/core/tests/visitors/op/rms_norm.cpp create mode 100644 src/plugins/intel_cpu/tests/unit/shape_inference_test/rms_norm_shape_inference_test.cpp diff --git a/src/core/include/openvino/op/ops.hpp b/src/core/include/openvino/op/ops.hpp index f6c91269215f8f..7a17f120f735a5 100644 --- a/src/core/include/openvino/op/ops.hpp +++ b/src/core/include/openvino/op/ops.hpp @@ -153,6 +153,7 @@ #include "openvino/op/result.hpp" #include "openvino/op/reverse.hpp" #include "openvino/op/reverse_sequence.hpp" +#include "openvino/op/rms_norm.hpp" #include "openvino/op/rnn_cell.hpp" #include "openvino/op/rnn_sequence.hpp" #include "openvino/op/roi_align.hpp" diff --git a/src/core/include/openvino/op/rms_norm.hpp b/src/core/include/openvino/op/rms_norm.hpp new file mode 100644 index 00000000000000..43bfd7e213bab0 --- /dev/null +++ b/src/core/include/openvino/op/rms_norm.hpp @@ -0,0 +1,57 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/op/op.hpp" + +namespace ov { +namespace op { +namespace v14 { +/// \brief Operator performing Root Mean Square Normalization +/// \ingroup ov_ops_cpp_api +class OPENVINO_API RMSNorm : public ov::op::Op { +public: + OPENVINO_OP("RMSNorm", "opset14", ov::op::Op); + + RMSNorm() = default; + /// \brief Constructs an RMSNorm operation without scaling. + /// + /// \param data Input tensor with data + /// \param axes Axes for reduce mean calculation + /// \param eps Epsilon for not dividing by zero while normalizing the value + /// \param compute_type Precision for the internal computation, if undefined it's the same as the input type + RMSNorm(const Output& data, + const Output& axes, + double epsilson, + const ov::element::Type& compute_type = ov::element::undefined); + + /// \brief Constructs an RMSNorm operation with scaling. + /// + /// \param data Input tensor with data + /// \param axes Axes for reduce mean calculation + /// \param scale Scale values for weight + /// \param eps Epsilon for not dividing by zero while normalizing the value + /// \param compute_type Precision for the internal computation, if undefined it's the same as the input type + RMSNorm(const Output& data, + const Output& axes, + const Output& scale, + double epsilson, + const ov::element::Type& compute_type = ov::element::undefined); + + bool visit_attributes(ov::AttributeVisitor& visitor) override; + void validate_and_infer_types() override; + std::shared_ptr clone_with_new_inputs(const ov::OutputVector& new_args) const override; + + double get_epsilon() const; + const ov::element::Type& get_compute_type() const; + +private: + double m_epsilon{0}; + ov::element::Type m_compute_type{ov::element::undefined}; +}; + +} // namespace v14 +} // namespace op +} // namespace ov diff --git a/src/core/shape_inference/include/rms_norm_shape_inference.hpp b/src/core/shape_inference/include/rms_norm_shape_inference.hpp new file mode 100644 index 00000000000000..bc03fe37f91f34 --- /dev/null +++ b/src/core/shape_inference/include/rms_norm_shape_inference.hpp @@ -0,0 +1,65 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/op/rms_norm.hpp" +#include "utils.hpp" + +namespace ov { +namespace op { +namespace v14 { +template > +std::vector shape_infer(const RMSNorm* op, + const std::vector& input_shapes, + const ITensorAccessor& tensor_accessor = make_tensor_accessor()) { + const auto inputs_count = input_shapes.size(); + const auto has_scale_input = inputs_count == 3; + NODE_SHAPE_INFER_CHECK(op, input_shapes, inputs_count == 2 || has_scale_input); + + const auto& data_shape = input_shapes[0]; + const auto& data_rank = data_shape.rank(); + const auto& axes_shape = input_shapes[1]; + const auto& axes_rank = axes_shape.rank(); + + NODE_SHAPE_INFER_CHECK(op, + input_shapes, + ov::util::is_rank_compatible_any_of(axes_rank, {0, 1}), + "Axes input must be a scalar or 1D input. Got: ", + axes_shape); + + // Further validation requires data rank to be static + if (data_rank.is_dynamic()) { + return {data_shape}; + } + + if (axes_shape.rank().is_static()) { + const bool has_axes_compatible = axes_shape.size() == 0 || axes_shape[0].is_dynamic() || + cmp::ge(data_rank.get_length(), axes_shape.get_shape()[0]); + NODE_SHAPE_INFER_CHECK(op, + input_shapes, + has_axes_compatible, + "Number of the axes can't be higher than the rank of the data shape."); + } + + if (has_scale_input) { // Validate scale input + TRShape scale_shape = input_shapes[2]; + const bool is_scale_shape_broadcastable = + TRShape::broadcast_merge_into(scale_shape, data_shape, ov::op::AutoBroadcastType::NUMPY); + NODE_SHAPE_INFER_CHECK(op, + input_shapes, + is_scale_shape_broadcastable, + "Scale input shape must be broadcastable to the shape of the data input."); + } + + // Axes values validation + if (const auto axes_val = ov::op::get_input_const_data_as(op, 1, tensor_accessor)) { + ov::util::normalize_axes(op, data_rank.get_length(), *axes_val); + } + + return {data_shape}; +} +} // namespace v14 +} // namespace op +} // namespace ov diff --git a/src/core/src/op/rms_norm.cpp b/src/core/src/op/rms_norm.cpp new file mode 100644 index 00000000000000..a249e86a6a207e --- /dev/null +++ b/src/core/src/op/rms_norm.cpp @@ -0,0 +1,97 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/op/rms_norm.hpp" + +#include "itt.hpp" +#include "openvino/core/validation_util.hpp" +#include "openvino/op/op.hpp" +#include "rms_norm_shape_inference.hpp" + +namespace ov { +namespace op { +namespace v14 { + +RMSNorm::RMSNorm(const Output& data, + const Output& axes, + double epsilson, + const ov::element::Type& compute_type) + : Op({data, axes}), + m_epsilon(epsilson), + m_compute_type(compute_type) { + constructor_validate_and_infer_types(); +} + +RMSNorm::RMSNorm(const Output& data, + const Output& axes, + const Output& scale, + double epsilson, + const ov::element::Type& compute_type) + : Op({data, axes, scale}), + m_epsilon(epsilson), + m_compute_type(compute_type) { + constructor_validate_and_infer_types(); +} + +bool RMSNorm::visit_attributes(ov::AttributeVisitor& visitor) { + OV_OP_SCOPE(v14_RMSNorm_visit_attributes); + visitor.on_attribute("epsilon", m_epsilon); + visitor.on_attribute("compute_type", m_compute_type); + return true; +} + +void RMSNorm::validate_and_infer_types() { + OV_OP_SCOPE(v14_RMSNorm_validate_and_infer_types); + + const auto& data_element_type = get_input_element_type(0); + const bool is_valid_data_type = data_element_type.is_dynamic() || data_element_type.is_real(); + NODE_VALIDATION_CHECK(this, + is_valid_data_type, + "The element type of the data tensor must be a floating point type. Got: ", + data_element_type); + + const auto& axes_element_type = get_input_element_type(1); + const bool is_valid_axes_type = + data_element_type.is_dynamic() || axes_element_type == element::i32 || axes_element_type == element::i64; + NODE_VALIDATION_CHECK(this, + is_valid_axes_type, + "The element type of the axes tensor must be i32 or i64 type. Got: ", + axes_element_type); + + if (get_input_size() > 2) { // Validate scale input type + + // Validate input types + auto merged_et = element::dynamic; + const auto& scale_element_type = get_input_element_type(2); + const bool is_scale_type_compatible = element::Type::merge(merged_et, data_element_type, scale_element_type); + NODE_VALIDATION_CHECK(this, + is_scale_type_compatible, + "Element type of the scale input must be the same as the data input type."); + } + + const auto output_shapes = shape_infer(this, ov::util::get_node_input_partial_shapes(*this)); + // Output type and shape is the same as the first input + set_output_type(0, data_element_type, output_shapes[0]); +} + +std::shared_ptr RMSNorm::clone_with_new_inputs(const ov::OutputVector& new_args) const { + OV_OP_SCOPE(v14_RMSNorm_clone_with_new_inputs); + check_new_args_count(this, new_args); + if (new_args.size() == 2) { + return std::make_shared(new_args.at(0), new_args.at(1), m_epsilon, m_compute_type); + } + return std::make_shared(new_args.at(0), new_args.at(1), new_args.at(2), m_epsilon, m_compute_type); +} + +double RMSNorm::get_epsilon() const { + return m_epsilon; +} + +const ov::element::Type& RMSNorm::get_compute_type() const { + return m_compute_type; +} + +} // namespace v14 +} // namespace op +} // namespace ov diff --git a/src/core/tests/type_prop/rms_norm.cpp b/src/core/tests/type_prop/rms_norm.cpp new file mode 100644 index 00000000000000..b24531a9c2cf23 --- /dev/null +++ b/src/core/tests/type_prop/rms_norm.cpp @@ -0,0 +1,273 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/op/rms_norm.hpp" + +#include + +#include "common_test_utils/test_assertions.hpp" +#include "common_test_utils/type_prop.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/subtract.hpp" + +namespace ov { +namespace test { + +using ov::op::v0::Constant; +using ov::op::v0::Parameter; +using testing::HasSubstr; + +class TypePropRMSNormTest : public TypePropOpTest { +public: + double eps = 1e-5; +}; + +TEST_F(TypePropRMSNormTest, default_ctor) { + const auto op = make_op(); + const auto data = std::make_shared(element::f16, PartialShape{2, 3, 8, 6}); + const auto axes = std::make_shared(element::i64, PartialShape{1}); + const auto scale = std::make_shared(element::f16, PartialShape{}); + + op->set_arguments(ov::OutputVector{data, axes, scale}); + op->validate_and_infer_types(); + + EXPECT_EQ(op->get_output_size(), 1); + EXPECT_EQ(op->get_input_size(), 3); + EXPECT_EQ(op->get_output_element_type(0), element::f16); + EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 8, 6})); +} + +TEST_F(TypePropRMSNormTest, no_scale_no_compute_type) { + const auto data = std::make_shared(element::f32, PartialShape{2, 3, 8, 6}); + const auto axes = std::make_shared(element::i32, PartialShape{1}); + + const auto op = make_op(data, axes, eps); + EXPECT_EQ(op->get_input_size(), 2); + EXPECT_EQ(op->get_output_size(), 1); + EXPECT_EQ(op->get_output_element_type(0), element::f32); + EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 8, 6})); + EXPECT_EQ(op->get_epsilon(), eps); +} + +TEST_F(TypePropRMSNormTest, scale_no_compute_type) { + const auto data = std::make_shared(element::f16, PartialShape{2, 3, 8, 6}); + const auto axes = std::make_shared(element::i32, PartialShape{1}); + const auto scale = std::make_shared(element::f16, PartialShape{}); + + const auto op = make_op(data, axes, scale, eps); + EXPECT_EQ(op->get_input_size(), 3); + EXPECT_EQ(op->get_output_size(), 1); + EXPECT_EQ(op->get_output_element_type(0), element::f16); + EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 8, 6})); + EXPECT_EQ(op->get_epsilon(), eps); +} + +TEST_F(TypePropRMSNormTest, scale_compute_type) { + const auto data = std::make_shared(element::f16, PartialShape{2, 3, 8, 6}); + const auto axes = std::make_shared(element::i32, PartialShape{1}); + const auto scale = std::make_shared(element::f16, PartialShape{}); + const auto compute_type = element::f32; + + const auto op = make_op(data, axes, scale, eps, compute_type); + EXPECT_EQ(op->get_input_size(), 3); + EXPECT_EQ(op->get_output_size(), 1); + EXPECT_EQ(op->get_output_element_type(0), element::f16); + EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 8, 6})); + EXPECT_EQ(op->get_epsilon(), eps); + EXPECT_EQ(op->get_compute_type(), compute_type); +} + +TEST_F(TypePropRMSNormTest, scale_compute_type_no_scale) { + const auto data = std::make_shared(element::f16, PartialShape{2, 3, 8, 6}); + const auto axes = std::make_shared(element::i32, PartialShape{1}); + const auto compute_type = element::f32; + + const auto op = make_op(data, axes, eps, compute_type); + EXPECT_EQ(op->get_output_size(), 1); + EXPECT_EQ(op->get_output_element_type(0), element::f16); + EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 8, 6})); +} + +TEST_F(TypePropRMSNormTest, dynamic_data_shape) { + const auto data = std::make_shared(element::f16, PartialShape{-1, {3, 4}, {8, -1}, 6}); + const auto axes = std::make_shared(element::i32, PartialShape{1}); + const auto scale = std::make_shared(element::f16, PartialShape{}); + const auto compute_type = element::f32; + + const auto op = make_op(data, axes, scale, eps, compute_type); + EXPECT_EQ(op->get_output_element_type(0), element::f16); + EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{-1, {3, 4}, {8, -1}, 6})); +} + +TEST_F(TypePropRMSNormTest, dynamic_data_shape_rank) { + const auto data = std::make_shared(element::f16, PartialShape::dynamic()); + const auto axes = std::make_shared(element::i32, PartialShape{1}); + const auto scale = std::make_shared(element::f16, PartialShape{}); + const auto compute_type = element::f32; + + const auto op = make_op(data, axes, scale, eps, compute_type); + EXPECT_EQ(op->get_output_element_type(0), element::f16); + EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape::dynamic())); +} + +TEST_F(TypePropRMSNormTest, propagate_symbols) { + auto data_shape = PartialShape{-1, {3, 4}, {8, -1}, 6}; + set_shape_symbols(data_shape); + const auto exp_symbols = get_shape_symbols(data_shape); + + const auto data = std::make_shared(element::f16, data_shape); + const auto axes = std::make_shared(element::i32, PartialShape{1}); + const auto scale = std::make_shared(element::f16, PartialShape{}); + const auto compute_type = element::f32; + + const auto op = make_op(data, axes, scale, eps, compute_type); + EXPECT_EQ(get_shape_symbols(op->get_output_partial_shape(0)), exp_symbols); +} + +TEST_F(TypePropRMSNormTest, incorrect_input_type) { + const auto data = std::make_shared(element::f16, PartialShape::dynamic()); + const auto axes = std::make_shared(element::i32, PartialShape{1}); + const auto scale = std::make_shared(element::f16, PartialShape{}); + const auto compute_type = element::f32; + { + const auto data_int = std::make_shared(element::i32, PartialShape::dynamic()); + OV_EXPECT_THROW(std::ignore = make_op(data_int, axes, scale, eps, compute_type), + ov::NodeValidationFailure, + HasSubstr("The element type of the data tensor must be a floating point type")); + } + { + const auto axes_float = std::make_shared(element::f32, PartialShape{1}); + OV_EXPECT_THROW(std::ignore = make_op(data, axes_float, scale, eps, compute_type), + ov::NodeValidationFailure, + HasSubstr("The element type of the axes tensor must be i32 or i64 type")); + } + { + const auto scale_incompatible = std::make_shared(element::f32, PartialShape{1}); + OV_EXPECT_THROW(std::ignore = make_op(data, axes, scale_incompatible, eps, compute_type), + ov::NodeValidationFailure, + HasSubstr("Element type of the scale input must be the same as the data input type")); + } +} + +TEST_F(TypePropRMSNormTest, incompatible_axes_shape) { + const auto data = std::make_shared(element::f16, PartialShape{2, 3, 8}); + const auto scale = std::make_shared(element::f16, PartialShape{}); + const auto compute_type = element::f32; + { + const auto axes = std::make_shared(element::i32, PartialShape{1, 2}); + OV_EXPECT_THROW(std::ignore = make_op(data, axes, scale, eps, compute_type), + ov::NodeValidationFailure, + HasSubstr("Axes input must be a scalar or 1D input. Got: [1,2]")); + } + { + const auto axes = std::make_shared(element::i32, PartialShape{4}); + OV_EXPECT_THROW(std::ignore = make_op(data, axes, scale, eps, compute_type), + ov::NodeValidationFailure, + HasSubstr("Number of the axes can't be higher than the rank of the data shape")); + } +} + +TEST_F(TypePropRMSNormTest, constant_axes_val_data_dyn_rank) { + const auto data = std::make_shared(element::f16, PartialShape::dynamic()); + const auto axes = std::make_shared(element::i32, Shape{}, 1); + const auto op = make_op(data, axes, eps); + + EXPECT_EQ(op->get_output_element_type(0), element::f16); + EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape::dynamic())); +} + +TEST_F(TypePropRMSNormTest, constant_axes_val_data_static_rank) { + const auto data = std::make_shared(element::f16, PartialShape{2, 3, 8}); + const auto axes = std::make_shared(element::i32, Shape{}, 1); + const auto op = make_op(data, axes, eps); + + EXPECT_EQ(op->get_output_element_type(0), element::f16); + EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 8})); +} + +TEST_F(TypePropRMSNormTest, axes_val_as_shape_of) { + const auto data = std::make_shared(element::f16, PartialShape{2, 3, 8}); + const auto data_rank = std::make_shared(std::make_shared(data)); + const auto axes = + std::make_shared(data_rank, std::make_shared(element::i64, Shape{}, 1)); + const auto op = make_op(data, axes, eps); + + EXPECT_EQ(op->get_output_element_type(0), element::f16); + EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 8})); +} + +TEST_F(TypePropRMSNormTest, incorrect_axes_val) { + const auto data = std::make_shared(element::f16, PartialShape{2, 3, 8}); + { + const auto axes = std::make_shared(element::i32, Shape{}, 3); + OV_EXPECT_THROW(std::ignore = make_op(data, axes, eps), + ov::NodeValidationFailure, + HasSubstr("Parameter axis 3 out of the tensor rank range [-3, 2]")); + } + { + const auto axes = std::make_shared(element::i32, Shape{}, -4); + OV_EXPECT_THROW(std::ignore = make_op(data, axes, eps), + ov::NodeValidationFailure, + HasSubstr("Parameter axis -4 out of the tensor rank range [-3, 2]")); + } +} + +using RMSNormTestParam = std::tuple; +class TypePropRMSNormTestP : public TypePropRMSNormTest, public testing::WithParamInterface { +protected: + void SetUp() override { + std::tie(shape_data, shape_scale) = GetParam(); + } + PartialShape shape_data, shape_scale; +}; + +INSTANTIATE_TEST_SUITE_P(type_prop_rms_scale_shape, + TypePropRMSNormTestP, + testing::Values(std::make_tuple(PartialShape{-1, 3, 1, 2}, PartialShape{-1}), + std::make_tuple(PartialShape{-1, 3, 1, 2}, PartialShape{}), + std::make_tuple(PartialShape{-1, 3, 1, 2}, PartialShape{1}), + std::make_tuple(PartialShape{-1, 3, 1, 2}, PartialShape{2}), + std::make_tuple(PartialShape{-1, 3, 1, 2}, PartialShape{1, 1}), + std::make_tuple(PartialShape{-1, 3, 1, 2}, PartialShape{1, 2}), + std::make_tuple(PartialShape{-1, 3, 1, 2}, PartialShape{3, 1, 2}), + std::make_tuple(PartialShape{-1, 4, 8, 6}, PartialShape{1, 4, 1, 1}), + std::make_tuple(PartialShape{2, 4, 8, 6}, PartialShape{2, 4, 8, 6}), + std::make_tuple(PartialShape{2, 4, 8, 6}, PartialShape{1, 4, 1, 1}), + std::make_tuple(PartialShape{2, 4, 8, 6}, PartialShape{1, 1, 1, 1}), + std::make_tuple(PartialShape{2, 4, 8, 6}, PartialShape::dynamic()), + std::make_tuple(PartialShape::dynamic(), PartialShape{1}), + std::make_tuple(PartialShape::dynamic(), PartialShape::dynamic())), + testing::PrintToStringParamName()); + +TEST_P(TypePropRMSNormTestP, scale_shape) { + const auto data = std::make_shared(element::f16, shape_data); + const auto axes = std::make_shared(element::i32, PartialShape{1}); + + const auto scale = std::make_shared(element::f16, shape_scale); + const auto op = make_op(data, axes, scale, eps); + + EXPECT_EQ(op->get_output_partial_shape(0), shape_data); +} + +TEST_F(TypePropRMSNormTest, scale_incompatible_shape) { + const auto data = std::make_shared(element::f16, PartialShape{-1, 3, 8, 6}); + const auto axes = std::make_shared(element::i32, PartialShape{1}); + const auto compute_type = element::f32; + { + const auto scale = std::make_shared(element::f16, PartialShape{8}); + OV_EXPECT_THROW(std::ignore = make_op(data, axes, scale, eps, compute_type), + ov::NodeValidationFailure, + HasSubstr("Scale input shape must be broadcastable to the shape of the data input")); + } + { + const auto scale = std::make_shared(element::f16, PartialShape{6, 1}); + OV_EXPECT_THROW(std::ignore = make_op(data, axes, scale, eps, compute_type), + ov::NodeValidationFailure, + HasSubstr("Scale input shape must be broadcastable to the shape of the data input")); + } +} + +} // namespace test +} // namespace ov diff --git a/src/core/tests/visitors/op/rms_norm.cpp b/src/core/tests/visitors/op/rms_norm.cpp new file mode 100644 index 00000000000000..ac0d191d1e6dfb --- /dev/null +++ b/src/core/tests/visitors/op/rms_norm.cpp @@ -0,0 +1,50 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/op/rms_norm.hpp" + +#include + +#include "visitors/visitors.hpp" + +using ov::PartialShape; +using ov::op::v0::Parameter; +using ov::test::NodeBuilder; + +TEST(attributes, rms_norm_v14_attr_comp_type_default) { + using ov::op::v14::RMSNorm; + NodeBuilder::opset().insert(); + + const auto data = std::make_shared(ov::element::f16, PartialShape{2, 3, 8, 6}); + const auto axes = std::make_shared(ov::element::i32, PartialShape{1}); + const auto eps = 1e-5f; + + const auto op = std::make_shared(data, axes, eps); + + NodeBuilder builder(op, {data, axes}); + auto g_op = ov::as_type_ptr(builder.create()); + + EXPECT_EQ(g_op->get_compute_type(), op->get_compute_type()); + EXPECT_EQ(g_op->get_output_element_type(0), op->get_output_element_type(0)); + EXPECT_EQ(g_op->get_output_partial_shape(0), op->get_output_partial_shape(0)); +} + +TEST(attributes, rms_norm_v14_attr_comp_type_custom) { + using ov::op::v14::RMSNorm; + NodeBuilder::opset().insert(); + + const auto data = std::make_shared(ov::element::f16, PartialShape{2, 3, 8, 6}); + const auto axes = std::make_shared(ov::element::i32, PartialShape{1}); + const auto eps = 1e-5f; + const auto compute_type = ov::element::f32; + + const auto op = std::make_shared(data, axes, eps, compute_type); + + NodeBuilder builder(op, {data, axes}); + auto g_op = ov::as_type_ptr(builder.create()); + + EXPECT_EQ(g_op->get_compute_type(), op->get_compute_type()); + EXPECT_EQ(g_op->get_output_element_type(0), op->get_output_element_type(0)); + EXPECT_EQ(g_op->get_output_partial_shape(0), op->get_output_partial_shape(0)); +} diff --git a/src/plugins/intel_cpu/src/shape_inference/shape_inference.cpp b/src/plugins/intel_cpu/src/shape_inference/shape_inference.cpp index 8f05876ce219b7..b3588f8bffbd47 100644 --- a/src/plugins/intel_cpu/src/shape_inference/shape_inference.cpp +++ b/src/plugins/intel_cpu/src/shape_inference/shape_inference.cpp @@ -90,6 +90,7 @@ #include "reshape_shape_inference.hpp" #include "reverse_sequence_shape_inference.hpp" #include "reverse_shape_inference.hpp" +#include "rms_norm_shape_inference.hpp" #include "rnn_cell_shape_inference.hpp" #include "rnn_sequence_shape_inference.hpp" #include "roi_align_shape_inference.hpp" @@ -399,6 +400,7 @@ using IStaticShapeInferFactory = template <> const IStaticShapeInferFactory::TRegistry IStaticShapeInferFactory::registry{ // opset14 + _OV_OP_SHAPE_INFER_MASK_REG(op::v14::RMSNorm, ShapeInferTA, util::bit::mask(1)), _OV_OP_SHAPE_INFER_MASK_REG(opset14::Inverse, ShapeInferTA, util::bit::mask()), // opset13 _OV_OP_SHAPE_INFER_MASK_REG(opset13::Multinomial, ShapeInferTA, util::bit::mask(1)), diff --git a/src/plugins/intel_cpu/tests/unit/shape_inference_test/rms_norm_shape_inference_test.cpp b/src/plugins/intel_cpu/tests/unit/shape_inference_test/rms_norm_shape_inference_test.cpp new file mode 100644 index 00000000000000..cb3f346ec98c6f --- /dev/null +++ b/src/plugins/intel_cpu/tests/unit/shape_inference_test/rms_norm_shape_inference_test.cpp @@ -0,0 +1,137 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "common_test_utils/test_assertions.hpp" +#include "utils.hpp" + +using namespace ov; +using namespace ov::intel_cpu; +using ov::op::v0::Constant; +using ov::op::v0::Parameter; +using testing::HasSubstr; + +TEST(StaticShapeInferenceTest, RMSNormStaticShapeInferenceTestDefaultCtor) { + const auto op = std::make_shared(); + const auto data = std::make_shared(element::f16, PartialShape::dynamic()); + const auto axes = std::make_shared(element::i64, PartialShape::dynamic()); + const auto scale = std::make_shared(element::f16, PartialShape::dynamic()); + + op->set_arguments(ov::OutputVector{data, axes, scale}); + + std::vector static_input_shapes = {StaticShape{2, 3, 8, 6}, StaticShape{1}, StaticShape{1}}; + int32_t axis_val = -1; + const auto const_data = std::unordered_map{{1, {element::i32, Shape{1}, &axis_val}}}; + const auto static_output_shapes = shape_inference(op.get(), static_input_shapes, const_data); + EXPECT_EQ(static_output_shapes[0], StaticShape({2, 3, 8, 6})); +} + +TEST(StaticShapeInferenceTest, RMSNormStaticShapeInferenceTest2ins) { + const auto data = std::make_shared(element::f32, PartialShape::dynamic()); + const auto axes = std::make_shared(element::i32, PartialShape::dynamic()); + const auto eps = 1e-5f; + + const auto op = std::make_shared(data, axes, eps); + + std::vector static_input_shapes = {StaticShape{2, 3, 8, 6}, StaticShape{1}}; + int32_t axis_val = -1; + const auto const_data = std::unordered_map{{1, {element::i32, Shape{1}, &axis_val}}}; + const auto static_output_shapes = shape_inference(op.get(), static_input_shapes, const_data); + EXPECT_EQ(static_output_shapes[0], StaticShape({2, 3, 8, 6})); +} + +TEST(StaticShapeInferenceTest, RMSNormStaticShapeInferenceTest3ins) { + const auto data = std::make_shared(element::f32, PartialShape::dynamic()); + const auto axes = std::make_shared(element::i32, PartialShape::dynamic()); + const auto scale = std::make_shared(element::f32, PartialShape::dynamic()); + const auto eps = 1e-5f; + + const auto op = std::make_shared(data, axes, scale, eps); + + std::vector static_input_shapes = {StaticShape{2, 3, 8, 6}, StaticShape{1}, StaticShape{1}}; + int32_t axis_val = -1; + const auto const_data = std::unordered_map{{1, {element::i32, Shape{1}, &axis_val}}}; + const auto static_output_shapes = shape_inference(op.get(), static_input_shapes, const_data); + EXPECT_EQ(static_output_shapes[0], StaticShape({2, 3, 8, 6})); +} + +TEST(StaticShapeInferenceTest, RMSNormIncorrectAxisValParam) { + const auto data = std::make_shared(element::f32, PartialShape::dynamic()); + const auto axes = std::make_shared(element::i32, PartialShape::dynamic()); + const auto eps = 1e-5f; + + const auto op = std::make_shared(data, axes, eps); + + std::vector static_input_shapes = {StaticShape{2, 3, 8, 6}, StaticShape{1}}; + int32_t axis_val = 5; + const auto const_data = std::unordered_map{{1, {element::i32, Shape{1}, &axis_val}}}; + + OV_EXPECT_THROW(shape_inference(op.get(), static_input_shapes, const_data), + NodeValidationFailure, + HasSubstr("Parameter axis 5 out of the tensor rank range [-4, 3]")); +} + +TEST(StaticShapeInferenceTest, RMSNormIncorrectAxisValConst) { + const auto data = std::make_shared(element::f32, PartialShape::dynamic()); + const auto axes = std::make_shared(element::i32, Shape{}, 5); + const auto eps = 1e-5f; + + const auto op = std::make_shared(data, axes, eps); + + std::vector static_input_shapes = {StaticShape{2, 3, 8, 6}, StaticShape{}}; + + OV_EXPECT_THROW(shape_inference(op.get(), static_input_shapes), + NodeValidationFailure, + HasSubstr("Parameter axis 5 out of the tensor rank range [-4, 3]")); +} + +TEST(StaticShapeInferenceTest, RMSNormIncorrectAxisShapeDim) { + const auto data = std::make_shared(element::f32, PartialShape::dynamic()); + const auto axes = std::make_shared(element::i32, PartialShape::dynamic()); + const auto eps = 1e-5f; + + const auto op = std::make_shared(data, axes, eps); + + std::vector static_input_shapes = {StaticShape{2, 3, 8, 6}, StaticShape{5}}; + int32_t axis_val = 5; + const auto const_data = std::unordered_map{{1, {element::i32, Shape{1}, &axis_val}}}; + + OV_EXPECT_THROW(shape_inference(op.get(), static_input_shapes, const_data), + NodeValidationFailure, + HasSubstr("Number of the axes can't be higher than the rank of the data shape")); +} + +TEST(StaticShapeInferenceTest, RMSNormIncorrectAxisShapeRank) { + const auto data = std::make_shared(element::f32, PartialShape::dynamic()); + const auto axes = std::make_shared(element::i32, PartialShape::dynamic()); + const auto eps = 1e-5f; + + const auto op = std::make_shared(data, axes, eps); + + std::vector static_input_shapes = {StaticShape{2, 3, 8, 6}, StaticShape{1, 5}}; + int32_t axis_val = 5; + const auto const_data = std::unordered_map{{1, {element::i32, Shape{1}, &axis_val}}}; + + OV_EXPECT_THROW(shape_inference(op.get(), static_input_shapes, const_data), + NodeValidationFailure, + HasSubstr("Axes input must be a scalar or 1D input. Got: {1,5}")); +} + +TEST(StaticShapeInferenceTest, RMSNormIncorrectScaleShape) { + const auto data = std::make_shared(element::f32, PartialShape::dynamic()); + const auto axes = std::make_shared(element::i32, PartialShape::dynamic()); + const auto scale = std::make_shared(element::f32, PartialShape::dynamic()); + const auto eps = 1e-5f; + + const auto op = std::make_shared(data, axes, scale, eps); + + std::vector static_input_shapes = {StaticShape{2, 3, 8, 6}, StaticShape{1}, StaticShape{6, 1}}; + int32_t axis_val = -1; + const auto const_data = std::unordered_map{{1, {element::i32, Shape{1}, &axis_val}}}; + + OV_EXPECT_THROW(shape_inference(op.get(), static_input_shapes, const_data), + NodeValidationFailure, + HasSubstr("Scale input shape must be broadcastable to the shape of the data input")); +}