Skip to content

Commit

Permalink
Stateful to Stateless Transformation for LLMs (#25150)
Browse files Browse the repository at this point in the history
Transformation that undoing make_stateful from optimum-intel.

### How to use in Python
```python
import openvino as ov
from openvino._offline_transformations import stateful_to_stateless_transformation
core = ov.Core()
model = core.read_model('your_chatty_stateful_model_right_from_vanilla_optimum_intel.xml')
stateful_to_stateless_transformation(model)
# use `model`
```

### How to use in C++
```c++
#include <openvino/openvino.hpp>
#include <openvino/pass/stateful_to_stateless.hpp>

int main() {
    auto core = ov::Core();
    auto model = core.read_model("your_chatty_stateful_model_right_from_vanilla_optimum_intel.xml");
    ov::pass::StatefulToStateless().run_on_model(model);
    // use `model`
}
```

### TODO

- [x] Restore the original order of inputs/output (now they are not
globally ordered, but kv inputs corresponds to kv outputs by indices
with a proper offset).
- [x] Restore the original names of inputs and outputs based on
optimum-intel conventions in make_stateful.
  • Loading branch information
slyalin authored Jul 4, 2024
1 parent c901a26 commit 2a9af43
Show file tree
Hide file tree
Showing 7 changed files with 279 additions and 0 deletions.
10 changes: 10 additions & 0 deletions .github/workflows/job_pytorch_models_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,16 @@ jobs:
USE_SYSTEM_CACHE: False
OP_REPORT_FILE: ${{ env.INSTALL_TEST_DIR }}/TEST-torch_unsupported_ops.log

- name: StatefulToStateless Test
if: always()
run: |
export PYTHONPATH=${MODEL_HUB_TESTS_INSTALL_DIR}:$PYTHONPATH
python3 -m pytest ${MODEL_HUB_TESTS_INSTALL_DIR}/pytorch/test_stateful_to_stateless_transformation.py -m ${TYPE} --html=${INSTALL_TEST_DIR}/TEST-torch_stateful_to_stateless_tests.html --self-contained-html -v --tb=short
env:
TYPE: ${{ inputs.event == 'schedule' && 'nightly' || 'precommit'}}
TEST_DEVICE: CPU
USE_SYSTEM_CACHE: False

- name: Reformat unsupported ops file
if: '!cancelled()'
run: |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
from openvino._pyopenvino._offline_transformations import compress_quantize_weights_transformation
from openvino._pyopenvino._offline_transformations import convert_sequence_to_tensor_iterator_transformation
from openvino._pyopenvino._offline_transformations import paged_attention_transformation
from openvino._pyopenvino._offline_transformations import stateful_to_stateless_transformation
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <openvino/pass/make_stateful.hpp>
#include <openvino/pass/sdpa_to_paged_attention.hpp>
#include <openvino/pass/serialize.hpp>
#include <openvino/pass/stateful_to_stateless.hpp>
#include <pruning.hpp>
#include <transformations/common_optimizations/compress_float_constants.hpp>
#include <transformations/common_optimizations/fused_names_cleanup.hpp>
Expand Down Expand Up @@ -137,4 +138,13 @@ void regmodule_offline_transformations(py::module m) {
manager.run_passes(model);
},
py::arg("model"));

m_offline_transformations.def(
"stateful_to_stateless_transformation",
[](std::shared_ptr<ov::Model> model) {
ov::pass::Manager manager;
manager.register_pass<ov::pass::StatefulToStateless>();
manager.run_passes(model);
},
py::arg("model"));
}
22 changes: 22 additions & 0 deletions src/core/include/openvino/pass/stateful_to_stateless.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/pass.hpp"

namespace ov {
namespace pass {
/**
* @brief The transformation converts KV cache state back to stateless form.
* \ingroup ov_pass_cpp_api
*/
class OPENVINO_API StatefulToStateless : public ModelPass {
public:
OPENVINO_RTTI("StatefulToStateless");

bool run_on_model(const std::shared_ptr<ov::Model>& model) override;
};
} // namespace pass
} // namespace ov
172 changes: 172 additions & 0 deletions src/core/src/pass/stateful_to_stateless.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/pass/stateful_to_stateless.hpp"

#include <regex>
#include <string>

#include "openvino/cc/pass/itt.hpp"
#include "openvino/op/assign.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/read_value.hpp"
#include "openvino/pass/manager.hpp"
#include "transformations/utils/utils.hpp"

using namespace ov::op;

namespace {

std::shared_ptr<ov::Node> set_name(std::shared_ptr<ov::Node> node, const std::string& name) {
// Set name for both node and output tensor (should be only one tensor, and any other names will be overriden by a
// given single name)
node->set_friendly_name(name);
OPENVINO_ASSERT(node->get_output_size() == 1);
node->get_output_tensor(0).set_names({name});
return node;
}

// Templated method that has the same effect as not templated `set_name` but saves Op type for convenient calls chaining
template <typename T>
inline std::shared_ptr<T> set_name(std::shared_ptr<T> node, const std::string& name) {
set_name(std::dynamic_pointer_cast<ov::Node>(node), name);
return node;
}

std::shared_ptr<v0::Parameter> get_parameter_by_tensor_name(const std::shared_ptr<ov::Model>& model,
const std::string& name) {
for (const auto& param : model->get_parameters()) {
if (param->get_output_tensor(0).get_names().count(name))
return param;
}
return nullptr; // nullptr and return type are only difference from ov::Model::input(name)
}

struct Variable {
struct Context {
// to hold compiled once regex for all Variable instances
const std::regex naming_convention =
std::regex(R"((past_key_values\.(\d+)\.(key|value))(present\.(\d+)\.(key|value)))");
};

Variable(const Context& context, const std::string& variable_name) : variable_name(variable_name) {
// Try to decode original naming of the corresponding input and output in the stateless model
std::smatch match;
if (std::regex_match(variable_name, match, context.naming_convention)) {
input_name = match[1].str();
output_name = match[4].str();
auto input_index = match[2].str();
auto output_index = match[5].str();
if (input_index == output_index && input_index.length() <= std::numeric_limits<int>::digits10) {
index = std::stoi(input_index) * 2 + int(match[3].str() == "value"); // order key before value
} else {
index = -1;
}
} else {
// Variable name doesn't follow the expected naming convention. It doens't prevent forming
// a correct stateless model but doesn't give a way to restore all names and inputs/outputs ordering
// accurately.
input_name = "input_restored." + variable_name;
output_name = "output_restored." + variable_name;
index = -1;
}
}

int index; // layer index, -1 means the index isn't known
std::string variable_name; // original variable_id
std::string input_name; // restored name of input
std::string output_name; // restored name of output
};

typedef std::vector<Variable> Variables;

void restore_kv_cache_order(Variables& variables, const std::unordered_map<std::string, size_t>& var_index_by_var_id) {
// Try to restore variable order based on the known naming convention from optimum-intel.
// If names are not satisfy the expected convention, fallback to use order based on var_index_by_var_id
// Sort items that do satisfy the naming conventions before items that don't satisfy.

std::stable_sort(variables.begin(), variables.end(), [&](const Variable& a, const Variable& b) {
if (a.index >= 0 && b.index >= 0) {
return a.index < b.index;
} else if (a.index >= 0 && b.index < 0) {
return true;
} else if (a.index < 0 && b.index >= 0) {
return false;
} else { // a.index < 0 && b.index < 0
return var_index_by_var_id.at(a.variable_name) < var_index_by_var_id.at(b.variable_name);
}
});
}

} // namespace

bool ov::pass::StatefulToStateless::run_on_model(const std::shared_ptr<ov::Model>& model) {
RUN_ON_MODEL_SCOPE(StatefulToStateless);

auto beam_idx = get_parameter_by_tensor_name(model, "beam_idx");
Variables variables; // to collect variables corresponding to future_params
variables.reserve(model->get_sinks().size());
Variable::Context context;
std::unordered_map<std::string, std::shared_ptr<ov::Node>>
future_params; // to collect nodes, each with a single output that will be replaced by new parameters
if (beam_idx) {
for (const ov::Input<ov::Node>& input : beam_idx->get_output_target_inputs(0)) {
if (auto gather = std::dynamic_pointer_cast<op::util::GatherBase>(input.get_node()->shared_from_this())) {
auto read_value =
std::dynamic_pointer_cast<op::util::ReadValueBase>(gather->get_input_node_shared_ptr(0));
OPENVINO_ASSERT(read_value,
"Unexpected model topology in StatefulToStateless: no ReadValue is found at the first "
"input of Gather by `beam_idx` parameter");
auto variable_name = read_value->get_variable_id();
variables.push_back(Variable(context, variable_name));
future_params[variable_name] = gather;
}
}
} else {
OPENVINO_THROW(
"Stateful models without `beam_idx` input are not supported in StatefulToStateless transformation");
}
model->remove_parameter(beam_idx);

typedef std::shared_ptr<op::util::AssignBase> PAssign;
std::unordered_map<std::string, PAssign> assigns_by_var_id;
std::unordered_map<std::string, size_t> assign_index_by_var_id;
const auto& sinks = model->get_sinks();
for (size_t i = 0; i < sinks.size(); ++i) {
if (auto assign = std::dynamic_pointer_cast<op::util::AssignBase>(sinks[i])) {
const auto& var_id = assign->get_variable_id();
assigns_by_var_id[var_id] = assign;
assign_index_by_var_id[var_id] = i;
}
}

restore_kv_cache_order(variables, assign_index_by_var_id);

ov::ParameterVector new_parameters;
ov::ResultVector new_results;
new_parameters.reserve(variables.size());
new_results.reserve(variables.size());

for (const auto& variable_id : variables) {
auto future_param = future_params[variable_id.variable_name];
auto parameter = ::set_name(std::make_shared<v0::Parameter>(future_param->get_output_element_type(0),
future_param->get_output_partial_shape(0)),
variable_id.input_name);

replace_node(future_param, parameter);

auto assign = assigns_by_var_id[variable_id.variable_name];
auto result = ::set_name(std::make_shared<v0::Result>(assign->input_value(0)), variable_id.output_name);

model->remove_sink(assign); // Don't do replace_node(assign, result)! It will lead to silently incorrect model.
model->remove_variable(model->get_variable_by_id(variable_id.variable_name));
new_parameters.push_back(parameter);
new_results.push_back(result);
}

model->add_parameters(new_parameters);
model->add_results(new_results);

return true;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
hf-internal-testing/tiny-random-LlamaForCausalLM,https://huggingface.co/trl-internal-testing/tiny-random-LlamaForCausalLM
hf-internal-testing/tiny-random-StableLmForCausalLM,https://huggingface.co/hf-internal-testing/tiny-random-StableLmForCausalLM
hf-internal-testing/tiny-random-PhiForCausalLM,https://huggingface.co/hf-internal-testing/tiny-random-PhiForCausalLM
hf-internal-testing/tiny-random-CodeGenForCausalLM,https://huggingface.co/hf-internal-testing/tiny-random-CodeGenForCausalLM
hf-internal-testing/tiny-random-Starcoder2ForCausalLM,https://huggingface.co/hf-internal-testing/tiny-random-Starcoder2ForCausalLM
hf-internal-testing/tiny-random-OPTForCausalLM,https://huggingface.co/hf-internal-testing/tiny-random-OPTForCausalLM
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import openvino as ov
from openvino._offline_transformations import stateful_to_stateless_transformation
from optimum.intel import OVModelForCausalLM
import models_hub_common.utils as utils
import pytest
import os

def get_read_value_ops(model: ov.Model):
return [op for op in model.get_ops() if op.get_type_name() == 'ReadValue']

def check_desc_tensors(tensors1, tensors2):
# order of tensors may not match, comparing by the total amount and names
assert len(tensors1) == len(tensors2)
assert set(tuple(t.names) for t in tensors1) == set(tuple(t.names) for t in tensors2)
for t1 in tensors1:
t2_candidates = [t for t in tensors2 if t1.names & t.names]
assert len(t2_candidates) == 1
t2 = t2_candidates[0]
assert t1.names == t2.names
assert t1.get_partial_shape() == t2.get_partial_shape()
assert t1.get_element_type() == t2.get_element_type()

def run_stateful_to_stateless_in_runtime(tmp_path, model_id, model_link):
model = OVModelForCausalLM.from_pretrained(model_id, export=True, stateful=True, compile=False)
assert len(model.model.get_sinks()), f"Input model is not in the expected stateful form because it doesn't have any sinks."
assert len(get_read_value_ops(model.model)), f"Input model is not in the expected stateful form because it doesn't have any ReadValue operations."

stateful_to_stateless_transformation(model.model)

sink_ops = model.model.get_sinks()
read_value_ops = get_read_value_ops(model.model)
assert len(sink_ops) == 0, f"Expected stateless model, but there are sinks found: {sink_ops}"
assert len(read_value_ops) == 0, f"Expected stateless model, but there are ReadValue operations found: {read_value_ops}"

stateless_model = OVModelForCausalLM.from_pretrained(model_id, export=True, stateful=False, compile=False)

print(model.model)
print(stateless_model.model)
check_desc_tensors(model.model.inputs, stateless_model.model.inputs)
check_desc_tensors(model.model.outputs, stateless_model.model.outputs)

core = ov.Core()
core.compile_model(model.model, 'CPU')


@pytest.mark.precommit
@pytest.mark.parametrize("model_name, model_link, mark, reason", utils.get_models_list(os.path.join(os.path.dirname(__file__), "models", "tiny-set-stateful-models-precommit")))
def test_stateful_to_stateless_precommit(tmp_path, model_name, model_link, mark, reason, ie_device):
assert mark is None or mark == 'skip' or mark == 'xfail', \
"Incorrect test case: {}, {}".format(model_name, model_link)
if mark == 'skip':
pytest.skip(reason)
elif mark == 'xfail':
pytest.xfail(reason)
run_stateful_to_stateless_in_runtime(tmp_path, model_name, model_link)

0 comments on commit 2a9af43

Please sign in to comment.