-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathldm.py
62 lines (45 loc) · 2.05 KB
/
ldm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
from utils import *
# Parse input augments
args = get_args_ldm()
# Set PyTorch to use only the specified GPU
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, args.gpu))
# Make project directory if not exist
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
from dataset import *
from trainer import *
def run(args):
# Initialize dataset and trainer
if args.option == 'surfpos':
train_dataset = SurfPosData(args.data, args.list, validate=False, aug=args.data_aug, args=args)
val_dataset = SurfPosData(args.data, args.list, validate=True, aug=False, args=args)
ldm = SurfPosTrainer(args, train_dataset, val_dataset)
elif args.option == 'surfz':
train_dataset = SurfZData(args.data, args.list, validate=False, aug=args.data_aug, args=args)
val_dataset = SurfZData(args.data, args.list, validate=True, aug=False, args=args)
ldm = SurfZTrainer(args, train_dataset, val_dataset)
elif args.option == 'edgepos':
train_dataset = EdgePosData(args.data, args.list, validate=False, aug=args.data_aug, args=args)
val_dataset = EdgePosData(args.data, args.list, validate=True, aug=False, args=args)
ldm = EdgePosTrainer(args, train_dataset, val_dataset)
elif args.option == 'edgez':
train_dataset = EdgeZData(args.data, args.list, validate=False, aug=args.data_aug, args=args)
val_dataset = EdgeZData(args.data, args.list, validate=True, aug=False, args=args)
ldm = EdgeZTrainer(args, train_dataset, val_dataset)
else:
assert False, 'please choose between [surfpos, surfz, edgepos, edgez]'
print('Start training...')
# Main training loop
for _ in range(args.train_nepoch):
# Train for one epoch
ldm.train_one_epoch()
# Evaluate model performance on validation set
if ldm.epoch % args.test_nepoch == 0:
ldm.test_val()
# save model
if ldm.epoch % args.save_nepoch == 0:
ldm.save_model()
return
if __name__ == "__main__":
run(args)