Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
Copy v_transposed like llama.cpp (#68)
Browse files Browse the repository at this point in the history
  • Loading branch information
KerfuffleV2 authored Mar 26, 2023
1 parent e7e7e8a commit b103dcd
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 19 deletions.
1 change: 1 addition & 0 deletions llama-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ fn main() {
}
}),
play_back_previous_tokens: false,
..Default::default()
};
let inference_session_params = {
let mem_typ = if args.float16 {
Expand Down
63 changes: 44 additions & 19 deletions llama-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ impl Default for InferenceSessionParameters {
}
}

#[derive(Clone, Debug, PartialEq)]
/// The parameters that drive text generation.
pub struct InferenceParameters {
pub n_threads: i32,
Expand All @@ -184,6 +185,7 @@ pub struct InferenceParameters {
pub temp: f32,
pub bias_tokens: TokenBias,
pub play_back_previous_tokens: bool,
pub increased_determinism: bool,
}

impl Default for InferenceParameters {
Expand All @@ -197,6 +199,7 @@ impl Default for InferenceParameters {
temp: 0.80,
bias_tokens: TokenBias::default(),
play_back_previous_tokens: false,
increased_determinism: true,
}
}
}
Expand Down Expand Up @@ -1094,11 +1097,13 @@ impl Model {
pub fn evaluate(
&self,
session: &mut InferenceSession,
n_threads: i32,
params: &InferenceParameters,
input_tokens: &[TokenId],
) {
let n = input_tokens.len();
let n_past = session.n_past as i32;
let n_threads = params.n_threads;
let increased_determinism = params.increased_determinism;

let Hyperparameters {
n_vocab,
Expand Down Expand Up @@ -1127,6 +1132,27 @@ impl Model {

let mut input_layer = ctx0.op_get_rows(&self.tok_embeddings, &embd);

// Defined here to avoid repetition and creating a binding inside nested loops.
// See the call site below for more context.
let vtrans_fun = |il: usize| -> ggml::Tensor {
ctx0.op_permute(
&ctx0.op_reshape_3d(
&ctx0.op_view_1d(
&session.memory_v,
(n_past + n as i32) * n_embd,
il * n_ctx as usize * session.memory_v.element_size() * n_embd as usize,
),
n_embd / n_head,
n_head,
n_past + n as i32,
),
1,
2,
0,
3,
)
};

for il in 0..n_layer as usize {
let input_self_attention = input_layer.share();
let mut current: ggml::Tensor;
Expand Down Expand Up @@ -1226,22 +1252,21 @@ impl Model {
let k_q_soft_max = ctx0.op_soft_max(&k_q_masked);

// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
let v_transposed = ctx0.op_permute(
&ctx0.op_reshape_3d(
&ctx0.op_view_1d(
&session.memory_v,
(n_past + n as i32) * n_embd,
il * n_ctx as usize * session.memory_v.element_size() * n_embd as usize,
),
n_embd / n_head,
n_head,
n_past + n as i32,
),
1,
2,
0,
3,
);
let v_transposed = {
if !increased_determinism {
vtrans_fun(il)
} else {
ctx0.op_cpy(
&vtrans_fun(il),
&ctx0.new_tensor_3d(
ggml::TYPE_F32,
n_past + n as i32,
n_embd / n_head,
n_head,
),
)
}
};

// KQV = transpose(V) * KQ_soft_max
let k_q_v = ctx0.op_mul_mat(&v_transposed, &k_q_soft_max);
Expand Down Expand Up @@ -1393,7 +1418,7 @@ impl InferenceSession {
}

for batch in prompt_tokens.chunks(8) {
model.evaluate(self, params.n_threads, batch);
model.evaluate(self, params, batch);
for &tk in batch {
// NOTE: No string ever tokenizes to the end of sentence. So we
// can just return the id here.
Expand Down Expand Up @@ -1427,7 +1452,7 @@ impl InferenceSession {
self.tokens.push(next_token);

// Then, evaluate the network again to compute the new last_logits
model.evaluate(self, params.n_threads, &[next_token]);
model.evaluate(self, params, &[next_token]);

// Return the next token
Ok(if next_token as TokenId == EOD_TOKEN_ID {
Expand Down

0 comments on commit b103dcd

Please sign in to comment.