Skip to content

Commit

Permalink
feat(show): 元信息和张量过滤功能升级到支持通配符
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Aug 2, 2024
1 parent 75d7d47 commit 2867693
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 47 deletions.
45 changes: 45 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions xtask/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ ggus = { version = "*", path = "../ggus" }
indexmap.workspace = true
clap = { version = "4.5", features = ["derive"] }
memmap2 = "0.9"
regex = "1.10"
12 changes: 1 addition & 11 deletions xtask/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@ mod file_info;
mod gguf_file;
mod loose_shards;
mod merge;
mod name_pattern;
mod show;
mod split;

#[macro_use]
extern crate clap;

use clap::Parser;
use std::collections::HashSet;

fn main() {
use Commands::*;
Expand Down Expand Up @@ -37,12 +36,3 @@ enum Commands {

const YES: &str = "✔️ ";
const ERR: &str = "❌ ";

fn split_keys(arg: &Option<String>) -> Option<HashSet<&str>> {
arg.as_ref().map(|s| {
s.split(',')
.map(str::trim)
.filter(|s| !s.is_empty())
.collect::<HashSet<_>>()
})
}
56 changes: 56 additions & 0 deletions xtask/src/name_pattern.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
use regex::Regex;
use std::{borrow::Cow, fmt, sync::OnceLock};

pub fn compile_patterns(patterns: &str) -> Regex {
Regex::new(&format!("{}", Patterns(patterns))).unwrap()
}

struct Patterns<'a>(&'a str);

impl fmt::Display for Patterns<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
static REGEX: OnceLock<Regex> = OnceLock::new();
// 匹配任何标识符、点、星号的组合
let patterns = REGEX
.get_or_init(|| Regex::new(r"[\w*.]+").unwrap())
.captures_iter(self.0);

let mut patterns = patterns.into_iter();
if let Some(pattern) = patterns.next() {
write!(f, "{}", Pattern(&pattern[0]))?;
}
for pattern in patterns {
write!(f, "|{}", Pattern(&pattern[0]))?;
}
Ok(())
}
}

struct Pattern<'a>(&'a str);

impl fmt::Display for Pattern<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut iter = self.0.split('.').map(|s| {
if s.is_empty() {
Cow::Borrowed(r"\w+")
} else if s.chars().all(|c| c == '*') {
Cow::Borrowed(r"(\w+\.)*\w+")
} else {
static REGEX: OnceLock<Regex> = OnceLock::new();
// 消除任何连续 *
REGEX
.get_or_init(|| Regex::new(r"\*+").unwrap())
.replace_all(s, r"\w*")
}
});

write!(f, "^")?;
if let Some(ele) = iter.next() {
write!(f, "{ele}")?;
}
for ele in iter {
write!(f, r"\.{ele}")?;
}
write!(f, "$")
}
}
57 changes: 21 additions & 36 deletions xtask/src/show.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use crate::{loose_shards::LooseShards, split_keys, ERR, YES};
use crate::{loose_shards::LooseShards, name_pattern::compile_patterns, ERR, YES};
use ggus::{
GGufFileHeader, GGufMetaDataValueType, GGufMetaKV, GGufMetaKVPairs, GGufReadError, GGufReader,
GGufTensors,
};
use std::{collections::HashSet, fmt, fs::File, path::PathBuf};
use regex::Regex;
use std::{fmt, fs::File, path::PathBuf};

#[derive(Args, Default)]
pub struct ShowArgs {
Expand All @@ -15,12 +16,12 @@ pub struct ShowArgs {
/// How many elements to show in arrays, `all` for all elements
#[clap(long, short = 'n', default_value = "8")]
array_detail: String,
/// Meta to show (split with `,`)
#[clap(long, short = 'm')]
filter_meta: Option<String>,
/// Tensors to show (split with `,`)
#[clap(long, short = 't')]
filter_tensor: Option<String>,
/// Meta to show
#[clap(long, short = 'm', default_value = "*")]
filter_meta: String,
/// Tensors to show
#[clap(long, short = 't', default_value = "*")]
filter_tensor: String,
}

struct Failed;
Expand All @@ -41,8 +42,8 @@ impl ShowArgs {
.parse()
.expect("Invalid array detail, should be an integer or `all`"),
};
let filter_meta = split_keys(&filter_meta);
let filter_tensor = split_keys(&filter_tensor);
let filter_meta = compile_patterns(&filter_meta);
let filter_tensor = compile_patterns(&filter_tensor);

let files = if shards {
LooseShards::from(&*file)
Expand Down Expand Up @@ -155,19 +156,11 @@ fn show_header(header: &GGufFileHeader) -> Result<(), Failed> {
Ok(())
}

fn show_meta_kvs<'a>(
kvs: &GGufMetaKVPairs,
filter: &Option<HashSet<&'a str>>,
detail: usize,
) -> Result<(), Failed> {
let kvs = filter.as_ref().map_or_else(
|| kvs.kvs().collect::<Vec<_>>(),
|to_keep| {
kvs.kvs()
.filter(move |m| to_keep.contains(m.key()))
.collect::<Vec<_>>()
},
);
fn show_meta_kvs<'a>(kvs: &GGufMetaKVPairs, filter: &Regex, detail: usize) -> Result<(), Failed> {
let kvs = kvs
.kvs()
.filter(move |m| filter.is_match(m.key()))
.collect::<Vec<_>>();

if let Some(width) = kvs.iter().map(|kv| kv.key().len()).max() {
show_title("Meta KV");
Expand Down Expand Up @@ -278,19 +271,11 @@ fn fmt_meta_val<'a>(
Ok(())
}

fn show_tensors<'a>(
tensors: &GGufTensors,
filter: &Option<HashSet<&'a str>>,
) -> Result<(), Failed> {
let tensors = filter.as_ref().map_or_else(
|| tensors.iter().collect::<Vec<_>>(),
|to_keep| {
tensors
.iter()
.filter(move |t| to_keep.contains(t.name()))
.collect::<Vec<_>>()
},
);
fn show_tensors<'a>(tensors: &GGufTensors, filter: &Regex) -> Result<(), Failed> {
let tensors = tensors
.iter()
.filter(move |t| filter.is_match(t.name()))
.collect::<Vec<_>>();

if let Some(name_width) = tensors.iter().map(|t| t.name().len()).max() {
show_title("Tensors");
Expand Down

0 comments on commit 2867693

Please sign in to comment.