-
Notifications
You must be signed in to change notification settings - Fork 1
/
config.py
57 lines (56 loc) · 2.06 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class Config:
def __init__(self):
self.batch_per_gpu = 32
self.num_gpu = 1
self.resize = (224,224)
self.mean = [0.485, 0.456, 0.406]
self.std = [0.229, 0.224, 0.225]
self.optimizer = {
'type': 'SGD',
'params': {
'momentum': 0.9
},
'learning_rate': {
'head_lr': 1e-3,
'backbone_lr': 1e-6
}
}
self.scheduler = {
'type': 'linear',
'params': {
'warmup_ratio': 0.03
}
}
self.do_eval = True
# self.eval_step = 2000 # This is only used when the training is iteration based (not epoch based)
# self.iterations = "8k" # This is only used when the training is iteration based (not epoch based)
self.num_train_epoch = 3
self.model = {
'backbone': 'dinov2_l', # 'dinov2_s', 'dinov2_b', 'dinov2_l', 'dinov2_g', 'siglip_384'
'head': 'single', # 'single', 'mlp'
# 'hidden_dims': [512, 256], # This is only used when head is 'mlp'
'num_classes': 3,
'freeze_backbone': False
}
self.loss = {
'loss_type': 'CE_loss', # 'CE_loss', 'class_balanced_CE_loss', 'Focal_loss', 'class_balanced_Focal_loss'
# 'beta': 0.99,
# 'gamma': 0.5
}
self.dataset = {
'train': {
'data_root': '/path/to/train',
# The rare_class_sampling is used when it is iteration based.
# This makes the training to sample the rare classes more frequently. However, it has a risk of not seeing the all data.
# rare_class_sampling might be useful when your class is multi-label classification.
# 'rare_class_sampling': {
# 'class_temp': 0.1
# }
},
'eval': {
'data_root': '/to/to/eval',
}
}
self.max_checkpoint = 1
def get_cfg(self):
return vars(self)