-
Notifications
You must be signed in to change notification settings - Fork 5
/
main.py
39 lines (26 loc) · 1.13 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import sys
import os
import numpy as np
import cv2
import argparse
import torch
root = './'
ckpt_dir = os.path.join(root, 'ckpt')
log_dir = os.path.join(root, 'log')
''' Step 2. Train the network '''
parser = argparse.ArgumentParser()
parser.add_argument('--nepoch', type=int, default=20, help='number of epochs to train for')
parser.add_argument('--batch_size', type=int, default=1, help='batch size')
parser.add_argument('--num_workers', type=int, default=16, help='number of frames extracted from each video')
parser.add_argument('--lr', type=float, default=0.00001, help='')
parser.add_argument('--name', type=str, default='test', help='experiment name for log')
parser.add_argument('--config_path', type=str, default='configs/yaml/oliver.yaml')
parser.add_argument('--ckpt_dir', type=str, default=ckpt_dir)
parser.add_argument('--log_dir', type=str, default=log_dir)
parser.add_argument('--ckpt_epoch_freq', type=int, default=2, help='')
parser.add_argument('--load_ckpt_path', type=str, default='')
if __name__=="__main__":
from trainer import Trainer
opt_parser = parser.parse_args()
model = Trainer(opt_parser)
model.run()