-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
97 lines (64 loc) · 3.15 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import subprocess
os.environ['PYTHONIOENCODING'] = 'utf-8'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
from utils.parser import get_parser
from utils.logger import get_logger
parser = get_parser()
option = parser.parse_args()
root_path = 'result'
logs_folder = os.path.join(root_path, 'logs', option.name)
save_folder = os.path.join(root_path, 'save', option.name)
sample_folder = os.path.join(root_path, 'sample', option.name)
result_folder = os.path.join(root_path, 'result', option.name)
subprocess.run('mkdir -p %s' % logs_folder, shell = True)
subprocess.run('mkdir -p %s' % save_folder, shell = True)
subprocess.run('mkdir -p %s' % sample_folder, shell = True)
subprocess.run('mkdir -p %s' % result_folder, shell = True)
logs_path = os.path.join(logs_folder, 'main.log')
save_path = os.path.join(save_folder, 'best.pth')
logger = get_logger(option.name, logs_path)
from loaders.loader1 import get_loader as get_loader1
from modules.module1 import get_module as get_module1
from utils.misc import train, valid, save_checkpoint, load_checkpoint, save_sample
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
logger.info('prepare loader')
train_loader, valid_loader, test_loader = get_loader1(option)
pad_idx = option.num_max # trg_pad_idx
logger.info('prepare module')
pointer_network = get_module1(option)
pointer_network = pointer_network.to(device)
logger.info('prepare envs')
optimizer = optim.Adam(pointer_network.parameters(), lr = option.lr, weight_decay = option.wd)
criterion = nn.CrossEntropyLoss(ignore_index = pad_idx)
is_coverage = option.is_coverage
st_coverage = option.st_coverage
logger.info('start training!')
best_valid_loss = float('inf')
for epoch in range(1, option.num_epochs + 1):
train_info = train(pointer_network, train_loader, criterion, optimizer, device, pad_idx, is_coverage and st_coverage <= epoch)
valid_info = valid(pointer_network, valid_loader, criterion, optimizer, device, pad_idx, is_coverage and st_coverage <= epoch)
logger.info(
'[Epoch %d] Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f' %
(epoch, train_info['loss'], train_info['acc'], valid_info['loss'], valid_info['acc'])
)
if best_valid_loss > valid_info['loss']:
best_valid_loss = valid_info['loss']
save_checkpoint(save_path, pointer_network, optimizer, epoch)
save_sample(sample_folder, valid_info['outputs'], valid_info['targets'], valid_info['maskeds'])
logger.info('start testing!')
cur_epoch = load_checkpoint(save_path, pointer_network, optimizer)
test_info = valid(pointer_network, test_loader, criterion, optimizer, device, pad_idx, is_coverage)
logger.info('Test Loss: %f, Test Acc: %f' % (test_info['loss'], test_info['acc']))
save_sample(result_folder, test_info['outputs'], test_info['targets'], test_info['maskeds'])