Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Embedding extraction #72

Merged
merged 8 commits into from
Mar 26, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions llama-cli/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{convert::Infallible, io::Write};

use cli_args::CLI_ARGS;
use llama_rs::EvaluateOutputRequest;
use llama_rs::{
InferenceError, InferenceParameters, InferenceSessionParameters, InferenceSnapshot,
ModelKVMemoryType, TokenBias, Vocabulary, EOD_TOKEN_ID,
Expand Down Expand Up @@ -98,6 +99,8 @@ fn main() {

let args = &*CLI_ARGS;

ad_hoc_test();

let inference_params = InferenceParameters {
n_threads: args.num_threads as i32,
n_batch: args.batch_size,
Expand Down Expand Up @@ -314,3 +317,81 @@ fn main() {
}
}
}

fn ad_hoc_test() {
philpax marked this conversation as resolved.
Show resolved Hide resolved
let (mut model, vocab) = llama_rs::Model::load(&CLI_ARGS.model_path, 2048, |_| {}).unwrap();
let mut session = model.start_session(InferenceSessionParameters {
last_n_size: 64,
memory_k_type: ModelKVMemoryType::Float32,
memory_v_type: ModelKVMemoryType::Float32,
});

let inference_params = InferenceParameters {
n_threads: 8,
n_batch: 8,
top_k: 40,
top_p: 0.95,
repeat_penalty: 1.30,
temp: 0.80,
bias_tokens: TokenBias::default(),
};

session.feed_prompt::<Infallible>(
&model,
&vocab,
&inference_params,
"My favourite animal is the ",
|_| Ok(()),
);

let mut output_request = EvaluateOutputRequest {
all_logits: None,
embeddings: Some(Vec::new()),
};

let dog = model.tokenize(&vocab, "dog", false).unwrap();
model.evaluate(&mut session, 8, &dog, &mut output_request);

println!("Embeddings for 'dog' (1):");
for x in output_request.embeddings.as_ref().unwrap() {
//print!("{x:.2} ")
}
println!();

let mut session2 = model.start_session(InferenceSessionParameters {
last_n_size: 64,
memory_k_type: ModelKVMemoryType::Float32,
memory_v_type: ModelKVMemoryType::Float32,
});
session2.feed_prompt::<Infallible>(
&model,
&vocab,
&inference_params,
"I have just adopted a cute ",
|_| Ok(()),
);
let mut output_request2 = EvaluateOutputRequest {
all_logits: None,
embeddings: Some(Vec::new()),
};

// Try other words: 'dog', 'cat', 'potato', '$' -> To see decreasingly lower dot product values.
let dog2 = model.tokenize(&vocab, "dog", false).unwrap();
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I'm doing here is feeding the following two sentences through the transformer:

  • "My favourite animal is the dog"
  • "I just adopted a cute dog"

Afterwards, I retrieve the embeddings for the last token (dog), and compute their similarity with a simple dot product.

Then, I tried changing the second sentence from 'dog' to 'cat', 'potato', '$' respectively, and the semantic similarity dropped accordingly, with $ ranking the lowest.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@setzer22 will feed prmpt before eval has different embeddings compared to eval all tokens together?

Copy link
Collaborator Author

@setzer22 setzer22 Mar 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hlhr202 The embeddings wouldn't be affected, but you shouldn't call evaluate with the whole prompt like that for a couple reasons:

  • A call to evaluate runs all the tokens you give it as a batch, meaning it requires increased memory usage. For very long prompts, this could become very expensive.
  • The output value will return the output embeddings for every token that you fed through eval. This means you would be retrieving a lot more embedding data than for just the word "dog".

This is why the test code uses feed_prompt first, to set up the context, and then makes a call to evaluate with a single token to retrieve the embeddings for a single word.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@setzer22 I just understand your comments here. this means we can only extract embeddings for a single part of words (which may also have hidden information mixed with context of the whole sentence). that should a little bit different with OpenAI's embedding function. what i understand for openai's embedding, is for the whole sentence but at the same time returned in a fixed size of tensor... that is quite beyond my knowledge though

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well I guess I might find possible ways to implement such 'sentence embedding', I will try add some special end token and extract the hidden layer once the end token evaluated. not sure if it works, but it must worth a try.

model.evaluate(&mut session2, 8, &dog2, &mut output_request2);

// Compute the dot product between embeddings from output_request and output_request2
let mut dot_product = 0.0;
for (x, y) in output_request
.embeddings
.as_ref()
.unwrap()
.iter()
.zip(output_request2.embeddings.as_ref().unwrap())
{
dot_product += x * y;
}

println!("Dot product {dot_product}");

std::process::exit(0);
}
59 changes: 56 additions & 3 deletions llama-rs/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod ggml;

use core::slice;
use std::{
collections::{HashMap, VecDeque},
fmt::Display,
Expand Down Expand Up @@ -396,6 +397,17 @@ pub enum InferenceError {
UserCallback(Box<dyn std::error::Error>),
}

/// Used in a call to `evaluate` to request information from the transformer.
#[derive(Default)]
pub struct EvaluateOutputRequest {
/// Returns all the logits for the provided batch of tokens.
/// Output shape is n_batch * n_vocab
pub all_logits: Option<Vec<f32>>,
/// Returns the embeddings for the provided batch of tokens
/// Output shape is n_batch * n_embd
pub embeddings: Option<Vec<f32>>,
}

/// NOTE: The original code relies in promotion rules and automatic cast between
/// int to float. What we do instead is use this macro to convert every term of
/// the multiplication to f64, which should have enough precision bits to hold
Expand Down Expand Up @@ -1065,11 +1077,15 @@ impl Model {
}

/// Evaluates the transformer.
///
/// The provided `output_request` struct lets you specify which additional
/// data you are interested in fetching from the transformer. Setting a field to a `Some` value will fill the
philpax marked this conversation as resolved.
Show resolved Hide resolved
pub fn evaluate(
&self,
session: &mut InferenceSession,
n_threads: i32,
input_tokens: &[TokenId],
output_request: &mut EvaluateOutputRequest,
) {
let n = input_tokens.len();
let n_past = session.n_past as i32;
Expand Down Expand Up @@ -1266,12 +1282,16 @@ impl Model {
input_layer = current;
}

// Used at the end to optionally extract the embeddings.
let embeddings_tensor;

// norm
{
input_layer = ctx0.op_norm(&input_layer);

// inpL = norm*inpL
input_layer = ctx0.op_mul(&ctx0.op_repeat(&self.norm, &input_layer), &input_layer);
embeddings_tensor = input_layer.share();
}

// lm_head
Expand All @@ -1296,6 +1316,30 @@ impl Model {
)
};

// Extract logits
if let Some(all_logits) = &mut output_request.all_logits {
all_logits.resize(n_vocab as usize * n, 0.0);
// SAFETY: Data can be read (properly aligned, initialized, data
// will not be mutated or otherwise aliased while the slice lives),
// and we're not reading past the end of the slice.
assert_eq!(input_layer.nelements(), n_vocab * n as i32);
let logits_slice = unsafe {
slice::from_raw_parts(input_layer.data() as *const f32, n_vocab as usize * n)
philpax marked this conversation as resolved.
Show resolved Hide resolved
};
all_logits.copy_from_slice(logits_slice);
}

// Extract embeddings
if let Some(embeddings) = &mut output_request.embeddings {
embeddings.resize(n_embd as usize * n, 0.0);
// SAFETY: Same rationale as for the "Extract logits" section applies.
assert_eq!(embeddings_tensor.nelements(), n_embd * n as i32);
let embeddings_slice = unsafe {
slice::from_raw_parts(embeddings_tensor.data() as *const f32, n_embd as usize * n)
};
embeddings.copy_from_slice(embeddings_slice);
}

// Adjust the required memory per token if we didn't know that already
if session.mem_per_token == 0 {
session.mem_per_token = ctx0.used_mem() / n;
Expand Down Expand Up @@ -1370,7 +1414,12 @@ impl InferenceSession {
}

for batch in prompt_tokens.chunks(8) {
model.evaluate(self, params.n_threads, batch);
model.evaluate(
self,
params.n_threads,
batch,
&mut EvaluateOutputRequest::default(),
);
for &tk in batch {
// NOTE: No string ever tokenizes to the end of sentence. So we
// can just return the id here.
Expand Down Expand Up @@ -1404,7 +1453,12 @@ impl InferenceSession {
self.last_n_tokens.push_front(next_token);

// Then, evaluate the network again to compute the new last_logits
model.evaluate(self, params.n_threads, &[next_token]);
model.evaluate(
self,
params.n_threads,
&[next_token],
&mut EvaluateOutputRequest::default(),
);

// Return the next token
Ok(if next_token as TokenId == EOD_TOKEN_ID {
Expand Down Expand Up @@ -1471,7 +1525,6 @@ impl InferenceSession {
/// ggml context. While the provided `InferenceSnapshotRef` object is alive,
/// no other methods for this model object should be called.
pub unsafe fn get_snapshot(&mut self) -> InferenceSnapshotRef<'_> {
use core::slice;
let memory_k = unsafe {
slice::from_raw_parts(self.memory_k.data() as *mut u8, self.memory_k.nbytes())
};
Expand Down