-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
65 lines (49 loc) · 2.02 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from models import resnet
import paddle.fluid as fluid
def get_module_name(name):
name = name.split('.')
if name[0] == 'module':
i = 1
else:
i = 0
if name[i] == 'features':
i += 1
return name[i]
def get_fine_tuning_parameters(model, ft_begin_module):
if not ft_begin_module:
return model.parameters()
parameters = []
add_flag = False
for k, v in model.named_parameters():
if ft_begin_module == get_module_name(k):
add_flag = True
if add_flag:
parameters.append(v)
return parameters
def generate_model(opt):
assert opt.model in [
'resnet', 'resnet2p1d', 'preresnet', 'wideresnet', 'resnext', 'densenet'
]
if opt.model == 'resnet':
model = resnet.generate_model(model_depth=opt.model_depth,
n_classes=opt.n_classes,
n_input_channels=opt.n_input_channels,
shortcut_type=opt.resnet_shortcut,
conv1_t_size=opt.conv1_t_size,
conv1_t_stride=opt.conv1_t_stride,
no_max_pool=opt.no_max_pool,
widen_factor=opt.resnet_widen_factor)
return model
def load_pretrained_model(model, pretrain_path, model_name, n_finetune_classes):
if pretrain_path:
print('loading pretrained model {}'.format(pretrain_path))
with fluid.dygraph.guard():
para_dict, _ = fluid.dygraph.load_dygraph(str(pretrain_path))
model.set_dict(para_dict)
if model_name == 'densenet':
model.classifier = fluid.dygraph.Linear(model.last_feature_size,
n_finetune_classes, act='softmax')
else:
model.fc = fluid.dygraph.Linear(model.last_feature_size,
n_finetune_classes, act='softmax')
return model