From 5a119fb2498f798571d58b0cb21bb8ede8bcf271 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Wed, 14 Aug 2024 19:10:21 +0200 Subject: [PATCH] [PT FE] Use GroupNormalization for aten::group_norm instead of MVN (#26063) ### Details: - *Use GroupNormalization for `aten::group_norm` instead of MVN* ### Tickets: - *ticket-id* --- src/frontends/pytorch/src/op/group_norm.cpp | 51 ++++++++++----------- 1 file changed, 24 insertions(+), 27 deletions(-) diff --git a/src/frontends/pytorch/src/op/group_norm.cpp b/src/frontends/pytorch/src/op/group_norm.cpp index 2ab51533c42288..42daa0f63f20d7 100644 --- a/src/frontends/pytorch/src/op/group_norm.cpp +++ b/src/frontends/pytorch/src/op/group_norm.cpp @@ -3,15 +3,13 @@ // #include "openvino/frontend/pytorch/node_context.hpp" -#include "openvino/op/add.hpp" +#include "openvino/op/broadcast.hpp" #include "openvino/op/constant.hpp" -#include "openvino/op/multiply.hpp" -#include "openvino/op/mvn.hpp" -#include "openvino/op/range.hpp" -#include "openvino/op/reshape.hpp" -#include "openvino/op/subtract.hpp" +#include "openvino/op/convert_like.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/group_normalization.hpp" +#include "openvino/op/shape_of.hpp" #include "openvino/op/unsqueeze.hpp" -#include "openvino/op/util/framework_node.hpp" #include "utils.hpp" namespace ov { @@ -33,30 +31,29 @@ OutputVector translate_group_norm_common(const NodeContext& context, auto data = context.get_input(0); auto num_groups = context.const_input(group_idx); // input 2 - weights and input 3 - bias are optional without default value, we handle them later - auto eps = static_cast(context.const_input(eps_idx)); - Output input_shape; - Output input_rank; - std::tie(input_shape, input_rank) = get_shape_rank(context, data, true, element::i32); - auto scalar_one = context.mark_node(v0::Constant::create(element::i32, {}, {1})); - auto shape = context.mark_node( - std::make_shared(element::i32, Shape({3}), std::vector{0, num_groups, -1})); - auto reshaped_input = context.mark_node(std::make_shared(data, shape, true)); - auto reduction_axes = context.mark_node(v0::Constant::create(element::i32, Shape({1}), std::vector(1, 2))); - auto reshaped_norm = context.mark_node( - std::make_shared(reshaped_input, reduction_axes, true, eps, MVNEpsMode::INSIDE_SQRT)); - auto norm = context.mark_node(std::make_shared(reshaped_norm, input_shape, true)); - auto skip_last = context.mark_node(std::make_shared(input_rank, scalar_one)); - auto axes = context.mark_node(std::make_shared(scalar_one, skip_last, scalar_one, element::i32)); + auto eps = context.const_input(eps_idx); + + auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); + auto one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1})); + auto shape = context.mark_node(std::make_shared(data, element::i32)); + auto channels = context.mark_node(std::make_shared(shape, one, zero)); + channels = context.mark_node(std::make_shared(channels, zero)); + + Output scale; if (!context.input_is_none(weights_idx)) { - auto weights = context.get_input(static_cast(weights_idx)); - weights = context.mark_node(std::make_shared(weights, axes)); - norm = context.mark_node(std::make_shared(norm, weights)); + scale = context.get_input(static_cast(weights_idx)); + } else { + scale = context.mark_node(std::make_shared(one, channels)); + scale = context.mark_node(std::make_shared(scale, data)); } + Output bias; if (!context.input_is_none(bias_idx)) { - auto bias = context.get_input(static_cast(bias_idx)); - bias = context.mark_node(std::make_shared(bias, axes)); - norm = context.mark_node(std::make_shared(norm, bias)); + bias = context.get_input(static_cast(bias_idx)); + } else { + bias = context.mark_node(std::make_shared(zero, channels)); + bias = context.mark_node(std::make_shared(bias, data)); } + auto norm = context.mark_node(std::make_shared(data, scale, bias, num_groups, eps)); // Input with index 5 is flag "cudnn_enabled" we can ignore it return {norm}; };