-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
138 lines (106 loc) · 5.57 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
This code runs a fine-tuned BERT model to estimate the similarity between two arguments.
In this example, we include arguments from three topics (courtesy to www.procon.org): Zoos, Vegetarianism, Climate Change
Each argument is compared against each other argument. Arguments from the same topic, e.g. on zoos, should be ranked higher
than arguments from different topics.
Usage: python inference.py
"""
from pytorch_pretrained_bert.tokenization import BertTokenizer
import torch
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
from train import InputExample, convert_examples_to_features
from SigmoidBERT import SigmoidBERT
import argparse
import pandas as pd
from tqdm import tqdm
import pickle
from collections import defaultdict
parser = argparse.ArgumentParser()
parser.add_argument("--input_text_file_csv", type=str, \
help='csv file containing input test')
parser.add_argument("--input_file_csv_text_field", type=str, \
help='column containing text to be analyzed')
parser.add_argument("--output_files_prefix", type=str, default=False, \
help='file prefix to dump output as pickled dict files')
args = parser.parse_args()
# See the README.md where to download pre-trained models
model_path = 'bert_output/ukp_aspects_all' #ukp_aspects_all model: trained
#model_path = 'bert_output/misra_all' #misra_all model: Trained on all 3 topics from Misra et al., 2016
max_seq_length = 64
eval_batch_size = 8
if args.input_text_file_csv:
input_df = pd.read_csv(args.input_text_file_csv, sep='\t')
arguments = list(input_df[args.input_file_csv_text_field])
arguments = list(set(arguments)) # remove repetitions
else:
arguments = ['Zoos save species from extinction and other dangers.',
'Zoos produce helpful scientific research.',
'Zoos are detrimental to animals\' physical health.',
'Zoo confinement is psychologically damaging to animals.',
'Eating meat is not cruel or unethical; it is a natural part of the cycle of life. ',
'It is cruel and unethical to kill animals for food when vegetarian options are available',
'Overwhelming scientific consensus says human activity is primarily responsible for global climate change.',
'Rising levels of human-produced gases released into the atmosphere create a greenhouse effect that traps heat and causes global warming.'
]
#Compare every argument with each other
input_examples = []
output_examples = []
for i in tqdm(range(0, len(arguments)-1)):
for j in range(i+1, len(arguments)):
input_examples.append(InputExample(text_a=arguments[i], text_b=arguments[j], label=-1))
output_examples.append([arguments[i], arguments[j]])
tokenizer = BertTokenizer.from_pretrained(model_path, do_lower_case=True)
eval_features = convert_examples_to_features(input_examples, max_seq_length, tokenizer)
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids)
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=eval_batch_size)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SigmoidBERT.from_pretrained(model_path,)
model.to(device)
model.eval()
predicted_logits = []
with torch.no_grad():
for input_ids, input_mask, segment_ids in tqdm(eval_dataloader):
input_ids = input_ids.to(device)
input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device)
logits = model(input_ids, segment_ids, input_mask)
logits = logits.detach().cpu().numpy()
predicted_logits.extend(logits[:, 0])
for idx in range(len(predicted_logits)):
output_examples[idx].append(predicted_logits[idx])
#Sort by similarity
output_examples = sorted(output_examples, key=lambda x: x[2], reverse=True)
if args.output_files_prefix:
unique_sentences = list(set(arguments))
sentence_to_id = defaultdict(dict)
id_to_sentence = defaultdict(dict)
for ctr, unique_sentence in enumerate(unique_sentences):
sentence_to_id[unique_sentence] = ctr
id_to_sentence[ctr] = unique_sentence
sim_scores_dict = defaultdict(dict)
for sent_a, sent_b, sim_score in output_examples:
sent_a_id = sentence_to_id[sent_a]
sent_b_id = sentence_to_id[sent_b]
sim_scores_dict[sent_a_id][sent_b_id] = sim_score
sentence_to_id_file = args.output_files_prefix + '_sentence_to_id.pickle'
id_to_sentence_file = args.output_files_prefix + '_id_to_sentence_file.pickle'
sim_scores_dict_file = args.output_files_prefix + '_sim_scores_dict_file.pickle'
with open(sentence_to_id_file, 'wb') as handle:
pickle.dump(sentence_to_id, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(id_to_sentence_file, 'wb') as handle:
pickle.dump(id_to_sentence, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(sim_scores_dict_file, 'wb') as handle:
pickle.dump(sim_scores_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
if len(output_examples) > 10:
output_examples = output_examples[:5] + output_examples[-5:]
print("Predicted similarities (sorted by similarity):")
for idx in range(len(output_examples)):
example = output_examples[idx]
print("Sentence A:", example[0])
print("Sentence B:", example[1])
print("Similarity:", example[2])
print("")