Skip to content

Commit

Permalink
feat(xtask): 张量排序功能支持任意模型
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Dec 25, 2024
1 parent 3fe675a commit 8adbb4e
Showing 1 changed file with 185 additions and 55 deletions.
240 changes: 185 additions & 55 deletions xtask/src/utils/operator/sort.rs
Original file line number Diff line number Diff line change
@@ -1,69 +1,199 @@
use super::Content;
use ggus::GGufMetaMapExt;
use itertools::Itertools;
use std::{collections::HashMap, sync::LazyLock};
use regex::Regex;
use std::{cmp::Ordering, collections::HashMap, sync::LazyLock};

impl Content<'_> {
pub(super) fn sort_tensors(&mut self) {
match self.general_architecture().unwrap() {
"llama" => {}
arch => todo!("unsupported architecture: {arch}"),
}
let tensors = std::mem::take(&mut self.tensors);
self.tensors = tensors
.into_iter()
.sorted_unstable_by_key(|(k, _)| rank(k).unwrap_or(usize::MAX))
.map(|(k, v)| (Name::new_key(&k), k, v))
.sorted_unstable_by(|(a, ..), (b, ..)| a.cmp(b))
.map(|(_, k, v)| (k, v))
.collect();
}
}

fn rank(name: &str) -> Option<usize> {
let (head, tail): (&str, usize);
if let Some(name) = name.strip_suffix(".weight") {
head = name;
tail = 0;
} else {
head = name.strip_suffix(".bias")?;
tail = 1;
};

static ORDER_MAP: LazyLock<HashMap<&str, usize>> = LazyLock::new(|| {
[
"token_embd",
"output_norm",
"output",
"attn_norm",
"attn_norm_2",
"attn_qkv",
"attn_q",
"attn_k",
"attn_v",
"attn_output",
"ffn_norm",
"ffn_gate_up",
"ffn_up",
"ffn_gate",
"ffn_down",
"ffn_up_exp",
"ffn_up_exps",
"ffn_gate_inp",
"ffn_gate_exp",
"ffn_gate_exps",
"ffn_down_exp",
"ffn_down_exps",
]
.iter()
.enumerate()
.map(|(i, s)| (*s, i))
.collect()
});

let head = match head.strip_prefix("blk.") {
Some(body) => {
let (blk, name) = body.split_once('.')?;
blk.parse::<usize>().unwrap() * ORDER_MAP.len() + *ORDER_MAP.get(name)?
const MID: &[&str] = &[
"attn_norm",
"attn_norm_2",
"ln1",
"attn_qkv",
"attn_q",
"attn_k",
"attn_v",
"attn.q",
"attn.k",
"attn.v",
"attn_output",
"attn_out",
"attn.out",
"ffn_norm",
"ln2",
"ffn_gate_up",
"ffn_gate",
"ffn_up",
"ffn_down",
"ffn_gate_inp",
"ffn_gate_exp",
"ffn_gate_exps",
"ffn_up_exp",
"ffn_up_exps",
"ffn_down_exp",
"ffn_down_exps",
];

const POST: &[&str] = &["weight", "bias"];

#[test]
fn test() {}

#[derive(PartialEq, Eq, Debug)]
struct Name<'a>(Pre<'a>, Mid<'a>, Post<'a>);

#[derive(PartialEq, Eq, Debug)]
struct Pre<'a>(Vec<PreSeg<'a>>);
#[derive(PartialEq, Eq, Debug)]
struct Mid<'a>(&'a str);
#[derive(PartialEq, Eq, Debug)]
struct Post<'a>(&'a str);

#[derive(PartialEq, Eq, Debug)]
enum PreSeg<'a> {
Str(&'a str),
Num(usize),
}

impl Name<'static> {
fn new_key(value: &str) -> Self {
static REGEX: LazyLock<Regex> = LazyLock::new(|| {
let mut mid = String::new();
for name in MID {
for c in name.chars() {
if c.is_ascii_alphanumeric() || c == '_' {
mid.push(c)
} else if c == '.' {
mid.push_str(r"\.")
} else {
panic!("invalid char: {c}")
}
}
mid.push('|')
}
mid.pop();
Regex::new(&mid).unwrap()
});

let value = unsafe {
std::str::from_utf8_unchecked(std::slice::from_raw_parts(value.as_ptr(), value.len()))
};
let (start, end) = REGEX
.find(value)
.map_or((value.len(), value.len()), |mid| (mid.start(), mid.end()));
let pre = value[..start]
.split('.')
.map(|s| s.parse::<usize>().map_or(PreSeg::Str(s), PreSeg::Num))
.collect();
let mid = &value[start..end];
let post = &value[end..];
Self(Pre(pre), Mid(mid), Post(post))
}
}

impl Ord for Name<'_> {
fn cmp(&self, other: &Self) -> Ordering {
use Ordering::Equal;
match self.0.cmp(&other.0) {
Equal => match self.1.cmp(&other.1) {
Equal => self.2.cmp(&other.2),
ord => ord,
},
ord => ord,
}
}
}
impl Ord for Pre<'_> {
fn cmp(&self, other: &Self) -> Ordering {
use Ordering::{Equal, Greater, Less};

for (a, b) in self.0.iter().zip(other.0.iter()) {
match (a, b) {
(PreSeg::Str(_), PreSeg::Num(_)) => return Less,
(PreSeg::Num(_), PreSeg::Str(_)) => return Greater,
(PreSeg::Str(a), PreSeg::Str(b)) => match a.cmp(b) {
Equal => {}
ord => return ord,
},
(PreSeg::Num(a), PreSeg::Num(b)) => match a.cmp(b) {
Equal => {}
ord => return ord,
},
}
}
None => *ORDER_MAP.get(head)?,
};
Some(head * 2 + tail)
self.0.len().cmp(&other.0.len())
}
}

impl Ord for Mid<'_> {
fn cmp(&self, other: &Self) -> Ordering {
static ORDER_MAP: LazyLock<HashMap<&str, usize>> =
LazyLock::new(|| MID.iter().enumerate().map(|(i, s)| (*s, i)).collect());
cmp_by_map(self.0, other.0, &ORDER_MAP)
}
}

impl Ord for Post<'_> {
fn cmp(&self, other: &Self) -> Ordering {
use Ordering::{Equal, Greater, Less};

static ORDER_MAP: LazyLock<HashMap<&str, usize>> =
LazyLock::new(|| POST.iter().enumerate().map(|(i, s)| (*s, i)).collect());
let mut a = self.0.split('.');
let mut b = other.0.split('.');
loop {
match (a.next(), b.next()) {
(Some(a), Some(b)) => match cmp_by_map(a, b, &ORDER_MAP) {
Equal => {}
ord => break ord,
},
(Some(_), None) => break Greater,
(None, Some(_)) => break Less,
(None, None) => break Equal,
}
}
}
}

fn cmp_by_map(a: &str, b: &str, map: &HashMap<&str, usize>) -> Ordering {
match (map.get(a), map.get(b)) {
(Some(_), None) => Ordering::Less,
(None, Some(_)) => Ordering::Greater,
(Some(a), Some(b)) => a.cmp(b),
(None, None) => a.cmp(b),
}
}

impl PartialOrd for Name<'_> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}

impl PartialOrd for Pre<'_> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}

impl PartialOrd for Mid<'_> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}

impl PartialOrd for Post<'_> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}

0 comments on commit 8adbb4e

Please sign in to comment.