-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsubmission.py
70 lines (57 loc) · 1.82 KB
/
submission.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import argparse
import logging
import os
import sys
import pandas as pd
import torch
from torch import nn
from torch.backends import cudnn
import network
from core.loader import get_test_provider
FORMAT = '[%(levelname)s]: %(message)s'
logging.basicConfig(
level=logging.INFO,
format=FORMAT,
stream=sys.stdout
)
def submission(args):
test_loader, id_to_class = get_test_provider(args.bs)
net = network.ResNet18(num_classes=10)
net.load_state_dict(torch.load(args.model_path)['state_dict'])
net = nn.DataParallel(net)
net.eval()
if args.use_gpu:
net = net.cuda()
pred_labels = list()
indices = list()
for data, fname in test_loader:
if args.use_gpu:
data = data.cuda()
with torch.no_grad():
scores = net(data)
labels = scores.max(1)[1].cpu().numpy()
pred_labels.extend(labels)
indices.extend(fname.numpy())
df = pd.DataFrame({'id': indices, 'label': pred_labels})
df['label'] = df['label'].apply(lambda x: id_to_class[x])
df.to_csv('submission.csv', index=False)
def main():
parser = argparse.ArgumentParser(description='cifar10 model testing')
parser.add_argument('--model-path', type=str, default='checkpoints/model_best.pth.tar',
help='training batch size')
parser.add_argument('--bs', type=int, default=128, help='testing batch size')
parser.add_argument('--use-gpu', type=bool, default=True, help='decide if use gpu training')
args = parser.parse_args()
cudnn.benchmark = True
submission(args)
if __name__ == '__main__':
main()