This repository has been archived by the owner on Jul 24, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
/
models.py
65 lines (55 loc) · 1.68 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from unet import UNet
def get_unet():
model = UNet(
in_channels=1,
out_classes=2,
dimensions=3,
num_encoding_blocks=3,
out_channels_first_layer=8,
normalization='batch',
pooling_type='max',
padding=True,
padding_mode='replicate',
residual=False,
initial_dilation=1,
activation='PReLU',
upsampling_type='linear',
dropout=0,
monte_carlo_dropout=0.5,
)
return model
def freeze(module):
for param in module.parameters():
param.requires_grad = False
def unfreeze(module):
for param in module.parameters():
param.requires_grad = True
def freeze_layers(model, num_layers):
# num_layers should be None, 1, 2 or 3
if num_layers is None:
return
if hasattr(model, 'module'): # DataParallel
model = model.module
if num_layers >= 1:
layer = model.classifier.conv_layer
freeze(layer)
if num_layers >= 2:
first = -(num_layers - 1)
for module in model.decoder.decoding_blocks[first:]:
freeze(module)
def freeze_except(model, num_layers):
# num_layers should be None, 1, 2 or 3
# If None, nothing happens
# If 1, all layers are frozen but the classifier
# If 2 or 3, conv layers from last decoder block will also not be frozen
if num_layers is None:
return
if hasattr(model, 'module'): # DataParallel
model = model.module
freeze(model)
if num_layers > 0:
unfreeze(model.classifier)
if num_layers > 1:
unfreeze(model.decoder.decoding_blocks[-1])
if num_layers == 3:
unfreeze(model.decoder.decoding_blocks[-2])