diff --git a/src/common/low_precision_transformations/include/low_precision/fold_convert.hpp b/src/common/low_precision_transformations/include/low_precision/fold_convert.hpp index 640cdda59e6947..d772121a841610 100644 --- a/src/common/low_precision_transformations/include/low_precision/fold_convert.hpp +++ b/src/common/low_precision_transformations/include/low_precision/fold_convert.hpp @@ -15,6 +15,7 @@ namespace low_precision { /** * @ingroup ie_transformation_common_api * @brief FoldConvertTransformation evaluates Convert operation on Subtract constant subgraph. + * Important notice: this transformation ignores DisableConstantFolding runtime attribute. * * For more details about the transformation, refer to * [FoldConvertTransformation](@ref openvino_docs_OV_UG_lpt_FoldConvertTransformation) page diff --git a/src/plugins/intel_cpu/src/graph_optimizer.cpp b/src/plugins/intel_cpu/src/graph_optimizer.cpp index 5bf9fb6247a4fa..6b7eb0ddf11d59 100644 --- a/src/plugins/intel_cpu/src/graph_optimizer.cpp +++ b/src/plugins/intel_cpu/src/graph_optimizer.cpp @@ -325,16 +325,22 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) { const auto mulParent = multiplyNode->getParentEdgesAtPort(0)[0]->getParent(); const bool withSubtract = mulParent->getAlgorithm() == Algorithm::EltwiseSubtract; - NodePtr subtractNode, subtractConstNode; + NodePtr subtractNode, subtractConvertNode, subtractConstNode; if (withSubtract) { subtractNode = mulParent; if (!expectedNode(subtractNode, Type::Eltwise)) continue; - subtractConstNode = subtractNode->getParentEdgesAtPort(1)[0]->getParent(); + auto subtractParent = subtractNode->getParentEdgesAtPort(1)[0]->getParent(); + if (expectedNode(subtractParent, Type::Convert)) { + subtractConvertNode = subtractParent; + subtractParent = subtractConvertNode->getParentEdgesAtPort(0)[0]->getParent(); + } + subtractConstNode = subtractParent; if (!expectedNode(subtractConstNode, Type::Input)) continue; } + const bool withSubtractConvert = subtractConvertNode != nullptr; const bool withPowerStatic = mulParent->getAlgorithm() == Algorithm::EltwisePowerStatic; NodePtr powerStaticNode; if (withPowerStatic) { @@ -364,12 +370,6 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) { continue; // Precision limitations - if (multiplyConstNode->getOriginalOutputPrecisionAtPort(0) != Precision::FP32) - continue; - if (withSubtract && subtractConstNode->getOriginalOutputPrecisionAtPort(0) != Precision::FP32) - continue; - if (withPowerStatic && powerStaticNode->getOriginalOutputPrecisionAtPort(0) != Precision::FP32) - continue; if (supportedDataPrecisions.find(fcNode->getOriginalInputPrecisionAtPort(0)) == supportedDataPrecisions.end()) continue; if (supportedWeightsPrecisions.find(weightsNode->getOriginalOutputPrecisionAtPort(0)) == supportedWeightsPrecisions.end()) @@ -403,9 +403,17 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) { decompressionConstShape = withTranspose ? VectorDims{N, 1, O} : VectorDims{O, N, 1}; groupNum = N; } - if (multiplyConstNode->getOutputShapeAtPort(0).getDims() != decompressionConstShape) + + auto check_decompression_shape = [&decompressionConstShape](const VectorDims& shape_to_check) { + if (shape_to_check.size() > decompressionConstShape.size()) + return false; + const auto comparison_start_pos = decompressionConstShape.size() - shape_to_check.size(); + // in case of different ranks shapes are compared taking into account ranks numpy broadcasting + return std::equal(shape_to_check.begin(), shape_to_check.end(), decompressionConstShape.begin() + comparison_start_pos); + }; + if (!check_decompression_shape(multiplyConstNode->getOutputShapeAtPort(0).getDims())) continue; - if (withSubtract && subtractConstNode->getOutputShapeAtPort(0).getDims() != decompressionConstShape) + if (withSubtract && !check_decompression_shape(subtractConstNode->getOutputShapeAtPort(0).getDims())) continue; // HW specific shape limitations @@ -460,6 +468,11 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) { fcNode->addOriginalLayer(multiplyNode->getOriginalLayers()); fcNode->addOriginalLayer(convertNode->getOriginalLayers()); + if (withSubtractConvert) { + fcNode->addOriginalLayer(subtractConvertNode->getOriginalLayers()); + auto subtractConvertEdge = subtractConvertNode->getChildEdges()[0].lock(); + graph.RemoveEdge(subtractConvertEdge); + } if (withSubtract) { fcNode->addOriginalLayer(subtractNode->getOriginalLayers()); auto subtractConstEdge = subtractConstNode->getChildEdges()[0].lock(); @@ -473,6 +486,8 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) { graph.RemoveEdge(multiplyConstEdge); graph.DropNode(convertNode); + if (withSubtractConvert) + graph.DropNode(subtractConvertNode); if (withSubtract) graph.DropNode(subtractNode); if (withPowerStatic) diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/move_fc_reshape_to_weights.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/move_fc_reshape_to_weights.cpp index 017d06fb57fedc..e9dbd8bb613759 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/move_fc_reshape_to_weights.cpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/move_fc_reshape_to_weights.cpp @@ -21,21 +21,27 @@ ov::intel_cpu::MoveFCReshapeToWeights::MoveFCReshapeToWeights() { MATCHER_SCOPE(MoveFCReshapeToWeights); using namespace ov::pass::pattern; auto weights_m = wrap_type(consumers_count(1)); - auto convert_m = wrap_type({weights_m}); + auto convert_m = wrap_type({weights_m}, consumers_count(1)); + + auto one_consumer_rank_equals = [](const ov::Dimension& expected_rank) { + return [=](ov::Output output) -> bool { + return consumers_count(1)(output) && rank_equals(expected_rank)(output); + }; + }; auto sub_const_m = wrap_type(consumers_count(1)); - auto subtract_m = wrap_type({convert_m, sub_const_m}); + auto subtract_wo_convert_m = wrap_type({convert_m, sub_const_m}, consumers_count(1)); + auto sub_convert = wrap_type({sub_const_m}, consumers_count(1)); + auto subtract_w_convert_m = wrap_type({convert_m, sub_convert}, consumers_count(1)); + auto subtract_m = std::make_shared(OutputVector{subtract_wo_convert_m, subtract_w_convert_m}); auto mul_const_m = wrap_type(consumers_count(1)); - auto mul_with_sub_m = wrap_type({subtract_m, mul_const_m}, rank_equals(3)); - auto mul_no_sub_m = wrap_type({convert_m, mul_const_m}, rank_equals(3)); + auto mul_with_sub_m = wrap_type({subtract_m, mul_const_m}, one_consumer_rank_equals(3)); + auto mul_no_sub_m = wrap_type({convert_m, mul_const_m}, one_consumer_rank_equals(3)); auto mul_m = std::make_shared(OutputVector{mul_with_sub_m, mul_no_sub_m}); - auto one_consumer_rank_2 = [](const ov::Output& out) { - return consumers_count(1)(out) && rank_equals(2)(out); - }; auto reshape_const_m = wrap_type(consumers_count(1)); - auto reshape_m = wrap_type({mul_m, reshape_const_m}, one_consumer_rank_2); + auto reshape_m = wrap_type({mul_m, reshape_const_m}, one_consumer_rank_equals(2)); auto transpose_const_m = wrap_type(); auto transpose_m = wrap_type({reshape_m, transpose_const_m}); @@ -58,21 +64,24 @@ ov::intel_cpu::MoveFCReshapeToWeights::MoveFCReshapeToWeights() { const auto& fc_input_shape = fully_connected->get_input_shape(1); const auto reshape = with_transpose ? weights_path->get_input_node_shared_ptr(0) : weights_path; - auto check_decompression_const = [&](const std::shared_ptr& node) { - if (!ov::is_type(node)) - return false; + auto check_decompression_shape = [&](const std::shared_ptr& node) { ov::Shape expected_shape(3, 1); const size_t out_channels_idx = with_transpose ? 2 : 1; expected_shape[out_channels_idx] = fc_input_shape[0]; - return node->get_output_shape(0) == expected_shape; + const auto& node_shape = node->get_output_shape(0); + if (node_shape.size() > expected_shape.size()) + return false; + + const auto comparison_start_pos = expected_shape.size() - node_shape.size(); + return std::equal(node_shape.begin(), node_shape.end(), expected_shape.begin() + comparison_start_pos); }; const auto mul = reshape->get_input_node_shared_ptr(0); - if (!check_decompression_const(mul->get_input_node_shared_ptr(1))) + if (!check_decompression_shape(mul->get_input_node_shared_ptr(1))) return false; const auto mul_parent = mul->get_input_node_shared_ptr(0); const bool with_subtract = ov::is_type(mul_parent); - if (with_subtract && !check_decompression_const(mul_parent->get_input_node_shared_ptr(1))) + if (with_subtract && !check_decompression_shape(mul_parent->get_input_node_shared_ptr(1))) return false; const auto convert = with_subtract ? mul_parent->get_input_node_shared_ptr(0) : mul_parent; @@ -83,22 +92,29 @@ ov::intel_cpu::MoveFCReshapeToWeights::MoveFCReshapeToWeights() { if (weights->get_output_shape(0) != expected_weights_shape) return false; - auto squeeze_constant = [](const std::shared_ptr& node) { + auto squeeze_constant = [&](const std::shared_ptr& node) { const auto constant = ov::as_type_ptr(node); + OPENVINO_ASSERT(constant, "squeeze_constant is called for non constant node"); auto shape = constant->get_shape(); - shape.erase(shape.begin()); - const auto new_constant = std::make_shared(*constant, shape); - ov::replace_node(constant, new_constant); - ov::copy_runtime_info(constant, new_constant); - new_constant->set_friendly_name(constant->get_friendly_name()); + if (shape.size() - fc_input_shape.size() == 1) { + shape.erase(shape.begin()); + const auto new_constant = std::make_shared(*constant, shape); + ov::replace_node(constant, new_constant); + ov::copy_runtime_info(constant, new_constant); + new_constant->set_friendly_name(constant->get_friendly_name()); + } }; // We can remove 3D->2D reshape if we manually reshape all constants in the weights subgraph ov::replace_output_update_name(reshape->output(0), reshape->input_value(0)); squeeze_constant(mul->get_input_node_shared_ptr(1)); squeeze_constant(weights); - if (with_subtract) - squeeze_constant(mul_parent->get_input_node_shared_ptr(1)); + if (with_subtract) { + auto sub_const = mul_parent->get_input_node_shared_ptr(1); + if (ov::is_type(sub_const)) + sub_const = sub_const->get_input_node_shared_ptr(0); + squeeze_constant(sub_const); + } return true; }; diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/move_fc_reshape_to_weights.hpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/move_fc_reshape_to_weights.hpp index ffe7e7a6f44e34..7856f04022332b 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/move_fc_reshape_to_weights.hpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/move_fc_reshape_to_weights.hpp @@ -13,9 +13,9 @@ namespace intel_cpu { * This transformation is applied to the FC with compressed 3D u8 weights. It moves Reshape at the weights path to the constants * in order to constant fold the Reshape node. * Example: - * Weights(3D) Weights(2D) - * | | - * Convert Subtract_const(3D) Convert Subtract_const(2D) + * Weights(3D) Subtract_const(3D) Weights(2D) Subtract_const(2D) + * | / | / + * Convert Subtract_convert(opt) Convert Subtract_convert(opt) * | / | / * Subtract(opt) Subtract(opt) * | Multiply_const(3D) ====> | Multiply_const(2D) diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index c16cfc2d648768..fad9321b340de9 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -87,10 +87,12 @@ #include "low_precision/add.hpp" #include "low_precision/convert_subtract_constant.hpp" #include "low_precision/convolution_backprop_data.hpp" +#include "low_precision/fold_convert.hpp" +#include "low_precision/fuse_convert.hpp" #include "low_precision/group_convolution.hpp" #include "low_precision/multiply_to_group_convolution.hpp" -#include "low_precision/recurrent_cell.hpp" #include "low_precision/network_helper.hpp" +#include "low_precision/recurrent_cell.hpp" #include "low_precision/rt_info/bias_attribute.hpp" #include "transformations/low_precision/mark_dequantization_subgraph.hpp" @@ -130,6 +132,35 @@ namespace intel_cpu { using const_node_ptr = const std::shared_ptr; +bool Transformations::is_decompression_multiply(const_node_ptr& node) const { + auto get_single_consumer = [](const_node_ptr& node) -> std::shared_ptr { + const auto consumers = node->get_output_target_inputs(0); + if (consumers.size() != 1) + return nullptr; + return consumers.begin()->get_node()->shared_from_this(); + }; + + auto consumer = get_single_consumer(node); + if (!consumer) + return false; + + if (ov::is_type(consumer)) { + return true; + } else if (ov::is_type(consumer)) { + consumer = get_single_consumer(consumer); + if (consumer != nullptr && ov::is_type(consumer)) { + return true; + } + } + if (consumer != nullptr && ov::is_type(consumer)) { + consumer = get_single_consumer(consumer); + if (consumer != nullptr && ov::is_type(consumer)) { + return true; + } + } + return false; +} + bool Transformations::fuse_type_to_convert(const std::shared_ptr& node, const precisions_map& precisions) { auto convert = ov::as_type_ptr(node); if (!convert) @@ -206,59 +237,35 @@ void Transformations::CpuSpecificOpSet(void) { void Transformations::PreLpt(const std::vector& defaultPrecisions, const bool isLegacyApi) { CPU_DEBUG_CAP_TRANSFORMATION_SCOPE(this, PreLpt); + // Decompression handling related transformations must be run separately from common preLPT pipeline + // since there is used the same transformations as in LPT related transformations, but with the specific settings. + // This must be done in order to keep compressed MatMul weights with decompression operations as is + ov::pass::Manager decompression_handling_manager; + decompression_handling_manager.set_per_pass_validation(false); + CPU_REGISTER_PASS_COMMON(decompression_handling_manager, ov::pass::InitNodeInfo); + CPU_REGISTER_PASS_COMMON(decompression_handling_manager, ov::pass::MarkShapeOfSubgraphs); + // We need to fuse Transpose to MatMul to have a simpler callback for the next transformation + CPU_REGISTER_PASS_X64(decompression_handling_manager, ov::pass::TransposeMatMul); + ov::element::TypeVector decompression_precisions{ov::element::u8}; + // We don't have BF16/FP16 FullyConnected kernels to work with 4bits compressed weights + // Convert node doesn't support 4bit precisions -> fallback on constant folding + if (inferencePrecision == ov::element::f32) { + decompression_precisions.push_back(ov::element::u4); + decompression_precisions.push_back(ov::element::i4); + decompression_precisions.push_back(ov::element::nf4); + } + // Ticket 124834: set fold_subtract_const to false when cpu_convert supports i4/u4/nf4 precisions + CPU_REGISTER_PASS_X64(decompression_handling_manager, ov::pass::MarkDequantizationSubgraph, decompression_precisions, true); + CPU_SET_CALLBACK_X64(decompression_handling_manager, [&](const_node_ptr &node) -> bool { + return !is_decompression_multiply(node); + }, ov::pass::MarkDequantizationSubgraph); + decompression_handling_manager.run_passes(model); + ov::pass::Manager manager; manager.set_per_pass_validation(false); - CPU_REGISTER_PASS_COMMON(manager, ov::pass::InitNodeInfo); - CPU_REGISTER_PASS_COMMON(manager, ov::pass::MarkShapeOfSubgraphs); - const bool useLpt = !defaultPrecisions.empty(); - if (useLpt) { + if (useLpt) CPU_REGISTER_PASS_COMMON(manager, ov::pass::MarkDequantizationSubgraph, defaultPrecisions); - } else { - // We need to fuse Transpose to MatMul to have a simpler callback for the next transformation - CPU_REGISTER_PASS_COMMON(manager, ov::pass::TransposeMatMul); - ov::element::TypeVector decompression_precisions{ - ov::element::u8 - }; - // We don't have BF16/FP16 FullyConnected kernels to work with 4bits compressed weights - // Convert node doesn't support 4bit precisions -> fallback on constant folding - if (inferencePrecision == ov::element::f32) { - decompression_precisions.push_back(ov::element::u4); - decompression_precisions.push_back(ov::element::i4); - decompression_precisions.push_back(ov::element::nf4); - } - // MarkDequantizationSubgraph is used even in non-LPT pipeline on X64 platforms - // in order to keep compressed MatMul weights with decompression operations as is - CPU_REGISTER_PASS_X64(manager, ov::pass::MarkDequantizationSubgraph, decompression_precisions, true); - CPU_SET_CALLBACK_X64(manager, [](const_node_ptr &node) -> bool { - auto get_single_consumer = [](const_node_ptr &node) -> std::shared_ptr { - const auto consumers = node->get_output_target_inputs(0); - if (consumers.size() != 1) - return nullptr; - return consumers.begin()->get_node()->shared_from_this(); - }; - - auto consumer = get_single_consumer(node); - if (!consumer) - return true; - - if (ov::is_type(consumer)) { - return false; - } else if (ov::is_type(consumer)) { - consumer = get_single_consumer(consumer); - if (consumer != nullptr && ov::is_type(consumer)) { - return false; - } - } - if (consumer != nullptr && ov::is_type(consumer)) { - consumer = get_single_consumer(consumer); - if (consumer != nullptr && ov::is_type(consumer)) { - return false; - } - } - return true; - }, ov::pass::MarkDequantizationSubgraph); - } auto get_convert_precisions = [&]() { precisions_map map = { @@ -565,32 +572,47 @@ void Transformations::Lpt(const bool hasINT16orINT32Levels, const std::vector bool { - if (const auto mulitply = std::dynamic_pointer_cast(node)) { - return !MultiplyToGroupConvolutionTransformation::canBeTransformedToGroupConvolution(mulitply); - } - return false; - }, - ov::pass::low_precision::MarkupPrecisions); - CPU_SET_CALLBACK_COMMON(lptManager, - [&defaultPrecisions](const_node_ptr& node) -> bool { - return LayerTransformation::isAsymmetricQuantization(node, defaultPrecisions) || - WeightableLayerTransformation::isAsymmetricOnWeights(node, defaultPrecisions); - }, - ov::pass::low_precision::ConvolutionBackpropDataTransformation); - lptManager.get_pass_config()->set_callback( - [](const_node_ptr& node) -> bool { - return ov::marked_as_bias(node); - }); + CPU_SET_CALLBACK_COMMON(lptManager, [](const_node_ptr& node) -> bool { + return ov::is_type(node) && + !MultiplyToGroupConvolutionTransformation::canBeTransformedToGroupConvolution(node); + }, MarkupPrecisions); + CPU_SET_CALLBACK_COMMON(lptManager, [&defaultPrecisions](const_node_ptr& node) -> bool { + return LayerTransformation::isAsymmetricQuantization(node, defaultPrecisions) || + WeightableLayerTransformation::isAsymmetricOnWeights(node, defaultPrecisions); + }, ConvolutionBackpropDataTransformation); + CPU_SET_CALLBACK_COMMON(lptManager, [](const_node_ptr& node) -> bool { + return ov::marked_as_bias(node); + }, AddTransformation); + + CPU_SET_CALLBACK_X64(lptManager, [&](const_node_ptr& node) -> bool { + const auto& consumers = node->get_output_target_inputs(0); + if (consumers.size() == 1) { + const auto consumer = consumers.begin()->get_node()->shared_from_this(); + return ov::is_type(consumer) && is_decompression_multiply(consumer); + } + return false; + }, FoldConvertTransformation); + + CPU_SET_CALLBACK_X64(lptManager, [&](const_node_ptr& node) -> bool { + if (ov::is_type(node)) { + return ov::is_type(node) && is_decompression_multiply(node); + } else if (ov::is_type(node)) { + const auto& consumers = node->get_output_target_inputs(0); + if (consumers.size() == 1) { + const auto consumer = consumers.begin()->get_node()->shared_from_this(); + return ov::is_type(consumer) && is_decompression_multiply(consumer); + } + } + return false; + }, FuseConvertTransformation); - CPU_DISABLE_PASS_ARM(lptManager, ov::pass::low_precision::RecurrentCellTransformation); - CPU_DISABLE_PASS_COMMON(lptManager, ov::pass::low_precision::MultiplyToGroupConvolutionTransformation); + CPU_DISABLE_PASS_ARM(lptManager, RecurrentCellTransformation); + CPU_DISABLE_PASS_COMMON(lptManager, MultiplyToGroupConvolutionTransformation); lptManager.run_passes(model); } diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.h b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.h index c1ef6a8b138951..a8824a2bf6afdd 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.h +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.h @@ -58,6 +58,8 @@ class Transformations { void PostSnippets(void); + bool is_decompression_multiply(const std::shared_ptr& node) const; + static bool fuse_type_to_convert(const std::shared_ptr& node, const precisions_map& precisions); }; diff --git a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/matmul_weights_decompression.cpp b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/matmul_weights_decompression.cpp index 35eb91d3fff04e..3df4f06e31b1bb 100644 --- a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/matmul_weights_decompression.cpp +++ b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/matmul_weights_decompression.cpp @@ -331,9 +331,10 @@ const std::vector input_shapes_amx = { {{{}, {{11, 339, 577}}}, {577, 335}}, {{{}, {{1, 1, 256}}}, {256, 128}, 64ul}, }; -const std::vector fusingParamsSet { +const std::vector fusing_params { emptyFusingSpec, fusingBias, + fusingFakeQuantizePerTensorRelu }; INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_basic, @@ -345,7 +346,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_basic, ::testing::Values(true), ::testing::Values(true), ::testing::ValuesIn(filterAdditionalConfigBasic()), - ::testing::ValuesIn(fusingParamsSet), + ::testing::ValuesIn(fusing_params), ::testing::Values(true)), MatmulWeightsDecompression::getTestCaseName); @@ -358,7 +359,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_amx, ::testing::Values(true), ::testing::Values(true), ::testing::ValuesIn(filterAdditionalConfigAMX()), - ::testing::ValuesIn(fusingParamsSet), + ::testing::ValuesIn(fusing_params), ::testing::Values(true)), MatmulWeightsDecompression::getTestCaseName); diff --git a/src/plugins/intel_cpu/tests/unit/ngraph_transformations/move_fc_reshape_to_weights.cpp b/src/plugins/intel_cpu/tests/unit/ngraph_transformations/move_fc_reshape_to_weights.cpp index bcd8efdd5c0ead..0ba159565a1f90 100644 --- a/src/plugins/intel_cpu/tests/unit/ngraph_transformations/move_fc_reshape_to_weights.cpp +++ b/src/plugins/intel_cpu/tests/unit/ngraph_transformations/move_fc_reshape_to_weights.cpp @@ -20,45 +20,70 @@ using namespace testing; using namespace ov::intel_cpu; +enum class ZeroPointType { NO_ZP, ZP_WEIGHTS_PRC, ZP_DECOMPRESSION_PRC }; +inline std::ostream& operator<<(std::ostream& os, ZeroPointType type) { + switch (type) { + case ZeroPointType::NO_ZP: + os << "NO_ZP"; + break; + case ZeroPointType::ZP_WEIGHTS_PRC: + os << "ZP_WEIGHTS_PRC"; + break; + case ZeroPointType::ZP_DECOMPRESSION_PRC: + os << "ZP_DECOMPRESSION_PRC"; + break; + default: + OPENVINO_THROW("Unknown ZeroPointType"); + } + return os; +} + using MoveFCReshapeToWeightsParams = std::tuple, // data_shape - weights_shape bool, // add transpose - bool>; // add subtract + ZeroPointType>; class MoveFCReshapeToWeightsTests : public TransformationTestsF, public WithParamInterface { public: static std::string getTestCaseName(testing::TestParamInfo obj) { std::pair input_shapes; bool add_transpose; - bool add_subtract; - std::tie(input_shapes, add_transpose, add_subtract) = obj.param; + ZeroPointType zp_type; + std::tie(input_shapes, add_transpose, zp_type) = obj.param; std::ostringstream result; result << "Input_shape=(" << input_shapes.first << ")_Weights_shape=(" << input_shapes.second - << ")_add_transpose=" << add_transpose << "_add_subtract=" << add_subtract; + << ")_add_transpose=" << add_transpose << "_zp_type=" << zp_type; return result.str(); } static std::shared_ptr initModel(const ov::PartialShape& data_shape, const ov::Shape& weights_shape, const bool add_transpose, - const bool add_subtract, + const ZeroPointType zp_type, const bool add_reshape) { - auto data = std::make_shared(ov::element::f32, data_shape); + const auto decompression_prc = ov::element::f32; + const auto weights_prc = ov::element::u8; + auto data = std::make_shared(decompression_prc, data_shape); auto transposed_shape = weights_shape; if (add_transpose) std::swap(*(transposed_shape.rbegin() + 1), *transposed_shape.rbegin()); - std::shared_ptr weights_path = ov::opset1::Constant::create(ov::element::u8, transposed_shape, {1}); - weights_path = std::make_shared(weights_path, ov::element::f32); + std::shared_ptr weights_path = ov::opset1::Constant::create(weights_prc, transposed_shape, {1}); + weights_path = std::make_shared(weights_path, decompression_prc); ov::Shape decompression_shape(weights_shape.size(), 1); const size_t n_idx = add_transpose ? transposed_shape.size() - 1 : transposed_shape.size() - 2; decompression_shape[n_idx] = transposed_shape[n_idx]; - if (add_subtract) { - auto sub_const = ov::opset1::Constant::create(ov::element::f32, decompression_shape, {1}); + if (zp_type == ZeroPointType::ZP_DECOMPRESSION_PRC) { + auto sub_const = ov::opset1::Constant::create(weights_prc, decompression_shape, {1}); + auto sub_convert = std::make_shared(sub_const, decompression_prc); + weights_path = std::make_shared(weights_path, sub_convert); + } else if (zp_type == ZeroPointType::ZP_WEIGHTS_PRC) { + auto sub_const = ov::opset1::Constant::create(decompression_prc, decompression_shape, {1}); weights_path = std::make_shared(weights_path, sub_const); } - auto mul_const = ov::opset1::Constant::create(ov::element::f32, decompression_shape, {1}); + + auto mul_const = ov::opset1::Constant::create(decompression_prc, decompression_shape, {1}); weights_path = std::make_shared(weights_path, mul_const); if (add_reshape) { @@ -80,13 +105,13 @@ class MoveFCReshapeToWeightsTests : public TransformationTestsF, public WithPara TransformationTestsF::SetUp(); std::pair input_shapes; bool add_transpose; - bool add_subtract; - std::tie(input_shapes, add_transpose, add_subtract) = this->GetParam(); + ZeroPointType zp_type; + std::tie(input_shapes, add_transpose, zp_type) = this->GetParam(); ov::Shape ref_weights_shape = input_shapes.second; ref_weights_shape.erase(ref_weights_shape.begin()); - model = initModel(input_shapes.first, input_shapes.second, add_transpose, add_subtract, true); - model_ref = initModel(input_shapes.first, ref_weights_shape, add_transpose, add_subtract, false); + model = initModel(input_shapes.first, input_shapes.second, add_transpose, zp_type, true); + model_ref = initModel(input_shapes.first, ref_weights_shape, add_transpose, zp_type, false); manager.register_pass(); } }; @@ -97,11 +122,15 @@ const std::vector> input_shapes_wo_transp {{-1, -1, -1}, {1, 4, 3}} }; const std::vector add_transpose = {false, true}; -const std::vector add_subtract = {false, true}; +const std::vector zp_types = { + ZeroPointType::NO_ZP, + ZeroPointType::ZP_DECOMPRESSION_PRC, + ZeroPointType::ZP_WEIGHTS_PRC +}; INSTANTIATE_TEST_SUITE_P(TransformationTests_wo_transpose, MoveFCReshapeToWeightsTests, ::testing::Combine( ::testing::ValuesIn(input_shapes_wo_transpose), ::testing::ValuesIn(add_transpose), - ::testing::ValuesIn(add_subtract)), + ::testing::ValuesIn(zp_types)), MoveFCReshapeToWeightsTests::getTestCaseName);