Skip to content

Commit

Permalink
remove code for no KV Cache path (pytorch#527)
Browse files Browse the repository at this point in the history
  • Loading branch information
mikekgfb authored and malfet committed Jul 17, 2024
1 parent af88c63 commit 092363f
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 95 deletions.
191 changes: 98 additions & 93 deletions runner/run.cpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
/* Inference for Llama-2 Transformer model in pure C++ */
#include <cstdint>
#include <cstdlib>
#include <ctype.h>
#include <iterator>
#include <math.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <tokenizer.h>
#include <cstdint>
#include <cstdlib>
#include <iterator>
#include <string>


#ifdef DEBUG
#include <cassert>
#include <iostream>
Expand Down Expand Up @@ -167,22 +166,14 @@ float* forward(Transformer* transformer, int token, int pos) {
torch::Tensor pos_tensor = torch::from_blob(pos_buffer, {1}, torch::kLong);
std::vector<torch::Tensor> inputs{token_tensor, pos_tensor};

torch::Tensor result = transformer->runner->run(inputs)[0].to(torch::dtype(torch::kFloat32));
torch::Tensor result =
transformer->runner->run(inputs)[0].to(torch::dtype(torch::kFloat32));
auto logits = result[0].data_ptr();

#else // __ET_MODEL__
ManagedTensor pos_managed(pos_buffer, sizeof(int64_t), {1}, ScalarType::Long);
#ifndef __KV_CACHE__
// @lint-ignore CLANGTIDY facebook-hte-LocalUncheckedArrayBounds
ManagedTensor tokens_managed(
&(s->toks[pos]),
/*ignored*/ sizeof(int64_t) * (pos + 1),
{1, 1},
ScalarType::Long);
#else // __KV_CACHE__
ManagedTensor tokens_managed(
token_buffer, sizeof(int64_t), {1, 1}, ScalarType::Long);
#endif
std::vector<EValue> inputs;
auto tmp1 = EValue(tokens_managed.get_aliasing_tensor());
auto tmp2 = EValue(pos_managed.get_aliasing_tensor());
Expand Down Expand Up @@ -491,9 +482,9 @@ void read_stdin(const char* guide, char* buffer, size_t bufsize) {
// is not safely implemented, it's more a proof of concept atm.

enum class ModelType {
unknown,
llama2,
llama3,
unknown,
llama2,
llama3,
};

ModelType get_model_type(Tokenizer* tokenizer) {
Expand All @@ -519,19 +510,27 @@ uint64_t get_eot_token(Tokenizer* tokenizer) {
return tokens[0];
}

fprintf(stderr, "No chat template implemnation for model type %d", model_type);
fprintf(
stderr, "No chat template implemnation for model type %d", model_type);
exit(EXIT_FAILURE);
}

std::vector<uint64_t> get_initial_prompt_tokens(const char* cli_system_prompt, const char* cli_user_prompt, Tokenizer* tokenizer) {
std::vector<uint64_t> get_initial_prompt_tokens(
const char* cli_system_prompt,
const char* cli_user_prompt,
Tokenizer* tokenizer) {
char system_prompt[512];
char user_prompt[512];
char rendered_prompt[512*2 + 200]; // the prompt template is ~170 characters. We use 200 to be safe.
char rendered_prompt[512 * 2 + 200]; // the prompt template is ~170
// characters. We use 200 to be safe.

if (cli_system_prompt != NULL) {
strcpy(system_prompt, cli_system_prompt);
} else {
read_stdin("Enter system prompt (optional): ", system_prompt, sizeof(system_prompt));
read_stdin(
"Enter system prompt (optional): ",
system_prompt,
sizeof(system_prompt));
}

if (cli_user_prompt != NULL) {
Expand All @@ -540,111 +539,114 @@ std::vector<uint64_t> get_initial_prompt_tokens(const char* cli_system_prompt, c
read_stdin("User: ", user_prompt, sizeof(user_prompt));
}

ModelType model_type = get_model_type(tokenizer);
std::vector<uint64_t> tokens;

switch (model_type) {
ModelType model_type = get_model_type(tokenizer);
std::vector<uint64_t> tokens;

switch (model_type) {
case ModelType::llama2:
if (system_prompt[0] != '\0') {
snprintf(
rendered_prompt,
sizeof(rendered_prompt)-1,
"[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]",
system_prompt,
user_prompt
);
rendered_prompt,
sizeof(rendered_prompt) - 1,
"[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]",
system_prompt,
user_prompt);
} else {
// const char prompt_template[] = ;
snprintf(
rendered_prompt,
sizeof(rendered_prompt)-1,
"[INST] %s [/INST]",
user_prompt
);
rendered_prompt,
sizeof(rendered_prompt) - 1,
"[INST] %s [/INST]",
user_prompt);
}

// We need to add BOS token here and not in template because llama2 tokenizer
// does not pattern match special tokens
// We need to add BOS token here and not in template because llama2
// tokenizer does not pattern match special tokens
tokens = tokenizer->encode(rendered_prompt, 1, 0);
break;

case ModelType::llama3:
if (system_prompt[0] != '\0') {
snprintf(
rendered_prompt,
sizeof(rendered_prompt)-1,
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
system_prompt,
user_prompt
);
rendered_prompt,
sizeof(rendered_prompt) - 1,
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
system_prompt,
user_prompt);
} else {
snprintf(
rendered_prompt,
sizeof(rendered_prompt)-1,
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
user_prompt
);
rendered_prompt,
sizeof(rendered_prompt) - 1,
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
user_prompt);
}
tokens = tokenizer->encode(rendered_prompt, 0, 0);
break;

default:
fprintf(stderr, "No chat template implemnation for model type %d", model_type);
fprintf(
stderr,
"No chat template implemnation for model type %d",
model_type);
exit(EXIT_FAILURE);
}
}

#ifdef DEBUG
std::cerr << "Start of rendered prompt:" << std::endl;
std::cerr << rendered_prompt;
std::cerr << "End of rendered prompt:" << std::endl;
std::cerr << "Encoded prompt: ";
for (int i = 0; i < tokens.size(); i++) {
std::cerr << tokens[i] << ", ";
}
std::cerr << std::endl << std::flush;
#endif
#ifdef DEBUG
std::cerr << "Start of rendered prompt:" << std::endl;
std::cerr << rendered_prompt;
std::cerr << "End of rendered prompt:" << std::endl;
std::cerr << "Encoded prompt: ";
for (int i = 0; i < tokens.size(); i++) {
std::cerr << tokens[i] << ", ";
}
std::cerr << std::endl << std::flush;
#endif

return tokens;
return tokens;
}

std::vector<uint64_t> get_next_user_prompt_tokens(Tokenizer* tokenizer) {
char user_prompt[512];
char rendered_prompt[512 + 150]; // the prompt template is ~100 characters. We use 150 to be safe.
char rendered_prompt[512 + 150]; // the prompt template is ~100 characters. We
// use 150 to be safe.

read_stdin("User: ", user_prompt, sizeof(user_prompt));

ModelType model_type = get_model_type(tokenizer);
std::vector<uint64_t> tokens;

switch (model_type) {

case ModelType::llama2:
// const char prompt_template[] = ;
snprintf(rendered_prompt, sizeof(rendered_prompt)-1, "[INST] %s [/INST]", user_prompt);
snprintf(
rendered_prompt,
sizeof(rendered_prompt) - 1,
"[INST] %s [/INST]",
user_prompt);

// We need to add BOS token here and not in template because llama2 tokenizer
// does not pattern match special tokens
tokens = tokenizer->encode(rendered_prompt, /*bos*/1, /*eos*/0);
// We need to add BOS token here and not in template because llama2
// tokenizer does not pattern match special tokens
tokens = tokenizer->encode(rendered_prompt, /*bos*/ 1, /*eos*/ 0);
break;

case ModelType::llama3:
snprintf(
rendered_prompt,
sizeof(rendered_prompt)-1,
"<|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
user_prompt
);
rendered_prompt,
sizeof(rendered_prompt) - 1,
"<|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
user_prompt);
tokens = tokenizer->encode(rendered_prompt, 0, 0);
break;

default:
fprintf(stderr, "No chat template implemnation for model type %d", model_type);
fprintf(
stderr,
"No chat template implemnation for model type %d",
model_type);
exit(EXIT_FAILURE);
}


#ifdef DEBUG
#ifdef DEBUG
std::cerr << "Start of rendered prompt:" << std::endl;
std::cerr << rendered_prompt;
std::cerr << "End of rendered prompt:" << std::endl;
Expand All @@ -653,20 +655,18 @@ std::vector<uint64_t> get_next_user_prompt_tokens(Tokenizer* tokenizer) {
std::cerr << tokens[i] << ", ";
}
std::cerr << std::endl << std::flush;
#endif
#endif

return tokens;
}


void chat(
Transformer* transformer,
Tokenizer* tokenizer,
Sampler* sampler,
const char* cli_user_prompt,
const char* cli_system_prompt,
int steps) {

const uint64_t EOT_TOKEN = get_eot_token(tokenizer);
int num_prompt_tokens = 0;
std::vector<uint64_t> prompt_tokens;
Expand All @@ -679,12 +679,12 @@ void chat(
int prev_token;
int pos = 0; // position in the sequence
while (pos < steps) {

// when it is the user's turn to contribute tokens to the dialog...
if (user_turn) {
// get the (optional) system prompt at position 0
if (pos == 0) {
prompt_tokens = get_initial_prompt_tokens(cli_system_prompt, cli_user_prompt, tokenizer);
prompt_tokens = get_initial_prompt_tokens(
cli_system_prompt, cli_user_prompt, tokenizer);
} else {
prompt_tokens = get_next_user_prompt_tokens(tokenizer);
}
Expand All @@ -711,12 +711,12 @@ void chat(

// std::cout << "TOKEN: " << token << " NEXT: " << next << std::endl;


if ((user_idx >= num_prompt_tokens) && (token == EOT_TOKEN)) {
user_turn = 1;
}

if (user_idx >= num_prompt_tokens && token != EOT_TOKEN && next != EOT_TOKEN) {
if (user_idx >= num_prompt_tokens && token != EOT_TOKEN &&
next != EOT_TOKEN) {
std::string piece = tokenizer->decode(token, next);
safe_printf(piece.c_str()); // same as printf("%s", piece), but skips
// "unsafe" bytes
Expand All @@ -727,7 +727,6 @@ void chat(
printf("\n");
}
pos++;

}
printf("\n");
}
Expand All @@ -752,7 +751,9 @@ void error_usage() {
fprintf(stderr, " -z <string> optional path to custom tokenizer\n");
fprintf(stderr, " -m <string> mode: generate|chat, default: generate\n");
fprintf(stderr, " -y <string> (optional) system prompt in chat mode\n");
fprintf(stderr, " -l <int> (optional) llama version (2 or 3). Defaults to 2.\n");
fprintf(
stderr,
" -l <int> (optional) llama version (2 or 3). Defaults to 2.\n");
exit(EXIT_FAILURE);
}

Expand All @@ -776,7 +777,8 @@ int main(int argc, char* argv[]) {
int llama_ver = 2;

#if defined(ET_USE_ADPATIVE_THREADS)
uint32_t num_performant_cores = torch::executorch::cpuinfo::get_num_performant_cores();
uint32_t num_performant_cores =
torch::executorch::cpuinfo::get_num_performant_cores();
if (num_performant_cores > 0) {
torch::executorch::threadpool::get_threadpool()->_unsafe_reset_threadpool(
num_performant_cores);
Expand Down Expand Up @@ -820,9 +822,8 @@ int main(int argc, char* argv[]) {
} else if (argv[i][1] == 'y') {
system_prompt = argv[i + 1];
} else if (argv[i][1] == 'l') {
llama_ver = atoi(argv[i+1]);
}
else {
llama_ver = atoi(argv[i + 1]);
} else {
error_usage();
}
}
Expand All @@ -837,7 +838,6 @@ int main(int argc, char* argv[]) {
if (steps < 0)
steps = 0;


if (vocab_size == -1) {
if (llama_ver == 2) {
vocab_size = 32000;
Expand All @@ -855,16 +855,21 @@ int main(int argc, char* argv[]) {

switch (llama_ver) {
case 2:
tokenizer = new BPETokenizer(transformer.config.vocab_size, /*bos*/ 1, /*eos*/ 2);
tokenizer =
new BPETokenizer(transformer.config.vocab_size, /*bos*/ 1, /*eos*/ 2);
tokenizer->load(tokenizer_path);
break;
case 3:
tokenizer = new Tiktoken(transformer.config.vocab_size, /*bos*/ 1, /*eos*/ 2);
tokenizer =
new Tiktoken(transformer.config.vocab_size, /*bos*/ 1, /*eos*/ 2);
tokenizer->load(tokenizer_path);
break;

default:
fprintf(stderr, "Cannot load tokenizer for unrecognized llama version %d", llama_ver);
fprintf(
stderr,
"Cannot load tokenizer for unrecognized llama version %d",
llama_ver);
exit(EXIT_FAILURE);
}

Expand Down
Loading

0 comments on commit 092363f

Please sign in to comment.