Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow exporting a view of the KV cache #4180

Merged
merged 6 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <regex>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <cinttypes>
Expand Down Expand Up @@ -1386,3 +1387,77 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
}

//
// KV cache utils
//

void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size) {
static const char slot_chars[] = ".123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+";

printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d",
view.n_cells, view.n_max_seq, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);

llama_kv_cache_view_cell * c_curr = view.cells;
llama_seq_id * cs_curr = view.cells_sequences;

for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
if (i % row_size == 0) {
printf("\n%5d: ", i);
}
int seq_count = 0;
for (int j = 0; j < view.n_max_seq; j++) {
if (cs_curr[j] >= 0) { seq_count++; }
}
putchar(slot_chars[std::min(sizeof(slot_chars) - 2, size_t(seq_count))]);
}

printf("\n=== Done dumping\n");
}

void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) {
static const char slot_chars[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";

printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d\n",
view.n_cells, view.n_max_seq, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);

std::unordered_map<llama_seq_id, size_t> seqs;
llama_kv_cache_view_cell * c_curr = view.cells;
llama_seq_id * cs_curr = view.cells_sequences;

for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
for (int j = 0; j < view.n_max_seq; j++) {
if (cs_curr[j] < 0) { continue; }
if (seqs.find(cs_curr[j]) == seqs.end()) {
if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
seqs[cs_curr[j]] = seqs.size();
}
}
if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
}

printf("=== Sequence legend: ");
for (const auto & it : seqs) {
printf("%zu=%d, ", it.second, it.first);
}
printf("'+'=other sequence ids");

c_curr = view.cells;
cs_curr = view.cells_sequences;
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
if (i % row_size == 0) {
printf("\n%5d: ", i);
}
for (int j = 0; j < view.n_max_seq; j++) {
if (cs_curr[j] >= 0) {
const auto & it = seqs.find(cs_curr[j]);
putchar(it != seqs.end() ? int(slot_chars[it->second]) : '+');
} else {
putchar('.');
}
}
putchar(' ');
}

printf("\n=== Done dumping\n");
}
10 changes: 10 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,13 @@ std::string get_sortable_timestamp();
void dump_non_result_info_yaml(
FILE * stream, const gpt_params & params, const llama_context * lctx,
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc);

//
// KV cache utils
//

// Dump the KV cache view with the number of sequences per cell.
void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80);

// Dump the KV cache view showing individual sequences in each cell (long output).
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
5 changes: 5 additions & 0 deletions examples/parallel/parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ int main(int argc, char ** argv) {
int32_t n_total_gen = 0;
int32_t n_cache_miss = 0;

struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, n_clients);

const auto t_main_start = ggml_time_us();

LOG_TEE("%s: Simulating parallel requests from clients:\n", __func__);
Expand Down Expand Up @@ -201,6 +203,9 @@ int main(int argc, char ** argv) {
LOG_TEE("Processing requests ...\n\n");

while (true) {
llama_kv_cache_view_update(ctx, &kvc_view);
dump_kv_cache_view_seqs(kvc_view, 40);
KerfuffleV2 marked this conversation as resolved.
Show resolved Hide resolved

llama_batch_clear(batch);

// decode any currently ongoing sequences
Expand Down
89 changes: 89 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8805,6 +8805,95 @@ int llama_model_apply_lora_from_file(const struct llama_model * model, const cha
}
}

struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq) {
struct llama_kv_cache_view result = {
/*.n_cells = */ 0,
/*.n_max_seq = */ n_max_seq,
/*.token_count = */ 0,
/*.used_cells = */ llama_get_kv_cache_used_cells(ctx),
/*.max_contiguous = */ 0,
/*.max_contiguous_idx = */ -1,
/*.cells = */ nullptr,
/*.cells_sequences = */ nullptr,
};
return result;
}

void llama_kv_cache_view_free(struct llama_kv_cache_view * view) {
if (view->cells != nullptr) {
free(view->cells);
view->cells = nullptr;
}
if (view->cells_sequences != nullptr) {
free(view->cells_sequences);
view->cells_sequences = nullptr;
}
}

void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view) {
if (uint32_t(view->n_cells) < ctx->kv_self.size || view->cells == nullptr) {
view->n_cells = int32_t(ctx->kv_self.size);
void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells);
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
view->cells = (struct llama_kv_cache_view_cell *)p;
p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_max_seq * view->n_cells);
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
view->cells_sequences = (llama_seq_id *)p;
}

const std::vector<llama_kv_cell> & kv_cells = ctx->kv_self.cells;
llama_kv_cache_view_cell * c_curr = view->cells;
llama_seq_id * cs_curr = view->cells_sequences;
int32_t used_cells = 0;
int32_t token_count = 0;
int32_t curr_contig_idx = -1;
uint32_t max_contig = 0;
int32_t max_contig_idx = -1;

for (int32_t i = 0; i < int32_t(ctx->kv_self.size); i++, c_curr++, cs_curr += view->n_max_seq) {
const size_t curr_size = kv_cells[i].seq_id.size();
token_count += curr_size;
c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;

if (curr_size > 0) {
if (curr_contig_idx >= 0 && uint32_t(i - curr_contig_idx) > max_contig) {
max_contig = i - curr_contig_idx;
max_contig_idx = curr_contig_idx;
}
curr_contig_idx = -1;
} else if (curr_contig_idx < 0) {
curr_contig_idx = i;
}

int seq_idx = 0;
for (const llama_seq_id it : kv_cells[i].seq_id) {
if (seq_idx >= view->n_max_seq) {
break;
}
cs_curr[seq_idx] = it;
seq_idx++;
}
if (seq_idx != 0) {
used_cells++;
}
for (; seq_idx < view->n_max_seq; seq_idx++) {
cs_curr[seq_idx] = -1;
}
}
if (curr_contig_idx >= 0 && kv_cells.size() - curr_contig_idx > max_contig) {
max_contig_idx = curr_contig_idx;
max_contig = kv_cells.size() - curr_contig_idx;
}
view->max_contiguous = max_contig;
view->max_contiguous_idx = max_contig_idx;
view->token_count = token_count;
view->used_cells = used_cells;
if (uint32_t(used_cells) != ctx->kv_self.used) {
LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
__func__, ctx->kv_self.used, used_cells);
}
}

int llama_get_kv_cache_token_count(const struct llama_context * ctx) {
int result = 0;

Expand Down
48 changes: 48 additions & 0 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,54 @@ extern "C" {
// KV cache
//

// Information associated with an individual cell in the KV cache view.
struct llama_kv_cache_view_cell {
// The position for this cell. Takes KV cache shifts into account.
// May be negative if the cell is not populated.
llama_pos pos;
};

// An updateable view of the KV cache.
struct llama_kv_cache_view {
// Number of KV cache cells. This will be the same as the context size.
int32_t n_cells;

// Maximum number of sequences that can exist in a cell. It's not an error
// if there are more sequences in a cell than this value, however they will
// not be visible in the view cells_sequences.
int32_t n_max_seq;

// Number of tokens in the cache. For example, if there are two populated
// cells, the first with 1 sequence id in it and the second with 2 sequence
// ids then you'll have 3 tokens.
int32_t token_count;

// Number of populated cache cells.
int32_t used_cells;

// Maximum contiguous empty slots in the cache.
int32_t max_contiguous;

// Index to the start of the max_contiguous slot range. Can be negative
// when cache is full.
int32_t max_contiguous_idx;

// Information for an individual cell.
struct llama_kv_cache_view_cell * cells;

// The sequences for each cell. There will be n_max_seq items per cell.
llama_seq_id * cells_sequences;
};

// Create an empty KV cache view.
LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq);

// Free a KV cache view.
LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);

// Update the KV cache view structure with the current state of the KV cache.
LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);

// Returns the number of tokens in the KV cache (slow, use only for debug)
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx);
Expand Down
Loading