-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathconvert_back_params.py
131 lines (117 loc) · 6.59 KB
/
convert_back_params.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from jax import Array
import torch
import torch.nn as tnn
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel as LlamaModelPt
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaMLP, LlamaRMSNorm
from ..array_utils import jax2pt
from ..llama import Llama, LlamaModel
from ..llama.attention import Attention
from ..llama.decoder_block import DecoderBlock
from ..tree_utils import unstack_leaves
def convert_back_embedding(x: Array) -> tnn.Embedding:
with torch.no_grad():
embedding = tnn.Embedding(*x.shape) # type: ignore
embedding.weight = tnn.Parameter(jax2pt(x))
return embedding
def convert_back_norm(x: Array, *, config: LlamaConfig) -> LlamaRMSNorm:
d_model = config.hidden_size
rms_norm_eps = config.rms_norm_eps
with torch.no_grad():
llama_rms_norm = LlamaRMSNorm(d_model, eps=rms_norm_eps)
llama_rms_norm.weight = tnn.Parameter(jax2pt(x))
return llama_rms_norm
def convert_back_proj(x: Array) -> tnn.Linear:
with torch.no_grad():
linear = tnn.Linear(*x.shape, bias=False) # type: ignore
linear.weight = tnn.Parameter(jax2pt(x).T)
return linear
def convert_back_q_proj(x: Array, *, config: LlamaConfig) -> tnn.Linear:
d_model = config.hidden_size
n_rep_kv = config.num_attention_heads // config.num_key_value_heads
n_heads_kv = config.num_key_value_heads
d_k = config.hidden_size // config.num_attention_heads
in_features = d_model
out_features = n_rep_kv * n_heads_kv * d_k
with torch.no_grad():
linear = tnn.Linear(in_features, out_features, bias=False)
linear.weight = tnn.Parameter(jax2pt(x).reshape(in_features, out_features).T)
return linear
def convert_back_k_proj(x: Array, *, config: LlamaConfig) -> tnn.Linear:
d_model = config.hidden_size
n_heads_kv = config.num_key_value_heads
d_k = config.hidden_size // config.num_attention_heads
in_features = d_model
out_features = n_heads_kv * d_k
with torch.no_grad():
linear = tnn.Linear(in_features, out_features, bias=False)
linear.weight = tnn.Parameter(jax2pt(x).reshape(in_features, out_features).T)
return linear
def convert_back_v_proj(x: Array, *, config: LlamaConfig) -> tnn.Linear:
d_model = config.hidden_size
n_heads_kv = config.num_key_value_heads
d_v = config.hidden_size // config.num_attention_heads
in_features = d_model
out_features = n_heads_kv * d_v
with torch.no_grad():
linear = tnn.Linear(in_features, out_features, bias=False)
linear.weight = tnn.Parameter(jax2pt(x).reshape(in_features, out_features).T)
return linear
def convert_back_out_proj(x: Array, *, config: LlamaConfig) -> tnn.Linear:
d_model = config.hidden_size
n_rep_kv = config.num_attention_heads // config.num_key_value_heads
n_heads_kv = config.num_key_value_heads
d_v = config.hidden_size // config.num_attention_heads
in_features = n_rep_kv * n_heads_kv * d_v
out_features = d_model
with torch.no_grad():
linear = tnn.Linear(in_features, out_features, bias=False) # type: ignore
linear.weight = tnn.Parameter(jax2pt(x).reshape(in_features, out_features).T)
return linear
def convert_back_attention(x: Attention, *, config: LlamaConfig) -> LlamaAttention:
with torch.no_grad():
llama_attention = LlamaAttention(config=config)
llama_attention.q_proj = convert_back_q_proj(x.q_proj, config=config)
llama_attention.k_proj = convert_back_k_proj(x.k_proj, config=config)
llama_attention.v_proj = convert_back_v_proj(x.v_proj, config=config)
llama_attention.o_proj = convert_back_out_proj(x.out_proj, config=config)
return llama_attention
def convert_back_mlp(gate_proj: Array, up_proj: Array, down_proj: Array, *, config: LlamaConfig) -> LlamaMLP:
with torch.no_grad():
llama_mlp = LlamaMLP(config=config)
llama_mlp.gate_proj = convert_back_proj(gate_proj)
llama_mlp.up_proj = convert_back_proj(up_proj)
llama_mlp.down_proj = convert_back_proj(down_proj)
return llama_mlp
def convert_back_decoder_block(x: DecoderBlock, *, config: LlamaConfig) -> LlamaDecoderLayer:
with torch.no_grad():
llama_decoder_layer = LlamaDecoderLayer(config=config)
llama_decoder_layer.self_attn = convert_back_attention(x.attention, config=config)
llama_decoder_layer.mlp = convert_back_mlp(x.gate_proj, x.up_proj, x.down_proj, config=config)
llama_decoder_layer.input_layernorm = convert_back_norm(x.input_norm, config=config)
llama_decoder_layer.post_attention_layernorm = convert_back_norm(x.post_attn_norm, config=config)
return llama_decoder_layer
def convert_back_llama_model(x: LlamaModel, *, config: LlamaConfig) -> LlamaModelPt:
with torch.no_grad():
llama_model = LlamaModelPt(config=config)
llama_model.embed_tokens = convert_back_embedding(x.embedding)
llama_model.layers = tnn.ModuleList([convert_back_decoder_block(decoder_block, config=config) for decoder_block in unstack_leaves(x.decoder)])
llama_model.norm = convert_back_norm(x.norm, config=config)
return llama_model
def convert_back_llama(x: Llama, *, config: LlamaConfig) -> LlamaForCausalLM:
with torch.no_grad():
llama = LlamaForCausalLM(config=config)
llama.model = convert_back_llama_model(x.model, config=config)
llama.lm_head = convert_back_proj(x.lm_head)
return llama
# from pathlib import Path; import sys; sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
# from lib.proc_init_utils import initialise_cpu; initialise_cpu()
# model_pt = LlamaForCausalLM.from_pretrained('/dev/shm/llama-weights/llama2-7B')
# config = LlamaConfig.from_pretrained('/dev/shm/llama-weights/llama2-7B')
# from lib.param_utils.convert_params import convert_proj
# assert torch.equal(convert_back_proj(convert_proj(model_pt.lm_head)).weight, model_pt.lm_head.weight)
# assert torch.equal(convert_back_proj(convert_proj(model_pt.model.layers[0].self_attn.q_proj)).weight, model_pt.model.layers[0].self_attn.q_proj.weight)
# assert torch.equal(convert_back_proj(convert_proj(model_pt.model.layers[0].self_attn.k_proj)).weight, model_pt.model.layers[0].self_attn.k_proj.weight)
# assert torch.equal(convert_back_proj(convert_proj(model_pt.model.layers[0].self_attn.v_proj)).weight, model_pt.model.layers[0].self_attn.v_proj.weight)
# assert torch.equal(convert_back_proj(convert_proj(model_pt.model.layers[0].self_attn.o_proj)).weight, model_pt.model.layers[0].self_attn.o_proj.weight)
# model_pt.model.norm.weight
# model_pt.model.embed_tokens.weight