-
Notifications
You must be signed in to change notification settings - Fork 9
/
main.py
74 lines (66 loc) · 2.32 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
from utils import get_config, get_log_dir, str2bool
from data_loader import get_loader
from train import Trainer
import warnings
from tensorboardX import SummaryWriter
warnings.filterwarnings('ignore')
#from visualizer import Visualizer
resume = ''
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
# Parameters to set
parser.add_argument('--mode',
type=str,
default='train',
choices=['train', 'test'])
parser.add_argument("--gpu_id", type=int, default=-1)
parser.add_argument('--dname',
type=str,
default='lm',
choices=['lm', 'ycb'])
parser.add_argument("--root_dataset",
type=str,
default='./datasets/LINEMOD')
parser.add_argument("--resume_train", type=str2bool, default=False)
parser.add_argument("--optim",
type=str,
default='Adam',
choices=['Adam', 'SGD'])
parser.add_argument("--batch_size",
type=str,
default='4')
parser.add_argument("--class_name",
type=str,
default='ape')
parser.add_argument("--initial_lr",
type=float,
default=1e-4)
parser.add_argument("--kpt_num",
type=str,
default='1')
parser.add_argument('--model_dir',
type=str,
default='ckpts/')
parser.add_argument('--demo_mode',
type=bool,
default=False)
parser.add_argument('--test_occ',
type=bool,
default=False)
opts = parser.parse_args()
cfg = get_config()[1]
opts.cfg = cfg
if opts.mode in ['train']:
opts.out = get_log_dir(opts.dname+'/'+opts.class_name+'Kp'+opts.kpt_num, cfg)
print('Output logs: ', opts.out)
vis = SummaryWriter(logdir=opts.out+'/tbLog/')
else:
vis = []
data = get_loader(opts)
trainer = Trainer(data, opts, vis)
if opts.mode == 'test':
trainer.Test()
else:
trainer.Train()