Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Added api for getting/setting the kv_cache #685

Merged
merged 6 commits into from
Apr 2, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1668,6 +1668,33 @@ int llama_model_quantize(
return 0;
}

// Returns the KV cache that will contain the context for the
// ongoing prediction with the model.
const uint8_t * llama_get_kv_cache(struct llama_context * ctx) {
return ctx->model.kv_self.buf.data();
}

// Returns the size of the KV cache
size_t llama_get_kv_cache_size(struct llama_context * ctx) {
return ctx->model.kv_self.buf.size();
}

int llama_get_kv_cache_token_count(struct llama_context * ctx) {
return ctx->model.kv_self.n;
}

// Sets the KV cache containing the current context for the model
void llama_set_kv_cache(
struct llama_context * ctx,
const uint8_t * kv_cache,
size_t n_size,
int n_token_count) {
// Make sure we have the same kv cache setup
LLAMA_ASSERT(ctx->model.kv_self.buf.size() == n_size);
memcpy(ctx->model.kv_self.buf.data(), kv_cache, n_size);
ctx->model.kv_self.n = n_token_count;
}

int llama_eval(
struct llama_context * ctx,
const llama_token * tokens,
Expand Down
17 changes: 17 additions & 0 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,23 @@ extern "C" {
const char * fname_out,
int itype);

// Returns the KV cache that will contain the context for the
// ongoing prediction with the model.
LLAMA_API const uint8_t * llama_get_kv_cache(struct llama_context * ctx);

// Returns the size of the KV cache
LLAMA_API size_t llama_get_kv_cache_size(struct llama_context * ctx);

// Returns the number of tokens in the KV cache
LLAMA_API int llama_get_kv_cache_token_count(struct llama_context * ctx);

// Sets the KV cache containing the current context for the model
LLAMA_API void llama_set_kv_cache(
struct llama_context * ctx,
const uint8_t * kv_cache,
size_t n_size,
int n_token_count);

// Run the llama inference to obtain the logits and probabilities for the next token.
// tokens + n_tokens is the provided batch of new tokens to process
// n_past is the number of tokens to use from previous eval calls
Expand Down