Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Group Query Attention support with OV base OPs #28163

Open
wants to merge 16 commits into
base: master
Choose a base branch
from

Conversation

sgbihu
Copy link

@sgbihu sgbihu commented Dec 20, 2024

Details:

  • Try to enable LLM based on onnxruntime. (Phi3, Llama3 is working on CPU, Phi3 can work with iGPU)

Test scripts

import onnxruntime as rt
import os
import numpy as np
import time

import onnxruntime.tools.add_openvino_win_libs as utils
utils.add_openvino_libs_to_path()
from transformers import PreTrainedTokenizerFast


test_lama3 = False
test_phi3 = True
if test_phi3:
    modelPath = os.path.join('D:\\', 'models', 'llm', 'Phi-3-mini-4k-instruct-onnx', 'model.onnx')
    tokenizerPath = os.path.join('D:\\', 'models', 'llm', 'Phi-3-mini-4k-instruct-onnx', 'tokenizer.json')

if test_lama3:
    modelPath = os.path.join('D:\\', 'models', 'llm', 'llama3.1-8B-instruct-onnx', 'model.onnx')

so = rt.SessionOptions()
# so.log_severity_level = 3

# sess = rt.InferenceSession(modelPath, so, providers=['CPUExecutionProvider'])
sess = rt.InferenceSession(modelPath, so, providers=['OpenVINOExecutionProvider'], provider_options=[{'device_type' : "CPU", 'cache_dir': "cache"}])
# sess = rt.InferenceSession(modelPath, so, providers=['OpenVINOExecutionProvider'], provider_options=[{'device_type' : "CPU"}])
# sess = rt.InferenceSession(modelPath, so, providers=['OpenVINOExecutionProvider'], provider_options=[{'device_type' : "NPU"}])
tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizerPath)

# print(sess.get_device())
# for name in sess.get_inputs():
#     print(f"Name: {name.name}, Shape: {name.shape}, Type: {name.type}")
outputs = sess.get_outputs()
output_names = list(map(lambda output: output.name, outputs))


# Assuming the model has 32 layers and each layer has a key and value state
# Phi3
def get_phi3_param():
    num_layers = 32
    batch_size = 1
    num_heads = 32
    sequence_length = 2048
    hidden_size = 96
    return num_layers, batch_size, num_heads, sequence_length, hidden_size

# lama
def get_llama3_param():
    num_layers = 32
    batch_size = 1
    num_heads = 8
    sequence_length = 2048
    hidden_size = 128
    return num_layers, batch_size, num_heads, sequence_length, hidden_size

if test_phi3:
    num_layers, batch_size, num_heads, sequence_length, hidden_size = get_phi3_param()

if test_lama3:
    num_layers, batch_size, num_heads, sequence_length, hidden_size = get_llama3_param()

# Initialize past_key_values with zeros
cpu_array = np.zeros((batch_size, num_heads, sequence_length, hidden_size), dtype=np.float32)

# print("Output names: ", outputs[0].type.data)

def create_present_state_binding(binding, outputs):
    outputMap={}
    for output in outputs:
        shapes = []
        for item in output.shape:
            if isinstance(item, str):
                if 'batch_size' in item:
                    shapes.append(batch_size)
                elif 'sequence_length' in item:
                    if output.name == 'logits':
                        shapes.append(len(inputToken))
                    else:
                        shapes.append(sequence_length)
                elif 'hidden_size' in item:
                    shapes.append(hidden_size)
                elif 'num_heads' in item:
                    shapes.append(num_heads)
                else:
                    raise ValueError(f"Unknown dimension: {item}")
            else:
                shapes.append(item)
            
        present_state = rt.OrtValue.ortvalue_from_shape_and_type(shapes, np.float32)
        binding.bind_ortvalue_output(output.name, present_state)
        outputMap[output.name] = present_state
    return outputMap

def rebind_inputs(lastOutput, binding):
    for index in range(num_layers):
        binding.bind_ortvalue_input(f'past_key_values.{index}.key', lastOutput[f'present.{index}.key'])
        binding.bind_ortvalue_input(f'past_key_values.{index}.value', lastOutput[f'present.{index}.value'])
    return binding

def init_input_with_binding(binding):
    for index in range(num_layers):
        key_state = rt.OrtValue.ortvalue_from_numpy(cpu_array)
        value_state = rt.OrtValue.ortvalue_from_numpy(cpu_array)
        binding.bind_ortvalue_input(f'past_key_values.{index}.key', key_state)
        binding.bind_ortvalue_input(f'past_key_values.{index}.value', value_state)
    return binding

def reinit_input_bindings(bindings, lastOutput):
    newOutput = create_present_state_binding(bindings, lastOutput)
    binding = rebind_inputs(lastOutput, bindings)
    return binding, newOutput

def create_numpy_inputs(inputToken):
    tokenLen = len(inputToken)
    npinput_ids = np.array([inputToken], dtype=np.int64)
    npattention_mask = np.array([[1] * (tokenLen)], dtype=np.int64)
    return npinput_ids, npattention_mask


def init_ortinput(inputToken):
    flattened_past_key_values = {}
    for index in range(num_layers):
        key_state = rt.OrtValue.ortvalue_from_numpy(cpu_array)
        value_state = rt.OrtValue.ortvalue_from_numpy(cpu_array)
        flattened_past_key_values[f'past_key_values.{index}.key'] = key_state
        flattened_past_key_values[f'past_key_values.{index}.value'] = value_state
    ids, mask = create_numpy_inputs(inputToken)
    flattened_past_key_values['input_ids'] = rt.OrtValue.ortvalue_from_numpy(ids)
    flattened_past_key_values['attention_mask'] = rt.OrtValue.ortvalue_from_numpy(mask)
    return flattened_past_key_values

def init_npinput(inputToken):
    flattened_past_key_values = {}
    for index in range(num_layers):
        key_state = np.zeros((batch_size, num_heads, sequence_length, hidden_size), dtype=np.float32)
        value_state = np.zeros((batch_size, num_heads, sequence_length, hidden_size), dtype=np.float32)
        flattened_past_key_values[f'past_key_values.{index}.key'] = key_state
        flattened_past_key_values[f'past_key_values.{index}.value'] = value_state
    flattened_past_key_values['input_ids'], flattened_past_key_values['attention_mask'] = create_numpy_inputs(inputToken)
    return flattened_past_key_values

def init_bindinginput(inputToken):
    binding = sess.io_binding()
    binding = init_input_with_binding(binding)
    
    ids, mask = create_numpy_inputs(inputToken)
    binding.bind_ortvalue_input(f'attention_mask', rt.OrtValue.ortvalue_from_numpy(mask))
    binding.bind_ortvalue_input(f'input_ids',  rt.OrtValue.ortvalue_from_numpy(ids))
    return binding


# Question
# The Sun is yellow because

# Phi3
if test_phi3:
    # 450 8991 5692
    # inputToken = [32010, 29871, 13]
    inputToken = [32010, 29871, 13, 1576, 8991, 338, 13328, 1363, 29871, 32007, 13, 32001]
    # inputToken = [32010, 32010, 32010, 32010, 32010, 32010, 32010, 32010, 32010, 32010, 32010, 32010]
# lama3
if test_lama3:
    # 315 1202 7479
    inputToken = [128000, 27, 91, 882, 91, 397, 791, 8219, 374, 14071, 1606, 83739, 408, 91, 397, 27, 91, 78191, 91, 29]
    # inputToken = [315]
history_tokens = inputToken

flattened_past_key_values = init_npinput(inputToken)

# flattened_past_key_values = init_ortinput(inputToken)

# binding = init_bindinginput(inputToken)
# lastoutput = create_present_state_binding(binding, outputs)

lastTokenLen = len(inputToken)


# roption = rt.RunOptions()
# roption.add_run_config_entry("gpu_graph_id", "-1")

before = time.time()
results = sess.run(output_names, flattened_past_key_values)
# results = sess.run_with_iobinding(binding)
# results = sess.run_with_ort_values(output_names, flattened_past_key_values)
after = time.time()
print("Time cost in ms: ", (after - before) * 1000)

# print(np.argmax(results[0].numpy(), axis=-1)[-1])
print(np.argmax(results[0], axis=-1)[-1])

# print(results[0])
# print(output_names[1])
# print(results[1][0][0][0])
# print(results[1][0][0][1])
# print(results[1][0][0][2])
# # print(results[1][0][0][14])
# # print(results[1])
# print(output_names[2])
# # print(results[2])
# print(results[2][0][0][0])
# print(results[2][0][0][1])
# print(results[2][0][0][2])
# print(results[2][0][0][14])
# inputToken.append(450)

# rebind_inputs(lastOutput, binding)

def update_kvcache(inputsMap, results):
    for index in range(len(output_names)):
        if not output_names[index].startswith('present'):
            continue
        # print(f'{output_names[index]}: {results[index].shape}')
        outputname = output_names[index]
        inputname = outputname.replace('present', 'past_key_values')
        inputsMap[inputname] = results[index]
    return inputsMap
# lastOutput = create_present_state_binding(binding, sess.get_outputs())

# flattened_past_key_values = update_kvcache(flattened_past_key_values, results)

for index in range(len(output_names)):
    if not output_names[index].startswith('present'):
        continue
    # print(f'{output_names[index]}: {results[index].shape}')
    outputname = output_names[index]
    inputname = outputname.replace('present', 'past_key_values')
    flattened_past_key_values[inputname] = results[index]
if test_phi3:
    inputToken = [450]

if test_lama3:
    inputToken = [315]
history_tokens += inputToken

npinput_ids = np.array([inputToken], dtype=np.int64)
npattention_mask = np.array([[1] * (lastTokenLen+1)], dtype=np.int64)
print(f"lastTokenLen:{lastTokenLen}")

# attention_mask = rt.OrtValue.ortvalue_from_numpy(npattention_mask)
# input_ids = rt.OrtValue.ortvalue_from_numpy(npinput_ids)
# binding.bind_ortvalue_input(f'attention_mask', attention_mask)
# binding.bind_ortvalue_input(f'input_ids', input_ids)
# flattened_past_key_values[f'attention_mask'].update_inplace(npattention_mask)
# flattened_past_key_values[f'input_ids'].update_inplace(npinput_ids)
# flattened_past_key_values[f'attention_mask'] = attention_mask
# flattened_past_key_values[f'input_ids'] = input_ids
flattened_past_key_values[f'attention_mask'] = npattention_mask
flattened_past_key_values[f'input_ids'] = npinput_ids
# print(flattened_past_key_values)

before = time.time()
results = sess.run(output_names, flattened_past_key_values)
# results = sess.run_with_iobinding(binding)
# results = sess.run_with_ort_values(output_names, flattened_past_key_values)
after = time.time()
print("Time cost in ms: ", (after - before) * 1000)

# Results:  [np.int32(450), np.int32(8991), np.int32(5692), np.int32(13328), np.int32(304), np.int32(502), np.int32(19434), np.int32(2861), np.int32(304), np.int32(9596), np.int32(280), np.int32(1141), np.int32(14801), np.int32(292), np.int32(29889), np.int32(1932), np.int32(6575), np.int32(4366), np.int32(14517), np.int32(1549), np.int32(278), np.int32(11563), np.int32(29915), np.int32(29879), np.int32(25005), np.int32(29892), np.int32(278), np.int32(20511), np.int32(7254), np.int32(281), np.int32(6447), np.int32(1477), np.int32(29879), np.int32(526), np.int32(29574), np.int32(297), np.int32(599), np.int32(18112), np.int32(491), np.int32(278), np.int32(330), np.int32(2129), np.int32(322), np.int32(17105), np.int32(297), np.int32(278), np.int32(4799), np.int32(29889), np.int32(910), np.int32(14801), np.int32(292), np.int32(9946), np.int32(278), np.int32(14744), np.int32(304), np.int32(1106), np.int32(7254), np.int32(29889), np.int32(2398), np.int32(29892), np.int32(278), np.int32(5520), np.int32(2654), np.int32(322), np.int32(13328), np.int32(281), np.int32(6447), np.int32(1477), np.int32(29879), np.int32(1209), np.int32(1549), np.int32(278), np.int32(25005), np.int32(901), np.int32(5948), np.int32(322), np.int32(526), np.int32(3109), np.int32(29574), np.int32(29889), np.int32(1932), np.int32(591), np.int32(1106), np.int32(472), np.int32(278), np.int32(8991), np.int32(29892), np.int32(591), np.int32(1074), np.int32(372), np.int32(408), np.int32(263), np.int32(13328), np.int32(470), np.int32(24841), np.int32(8086), np.int32(1363), np.int32(278), np.int32(7254), np.int32(3578), np.int32(338), np.int32(29574), np.int32(714), np.int32(310), np.int32(1749), np.int32(1196), np.int32(310), np.int32(11126), np.int32(29892), np.int32(322), np.int32(278), np.int32(9886), np.int32(3578), np.int32(393), np.int32(22170), np.int32(1749), np.int32(5076), np.int32(338), np.int32(758), np.int32(24130), np.int32(10835), np.int32(13328), np.int32(322), np.int32(2654), np.int32(29889), np.int32(32000)]
# index = 0
# for result in results:
#     print(f'{output_names[index]}: {result.shape}, {result.dtype}')
#     index += 1
print(np.argmax(results[0], axis=-1)[-1])
# print(np.argmax(results[0].numpy(), axis=-1)[-1])


# golden results
# Time cost in ms:  1255.2332878112793
# [30751    13    13  1494  1731   263 29889   372    13 24380    13   450]
# lastTokenLen:12
# Time cost in ms:  1006.781816482544
# [8991]

last_generated_token = np.argmax(results[0], axis=-1)[-1][-1]
history_tokens.append(last_generated_token)
NUM_INFERENCE = 15
for i in range(NUM_INFERENCE):
    # update kvcahe
    for index in range(len(output_names)):
        if not output_names[index].startswith('present'):
            continue
        # print(f'{output_names[index]}: {results[index].shape}')
        outputname = output_names[index]
        inputname = outputname.replace('present', 'past_key_values')
        flattened_past_key_values[inputname] = results[index]

    # update input token
    flattened_past_key_values[f'input_ids'] = np.array([[last_generated_token]], dtype=np.int64)
    flattened_past_key_values[f'attention_mask'] = np.array([[1] * len(history_tokens)], dtype=np.int64)

    before = time.time()
    results = sess.run(output_names, flattened_past_key_values)
    after = time.time()
    print("Time cost in ms: ", (after - before) * 1000)

    last_generated_token = np.argmax(results[0], axis=-1)[-1][-1]
    history_tokens.append(last_generated_token)

print(tokenizer.decode(history_tokens))

Tickets:

  • related to 155287, 157123

@github-actions github-actions bot added the category: ONNX FE OpenVINO ONNX FrontEnd label Dec 20, 2024
@sys-openvino-ci sys-openvino-ci added the ExternalIntelPR External contributor from Intel label Dec 20, 2024
@slyalin
Copy link
Contributor

slyalin commented Dec 20, 2024

How is it related to #27648?

@github-actions github-actions bot added category: Core OpenVINO Core (aka ngraph) category: GPU OpenVINO GPU plugin category: CPU OpenVINO CPU plugin category: transformations OpenVINO Runtime library - Transformations labels Jan 16, 2025
@github-actions github-actions bot added category: CPP API OpenVINO CPP API bindings and removed category: GPU OpenVINO GPU plugin category: CPU OpenVINO CPU plugin labels Jan 21, 2025
@wine99 wine99 force-pushed the gqa_enabling branch 2 times, most recently from f4770e0 to 911691b Compare January 26, 2025 02:32
@wine99
Copy link
Contributor

wine99 commented Feb 5, 2025

@slyalin we have relocated the transformation code from the ONNX frontend to the plugin transformation passes as detailed in #27648. Could you please review and provide feedback? Currently, the GQA node is defined in opset15, which likely needs to be updated.

@sgbihu sgbihu marked this pull request as ready for review February 6, 2025 13:03
@sgbihu sgbihu requested review from a team as code owners February 6, 2025 13:03
@sgbihu sgbihu requested review from itikhono and removed request for a team February 6, 2025 13:03
@mlukasze
Copy link
Contributor

mlukasze commented Feb 6, 2025

hey @sgbihu
please, resolve conflicts before CI will be triggered

@wine99 wine99 requested a review from a team as a code owner February 7, 2025 03:00
@wine99 wine99 requested review from ilya-lavrenov and removed request for a team February 7, 2025 03:00
Copy link
Contributor

@mitruska mitruska left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my point of view, the main things to address is to clarify the correct place for GQA op (opset or dev_api) and the way of "Null" node representation for optional inputs.
Details in the comments for the related part of the code.

@slyalin Could you please share opinion from the architecture side and as an author of the original:

Comment on lines 12 to 15
// This is an experimental operation that is implemented in the plugins.
class OPENVINO_API GroupQueryAttention : public Op {
public:
OPENVINO_OP("GroupQueryAttention", "opset15");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The opset15 has been already released and closed, new operations should be added to the opset16 or (as this is marked as "experimental") to the dev_api (like src/core/dev_api/openvino/op/group_query_attention.hpp in:

Is it possible to have GQA in the dev_api? Or it's required to have GQA operation in ov "opset" (and IR deserialization enabled) to make this solution working with onnxruntime and the attached example script?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We added to opset 15 due to seeing some errors when implementing this at the beginning (errors related to IR deserialization IIRC, but I do not understand the errors). But I have just tried and found that having GQA in the dev_api also works. If that is preferred, we can move GQA (and maybe Null) to dev_api.

Comment on lines 11 to 24
namespace v15 {

/// \brief Represents a missing optional input or output of an ONNX node
///
/// Some ONNX operators have inputs or outputs that are marked as optional,
/// which means that a referring node MAY forgo providing values for such inputs
/// or computing these outputs.
/// An empty string is used in place of a name of such input or output.
///
/// More:
/// https://github.com/onnx/onnx/blob/master/docs/IR.md#optional-inputs-and-outputs
class OPENVINO_API Null : public Op {
public:
OPENVINO_OP("Null", "opset15", op::Op);
Copy link
Contributor

@mitruska mitruska Feb 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand the motivation behind the "Null" Node (as GQA has several optional inputs mixed with required). Actually, this is a copy of the class used within the OV ONNX Frontend scope. Following this approach we should agreed and provide more generic meaning for Null op, not only for ONNX.

On the other hand @slyalin proposed an empty Constant to represent the "Null" Node (group_query_attention.cpp#L72-L74)

Output<Node> GroupQueryAttentionExtension::null() {
    return v0::Constant::create(element::f32, Shape{0}, {});  // particular type and shape do not matter
}

@slyalin Could you please share your view on the preferred approach here?

Copy link
Contributor

@wine99 wine99 Feb 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The QKV unpacking logic is moved to frontend. Null is not needed now for this PR and removed.

@@ -1682,3 +1682,415 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_com_microsoft_qlinear_mul) {
test_case.add_expected_output<int8_t>(Shape{2, 2}, expected_output);
test_case.run();
}

OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_0_input_1_rotary) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying to run those tests locally with the latest version of this PR and all of them failed with the following error:

C++ exception with description "Exception from src/core/src/runtime/tensor.cpp:96:
Exception from src/inference/src/dev/make_tensor.cpp:66:

Tensor data with element type f32, is not representable as pointer to i32

It seems to be related to the total_sequence_length input and Null node:

test_case.add_input<int>(total_sequence_length);
// total_sequence_length is not used currently in OV GQA
ov_op_inputs[6] = std::make_shared<v15::Null>();

Please try to reproduce and fix.

Copy link
Contributor

@wine99 wine99 Feb 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We changed this input to Null because we met some problem in NPUW. I have changed the tests to exclude this input. The tests should work now.

Comment on lines 44 to 52
PartialShape input_shape = get_input_partial_shape(0);
Dimension batch_size = input_shape[0];
Dimension sequence_len = input_shape[1];
Dimension head_size;
if (Null::is_null(input_value(1)) && Null::is_null(input_value(2))) {
head_size = get_head_size(input_shape, m_num_heads, m_kv_num_heads);
} else {
head_size = input_shape[2].get_length() / m_num_heads;
}
Copy link
Contributor

@mitruska mitruska Feb 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The .get_length() will throw if called on dynamic dimension, and static dimension is not ensured here (and in get_head_size func). The check for input_shape[2].is_static() should be added, or simply dynamic dimension division can be applied to set bounds
with ov::util::dim::floor_div util from src/core/shape_inference/include/dimension_util.hpp

*Additional Note: Currently shape inference is usually implemented in a "shape_infer" function (like src/core/shape_inference/include/scaled_dot_product_attention_shape_inference.hpp) to be reused by plugins for dynamic cases.
But it can be considered for further refactor out of this PR scope.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a static check since the last dim should be known.

Copy link
Contributor

@gkrivor gkrivor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, thank you for contribution! I've left a few comments from frontend perspective :)

@@ -0,0 +1,24 @@
// Copyright (C) 2018-2024 Intel Corporation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Old copyrights in files


namespace opset_1 {
ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) {
const auto onnx_op_inputs = node.get_ov_inputs();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please, add common::defaul_op_checks, like here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

OutputVector ov_op_inputs;
ov_op_inputs.reserve(onnx_op_inputs.size());
for (const auto& input : onnx_op_inputs) {
ov_op_inputs.push_back(ov::op::util::is_null(input) ? std::make_shared<v15::Null>() : input);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually we place defaults pre-calculation in a front-end specific code. It might be also helpful for frontend-specific optimizations in the future.

Not sure I get why can't you calculate "default values" here, instead of doing it in the transformation pipeline?

Not necessary by frontend side, but as I understand - it allows to remove op::Null in the implementation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. QKV unpacking logic is now in the frontend.

const auto kv_num_heads = node.get_attribute_value<int64_t>("kv_num_heads");
const auto scale = node.get_attribute_value<float>("scale", 0.0f);
const auto do_rotary = node.get_attribute_value<int64_t>("do_rotary", 0);
const auto rotary_interleaved = node.get_attribute_value<float>("rotary_interleaved", 0.0f);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docs says it should be int, why it is float here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this was a mistake. Have changed now

public:
OPENVINO_MATCHER_PASS_RTTI("GroupQueryAttentionDecomposition");
GroupQueryAttentionDecomposition();
ov::OutputVector decompose(std::shared_ptr<ov::op::v15::GroupQueryAttention> node);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: do we need this method as public?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Made it private

github-merge-queue bot pushed a commit that referenced this pull request Feb 18, 2025
This PR is doing some optimization work on onnxfrontend
com.microsoft.MatMulNbits operators

with this changes:
1. it disabled const folding with use 75GB for phi3 INT4 model and
200+GB for llama3 INT4 model.
2. it trigger oneDNN matmul primitives, much benefits the GPU
performance

we tested this changes along with another PR #28163 , and confirmed
phi3/llama3 INT4 model run well in LNL.

---------

Co-authored-by: Yu, Zijun <zijun.yu@intel.com>
@github-actions github-actions bot removed the category: CPP API OpenVINO CPP API bindings label Feb 19, 2025
@mitruska
Copy link
Contributor

The related onnx tests are failing and need to be fixed:

[ RUN      ] IE_CPU.onnx_model_gqa_past_0_input_1_rotary
Segmentation fault (core dumped) 

The Interpreter (Template plugin), tests require ref impl for ScaledDotProductAttention which is not enabled yet, so with working CPU tests, the Interpreter could be excluded eventually (here frontends/onnx/tests/runtime/interpreter/unit_test.manifest).

[ RUN      ] INTERPRETER.onnx_model_gqa_past_0_input_1_rotary
Interpreter backend doesn't implement evaluate method for OP ScaledDotProductAttention

Please take a look (CI logs).

@wine99
Copy link
Contributor

wine99 commented Feb 20, 2025

The failed CI checks are:

  • onnx frontend tests on windows, but based on CI log, all tests passed
  • tensorflow layer tests on ubuntu, failed at fetching openvino tarball

@mitruska
Copy link
Contributor

build_jenkins

Copy link
Contributor

@mitruska mitruska left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my opinion we could merge this work, and move forward to apply any improvements as a smaller changes if needed.

Comment on lines +93 to +102
const auto q_shape = std::make_shared<v3::ShapeOf>(Q);
const auto current_sequence_length = detail::get_dimensions(q_shape, {2});
auto head_size_node = v0::Constant::create(ov::element::i64, ov::Shape{}, {node->get_head_size()});

auto zero = v0::Constant::create(ov::element::i64, ov::Shape{1}, {0});
auto one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {1});
auto one_without_shape = v0::Constant::create(ov::element::i64, ov::Shape{}, {1});
auto two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {2});
auto seqlens_elemi64 = std::make_shared<v0::Convert>(seqlens_k, ov::element::i64);
auto real_seqlens = std::make_shared<v1::Add>(seqlens_elemi64, one);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current approach within the transformations is to add every node to the NodeRegistry like:

auto q_shape = register_new_node<v3::ShapeOf>(query, element::i32);
auto k_shape = register_new_node<v3::ShapeOf>(key, element::i32);
auto minus_one = register_new_node(v0::Constant::create(element::i32, Shape{}, {-1}));
auto minus_two = register_new_node(v0::Constant::create(element::i32, Shape{}, {-2}));

It is used to copy runtime info before replacement:
auto result = register_new_node<v0::MatMul>(scaled_atten, value);
result->set_friendly_name(node->get_friendly_name());
copy_runtime_info(node, get_new_nodes());
return result;

cc: @itikhono

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I make this change in this PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've already approved, so I don't force this change in this PR (but it should be applied it as a follow up).
@itikhono Do you consider it as a blocker?

Comment on lines +11 to +12
// This is an experimental operation that is implemented in the plugins.
class OPENVINO_API GroupQueryAttention : public Op {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any plugin able to support this GroupQueryAttention class right now or the decomposition to ScaleDotProductAttention is always needed and applied?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK no plugin has GQA kernels so the decomposition is always needed @sgbihu

@wine99 wine99 mentioned this pull request Mar 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
category: Core OpenVINO Core (aka ngraph) category: ONNX FE OpenVINO ONNX FrontEnd category: transformations OpenVINO Runtime library - Transformations ExternalIntelPR External contributor from Intel
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants