Skip to content

Commit

Permalink
optimize com.microsoft.MatMulNbits operator (#28504)
Browse files Browse the repository at this point in the history
This PR is doing some optimization work on onnxfrontend
com.microsoft.MatMulNbits operators

with this changes:
1. it disabled const folding with use 75GB for phi3 INT4 model and
200+GB for llama3 INT4 model.
2. it trigger oneDNN matmul primitives, much benefits the GPU
performance

we tested this changes along with another PR #28163 , and confirmed
phi3/llama3 INT4 model run well in LNL.

---------

Co-authored-by: Yu, Zijun <zijun.yu@intel.com>
  • Loading branch information
bopeng1234 and wine99 authored Feb 18, 2025
1 parent 047976e commit 68ecdfb
Showing 1 changed file with 71 additions and 112 deletions.
183 changes: 71 additions & 112 deletions src/frontends/onnx/frontend/src/op/com.microsoft/matmulnbits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,13 @@
#include "exceptions.hpp"
#include "openvino/frontend/exception.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/transpose.hpp"
#include "utils/common.hpp"
#include "utils/reshape.hpp"

Expand Down Expand Up @@ -111,142 +109,103 @@ ov::OutputVector matmulnbits(const ov::frontend::onnx::Node& node) {
bias.get_partial_shape());
}

ov::Output<ov::Node> mm_output;
{
const auto b_const = ov::as_type_ptr<v0::Constant>(b_quantized.get_node_shared_ptr());

ov::Output<ov::Node> casted_b;
ov::Shape casted_b_shape;
ov::Output<ov::Node> default_zp;
// Casting/converting data of source constant.
// For further calculations (sub and/or multiply) we need to reshape it from [N][n_blocks_per_col][blob_size *
// X] to [N * n_blocks_per_col][blob_size * X] (where X is amount of values in 1 byte) because scale and
// zero_point are represented as: ...with shape like: [N * n_blocks_per_col]...
// For further calculations (sub and/or multiply) we need to reshape
// b -> [N][n_blocks_per_col][block_size]
switch (bits) {
case 2:
casted_b_shape = ov::Shape{static_cast<size_t>(N * n_blocks_per_col), static_cast<size_t>(blob_size * 4)};
casted_b_shape = ov::Shape{static_cast<size_t>(N),
static_cast<size_t>(n_blocks_per_col),
static_cast<size_t>(blob_size * 4)};
casted_b = std::make_shared<v0::Constant>(ov::element::u2, casted_b_shape, b_const->get_data_ptr());
if (a.get_element_type() != ov::element::dynamic) {
default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 2);
} else {
default_zp =
std::make_shared<v1::ConvertLike>(a,
std::make_shared<v0::Constant>(ov::element::f32, Shape{}, 2.f));
}
default_zp = std::make_shared<v0::Constant>(ov::element::u2, Shape{1}, 2);
break;
case 4:
casted_b_shape = ov::Shape{static_cast<size_t>(N * n_blocks_per_col), static_cast<size_t>(blob_size * 2)};
casted_b_shape = ov::Shape{static_cast<size_t>(N),
static_cast<size_t>(n_blocks_per_col),
static_cast<size_t>(blob_size * 2)};
casted_b = std::make_shared<v0::Constant>(ov::element::u4, casted_b_shape, b_const->get_data_ptr());
if (a.get_element_type() != ov::element::dynamic) {
default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 8);
} else {
default_zp =
std::make_shared<v1::ConvertLike>(a,
std::make_shared<v0::Constant>(ov::element::f32, Shape{}, 8.f));
}
default_zp = std::make_shared<v0::Constant>(ov::element::u4, Shape{1}, 8);
break;
case 8:
casted_b_shape = ov::Shape{static_cast<size_t>(N * n_blocks_per_col), static_cast<size_t>(blob_size)};
casted_b = op::util::reshape(b_const, casted_b_shape);
if (a.get_element_type() != ov::element::dynamic) {
default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 128);
} else {
default_zp =
std::make_shared<v1::ConvertLike>(a,
std::make_shared<v0::Constant>(ov::element::f32, Shape{}, 128.f));
}
casted_b_shape = ov::Shape{static_cast<size_t>(N),
static_cast<size_t>(n_blocks_per_col),
static_cast<size_t>(blob_size)};
casted_b = std::make_shared<v0::Constant>(ov::element::u8, casted_b_shape, b_const->get_data_ptr());
default_zp = std::make_shared<v0::Constant>(ov::element::u8, Shape{1}, 128);
break;
default:
FRONT_END_THROW("Unsupported bits count");
break;
}

if (!zero_points.get_node_shared_ptr()) {
zero_points = default_zp;
} else {
// https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.MatMulNBits
// according to the link, zero point are:
// Constrain quantized zero point types to uint8/int32/float16/float.
// Input zero_points is stored as uint8_t or same as type(A). It has the same packing method as input B
zero_points =
op::util::reshape(zero_points,
ov::Shape{static_cast<size_t>(N), static_cast<size_t>(n_blocks_per_col), 1});
}

// Possible issue with slice implementation, had to move convertion before slice, instead of slicing uint4
// TODO: Ticket
const auto converted_b = std::make_shared<v1::ConvertLike>(casted_b, a);
// Comments: it is still there, so need to convert b to fp16 first.

// TODO: Need to collect performance data in case constant folding is applied. Possible some perf/mem-gap

// Simple case
if (n_blocks_per_col == 1) {
// Removing unused items in case block is bigger than column count
// For example, if data is (uint8)[1,2,3,4,5,6] then block will be (uint8)[1,2,3,4,5,6,0,0,0,0,0,0,0,0,0,0].
// And last zeros are unused.
const auto zero_const = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 0);
const auto one_const = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 1);
const auto elements_const =
std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, static_cast<int32_t>(K));
const auto axis_const = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 1);
const auto slice_b =
std::make_shared<v8::Slice>(converted_b, zero_const, elements_const, one_const, axis_const);

// Transpose matrix
const auto transposed_shape =
std::make_shared<v0::Constant>(ov::element::i64, Shape{2}, std::vector<int64_t>{1, 0});
const auto transposed_b = std::make_shared<v1::Transpose>(slice_b, transposed_shape);

// If no zero-points provided - we generate default, depends on data size
if (!zero_points.get_node_shared_ptr()) {
zero_points = default_zp;
}
const auto sub_b = std::make_shared<v1::Subtract>(transposed_b, zero_points);

// Scaling
const auto scaled_b = std::make_shared<v1::Multiply>(sub_b, scales);

// Adding bias if required
if (!bias.get_node_shared_ptr()) {
b = scaled_b;
} else {
b = std::make_shared<v1::Add>(scaled_b, bias);
}
// Comments: in this latest code, the const folding is gone, it trigle the oneDNN kernel
// and use u2/u4/u8 weights as the kernel's input, won't do const folding anymore.

// use fp16 for compute

// convert b to fp16
auto converted_b = std::make_shared<v0::Convert>(casted_b, a.get_element_type());
auto converted_zero_points = std::make_shared<v0::Convert>(zero_points, a.get_element_type());

// sub and scale
const auto sub_b = std::make_shared<v1::Subtract>(converted_b, converted_zero_points);
const auto scales_fp16 = std::make_shared<v0::Convert>(scales, a.get_element_type());
const auto scales_reshaped =
op::util::reshape(scales_fp16, ov::Shape{static_cast<size_t>(N), static_cast<size_t>(n_blocks_per_col), 1});
const auto scaled_b = std::make_shared<v1::Multiply>(sub_b, scales_reshaped);

// reshape b to [N, K]
auto shape_b = v0::Constant::create(ov::element::i32, ov::Shape{2}, {0, -1});
auto reshaped_b = std::make_shared<v1::Reshape>(scaled_b, shape_b, true);

// if n_blocks_per_col*blob_size*X != K
// need slice it to K
// to produce b = [N, K]
const bool slice_needed = (K % block_size != 0);
if (slice_needed) {
const auto zero = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 0);
const auto one = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 1);
const auto elements = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, static_cast<int32_t>(K));
const auto axis = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 1);
b = std::make_shared<v8::Slice>(reshaped_b, zero, elements, one, axis);
} else {
// Transpose matrix. Quantized B matrix is transposed and has a shape [N,K].
// To apply further operations on it which operand's shape is [N] we do this
// transpose to have a matrix [K,N]...
const auto transposed_shape =
std::make_shared<v0::Constant>(ov::element::i64, Shape{2}, std::vector<int64_t>{1, 0});
ov::Output<ov::Node> transposed_b = std::make_shared<v1::Transpose>(converted_b, transposed_shape);

// If no zero-points provided - we generate default, depends on data size
if (!zero_points.get_node_shared_ptr()) {
zero_points = default_zp;
}
const auto sub_b = std::make_shared<v1::Subtract>(transposed_b, zero_points);

// Scaling
const auto scaled_b = std::make_shared<v1::Multiply>(sub_b, scales);

// Transpose again to make reshaping and slicing
transposed_b = std::make_shared<v1::Transpose>(scaled_b, transposed_shape);

const auto reshaped_b =
op::util::reshape(transposed_b,
ov::Shape{static_cast<size_t>(casted_b_shape[0] / n_blocks_per_col),
static_cast<size_t>(casted_b_shape[1] * n_blocks_per_col)});

// Removing unused items in case block is bigger than column count (see description for
// Slice above)
const auto zero_const = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 0);
const auto one_const = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 1);
const auto elements_const =
std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, static_cast<int32_t>(K));
const auto axis_const = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 1);
const auto slice_b =
std::make_shared<v8::Slice>(reshaped_b, zero_const, elements_const, one_const, axis_const);

// Adding bias if required
if (!bias.get_node_shared_ptr()) {
return {std::make_shared<v0::MatMul>(a, slice_b, false, true)};
} else {
// Transpose again
transposed_b = std::make_shared<v1::Transpose>(slice_b, transposed_shape);

b = std::make_shared<v1::Add>(transposed_b, bias);
}
b = reshaped_b;
}

// mm = matmul(a,b)
mm_output = std::make_shared<v0::MatMul>(a, b, false, true);
}

return {std::make_shared<v0::MatMul>(a, b)};
if (bias.get_node_shared_ptr()) {
return {std::make_shared<v1::Add>(mm_output, bias)};
} else {
return {mm_output};
}
}

ONNX_OP("MatMulNBits", OPSET_SINCE(1), com_microsoft::opset_1::matmulnbits, MICROSOFT_DOMAIN);
Expand All @@ -255,4 +214,4 @@ ONNX_OP("MatMulNBits", OPSET_SINCE(1), com_microsoft::opset_1::matmulnbits, MICR
} // namespace com_microsoft
} // namespace onnx
} // namespace frontend
} // namespace ov
} // namespace ov

0 comments on commit 68ecdfb

Please sign in to comment.