-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Stateful to Stateless Transformation for LLMs (#25150)
Transformation that undoing make_stateful from optimum-intel. ### How to use in Python ```python import openvino as ov from openvino._offline_transformations import stateful_to_stateless_transformation core = ov.Core() model = core.read_model('your_chatty_stateful_model_right_from_vanilla_optimum_intel.xml') stateful_to_stateless_transformation(model) # use `model` ``` ### How to use in C++ ```c++ #include <openvino/openvino.hpp> #include <openvino/pass/stateful_to_stateless.hpp> int main() { auto core = ov::Core(); auto model = core.read_model("your_chatty_stateful_model_right_from_vanilla_optimum_intel.xml"); ov::pass::StatefulToStateless().run_on_model(model); // use `model` } ``` ### TODO - [x] Restore the original order of inputs/output (now they are not globally ordered, but kv inputs corresponds to kv outputs by indices with a proper offset). - [x] Restore the original names of inputs and outputs based on optimum-intel conventions in make_stateful.
- Loading branch information
Showing
7 changed files
with
279 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/pass/pass.hpp" | ||
|
||
namespace ov { | ||
namespace pass { | ||
/** | ||
* @brief The transformation converts KV cache state back to stateless form. | ||
* \ingroup ov_pass_cpp_api | ||
*/ | ||
class OPENVINO_API StatefulToStateless : public ModelPass { | ||
public: | ||
OPENVINO_RTTI("StatefulToStateless"); | ||
|
||
bool run_on_model(const std::shared_ptr<ov::Model>& model) override; | ||
}; | ||
} // namespace pass | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "openvino/pass/stateful_to_stateless.hpp" | ||
|
||
#include <regex> | ||
#include <string> | ||
|
||
#include "openvino/cc/pass/itt.hpp" | ||
#include "openvino/op/assign.hpp" | ||
#include "openvino/op/gather.hpp" | ||
#include "openvino/op/read_value.hpp" | ||
#include "openvino/pass/manager.hpp" | ||
#include "transformations/utils/utils.hpp" | ||
|
||
using namespace ov::op; | ||
|
||
namespace { | ||
|
||
std::shared_ptr<ov::Node> set_name(std::shared_ptr<ov::Node> node, const std::string& name) { | ||
// Set name for both node and output tensor (should be only one tensor, and any other names will be overriden by a | ||
// given single name) | ||
node->set_friendly_name(name); | ||
OPENVINO_ASSERT(node->get_output_size() == 1); | ||
node->get_output_tensor(0).set_names({name}); | ||
return node; | ||
} | ||
|
||
// Templated method that has the same effect as not templated `set_name` but saves Op type for convenient calls chaining | ||
template <typename T> | ||
inline std::shared_ptr<T> set_name(std::shared_ptr<T> node, const std::string& name) { | ||
set_name(std::dynamic_pointer_cast<ov::Node>(node), name); | ||
return node; | ||
} | ||
|
||
std::shared_ptr<v0::Parameter> get_parameter_by_tensor_name(const std::shared_ptr<ov::Model>& model, | ||
const std::string& name) { | ||
for (const auto& param : model->get_parameters()) { | ||
if (param->get_output_tensor(0).get_names().count(name)) | ||
return param; | ||
} | ||
return nullptr; // nullptr and return type are only difference from ov::Model::input(name) | ||
} | ||
|
||
struct Variable { | ||
struct Context { | ||
// to hold compiled once regex for all Variable instances | ||
const std::regex naming_convention = | ||
std::regex(R"((past_key_values\.(\d+)\.(key|value))(present\.(\d+)\.(key|value)))"); | ||
}; | ||
|
||
Variable(const Context& context, const std::string& variable_name) : variable_name(variable_name) { | ||
// Try to decode original naming of the corresponding input and output in the stateless model | ||
std::smatch match; | ||
if (std::regex_match(variable_name, match, context.naming_convention)) { | ||
input_name = match[1].str(); | ||
output_name = match[4].str(); | ||
auto input_index = match[2].str(); | ||
auto output_index = match[5].str(); | ||
if (input_index == output_index && input_index.length() <= std::numeric_limits<int>::digits10) { | ||
index = std::stoi(input_index) * 2 + int(match[3].str() == "value"); // order key before value | ||
} else { | ||
index = -1; | ||
} | ||
} else { | ||
// Variable name doesn't follow the expected naming convention. It doens't prevent forming | ||
// a correct stateless model but doesn't give a way to restore all names and inputs/outputs ordering | ||
// accurately. | ||
input_name = "input_restored." + variable_name; | ||
output_name = "output_restored." + variable_name; | ||
index = -1; | ||
} | ||
} | ||
|
||
int index; // layer index, -1 means the index isn't known | ||
std::string variable_name; // original variable_id | ||
std::string input_name; // restored name of input | ||
std::string output_name; // restored name of output | ||
}; | ||
|
||
typedef std::vector<Variable> Variables; | ||
|
||
void restore_kv_cache_order(Variables& variables, const std::unordered_map<std::string, size_t>& var_index_by_var_id) { | ||
// Try to restore variable order based on the known naming convention from optimum-intel. | ||
// If names are not satisfy the expected convention, fallback to use order based on var_index_by_var_id | ||
// Sort items that do satisfy the naming conventions before items that don't satisfy. | ||
|
||
std::stable_sort(variables.begin(), variables.end(), [&](const Variable& a, const Variable& b) { | ||
if (a.index >= 0 && b.index >= 0) { | ||
return a.index < b.index; | ||
} else if (a.index >= 0 && b.index < 0) { | ||
return true; | ||
} else if (a.index < 0 && b.index >= 0) { | ||
return false; | ||
} else { // a.index < 0 && b.index < 0 | ||
return var_index_by_var_id.at(a.variable_name) < var_index_by_var_id.at(b.variable_name); | ||
} | ||
}); | ||
} | ||
|
||
} // namespace | ||
|
||
bool ov::pass::StatefulToStateless::run_on_model(const std::shared_ptr<ov::Model>& model) { | ||
RUN_ON_MODEL_SCOPE(StatefulToStateless); | ||
|
||
auto beam_idx = get_parameter_by_tensor_name(model, "beam_idx"); | ||
Variables variables; // to collect variables corresponding to future_params | ||
variables.reserve(model->get_sinks().size()); | ||
Variable::Context context; | ||
std::unordered_map<std::string, std::shared_ptr<ov::Node>> | ||
future_params; // to collect nodes, each with a single output that will be replaced by new parameters | ||
if (beam_idx) { | ||
for (const ov::Input<ov::Node>& input : beam_idx->get_output_target_inputs(0)) { | ||
if (auto gather = std::dynamic_pointer_cast<op::util::GatherBase>(input.get_node()->shared_from_this())) { | ||
auto read_value = | ||
std::dynamic_pointer_cast<op::util::ReadValueBase>(gather->get_input_node_shared_ptr(0)); | ||
OPENVINO_ASSERT(read_value, | ||
"Unexpected model topology in StatefulToStateless: no ReadValue is found at the first " | ||
"input of Gather by `beam_idx` parameter"); | ||
auto variable_name = read_value->get_variable_id(); | ||
variables.push_back(Variable(context, variable_name)); | ||
future_params[variable_name] = gather; | ||
} | ||
} | ||
} else { | ||
OPENVINO_THROW( | ||
"Stateful models without `beam_idx` input are not supported in StatefulToStateless transformation"); | ||
} | ||
model->remove_parameter(beam_idx); | ||
|
||
typedef std::shared_ptr<op::util::AssignBase> PAssign; | ||
std::unordered_map<std::string, PAssign> assigns_by_var_id; | ||
std::unordered_map<std::string, size_t> assign_index_by_var_id; | ||
const auto& sinks = model->get_sinks(); | ||
for (size_t i = 0; i < sinks.size(); ++i) { | ||
if (auto assign = std::dynamic_pointer_cast<op::util::AssignBase>(sinks[i])) { | ||
const auto& var_id = assign->get_variable_id(); | ||
assigns_by_var_id[var_id] = assign; | ||
assign_index_by_var_id[var_id] = i; | ||
} | ||
} | ||
|
||
restore_kv_cache_order(variables, assign_index_by_var_id); | ||
|
||
ov::ParameterVector new_parameters; | ||
ov::ResultVector new_results; | ||
new_parameters.reserve(variables.size()); | ||
new_results.reserve(variables.size()); | ||
|
||
for (const auto& variable_id : variables) { | ||
auto future_param = future_params[variable_id.variable_name]; | ||
auto parameter = ::set_name(std::make_shared<v0::Parameter>(future_param->get_output_element_type(0), | ||
future_param->get_output_partial_shape(0)), | ||
variable_id.input_name); | ||
|
||
replace_node(future_param, parameter); | ||
|
||
auto assign = assigns_by_var_id[variable_id.variable_name]; | ||
auto result = ::set_name(std::make_shared<v0::Result>(assign->input_value(0)), variable_id.output_name); | ||
|
||
model->remove_sink(assign); // Don't do replace_node(assign, result)! It will lead to silently incorrect model. | ||
model->remove_variable(model->get_variable_by_id(variable_id.variable_name)); | ||
new_parameters.push_back(parameter); | ||
new_results.push_back(result); | ||
} | ||
|
||
model->add_parameters(new_parameters); | ||
model->add_results(new_results); | ||
|
||
return true; | ||
} |
6 changes: 6 additions & 0 deletions
6
tests/model_hub_tests/pytorch/models/tiny-set-stateful-models-precommit
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
hf-internal-testing/tiny-random-LlamaForCausalLM,https://huggingface.co/trl-internal-testing/tiny-random-LlamaForCausalLM | ||
hf-internal-testing/tiny-random-StableLmForCausalLM,https://huggingface.co/hf-internal-testing/tiny-random-StableLmForCausalLM | ||
hf-internal-testing/tiny-random-PhiForCausalLM,https://huggingface.co/hf-internal-testing/tiny-random-PhiForCausalLM | ||
hf-internal-testing/tiny-random-CodeGenForCausalLM,https://huggingface.co/hf-internal-testing/tiny-random-CodeGenForCausalLM | ||
hf-internal-testing/tiny-random-Starcoder2ForCausalLM,https://huggingface.co/hf-internal-testing/tiny-random-Starcoder2ForCausalLM | ||
hf-internal-testing/tiny-random-OPTForCausalLM,https://huggingface.co/hf-internal-testing/tiny-random-OPTForCausalLM |
58 changes: 58 additions & 0 deletions
58
tests/model_hub_tests/pytorch/test_stateful_to_stateless_transformation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Copyright (C) 2018-2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import openvino as ov | ||
from openvino._offline_transformations import stateful_to_stateless_transformation | ||
from optimum.intel import OVModelForCausalLM | ||
import models_hub_common.utils as utils | ||
import pytest | ||
import os | ||
|
||
def get_read_value_ops(model: ov.Model): | ||
return [op for op in model.get_ops() if op.get_type_name() == 'ReadValue'] | ||
|
||
def check_desc_tensors(tensors1, tensors2): | ||
# order of tensors may not match, comparing by the total amount and names | ||
assert len(tensors1) == len(tensors2) | ||
assert set(tuple(t.names) for t in tensors1) == set(tuple(t.names) for t in tensors2) | ||
for t1 in tensors1: | ||
t2_candidates = [t for t in tensors2 if t1.names & t.names] | ||
assert len(t2_candidates) == 1 | ||
t2 = t2_candidates[0] | ||
assert t1.names == t2.names | ||
assert t1.get_partial_shape() == t2.get_partial_shape() | ||
assert t1.get_element_type() == t2.get_element_type() | ||
|
||
def run_stateful_to_stateless_in_runtime(tmp_path, model_id, model_link): | ||
model = OVModelForCausalLM.from_pretrained(model_id, export=True, stateful=True, compile=False) | ||
assert len(model.model.get_sinks()), f"Input model is not in the expected stateful form because it doesn't have any sinks." | ||
assert len(get_read_value_ops(model.model)), f"Input model is not in the expected stateful form because it doesn't have any ReadValue operations." | ||
|
||
stateful_to_stateless_transformation(model.model) | ||
|
||
sink_ops = model.model.get_sinks() | ||
read_value_ops = get_read_value_ops(model.model) | ||
assert len(sink_ops) == 0, f"Expected stateless model, but there are sinks found: {sink_ops}" | ||
assert len(read_value_ops) == 0, f"Expected stateless model, but there are ReadValue operations found: {read_value_ops}" | ||
|
||
stateless_model = OVModelForCausalLM.from_pretrained(model_id, export=True, stateful=False, compile=False) | ||
|
||
print(model.model) | ||
print(stateless_model.model) | ||
check_desc_tensors(model.model.inputs, stateless_model.model.inputs) | ||
check_desc_tensors(model.model.outputs, stateless_model.model.outputs) | ||
|
||
core = ov.Core() | ||
core.compile_model(model.model, 'CPU') | ||
|
||
|
||
@pytest.mark.precommit | ||
@pytest.mark.parametrize("model_name, model_link, mark, reason", utils.get_models_list(os.path.join(os.path.dirname(__file__), "models", "tiny-set-stateful-models-precommit"))) | ||
def test_stateful_to_stateless_precommit(tmp_path, model_name, model_link, mark, reason, ie_device): | ||
assert mark is None or mark == 'skip' or mark == 'xfail', \ | ||
"Incorrect test case: {}, {}".format(model_name, model_link) | ||
if mark == 'skip': | ||
pytest.skip(reason) | ||
elif mark == 'xfail': | ||
pytest.xfail(reason) | ||
run_stateful_to_stateless_in_runtime(tmp_path, model_name, model_link) |