-
Notifications
You must be signed in to change notification settings - Fork 36
/
classify_plm.py
104 lines (78 loc) · 2.85 KB
/
classify_plm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import sys
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizerFast
from transformers import BertForSequenceClassification, AlbertForSequenceClassification
def define_argparser():
'''
Define argument parser to take inference using pre-trained model.
'''
p = argparse.ArgumentParser()
p.add_argument('--model_fn', required=True)
p.add_argument('--gpu_id', type=int, default=-1)
p.add_argument('--batch_size', type=int, default=256)
p.add_argument('--top_k', type=int, default=1)
config = p.parse_args()
return config
def read_text():
'''
Read text from standard input for inference.
'''
lines = []
for line in sys.stdin:
if line.strip() != '':
lines += [line.strip()]
return lines
def main(config):
saved_data = torch.load(
config.model_fn,
map_location='cpu' if config.gpu_id < 0 else 'cuda:%d' % config.gpu_id
)
train_config = saved_data['config']
bert_best = saved_data['bert']
index_to_label = saved_data['classes']
lines = read_text()
with torch.no_grad():
# Declare model and load pre-trained weights.
tokenizer = BertTokenizerFast.from_pretrained(train_config.pretrained_model_name)
model_loader = AlbertForSequenceClassification if train_config.use_albert else BertForSequenceClassification
model = model_loader.from_pretrained(
train_config.pretrained_model_name,
num_labels=len(index_to_label)
)
model.load_state_dict(bert_best)
if config.gpu_id >= 0:
model.cuda(config.gpu_id)
device = next(model.parameters()).device
# Don't forget turn-on evaluation mode.
model.eval()
y_hats = []
for idx in range(0, len(lines), config.batch_size):
mini_batch = tokenizer(
lines[idx:idx + config.batch_size],
padding=True,
truncation=True,
return_tensors="pt",
)
x = mini_batch['input_ids']
x = x.to(device)
mask = mini_batch['attention_mask']
mask = mask.to(device)
# Take feed-forward
y_hat = F.softmax(model(x, attention_mask=mask).logits, dim=-1)
y_hats += [y_hat]
# Concatenate the mini-batch wise result
y_hats = torch.cat(y_hats, dim=0)
# |y_hats| = (len(lines), n_classes)
probs, indice = y_hats.cpu().topk(config.top_k)
# |indice| = (len(lines), top_k)
for i in range(len(lines)):
sys.stdout.write('%s\t%s\n' % (
' '.join([index_to_label[int(indice[i][j])] for j in range(config.top_k)]),
lines[i]
))
if __name__ == '__main__':
config = define_argparser()
main(config)