Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
[llama-loader] Support non-copy loader
Browse files Browse the repository at this point in the history
  • Loading branch information
iacore committed Apr 8, 2023
1 parent 8390593 commit 15fe19b
Showing 1 changed file with 33 additions and 19 deletions.
52 changes: 33 additions & 19 deletions llama-loader/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,23 +90,31 @@ pub struct TensorInfo {
pub dims: [usize; 2],
pub n_elements: usize,
pub ftype: ElementType,
/// start of tensor - start of file
pub start_offset: u64,
}

#[allow(unused_variables)]
pub trait LoadHandler<T> {
fn cb_container_type(&mut self, model_type: ContainerType) -> ControlFlow<T> {
fn got_container_type(&mut self, model_type: ContainerType) -> ControlFlow<T> {
ControlFlow::Continue(())
}

fn cb_hyper_parameters(&mut self, hparams: LlamaHyperparameters) -> ControlFlow<T> {
fn got_hyper_parameters(&mut self, hparams: LlamaHyperparameters) -> ControlFlow<T> {
ControlFlow::Continue(())
}

fn cb_vocab_token(&mut self, i: usize, token: Vec<u8>, score: f32) -> ControlFlow<T> {
fn got_vocab_token(&mut self, i: usize, token: Vec<u8>, score: f32) -> ControlFlow<T> {
ControlFlow::Continue(())
}

fn tensor_buffer(&mut self, info: TensorInfo) -> ControlFlow<T, &mut [u8]>;
/// # Returns
///
/// `None` to skip copying
/// `Some(buf)` to provide a buffer for copying weights into
fn get_tensor_buffer(&mut self, info: TensorInfo) -> ControlFlow<T, Option<&mut [u8]>> {
ControlFlow::Continue(None)
}
}

fn retchk<A, B>(model_type: ControlFlow<A, B>) -> Result<B, LoadError<A>> {
Expand All @@ -127,7 +135,7 @@ pub fn load_model_from_reader<T>(
ggml::FILE_MAGIC_UNVERSIONED => ContainerType::GGML,
magic => return Err(LoadError::InvalidMagic(magic)),
};
retchk(handler.cb_container_type(container_type))?;
retchk(handler.got_container_type(container_type))?;

// Load format version
match container_type {
Expand All @@ -154,7 +162,7 @@ pub fn load_model_from_reader<T>(
tensor_element_type: decode_element_type_res(read_i32(&mut reader)?)?,
};
let n_vocab = hparams.n_vocab;
retchk(handler.cb_hyper_parameters(hparams))?;
retchk(handler.got_hyper_parameters(hparams))?;

// Load vocabulary
for i in 0..n_vocab {
Expand All @@ -167,7 +175,7 @@ pub fn load_model_from_reader<T>(
0.
}
};
retchk(handler.cb_vocab_token(i, token, token_score))?;
retchk(handler.got_vocab_token(i, token, token_score))?;
}

// Load tensor data
Expand Down Expand Up @@ -224,28 +232,34 @@ fn load_weights_ggjt<T>(
_ => {}
}

// load tensor weights
let offset_curr = reader.stream_position()?;
let offset_aligned: u64 = (offset_curr + 31) & !31;

let tensor_info = TensorInfo {
name,
dims,
n_dims,
n_elements,
ftype,
start_offset: offset_aligned
};

// load tensor weights
let offset_curr = reader.stream_position()?;
let offset_aligned: u64 = (offset_curr + 31) & !31;
reader.seek(SeekFrom::Start(offset_aligned))?;


let type_size = ggml::type_size(ftype);
let buf = retchk(handler.tensor_buffer(tensor_info))?;
let buf_len = buf.len();
if !(buf_len == type_size * n_elements) {
return Err(LoadError::InvariantBroken(format!(
"{buf_len} == {type_size} * {n_elements}"
)));
if let Some(buf) = retchk(handler.get_tensor_buffer(tensor_info))? {
reader.seek(SeekFrom::Start(offset_aligned))?;
let buf_len = buf.len();
if !(buf_len == type_size * n_elements) {
return Err(LoadError::InvariantBroken(format!(
"{buf_len} == {type_size} * {n_elements}"
)));
}
reader.read_exact(buf)?;
} else {
// skip if no buffer is given
reader.seek(SeekFrom::Start(offset_aligned + (type_size * n_elements) as u64))?;
}
reader.read_exact(buf)?;
}

Ok(())
Expand Down

0 comments on commit 15fe19b

Please sign in to comment.