Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
Merge pull request #41 from philpax/session-caching-cli
Browse files Browse the repository at this point in the history
Session caching CLI
  • Loading branch information
philpax authored Mar 26, 2023
2 parents 91afe67 + 796054f commit 08b875c
Show file tree
Hide file tree
Showing 7 changed files with 271 additions and 161 deletions.
59 changes: 59 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ Some additional things to try:
- Use `--help` to see a list of available options.
- If you have the [alpaca-lora](https://github.com/tloen/alpaca-lora) weights,
try `--repl` mode! `cargo run --release -- -m <path>/ggml-alpaca-7b-q4.bin
-f examples/alpaca_prompt.txt --repl`.
-f examples/alpaca_prompt.txt --repl`.

![Gif showcasing alpaca repl mode](./doc/resources/alpaca_repl_screencap.gif)

- Prompt files can be precomputed to speed up processing using the
`--cache-prompt` and `--restore-prompt` flags so you can save processing time
for lengthy prompts.
- Sessions can be loaded (`--load-session`) or saved (`--save-session`) to file. To automatically load
and save the same session, use `--persist-session`. This can be used to cache prompts to reduce load
time, too:

![Gif showcasing prompt caching](./doc/resources/prompt_caching_screencap.gif)

Expand Down
8 changes: 4 additions & 4 deletions llama-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ edition = "2021"
clap = { version = "4.1.8", features = ["derive"] }
env_logger = "0.10.0"
log = "0.4"
once_cell = "1.17.1"
num_cpus = "1.15.0"
once_cell = "1.17.1"
rustyline = "11.0.0"
spinners = "4.1.0"
zstd = { version = "0.12", default-features = false }

llama-rs = { path = "../llama-rs" }

rand = { workspace = true }

rustyline = "11.0.0"
spinners = "4.1.0"
26 changes: 19 additions & 7 deletions llama-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::path::PathBuf;

use clap::Parser;
use llama_rs::TokenBias;
use once_cell::sync::Lazy;
Expand Down Expand Up @@ -58,20 +60,30 @@ pub struct Args {
#[arg(long, default_value_t = 40)]
pub top_k: usize,

/// Top-p: The cummulative probability after which no more words are kept
/// Top-p: The cumulative probability after which no more words are kept
/// for sampling.
#[arg(long, default_value_t = 0.95)]
pub top_p: f32,

/// Stores a cached prompt at the given path. The same prompt can then be
/// loaded from disk using --restore-prompt
/// Saves an inference session at the given path. The same session can then be
/// loaded from disk using `--load-session`.
///
/// Use this with `-n 0` to save just the prompt
#[arg(long, default_value = None)]
pub save_session: Option<PathBuf>,

/// Loads a saved inference session from the given path, previously saved using
/// `--save-session`
#[arg(long, default_value = None)]
pub cache_prompt: Option<String>,
pub load_session: Option<PathBuf>,

/// Restores a cached prompt at the given path, previously using
/// --cache-prompt
/// Loads an inference session from the given path if present, and then saves
/// the result to the same path after inference is completed.
///
/// Equivalent to `--load-session` and `--save-session` with the same path,
/// but will not error if the path does not exist
#[arg(long, default_value = None)]
pub restore_prompt: Option<String>,
pub persist_session: Option<PathBuf>,

/// Specifies the seed to use during sampling. Note that, depending on
/// hardware, the same seed may lead to different results on two separate
Expand Down
139 changes: 80 additions & 59 deletions llama-cli/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use std::{convert::Infallible, io::Write};
use std::{convert::Infallible, io::Write, path::Path};

use cli_args::CLI_ARGS;
use llama_rs::{
InferenceError, InferenceParameters, InferenceSessionParameters, InferenceSnapshot,
InferenceError, InferenceParameters, InferenceSession, InferenceSessionParameters, Model,
ModelKVMemoryType, TokenBias, Vocabulary, EOD_TOKEN_ID,
};
use rand::thread_rng;
use rand::SeedableRng;
use rand::{thread_rng, SeedableRng};
use rustyline::error::ReadlineError;

mod cli_args;
Expand Down Expand Up @@ -112,6 +111,7 @@ fn main() {
TokenBias::default()
}
}),
play_back_previous_tokens: false,
};
let inference_session_params = {
let mem_typ = if args.float16 {
Expand All @@ -122,7 +122,7 @@ fn main() {
InferenceSessionParameters {
memory_k_type: mem_typ,
memory_v_type: mem_typ,
last_n_size: args.repeat_last_n,
repetition_penalty_last_n: args.repeat_last_n,
}
};

Expand Down Expand Up @@ -153,7 +153,7 @@ fn main() {
std::process::exit(1);
};

let (mut model, vocab) =
let (model, vocab) =
llama_rs::Model::load(&args.model_path, args.num_ctx_tokens as i32, |progress| {
use llama_rs::LoadProgress;
match progress {
Expand Down Expand Up @@ -220,20 +220,26 @@ fn main() {
rand::rngs::StdRng::from_entropy()
};

let mut session = if let Some(restore_path) = &args.restore_prompt {
let snapshot = InferenceSnapshot::load_from_disk(restore_path);
match snapshot.and_then(|snapshot| model.session_from_snapshot(snapshot)) {
Ok(session) => {
log::info!("Restored cached memory from {restore_path}");
session
}
Err(err) => {
log::error!("{err}");
std::process::exit(1);
let (mut session, session_loaded) = {
fn load_snapshot_from_disk(model: &Model, path: &Path) -> InferenceSession {
let snapshot = snapshot::load_from_disk(path);
match snapshot.and_then(|snapshot| model.session_from_snapshot(snapshot)) {
Ok(session) => {
log::info!("Loaded inference session from {path:?}");
session
}
Err(err) => {
eprintln!("Could not load inference session. Error: {err}");
std::process::exit(1);
}
}
}
} else {
model.start_session(inference_session_params)

match (&args.persist_session, &args.load_session) {
(Some(path), _) if path.exists() => (load_snapshot_from_disk(&model, path), true),
(_, Some(path)) => (load_snapshot_from_disk(&model, path), true),
_ => (model.start_session(inference_session_params), false),
}
};

if args.repl {
Expand All @@ -244,46 +250,16 @@ fn main() {
&inference_params,
&inference_session_params,
);
} else if let Some(cache_path) = &args.cache_prompt {
let res =
session.feed_prompt::<Infallible>(&model, &vocab, &inference_params, &prompt, |t| {
print!("{t}");
std::io::stdout().flush().unwrap();

Ok(())
});

println!();

match res {
Ok(_) => (),
Err(InferenceError::ContextFull) => {
log::warn!(
"Context is not large enough to fit the prompt. Saving intermediate state."
);
}
Err(llama_rs::InferenceError::TokenizationFailed) => {
log::error!("Failed to tokenize initial prompt. Exiting.");
return;
} else {
let inference_params = if session_loaded {
InferenceParameters {
play_back_previous_tokens: true,
..inference_params
}
Err(llama_rs::InferenceError::UserCallback(_)) => unreachable!("cannot fail"),
}
} else {
inference_params
};

// Write the memory to the cache file
// SAFETY: no other model functions used inside the block
unsafe {
let memory = session.get_snapshot();
match memory.write_to_disk(cache_path) {
Ok(_) => {
log::info!("Successfully written prompt cache to {cache_path}");
}
Err(err) => {
eprintln!("Could not restore prompt. Error: {err}");
std::process::exit(1);
}
}
}
} else {
let res = session.inference_with_prompt::<Infallible>(
&model,
&vocab,
Expand All @@ -301,9 +277,7 @@ fn main() {
println!();

match res {
Ok(stats) => {
println!("{}", stats);
}
Ok(_) => (),
Err(llama_rs::InferenceError::ContextFull) => {
log::warn!("Context window full, stopping inference.")
}
Expand All @@ -312,5 +286,52 @@ fn main() {
}
Err(llama_rs::InferenceError::UserCallback(_)) => unreachable!("cannot fail"),
}

if let Some(session_path) = args.save_session.as_ref().or(args.persist_session.as_ref()) {
// Write the memory to the cache file
// SAFETY: no other model functions used inside the block
unsafe {
match snapshot::write_to_disk(&session.get_snapshot(), session_path) {
Ok(_) => {
log::info!("Successfully wrote session to {session_path:?}");
}
Err(err) => {
log::error!("Could not write session at {session_path:?}: {err}");
std::process::exit(1);
}
}
}
}
}
}

mod snapshot {
use llama_rs::{InferenceSnapshot, InferenceSnapshotRef, SnapshotError};
use std::{
fs::File,
io::{BufReader, BufWriter},
path::Path,
};
use zstd::zstd_safe::CompressionLevel;

const SNAPSHOT_COMPRESSION_LEVEL: CompressionLevel = 1;

pub fn load_from_disk(path: impl AsRef<Path>) -> Result<InferenceSnapshot, SnapshotError> {
let mut reader =
zstd::stream::read::Decoder::new(BufReader::new(File::open(path.as_ref())?))?;
InferenceSnapshot::read(&mut reader)
}

pub fn write_to_disk(
snap: &InferenceSnapshotRef<'_>,
path: impl AsRef<Path>,
) -> Result<(), SnapshotError> {
let mut writer = zstd::stream::write::Encoder::new(
BufWriter::new(File::create(path.as_ref())?),
SNAPSHOT_COMPRESSION_LEVEL,
)?
.auto_finish();

snap.write(&mut writer)
}
}
1 change: 1 addition & 0 deletions llama-rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ thiserror = "1.0"

rand = { workspace = true }
serde = { version = "1.0.156", features = ["derive"] }
serde_bytes = "0.11"
bincode = "1.3.3"
Loading

0 comments on commit 08b875c

Please sign in to comment.