Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
[LLM Runtime] Unify KV_cache and Support Batch-dim Process in Beam Se…
Browse files Browse the repository at this point in the history
…arch (#583)
  • Loading branch information
zhentaoyu authored Nov 10, 2023
1 parent abce937 commit 2246567
Show file tree
Hide file tree
Showing 23 changed files with 1,057 additions and 676 deletions.
13 changes: 9 additions & 4 deletions intel_extension_for_transformers/llm/runtime/graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def quant_model(self, model_name, model_path, out_path, **kwargs):

def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=False, **kwargs):
if self.model is None:
self.init_from_bin(self.model_type, self.bin_file, **kwargs)
self.init_from_bin(self.model_type, self.bin_file, batch_size=input_ids.shape[0],
**kwargs)
self.generate_round = 0
elif not interactive:
self.model.reinit()
Expand All @@ -107,12 +108,13 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa
if self.generate_round == 0 and not ignore_prompt:
ret = input_ids.tolist()

# TODO support multi batch
assert input_ids.shape[0] == 1, "Unsupport multi-batch input ids."
beam_search = False
if ("num_beams" in kwargs and kwargs["num_beams"] > 1) and not \
kwargs.get("do_sample", False):
beam_search = True
if not beam_search:
# TODO support multi batch
assert input_ids.shape[0] == 1, "Unsupport multi-batch input ids."
if streamer:
if beam_search:
print("ERROR, can not use streamer when use beam search for generation!")
Expand All @@ -130,7 +132,10 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa
ret[0].extend(out)
streamer.end()
else:
ret[0].extend(self.model.generate_tokens(input_ids = input_ids.tolist()[0]))
response = self.model.generate_tokens(input_ids = input_ids.tolist())
assert (len(ret) == len(response))
for i in range(len(response)):
ret[i].extend(response[i])

self.generate_round += 1
return ret
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ class Model {
~Model() {
if (ctx) model_free(ctx);
}
void init_model(const std::string& model_path, int n_predict, int batch_size, int ctx_size, int seed, int threads,
void init_model(const std::string& model_path, int n_predict, int n_batch, int ctx_size, int seed, int threads,
float repetition_penalty, int num_beams, bool do_sample, int top_k, float top_p, float temperature,
int min_new_tokens, float length_penalty, bool early_stopping, int n_keep, int n_discard,
bool shift_roped_k);
bool shift_roped_k, int batch_size, model_vocab::id pad_token);
void reinit();
std::vector<model_token> generate(const std::vector<model_token>& input_ids);
std::vector<model_token> generate_tokens(const std::vector<model_token>& input_ids);
std::vector<std::vector<model_token>> generate_tokens(const std::vector<std::vector<model_token>>& input_ids);
bool is_token_end() { return token_eos; }
static int quant_model(const std::string& model_path, const std::string& out_path, const std::string& weight_dtype,
const std::string& alg, int group_size, const std::string& scale_dtype,
Expand All @@ -86,28 +86,29 @@ class Model {

model_token post_process(float* logits);
model_token post_greedy_search(float* logits);
std::vector<model_token> post_beam_search(model_context* lctx, const int& n_predict, const model_token* tokens_inp,
const int& n_tokens, const int& n_threads);
std::vector<std::vector<model_token>> post_beam_search(model_context* lctx, const int& n_predict,
const std::vector<model_input>& inputs, const int& n_threads);
model_token post_sample_top_k_top_p_repeat(float* logits);
};

void Model::init_model(const std::string& model_path, int max_new_tokens, int batch_size, int ctx_size, int seed,
void Model::init_model(const std::string& model_path, int max_new_tokens, int n_batch, int ctx_size, int seed,
int threads, float repetition_penalty, int num_beams, bool do_sample, int top_k, float top_p,
float temperature, int min_new_tokens, float length_penalty, bool early_stopping, int n_keep,
int n_discard, bool shift_roped_k) {
int n_discard, bool shift_roped_k, int batch_size, model_vocab::id pad_token) {
#ifdef MODEL_NAME
params.model_name = MODEL_NAME;
#endif
params.model_arch = model_name_to_arch::init().find(params.model_name);
params.model = model_path;
params.n_predict = max_new_tokens;
params.n_batch = batch_size;
params.n_batch = n_batch;
params.n_ctx = ctx_size;
params.seed = seed;
params.n_threads = threads;
params.repeat_penalty = repetition_penalty;
params.beam_size = num_beams;
params.do_sample = do_sample;
params.batch_size = batch_size;
params.beam_search = (num_beams > 1 && !do_sample) ? true : false;
if (params.beam_search) {
params.memory_type = KV_MEM_TYPE_F16; // TODO NO MHA IN BEAM SEARCH
Expand All @@ -133,6 +134,7 @@ void Model::init_model(const std::string& model_path, int max_new_tokens, int ba
ctx->generation_conf.min_new_tokens = min_new_tokens;
ctx->generation_conf.length_penalty = length_penalty;
ctx->generation_conf.do_early_stopping = early_stopping;
if (pad_token != -1) ctx->vocab.pad_token_id = pad_token;
}

void Model::reinit() {
Expand Down Expand Up @@ -177,7 +179,18 @@ std::vector<model_token> Model::generate(const std::vector<model_token>& input_i
NE_ASSERT(("n_discard cannot be used with shift_roped_k!", n_discard == -1 || n_discard == 1));
}
}
model_eval(ctx, &curr_input_ids[0], curr_input_ids.size(), n_past, n_total, params.n_threads);
std::vector<model_input> inputs = {model_input{
/*.tokens =*/curr_input_ids.data(),
/*.n_tokens =*/(uint32_t)curr_input_ids.size(),
/*.n_prompt_tokens =*/0,
/*.n_past =*/(uint32_t)n_past,
/*.n_total =*/(uint32_t)n_total,
/*.request_idx =*/0,
/*.beam_idx =*/0,
/*.padding_side =*/0,
/*n_padding =*/0,
}};
model_eval(ctx, inputs.data(), inputs.size(), params.n_threads);
n_past += curr_input_ids.size();
n_total += curr_input_ids.size();

Expand All @@ -196,18 +209,52 @@ std::vector<model_token> Model::generate(const std::vector<model_token>& input_i
return {next_token_id};
}

std::vector<model_token> Model::generate_tokens(const std::vector<model_token>& input_ids) {
std::vector<std::vector<model_token>> Model::generate_tokens(const std::vector<std::vector<model_token>>& input_ids) {
int n_remain = params.n_predict;
std::vector<model_token> output_ids;
std::vector<std::vector<model_token>> rets;

if (ctx->beam_search) {
MODEL_ASSERT(input_ids.size() == ctx->batch_size);
if (ctx->batch_size > 1 && ctx->vocab.pad_token_id == -1) {
fprintf(stderr, "\nERROR: please set pad_token for beam search multi-batch generation!\n");
return rets;
}
std::vector<model_input> inputs;
for (int bs = 0; bs < input_ids.size(); ++bs) {
uint32_t count = 0;
model_vocab::id pad_token_id = ctx->vocab.pad_token_id;
auto iter = std::find_if(input_ids[bs].begin(), input_ids[bs].end(),
[&pad_token_id](model_token t) { return (t != pad_token_id); });
if (iter == input_ids[bs].end()) fprintf(stderr, "\nERROR: there are all pad tokens in batch %d!\n", bs);
count = std::distance(input_ids[bs].begin(), iter);
inputs.push_back(model_input{
/*.tokens =*/input_ids[bs].data(),
/*.n_tokens =*/(uint32_t)input_ids[bs].size(),
/*.n_prompt_tokens =*/0,
/*.n_past =*/0,
/*.n_total =*/0,
/*.request_idx =*/bs,
/*.beam_idx =*/0,
/*.padding_side =*/0,
/*n_padding =*/count,
});
}
return post_beam_search(ctx, n_remain, inputs, params.n_threads);
}
if (input_ids.size() > 1) {
fprintf(stderr, "\nERROR: Only beam search supports multi-batch generation!\n");
return rets;
}

if (curr_input_ids.empty()) {
if (input_ids.size() > n_ctx - 4) {
if (input_ids[0].size() > n_ctx - 4) {
fprintf(stderr, "\n%s: Warning: prompt is too long (%d tokens, max %d), will be truncated\n", __func__,
input_ids.size(), n_ctx - 4);
input_ids[0].size(), n_ctx - 4);
curr_input_ids.resize(n_ctx - 4);
std::copy(input_ids.end() - n_ctx - 4, input_ids.end(), curr_input_ids.begin());
std::copy(input_ids[0].end() - n_ctx - 4, input_ids[0].end(), curr_input_ids.begin());
} else {
curr_input_ids = input_ids;
curr_input_ids = input_ids[0];
}
}

Expand All @@ -231,11 +278,18 @@ std::vector<model_token> Model::generate_tokens(const std::vector<model_token>&
NE_ASSERT(("n_discard cannot be used with shift_roped_k!", n_discard == -1 || n_discard == 1));
}
}
if (ctx->beam_search) {
output_ids = post_beam_search(ctx, n_remain, curr_input_ids.data(), curr_input_ids.size(), params.n_threads);
break;
}
model_eval(ctx, &curr_input_ids[0], curr_input_ids.size(), n_past, n_total, params.n_threads);
std::vector<model_input> inputs = {model_input{
/*.tokens =*/curr_input_ids.data(),
/*.n_tokens =*/(uint32_t)curr_input_ids.size(),
/*.n_prompt_tokens =*/0,
/*.n_past =*/(uint32_t)n_past,
/*.n_total =*/(uint32_t)n_total,
/*.request_idx =*/0,
/*.beam_idx =*/0,
/*.padding_side =*/0,
/*n_padding =*/0,
}};
model_eval(ctx, inputs.data(), inputs.size(), params.n_threads);
n_past += curr_input_ids.size();
n_total += curr_input_ids.size();

Expand All @@ -253,25 +307,25 @@ std::vector<model_token> Model::generate_tokens(const std::vector<model_token>&
break;
}
}

return output_ids;
rets.push_back(output_ids);
return rets;
}

model_token Model::post_greedy_search(float* logits) {
model_token id = std::max_element(logits, logits + n_vocab) - logits;
return id;
}

std::vector<model_token> Model::post_beam_search(model_context* lctx, const int& n_predict,
const model_token* tokens_inp, const int& n_tokens,
const int& n_threads) {
std::vector<std::vector<model_token>> Model::post_beam_search(model_context* lctx, const int& n_predict,
const std::vector<model_input>& inputs,
const int& n_threads) {
// TODO: to implement
static std::set<model_archs> supported_archs = {MODEL_GPTJ, MODEL_GPTNEOX};
if (supported_archs.count(params.model_arch) != 0) {
return beam_search(lctx, n_predict, tokens_inp, n_tokens, n_threads);
return beam_search(lctx, n_predict, inputs, n_threads);
} else {
fprintf(stderr, "\nERROR: this model does not support beam search generation!\n");
return std::vector<model_token>();
return std::vector<std::vector<model_token>>();
}
}

Expand Down Expand Up @@ -416,11 +470,12 @@ PYBIND11_MODULE(mistral_cpp, m)
py::class_<Model>(m, "Model", py::module_local())
.def(py::init())
.def("init_model", &Model::init_model, "initial model with model path and parameters", py::arg("model_path"),
py::arg("max_new_tokens") = -1, py::arg("batch_size") = 512, py::arg("ctx_size") = 512, py::arg("seed") = -1,
py::arg("max_new_tokens") = -1, py::arg("n_batch") = 512, py::arg("ctx_size") = 512, py::arg("seed") = -1,
py::arg("threads") = 8, py::arg("repetition_penalty") = 1.1f, py::arg("num_beams") = 1,
py::arg("do_sample") = false, py::arg("top_k") = 40, py::arg("top_p") = 0.95, py::arg("temperature") = 0.8,
py::arg("min_new_tokens") = 0, py::arg("length_penalty") = 1.0, py::arg("early_stopping") = false,
py::arg("n_keep") = 0, py::arg("n_discard") = -1, py::arg("shift_roped_k") = false)
py::arg("n_keep") = 0, py::arg("n_discard") = -1, py::arg("shift_roped_k") = false,
py::arg("batch_size") = 1, py::arg("pad_token") = -1)
.def("generate", &Model::generate, "Generate token with input ids", py::arg("input_ids"))
.def("generate_tokens", &Model::generate_tokens, "Generate tokens with input ids", py::arg("input_ids"))
.def_static("quant_model", &Model::quant_model, "Quantize model", py::arg("model_path"), py::arg("out_path"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,36 @@ int main(int argc, char** argv) {
if (params.mem_test) {
{
const std::vector<model_token> tmp(params.n_batch, ctx->vocab.bos_token_id);
model_eval(ctx, tmp.data(), tmp.size(), 0, 0, params.n_threads);
std::vector<model_input> inputs = {model_input{
/*.tokens =*/tmp.data(),
/*.n_tokens =*/(uint32_t)tmp.size(),
/*.n_prompt_tokens =*/0,
/*.n_past =*/0,
/*.n_total =*/0,
/*.request_idx =*/0,
/*.beam_idx =*/0,
/*.padding_side =*/0,
/*n_padding =*/0,
}};
model_eval(ctx, inputs.data(), inputs.size(), params.n_threads);
}

{
const std::vector<model_token> tmp = {
0,
};
model_eval(ctx, tmp.data(), tmp.size(), params.n_predict - 1, params.n_predict - 1, params.n_threads);
std::vector<model_input> inputs = {model_input{
/*.tokens =*/tmp.data(),
/*.n_tokens =*/(uint32_t)tmp.size(),
/*.n_prompt_tokens =*/0,
/*.n_past =*/(uint32_t)(params.n_predict - 1),
/*.n_total =*/(uint32_t)(params.n_predict - 1),
/*.request_idx =*/0,
/*.beam_idx =*/0,
/*.padding_side =*/0,
/*n_padding =*/0,
}};
model_eval(ctx, inputs.data(), inputs.size(), params.n_threads);
}

model_print_timings(ctx);
Expand Down Expand Up @@ -436,7 +458,18 @@ int main(int argc, char** argv) {
if (n_eval > params.n_batch) {
n_eval = params.n_batch;
}
if (model_eval(ctx, &embd[i], n_eval, n_past, n_total, params.n_threads)) {
std::vector<model_input> inputs = {model_input{
/*.tokens =*/&embd[i],
/*.n_tokens =*/(uint32_t)n_eval,
/*.n_prompt_tokens =*/0,
/*.n_past =*/(uint32_t)n_past,
/*.n_total =*/(uint32_t)n_total,
/*.request_idx =*/0,
/*.beam_idx =*/0,
/*.padding_side =*/0,
/*n_padding =*/0,
}};
if (model_eval(ctx, inputs.data(), inputs.size(), params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,16 @@ bool gptj_model_eval_ids(model_context* ctx, model_token* tokens, size_t n_eval,
return 1;
}

if (model_eval(ctx, tokens, n_eval, n_past, n_past, n_threads)) {
std::vector<model_input> inputs = {model_input{
/*.tokens =*/tokens,
/*.n_tokens =*/static_cast<uint32_t>(n_eval),
/*.n_prompt_tokens =*/0,
/*.n_past =*/static_cast<uint32_t>(n_past),
/*.n_total =*/static_cast<uint32_t>(n_past),
/*.request_idx =*/0,
/*.beam_idx =*/0,
}};
if (model_eval(ctx, inputs.data(), inputs.size(), n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}
Expand Down Expand Up @@ -98,7 +107,16 @@ int32_t* eval_gptj_ids(void* ctx, int32_t* embd_inp_ptr, int ind_size, int n_pre
bool do_beam_search = lctx->beam_search;

if (do_beam_search) {
res = beam_search(lctx, n_predict, embd_inp_ptr, ind_size, n_threads);
std::vector<model_input> inputs = {model_input{
/*.tokens =*/embd_inp_ptr,
/*.n_tokens =*/static_cast<uint32_t>(ind_size),
/*.n_prompt_tokens =*/0,
/*.n_past =*/0,
/*.n_total =*/0,
/*.request_idx =*/0,
/*.beam_idx =*/0,
}};
res = beam_search(lctx, n_predict, inputs, n_threads)[0];
} else {
std::vector<model_token> embd_inp(embd_inp_ptr, embd_inp_ptr + ind_size);
std::vector<model_token> embd;
Expand Down Expand Up @@ -157,7 +175,18 @@ char* eval_gptj_char(void* ctx, const char* prom, int n_predict, int top_k, floa

bool do_beam_search = lctx->beam_search;
if (do_beam_search) {
embd = beam_search(lctx, n_predict, embd_inp.data(), embd_inp.size(), N_threads);
std::vector<model_input> inputs = {model_input{
/*.tokens =*/embd_inp.data(),
/*.n_tokens =*/static_cast<uint32_t>(embd_inp.size()),
/*.n_prompt_tokens =*/0,
/*.n_past =*/0,
/*.n_total =*/0,
/*.request_idx =*/0,
/*.beam_idx =*/0,
/*.padding_side =*/0,
/*n_padding =*/0,
}};
embd = beam_search(lctx, n_predict, inputs, N_threads)[0];
for (auto id : embd_inp) {
res += model_token_to_str(lctx, id);
}
Expand Down Expand Up @@ -229,7 +258,7 @@ int main(int argc, char* argv[]) {
for (auto gptj_in_all : ctxs) {
auto res = eval_gptj_char(
gptj_in_all,
// "she opened the door and see",
//"she opened the door and see",
// "Once upon a time",
// "Tell me 10 things about jazz music",
// "A spaceship lands on the moon",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
#define NE_FILE_VERSION 1

#define NE_MAX_DIMS 4
#define NE_MAX_NODES 4096
#define NE_MAX_NODES 8192
#define NE_MAX_PARAMS 256
#define NE_MAX_CONTEXTS 64
#define NE_MAX_OPT 4
Expand Down
Loading

0 comments on commit 2246567

Please sign in to comment.