Skip to content

Commit

Permalink
Add time to first token for llama runner (#2141)
Browse files Browse the repository at this point in the history
Summary:

Add time to first generated token & other features



- Since we're measuring the first token time, the token rate is measured both at the

* Model Load Time - just a timer around   ET_CHECK_OK_OR_RETURN_ERROR(load());
* Total inference time - Immediately after model load until the end of the inference loop
* >>First token time - From immediately after the model load until the first generated (not prompt) token is printed.
* >>>>Prompt eval - (comparable to llama.cpp prompt_eval_time) prompt array allocation and tokenization. Ends right before the inference loop starts
* >>Remaining tokens - immediately after the first token is outputted until the end of the inference loop
* >>Net eval time - (comparable to llama.cpp eval_time) Total time spent generating tokens.

To implement:
* Sample time - amount of time spent sampling per token (present in llama.cpp)

Differential Revision: D54223564
  • Loading branch information
Varun Puri authored and facebook-github-bot committed Mar 6, 2024
1 parent 0570294 commit 095524e
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 14 deletions.
82 changes: 68 additions & 14 deletions examples/models/llama2/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,14 @@ Error Runner::generate(
// Prepare the inputs.
// Use ones-initialized inputs.
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
timers_.model_load = util::time_in_ms();
ET_CHECK_OK_OR_RETURN_ERROR(load());
timers_.model_load = util::time_in_ms() - timers_.model_load;

// First token time only measures the time it takes to encode the prompt and
// return a response token.

timers_.start = util::time_in_ms();
shouldStop_ = false;

// encode the (string) prompt into tokens sequence
Expand All @@ -173,12 +179,14 @@ Error Runner::generate(
// Set the sequence length to the max seq length if not provided
seq_len = (seq_len > 0 && seq_len <= max_seq_len_) ? seq_len : max_seq_len_;


tokenizer_->encode(
prompt.c_str(),
n_bos_,
append_eos_ ? n_eos_ : 0,
prompt_tokens,
&num_prompt_tokens);

for (int i = 0; i < num_prompt_tokens; i++) {
ET_LOG(Info, "prompt_tokens[%d]: %d", i, prompt_tokens[i]);
}
Expand All @@ -192,8 +200,6 @@ Error Runner::generate(
"Sequence length exceeded - please increase the seq_len value passed to generate()");

// start the main loop
long start =
0; // used to time our code, only initialized after first iteration
int next; // will store the next token in the sequence
int64_t pos = num_prompt_tokens - 1; // position in the sequence
int token = prompt_tokens[pos]; // prefill starts from 0 to num_prompt_tokens
Expand Down Expand Up @@ -255,6 +261,7 @@ Error Runner::generate(
tokenizer_->decode(prompt_tokens[i - 1], prompt_tokens[i])));
}
}

// create a 1xN int tensor with next as value
while (pos < seq_len) {
// ET_LOG(Info, "Generating step %d...", pos);
Expand Down Expand Up @@ -290,7 +297,12 @@ Error Runner::generate(
outputs.size() > 0,
"Expecting output to have at least one evalue. Got %zu",
outputs.size());

if (pos == num_prompt_tokens) {
timers_.first_token = util::time_in_ms() - timers_.start;
timers_.remaining_tokens = util::time_in_ms();
} else if (pos == num_prompt_tokens - 1) {
timers_.prompt_eval = util::time_in_ms() - timers_.start;
}
int32_t next_tok;
exec_aten::Tensor logits_tensor = outputs.at(logits_index).toTensor();

Expand Down Expand Up @@ -342,6 +354,7 @@ Error Runner::generate(
if (pos >= num_prompt_tokens && next == eos_id_) {
eos_counter++;
if (eos_counter == n_eos_) {
printf("\n");
ET_LOG(Info, "Reached to the end of generation");
break;
}
Expand All @@ -351,10 +364,6 @@ Error Runner::generate(

token = next;

// init the timer here because the first iteration can be slower
if (start == 0) {
start = util::time_in_ms();
}
if (use_kv_cache_) {
// outputs: [k_cache, v_cache, logits, k_cache, v_cache]
memcpy(
Expand All @@ -367,23 +376,68 @@ Error Runner::generate(
v_data.size());
}
}
timers_.remaining_tokens = util::time_in_ms() - timers_.remaining_tokens;
timers_.end = util::time_in_ms();
printf("\n");

if (pos == seq_len) {
ET_LOG(Info, "Sequence length (%i tokens) reached!", seq_len);
}
// report achieved tok/s (pos-1 because the timer starts after first
// iteration)
if (pos >= 1) {
long end = util::time_in_ms();
ET_LOG(
Info, "Achieved tok/s: %f\n", (pos - 1) / (double)(end - start) * 1000);
}

printReport(num_prompt_tokens, pos - num_prompt_tokens);

delete[] prompt_tokens;
return Error::Ok;
}

void Runner::printReport(
int64_t num_prompt_tokens,
int64_t num_generated_tokens) {
printf("\n");
double net_eval_time =
(double)(timers_.first_token + timers_.remaining_tokens - timers_.prompt_eval);

ET_LOG(
Info,
"\tPrompt Tokens: %ld Generated Tokens: %ld",
num_prompt_tokens,
num_generated_tokens);

ET_LOG(
Info,
"\tModel Load Time:\t\t%f (seconds)",
((double)(timers_.model_load) / 1000));
ET_LOG(
Info,
"\tTotal inference time:\t\t%f (seconds)\t\t Token Rate: \t%f (tokens/second)",
(double)(timers_.end - timers_.start) / 1000,

(num_generated_tokens) / (double)(timers_.end - timers_.start) * 1000);
ET_LOG(
Info,
"\t\tTime to first token:\t%f (seconds)",
((double)(timers_.first_token) / 1000));
ET_LOG(
Info,
"\t\t\tPrompt eval:\t%f (seconds)\t\t Token Rate: \t%f (tokens/second)",
((double)(timers_.prompt_eval) / 1000),
(num_prompt_tokens) / (double)(timers_.prompt_eval) * 1000);

ET_LOG(
Info,
"\t\tRemaining %ld tokens:\t%f (seconds)\t\t Token Rate: \t%f (tokens/second)",
num_generated_tokens - 1,
(double)(timers_.remaining_tokens) / 1000,

(num_generated_tokens - 1) / (double)(timers_.remaining_tokens) * 1000);
ET_LOG(
Info,
"\t\tNet evaluation time:\t%f (seconds)\t\t Token Rate: \t%f (tokens/second)",
(net_eval_time / 1000),
(num_generated_tokens) / net_eval_time * 1000);

}

void Runner::stop() {
shouldStop_ = true;
}
Expand Down
14 changes: 14 additions & 0 deletions examples/models/llama2/runner/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,20 @@ class Runner {
std::unique_ptr<Tokenizer> tokenizer_;
std::unique_ptr<Sampler> sampler_;
bool shouldStop_{false};

struct timers {

long start;
long model_load;
long prompt_eval;
long eval;
long first_token;
long remaining_tokens;
long end;

};
timers timers_;
void printReport(int64_t num_prompt_tokens, int64_t num_generated_tokens);
};

} // namespace torch::executor

0 comments on commit 095524e

Please sign in to comment.