-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathmain_SSRT.domainnet.py
56 lines (44 loc) · 1.63 KB
/
main_SSRT.domainnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from trainer.train import train_main
import time
import socket
import os
timestamp = time.strftime("%Y-%m-%d_%H.%M.%S", time.localtime())
hostName = socket.gethostname()
pid = os.getpid()
domains = ['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch']
for src in domains:
for tgt in domains:
if src == tgt:
continue
header = '''
++++++++++++++++++++++++++++++++++++++++++++++++
{}
++++++++++++++++++++++++++++++++++++++++++++++++
@{}:{}
'''.format
args = ['--model=SSRT',
'--base_net=vit_base_patch16_224',
'--gpu=0',
'--timestamp={}'.format(timestamp),
'--dataset=DomainNet',
'--source_path=data/{}_train.txt'.format(src),
'--target_path=data/{}_train.txt'.format(tgt),
'--test_path=data/{}_test.txt'.format(tgt),
'--batch_size=32',
'--lr=0.004',
'--train_epoch=40',
'--save_epoch=40',
'--eval_epoch=5',
'--iters_per_epoch=1000',
'--sr_loss_weight=0.2',
'--sr_alpha=0.3',
'--sr_layers=[0,4,8]',
'--sr_epsilon=0.4',
'--use_safe_training=True',
'--adap_adjust_T=1000',
'--adap_adjust_L=4',
'--use_tensorboard=False',
'--tensorboard_dir=tbs/SSRT',
'--use_file_logger=True',
'--log_dir=logs/SSRT']
train_main(args, header('\n\t\t'.join(args), hostName, pid))