diff --git a/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/dcoff.cpp b/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/dcoff.cpp index ffbece94b04176..4928be783317fd 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/dcoff.cpp +++ b/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/dcoff.cpp @@ -619,7 +619,7 @@ DCOFFPassReshape3::DCOFFPassReshape3(DCOffMode dcoff_mode, ov::element::Type dco register_matcher(std::make_shared(cvt, "TagDCOFFPassReshape3"), std::move(callback)); } -// Pattern: i4 Phi-3 4SymW16A +// Pattern: i4 group-quant // // // "tensor" "scale" > "tensor" @@ -627,12 +627,12 @@ DCOFFPassReshape3::DCOFFPassReshape3(DCOffMode dcoff_mode, ov::element::Type dco // i4 f16|f32 > f16 // : : > : // V : > V -// Convert : > Convert -// f16|f32 : > f32 -// : : > -// V V > -// Multiply > -// f16|f32 > +// Convert : > [ Convert] +// f16|f32 : > [ f32 ] +// : : > : +// V V > V +// Multiply > Reshape +// f16|f32 > f16|f32 // : > // : > // Reshape > @@ -657,6 +657,8 @@ DCOFFPassReshape4::DCOFFPassReshape4(DCOffMode dcoff_mode, ov::element::Type dco auto matched_paramA = std::static_pointer_cast(matched_nodeA); auto matched_paramC = std::static_pointer_cast(matched_nodeC); + auto matched_out_mulply = node_to_output.at(mulply); + if (ov::element::i4 == matched_paramA->get_element_type() && (ov::element::f16 == matched_paramC->get_element_type() || ov::element::f32 == matched_paramC->get_element_type())) { @@ -669,31 +671,17 @@ DCOFFPassReshape4::DCOFFPassReshape4(DCOffMode dcoff_mode, ov::element::Type dco LOG_DEBUG("Matched: " << matched_paramC << " - parameter to remove..."); LOG_BLOCK(); - // Extra transformation here: - // - remove Multiply + Intermediate Convert - // - mark paramC for removal. - // Convert will be reconnected to paramA directly. - // Record mapping from the Scale coeff parameter to the Real weight parameter pref.get().scales[matched_paramC] = matched_paramA; - // Disconnect Multiply and Convert from their outputs - auto matched_mulply = node_to_output.at(mulply).get_node_shared_ptr(); - auto matched_convrt = node_to_output.at(cvtA).get_node_shared_ptr(); - auto drop_outputs = [](std::shared_ptr node) { - for (auto&& node_outputs : node->outputs()) { - for (auto&& node_reader_port : node_outputs.get_target_inputs()) { - node_outputs.remove_target_input(node_reader_port); - } - } - }; - LOG_DEBUG("Dropping the connections..."); - drop_outputs(matched_mulply); - drop_outputs(matched_convrt); + std::shared_ptr new_rshp_in = matched_paramA; + if (matched_out_mulply.get_element_type() == ov::element::f32) { + new_rshp_in = std::make_shared(matched_paramA, ov::element::f32); + } LOG_DEBUG("Reconnecting the Root..."); auto matched_reshape = node_to_output.at(reshape).get_node_shared_ptr(); - matched_reshape->input(0).replace_source_output(matched_paramA); + matched_reshape->input(0).replace_source_output(new_rshp_in); } LOG_DEBUG("Done"); } diff --git a/src/plugins/intel_npu/src/plugin/npuw/util.cpp b/src/plugins/intel_npu/src/plugin/npuw/util.cpp index abb7e5841a95b7..d83a521fb29496 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/util.cpp +++ b/src/plugins/intel_npu/src/plugin/npuw/util.cpp @@ -644,6 +644,115 @@ void unpack_i4f16(const ov::SoPtr& from, } } +void unpack_i4f16_z(const ov::SoPtr& from, + const ov::SoPtr& scale, + const ov::SoPtr& to, + const ov::npuw::util::UnpackOptions& unpack_options) { + NPUW_ASSERT(from->is_continuous()); + NPUW_ASSERT(scale->is_continuous()); + NPUW_ASSERT(to->is_continuous()); + NPUW_ASSERT(from->get_size() == to->get_size()); + + const auto& from_shape = from->get_shape(); + NPUW_ASSERT(from_shape.back() % 64 == 0); + + const auto& scale_shape = scale->get_shape(); + NPUW_ASSERT(scale_shape.size() == 3); + NPUW_ASSERT(scale_shape[0] == from_shape[0]); + NPUW_ASSERT(scale_shape[2] == from_shape[2]); + NPUW_ASSERT(scale_shape[1] == 1); + + const auto scale_elem_type = scale->get_element_type(); + NPUW_ASSERT(scale_elem_type == ov::element::f32); + + // This conversion combines i4tof32 and f32tof16. Here we + // - read 256 bits (= 32 bytes, = 64 u4 elements) + // - write 1024 bits (= 128 bytes, = 64 f16 elements) + // per every iteration, what translates to (from->size() / 64) iterations + + const size_t C = from_shape[from_shape.size() - 3]; + const size_t H = from_shape[from_shape.size() - 2]; + const size_t W = from_shape[from_shape.size() - 1]; + + const int8_t* const pSrc = static_cast(from->data()); // 2 x i4 elements + const float* const pScl = static_cast(scale->data()); // 1 x f32 element + int16_t* pDst = static_cast(to->data()); // 1 x f16 element + + auto unpack_body = [&](size_t job_index, size_t stride) { + size_t start_c = job_index * stride; + size_t end_c = std::min(C, start_c + stride); + + for (size_t c = start_c; c < end_c; ++c) { + for (size_t h = 0; h < H; ++h) { + for (size_t w = 0; w < W; w += 64) { + const int8_t* pSrc_iter = pSrc + (w + W * h + W * H * c) / 2; + __m256i vinput = _mm256_lddqu_si256(reinterpret_cast(pSrc_iter)); + __m256i vout0, vout1; + avx2_i4toi8(vinput, &vout0, &vout1); + int8_t tmp[64]; // FIXME: Avoid it + __m256i* tmpv0 = reinterpret_cast<__m256i*>(tmp); + __m256i* tmpv1 = reinterpret_cast<__m256i*>(tmp + 32); + _mm256_storeu_si256(tmpv0, vout0); + _mm256_storeu_si256(tmpv1, vout1); + __m128i i8vecs[8] = { + _mm_loadl_epi64(reinterpret_cast<__m128i*>(tmp)), + _mm_loadl_epi64(reinterpret_cast<__m128i*>(tmp + 8)), + _mm_loadl_epi64(reinterpret_cast<__m128i*>(tmp + 16)), + _mm_loadl_epi64(reinterpret_cast<__m128i*>(tmp + 24)), + _mm_loadl_epi64(reinterpret_cast<__m128i*>(tmp + 32)), + _mm_loadl_epi64(reinterpret_cast<__m128i*>(tmp + 40)), + _mm_loadl_epi64(reinterpret_cast<__m128i*>(tmp + 48)), + _mm_loadl_epi64(reinterpret_cast<__m128i*>(tmp + 56)), + }; + + const float* pScl_iter = pScl + w + W * c; + __m256 svalVec[8]; + for (int i = 0; i < 8; ++i) { + svalVec[i] = _mm256_loadu_ps(pScl_iter + i * 8); + } + + __m128i vresults[8] = {avx2_i8tof16(i8vecs[0], svalVec[0]), + avx2_i8tof16(i8vecs[1], svalVec[1]), + avx2_i8tof16(i8vecs[2], svalVec[2]), + avx2_i8tof16(i8vecs[3], svalVec[3]), + avx2_i8tof16(i8vecs[4], svalVec[4]), + avx2_i8tof16(i8vecs[5], svalVec[5]), + avx2_i8tof16(i8vecs[6], svalVec[6]), + avx2_i8tof16(i8vecs[7], svalVec[7])}; + + int16_t* pDst_iter = pDst + w + W * h + W * H * c; + for (int i = 0; i < 8; ++i) { + _mm_storeu_si128(reinterpret_cast<__m128i*>(pDst_iter + i * 8), vresults[i]); + } + } + } + } + }; + + size_t stride = C; + size_t num_jobs = 1; + + if (unpack_options.nPartitions) { + if (unpack_options.bStrictPartitioning) { + stride = (C + unpack_options.nPartitions - 1) / unpack_options.nPartitions; + num_jobs = unpack_options.nPartitions; + } else { + stride = std::max(1, C / unpack_options.nPartitions); + num_jobs = (C + stride - 1) / stride; + } + } + + if (unpack_options.bUseOvParallelFor) { + ov::parallel_for(num_jobs, [&](size_t job_index) { + unpack_body(job_index, stride); + }); + } else { + for (size_t job_index = 0; job_index < num_jobs; ++job_index) { + unpack_body(job_index, stride); + } + } +} + void unpack_u4f16(const ov::SoPtr& from, const ov::SoPtr& to, const ov::npuw::util::UnpackOptions& unpack_options) { @@ -1328,22 +1437,27 @@ void ov::npuw::util::unpack(const ov::SoPtr& from, // This is in fact a weight decompression procedure const auto type_from = from->get_element_type(); const auto type_to = to->get_element_type(); - namespace ove = ov::element; -#define CAST(x) static_cast((x).operator ove::Type_t()) -#define PAIR(f, t) (CAST(f) << 16 | CAST(t)) -#define HNDL(f, t) \ - case PAIR(ove::f, ove::t): \ - unpack_##f##t(from, scale, to, unpack_options); \ - break; - switch (PAIR(type_from, type_to)) { - HNDL(i4, f16); - HNDL(i8, f16); - default: - OPENVINO_THROW("Unknown unpack/scale combination ", type_from, " -> ", type_to); + NPUW_ASSERT(type_to == ov::element::f16); + + const auto& from_shape = from->get_shape(); + const auto& scale_shape = scale->get_shape(); + + if (type_from == ov::element::i4) { + if (from_shape.size() == 3) { + if (scale_shape[2] == from_shape[2]) { + unpack_i4f16_z(from, scale, to, unpack_options); + } else { + unpack_i4f16(from, scale, to, unpack_options); + } + } else { + NPUW_ASSERT(from_shape.size() == 2); + unpack_i4f16(from, scale, to, unpack_options); + } + } else if (type_from == ov::element::i8) { + unpack_i8f16(from, scale, to, unpack_options); + } else { + NPUW_ASSERT(false && "Unsupported combination"); } -#undef HNDL -#undef PAIR -#undef CAST } void ov::npuw::util::unpack(const ov::SoPtr& from,