Skip to content

Commit

Permalink
NPUW: Fix DCOFF patterns and handling for NNCF 2.12's GPTQs (openvino…
Browse files Browse the repository at this point in the history
…toolkit#26955)

### Details:
 - *item1*
 - *...*

### Tickets:
 - *ticket-id*
  • Loading branch information
dmatveev authored Oct 9, 2024
1 parent 8517c36 commit 5a92485
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -619,20 +619,20 @@ DCOFFPassReshape3::DCOFFPassReshape3(DCOffMode dcoff_mode, ov::element::Type dco
register_matcher(std::make_shared<opp::Matcher>(cvt, "TagDCOFFPassReshape3"), std::move(callback));
}

// Pattern: i4 Phi-3 4SymW16A
// Pattern: i4 group-quant
//
//
// "tensor" "scale" > "tensor"
// Param:A Param:C > Param:A
// i4 f16|f32 > f16
// : : > :
// V : > V
// Convert : > Convert
// f16|f32 : > f32
// : : >
// V V >
// Multiply >
// f16|f32 >
// Convert : > [ Convert]
// f16|f32 : > [ f32 ]
// : : > :
// V V > V
// Multiply > Reshape
// f16|f32 > f16|f32
// : >
// : >
// Reshape >
Expand All @@ -657,6 +657,8 @@ DCOFFPassReshape4::DCOFFPassReshape4(DCOffMode dcoff_mode, ov::element::Type dco
auto matched_paramA = std::static_pointer_cast<ov::op::v0::Parameter>(matched_nodeA);
auto matched_paramC = std::static_pointer_cast<ov::op::v0::Parameter>(matched_nodeC);

auto matched_out_mulply = node_to_output.at(mulply);

if (ov::element::i4 == matched_paramA->get_element_type() &&
(ov::element::f16 == matched_paramC->get_element_type() ||
ov::element::f32 == matched_paramC->get_element_type())) {
Expand All @@ -669,31 +671,17 @@ DCOFFPassReshape4::DCOFFPassReshape4(DCOffMode dcoff_mode, ov::element::Type dco
LOG_DEBUG("Matched: " << matched_paramC << " - parameter to remove...");
LOG_BLOCK();

// Extra transformation here:
// - remove Multiply + Intermediate Convert
// - mark paramC for removal.
// Convert will be reconnected to paramA directly.

// Record mapping from the Scale coeff parameter to the Real weight parameter
pref.get().scales[matched_paramC] = matched_paramA;

// Disconnect Multiply and Convert from their outputs
auto matched_mulply = node_to_output.at(mulply).get_node_shared_ptr();
auto matched_convrt = node_to_output.at(cvtA).get_node_shared_ptr();
auto drop_outputs = [](std::shared_ptr<ov::Node> node) {
for (auto&& node_outputs : node->outputs()) {
for (auto&& node_reader_port : node_outputs.get_target_inputs()) {
node_outputs.remove_target_input(node_reader_port);
}
}
};
LOG_DEBUG("Dropping the connections...");
drop_outputs(matched_mulply);
drop_outputs(matched_convrt);
std::shared_ptr<ov::Node> new_rshp_in = matched_paramA;
if (matched_out_mulply.get_element_type() == ov::element::f32) {
new_rshp_in = std::make_shared<ov::op::v0::Convert>(matched_paramA, ov::element::f32);
}

LOG_DEBUG("Reconnecting the Root...");
auto matched_reshape = node_to_output.at(reshape).get_node_shared_ptr();
matched_reshape->input(0).replace_source_output(matched_paramA);
matched_reshape->input(0).replace_source_output(new_rshp_in);
}
LOG_DEBUG("Done");
}
Expand Down
144 changes: 129 additions & 15 deletions src/plugins/intel_npu/src/plugin/npuw/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,115 @@ void unpack_i4f16(const ov::SoPtr<ov::ITensor>& from,
}
}

void unpack_i4f16_z(const ov::SoPtr<ov::ITensor>& from,
const ov::SoPtr<ov::ITensor>& scale,
const ov::SoPtr<ov::ITensor>& to,
const ov::npuw::util::UnpackOptions& unpack_options) {
NPUW_ASSERT(from->is_continuous());
NPUW_ASSERT(scale->is_continuous());
NPUW_ASSERT(to->is_continuous());
NPUW_ASSERT(from->get_size() == to->get_size());

const auto& from_shape = from->get_shape();
NPUW_ASSERT(from_shape.back() % 64 == 0);

const auto& scale_shape = scale->get_shape();
NPUW_ASSERT(scale_shape.size() == 3);
NPUW_ASSERT(scale_shape[0] == from_shape[0]);
NPUW_ASSERT(scale_shape[2] == from_shape[2]);
NPUW_ASSERT(scale_shape[1] == 1);

const auto scale_elem_type = scale->get_element_type();
NPUW_ASSERT(scale_elem_type == ov::element::f32);

// This conversion combines i4tof32 and f32tof16. Here we
// - read 256 bits (= 32 bytes, = 64 u4 elements)
// - write 1024 bits (= 128 bytes, = 64 f16 elements)
// per every iteration, what translates to (from->size() / 64) iterations

const size_t C = from_shape[from_shape.size() - 3];
const size_t H = from_shape[from_shape.size() - 2];
const size_t W = from_shape[from_shape.size() - 1];

const int8_t* const pSrc = static_cast<int8_t*>(from->data()); // 2 x i4 elements
const float* const pScl = static_cast<float*>(scale->data()); // 1 x f32 element
int16_t* pDst = static_cast<int16_t*>(to->data()); // 1 x f16 element

auto unpack_body = [&](size_t job_index, size_t stride) {
size_t start_c = job_index * stride;
size_t end_c = std::min(C, start_c + stride);

for (size_t c = start_c; c < end_c; ++c) {
for (size_t h = 0; h < H; ++h) {
for (size_t w = 0; w < W; w += 64) {
const int8_t* pSrc_iter = pSrc + (w + W * h + W * H * c) / 2;
__m256i vinput = _mm256_lddqu_si256(reinterpret_cast<const __m256i*>(pSrc_iter));
__m256i vout0, vout1;
avx2_i4toi8(vinput, &vout0, &vout1);
int8_t tmp[64]; // FIXME: Avoid it
__m256i* tmpv0 = reinterpret_cast<__m256i*>(tmp);
__m256i* tmpv1 = reinterpret_cast<__m256i*>(tmp + 32);
_mm256_storeu_si256(tmpv0, vout0);
_mm256_storeu_si256(tmpv1, vout1);
__m128i i8vecs[8] = {
_mm_loadl_epi64(reinterpret_cast<__m128i*>(tmp)),
_mm_loadl_epi64(reinterpret_cast<__m128i*>(tmp + 8)),
_mm_loadl_epi64(reinterpret_cast<__m128i*>(tmp + 16)),
_mm_loadl_epi64(reinterpret_cast<__m128i*>(tmp + 24)),
_mm_loadl_epi64(reinterpret_cast<__m128i*>(tmp + 32)),
_mm_loadl_epi64(reinterpret_cast<__m128i*>(tmp + 40)),
_mm_loadl_epi64(reinterpret_cast<__m128i*>(tmp + 48)),
_mm_loadl_epi64(reinterpret_cast<__m128i*>(tmp + 56)),
};

const float* pScl_iter = pScl + w + W * c;
__m256 svalVec[8];
for (int i = 0; i < 8; ++i) {
svalVec[i] = _mm256_loadu_ps(pScl_iter + i * 8);
}

__m128i vresults[8] = {avx2_i8tof16(i8vecs[0], svalVec[0]),
avx2_i8tof16(i8vecs[1], svalVec[1]),
avx2_i8tof16(i8vecs[2], svalVec[2]),
avx2_i8tof16(i8vecs[3], svalVec[3]),
avx2_i8tof16(i8vecs[4], svalVec[4]),
avx2_i8tof16(i8vecs[5], svalVec[5]),
avx2_i8tof16(i8vecs[6], svalVec[6]),
avx2_i8tof16(i8vecs[7], svalVec[7])};

int16_t* pDst_iter = pDst + w + W * h + W * H * c;
for (int i = 0; i < 8; ++i) {
_mm_storeu_si128(reinterpret_cast<__m128i*>(pDst_iter + i * 8), vresults[i]);
}
}
}
}
};

size_t stride = C;
size_t num_jobs = 1;

if (unpack_options.nPartitions) {
if (unpack_options.bStrictPartitioning) {
stride = (C + unpack_options.nPartitions - 1) / unpack_options.nPartitions;
num_jobs = unpack_options.nPartitions;
} else {
stride = std::max<size_t>(1, C / unpack_options.nPartitions);
num_jobs = (C + stride - 1) / stride;
}
}

if (unpack_options.bUseOvParallelFor) {
ov::parallel_for(num_jobs, [&](size_t job_index) {
unpack_body(job_index, stride);
});
} else {
for (size_t job_index = 0; job_index < num_jobs; ++job_index) {
unpack_body(job_index, stride);
}
}
}

void unpack_u4f16(const ov::SoPtr<ov::ITensor>& from,
const ov::SoPtr<ov::ITensor>& to,
const ov::npuw::util::UnpackOptions& unpack_options) {
Expand Down Expand Up @@ -1328,22 +1437,27 @@ void ov::npuw::util::unpack(const ov::SoPtr<ov::ITensor>& from,
// This is in fact a weight decompression procedure
const auto type_from = from->get_element_type();
const auto type_to = to->get_element_type();
namespace ove = ov::element;
#define CAST(x) static_cast<int>((x).operator ove::Type_t())
#define PAIR(f, t) (CAST(f) << 16 | CAST(t))
#define HNDL(f, t) \
case PAIR(ove::f, ove::t): \
unpack_##f##t(from, scale, to, unpack_options); \
break;
switch (PAIR(type_from, type_to)) {
HNDL(i4, f16);
HNDL(i8, f16);
default:
OPENVINO_THROW("Unknown unpack/scale combination ", type_from, " -> ", type_to);
NPUW_ASSERT(type_to == ov::element::f16);

const auto& from_shape = from->get_shape();
const auto& scale_shape = scale->get_shape();

if (type_from == ov::element::i4) {
if (from_shape.size() == 3) {
if (scale_shape[2] == from_shape[2]) {
unpack_i4f16_z(from, scale, to, unpack_options);
} else {
unpack_i4f16(from, scale, to, unpack_options);
}
} else {
NPUW_ASSERT(from_shape.size() == 2);
unpack_i4f16(from, scale, to, unpack_options);
}
} else if (type_from == ov::element::i8) {
unpack_i8f16(from, scale, to, unpack_options);
} else {
NPUW_ASSERT(false && "Unsupported combination");
}
#undef HNDL
#undef PAIR
#undef CAST
}

void ov::npuw::util::unpack(const ov::SoPtr<ov::ITensor>& from,
Expand Down

0 comments on commit 5a92485

Please sign in to comment.