-
Notifications
You must be signed in to change notification settings - Fork 5.5k
/
Copy pathlr_scheduler.py
93 lines (80 loc) · 3.15 KB
/
lr_scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
from torch.optim.lr_scheduler import _LRScheduler
class PolyScheduler(_LRScheduler):
def __init__(self,
optimizer,
base_lr,
max_steps,
warmup_steps,
last_epoch=-1):
self.base_lr = base_lr
self.warmup_lr_init = 0.0001
self.max_steps: int = max_steps
self.warmup_steps: int = warmup_steps
self.power = 2
super(PolyScheduler, self).__init__(optimizer, last_epoch, False)
def get_warmup_lr(self):
alpha = float(self.last_epoch) / float(self.warmup_steps)
#_lr = max(self.base_lr * alpha, self.warmup_lr_init)
_lr = self.base_lr * alpha
return [_lr for _ in self.optimizer.param_groups]
def get_lr(self):
if self.last_epoch == -1:
return [self.warmup_lr_init for _ in self.optimizer.param_groups]
if self.last_epoch < self.warmup_steps:
return self.get_warmup_lr()
else:
alpha = pow(
1 - float(self.last_epoch - self.warmup_steps) /
float(self.max_steps - self.warmup_steps),
self.power,
)
return [self.base_lr * alpha for _ in self.optimizer.param_groups]
class StepScheduler(_LRScheduler):
def __init__(self,
optimizer,
base_lr,
lr_steps,
warmup_steps,
last_epoch=-1):
self.base_lr = base_lr
self.warmup_lr_init = 0.0001
self.lr_steps = lr_steps
self.warmup_steps: int = warmup_steps
super(StepScheduler, self).__init__(optimizer, last_epoch, False)
def get_warmup_lr(self):
alpha = float(self.last_epoch) / float(self.warmup_steps)
#_lr = max(self.base_lr * alpha, self.warmup_lr_init)
_lr = self.base_lr * alpha
return [_lr for _ in self.optimizer.param_groups]
def get_lr(self):
if self.last_epoch == -1:
return [self.warmup_lr_init for _ in self.optimizer.param_groups]
if self.last_epoch < self.warmup_steps:
return self.get_warmup_lr()
else:
alpha = 0.1 ** len([m for m in self.lr_steps if m <= self.last_epoch])
return [self.base_lr * alpha for _ in self.optimizer.param_groups]
def get_scheduler(opt, cfg):
if cfg.lr_func is not None:
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer=opt, lr_lambda=cfg.lr_func)
else:
#total_batch_size = cfg.batch_size * cfg.world_size
#warmup_steps = cfg.num_images // total_batch_size * cfg.warmup_epochs
#total_steps = cfg.num_images // total_batch_size * cfg.num_epochs
if cfg.lr_steps is None:
scheduler = PolyScheduler(
optimizer=opt,
base_lr=cfg.lr,
max_steps=cfg.total_steps,
warmup_steps=cfg.warmup_steps,
)
else:
scheduler = StepScheduler(
optimizer=opt,
base_lr=cfg.lr,
lr_steps=cfg.lr_steps,
warmup_steps=cfg.warmup_steps,
)
return scheduler