Skip to content

Commit

Permalink
feat(xtask): 支持gpt2模型类型转换
Browse files Browse the repository at this point in the history
  • Loading branch information
onenewcode committed Jan 3, 2025
1 parent e6fbb39 commit fe6d513
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions xtask/src/utils/operator/cast.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{Content, DataPromise, Operator};
use super::{Content, DataPromise, Operator};
use ggus::{
ggml_quants::{bf16, f16, QuantExt, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, Q8_1},
DataFuture, GGmlType as Ty, GGufMetaMapExt,
Expand Down Expand Up @@ -28,15 +28,17 @@ impl Operator {
impl Content<'_> {
pub(super) fn cast(&mut self, types: HashMap<String, Ty>) {
match self.general_architecture().unwrap() {
"llama" => {
"llama" | "gpt2" => {
let [mat, embd, norm, else_] =
["mat", "embd", "norm", "else"].map(|name| types.get(name).copied());
self.cast_(mat, |name, shape| {
if matches!(name, "token_embd.weight" | "output.weight") {
embd
} else if name.ends_with("_norm.weight") {
} else if name.ends_with("_norm.weight") || name.ends_with("_norm.bias") {
norm
} else if shape.len() > 1 {
} else if shape.len() > 1
|| (name.ends_with(".bias") && !name.ends_with("_norm.bias"))
{
mat
} else {
else_
Expand Down

0 comments on commit fe6d513

Please sign in to comment.