diff --git a/ggml-sys/ggml/ggml.c b/ggml-sys/ggml/ggml.c index 44a8655d..38541fe1 100644 --- a/ggml-sys/ggml/ggml.c +++ b/ggml-sys/ggml/ggml.c @@ -2644,7 +2644,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "FLASH_FF", }; -static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35"); +static_assert(GGML_OP_COUNT == 36, "GGML_OP_COUNT != 36"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -2688,7 +2688,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "flash_ff(x)", }; -static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35"); +static_assert(GGML_OP_COUNT == 36, "GGML_OP_COUNT != 36"); // // ggml object @@ -4709,6 +4709,37 @@ struct ggml_tensor * ggml_rope( return result; } + +// ggml_alibi +struct ggml_tensor * ggml_alibi( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_head) { + GGML_ASSERT(n_past >= 0); + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + // TODO: when implement backward, fix this: + //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3); + ((int32_t *) b->data)[0] = n_past; + ((int32_t *) b->data)[1] = n_head; + + result->op = GGML_OP_ALIBI; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + // ggml_conv_1d_1s struct ggml_tensor * ggml_conv_1d_1s( @@ -7192,6 +7223,163 @@ static void ggml_compute_forward_soft_max( } } +// ggml_compute_forward_alibi + +static void ggml_compute_forward_alibi_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(src1->type == GGML_TYPE_I32); + assert(ggml_nelements(src1) == 3); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n_past = ((int32_t *) src1->data)[0]; + const int n_head = ((int32_t *) src1->data)[1]; + const int mode = ((int32_t *) src1->data)[2]; + + const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 + const int ne1 = src0->ne[1]; // seq_len_without_past + const int ne2 = src0->ne[2]; // n_head -> this is k + const int ne3 = src0->ne[3]; // 1 -> bsz + + const int n = ggml_nrows(src0); + const int ne2_ne3 = n/ne1; // ne2*ne3 + + const int nb0 = src0->nb[0]; + const int nb1 = src0->nb[1]; + const int nb2 = src0->nb[2]; + const int nb3 = src0->nb[3]; + + + // printf("\nne0: %d, ne1: %d, ne2: %d, ne3: %d", ne0, ne1, ne2, ne3); + // printf("\nn_past = %d, ne2 = %d", n_past, ne2); + + assert(nb0 == sizeof(float)); + assert(ne1+n_past == ne0); + + // add alibi to src0 (KQ_scaled) + const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); + const float m0 = pow(2.0, -8.0 / n_heads_log2_floor); + const float m1 = pow(2.0, -4.0 / n_heads_log2_floor); + + for (int i = 0; i < ne0; i++) { + for (int j = 0; j < ne1; j++) { + for (int k = 0; k < ne2_ne3; k++) { + float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2); + float * dst_data = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2); + + // TODO: k*nb2 or k*nb3 + + float m_k; + if (k < n_heads_log2_floor) { + m_k = pow(m0, k + 1); + } else { + m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1); + } + //TODO: optimize + dst_data[0] = (j+1) * m_k + src[0]; + } + } + } + +} + + +static void ggml_compute_forward_alibi_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(src1->type == GGML_TYPE_I32); + assert(ggml_nelements(src1) == 3); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n_past = ((int32_t *) src1->data)[0]; + const int n_head = ((int32_t *) src1->data)[1]; + const int mode = ((int32_t *) src1->data)[2]; + + const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 + const int ne1 = src0->ne[1]; // seq_len_without_past + const int ne2 = src0->ne[2]; // n_head -> this is k + const int ne3 = src0->ne[3]; // 1 -> bsz + + const int n = ggml_nrows(src0); + const int ne2_ne3 = n/ne1; // ne2*ne3 + + const int nb0 = src0->nb[0]; + const int nb1 = src0->nb[1]; + const int nb2 = src0->nb[2]; + const int nb3 = src0->nb[3]; + + + // printf("\nne0: %d, ne1: %d, ne2: %d, ne3: %d", ne0, ne1, ne2, ne3); + // printf("\nn_past = %d, ne2 = %d", n_past, ne2); + + assert(nb0 == sizeof(float)); + assert(ne1+n_past == ne0); + + // add alibi to src0 (KQ_scaled) + const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); + const ggml_fp16_t m0 = pow(2.0, -8.0 / n_heads_log2_floor); + const ggml_fp16_t m1 = pow(2.0, -4.0 / n_heads_log2_floor); + + for (int i = 0; i < ne0; i++) { + for (int j = 0; j < ne1; j++) { + for (int k = 0; k < ne2_ne3; k++) { + ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2); + ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2); + + // TODO: k*nb2 or k*nb3 + + ggml_fp16_t m_k; + if (k < n_heads_log2_floor) { + m_k = pow(m0, k + 1); + } else { + m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1); + } + //TODO: optimize + dst_data[0] = (j+1) * m_k + src[0]; + } + } + } + +} + +static void ggml_compute_forward_alibi( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_alibi_f16(params, src0, src1, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_alibi_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_rope static void ggml_compute_forward_rope_f32( @@ -8691,6 +8879,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor); } break; + case GGML_OP_ALIBI: + { + ggml_compute_forward_alibi(params, tensor->src0, tensor->src1, tensor); + } break; case GGML_OP_CONV_1D_1S: { ggml_compute_forward_conv_1d_1s(params, tensor->src0, tensor->src1, tensor); @@ -8881,6 +9073,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // TODO: not implemented } break; + case GGML_OP_ALIBI: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_SILU: { GGML_ASSERT(false); // TODO: not implemented @@ -9387,6 +9583,10 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { node->n_tasks = 1; } break; + case GGML_OP_ALIBI: + { + node->n_tasks = 1; //TODO + } break; case GGML_OP_CONV_1D_1S: case GGML_OP_CONV_1D_2S: { diff --git a/ggml-sys/ggml/ggml.h b/ggml-sys/ggml/ggml.h index ad962b10..107868d6 100644 --- a/ggml-sys/ggml/ggml.h +++ b/ggml-sys/ggml/ggml.h @@ -244,6 +244,7 @@ enum ggml_op { GGML_OP_DIAG_MASK_INF, GGML_OP_SOFT_MAX, GGML_OP_ROPE, + GGML_OP_ALIBI, GGML_OP_CONV_1D_1S, GGML_OP_CONV_1D_2S, @@ -599,6 +600,16 @@ struct ggml_tensor * ggml_rope( int n_dims, int mode); +// alibi position embedding +// in-place, returns view(a) +struct ggml_tensor * ggml_alibi( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_head); + + + // padding = 1 // TODO: we don't support extra parameters for now // that's why we are hard-coding the stride, padding, and dilation diff --git a/ggml-sys/src/lib.rs b/ggml-sys/src/lib.rs index 3640a6af..12d7a3e4 100644 --- a/ggml-sys/src/lib.rs +++ b/ggml-sys/src/lib.rs @@ -899,6 +899,17 @@ extern "C" { filename: *const ::std::os::raw::c_char, ); } + + +extern "C" { + pub fn ggml_alibi( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + n_past: ::std::os::raw::c_int, + n_head: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} + pub const ggml_opt_type_GGML_OPT_ADAM: ggml_opt_type = 0; pub const ggml_opt_type_GGML_OPT_LBFGS: ggml_opt_type = 1; pub type ggml_opt_type = ::std::os::raw::c_uint; diff --git a/ggml/src/lib.rs b/ggml/src/lib.rs index 9b36bd7d..9093aac9 100644 --- a/ggml/src/lib.rs +++ b/ggml/src/lib.rs @@ -203,6 +203,28 @@ impl Context { self.new_tensor_raw(tensor) } + /// Creates a 2D view over `a`. + pub fn op_view_2d( + &self, + a: &Tensor, + ne0: usize, + ne1: usize, + nb1: usize, + offset: usize, + ) -> Tensor { + let tensor = unsafe { + ggml_sys::ggml_view_2d( + self.ptr.as_ptr(), + a.ptr.as_ptr(), + usize_to_i64(ne0), + usize_to_i64(ne1), + nb1, + offset, + ) + }; + self.new_tensor_raw(tensor) + } + /// Copies `a` to `b` and returns `b`. pub fn op_cpy(&self, a: &Tensor, b: &Tensor) -> Tensor { let tensor = @@ -271,6 +293,26 @@ impl Context { pub fn used_mem(&self) -> usize { unsafe { ggml_sys::ggml_used_mem(self.ptr.as_ptr()) } } + + /// TODO: something something + pub fn op_alibi(&self, a: &Tensor, n_past: usize, n_head: usize) -> Tensor { + let tensor = unsafe { + ggml_sys::ggml_alibi( + self.ptr.as_ptr(), + a.ptr.as_ptr(), + usize_to_i32(n_past), + usize_to_i32(n_head), + ) + }; + + self.new_tensor_raw(tensor) + } + + /// Gaussian Error Linear Units + pub fn op_gelu(&self, a: &Tensor) -> Tensor { + let tensor = unsafe { ggml_sys::ggml_gelu(self.ptr.as_ptr(), a.ptr.as_ptr()) }; + self.new_tensor_raw(tensor) + } } impl Drop for Context { diff --git a/llama-cli/src/cli_args.rs b/llama-cli/src/cli_args.rs index e6f19529..3b64ee46 100644 --- a/llama-cli/src/cli_args.rs +++ b/llama-cli/src/cli_args.rs @@ -1,8 +1,7 @@ -use std::path::PathBuf; - use clap::Parser; -use llama_rs::TokenBias; +use llama_rs::common::token::TokenBias; use once_cell::sync::Lazy; +use std::path::PathBuf; #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] @@ -29,6 +28,10 @@ pub struct Args { #[arg(long, short = 'R', default_value_t = false)] pub repl: bool, + /// Run in bloom mode + #[arg(long, short = 'B', default_value_t = false)] + pub bloom: bool, + /// Sets the number of threads to use #[arg(long, short = 't')] pub num_threads: Option, diff --git a/llama-cli/src/main.rs b/llama-cli/src/main.rs index c8da7d04..b0194783 100644 --- a/llama-cli/src/main.rs +++ b/llama-cli/src/main.rs @@ -1,45 +1,239 @@ use std::{convert::Infallible, io::Write, path::Path}; use cli_args::CLI_ARGS; + use llama_rs::{ - InferenceError, InferenceParameters, InferenceSession, InferenceSessionParameters, Model, - ModelKVMemoryType, TokenBias, Vocabulary, EOT_TOKEN_ID, + common::{inference::*, load::*, model::*, token::*, vocabulary::*}, + models::{bloom::BLOOM, llama::Llama}, }; + use rand::{thread_rng, SeedableRng}; use rustyline::error::ReadlineError; mod cli_args; +fn load_snapshot_from_disk(model: &Llama, path: &Path) -> InferenceSession { + let snapshot = snapshot::load_from_disk(path); + match snapshot.and_then(|snapshot| model.session_from_snapshot(snapshot)) { + Ok(session) => { + log::info!("Loaded inference session from {path:?}"); + session + } + Err(err) => { + eprintln!("Could not load inference session. Error: {err}"); + std::process::exit(1); + } + } +} + +fn load_llama(model_path: &String, num_ctx_tokens: usize) -> Result<(Llama, Vocabulary), String> { + let (model, vocab) = + Llama::load( + model_path, + num_ctx_tokens, + |progress| match progress { + LoadProgress::HyperparametersLoaded(hparams) => { + log::debug!("Loaded HyperParams {hparams:#?}") + } + LoadProgress::BadToken { index } => { + log::info!("Warning: Bad token in vocab at index {index}") + } + LoadProgress::ContextSize { bytes } => log::info!( + "ggml ctx size = {:.2} MB\n", + bytes as f64 / (1024.0 * 1024.0) + ), + LoadProgress::MemorySize { bytes, n_mem } => log::info!( + "Memory size: {} MB {}", + bytes as f32 / 1024.0 / 1024.0, + n_mem + ), + LoadProgress::PartLoading { + file, + current_part, + total_parts, + } => log::info!( + "Loading model part {}/{} from '{}'\n", + current_part, + total_parts, + file.to_string_lossy(), + ), + LoadProgress::PartTensorLoaded { + current_tensor, + tensor_count, + .. + } => { + if current_tensor % 8 == 0 { + log::info!("Loaded tensor {current_tensor}/{tensor_count}"); + } + } + LoadProgress::PartLoaded { + file, + byte_size, + tensor_count, + } => { + log::info!("Loading of '{}' complete", file.to_string_lossy()); + log::info!( + "Model size = {:.2} MB / num tensors = {}", + byte_size as f64 / 1024.0 / 1024.0, + tensor_count + ); + } + }, + ) + .expect("Could not load model"); + + log::info!("Model fully loaded!"); + Ok((model, vocab)) +} + +fn load_bloom(model_path: &String, num_ctx_tokens: usize) -> Result<(BLOOM, Vocabulary), String> { + let (model, vocab) = + BLOOM::load( + model_path, + num_ctx_tokens, + |progress| match progress { + LoadProgress::HyperparametersLoaded(hparams) => { + log::debug!("Loaded HyperParams {hparams:#?}") + } + LoadProgress::BadToken { index } => { + log::info!("Warning: Bad token in vocab at index {index}") + } + LoadProgress::ContextSize { bytes } => log::info!( + "ggml ctx size = {:.2} MB\n", + bytes as f64 / (1024.0 * 1024.0) + ), + LoadProgress::MemorySize { bytes, n_mem } => log::info!( + "Memory size: {} MB {}", + bytes as f32 / 1024.0 / 1024.0, + n_mem + ), + LoadProgress::PartLoading { + file, + current_part, + total_parts, + } => log::info!( + "Loading model part {}/{} from '{}'\n", + current_part, + total_parts, + file.to_string_lossy(), + ), + LoadProgress::PartTensorLoaded { + current_tensor, + tensor_count, + .. + } => { + if current_tensor % 8 == 0 { + log::info!("Loaded tensor {current_tensor}/{tensor_count}"); + } + } + LoadProgress::PartLoaded { + file, + byte_size, + tensor_count, + } => { + log::info!("Loading of '{}' complete", file.to_string_lossy()); + log::info!( + "Model size = {:.2} MB / num tensors = {}", + byte_size as f64 / 1024.0 / 1024.0, + tensor_count + ); + } + }, + ) + .expect("Could not load model"); + + log::info!("Model fully loaded!"); + Ok((model, vocab)) +} + +fn bloom_mode( + prompt: &str, + model: &BLOOM, + vocab: &Vocabulary, + params: &InferenceParameters, + session_params: &InferenceSessionParameters, +) { + let mut rl = rustyline::DefaultEditor::new().unwrap(); + loop { + let readline = rl.readline(">> "); + match readline { + Ok(line) => { + let mut session = model.start_session(*session_params); + let prompt = prompt.replace("$PROMPT", &line); + let mut rng = thread_rng(); + + let mut sp = spinners::Spinner::new(spinners::Spinners::Dots2, "".to_string()); + if let Err(InferenceError::ContextFull) = + session + .feed_prompt::(model, vocab, params, &prompt, model.hparams.n_ctx, |_| Ok(())) + { + log::error!("Prompt exceeds context window length.") + }; + sp.stop(); + + let res = session.inference_with_prompt::( + model, + vocab, + params, + "", + CLI_ARGS.num_predict, + &mut rng, + model.hparams.n_ctx, + |tk| { + print!("{tk}"); + std::io::stdout().flush().unwrap(); + Ok(()) + }, + ); + println!(); + + if let Err(InferenceError::ContextFull) = res { + log::error!("Reply exceeds context window length"); + } + } + Err(ReadlineError::Eof) | Err(ReadlineError::Interrupted) => { + break; + } + Err(err) => { + log::error!("{err}"); + } + } + } +} + fn repl_mode( - raw_prompt: &str, - model: &llama_rs::Model, - vocab: &llama_rs::Vocabulary, + prompt: &str, + model: &Llama, + vocab: &Vocabulary, params: &InferenceParameters, - mut session: InferenceSession, + session_params: &InferenceSessionParameters, ) { let mut rl = rustyline::DefaultEditor::new().unwrap(); loop { let readline = rl.readline(">> "); match readline { Ok(line) => { - let prompt = process_prompt(raw_prompt, &line); + let mut session = model.start_session(*session_params); + let prompt = prompt.replace("$PROMPT", &line); let mut rng = thread_rng(); let mut sp = spinners::Spinner::new(spinners::Spinners::Dots2, "".to_string()); if let Err(InferenceError::ContextFull) = - session.feed_prompt::(model, vocab, params, &prompt, |_| Ok(())) + session + .feed_prompt::(model, vocab, params, &prompt, model.hparams.n_ctx, |_| Ok(())) { log::error!("Prompt exceeds context window length.") }; sp.stop(); - let res = session.inference_with_prompt::( + let res = session.inference_with_prompt::( model, vocab, params, "", CLI_ARGS.num_predict, &mut rng, + model.hparams.n_ctx, |tk| { print!("{tk}"); std::io::stdout().flush().unwrap(); @@ -96,16 +290,16 @@ fn main() { let args = &*CLI_ARGS; - let inference_params = InferenceParameters { + let inference_params: InferenceParameters = InferenceParameters { n_threads: args.num_threads(), n_batch: args.batch_size, top_k: args.top_k, top_p: args.top_p, repeat_penalty: args.repeat_penalty, - temperature: args.temp, + temp: args.temp, bias_tokens: args.token_bias.clone().unwrap_or_else(|| { if args.ignore_eos { - TokenBias::new(vec![(EOT_TOKEN_ID, -1.0)]) + TokenBias::new(vec![(EOD_TOKEN_ID, -1.0)]) } else { TokenBias::default() } @@ -126,7 +320,7 @@ fn main() { } }; - let raw_prompt = if let Some(path) = &args.prompt_file { + let prompt = if let Some(path) = &args.prompt_file { match std::fs::read_to_string(path) { Ok(mut prompt) => { // Strip off the last character if it's exactly newline. Also strip off a single @@ -153,59 +347,7 @@ fn main() { std::process::exit(1); }; - let (model, vocab) = llama_rs::Model::load(&args.model_path, args.num_ctx_tokens, |progress| { - use llama_rs::LoadProgress; - match progress { - LoadProgress::HyperparametersLoaded(hparams) => { - log::debug!("Loaded hyperparameters {hparams:#?}") - } - LoadProgress::BadToken { index } => { - log::info!("Warning: Bad token in vocab at index {index}") - } - LoadProgress::ContextSize { bytes } => log::info!( - "ggml ctx size = {:.2} MB\n", - bytes as f64 / (1024.0 * 1024.0) - ), - LoadProgress::PartLoading { - file, - current_part, - total_parts, - } => { - let current_part = current_part + 1; - log::info!( - "Loading model part {}/{} from '{}'\n", - current_part, - total_parts, - file.to_string_lossy(), - ) - } - LoadProgress::PartTensorLoaded { - current_tensor, - tensor_count, - .. - } => { - let current_tensor = current_tensor + 1; - if current_tensor % 8 == 0 { - log::info!("Loaded tensor {current_tensor}/{tensor_count}"); - } - } - LoadProgress::PartLoaded { - file, - byte_size, - tensor_count, - } => { - log::info!("Loading of '{}' complete", file.to_string_lossy()); - log::info!( - "Model size = {:.2} MB / num tensors = {}", - byte_size as f64 / 1024.0 / 1024.0, - tensor_count - ); - } - } - }) - .expect("Could not load model"); - - log::info!("Model fully loaded!"); + // load llama model let mut rng = if let Some(seed) = CLI_ARGS.seed { rand::rngs::StdRng::seed_from_u64(seed) @@ -213,34 +355,35 @@ fn main() { rand::rngs::StdRng::from_entropy() }; - let (mut session, session_loaded) = { - fn load_snapshot_from_disk(model: &Model, path: &Path) -> InferenceSession { - let snapshot = snapshot::load_from_disk(path); - match snapshot.and_then(|snapshot| model.session_from_snapshot(snapshot)) { - Ok(session) => { - log::info!("Loaded inference session from {path:?}"); - session - } - Err(err) => { - eprintln!("Could not load inference session. Error: {err}"); - std::process::exit(1); - } - } - } - - match (&args.persist_session, &args.load_session) { - (Some(path), _) if path.exists() => (load_snapshot_from_disk(&model, path), true), - (_, Some(path)) => (load_snapshot_from_disk(&model, path), true), - _ => (model.start_session(inference_session_params), false), - } - }; - if args.repl { - repl_mode(&raw_prompt, &model, &vocab, &inference_params, session); + if args.bloom { + let (model, vocab) = load_bloom(&args.model_path, args.num_ctx_tokens).unwrap(); + bloom_mode( + &prompt, + &model, + &vocab, + &inference_params, + &inference_session_params, + ); + } else if args.repl { + let (model, vocab) = load_llama(&args.model_path, args.num_ctx_tokens).unwrap(); + repl_mode( + &prompt, + &model, + &vocab, + &inference_params, + &inference_session_params, + ); } else { - let prompt = match (&args.prompt_file, &args.prompt) { - (Some(_), Some(prompt)) => process_prompt(&raw_prompt, prompt), - _ => raw_prompt, + let (model, vocab) = load_llama(&args.model_path, args.num_ctx_tokens).unwrap(); + + + let (mut session, session_loaded) = { + match (&args.persist_session, &args.load_session) { + (Some(path), _) if path.exists() => (load_snapshot_from_disk(&model, path), true), + (_, Some(path)) => (load_snapshot_from_disk(&model, path), true), + _ => (model.start_session(inference_session_params), false), + } }; if args.dump_prompt_tokens { @@ -257,13 +400,14 @@ fn main() { inference_params }; - let res = session.inference_with_prompt::( + let res = session.inference_with_prompt::( &model, &vocab, &inference_params, &prompt, args.num_predict, &mut rng, + model.hparams.n_ctx, |t| { print!("{t}"); std::io::stdout().flush().unwrap(); @@ -275,14 +419,13 @@ fn main() { match res { Ok(_) => (), - Err(llama_rs::InferenceError::ContextFull) => { + Err(InferenceError::ContextFull) => { log::warn!("Context window full, stopping inference.") } - Err(llama_rs::InferenceError::TokenizationFailed) => { + Err(InferenceError::TokenizationFailed) => { log::error!("Failed to tokenize initial prompt."); } - Err(llama_rs::InferenceError::UserCallback(_)) - | Err(llama_rs::InferenceError::EndOfText) => unreachable!("cannot fail"), + Err(InferenceError::UserCallback(_)) => unreachable!("cannot fail"), } if let Some(session_path) = args.save_session.as_ref().or(args.persist_session.as_ref()) { @@ -304,12 +447,14 @@ fn main() { } mod snapshot { - use llama_rs::{InferenceSnapshot, InferenceSnapshotRef, SnapshotError}; use std::{ fs::File, io::{BufReader, BufWriter}, path::Path, }; + + use llama_rs::common::inference::{InferenceSnapshot, InferenceSnapshotRef, SnapshotError}; + use zstd::zstd_safe::CompressionLevel; const SNAPSHOT_COMPRESSION_LEVEL: CompressionLevel = 1; @@ -333,7 +478,3 @@ mod snapshot { snap.write(&mut writer) } } - -fn process_prompt(raw_prompt: &str, prompt: &str) -> String { - raw_prompt.replace("{{PROMPT}}", prompt) -} diff --git a/llama-rs/src/common/helpers.rs b/llama-rs/src/common/helpers.rs new file mode 100644 index 00000000..8570431f --- /dev/null +++ b/llama-rs/src/common/helpers.rs @@ -0,0 +1,42 @@ +use super::load::LoadError; +use std::{ + fs::File, + io::BufReader, + io::{BufRead, Read}, +}; + +pub fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> { + let mut bytes = [0u8; N]; + reader + .read_exact(&mut bytes) + .map_err(|e| LoadError::ReadExactFailed { + source: e, + bytes: N, + })?; + Ok(bytes) +} + +pub fn read_i32(reader: &mut impl BufRead) -> Result { + Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) +} + +pub fn read_u32(reader: &mut impl BufRead) -> Result { + Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) +} + +pub fn read_f32(reader: &mut impl BufRead) -> Result { + Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) +} + +/// Helper function. Reads a string from the buffer and returns it. +pub fn read_string(reader: &mut BufReader, len: usize) -> Result { + let mut buf = vec![0; len]; + reader + .read_exact(&mut buf) + .map_err(|e| LoadError::ReadExactFailed { + source: e, + bytes: buf.len(), + })?; + let s = String::from_utf8(buf)?; + Ok(s) +} diff --git a/llama-rs/src/common/inference.rs b/llama-rs/src/common/inference.rs new file mode 100644 index 00000000..719191d6 --- /dev/null +++ b/llama-rs/src/common/inference.rs @@ -0,0 +1,369 @@ +use super::model::{EvaluateOutputRequest, Model}; +use super::token::{OutputToken, TokenBias, TokenId, EOD_TOKEN_ID}; +use super::vocabulary::Vocabulary; +use core::slice; +use std::fmt::Display; +use std::time::SystemTime; +use thiserror::Error; + +/// An inference session represents the state of the text generation. This holds +/// the full context window, as long as several additional parameters used +/// during sampling. +pub struct InferenceSession { + // Must be kept alive for the model + pub _session_ctx: ggml::Context, + + // Parameters for the session. + pub params: InferenceSessionParameters, + + pub memory_k: ggml::Tensor, + pub memory_v: ggml::Tensor, + + /// How many tokens have been fed into the model's working memory so far. + pub n_past: usize, + + /// How much memory is required per token for the temporary context used + /// during inference. + pub mem_per_token: usize, + + /// All tokens generated by this inference session + pub tokens: Vec, + + /// The logits that were last predicted by the network. Zeroed out otherwise. + pub last_logits: Vec, +} +impl InferenceSession { + pub fn repetition_penalty_tokens(&self) -> &[TokenId] { + &self.tokens[self + .tokens + .len() + .saturating_sub(self.params.repetition_penalty_last_n)..] + } +} + +#[derive(serde::Serialize, Clone, PartialEq)] +/// A serializable snapshot of the inference process. Can be saved to disk. +// Keep in sync with [InferenceSession] and [InferenceSnapshot] +pub struct InferenceSnapshotRef<'a> { + /// How many tokens have been stored in the memory so far. + pub npast: usize, + // Parameters associated with the saved inference session. + pub session_params: InferenceSessionParameters, + /// All tokens generated by this inference session + pub tokens: Vec, + /// The vector of logits that was produced after the last inference + pub logits: Vec, + /// The contents of the 'key' memory tensor + #[serde(with = "serde_bytes")] + pub memory_k: &'a [u8], + /// The contents of the 'value' memory tensor + #[serde(with = "serde_bytes")] + pub memory_v: &'a [u8], +} + +/// A serializable snapshot of the inference process. Can be restored by calling +/// `Model::restore_from_snapshot`. +#[derive(serde::Deserialize, Clone, PartialEq)] +// Keep in sync with [InferenceSession] and [InferenceSnapshotRef] +pub struct InferenceSnapshot { + /// How many tokens have been stored in the memory so far. + pub npast: usize, + // Parameters associated with the saved inference session. + pub session_params: InferenceSessionParameters, + /// All tokens generated by this inference session + pub tokens: Vec, + /// The vector of logitsTokenB that was produced after the last inference + pub last_logits: Vec, + /// The contents of the 'key' memory tensor + #[serde(with = "serde_bytes")] + pub memory_k: Vec, + /// The contents of the 'value' memory tensor + #[serde(with = "serde_bytes")] + pub memory_v: Vec, +} + +// Allowed types for the model memory K/V tensors. +#[derive(Clone, Copy, Debug, PartialEq, serde::Serialize, serde::Deserialize)] +pub enum ModelKVMemoryType { + Float16, + Float32, +} + +impl From for u32 { + fn from(value: ModelKVMemoryType) -> Self { + match value { + ModelKVMemoryType::Float16 => ggml::TYPE_F16, + ModelKVMemoryType::Float32 => ggml::TYPE_F32, + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq, serde::Serialize, serde::Deserialize)] +// Parameters for an inference session. +pub struct InferenceSessionParameters { + pub repetition_penalty_last_n: usize, + pub memory_k_type: ModelKVMemoryType, + pub memory_v_type: ModelKVMemoryType, +} + +impl Default for InferenceSessionParameters { + fn default() -> Self { + Self { + repetition_penalty_last_n: 512, + memory_k_type: ModelKVMemoryType::Float32, + memory_v_type: ModelKVMemoryType::Float32, + } + } +} + +#[derive(Clone, Debug, PartialEq)] +/// The parameters that drive text generation. +pub struct InferenceParameters { + pub n_threads: usize, + pub n_batch: usize, + pub top_k: usize, + pub top_p: f32, + pub repeat_penalty: f32, + pub temp: f32, + pub bias_tokens: TokenBias, + pub play_back_previous_tokens: bool, + pub increased_determinism: bool, +} + +impl Default for InferenceParameters { + fn default() -> Self { + Self { + n_threads: 8, + n_batch: 8, + top_k: 40, + top_p: 0.95, + repeat_penalty: 1.30, + temp: 0.80, + bias_tokens: TokenBias::default(), + play_back_previous_tokens: false, + increased_determinism: true, + } + } +} + +pub struct InferenceStats { + pub feed_prompt_duration: std::time::Duration, + pub prompt_tokens: usize, + pub predict_duration: std::time::Duration, + pub predict_tokens: usize, +} + +impl Default for InferenceStats { + fn default() -> Self { + Self { + feed_prompt_duration: std::time::Duration::from_secs(0), + prompt_tokens: 0, + predict_duration: std::time::Duration::from_secs(0), + predict_tokens: 0, + } + } +} + +impl Display for InferenceStats { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "feed_prompt_duration: {}ms\nprompt_tokens: {}\npredict_duration: {}ms\npredict_tokens: {}\nper_token_duration: {:.3}ms", + self.feed_prompt_duration.as_millis(), + self.prompt_tokens, + self.predict_duration.as_millis(), + self.predict_tokens, + (self.predict_duration.as_millis() as f64) / (self.predict_tokens as f64), + ) + } +} + +impl InferenceSession { + pub fn feed_prompt( + &mut self, + model: &M, + vocab: &Vocabulary, + params: &InferenceParameters, + prompt: &str, + n_ctx: usize, + callback: impl Fn(OutputToken) -> Result<(), E>, + ) -> Result<(), InferenceError> { + let beginning_of_sentence = self.n_past == 0; + let prompt_tokens = model.tokenize(vocab, prompt, beginning_of_sentence)?; + + if self.n_past + prompt_tokens.len() >= n_ctx { + return Err(InferenceError::ContextFull); + } + + for batch in prompt_tokens.chunks(8) { + model.evaluate(self, params, batch, &mut EvaluateOutputRequest::default()); + for &tk in batch { + // NOTE: No string ever tokenizes to the end of sentence. So we + // can just return the id here. + if let Err(e) = callback(OutputToken::Token(&vocab.id_to_token[tk as usize])) { + return Err(InferenceError::UserCallback(Box::new(e))); + } + + // Update the tokens for this session + self.tokens.push(tk); + } + } + + Ok(()) + } + + pub fn infer_next_token<'v, M: Model>( + &mut self, + model: &M, + vocab: &'v Vocabulary, + params: &InferenceParameters, + rng: &mut impl rand::Rng, + n_ctx: usize, + ) -> Result, InferenceError> { + if self.n_past + 1 >= n_ctx { + return Err(InferenceError::ContextFull); + } + + // First, sample the next token, using the stored last_logits; + let next_token = model.sample_top_p_top_k(self, params, rng); + + // Update the tokens for this session + self.tokens.push(next_token); + + // Then, evaluate the network again to compute the new last_logits + model.evaluate( + self, + params, + &[next_token], + &mut EvaluateOutputRequest::default(), + ); + + // Return the next token + Ok(if next_token as TokenId == EOD_TOKEN_ID { + OutputToken::EndOfText + } else { + OutputToken::Token(&vocab.id_to_token[next_token as usize]) + }) + } + + // todo: see if we can reduce the arguments here somehow - consolidate model and vocab maybe? + /// Helper function to run inference with this session and the given model and vocabulary. + /// + /// Note that this will "play back" all existing tokens in the session. If this is not desired + /// behaviour, consider implementing your own inference loop to customize the behavior. + #[allow(clippy::too_many_arguments)] + pub fn inference_with_prompt( + &mut self, + model: &M, + vocab: &Vocabulary, + params: &InferenceParameters, + prompt: &str, + maximum_token_count: Option, + rng: &mut impl rand::Rng, + n_ctx: usize, + callback: impl Fn(OutputToken) -> Result<(), E>, + ) -> Result { + let maximum_token_count = maximum_token_count.unwrap_or(usize::MAX); + if params.play_back_previous_tokens { + // "Play back" the existing tokens, so that loading from an inference snapshot works + // as expected. + for token_id in &self.tokens { + let token = OutputToken::from_id(vocab, *token_id); + if let Err(e) = callback(token) { + return Err(InferenceError::UserCallback(Box::new(e))); + } + } + } + + let mut stats = InferenceStats::default(); + + let start_at = SystemTime::now(); + + // Feed the initial prompt through the transformer, to update its + // context window with new data. + self.feed_prompt(model, vocab, params, prompt, n_ctx, |tk| callback(tk))?; + stats.feed_prompt_duration = start_at.elapsed().unwrap(); + stats.prompt_tokens = self.n_past; + + // After the prompt is consumed, sample tokens by repeatedly calling + // `infer_next_token`. We generate tokens until the model returns an + // EndOfText token, or we run out of space in the context window, + // or we reach the specified limit. + let mut tokens_processed = 0; + while tokens_processed < maximum_token_count { + let token = self.infer_next_token(model, vocab, params, rng, n_ctx)?; + + if let Err(e) = callback(token) { + return Err(InferenceError::UserCallback(Box::new(e))); + } + + tokens_processed += 1; + + if let OutputToken::EndOfText = token { + break; + } + } + stats.predict_duration = start_at.elapsed().unwrap(); + stats.predict_tokens = self.n_past; + + Ok(stats) + } + + /// Obtains a serializable snapshot of the current inference status. This + /// can be used to cache the state of the model and store them into a file. + /// + /// # Safety + /// + /// This function provides raw access to the underlying memory owned by the + /// ggml context. While the provided `InferenceSnapshotRef` object is alive, + /// no other methods for this model object should be called. + pub unsafe fn get_snapshot(&mut self) -> InferenceSnapshotRef<'_> { + let memory_k = unsafe { + slice::from_raw_parts(self.memory_k.data() as *mut u8, self.memory_k.nbytes()) + }; + let memory_v = unsafe { + slice::from_raw_parts(self.memory_v.data() as *mut u8, self.memory_v.nbytes()) + }; + + InferenceSnapshotRef { + npast: self.n_past, + session_params: self.params, + tokens: self.tokens.clone(), + logits: self.last_logits.clone(), + memory_k, + memory_v, + } + } +} + +impl<'a> InferenceSnapshotRef<'a> { + pub fn write(&self, writer: &mut impl std::io::Write) -> Result<(), SnapshotError> { + Ok(bincode::serialize_into(writer, &self)?) + } +} + +impl InferenceSnapshot { + pub fn read(reader: &mut impl std::io::Read) -> Result { + Ok(bincode::deserialize_from(reader)?) + } +} + +#[derive(Error, Debug)] +pub enum InferenceError { + #[error("an invalid token was encountered during tokenization")] + TokenizationFailed, + #[error("the context window is full")] + ContextFull, + #[error("the user-specified callback returned an error")] + UserCallback(Box), +} + +#[derive(Error, Debug)] +pub enum SnapshotError { + #[error("I/O error while reading or writing snapshot")] + IO(#[from] std::io::Error), + #[error("error during snapshot serialization")] + Serialization(#[from] bincode::Error), + #[error("could not read snapshot due to size mismatch (self={self_size}, input={input_size})")] + MemorySizeMismatch { self_size: usize, input_size: usize }, +} diff --git a/llama-rs/src/common/load.rs b/llama-rs/src/common/load.rs new file mode 100644 index 00000000..4d4164f3 --- /dev/null +++ b/llama-rs/src/common/load.rs @@ -0,0 +1,72 @@ +use std::path::{Path, PathBuf}; +use thiserror::Error; + +/// Each variant represents a step within the process of loading the model. +/// These can be used to report progress to the user. +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug)] +pub enum LoadProgress { + HyperparametersLoaded(T), + BadToken { + index: usize, + }, + ContextSize { + bytes: usize, + }, + MemorySize { + bytes: usize, + n_mem: usize, + }, + PartLoading { + file: Box, + current_part: usize, + total_parts: usize, + }, + PartTensorLoaded { + file: Box, + current_tensor: usize, + tensor_count: usize, + }, + PartLoaded { + file: Box, + byte_size: usize, + tensor_count: usize, + }, +} + +#[derive(Error, Debug)] +pub enum LoadError { + #[error("could not open file {path:?}")] + OpenFileFailed { + source: std::io::Error, + path: PathBuf, + }, + #[error("no parent path for {path:?}")] + NoParentPath { path: PathBuf }, + #[error("unable to read exactly {bytes} bytes")] + ReadExactFailed { + source: std::io::Error, + bytes: usize, + }, + #[error("non-specific I/O error")] + IO(#[from] std::io::Error), + + #[error("could not convert bytes to a UTF-8 string")] + InvalidUtf8(#[from] std::string::FromUtf8Error), + #[error("invalid integer conversion")] + InvalidIntegerConversion(#[from] std::num::TryFromIntError), + + #[error("unversioned magic number, regenerate your ggml models")] + UnversionedMagic, + #[error("invalid magic number for {path:?}")] + InvalidMagic { path: PathBuf }, + #[error("invalid file format version {value}")] + InvalidFormatVersion { value: u32 }, + #[error("invalid value {value} for `f16` in hyperparameters")] + HyperparametersF16Invalid { value: u32 }, + #[error("unknown tensor `{tensor_name}` in {path:?}")] + UnknownTensor { tensor_name: String, path: PathBuf }, + #[error("the tensor `{tensor_name}` has the wrong size in {path:?}")] + TensorWrongSize { tensor_name: String, path: PathBuf }, + #[error("invalid ftype {ftype} in {path:?}")] + InvalidFtype { ftype: u32, path: PathBuf }, +} diff --git a/llama-rs/src/common/macros.rs b/llama-rs/src/common/macros.rs new file mode 100644 index 00000000..1acce873 --- /dev/null +++ b/llama-rs/src/common/macros.rs @@ -0,0 +1,13 @@ +/// NOTE: The original code relies in promotion rules and automatic cast between +/// int to float. What we do instead is use this macro to convert every term of +/// the multiplication to f64, which should have enough precision bits to hold +/// the final value, then cast to usize. I have observed a discrepancy between +/// the ctx_size found using this code, and the one in llama.cpp. The number for +/// rust ends up being slightly lower, but no "out of memory" errors are +/// reported by ggml. +#[macro_export] +macro_rules! mulf { + ($term:expr, $($terms:expr),*) => { + (($term as f64) $(* ($terms as f64))*) as u64 + }; +} diff --git a/llama-rs/src/common/mod.rs b/llama-rs/src/common/mod.rs new file mode 100644 index 00000000..7f1c2c8b --- /dev/null +++ b/llama-rs/src/common/mod.rs @@ -0,0 +1,8 @@ +pub mod inference; +pub mod load; +#[macro_use] +pub mod macros; +pub mod helpers; +pub mod model; +pub mod token; +pub mod vocabulary; diff --git a/llama-rs/src/common/model.rs b/llama-rs/src/common/model.rs new file mode 100644 index 00000000..91f6f830 --- /dev/null +++ b/llama-rs/src/common/model.rs @@ -0,0 +1,182 @@ +use super::inference::{ + InferenceError, InferenceParameters, InferenceSession, InferenceSessionParameters, + InferenceSnapshot, SnapshotError, +}; +use super::load::{LoadError, LoadProgress}; +use super::token::TokenId; +use super::vocabulary::Vocabulary; +use partial_sort::PartialSort; +use rand::{distributions::WeightedIndex, prelude::Distribution}; +use std::path::Path; + +/// Used in a call to `evaluate` to request information from the transformer. +#[derive(Default)] +pub struct EvaluateOutputRequest { + /// Returns all the logits for the provided batch of tokens. + /// Output shape is n_batch * n_vocab + pub all_logits: Option>, + /// Returns the embeddings for the provided batch of tokens + /// Output shape is n_batch * n_embd + pub embeddings: Option>, +} + +pub trait Model { + type Weights; + type HP; + + fn load( + path: impl AsRef, + n_ctx: usize, + load_progress_callback: impl Fn(LoadProgress), + ) -> Result<(Self::Weights, Vocabulary), LoadError>; + + /// Starts a new `InferenceSession` for this model. + fn start_session(&self, params: InferenceSessionParameters) -> InferenceSession; + + fn sample_top_p_top_k( + &self, + session: &InferenceSession, + params: &InferenceParameters, + rng: &mut impl rand::Rng, + ) -> TokenId { + let logits = &session.last_logits; + // println!("{:?}", logits); + let n_logits = logits.len(); + let mut logits_id = Vec::<(f32, TokenId)>::with_capacity(n_logits); + + { + let scale = 1.0 / params.temp; + for (i, &logit) in logits.iter().enumerate() { + let tid = i as TokenId; + + // repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) + // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main + let val = if let Some(logit_override) = params.bias_tokens.get(tid) { + logit_override + } else if session + .repetition_penalty_tokens() + .contains(&(i as TokenId)) + { + // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability + if logits[i] < 0.0 { + logit * scale * params.repeat_penalty + } else { + logit * scale / params.repeat_penalty + } + } else { + logit * scale + }; + logits_id.push((val, tid)); + } + } + + // find the top K tokens + { + logits_id.partial_sort(params.top_k, |a, b| { + // Sort descending + b.0.total_cmp(&a.0) + }); + logits_id.truncate(params.top_k); + } + + let maxl = logits_id + .iter() + .map(|x| x.0) + .max_by(f32::total_cmp) + .unwrap(); + + // compute probs for the top K tokens + let mut probs: Vec = logits_id + .iter() + .copied() + .map(|(k, _)| (k - maxl).exp()) + .collect(); + let sum: f32 = probs.iter().copied().sum(); + + // Normalize the probs + for p in probs.iter_mut() { + *p /= sum; + } + + // Top p sampling + if params.top_p < 1.0 { + let mut cumsum = 0.0; + for i in 0..probs.len() { + cumsum += probs[i]; + if cumsum >= params.top_p { + probs.truncate(i + 1); + logits_id.truncate(i + 1); + break; + } + } + + cumsum = 1.0 / cumsum; + for p in probs.iter_mut() { + *p *= cumsum; + } + } + + let dist = WeightedIndex::new(&probs).expect("WeightedIndex error"); + let idx = dist.sample(rng); + + logits_id[idx].1 + } + + /// Evaluates the transformer. + /// + /// The provided `output_request` struct lets you specify which additional + /// data you are interested in fetching from the transformer. Setting a + /// field to a `Some` value will clear and fill the provided vector with + /// data. The provided vector will be resized to the exact output size. + fn evaluate( + &self, + session: &mut InferenceSession, + params: &InferenceParameters, + input_tokens: &[TokenId], + output_request: &mut EvaluateOutputRequest, + ); + + fn tokenize( + &self, + vocab: &Vocabulary, + text: &str, + bos: bool, + ) -> Result, InferenceError> { + Ok(vocab + .tokenize(text, bos)? + .iter() + .map(|(_, tid)| *tid) + .collect::>()) + } + + /// Hydrates a previously obtained InferenceSnapshot for this model + fn session_from_snapshot( + &self, + snapshot: InferenceSnapshot, + ) -> Result { + let mut session = self.start_session(snapshot.session_params); + + if session.memory_k.nbytes() != snapshot.memory_k.len() + || session.memory_v.nbytes() != snapshot.memory_v.len() + { + return Err(SnapshotError::MemorySizeMismatch { + self_size: session.memory_k.nbytes() + session.memory_v.nbytes(), + input_size: snapshot.memory_k.len() + snapshot.memory_v.len(), + }); + } + + // SAFETY: We have exclusive access to Session, which means no one else + // should be touching the context's memory. We can write to it because + // we already checked the size. + unsafe { + session.memory_k.write_data(&snapshot.memory_k); + session.memory_v.write_data(&snapshot.memory_v); + } + + session.n_past = snapshot.npast; + session.tokens = snapshot.tokens; + session.last_logits = snapshot.last_logits; + + Ok(session) + } +} diff --git a/llama-rs/src/common/token.rs b/llama-rs/src/common/token.rs new file mode 100644 index 00000000..a04e5d60 --- /dev/null +++ b/llama-rs/src/common/token.rs @@ -0,0 +1,86 @@ +use super::vocabulary::Vocabulary; +use std::fmt::Display; +use std::str::FromStr; + +pub const EOD_TOKEN_ID: TokenId = 2; // Hardcoded (for now?) + +pub type TokenId = i32; +pub type Token = String; +pub type TokenScore = f32; + +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub enum OutputToken<'a> { + Token(&'a str), + EndOfText, +} +impl<'a> OutputToken<'a> { + pub fn from_id(vocab: &'a Vocabulary, id: TokenId) -> Self { + if id == 2 { + Self::EndOfText + } else { + Self::Token(&vocab.id_to_token[id as usize]) + } + } +} +impl Display for OutputToken<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + OutputToken::Token(t) => *t, + OutputToken::EndOfText => "[end of text]", + } + ) + } +} + +#[derive(Default, Clone, Debug, PartialEq)] +pub struct TokenBias(Vec<(TokenId, f32)>); + +impl TokenBias { + pub fn new(mut v: Vec<(TokenId, f32)>) -> Self { + v.sort_by_cached_key(|(tid, _)| *tid); + v.dedup_by_key(|(tid, _)| *tid); + Self(v) + } + + pub fn get(&self, tid: TokenId) -> Option { + self.0 + .binary_search_by_key(&tid, |(tid, _)| *tid) + .map(|idx| self.0[idx].1) + .ok() + } +} + +impl FromStr for TokenBias { + type Err = String; + + fn from_str(s: &str) -> Result { + let x = s + .split(',') + .map(|kv| { + let (k, v) = kv + .trim() + .split_once('=') + .ok_or_else(|| "Missing '=' in bias item".to_owned())?; + let tid: TokenId = k + .trim() + .parse() + .map_err(|e: std::num::ParseIntError| e.to_string())?; + let bias: f32 = v + .trim() + .parse() + .map_err(|e: std::num::ParseFloatError| e.to_string())?; + Result::<_, String>::Ok((tid, bias)) + }) + .collect::>()?; + Ok(TokenBias::new(x)) + } +} + +impl Display for TokenBias { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.0) + } +} diff --git a/llama-rs/src/common/vocabulary.rs b/llama-rs/src/common/vocabulary.rs new file mode 100644 index 00000000..e7a437c2 --- /dev/null +++ b/llama-rs/src/common/vocabulary.rs @@ -0,0 +1,78 @@ +use super::inference::InferenceError; +use super::token::{Token, TokenId, TokenScore}; +use std::collections::HashMap; + +pub struct Vocabulary { + /// Maps every integer (index) token id to its corresponding token + pub id_to_token: Vec, + + /// Maps every integer (index) token id to corresponding score + #[allow(dead_code)] + pub id_to_token_score: Vec, + + /// Maps a token to a token id + pub token_to_id: HashMap, + + /// The longest token in this vocabulary + pub max_token_length: usize, +} + +impl Vocabulary { + // SentencePiece implementation after https://guillaume-be.github.io/2020-05-30/sentence_piece + pub fn tokenize<'a>( + &'a self, + text: &str, + bos: bool, + ) -> Result, InferenceError> { + let len = text.len(); + + let mut score = vec![0usize; len + 1]; + let mut prev = vec![TokenId::default(); len + 1]; + + for i in 0..len { + let max_len = (len - i).min(self.max_token_length); + for sub_len in 1..=max_len { + let sub = &text.as_bytes()[i..i + sub_len]; + let Ok(sub) = std::str::from_utf8(sub) else { continue; }; + let token = self.token_to_id.get(sub); + + if let Some(token) = token { + let token_score = sub.len() * sub.len(); + let local_score = score[i] + token_score; + let next = i + sub_len; + + if score[next] < local_score { + score[next] = local_score; + prev[next] = *token; + } + } + } + } + + // Backward pass + let mut res = vec![]; + let mut i = len; + while i > 0 { + let token_id = prev[i]; + if token_id == 0 { + return Err(InferenceError::TokenizationFailed); + } + let token = self.id_to_token[token_id as usize].as_str(); + res.push((token, token_id)); + i -= token.len(); + } + + if bos { + // TODO: replace with vocab.bos + res.push(("", 1)); + } + + // Pieces are in reverse order so correct that + res.reverse(); + + Ok(res) + } + fn token(&self, idx: usize) -> &str { + &self.id_to_token[idx] + } +} diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index 111a1c56..a3f6fe72 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -1,1776 +1,2 @@ -#![deny(missing_docs)] - -//! LLaMA-rs is a Rust port of the llama.cpp project. This allows running inference for Facebook's LLaMA model on a CPU with good performance using full precision, f16 or 4-bit quantized versions of the model. - -use core::slice; -use std::{ - collections::HashMap, - fmt::Display, - io::{BufRead, Read, Seek, SeekFrom}, - path::{Path, PathBuf}, - str::FromStr, - time, -}; - -use thiserror::Error; - -use partial_sort::PartialSort; -use rand::{distributions::WeightedIndex, prelude::Distribution}; - -/// The end of text token. -pub const EOT_TOKEN_ID: TokenId = 2; // Hardcoded (for now?) - -/// The hyperparameters of the model. -#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord)] -pub struct Hyperparameters { - n_vocab: usize, - n_ctx: usize, - n_embd: usize, - n_mult: usize, - n_head: usize, - n_layer: usize, - n_rot: usize, - f16_: u32, -} - -struct Layer { - attention_norm: ggml::Tensor, - - wq: ggml::Tensor, - wk: ggml::Tensor, - wv: ggml::Tensor, - wo: ggml::Tensor, - - // normalization - ffn_norm: ggml::Tensor, - - // ff - w1: ggml::Tensor, - w2: ggml::Tensor, - w3: ggml::Tensor, -} - -/// The weights for the LLaMA model. All the mutable state is split into a -/// separate struct `InferenceSession`. -pub struct Model { - hparams: Hyperparameters, - - tok_embeddings: ggml::Tensor, - - norm: ggml::Tensor, - output: ggml::Tensor, - - layers: Vec, - - tensors: HashMap, - - // Must be kept alive for the model - _context: ggml::Context, -} - -/// An inference session represents the state of the text generation. This holds -/// the full context window, as long as several additional parameters used -/// during sampling. -pub struct InferenceSession { - // Must be kept alive for the model - _session_ctx: ggml::Context, - - // Original size of the memory used to create this context. - memory_size: usize, - - // Parameters for the session. - params: InferenceSessionParameters, - - memory_k: ggml::Tensor, - memory_v: ggml::Tensor, - - /// How many tokens have been fed into the model's working memory so far. - n_past: usize, - - /// How much memory is required per token for the temporary context used - /// during inference. - mem_per_token: usize, - - /// All tokens generated by this inference session - tokens: Vec, - - /// The logits that were last predicted by the network. Zeroed out otherwise. - last_logits: Vec, -} -impl InferenceSession { - fn repetition_penalty_tokens(&self) -> &[TokenId] { - &self.tokens[self - .tokens - .len() - .saturating_sub(self.params.repetition_penalty_last_n)..] - } -} -impl Clone for InferenceSession { - fn clone(&self) -> Self { - let context = ggml::Context::init(self.memory_size); - let memory_k = context.new_tensor_1d(self.memory_k.get_type(), self.memory_k.nelements()); - let memory_v = context.new_tensor_1d(self.memory_v.get_type(), self.memory_v.nelements()); - - Self { - _session_ctx: context, - memory_size: self.memory_size, - params: self.params, - memory_k, - memory_v, - n_past: self.n_past, - mem_per_token: self.mem_per_token, - tokens: self.tokens.clone(), - last_logits: self.last_logits.clone(), - } - } -} - -#[derive(serde::Serialize, Clone, PartialEq)] -/// A serializable snapshot of the inference process. Can be saved to disk. -// Keep in sync with [InferenceSession] and [InferenceSnapshot] -pub struct InferenceSnapshotRef<'a> { - /// How many tokens have been stored in the memory so far. - pub npast: usize, - /// Parameters associated with the saved inference session. - pub session_params: InferenceSessionParameters, - /// All tokens generated by this inference session - pub tokens: Vec, - /// The vector of logits that was produced after the last inference - pub logits: Vec, - /// The contents of the 'key' memory tensor - #[serde(with = "serde_bytes")] - pub memory_k: &'a [u8], - /// The contents of the 'value' memory tensor - #[serde(with = "serde_bytes")] - pub memory_v: &'a [u8], -} - -/// A serializable snapshot of the inference process. Can be restored by calling -/// `Model::restore_from_snapshot`. -#[derive(serde::Deserialize, Clone, PartialEq)] -// Keep in sync with [InferenceSession] and [InferenceSnapshotRef] -pub struct InferenceSnapshot { - /// How many tokens have been stored in the memory so far. - pub npast: usize, - /// Parameters associated with the saved inference session. - pub session_params: InferenceSessionParameters, - /// All tokens generated by this inference session - pub tokens: Vec, - /// The vector of logits that was produced after the last inference - pub last_logits: Vec, - /// The contents of the 'key' memory tensor - #[serde(with = "serde_bytes")] - pub memory_k: Vec, - /// The contents of the 'value' memory tensor - #[serde(with = "serde_bytes")] - pub memory_v: Vec, -} - -/// Allowed types for the model memory K/V tensors. -#[derive(Clone, Copy, Debug, PartialEq, serde::Serialize, serde::Deserialize)] -pub enum ModelKVMemoryType { - /// 16-bit float. - Float16, - /// 32-bit float. - Float32, -} -impl From for u32 { - fn from(value: ModelKVMemoryType) -> Self { - match value { - ModelKVMemoryType::Float16 => ggml::TYPE_F16, - ModelKVMemoryType::Float32 => ggml::TYPE_F32, - } - } -} - -#[derive(Clone, Copy, Debug, PartialEq, serde::Serialize, serde::Deserialize)] -/// Parameters for an inference session. -pub struct InferenceSessionParameters { - /// The number of tokens to consider for the repetition penalty. - pub repetition_penalty_last_n: usize, - /// The type of the memory K tensor. - pub memory_k_type: ModelKVMemoryType, - /// The type of the memory V tensor. - pub memory_v_type: ModelKVMemoryType, -} - -impl Default for InferenceSessionParameters { - fn default() -> Self { - Self { - repetition_penalty_last_n: 512, - memory_k_type: ModelKVMemoryType::Float32, - memory_v_type: ModelKVMemoryType::Float32, - } - } -} - -#[derive(Clone, Debug, PartialEq)] -/// The parameters that drive text generation. -pub struct InferenceParameters { - /// The number of threads to use. - pub n_threads: usize, - /// [InferenceSession::feed_prompt] processes the prompt in batches of tokens. - /// This controls how large an individual batch is. - pub n_batch: usize, - /// Top-K: The top K words by score are kept during sampling. - pub top_k: usize, - /// Top-p: The cumulative probability after which no more words are kept for sampling. - pub top_p: f32, - /// The penalty for repeating tokens. Higher values make the generation less - /// likely to get into a loop, but may harm results when repetitive outputs - /// are desired. - pub repeat_penalty: f32, - /// Temperature used for sampling. - pub temperature: f32, - /// A list of tokens to bias against in the process of generation. - pub bias_tokens: TokenBias, - /// Whether or not previous tokens should be played back in [InferenceSession::inference_with_prompt]. - pub play_back_previous_tokens: bool, - /// If set, the inference process will behave more deterministically at the potential cost of performance. - /// - /// Note that this does not guarantee full determinism. When run on the same machine with the same parameters, - /// seed, and this set, inference should be identical, but this is not guaranteed to hold across machines. - pub increased_determinism: bool, -} - -impl Default for InferenceParameters { - fn default() -> Self { - Self { - n_threads: 8, - n_batch: 8, - top_k: 40, - top_p: 0.95, - repeat_penalty: 1.30, - temperature: 0.80, - bias_tokens: TokenBias::default(), - play_back_previous_tokens: false, - increased_determinism: true, - } - } -} - -/// Statistics about the inference process. -pub struct InferenceStats { - /// How long it took to feed the prompt. - pub feed_prompt_duration: std::time::Duration, - /// How many tokens the prompt was. - pub prompt_tokens: usize, - /// How long it took to predict new tokens. - pub predict_duration: std::time::Duration, - /// The number of predicted tokens. - pub predict_tokens: usize, -} - -impl Default for InferenceStats { - fn default() -> Self { - Self { - feed_prompt_duration: std::time::Duration::from_secs(0), - prompt_tokens: 0, - predict_duration: std::time::Duration::from_secs(0), - predict_tokens: 0, - } - } -} - -impl Display for InferenceStats { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!( - f, - "feed_prompt_duration: {}ms\nprompt_tokens: {}\npredict_duration: {}ms\npredict_tokens: {}\nper_token_duration: {:.3}ms", - self.feed_prompt_duration.as_millis(), - self.prompt_tokens, - self.predict_duration.as_millis(), - self.predict_tokens, - (self.predict_duration.as_millis() as f64) / (self.predict_tokens as f64), - ) - } -} - -type TokenId = i32; -type Token = String; -type TokenScore = f32; - -/// The vocabulary used by a model. -pub struct Vocabulary { - /// Maps every integer (index) token id to its corresponding token - id_to_token: Vec, - - /// Maps every integer (index) token id to corresponding score - #[allow(dead_code)] - id_to_token_score: Vec, - - /// Maps a token to a token id - token_to_id: HashMap, - - /// The longest token in this vocabulary - max_token_length: usize, -} -impl Vocabulary { - fn token(&self, idx: usize) -> &str { - &self.id_to_token[idx] - } -} - -#[derive(Default, Clone, Debug, PartialEq)] -/// A list of tokens to bias during the process of inferencing. -/// -/// When a biased token is encountered, the bias will be used -/// instead of the inferred logit during the sampling process. -/// -/// This can be used to disable the generation of responses -/// with specific tokens by setting their corresponding bias -/// to -1.0. -pub struct TokenBias(Vec<(TokenId, f32)>); - -impl TokenBias { - /// Create a [TokenBias] from an existing `Vec`. - pub fn new(mut v: Vec<(TokenId, f32)>) -> Self { - v.sort_by_cached_key(|(tid, _)| *tid); - v.dedup_by_key(|(tid, _)| *tid); - Self(v) - } - - /// Retrieves the bias for a given token, if available. - pub fn get(&self, tid: TokenId) -> Option { - self.0 - .binary_search_by_key(&tid, |(tid, _)| *tid) - .map(|idx| self.0[idx].1) - .ok() - } -} - -impl FromStr for TokenBias { - type Err = String; - - /// A comma separated list of token biases. The list should be in the format - /// "TID=BIAS,TID=BIAS" where TID is an integer token ID and BIAS is a - /// floating point number. - /// For example, "1=-1.0,2=-1.0" sets the bias for token IDs 1 - /// (start of document) and 2 (end of document) to -1.0 which effectively - /// disables the model from generating responses containing those token IDs. - fn from_str(s: &str) -> Result { - let x = s - .split(',') - .map(|kv| { - let (k, v) = kv - .trim() - .split_once('=') - .ok_or_else(|| "Missing '=' in bias item".to_owned())?; - let tid: TokenId = k - .trim() - .parse() - .map_err(|e: std::num::ParseIntError| e.to_string())?; - let bias: f32 = v - .trim() - .parse() - .map_err(|e: std::num::ParseFloatError| e.to_string())?; - Result::<_, String>::Ok((tid, bias)) - }) - .collect::>()?; - Ok(TokenBias::new(x)) - } -} - -impl std::fmt::Display for TokenBias { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self.0) - } -} - -/// Each variant represents a step within the process of loading the model. -/// These can be used to report progress to the user. -#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug)] -pub enum LoadProgress<'a> { - /// The hyperparameters have been loaded from the model. - HyperparametersLoaded(&'a Hyperparameters), - /// A bad token was encountered during the loading process. - /// - /// This can be ignored, but invalid tokens will be replaced with - /// the `�` character. - BadToken { - /// The index within the vocabulary. - index: usize, - }, - /// The context has been created. - ContextSize { - /// The size of the context. - bytes: usize, - }, - /// A part of the model is being loaded. - PartLoading { - /// The path to the model part. - file: &'a Path, - /// The current part (0-indexed). - current_part: usize, - /// The number of total parts. - total_parts: usize, - }, - /// A tensor from the current part has been loaded. - PartTensorLoaded { - /// The path to the model part. - file: &'a Path, - /// The current tensor (0-indexed). - current_tensor: usize, - /// The number of total tensors. - tensor_count: usize, - }, - /// A model part has finished fully loading. - PartLoaded { - /// The path to the model part. - file: &'a Path, - /// The number of bytes in the part. - byte_size: usize, - /// The number of tensors in the part. - tensor_count: usize, - }, -} - -#[derive(Error, Debug)] -/// Errors encountered during the loading process. -pub enum LoadError { - #[error("could not open file {path:?}")] - /// A file failed to open. - OpenFileFailed { - /// The original error. - source: std::io::Error, - /// The path that failed. - path: PathBuf, - }, - #[error("no parent path for {path:?}")] - /// There is no parent path for a given path. - NoParentPath { - /// The path without a parent. - path: PathBuf, - }, - #[error("unable to read exactly {bytes} bytes")] - /// Reading exactly `bytes` from a file failed. - ReadExactFailed { - /// The original error. - source: std::io::Error, - /// The number of bytes that were attempted to be read. - bytes: usize, - }, - #[error("non-specific I/O error")] - /// A non-specific IO error. - IO(#[from] std::io::Error), - #[error("could not convert bytes to a UTF-8 string")] - /// One of the strings encountered was not valid UTF-8. - InvalidUtf8(#[from] std::string::FromUtf8Error), - #[error("invalid integer conversion")] - /// One of the integers encountered could not be converted to a more appropriate type. - InvalidIntegerConversion(#[from] std::num::TryFromIntError), - #[error("invalid magic number for {path:?}")] - /// An invalid magic number was encountered during the loading process. - InvalidMagic { - /// The path that failed. - path: PathBuf, - }, - #[error("invalid file format version {value}")] - /// The version of the format is not supported by this version of `llama-rs`. - InvalidFormatVersion { - /// The version that was encountered. - value: u32, - }, - #[error("invalid value {ftype} for `f16` in hyperparameters")] - /// The `f16` hyperparameter had an invalid value. - HyperparametersF16Invalid { - /// The format type that was encountered. - ftype: u32, - }, - #[error("unknown tensor `{tensor_name}` in {path:?}")] - /// The tensor `tensor_name` was encountered during the loading of `path`, but was not seen during - /// the model prelude. - UnknownTensor { - /// The name of the tensor. - tensor_name: String, - /// The path that failed. - path: PathBuf, - }, - #[error("the tensor `{tensor_name}` has the wrong size in {path:?}")] - /// The tensor `tensor_name` did not match its expected size. - TensorWrongSize { - /// The name of the tensor. - tensor_name: String, - /// The path that failed. - path: PathBuf, - }, - /// The tensor `tensor_name` did not have the expected format type. - #[error("invalid ftype {ftype} for tensor `{tensor_name}` in {path:?}")] - InvalidFtype { - /// The name of the tensor. - tensor_name: String, - /// The format type that was encountered. - ftype: u32, - /// The path that failed. - path: PathBuf, - }, -} - -#[derive(Error, Debug)] -/// Errors encountered during the snapshot process. -pub enum SnapshotError { - /// Arbitrary I/O error. - #[error("I/O error while reading or writing snapshot")] - IO(#[from] std::io::Error), - /// Error during the serialization process. - #[error("error during snapshot serialization")] - Serialization(#[from] bincode::Error), - /// Mismatch between the snapshotted memory and the in-memory memory. - #[error("could not read snapshot due to size mismatch (self={self_size}, input={input_size})")] - MemorySizeMismatch { - /// The size of the session memory in memory. - self_size: usize, - /// The size of the session memory in snapshot. - input_size: usize, - }, -} - -#[derive(Error, Debug)] -/// Errors encountered during the inferencep rocess. -pub enum InferenceError { - #[error("an invalid token was encountered during tokenization")] - /// During tokenization, one of the produced tokens was invalid / zero. - TokenizationFailed, - #[error("the context window is full")] - /// The context window for the model is full. - ContextFull, - #[error("reached end of text")] - /// The model has produced an end of text token, signalling that it thinks that the text should end here. - /// - /// Note that this error *can* be ignored and inference can continue, but the results are not guaranteed to be sensical. - EndOfText, - #[error("the user-specified callback returned an error")] - /// The user-specified callback returned an error. - UserCallback(Box), -} - -/// Used in a call to `evaluate` to request information from the transformer. -#[derive(Default)] -pub struct EvaluateOutputRequest { - /// Returns all the logits for the provided batch of tokens. - /// Output shape is n_batch * n_vocab - pub all_logits: Option>, - /// Returns the embeddings for the provided batch of tokens - /// Output shape is n_batch * n_embd - pub embeddings: Option>, -} - -/// NOTE: The original code relies in promotion rules and automatic cast between -/// int to float. What we do instead is use this macro to convert every term of -/// the multiplication to f64, which should have enough precision bits to hold -/// the final value, then cast to usize. I have observed a discrepancy between -/// the ctx_size found using this code, and the one in llama.cpp. The number for -/// rust ends up being slightly lower, but no "out of memory" errors are -/// reported by ggml. -macro_rules! mulf { - ($term:expr, $($terms:expr),*) => { - usize::try_from((($term as f64) $(* ($terms as f64))*) as u64).unwrap() - }; -} - -impl Model { - /// Load the model from `path` with `n_context_tokens` context tokens. - /// - /// The status of the loading process will be reported through `load_progress_callback`. - pub fn load( - path: impl AsRef, - n_context_tokens: usize, - load_progress_callback: impl Fn(LoadProgress), - ) -> Result<(Model, Vocabulary), LoadError> { - use std::fs::File; - use std::io::BufReader; - - let main_path = path.as_ref(); - - let mut reader = - BufReader::new( - File::open(main_path).map_err(|e| LoadError::OpenFileFailed { - source: e, - path: main_path.to_owned(), - })?, - ); - - fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> { - let mut bytes = [0u8; N]; - reader - .read_exact(&mut bytes) - .map_err(|e| LoadError::ReadExactFailed { - source: e, - bytes: N, - })?; - Ok(bytes) - } - - fn read_i32(reader: &mut impl BufRead) -> Result { - Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) - } - - fn read_u32(reader: &mut impl BufRead) -> Result { - Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) - } - - fn read_f32(reader: &mut impl BufRead) -> Result { - Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) - } - - /// Helper function. Reads a string from the buffer and returns it. - fn read_string(reader: &mut BufReader, len: usize) -> Result { - let mut buf = vec![0; len]; - reader - .read_exact(&mut buf) - .map_err(|e| LoadError::ReadExactFailed { - source: e, - bytes: buf.len(), - })?; - let s = String::from_utf8(buf)?; - Ok(s) - } - - // Verify magic - let is_legacy_model: bool = match read_u32(&mut reader)? { - ggml::FILE_MAGIC => false, - ggml::FILE_MAGIC_UNVERSIONED => true, - _ => { - return Err(LoadError::InvalidMagic { - path: main_path.to_owned(), - }) - } - }; - - // Load format version - if !is_legacy_model { - #[allow(unused_variables)] - let version: u32 = match read_u32(&mut reader)? { - ggml::FORMAT_VERSION => ggml::FORMAT_VERSION, - version => return Err(LoadError::InvalidFormatVersion { value: version }), - }; - } - - // ================= - // Load hyper params - // ================= - - // NOTE: Field order matters! Data is laid out in the file exactly - // in this order. - let hparams = Hyperparameters { - n_vocab: read_i32(&mut reader)?.try_into()?, - n_ctx: n_context_tokens, - n_embd: read_i32(&mut reader)?.try_into()?, - n_mult: read_i32(&mut reader)?.try_into()?, - n_head: read_i32(&mut reader)?.try_into()?, - n_layer: read_i32(&mut reader)?.try_into()?, - n_rot: read_i32(&mut reader)?.try_into()?, - f16_: read_i32(&mut reader)?.try_into()?, - }; - - let n_ff = - ((2 * (4 * hparams.n_embd) / 3 + hparams.n_mult - 1) / hparams.n_mult) * hparams.n_mult; - - load_progress_callback(LoadProgress::HyperparametersLoaded(&hparams)); - - // =============== - // Load vocabulary - // =============== - let vocab = { - let mut id_to_token = vec![]; - let mut id_to_token_score = vec![]; - let mut token_to_id = HashMap::new(); - let mut max_token_length = 0; - - for i in 0..hparams.n_vocab { - let len = read_i32(&mut reader)?; - if let Ok(word) = read_string(&mut reader, len as usize) { - max_token_length = max_token_length.max(word.len()); - id_to_token.push(word.clone()); - token_to_id.insert(word, TokenId::try_from(i)?); - } else { - load_progress_callback(LoadProgress::BadToken { index: i }); - id_to_token.push("�".to_string()); - } - - // Token score, currently unused - if !is_legacy_model { - if let Ok(score) = read_f32(&mut reader) { - id_to_token_score.push(score); - } - } else { - // Legacy model, set empty score - id_to_token_score.push(0.); - } - } - - Vocabulary { - id_to_token, - id_to_token_score, - token_to_id, - max_token_length, - } - }; - - // for the big tensors, we have the option to store the data in 16-bit - // floats or quantized in order to save memory and also to speed up the - // computation - let wtype = match hparams.f16_ { - 0 => ggml::TYPE_F32, - 1 => ggml::TYPE_F16, - 2 => ggml::TYPE_Q4_0, - 3 => ggml::TYPE_Q4_1, - invalid => return Err(LoadError::HyperparametersF16Invalid { ftype: invalid }), - }; - - let n_embd = hparams.n_embd; - let n_layer = hparams.n_layer; - let n_vocab = hparams.n_vocab; - - let ctx_size = { - // Use 64-bit math to prevent overflow. - let mut ctx_size: usize = 0; - - ctx_size += mulf!(n_embd, n_vocab, ggml::type_sizef(wtype)); // tok_embeddings - - ctx_size += mulf!(n_embd, ggml::type_sizef(ggml::TYPE_F32)); // norm - - ctx_size += mulf!(n_embd, n_vocab, ggml::type_sizef(wtype)); // output - - ctx_size += mulf!(n_layer, n_embd, ggml::type_sizef(ggml::TYPE_F32)); // attention_norm - - ctx_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wq - ctx_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wk - ctx_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wv - ctx_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wo - - ctx_size += mulf!(n_layer, n_embd, ggml::type_sizef(ggml::TYPE_F32)); // ffn_norm - - ctx_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w1 - ctx_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w2 - ctx_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w3 - - ctx_size += (5 + 10 * n_layer) * 256; // object overhead - - load_progress_callback(LoadProgress::ContextSize { bytes: ctx_size }); - - ctx_size - }; - - // Initialize the context - let context = ggml::Context::init(ctx_size); - - let model = { - let mut tensors = HashMap::new(); - - let tok_embeddings = context.new_tensor_2d(wtype, n_embd, n_vocab); - let norm = context.new_tensor_1d(ggml::TYPE_F32, n_embd); - let output = context.new_tensor_2d(wtype, n_embd, n_vocab); - - tensors.insert("tok_embeddings.weight".to_owned(), tok_embeddings.share()); - tensors.insert("norm.weight".to_owned(), norm.share()); - tensors.insert("output.weight".to_owned(), output.share()); - - let mut layers = Vec::new(); - for i in 0..n_layer { - let layer = Layer { - attention_norm: context.new_tensor_1d(ggml::TYPE_F32, n_embd), - wq: context.new_tensor_2d(wtype, n_embd, n_embd), - wk: context.new_tensor_2d(wtype, n_embd, n_embd), - wv: context.new_tensor_2d(wtype, n_embd, n_embd), - wo: context.new_tensor_2d(wtype, n_embd, n_embd), - ffn_norm: context.new_tensor_1d(ggml::TYPE_F32, n_embd), - w1: context.new_tensor_2d(wtype, n_embd, n_ff), - w2: context.new_tensor_2d(wtype, n_ff, n_embd), - w3: context.new_tensor_2d(wtype, n_embd, n_ff), - }; - - tensors.insert( - format!("layers.{i}.attention_norm.weight"), - layer.attention_norm.share(), - ); - - tensors.insert(format!("layers.{i}.attention.wq.weight"), layer.wq.share()); - tensors.insert(format!("layers.{i}.attention.wk.weight"), layer.wk.share()); - tensors.insert(format!("layers.{i}.attention.wv.weight"), layer.wv.share()); - tensors.insert(format!("layers.{i}.attention.wo.weight"), layer.wo.share()); - - tensors.insert( - format!("layers.{i}.ffn_norm.weight"), - layer.ffn_norm.share(), - ); - - tensors.insert( - format!("layers.{i}.feed_forward.w1.weight"), - layer.w1.share(), - ); - tensors.insert( - format!("layers.{i}.feed_forward.w2.weight"), - layer.w2.share(), - ); - tensors.insert( - format!("layers.{i}.feed_forward.w3.weight"), - layer.w3.share(), - ); - - layers.push(layer); - } - - Model { - hparams, - tok_embeddings, - norm, - output, - layers, - tensors, - _context: context, - } - }; - - // Close the file, but keep its offset. That way we know how to skip the - // metadata when loading the parts. - let file_offset = reader.stream_position()?; - drop(reader); - - let paths = { - let main_filename = main_path.file_name().and_then(|p| p.to_str()); - - let mut paths: Vec = - std::fs::read_dir(main_path.parent().ok_or_else(|| LoadError::NoParentPath { - path: main_path.to_owned(), - })?)? - .filter_map(Result::ok) - .map(|de| de.path()) - .filter(|p| { - p.file_name() - .and_then(|p| p.to_str()) - .zip(main_filename) - .map(|(part_filename, main_filename)| { - part_filename.starts_with(main_filename) - }) - .unwrap_or(false) - }) - .collect(); - paths.sort(); - paths - }; - - let n_parts = paths.len(); - - for (i, part_path) in paths.into_iter().enumerate() { - let part_id = i; - - load_progress_callback(LoadProgress::PartLoading { - file: &part_path, - current_part: i, - total_parts: n_parts, - }); - - let mut part_reader = BufReader::new(File::open(&part_path)?); - - // Skip metadata - part_reader.seek(SeekFrom::Start(file_offset))?; - - let mut total_size = 0; - let mut n_tensors = 0; - - // Load weights - loop { - // NOTE: Implementation from #![feature(buf_read_has_data_left)] - let is_eof = part_reader.fill_buf().map(|b| b.is_empty())?; - - if is_eof { - break; - } - - let n_dims = usize::try_from(read_i32(&mut part_reader)?)?; - let length = read_i32(&mut part_reader)?; - let ftype = read_u32(&mut part_reader)?; - - let mut nelements = 1; - let mut ne = [1i64, 1i64]; - - #[allow(clippy::needless_range_loop)] - for i in 0..n_dims { - ne[i] = read_i32(&mut part_reader)? as i64; - nelements *= usize::try_from(ne[i])?; - } - - let tensor_name = read_string(&mut part_reader, length as usize)?; - - let Some(tensor) = model.tensors.get(&tensor_name) - else { - return Err(LoadError::UnknownTensor { tensor_name, path: part_path }); - }; - - // split_type = 0: split by columns - // split_type = 1: split by rows - // - // split_type = 0: - // regex: - // - tok_embeddings.* - // - layers.*.attention.wo.weight - // - layers.*.feed_forward.w2.weight - - // split_type = 1: - // regex: - // - output.* - // - layers.*.attention.wq.weight - // - layers.*.attention.wk.weight - // - layers.*.attention.wv.weight - // - layers.*.feed_forward.w1.weight - // - layers.*.feed_forward.w3.weight - #[allow(clippy::if_same_then_else)] - let split_type = if tensor_name.contains("tok_embeddings") { - 0 - } else if tensor_name.contains("layers") { - if tensor_name.contains("attention.wo.weight") { - 0 - } else if tensor_name.contains("feed_forward.w2.weight") { - 0 - } else { - 1 - } - } else if tensor_name.contains("output") { - 1 - } else { - 0 - }; - - if n_dims == 1 { - if tensor.nelements() != nelements { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - } else if tensor.nelements() / n_parts != nelements { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - - if n_dims == 1 { - if tensor.get_ne()[0] != ne[0] || tensor.get_ne()[1] != ne[1] { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - } else if split_type == 0 { - if tensor.get_ne()[0] / i64::try_from(n_parts)? != ne[0] - || tensor.get_ne()[1] != ne[1] - { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - } else if tensor.get_ne()[0] != ne[0] - || tensor.get_ne()[1] / i64::try_from(n_parts)? != ne[1] - { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - - let bpe = match ftype { - 0 => ggml::type_size(ggml::TYPE_F32), - 1 => ggml::type_size(ggml::TYPE_F16), - 2 => { - assert_eq!(ne[0] % 64, 0); - ggml::type_size(ggml::TYPE_Q4_0) - } - 3 => { - assert_eq!(ne[0] % 64, 0); - ggml::type_size(ggml::TYPE_Q4_1) - } - _ => { - return Err(LoadError::InvalidFtype { - tensor_name, - ftype, - path: part_path, - }) - } - }; - - if n_dims == 1 || n_parts == 1 { - if (nelements * bpe) / ggml::blck_size(tensor.get_type()) != tensor.nbytes() { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - - if part_id == 0 { - // SAFETY: yolo, same as original code - let slice = unsafe { - let data = tensor.data(); - std::slice::from_raw_parts_mut(data as *mut u8, tensor.nbytes()) - }; - part_reader.read_exact(slice)?; - } else { - part_reader.seek(SeekFrom::Current(tensor.nbytes() as i64))?; - } - - total_size += tensor.nbytes(); - } else { - if (nelements * bpe) / ggml::blck_size(tensor.get_type()) - != tensor.nbytes() / n_parts - { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - - if split_type == 0 { - let np0 = ne[0]; - let row_size = (usize::try_from(tensor.get_ne()[0])? - / ggml::blck_size(tensor.get_type())) - * ggml::type_size(tensor.get_type()); - - assert_eq!(row_size, tensor.get_nb()[1]); - - for i1 in 0..ne[1] { - let offset_row = i1 as usize * row_size; - let offset = offset_row - + ((part_id * np0 as usize) / ggml::blck_size(tensor.get_type())) - * ggml::type_size(tensor.get_type()); - // SAFETY: yolo, same as original code - unsafe { - let ptr = tensor.data().add(offset); - let slice = std::slice::from_raw_parts_mut( - ptr as *mut u8, - row_size / n_parts, - ); - part_reader.read_exact(slice)?; - } - } - } else { - let np1 = ne[1]; - let row_size = (usize::try_from(tensor.get_ne()[0])? - / ggml::blck_size(tensor.get_type())) - * ggml::type_size(tensor.get_type()); - - for i1 in 0..ne[1] { - let offset_row = (i1 as usize + part_id * np1 as usize) * row_size; - // SAFETY: yolo, same as original code - unsafe { - let ptr = tensor.data().add(offset_row); - let slice = - std::slice::from_raw_parts_mut(ptr as *mut u8, row_size); - part_reader.read_exact(slice)?; - } - } - } - - total_size += tensor.nbytes() / n_parts; - } - - n_tensors += 1; - load_progress_callback(LoadProgress::PartTensorLoaded { - file: &part_path, - current_tensor: n_tensors.try_into()?, - tensor_count: model.tensors.len(), - }); - } - - load_progress_callback(LoadProgress::PartLoaded { - file: &part_path, - byte_size: total_size, - tensor_count: n_tensors.try_into()?, - }); - } - - Ok((model, vocab)) - } - - /// Starts a new `InferenceSession` for this model. - pub fn start_session(&self, params: InferenceSessionParameters) -> InferenceSession { - let Hyperparameters { - n_ctx, - n_embd, - n_layer, - n_vocab, - .. - } = self.hparams; - - let ctx_size = { - let mut ctx_size = 0; - ctx_size += mulf!( - n_ctx, - n_layer, - n_embd, - ggml::type_sizef(params.memory_k_type.into()) - ); // memory_k - ctx_size += mulf!( - n_ctx, - n_layer, - n_embd, - ggml::type_sizef(params.memory_v_type.into()) - ); // memory_v - ctx_size += (5 + 10 * n_layer) * 256; // object overhead - ctx_size - }; - - let session_ctx = ggml::Context::init(ctx_size); - - // Initialize key + value memory tensors - let n_mem = n_layer * n_ctx; - let n_elements = n_embd * n_mem; - let memory_k = session_ctx.new_tensor_1d(params.memory_k_type.into(), n_elements); - let memory_v = session_ctx.new_tensor_1d(params.memory_v_type.into(), n_elements); - - InferenceSession { - _session_ctx: session_ctx, - memory_size: ctx_size, - params, - memory_k, - memory_v, - n_past: 0, - mem_per_token: 0, - tokens: vec![], - last_logits: vec![0.0; n_vocab], - } - } - - /// Evaluates the transformer. - /// - /// The provided `output_request` struct lets you specify which additional - /// data you are interested in fetching from the transformer. Setting a - /// field to a `Some` value will clear and fill the provided vector with - /// data. The provided vector will be resized to the exact output size. - pub fn evaluate( - &self, - session: &mut InferenceSession, - params: &InferenceParameters, - input_tokens: &[TokenId], - output_request: &mut EvaluateOutputRequest, - ) { - let n = input_tokens.len(); - let n_past = session.n_past; - let n_threads = params.n_threads; - let increased_determinism = params.increased_determinism; - - let Hyperparameters { - n_vocab, - n_ctx, - n_embd, - n_mult: _, - n_head, - n_layer, - n_rot, - f16_: _, - } = self.hparams; - - // For the first run, we need to guess a maximum buffer size so we can measure - // the actual memory consumption of the temporary ggml context. - let mut buf_size = 1024 * 1024 * 1024; - if session.mem_per_token > 0 && session.mem_per_token * n > buf_size { - // add 10% to account for ggml object overhead - buf_size = (1.1f64 * session.mem_per_token as f64 * n as f64) as usize; - }; - let ctx0 = ggml::Context::init(buf_size); - - let mut gf = ggml::ComputationGraph::new(n_threads); - - let embd = ctx0.new_tensor_1d(ggml::TYPE_I32, n); - unsafe { embd.write_data(bytemuck::cast_slice(input_tokens)) }; - - let mut input_layer = ctx0.op_get_rows(&self.tok_embeddings, &embd); - - // Defined here to avoid repetition and creating a binding inside nested loops. - // See the call site below for more context. - let vtrans_fun = |il: usize| -> ggml::Tensor { - ctx0.op_permute( - &ctx0.op_reshape_3d( - &ctx0.op_view_1d( - &session.memory_v, - (n_past + n) * n_embd, - il * n_ctx * session.memory_v.element_size() * n_embd, - ), - n_embd / n_head, - n_head, - n_past + n, - ), - 1, - 2, - 0, - 3, - ) - }; - - for il in 0..n_layer { - let input_self_attention = input_layer.share(); - let mut current: ggml::Tensor; - - // norm - { - current = ctx0.op_rms_norm(&input_layer); - - // cur = attention_norm * cur - current = ctx0.op_mul( - &ctx0.op_repeat(&self.layers[il].attention_norm, ¤t), - ¤t, - ); - } - - // self-attention - { - let q_current = ctx0.op_mul_mat(&self.layers[il].wq, ¤t); - let k_current = ctx0.op_mul_mat(&self.layers[il].wk, ¤t); - let v_current = ctx0.op_mul_mat(&self.layers[il].wv, ¤t); - - // store key and value to memory - if n >= 1 { - let k = ctx0.op_view_1d( - &session.memory_k, - n * n_embd, - (session.memory_k.element_size() * n_embd) * (il * n_ctx + n_past), - ); - - let v = ctx0.op_view_1d( - &session.memory_v, - n * n_embd, - (session.memory_v.element_size() * n_embd) * (il * n_ctx + n_past), - ); - - gf.build_forward_expand(&ctx0.op_cpy(&k_current, &k)); - gf.build_forward_expand(&ctx0.op_cpy(&v_current, &v)); - } - - // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) - let q = ctx0.op_permute( - &ctx0.op_rope( - &ctx0.op_cpy( - &q_current, - &ctx0.new_tensor_3d(ggml::TYPE_F32, n_embd / n_head, n_head, n), - ), - n_past, - n_rot, - 0, - ), - 0, - 2, - 1, - 3, - ); - - // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) - let k = ctx0.op_permute( - &ctx0.op_rope( - &ctx0.op_reshape_3d( - &ctx0.op_view_1d( - &session.memory_k, - (n_past + n) * n_embd, - il * n_ctx * session.memory_k.element_size() * n_embd, - ), - n_embd / n_head, - n_head, - n_past + n, - ), - n_past, - n_rot, - 1, - ), - 0, - 2, - 1, - 3, - ); - - // K * Q - let k_q = ctx0.op_mul_mat(&k, &q); - - // KQ_scaled = KQ / sqrt(n_embd/n_head) - let k_q_scaled = ctx0.op_scale( - &k_q, - &ctx0.new_f32(1.0 / f32::sqrt(n_embd as f32 / n_head as f32)), - ); - - // KQ_masked = mask_past(KQ_scaled) - let k_q_masked = ctx0.op_diag_mask_inf(&k_q_scaled, n_past); - - // KQ = soft_max(KQ_masked) - let k_q_soft_max = ctx0.op_soft_max(&k_q_masked); - - // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() - let v_transposed = { - if !increased_determinism { - vtrans_fun(il) - } else { - ctx0.op_cpy( - &vtrans_fun(il), - &ctx0.new_tensor_3d( - ggml::TYPE_F32, - n_past + n, - n_embd / n_head, - n_head, - ), - ) - } - }; - - // KQV = transpose(V) * KQ_soft_max - let k_q_v = ctx0.op_mul_mat(&v_transposed, &k_q_soft_max); - - // KQV_merged = KQV.permute(0, 2, 1, 3) - let k_q_v_merged = ctx0.op_permute(&k_q_v, 0, 2, 1, 3); - - // cur = KQV_merged.contiguous().view(n_embd, N) - current = ctx0.op_cpy( - &k_q_v_merged, - &ctx0.new_tensor_2d(ggml::TYPE_F32, n_embd, n), - ); - - // projection (no bias) - current = ctx0.op_mul_mat(&self.layers[il].wo, ¤t); - } - - let input_feed_forward = ctx0.op_add(¤t, &input_self_attention); - - // feed-forward network - { - // norm - { - current = ctx0.op_rms_norm(&input_feed_forward); - - // cur = ffn_norm*cur - current = ctx0.op_mul( - &ctx0.op_repeat(&self.layers[il].ffn_norm, ¤t), - ¤t, - ); - } - - let tmp = ctx0.op_mul_mat(&self.layers[il].w3, ¤t); - - current = ctx0.op_mul_mat(&self.layers[il].w1, ¤t); - - // SILU activation - current = ctx0.op_silu(¤t); - - current = ctx0.op_mul(¤t, &tmp); - - current = ctx0.op_mul_mat(&self.layers[il].w2, ¤t); - } - - current = ctx0.op_add(¤t, &input_feed_forward); - - // input for next layer - input_layer = current; - } - - // Used at the end to optionally extract the embeddings. - let embeddings_tensor; - - // norm - { - input_layer = ctx0.op_rms_norm(&input_layer); - - // inpL = norm*inpL - input_layer = ctx0.op_mul(&ctx0.op_repeat(&self.norm, &input_layer), &input_layer); - embeddings_tensor = input_layer.share(); - } - - // lm_head - { - input_layer = ctx0.op_mul_mat(&self.output, &input_layer); - } - - // logits -> probs - // inpL = ctx0.op_soft_max(&inpL); - - // run the computation - gf.build_forward_expand(&input_layer); - ctx0.graph_compute(&mut gf); - - // return result for just the last token - // SAFETY: yolo - assert_eq!(session.last_logits.len(), n_vocab); - unsafe { - input_layer.read_data( - n_vocab * (n - 1) * std::mem::size_of::(), - bytemuck::cast_slice_mut(&mut session.last_logits), - ) - }; - - // Extract logits - if let Some(all_logits) = &mut output_request.all_logits { - all_logits.resize(n_vocab * n, 0.0); - // SAFETY: Tensor data can be read (properly aligned, initialized, - // data will not be mutated or otherwise aliased during the copy), - // and we're not reading past the end of the tensor data. - assert_eq!(input_layer.nelements(), n_vocab * n); - unsafe { - input_layer.read_data(0, bytemuck::cast_slice_mut(all_logits)); - } - } - - // Extract embeddings - if let Some(embeddings) = &mut output_request.embeddings { - embeddings.resize(n_embd * n, 0.0); - // SAFETY: Same rationale as for the "Extract logits" section applies. - assert_eq!(embeddings_tensor.nelements(), n_embd * n); - unsafe { - embeddings_tensor.read_data(0, bytemuck::cast_slice_mut(embeddings)); - } - } - - // Adjust the required memory per token if we didn't know that already - if session.mem_per_token == 0 { - session.mem_per_token = ctx0.used_mem() / n; - } - - // Adjust n_past to new length. - session.n_past += input_tokens.len(); - } - - /// Hydrates a previously obtained InferenceSnapshot for this model - pub fn session_from_snapshot( - &self, - snapshot: InferenceSnapshot, - ) -> Result { - let mut session = self.start_session(snapshot.session_params); - - if session.memory_k.nbytes() != snapshot.memory_k.len() - || session.memory_v.nbytes() != snapshot.memory_v.len() - { - return Err(SnapshotError::MemorySizeMismatch { - self_size: session.memory_k.nbytes() + session.memory_v.nbytes(), - input_size: snapshot.memory_k.len() + snapshot.memory_v.len(), - }); - } - - // SAFETY: We have exclusive access to Session, which means no one else - // should be touching the context's memory. We can write to it because - // we already checked the size. - unsafe { - session.memory_k.write_data(&snapshot.memory_k); - session.memory_v.write_data(&snapshot.memory_v); - } - - session.n_past = snapshot.npast; - session.tokens = snapshot.tokens; - session.last_logits = snapshot.last_logits; - - Ok(session) - } -} - -impl InferenceSession { - /// Feed a prompt to the model for this session. - pub fn feed_prompt( - &mut self, - model: &Model, - vocab: &Vocabulary, - params: &InferenceParameters, - prompt: &str, - callback: impl Fn(&str) -> Result<(), E>, - ) -> Result<(), InferenceError> { - let beginning_of_sentence = self.n_past == 0; - let prompt_tokens: Vec = vocab - .tokenize(prompt, beginning_of_sentence)? - .iter() - .map(|(_, tok)| *tok) - .collect(); - - if self.n_past + prompt_tokens.len() >= model.hparams.n_ctx { - return Err(InferenceError::ContextFull); - } - - for batch in prompt_tokens.chunks(params.n_batch) { - model.evaluate(self, params, batch, &mut EvaluateOutputRequest::default()); - for &tk in batch { - // NOTE: No string ever tokenizes to the end of sentence. So we - // can just return the id here. - if let Err(e) = callback(vocab.token(tk as usize)) { - return Err(InferenceError::UserCallback(Box::new(e))); - } - - // Update the tokens for this session - self.tokens.push(tk); - } - } - - Ok(()) - } - - /// Infer the next token for this session. - pub fn infer_next_token<'v>( - &mut self, - model: &Model, - vocab: &'v Vocabulary, - params: &InferenceParameters, - rng: &mut impl rand::Rng, - ) -> Result<&'v str, InferenceError> { - if self.n_past + 1 >= model.hparams.n_ctx { - return Err(InferenceError::ContextFull); - } - - // First, sample the next token, using the stored last_logits; - let next_token = self.sample_top_p_top_k(params, rng); - - // Update the tokens for this session - self.tokens.push(next_token); - - // Then, evaluate the network again to compute the new last_logits - model.evaluate( - self, - params, - &[next_token], - &mut EvaluateOutputRequest::default(), - ); - - // Return the next token - if next_token as TokenId == EOT_TOKEN_ID { - Err(InferenceError::EndOfText) - } else { - Ok(vocab.token(next_token as usize)) - } - } - - // todo: see if we can reduce the arguments here somehow - consolidate model and vocab maybe? - /// Helper function to run inference with this session and the given model and vocabulary. - /// The `callback` is called with each new token until inference is complete. - /// - /// If `params.play_back_previous_tokens` is specified, this will "play back" all existing tokens in the session. - #[allow(clippy::too_many_arguments)] - pub fn inference_with_prompt( - &mut self, - model: &Model, - vocab: &Vocabulary, - params: &InferenceParameters, - prompt: &str, - maximum_token_count: Option, - rng: &mut impl rand::Rng, - callback: impl Fn(&str) -> Result<(), E>, - ) -> Result { - let maximum_token_count = maximum_token_count.unwrap_or(usize::MAX); - if params.play_back_previous_tokens { - // "Play back" the existing tokens, so that loading from an inference snapshot works - // as expected. - for token_id in &self.tokens { - if let Err(e) = callback(vocab.token(*token_id as usize)) { - return Err(InferenceError::UserCallback(Box::new(e))); - } - } - } - - let mut stats = InferenceStats::default(); - - let start_at = time::SystemTime::now(); - - // Feed the initial prompt through the transformer, to update its - // context window with new data. - self.feed_prompt(model, vocab, params, prompt, |tk| callback(tk))?; - stats.feed_prompt_duration = start_at.elapsed().unwrap(); - stats.prompt_tokens = self.n_past; - - // After the prompt is consumed, sample tokens by repeatedly calling - // `infer_next_token`. We generate tokens until the model returns an - // EndOfText token, or we run out of space in the context window, - // or we reach the specified limit. - let mut tokens_processed = 0; - while tokens_processed < maximum_token_count { - let token = match self.infer_next_token(model, vocab, params, rng) { - Ok(token) => token, - Err(InferenceError::EndOfText) => break, - Err(e) => return Err(e), - }; - - if let Err(e) = callback(token) { - return Err(InferenceError::UserCallback(Box::new(e))); - } - - tokens_processed += 1; - } - stats.predict_duration = start_at.elapsed().unwrap(); - stats.predict_tokens = self.n_past; - - Ok(stats) - } - - /// Sample a token using Top-P/Top-K sampling and the last logits from this session. - pub fn sample_top_p_top_k( - &self, - params: &InferenceParameters, - rng: &mut impl rand::Rng, - ) -> TokenId { - let logits = &self.last_logits; - let n_logits = logits.len(); - let mut logits_id = Vec::<(f32, TokenId)>::with_capacity(n_logits); - - { - let scale = 1.0 / params.temperature; - for (i, &logit) in logits.iter().enumerate() { - let tid = i as TokenId; - - let val = if let Some(logit_override) = params.bias_tokens.get(tid) { - logit_override - } else if self.repetition_penalty_tokens().contains(&(i as TokenId)) { - // repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) - // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main - - // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability - if logits[i] < 0.0 { - logit * scale * params.repeat_penalty - } else { - logit * scale / params.repeat_penalty - } - } else { - logit * scale - }; - logits_id.push((val, tid)); - } - } - - // find the top K tokens - { - logits_id.partial_sort(params.top_k, |a, b| { - // Sort descending - b.0.total_cmp(&a.0) - }); - logits_id.truncate(params.top_k); - } - - let maxl = logits_id - .iter() - .map(|x| x.0) - .max_by(f32::total_cmp) - .unwrap(); - - // compute probs for the top K tokens - let mut probs: Vec = logits_id - .iter() - .copied() - .map(|(k, _)| (k - maxl).exp()) - .collect(); - let sum: f32 = probs.iter().copied().sum(); - - // Normalize the probs - for p in probs.iter_mut() { - *p /= sum; - } - - // Top p sampling - if params.top_p < 1.0 { - let mut cumsum = 0.0; - for i in 0..probs.len() { - cumsum += probs[i]; - if cumsum >= params.top_p { - probs.truncate(i + 1); - logits_id.truncate(i + 1); - break; - } - } - - cumsum = 1.0 / cumsum; - for p in probs.iter_mut() { - *p *= cumsum; - } - } - - let dist = WeightedIndex::new(&probs).expect("WeightedIndex error"); - let idx = dist.sample(rng); - - logits_id[idx].1 - } - - /// Obtains a serializable snapshot of the current inference status. This - /// can be used to cache the state of the model and store them into a file. - /// - /// # Safety - /// - /// This function provides raw access to the underlying memory owned by the - /// ggml context. While the provided `InferenceSnapshotRef` object is alive, - /// no other methods for this model object should be called. - pub unsafe fn get_snapshot(&mut self) -> InferenceSnapshotRef<'_> { - let memory_k = unsafe { - slice::from_raw_parts(self.memory_k.data() as *mut u8, self.memory_k.nbytes()) - }; - let memory_v = unsafe { - slice::from_raw_parts(self.memory_v.data() as *mut u8, self.memory_v.nbytes()) - }; - - InferenceSnapshotRef { - npast: self.n_past, - session_params: self.params, - tokens: self.tokens.clone(), - logits: self.last_logits.clone(), - memory_k, - memory_v, - } - } -} - -impl<'a> InferenceSnapshotRef<'a> { - /// Write this snapshot to the given writer. - pub fn write(&self, writer: &mut impl std::io::Write) -> Result<(), SnapshotError> { - Ok(bincode::serialize_into(writer, &self)?) - } -} - -impl InferenceSnapshot { - /// Read a snapshot from the given reader. - pub fn read(reader: &mut impl std::io::Read) -> Result { - Ok(bincode::deserialize_from(reader)?) - } -} - -impl Vocabulary { - // SentencePiece implementation after https://guillaume-be.github.io/2020-05-30/sentence_piece - /// Tokenize a `text` with this vocabulary. - /// - /// `bos` controls whether a beginning-of-string token should be inserted. - pub fn tokenize<'a>( - &'a self, - text: &str, - bos: bool, - ) -> Result, InferenceError> { - let len = text.len(); - - let mut score = vec![0usize; len + 1]; - let mut prev = vec![TokenId::default(); len + 1]; - - for i in 0..len { - let max_len = (len - i).min(self.max_token_length); - for sub_len in 1..=max_len { - let sub = &text.as_bytes()[i..i + sub_len]; - let Ok(sub) = std::str::from_utf8(sub) else { continue; }; - let token = self.token_to_id.get(sub); - - if let Some(token) = token { - let token_score = sub.len() * sub.len(); - let local_score = score[i] + token_score; - let next = i + sub_len; - - if score[next] < local_score { - score[next] = local_score; - prev[next] = *token; - } - } - } - } - - // Backward pass - let mut res = vec![]; - let mut i = len; - while i > 0 { - let token_id = prev[i]; - if token_id == 0 { - return Err(InferenceError::TokenizationFailed); - } - let token = self.id_to_token[token_id as usize].as_str(); - res.push((token, token_id)); - i -= token.len(); - } - - if bos { - // TODO: replace with vocab.bos - res.push(("", 1)); - } - - // Pieces are in reverse order so correct that - res.reverse(); - - Ok(res) - } -} +pub mod common; +pub mod models; diff --git a/llama-rs/src/models/bloom.rs b/llama-rs/src/models/bloom.rs new file mode 100644 index 00000000..af87bfc1 --- /dev/null +++ b/llama-rs/src/models/bloom.rs @@ -0,0 +1,979 @@ +use std::{ + collections::HashMap, + fs::File, + io::BufReader, + io::{BufRead, Read, Seek, SeekFrom}, + path::{Path, PathBuf}, +}; + +use crate::common::{helpers::*, inference::*, load::*, model::*, token::*, vocabulary::*}; +use crate::mulf; + +// NOTE: Field order matters! Data is laid out in the file exactly +// in this order. +#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +pub struct Hyperparameters { + pub n_vocab: usize, + pub n_ctx: usize, + pub n_embd: usize, + pub n_mult: usize, + pub n_head: usize, + pub n_layer: usize, + pub f16_: u32, +} + +// default +pub struct Layer { + pub attention_norm: ggml::Tensor, + pub attention_norm_b: ggml::Tensor, + pub wo: ggml::Tensor, + pub wo_b: ggml::Tensor, + pub query_key_value: ggml::Tensor, + pub query_key_value_b: ggml::Tensor, + // normalization + pub ffn_norm: ggml::Tensor, + pub ffn_norm_b: ggml::Tensor, + // ff + pub w1: ggml::Tensor, + pub w1_b: ggml::Tensor, + pub w2: ggml::Tensor, + pub w2_b: ggml::Tensor, +} + +/// The weights for the BLOOM model. All the mutable state is split into a +/// separate struct `InferenceSession`. +pub struct BLOOM { + pub hparams: Hyperparameters, + pub tok_embeddings: ggml::Tensor, + pub norm: ggml::Tensor, + pub norm_b: ggml::Tensor, + pub output_norm: ggml::Tensor, + pub output_norm_b: ggml::Tensor, + pub output: ggml::Tensor, + pub layers: Vec, + pub tensors: HashMap, + // Must be kept alive for the model + pub _context: ggml::Context, +} + +impl Model for BLOOM { + type Weights = BLOOM; + type HP = Hyperparameters; + + fn load( + path: impl AsRef, + n_ctx: usize, + load_progress_callback: impl Fn(LoadProgress), + ) -> Result<(Self::Weights, Vocabulary), LoadError> { + // Load model + + let main_path = path.as_ref(); + + let mut reader = + BufReader::new( + File::open(main_path).map_err(|e| LoadError::OpenFileFailed { + source: e, + path: main_path.to_owned(), + })?, + ); + + // Verify magic + let is_legacy_model: bool = match read_u32(&mut reader)? { + ggml::FILE_MAGIC => false, + ggml::FILE_MAGIC_UNVERSIONED => true, + _ => { + return Err(LoadError::InvalidMagic { + path: main_path.to_owned(), + }) + } + }; + + // Load format version + if !is_legacy_model { + #[allow(unused_variables)] + let version: u32 = match read_u32(&mut reader)? { + ggml::FORMAT_VERSION => ggml::FORMAT_VERSION, + version => return Err(LoadError::InvalidFormatVersion { value: version }), + }; + } + + // ================= + // Load hyper params + // ================= + + // NOTE: Field order matters! Data is laid out in the file exactly + // in this order. + let hparams = Hyperparameters { + n_vocab: read_i32(&mut reader)?.try_into()?, + n_ctx, + n_embd: read_i32(&mut reader)?.try_into()?, + n_mult: read_i32(&mut reader)?.try_into()?, + n_head: read_i32(&mut reader)?.try_into()?, + n_layer: read_i32(&mut reader)?.try_into()?, + f16_: read_i32(&mut reader)?.try_into()?, + }; + + let n_ff = ((4 * hparams.n_embd + hparams.n_mult - 1) / hparams.n_mult) * hparams.n_mult; + + load_progress_callback(LoadProgress::HyperparametersLoaded(hparams)); + + // =============== + // Load vocabulary + // =============== + let vocab = { + let mut id_to_token = vec![]; + let mut id_to_token_score = vec![]; + let mut token_to_id = HashMap::new(); + let mut max_token_length = 0; + + for i in 0..hparams.n_vocab { + let len = read_i32(&mut reader)?; + if let Ok(word) = read_string(&mut reader, len as usize) { + max_token_length = max_token_length.max(word.len()); + id_to_token.push(word.clone()); + token_to_id.insert(word, TokenId::try_from(i)?); + } else { + load_progress_callback(LoadProgress::BadToken { index: i }); + id_to_token.push("�".to_string()); + } + + // Token score, currently unused + if !is_legacy_model { + if let Ok(score) = read_f32(&mut reader) { + id_to_token_score.push(score); + } + } else { + // Legacy model, set empty score + id_to_token_score.push(0.); + } + } + + Vocabulary { + id_to_token, + id_to_token_score, + token_to_id, + max_token_length, + } + }; + + // for the big tensors, we have the option to store the data in 16-bit + // floats or quantized in order to save memory and also to speed up the + // computation + let wtype = match hparams.f16_ { + 0 => ggml::TYPE_F32, + 1 => ggml::TYPE_F16, + 2 => ggml::TYPE_Q4_0, + 3 => ggml::TYPE_Q4_1, + invalid => return Err(LoadError::HyperparametersF16Invalid { value: invalid }), + }; + + let n_embd = hparams.n_embd; + let n_layer = hparams.n_layer; + let n_vocab = hparams.n_vocab; + + let ctx_size = { + // Use 64-bit math to prevent overflow. + let n_embd = n_embd as u64; + let n_layer = n_layer as u64; + let n_vocab = n_vocab as u64; + let n_ff = n_ff as u64; + + let mut ctx_size: u64 = 0; + + ctx_size += mulf!(n_embd, n_vocab, ggml::type_sizef(wtype)); // tok_embeddings + + ctx_size += mulf!(n_embd, ggml::type_sizef(ggml::TYPE_F32)); // norm + ctx_size += mulf!(n_embd, ggml::type_sizef(ggml::TYPE_F32)); // norm_b + + ctx_size += mulf!(n_embd, ggml::type_sizef(ggml::TYPE_F32)); // output_norm + ctx_size += mulf!(n_embd, ggml::type_sizef(ggml::TYPE_F32)); // output_norm_b + + ctx_size += mulf!(n_embd, n_vocab, ggml::type_sizef(wtype)); // output + + ctx_size += mulf!(n_layer, n_embd, ggml::type_sizef(ggml::TYPE_F32)); // attention_norm + ctx_size += mulf!(n_layer, n_embd, ggml::type_sizef(ggml::TYPE_F32)); // attention_norm_b + + ctx_size += mulf!(n_layer, 3, n_embd, n_embd, ggml::type_sizef(wtype)); //query_key_value + ctx_size += mulf!(n_layer, 3, n_embd, ggml::type_sizef(ggml::TYPE_F32)); //query_key_value_b + + ctx_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wo + ctx_size += mulf!(n_layer, n_embd, ggml::type_sizef(ggml::TYPE_F32)); // wo_b + + ctx_size += mulf!(n_layer, n_embd, ggml::type_sizef(ggml::TYPE_F32)); // ffn_norm + ctx_size += mulf!(n_layer, n_embd, ggml::type_sizef(ggml::TYPE_F32)); // ffn_norm_b + + ctx_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w1 + ctx_size += mulf!(n_layer, n_ff, ggml::type_sizef(ggml::TYPE_F32)); // w1_b + + ctx_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w2 + ctx_size += mulf!(n_layer, n_ff, ggml::type_sizef(ggml::TYPE_F32)); // w2_b + + ctx_size += (5 + 10 * n_layer) * 256; // object overhead + + load_progress_callback(LoadProgress::ContextSize { + bytes: ctx_size.try_into()?, + }); + + ctx_size + }; + + // Initialize the context + let context = ggml::Context::init(ctx_size as usize); + + let model = { + let mut tensors = HashMap::new(); + + let tok_embeddings = context.new_tensor_2d(wtype, n_embd, n_vocab); + + let norm = context.new_tensor_1d(ggml::TYPE_F32, n_embd); + let norm_b = context.new_tensor_1d(ggml::TYPE_F32, n_embd); + + let output_norm = context.new_tensor_1d(ggml::TYPE_F32, n_embd); + let output_norm_b = context.new_tensor_1d(ggml::TYPE_F32, n_embd); + + let output = context.new_tensor_2d(wtype, n_embd, n_vocab); + + tensors.insert("tok_embeddings.weight".to_owned(), tok_embeddings.share()); + + tensors.insert("norm.weight".to_owned(), norm.share()); + tensors.insert("norm.bias".to_owned(), norm_b.share()); + + tensors.insert("output_norm.weight".to_owned(), output_norm.share()); + tensors.insert("output_norm.bias".to_owned(), output_norm_b.share()); + + tensors.insert("output.weight".to_owned(), output.share()); + + let mut layers = Vec::new(); + for i in 0..n_layer { + let layer = Layer { + attention_norm: context.new_tensor_1d(ggml::TYPE_F32, n_embd), + attention_norm_b: context.new_tensor_1d(ggml::TYPE_F32, n_embd), + + query_key_value: context.new_tensor_2d(wtype, n_embd, 3 * n_embd), + query_key_value_b: context.new_tensor_1d(ggml::TYPE_F32, 3 * n_embd), + + wo: context.new_tensor_2d(wtype, n_embd, n_embd), + wo_b: context.new_tensor_1d(ggml::TYPE_F32, n_embd), + + ffn_norm: context.new_tensor_1d(ggml::TYPE_F32, n_embd), + ffn_norm_b: context.new_tensor_1d(ggml::TYPE_F32, n_embd), + + w1: context.new_tensor_2d(wtype, n_embd, n_ff), + w1_b: context.new_tensor_1d(ggml::TYPE_F32, n_ff), + w2: context.new_tensor_2d(wtype, n_ff, n_embd), + w2_b: context.new_tensor_1d(ggml::TYPE_F32, n_embd), + }; + + tensors.insert( + format!("layers.{i}.attention_norm.weight"), + layer.attention_norm.share(), + ); + + tensors.insert( + format!("layers.{i}.attention_norm.bias"), + layer.attention_norm_b.share(), + ); + + tensors.insert( + format!("layers.{i}.attention.query_key_value.weight"), + layer.query_key_value.share(), + ); + tensors.insert( + format!("layers.{i}.attention.query_key_value.bias"), + layer.query_key_value_b.share(), + ); + + tensors.insert(format!("layers.{i}.attention.wo.weight"), layer.wo.share()); + tensors.insert(format!("layers.{i}.attention.wo.bias"), layer.wo_b.share()); + + tensors.insert( + format!("layers.{i}.ffn_norm.weight"), + layer.ffn_norm.share(), + ); + tensors.insert( + format!("layers.{i}.ffn_norm.bias"), + layer.ffn_norm_b.share(), + ); + + tensors.insert( + format!("layers.{i}.feed_forward.w1.weight"), + layer.w1.share(), + ); + tensors.insert( + format!("layers.{i}.feed_forward.w1.bias"), + layer.w1_b.share(), + ); + tensors.insert( + format!("layers.{i}.feed_forward.w2.weight"), + layer.w2.share(), + ); + tensors.insert( + format!("layers.{i}.feed_forward.w2.bias"), + layer.w2_b.share(), + ); + + layers.push(layer); + } + + BLOOM { + hparams, + tok_embeddings, + norm, + norm_b, + output_norm, + output_norm_b, + output, + layers, + tensors, + _context: context, + } + }; + + // Close the file, but keep its offset. That way we know how to skip the + // metadata when loading the parts. + let file_offset = reader.stream_position()?; + drop(reader); + + let paths = { + let main_filename = main_path.file_name().and_then(|p| p.to_str()); + + let mut paths: Vec = + std::fs::read_dir(main_path.parent().ok_or_else(|| LoadError::NoParentPath { + path: main_path.to_owned(), + })?)? + .filter_map(Result::ok) + .map(|de| de.path()) + .filter(|p| { + p.file_name() + .and_then(|p| p.to_str()) + .zip(main_filename) + .map(|(part_filename, main_filename)| { + part_filename.starts_with(main_filename) + }) + .unwrap_or(false) + }) + .collect(); + paths.sort(); + paths + }; + + let n_parts = paths.len(); + + for (i, part_path) in paths.into_iter().enumerate() { + let part_id = i; + + load_progress_callback(LoadProgress::PartLoading { + file: part_path.to_path_buf().into(), + current_part: i + 1, + total_parts: n_parts, + }); + + let mut part_reader = BufReader::new(File::open(&part_path)?); + + // Skip metadata + part_reader.seek(SeekFrom::Start(file_offset))?; + + let mut total_size = 0; + let mut n_tensors = 0; + + // Load weights + loop { + // NOTE: Implementation from #![feature(buf_read_has_data_left)] + let is_eof = part_reader.fill_buf().map(|b| b.is_empty())?; + + if is_eof { + break; + } + + let n_dims = usize::try_from(read_i32(&mut part_reader)?)?; + let length = read_i32(&mut part_reader)?; + let ftype = read_u32(&mut part_reader)?; + + let mut nelements: usize = 1; + let mut ne = [1i64, 1i64]; + + #[allow(clippy::needless_range_loop)] + for i in 0..n_dims { + ne[i] = read_i32(&mut part_reader)? as i64; + nelements *= usize::try_from(ne[i])?; + } + + let tensor_name = read_string(&mut part_reader, length as usize)?; + + let Some(tensor) = model.tensors.get(&tensor_name) + else { + return Err(LoadError::UnknownTensor { tensor_name, path: part_path.to_path_buf() }); + }; + + // split_type = 0: split by columns + // split_type = 1: split by rows + // + // split_type = 0: + // regex: + // - tok_embeddings.* + // - layers.*.attention.wo.weight + // - layers.*.feed_forward.w2.weight + + // split_type = 1: + // regex: + // - output.* + // - layers.*.attention.wq.weight + // - layers.*.attention.wk.weight + // - layers.*.attention.wv.weight + // - layers.*.feed_forward.w1.weight + // - layers.*.feed_forward.w3.weight + #[allow(clippy::if_same_then_else)] + let split_type = if tensor_name.contains("tok_embeddings") { + 0 + } else if tensor_name.contains("layers") { + if tensor_name.contains("attention.wo.weight") { + 0 + } else if tensor_name.contains("feed_forward.w2.weight") { + 0 + } else { + 1 + } + } else if tensor_name.contains("output") { + 1 + } else { + 0 + }; + + if n_dims == 1 { + if tensor.nelements() != nelements { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path.to_path_buf(), + }); + } + } else if tensor.nelements() / n_parts != nelements { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path.to_path_buf(), + }); + } + + if n_dims == 1 { + if tensor.get_ne()[0] != ne[0] || tensor.get_ne()[1] != ne[1] { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path.to_path_buf(), + }); + } + } else if split_type == 0 { + if tensor.get_ne()[0] / i64::try_from(n_parts)? != ne[0] + || tensor.get_ne()[1] != ne[1] + { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path.to_path_buf(), + }); + } + } else if tensor.get_ne()[0] != ne[0] + || tensor.get_ne()[1] / i64::try_from(n_parts)? != ne[1] + { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path.to_path_buf(), + }); + } + + let bpe = match ftype { + 0 => ggml::type_size(ggml::TYPE_F32), + 1 => ggml::type_size(ggml::TYPE_F16), + 2 => { + assert_eq!(ne[0] % 64, 0); + ggml::type_size(ggml::TYPE_Q4_0) + } + 3 => { + assert_eq!(ne[0] % 64, 0); + ggml::type_size(ggml::TYPE_Q4_1) + } + _ => { + return Err(LoadError::InvalidFtype { + ftype, + path: part_path.to_path_buf(), + }) + } + }; + + if n_dims == 1 || n_parts == 1 { + if (nelements as usize * bpe) / ggml::blck_size(tensor.get_type()) as usize + != tensor.nbytes() + { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path.to_path_buf(), + }); + } + + if part_id == 0 { + // SAFETY: yolo, same as original code + let slice = unsafe { + let data = tensor.data(); + std::slice::from_raw_parts_mut(data as *mut u8, tensor.nbytes()) + }; + part_reader.read_exact(slice)?; + } else { + part_reader.seek(SeekFrom::Current(tensor.nbytes() as i64))?; + } + + total_size += tensor.nbytes(); + } else { + if (nelements as usize * bpe) / ggml::blck_size(tensor.get_type()) as usize + != tensor.nbytes() / n_parts + { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path.to_path_buf(), + }); + } + + if split_type == 0 { + let np0 = ne[0]; + let row_size = (usize::try_from(tensor.get_ne()[0])? + / ggml::blck_size(tensor.get_type())) + * ggml::type_size(tensor.get_type()); + + assert_eq!(row_size, tensor.get_nb()[1]); + + for i1 in 0..ne[1] { + let offset_row = i1 as usize * row_size; + let offset = offset_row + + ((part_id * np0 as usize) + / ggml::blck_size(tensor.get_type()) as usize) + * ggml::type_size(tensor.get_type()); + // SAFETY: yolo, same as original code + unsafe { + let ptr = tensor.data().add(offset); + let slice = std::slice::from_raw_parts_mut( + ptr as *mut u8, + row_size / n_parts, + ); + part_reader.read_exact(slice)?; + } + } + } else { + let np1 = ne[1]; + + let row_size = (usize::try_from(tensor.get_ne()[0])? + / ggml::blck_size(tensor.get_type())) + * ggml::type_size(tensor.get_type()); + + for i1 in 0..ne[1] { + let offset_row = (i1 as usize + part_id * np1 as usize) * row_size; + // SAFETY: yolo, same as original code + unsafe { + let ptr = tensor.data().add(offset_row); + let slice = + std::slice::from_raw_parts_mut(ptr as *mut u8, row_size); + part_reader.read_exact(slice)?; + } + } + } + + total_size += tensor.nbytes() / n_parts; + } + + n_tensors += 1; + load_progress_callback(LoadProgress::PartTensorLoaded { + file: part_path.clone().into_boxed_path(), + current_tensor: n_tensors.try_into()?, + tensor_count: model.tensors.len(), + }); + } + + load_progress_callback(LoadProgress::PartLoaded { + file: part_path.into_boxed_path(), + byte_size: total_size, + tensor_count: n_tensors.try_into()?, + }); + } + + Ok((model, vocab)) + } + + /// Starts a new `InferenceSession` for this model. + fn start_session(&self, params: InferenceSessionParameters) -> InferenceSession { + let Hyperparameters { + n_ctx, + n_embd, + n_layer, + n_vocab, + .. + } = self.hparams; + + let ctx_size = { + let mut ctx_size = 0; + ctx_size += mulf!( + n_ctx, + n_layer, + n_embd, + ggml::type_sizef(params.memory_k_type.into()) + ); // memory_k + ctx_size += mulf!( + n_ctx, + n_layer, + n_embd, + ggml::type_sizef(params.memory_v_type.into()) + ); // memory_v + ctx_size += (5 + 10 * n_layer as u64) * 256; // object overhead + ctx_size + }; + + let session_ctx = ggml::Context::init(ctx_size as usize); + + // Initialize key + value memory tensors + let n_mem = n_layer * n_ctx; + let n_elements = n_embd * n_mem; + let memory_k = session_ctx.new_tensor_1d(params.memory_k_type.into(), n_elements); + let memory_v = session_ctx.new_tensor_1d(params.memory_v_type.into(), n_elements); + + InferenceSession { + _session_ctx: session_ctx, + params, + memory_k, + memory_v, + n_past: 0, + mem_per_token: 0, + tokens: vec![], + last_logits: vec![0.0; n_vocab as usize], + } + } + + /// Evaluates the transformer. + /// + /// The provided `output_request` struct lets you specify which additional + /// data you are interested in fetching from the transformer. Setting a + /// field to a `Some` value will clear and fill the provided vector with + /// data. The provided vector will be resized to the exact output size. + fn evaluate( + &self, + session: &mut InferenceSession, + params: &InferenceParameters, + input_tokens: &[TokenId], + output_request: &mut EvaluateOutputRequest, + ) { + let n = input_tokens.len(); + let n_past = session.n_past; + let n_threads = params.n_threads; + let increased_determinism = params.increased_determinism; + + let Hyperparameters { + n_vocab, + n_ctx, + n_embd, + n_mult: _, + n_head, + n_layer, + f16_: _, + } = self.hparams; + + // For the first run, we need to guess a maximum buffer size so we can measure + // the actual memory consumption of the temporary ggml context. + let mut buf_size = 1024 * 1024 * 1024; + if session.mem_per_token > 0 && session.mem_per_token * n > buf_size { + // add 10% to account for ggml object overhead + buf_size = (1.1f64 * session.mem_per_token as f64 * n as f64) as usize; + }; + let ctx0 = ggml::Context::init(buf_size); + + // TODO: REMAKE THIS AFTER CHECKING GGML GRAPH + let mut gf = ggml::ComputationGraph::new(n_threads); + + let embd = ctx0.new_tensor_1d(ggml::TYPE_I32, n); + unsafe { embd.write_data(bytemuck::cast_slice(input_tokens)) }; + + let mut input_layer = ctx0.op_get_rows(&self.tok_embeddings, &embd); + + //TODO: word embeddings norm, + { + input_layer = ctx0.op_norm(&input_layer); + input_layer = ctx0.op_mul(&ctx0.op_repeat(&self.norm, &input_layer), &input_layer); + input_layer = ctx0.op_add(&ctx0.op_repeat(&self.norm_b, &input_layer), &input_layer); + } + + // Defined here to avoid repetition and creating a binding inside nested loops. + // See the call site below for more context. + let vtrans_fun = |il: usize| -> ggml::Tensor { + ctx0.op_permute( + &ctx0.op_reshape_3d( + &ctx0.op_view_1d( + &session.memory_v, + (n_past + n) * n_embd, + il * n_ctx as usize * session.memory_v.element_size() * n_embd as usize, + ), + n_embd / n_head, + n_head, + n_past + n, + ), + 1, + 2, + 0, + 3, + ) + }; + + for il in 0..n_layer as usize { + let input_self_attention = input_layer.share(); + let mut current: ggml::Tensor; + + // norm + { + current = ctx0.op_norm(&input_layer); + + // cur = attention_norm * cur + current = ctx0.op_mul( + &ctx0.op_repeat(&self.layers[il].attention_norm, ¤t), + ¤t, + ); + current = ctx0.op_add( + &ctx0.op_repeat(&self.layers[il].attention_norm_b, ¤t), + ¤t, + ); + } + + //attention + { + current = ctx0.op_mul_mat(&self.layers[il].query_key_value, ¤t); + current = ctx0.op_add( + &ctx0.op_repeat(&self.layers[il].query_key_value_b, ¤t), + ¤t, + ); + } + + // self-attention + { + let nb = current.get_nb()[1]; + let q_current = ctx0.op_view_2d( + ¤t, n_embd, n, nb, + //0 * std::mem::size_of::() * n_embd as usize, + 0, + ); + let k_current = + ctx0.op_view_2d(¤t, n_embd, n, nb, std::mem::size_of::() * n_embd); + let v_current = ctx0.op_view_2d( + ¤t, + n_embd, + n, + nb, + 2 * std::mem::size_of::() * n_embd, + ); + + // store key and value to memory + if n >= 1 { + let k = ctx0.op_view_1d( + &session.memory_k, + n * n_embd, + (session.memory_k.element_size() * n_embd as usize) + * (il * n_ctx as usize + n_past as usize), + ); + + let v = ctx0.op_view_1d( + &session.memory_v, + n * n_embd, + (session.memory_v.element_size() * n_embd as usize) + * (il * n_ctx as usize + n_past as usize), + ); + + gf.build_forward_expand(&ctx0.op_cpy(&k_current, &k)); + gf.build_forward_expand(&ctx0.op_cpy(&v_current, &v)); + } + + // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) + let q = ctx0.op_permute( + &ctx0.op_cpy( + &q_current, + &ctx0.new_tensor_3d(ggml::TYPE_F32, n_embd / n_head, n_head, n), + ), + 0, + 2, + 1, + 3, + ); + + // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) + let k = ctx0.op_permute( + &ctx0.op_reshape_3d( + &ctx0.op_view_1d( + &session.memory_k, + (n_past + n) * n_embd, + il * n_ctx * session.memory_k.element_size() * n_embd, + ), + n_embd / n_head, + n_head, + n_past + n, + ), + 0, + 2, + 1, + 3, + ); + + // K * Q + let k_q = ctx0.op_mul_mat(&k, &q); + + // KQ_scaled = KQ / sqrt(n_embd/n_head) + let k_q_scaled = ctx0.op_scale( + &k_q, + &ctx0.new_f32(1.0 / f32::sqrt(n_embd as f32 / n_head as f32)), + ); + + //alibi + // KQ_scaled_alibi = KQ_scaled + alibi_bias + // TODO: op_alibi function + let k_q_scaled_alibi = ctx0.op_alibi(&k_q_scaled, n_past, n_head); + + // KQ_masked = mask_past(KQ_scaled) + let k_q_masked = ctx0.op_diag_mask_inf(&k_q_scaled_alibi, n_past); + + // KQ = soft_max(KQ_masked) + let k_q_soft_max = ctx0.op_soft_max(&k_q_masked); + + // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() + let v_transposed = { + if !increased_determinism { + vtrans_fun(il) + } else { + ctx0.op_cpy( + &vtrans_fun(il), + &ctx0.new_tensor_3d( + ggml::TYPE_F32, + n_past + n, + n_embd / n_head, + n_head, + ), + ) + } + }; + + // KQV = transpose(V) * KQ_soft_max + let k_q_v = ctx0.op_mul_mat(&v_transposed, &k_q_soft_max); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + let k_q_v_merged = ctx0.op_permute(&k_q_v, 0, 2, 1, 3); + + // cur = KQV_merged.contiguous().view(n_embd, N) + current = ctx0.op_cpy( + &k_q_v_merged, + &ctx0.new_tensor_2d(ggml::TYPE_F32, n_embd, n), + ); + + // projection + current = ctx0.op_mul_mat(&self.layers[il].wo, ¤t); + current = ctx0.op_add(&ctx0.op_repeat(&self.layers[il].wo_b, ¤t), ¤t); + } + + let input_feed_forward = ctx0.op_add(¤t, &input_self_attention); + + // feed-forward network + { + // norm + { + current = ctx0.op_norm(&input_feed_forward); + + // cur = ffn_norm*cur + ffn_norm_b + current = ctx0.op_mul( + &ctx0.op_repeat(&self.layers[il].ffn_norm, ¤t), + ¤t, + ); + + current = ctx0.op_add( + &ctx0.op_repeat(&self.layers[il].ffn_norm_b, ¤t), + ¤t, + ); + } + + current = ctx0.op_mul_mat(&self.layers[il].w1, ¤t); + + current = ctx0.op_add(&ctx0.op_repeat(&self.layers[il].w1_b, ¤t), ¤t); + + // SILU activation + + current = ctx0.op_gelu(¤t); + + current = ctx0.op_mul_mat(&self.layers[il].w2, ¤t); + + current = ctx0.op_add(&ctx0.op_repeat(&self.layers[il].w2_b, ¤t), ¤t); + } + + current = ctx0.op_add(¤t, &input_feed_forward); + + // input for next layer + input_layer = current; + } + + // Used at the end to optionally extract the embeddings. + let embeddings_tensor; + + // norm + { + input_layer = ctx0.op_norm(&input_layer); + + // inpL = norm*inpL + input_layer = ctx0.op_mul( + &ctx0.op_repeat(&self.output_norm, &input_layer), + &input_layer, + ); + + input_layer = ctx0.op_add( + &ctx0.op_repeat(&self.output_norm_b, &input_layer), + &input_layer, + ); + + embeddings_tensor = input_layer.share(); //TODO: CHECK if this is still necessary, (not in BLOOM C implementation) + } + + // lm_head + { + input_layer = ctx0.op_mul_mat(&self.output, &input_layer); + } + + // logits -> probs + // inpL = ctx0.op_soft_max(&inpL); + + // run the computation + gf.build_forward_expand(&input_layer); + ctx0.graph_compute(&mut gf); + + // return result for just the last token + // SAFETY: yolo + assert_eq!(session.last_logits.len(), n_vocab as usize); + unsafe { + input_layer.read_data( + n_vocab as usize * (n - 1) * std::mem::size_of::(), + bytemuck::cast_slice_mut(&mut session.last_logits), + ) + }; + + // Extract logits + if let Some(all_logits) = &mut output_request.all_logits { + all_logits.resize(n_vocab as usize * n, 0.0); + // SAFETY: Tensor data can be read (properly aligned, initialized, + // data will not be mutated or otherwise aliased during the copy), + // and we're not reading past the end of the tensor data. + assert_eq!(input_layer.nelements(), n_vocab * n); + unsafe { + input_layer.read_data(0, bytemuck::cast_slice_mut(all_logits)); + } + } + + // Extract embeddings + if let Some(embeddings) = &mut output_request.embeddings { + embeddings.resize(n_embd as usize * n, 0.0); + // SAFETY: Same rationale as for the "Extract logits" section applies. + assert_eq!(embeddings_tensor.nelements(), n_embd * n); + unsafe { + embeddings_tensor.read_data(0, bytemuck::cast_slice_mut(embeddings)); + } + } + + // Adjust the required memory per token if we didn't know that already + if session.mem_per_token == 0 { + session.mem_per_token = ctx0.used_mem() / n; + } + + // Adjust n_past to new length. + session.n_past += input_tokens.len(); + } +} diff --git a/llama-rs/src/models/llama.rs b/llama-rs/src/models/llama.rs new file mode 100644 index 00000000..b467e72c --- /dev/null +++ b/llama-rs/src/models/llama.rs @@ -0,0 +1,878 @@ +use std::{ + collections::HashMap, + fs::File, + io::{BufRead, BufReader, Read, Seek, SeekFrom}, + path::{Path, PathBuf}, +}; + +use crate::common::{helpers::*, inference::*, load::*, model::*, token::*, vocabulary::*}; +use crate::mulf; + +// NOTE: Field order matters! Data is laid out in the file exactly +// in this order. +#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +pub struct Hyperparameters { + pub n_vocab: usize, + pub n_ctx: usize, + pub n_embd: usize, + pub n_mult: usize, + pub n_head: usize, + pub n_layer: usize, + pub n_rot: usize, + pub f16_: u32, +} + +// default +pub struct Layer { + pub attention_norm: ggml::Tensor, + pub wq: ggml::Tensor, + pub wk: ggml::Tensor, + pub wv: ggml::Tensor, + pub wo: ggml::Tensor, + // normalization + pub ffn_norm: ggml::Tensor, + // ff + pub w1: ggml::Tensor, + pub w2: ggml::Tensor, + pub w3: ggml::Tensor, +} + +/// The weights for the LLaMA model. All the mutable state is split into a +/// separate struct `InferenceSession`. +pub struct Llama { + pub hparams: Hyperparameters, + pub tok_embeddings: ggml::Tensor, + pub norm: ggml::Tensor, + pub output: ggml::Tensor, + pub layers: Vec, + pub tensors: HashMap, + // Must be kept alive for the model + pub _context: ggml::Context, +} + +impl Model for Llama { + type Weights = Llama; + type HP = Hyperparameters; + + fn load( + path: impl AsRef, + n_ctx: usize, + load_progress_callback: impl Fn(LoadProgress), + ) -> Result<(Self::Weights, Vocabulary), LoadError> { + let main_path = path.as_ref(); + + let mut reader = + BufReader::new( + File::open(main_path).map_err(|e| LoadError::OpenFileFailed { + source: e, + path: main_path.to_owned(), + })?, + ); + + // Verify magic + let is_legacy_model: bool = match read_u32(&mut reader)? { + ggml::FILE_MAGIC => false, + ggml::FILE_MAGIC_UNVERSIONED => true, + _ => { + return Err(LoadError::InvalidMagic { + path: main_path.to_owned(), + }) + } + }; + + // Load format version + if !is_legacy_model { + #[allow(unused_variables)] + let version: u32 = match read_u32(&mut reader)? { + ggml::FORMAT_VERSION => ggml::FORMAT_VERSION, + version => return Err(LoadError::InvalidFormatVersion { value: version }), + }; + } + // ================= + // Load hyper params + // ================= + + // NOTE: Field order matters! Data is laid out in the file exactly + // in this order. + let hparams = Hyperparameters { + n_vocab: read_i32(&mut reader)?.try_into()?, + n_ctx, + n_embd: read_i32(&mut reader)?.try_into()?, + n_mult: read_i32(&mut reader)?.try_into()?, + n_head: read_i32(&mut reader)?.try_into()?, + n_layer: read_i32(&mut reader)?.try_into()?, + n_rot: read_i32(&mut reader)?.try_into()?, + f16_: read_i32(&mut reader)?.try_into()?, + }; + + let n_ff = + ((2 * (4 * hparams.n_embd) / 3 + hparams.n_mult - 1) / hparams.n_mult) * hparams.n_mult; + + load_progress_callback(LoadProgress::HyperparametersLoaded(hparams)); + + // =============== + // Load vocabulary + // =============== + let vocab = { + let mut id_to_token = vec![]; + let mut id_to_token_score = vec![]; + let mut token_to_id = HashMap::new(); + let mut max_token_length = 0; + + for i in 0..hparams.n_vocab { + let len = read_i32(&mut reader)?; + if let Ok(word) = read_string(&mut reader, len as usize) { + max_token_length = max_token_length.max(word.len()); + id_to_token.push(word.clone()); + token_to_id.insert(word, TokenId::try_from(i)?); + } else { + load_progress_callback(LoadProgress::BadToken { index: i }); + id_to_token.push("�".to_string()); + } + + // Token score, currently unused + if !is_legacy_model { + if let Ok(score) = read_f32(&mut reader) { + id_to_token_score.push(score); + } + } else { + // Legacy model, set empty score + id_to_token_score.push(0.); + } + } + + Vocabulary { + id_to_token, + id_to_token_score, + token_to_id, + max_token_length, + } + }; + + // for the big tensors, we have the option to store the data in 16-bit + // floats or quantized in order to save memory and also to speed up the + // computation + let wtype = match hparams.f16_ { + 0 => ggml::TYPE_F32, + 1 => ggml::TYPE_F16, + 2 => ggml::TYPE_Q4_0, + 3 => ggml::TYPE_Q4_1, + invalid => return Err(LoadError::HyperparametersF16Invalid { value: invalid }), + }; + + let n_embd = hparams.n_embd; + let n_layer = hparams.n_layer; + let n_vocab = hparams.n_vocab; + + let ctx_size = { + // Use 64-bit math to prevent overflow. + let n_embd = n_embd as u64; + let n_layer = n_layer as u64; + let n_vocab = n_vocab as u64; + let n_ff = n_ff as u64; + + let mut ctx_size: u64 = 0; + + ctx_size += mulf!(n_embd, n_vocab, ggml::type_sizef(wtype)); // tok_embeddings + + ctx_size += mulf!(n_embd, ggml::type_sizef(ggml::TYPE_F32)); // norm + + ctx_size += mulf!(n_embd, n_vocab, ggml::type_sizef(wtype)); // output + + ctx_size += mulf!(n_layer, n_embd, ggml::type_sizef(ggml::TYPE_F32)); // attention_norm + + ctx_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wq + ctx_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wk + ctx_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wv + ctx_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wo + + ctx_size += mulf!(n_layer, n_embd, ggml::type_sizef(ggml::TYPE_F32)); // ffn_norm + + ctx_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w1 + ctx_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w2 + ctx_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w3 + + ctx_size += (5 + 10 * n_layer) * 256; // object overhead + + load_progress_callback(LoadProgress::ContextSize { + bytes: ctx_size.try_into()?, + }); + + ctx_size + }; + + // Initialize the context + let context = ggml::Context::init(ctx_size as usize); + + let model = { + let mut tensors = HashMap::new(); + + let tok_embeddings = context.new_tensor_2d(wtype, n_embd, n_vocab); + let norm = context.new_tensor_1d(ggml::TYPE_F32, n_embd); + let output = context.new_tensor_2d(wtype, n_embd, n_vocab); + + tensors.insert("tok_embeddings.weight".to_owned(), tok_embeddings.share()); + tensors.insert("norm.weight".to_owned(), norm.share()); + tensors.insert("output.weight".to_owned(), output.share()); + + let mut layers = Vec::new(); + for i in 0..n_layer { + let layer = Layer { + attention_norm: context.new_tensor_1d(ggml::TYPE_F32, n_embd), + wq: context.new_tensor_2d(wtype, n_embd, n_embd), + wk: context.new_tensor_2d(wtype, n_embd, n_embd), + wv: context.new_tensor_2d(wtype, n_embd, n_embd), + wo: context.new_tensor_2d(wtype, n_embd, n_embd), + ffn_norm: context.new_tensor_1d(ggml::TYPE_F32, n_embd), + w1: context.new_tensor_2d(wtype, n_embd, n_ff), + w2: context.new_tensor_2d(wtype, n_ff, n_embd), + w3: context.new_tensor_2d(wtype, n_embd, n_ff), + }; + + tensors.insert( + format!("layers.{i}.attention_norm.weight"), + layer.attention_norm.share(), + ); + + tensors.insert(format!("layers.{i}.attention.wq.weight"), layer.wq.share()); + tensors.insert(format!("layers.{i}.attention.wk.weight"), layer.wk.share()); + tensors.insert(format!("layers.{i}.attention.wv.weight"), layer.wv.share()); + tensors.insert(format!("layers.{i}.attention.wo.weight"), layer.wo.share()); + + tensors.insert( + format!("layers.{i}.ffn_norm.weight"), + layer.ffn_norm.share(), + ); + + tensors.insert( + format!("layers.{i}.feed_forward.w1.weight"), + layer.w1.share(), + ); + tensors.insert( + format!("layers.{i}.feed_forward.w2.weight"), + layer.w2.share(), + ); + tensors.insert( + format!("layers.{i}.feed_forward.w3.weight"), + layer.w3.share(), + ); + + layers.push(layer); + } + + Llama { + hparams, + tok_embeddings, + norm, + output, + layers, + tensors, + _context: context, + } + }; + + // Close the file, but keep its offset. That way we know how to skip the + // metadata when loading the parts. + let file_offset = reader.stream_position()?; + drop(reader); + + let paths = { + let main_filename = main_path.file_name().and_then(|p| p.to_str()); + + let mut paths: Vec = + std::fs::read_dir(main_path.parent().ok_or_else(|| LoadError::NoParentPath { + path: main_path.to_owned(), + })?)? + .filter_map(Result::ok) + .map(|de| de.path()) + .filter(|p| { + p.file_name() + .and_then(|p| p.to_str()) + .zip(main_filename) + .map(|(part_filename, main_filename)| { + part_filename.starts_with(main_filename) + }) + .unwrap_or(false) + }) + .collect(); + paths.sort(); + paths + }; + + let n_parts = paths.len(); + + for (i, part_path) in paths.into_iter().enumerate() { + let part_id = i; + + load_progress_callback(LoadProgress::PartLoading { + file: part_path.clone().into_boxed_path(), + current_part: i + 1, + total_parts: n_parts, + }); + + let mut part_reader = BufReader::new(File::open(&part_path)?); + + // Skip metadata + part_reader.seek(SeekFrom::Start(file_offset))?; + + let mut total_size = 0; + let mut n_tensors = 0; + + // Load weights + loop { + // NOTE: Implementation from #![feature(buf_read_has_data_left)] + let is_eof = part_reader.fill_buf().map(|b| b.is_empty())?; + + if is_eof { + break; + } + + let n_dims = usize::try_from(read_i32(&mut part_reader)?)?; + let length = read_i32(&mut part_reader)?; + let ftype = read_u32(&mut part_reader)?; + + let mut nelements: usize = 1; + let mut ne = [1i64, 1i64]; + + #[allow(clippy::needless_range_loop)] + for i in 0..n_dims { + ne[i] = read_i32(&mut part_reader)? as i64; + nelements *= usize::try_from(ne[i])?; + } + + let tensor_name = read_string(&mut part_reader, length as usize)?; + + let Some(tensor) = model.tensors.get(&tensor_name) + else { + return Err(LoadError::UnknownTensor { tensor_name, path: part_path }); + }; + + // split_type = 0: split by columns + // split_type = 1: split by rows + // + // split_type = 0: + // regex: + // - tok_embeddings.* + // - layers.*.attention.wo.weight + // - layers.*.feed_forward.w2.weight + + // split_type = 1: + // regex: + // - output.* + // - layers.*.attention.wq.weight + // - layers.*.attention.wk.weight + // - layers.*.attention.wv.weight + // - layers.*.feed_forward.w1.weight + // - layers.*.feed_forward.w3.weight + #[allow(clippy::if_same_then_else)] + let split_type = if tensor_name.contains("tok_embeddings") { + 0 + } else if tensor_name.contains("layers") { + if tensor_name.contains("attention.wo.weight") { + 0 + } else if tensor_name.contains("feed_forward.w2.weight") { + 0 + } else { + 1 + } + } else if tensor_name.contains("output") { + 1 + } else { + 0 + }; + + if n_dims == 1 { + if tensor.nelements() != nelements { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + } else if tensor.nelements() / n_parts != nelements { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + + if n_dims == 1 { + if tensor.get_ne()[0] != ne[0] || tensor.get_ne()[1] != ne[1] { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + } else if split_type == 0 { + if tensor.get_ne()[0] / i64::try_from(n_parts)? != ne[0] + || tensor.get_ne()[1] != ne[1] + { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + } else if tensor.get_ne()[0] != ne[0] + || tensor.get_ne()[1] / i64::try_from(n_parts)? != ne[1] + { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + + let bpe = match ftype { + 0 => ggml::type_size(ggml::TYPE_F32), + 1 => ggml::type_size(ggml::TYPE_F16), + 2 => { + assert_eq!(ne[0] % 64, 0); + ggml::type_size(ggml::TYPE_Q4_0) + } + 3 => { + assert_eq!(ne[0] % 64, 0); + ggml::type_size(ggml::TYPE_Q4_1) + } + _ => { + return Err(LoadError::InvalidFtype { + ftype, + path: part_path, + }) + } + }; + + if n_dims == 1 || n_parts == 1 { + if (nelements as usize * bpe) / ggml::blck_size(tensor.get_type()) as usize + != tensor.nbytes() + { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + + if part_id == 0 { + // SAFETY: yolo, same as original code + let slice = unsafe { + let data = tensor.data(); + std::slice::from_raw_parts_mut(data as *mut u8, tensor.nbytes()) + }; + part_reader.read_exact(slice)?; + } else { + part_reader.seek(SeekFrom::Current(tensor.nbytes() as i64))?; + } + + total_size += tensor.nbytes(); + } else { + if (nelements as usize * bpe) / ggml::blck_size(tensor.get_type()) as usize + != tensor.nbytes() / n_parts + { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + + if split_type == 0 { + let np0 = ne[0]; + let row_size = (usize::try_from(tensor.get_ne()[0])? + / ggml::blck_size(tensor.get_type())) + * ggml::type_size(tensor.get_type()); + + assert_eq!(row_size, tensor.get_nb()[1]); + + for i1 in 0..ne[1] { + let offset_row = i1 as usize * row_size; + let offset = offset_row + + ((part_id * np0 as usize) + / ggml::blck_size(tensor.get_type()) as usize) + * ggml::type_size(tensor.get_type()); + // SAFETY: yolo, same as original code + unsafe { + let ptr = tensor.data().add(offset); + let slice = std::slice::from_raw_parts_mut( + ptr as *mut u8, + row_size / n_parts, + ); + part_reader.read_exact(slice)?; + } + } + } else { + let np1 = ne[1]; + let row_size = (usize::try_from(tensor.get_ne()[0])? + / ggml::blck_size(tensor.get_type())) + * ggml::type_size(tensor.get_type()); + + for i1 in 0..ne[1] { + let offset_row = (i1 as usize + part_id * np1 as usize) * row_size; + // SAFETY: yolo, same as original code + unsafe { + let ptr = tensor.data().add(offset_row); + let slice = + std::slice::from_raw_parts_mut(ptr as *mut u8, row_size); + part_reader.read_exact(slice)?; + } + } + } + + total_size += tensor.nbytes() / n_parts; + } + + n_tensors += 1; + load_progress_callback(LoadProgress::PartTensorLoaded { + file: part_path.clone().into_boxed_path(), + current_tensor: n_tensors.try_into()?, + tensor_count: model.tensors.len(), + }); + } + + load_progress_callback(LoadProgress::PartLoaded { + file: part_path.into_boxed_path(), + byte_size: total_size, + tensor_count: n_tensors.try_into()?, + }); + } + + Ok((model, vocab)) + } + + fn start_session(&self, params: InferenceSessionParameters) -> InferenceSession { + let Hyperparameters { + n_ctx, + n_embd, + n_layer, + n_vocab, + .. + } = self.hparams; + + let ctx_size = { + let mut ctx_size = 0; + ctx_size += mulf!( + n_ctx, + n_layer, + n_embd, + ggml::type_sizef(params.memory_k_type.into()) + ); // memory_k + ctx_size += mulf!( + n_ctx, + n_layer, + n_embd, + ggml::type_sizef(params.memory_v_type.into()) + ); // memory_v + ctx_size += (5 + 10 * n_layer as u64) * 256; // object overhead + ctx_size + }; + + let session_ctx = ggml::Context::init(ctx_size as usize); + + // Initialize key + value memory tensors + let n_mem = n_layer * n_ctx; + let n_elements = n_embd * n_mem; + let memory_k = session_ctx.new_tensor_1d(params.memory_k_type.into(), n_elements); + let memory_v = session_ctx.new_tensor_1d(params.memory_v_type.into(), n_elements); + + InferenceSession { + _session_ctx: session_ctx, + params, + memory_k, + memory_v, + n_past: 0, + mem_per_token: 0, + tokens: vec![], + last_logits: vec![0.0; n_vocab as usize], + } + } + + /// Evaluates the transformer. + /// + /// The provided `output_request` struct lets you specify which additional + /// data you are interested in fetching from the transformer. Setting a + /// field to a `Some` value will clear and fill the provided vector with + /// data. The provided vector will be resized to the exact output size. + fn evaluate( + &self, + session: &mut InferenceSession, + params: &InferenceParameters, + input_tokens: &[TokenId], + output_request: &mut EvaluateOutputRequest, + ) { + let n = input_tokens.len(); + let n_past = session.n_past; + let n_threads = params.n_threads; + let increased_determinism = params.increased_determinism; + + let Hyperparameters { + n_vocab, + n_ctx, + n_embd, + n_mult: _, + n_head, + n_layer, + n_rot, + f16_: _, + } = self.hparams; + + // For the first run, we need to guess a maximum buffer size so we can measure + // the actual memory consumption of the temporary ggml context. + let mut buf_size = 1024 * 1024 * 1024; + if session.mem_per_token > 0 && session.mem_per_token * n > buf_size { + // add 10% to account for ggml object overhead + buf_size = (1.1f64 * session.mem_per_token as f64 * n as f64) as usize; + }; + let ctx0 = ggml::Context::init(buf_size); + + let mut gf = ggml::ComputationGraph::new(n_threads); + + let embd = ctx0.new_tensor_1d(ggml::TYPE_I32, n); + unsafe { embd.write_data(bytemuck::cast_slice(input_tokens)) }; + + let mut input_layer = ctx0.op_get_rows(&self.tok_embeddings, &embd); + + // Defined here to avoid repetition and creating a binding inside nested loops. + // See the call site below for more context. + let vtrans_fun = |il: usize| -> ggml::Tensor { + ctx0.op_permute( + &ctx0.op_reshape_3d( + &ctx0.op_view_1d( + &session.memory_v, + (n_past + n) * n_embd, + il * n_ctx as usize * session.memory_v.element_size() * n_embd as usize, + ), + n_embd / n_head, + n_head, + n_past + n, + ), + 1, + 2, + 0, + 3, + ) + }; + + for il in 0..n_layer as usize { + let input_self_attention = input_layer.share(); + let mut current: ggml::Tensor; + + // norm + { + current = ctx0.op_norm(&input_layer); + + // cur = attention_norm * cur + current = ctx0.op_mul( + &ctx0.op_repeat(&self.layers[il].attention_norm, ¤t), + ¤t, + ); + } + + // self-attention + { + let q_current = ctx0.op_mul_mat(&self.layers[il].wq, ¤t); + let k_current = ctx0.op_mul_mat(&self.layers[il].wk, ¤t); + let v_current = ctx0.op_mul_mat(&self.layers[il].wv, ¤t); + + // store key and value to memory + if n >= 1 { + let k = ctx0.op_view_1d( + &session.memory_k, + n * n_embd, + (session.memory_k.element_size() * n_embd as usize) + * (il * n_ctx as usize + n_past as usize), + ); + + let v = ctx0.op_view_1d( + &session.memory_v, + n * n_embd, + (session.memory_v.element_size() * n_embd as usize) + * (il * n_ctx as usize + n_past as usize), + ); + + gf.build_forward_expand(&ctx0.op_cpy(&k_current, &k)); + gf.build_forward_expand(&ctx0.op_cpy(&v_current, &v)); + } + + // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) + let q = ctx0.op_permute( + &ctx0.op_rope( + &ctx0.op_cpy( + &q_current, + &ctx0.new_tensor_3d(ggml::TYPE_F32, n_embd / n_head, n_head, n), + ), + n_past, + n_rot, + 0, + ), + 0, + 2, + 1, + 3, + ); + + // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) + let k = ctx0.op_permute( + &ctx0.op_rope( + &ctx0.op_reshape_3d( + &ctx0.op_view_1d( + &session.memory_k, + (n_past + n) * n_embd, + il * n_ctx as usize + * session.memory_k.element_size() + * n_embd as usize, + ), + n_embd / n_head, + n_head, + n_past + n, + ), + n_past, + n_rot, + 1, + ), + 0, + 2, + 1, + 3, + ); + + // K * Q + let k_q = ctx0.op_mul_mat(&k, &q); + + // KQ_scaled = KQ / sqrt(n_embd/n_head) + let k_q_scaled = ctx0.op_scale( + &k_q, + &ctx0.new_f32(1.0 / f32::sqrt(n_embd as f32 / n_head as f32)), + ); + + // KQ_masked = mask_past(KQ_scaled) + let k_q_masked = ctx0.op_diag_mask_inf(&k_q_scaled, n_past); + + // KQ = soft_max(KQ_masked) + let k_q_soft_max = ctx0.op_soft_max(&k_q_masked); + + // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() + let v_transposed = { + if !increased_determinism { + vtrans_fun(il) + } else { + ctx0.op_cpy( + &vtrans_fun(il), + &ctx0.new_tensor_3d( + ggml::TYPE_F32, + n_past + n, + n_embd / n_head, + n_head, + ), + ) + } + }; + + // KQV = transpose(V) * KQ_soft_max + let k_q_v = ctx0.op_mul_mat(&v_transposed, &k_q_soft_max); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + let k_q_v_merged = ctx0.op_permute(&k_q_v, 0, 2, 1, 3); + + // cur = KQV_merged.contiguous().view(n_embd, N) + current = ctx0.op_cpy( + &k_q_v_merged, + &ctx0.new_tensor_2d(ggml::TYPE_F32, n_embd, n), + ); + + // projection (no bias) + current = ctx0.op_mul_mat(&self.layers[il].wo, ¤t); + } + + let input_feed_forward = ctx0.op_add(¤t, &input_self_attention); + + // feed-forward network + { + // norm + { + current = ctx0.op_norm(&input_feed_forward); + + // cur = ffn_norm*cur + current = ctx0.op_mul( + &ctx0.op_repeat(&self.layers[il].ffn_norm, ¤t), + ¤t, + ); + } + + let tmp = ctx0.op_mul_mat(&self.layers[il].w3, ¤t); + + current = ctx0.op_mul_mat(&self.layers[il].w1, ¤t); + + // SILU activation + current = ctx0.op_silu(¤t); + + current = ctx0.op_mul(¤t, &tmp); + + current = ctx0.op_mul_mat(&self.layers[il].w2, ¤t); + } + + current = ctx0.op_add(¤t, &input_feed_forward); + + // input for next layer + input_layer = current; + } + + // Used at the end to optionally extract the embeddings. + let embeddings_tensor; + + // norm + { + input_layer = ctx0.op_norm(&input_layer); + + // inpL = norm*inpL + input_layer = ctx0.op_mul(&ctx0.op_repeat(&self.norm, &input_layer), &input_layer); + embeddings_tensor = input_layer.share(); + } + + // lm_head + { + input_layer = ctx0.op_mul_mat(&self.output, &input_layer); + } + + // logits -> probs + // inpL = ctx0.op_soft_max(&inpL); + + // run the computation + gf.build_forward_expand(&input_layer); + ctx0.graph_compute(&mut gf); + + // return result for just the last token + // SAFETY: yolo + assert_eq!(session.last_logits.len(), n_vocab as usize); + unsafe { + input_layer.read_data( + n_vocab as usize * (n - 1) * std::mem::size_of::(), + bytemuck::cast_slice_mut(&mut session.last_logits), + ) + }; + + // Extract logits + if let Some(all_logits) = &mut output_request.all_logits { + all_logits.resize(n_vocab as usize * n, 0.0); + // SAFETY: Tensor data can be read (properly aligned, initialized, + // data will not be mutated or otherwise aliased during the copy), + // and we're not reading past the end of the tensor data. + assert_eq!(input_layer.nelements(), n_vocab * n); + unsafe { + input_layer.read_data(0, bytemuck::cast_slice_mut(all_logits)); + } + } + + // Extract embeddings + if let Some(embeddings) = &mut output_request.embeddings { + embeddings.resize(n_embd as usize * n, 0.0); + // SAFETY: Same rationale as for the "Extract logits" section applies. + assert_eq!(embeddings_tensor.nelements(), n_embd * n); + unsafe { + embeddings_tensor.read_data(0, bytemuck::cast_slice_mut(embeddings)); + } + } + + // Adjust the required memory per token if we didn't know that already + if session.mem_per_token == 0 { + session.mem_per_token = ctx0.used_mem() / n; + } + + // Adjust n_past to new length. + session.n_past += input_tokens.len(); + } +} diff --git a/llama-rs/src/models/mod.rs b/llama-rs/src/models/mod.rs new file mode 100644 index 00000000..9793a8a1 --- /dev/null +++ b/llama-rs/src/models/mod.rs @@ -0,0 +1,2 @@ +pub mod bloom; +pub mod llama;