Skip to content

Commit

Permalink
feat(xtask): 添加 convert 命令,支持转换数据类型
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Aug 15, 2024
1 parent 27c325b commit 8fcde25
Show file tree
Hide file tree
Showing 15 changed files with 291 additions and 204 deletions.
1 change: 1 addition & 0 deletions .cargo/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ show = "xtask show"
split = "xtask split"
merge = "xtask merge"
filter = "xtask filter"
convert = "xtask convert"
63 changes: 63 additions & 0 deletions xtask/src/convert.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use crate::utils::{operate, show_file_info, Operator, OutputConfig, Shards};
use std::path::PathBuf;

#[derive(Args, Default)]
pub struct ConvertArgs {
/// File to split
file: PathBuf,
/// Output directory for splited shards
#[clap(long, short)]
output_dir: Option<PathBuf>,
/// Operations to apply, separated by "->"
#[clap(long)]
ops: String,
/// Max count of tensors per shard
#[clap(long, short = 't')]
max_tensors: Option<usize>,
/// Max size in bytes per shard
#[clap(long, short = 's')]
max_bytes: Option<String>,
/// If set, the first shard will not contain any tensor
#[clap(long, short)]
no_tensor_first: bool,
}

impl ConvertArgs {
pub fn convert(self) {
let Self {
file,
output_dir,
ops,
max_tensors,
max_bytes,
no_tensor_first,
} = self;

let shards = Shards::from(&*file);
let files = operate(
shards.iter_all(),
ops.split("->").map(|op| {
let op = op.trim();
if let Some(content) = op.strip_prefix("filter-meta:") {
Operator::filter_meta_key(content)
} else if let Some(content) = op.strip_prefix("filter-tensor:") {
Operator::filter_tensor_name(content)
} else if let Some(content) = op.strip_prefix("cast:") {
Operator::cast(content)
} else {
panic!("Unsupported operation: {}", op)
}
}),
OutputConfig {
dir: output_dir.unwrap_or_else(|| std::env::current_dir().unwrap()),
name: shards.name.into(),
shard_max_tensor_count: max_tensors.unwrap_or(usize::MAX),
shard_max_file_size: max_bytes.map_or(Default::default(), |s| s.parse().unwrap()),
shard_no_tensor_first: no_tensor_first,
},
)
.unwrap();

show_file_info(&files);
}
}
125 changes: 0 additions & 125 deletions xtask/src/convert/mod.rs

This file was deleted.

27 changes: 13 additions & 14 deletions xtask/src/filter.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
use crate::{
convert::{ConvertArgs, Operator},
file_info::show_file_info,
};
use crate::utils::{operate, show_file_info, Operator, OutputConfig};
use std::path::PathBuf;

#[derive(Args, Default)]
Expand All @@ -28,20 +25,22 @@ impl FilterArgs {
filter_tensor,
} = self;

let files = ConvertArgs {
output_name: file_path.file_stem().unwrap().to_str().unwrap().to_string() + ".part",
input_files: vec![file_path],
output_dir: output_dir.unwrap_or_else(|| std::env::current_dir().unwrap()),
operations: vec![
let files = operate(
[&file_path],
[
Operator::filter_meta_key(filter_meta),
Operator::filter_tensor_name(filter_tensor),
],
split_tensor_count: usize::MAX,
split_file_size: usize::MAX,
split_no_tensor_first: false,
}
.convert()
OutputConfig {
dir: output_dir.unwrap_or_else(|| std::env::current_dir().unwrap()),
name: file_path.file_stem().unwrap().to_str().unwrap().to_string() + ".part",
shard_max_tensor_count: usize::MAX,
shard_max_file_size: Default::default(),
shard_no_tensor_first: false,
},
)
.unwrap();

show_file_info(&files);
}
}
9 changes: 3 additions & 6 deletions xtask/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
mod convert;
mod file_info;
mod filter;
mod merge;
mod name_pattern;
mod shards;
mod show;
mod split;
mod utils;

#[macro_use]
extern crate clap;
Expand All @@ -18,6 +16,7 @@ fn main() {
Split(args) => args.split(),
Merge(args) => args.merge(),
Filter(args) => args.filter(),
Convert(args) => args.convert(),
}
}

Expand All @@ -35,7 +34,5 @@ enum Commands {
Split(split::SplitArgs),
Merge(merge::MergeArgs),
Filter(filter::FilterArgs),
Convert(convert::ConvertArgs),
}

const YES: &str = "✔️ ";
const ERR: &str = "❌ ";
24 changes: 13 additions & 11 deletions xtask/src/merge.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{convert::ConvertArgs, file_info::show_file_info, shards::Shards};
use crate::utils::{operate, show_file_info, OutputConfig, Shards};
use std::path::PathBuf;

#[derive(Args, Default)]
Expand All @@ -20,17 +20,19 @@ impl MergeArgs {
return;
}

let files = ConvertArgs {
input_files: shards.iter_all().collect(),
output_dir: output_dir.unwrap_or_else(|| std::env::current_dir().unwrap()),
output_name: shards.name.into(),
operations: Vec::new(),
split_tensor_count: usize::MAX,
split_file_size: usize::MAX,
split_no_tensor_first: false,
}
.convert()
let files = operate(
shards.iter_all(),
[],
OutputConfig {
dir: output_dir.unwrap_or_else(|| std::env::current_dir().unwrap()),
name: shards.name.into(),
shard_max_tensor_count: usize::MAX,
shard_max_file_size: Default::default(),
shard_no_tensor_first: false,
},
)
.unwrap();

show_file_info(&files);
}
}
5 changes: 4 additions & 1 deletion xtask/src/show.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{name_pattern::compile_patterns, shards::Shards, ERR, YES};
use crate::utils::{compile_patterns, Shards};
use ggus::{GGufFileHeader, GGufMetaDataValueType, GGufMetaKV, GGufReadError, GGufReader};
use indexmap::IndexMap;
use memmap2::Mmap;
Expand All @@ -9,6 +9,9 @@ use std::{
path::{Path, PathBuf},
};

const YES: &str = "✔️ ";
const ERR: &str = "❌ ";

#[derive(Args, Default)]
pub struct ShowArgs {
/// The file to show
Expand Down
37 changes: 13 additions & 24 deletions xtask/src/split.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{convert::ConvertArgs, file_info::show_file_info, shards::Shards};
use std::{path::PathBuf, str::from_utf8};
use crate::utils::{operate, show_file_info, OutputConfig, Shards};
use std::path::PathBuf;

#[derive(Args, Default)]
pub struct SplitArgs {
Expand Down Expand Up @@ -35,30 +35,19 @@ impl SplitArgs {
return;
}

fn parse_size_num(num: &[u8], k: usize) -> Option<usize> {
from_utf8(num).ok()?.parse().ok().map(|n: usize| n << k)
}

let files = ConvertArgs {
output_name: shards.name.into(),
input_files: vec![file],
output_dir: output_dir.unwrap_or_else(|| std::env::current_dir().unwrap()),
operations: Vec::new(),
split_tensor_count: max_tensors.unwrap_or(usize::MAX),
split_file_size: match max_bytes {
Some(s) => match s.trim().as_bytes() {
[num @ .., b'G'] => parse_size_num(num, 30),
[num @ .., b'M'] => parse_size_num(num, 20),
[num @ .., b'K'] => parse_size_num(num, 10),
num => parse_size_num(num, 0),
}
.unwrap_or_else(|| panic!("Invalid max bytes format: \"{s}\"")),
None => usize::MAX,
let files = operate(
[&file],
[],
OutputConfig {
dir: output_dir.unwrap_or_else(|| std::env::current_dir().unwrap()),
name: shards.name.into(),
shard_max_tensor_count: max_tensors.unwrap_or(usize::MAX),
shard_max_file_size: max_bytes.map_or(Default::default(), |s| s.parse().unwrap()),
shard_no_tensor_first: no_tensor_first,
},
split_no_tensor_first: no_tensor_first,
}
.convert()
)
.unwrap();

show_file_info(&files);
}
}
Loading

0 comments on commit 8fcde25

Please sign in to comment.