-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathutil.py
executable file
·56 lines (43 loc) · 2.34 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os, time, gc, json, pickle, argparse, math
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.distributed as dist
import torch.multiprocessing as mp
import numpy as np
from data.util import *
import copy
def num_params(model):
return sum([np.prod(p.size()) for p in model.parameters() if p.requires_grad])
def init_para_frompretrained(m, pm, share_para=False):
m.wte.weight = pm.wte.weight
m.wpe.weight = pm.wpe.weight
for i in range(min(len(m.h), len(pm.h))):
m.h[i].ln_1.weight = pm.h[i].ln_1.weight if share_para else copy.copy(pm.h[i].ln_1.weight)
m.h[i].ln_1.bias = pm.h[i].ln_1.bias if share_para else copy.copy(pm.h[i].ln_1.bias)
m.h[i].attn.c_attn.weight = pm.h[i].attn.c_attn.weight if share_para else copy.copy(pm.h[i].attn.c_attn.weight)
m.h[i].attn.c_attn.bias = pm.h[i].attn.c_attn.bias if share_para else copy.copy(pm.h[i].attn.c_attn.bias)
m.h[i].attn.c_proj.weight = pm.h[i].attn.c_proj.weight if share_para else copy.copy(pm.h[i].attn.c_proj.weight)
m.h[i].attn.c_proj.bias = pm.h[i].attn.c_proj.bias if share_para else copy.copy(pm.h[i].attn.c_proj.bias)
m.h[i].ln_2.weight = pm.h[i].ln_2.weight if share_para else copy.copy(pm.h[i].ln_2.weight)
m.h[i].ln_2.bias = pm.h[i].ln_2.bias if share_para else copy.copy(pm.h[i].ln_2.bias)
m.h[i].mlp.c_fc.weight = pm.h[i].mlp.c_fc.weight if share_para else copy.copy(pm.h[i].mlp.c_fc.weight)
m.h[i].mlp.c_fc.bias = pm.h[i].mlp.c_fc.bias if share_para else copy.copy(pm.h[i].mlp.c_fc.bias)
m.h[i].mlp.c_proj.weight = pm.h[i].mlp.c_proj.weight if share_para else copy.copy(pm.h[i].mlp.c_proj.weight)
m.h[i].mlp.c_proj.bias = pm.h[i].mlp.c_proj.bias if share_para else copy.copy(pm.h[i].mlp.c_proj.bias)
m.ln_f.weight = pm.ln_f.weight if share_para else copy.copy(pm.ln_f.weight)
m.ln_f.bias = pm.ln_f.bias if share_para else copy.copy(pm.ln_f.bias)
def switch_schedule(schedule, mult, switch):
""" Apply LR multiplier before iteration "switch" """
def f(e):
s = schedule(e)
if e < switch:
return s * mult
return s
return f
def linear_schedule(args):
def f(e):
if e <= args.warmup:
return e / args.warmup
return max((e - args.iterations) / (args.warmup - args.iterations), 0)
return f