-
Notifications
You must be signed in to change notification settings - Fork 2
/
pl_transfer_learning_helpers.py
110 lines (91 loc) · 3.66 KB
/
pl_transfer_learning_helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/domain_templates/computer_vision_fine_tuning.py
from typing import Optional, Generator
from torch.nn import Module
import torch
from torch.optim.optimizer import Optimizer
BN_TYPES = (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)
# --- Utility functions ---
def _make_trainable(module: Module) -> None:
"""Unfreezes a given module.
Args:
module: The module to unfreeze
"""
for param in module.parameters():
param.requires_grad = True
module.train()
def _recursive_freeze(module: Module,
train_bn: bool = True) -> None:
"""Freezes the layers of a given module.
Args:
module: The module to freeze
train_bn: If True, leave the BatchNorm layers in training mode
"""
children = list(module.children())
if not children:
if not (isinstance(module, BN_TYPES) and train_bn):
for param in module.parameters():
param.requires_grad = False
module.eval()
else:
# Make the BN layers trainable
_make_trainable(module)
else:
for child in children:
_recursive_freeze(module=child, train_bn=train_bn)
def freeze(module: Module,
n: Optional[int] = None,
train_bn: bool = True) -> None:
"""Freezes the layers up to index n (if n is not None).
Args:
module: The module to freeze (at least partially)
n: Max depth at which we stop freezing the layers. If None, all
the layers of the given module will be frozen.
train_bn: If True, leave the BatchNorm layers in training mode
"""
children = list(module.children())
n_max = len(children) if n is None else int(n)
for child in children[:n_max]:
_recursive_freeze(module=child, train_bn=train_bn)
for child in children[n_max:]:
_make_trainable(module=child)
def filter_params(module: Module,
train_bn: bool = True) -> Generator:
"""Yields the trainable parameters of a given module.
Args:
module: A given module
train_bn: If True, leave the BatchNorm layers in training mode
Returns:
Generator
"""
children = list(module.children())
if not children:
if not (isinstance(module, BN_TYPES) and train_bn):
for param in module.parameters():
if param.requires_grad:
yield param
else:
for child in children:
for param in filter_params(module=child, train_bn=train_bn):
yield param
def _unfreeze_and_add_param_group(module: Module,
optimizer: Optimizer,
lr: Optional[float] = None,
train_bn: bool = True):
"""Unfreezes a module and adds its parameters to an optimizer."""
_make_trainable(module)
params_lr = optimizer.param_groups[0]['lr'] if lr is None else float(lr)
optimizer.add_param_group(
{'params': filter_params(module=module, train_bn=train_bn),
'lr': params_lr / 10.,
})
def unfreeze(module: Module,
optimizer: Optimizer,
lr: Optional[float] = None,
train_bn: bool = True,
start_n: Optional[int] = None,
end_n: Optional[int] = None):
children = list(module.children())
start_n = 0 if start_n is None else int(start_n)
end_n = len(children) if end_n is None else int(end_n)
for child in children[start_n:end_n]:
_unfreeze_and_add_param_group(module=child, optimizer=optimizer, lr=lr, train_bn=train_bn)