-
Notifications
You must be signed in to change notification settings - Fork 0
/
qe_multiple_entities.py
339 lines (301 loc) · 12.5 KB
/
qe_multiple_entities.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
import json
import logging
from enum import Enum, auto
from functools import lru_cache
import numpy as np
from nltk import PerceptronTagger, WordNetLemmatizer
from sqlitedict import SqliteDict
import models_manager
import utils
from word_onthology import EventDetector
class MultipleEntitiesQEMethod(Enum):
none = auto()
events = auto()
events_temporal = auto()
class QEMultipleEntities:
def __init__(
self,
qe_single_entity,
two_entities_qe_method=None,
k=10,
onthology=None,
event_detector=None,
min_score=None,
candidates_lambda=None,
):
self.qe_single_entity = qe_single_entity
self.two_entities_qe_method = two_entities_qe_method
self.k = k
self.models_manager = qe_single_entity.models_manager
self.models_manager = (
onthology.models_manager if onthology is not None else self.models_manager
)
self.global_model = qe_single_entity.global_model
self.global_model = (
self.models_manager[models_manager.STATIC_YEAR]
if self.models_manager is not None
else self.global_model
)
self.onthology = onthology
self.events = json.load(open('data/events_since1980.json', encoding='utf-8'))
self.event_id_name = json.load(
open('data/event_id_name_since1980.json', encoding='utf-8')
)
self.pos_tagger = PerceptronTagger()
self.wnl = WordNetLemmatizer()
self.lemmatize = lru_cache()(utils.lemmatize_word)
self.event_to_top_tfidf = SqliteDict(
'data/event_to_top_tfidf_100.sqlite', flag='r'
)
self.event_detector = (
event_detector if event_detector else EventDetector.WikipediaFrequency
)
self.min_score = min_score
self.candidates_lambda = (
candidates_lambda if candidates_lambda is not None else 0.6
)
def _expand_entities(
self, entities, two_entities_qe_method=None, lemmatize=True, k=None
):
"""
Apply QE with the selected method and return a string of expansion terms
"""
if k is None:
k = self.k
if lemmatize:
entities = [
self.lemmatize(entity, self.wnl, self.pos_tagger) for entity in entities
]
entities = [entity.replace('_', ' ').title() for entity in entities]
qe_method = (
two_entities_qe_method
if two_entities_qe_method is not None
else self.two_entities_qe_method
)
if qe_method == MultipleEntitiesQEMethod.none:
expansions = {}
elif qe_method.name.startswith('events'):
expansions = self.expand_using_events(entities, k=k)
else:
raise ValueError('Unknown QE method: {}'.format(qe_method))
if expansions:
expansions = utils.normalize(
{
exp.replace('ENTITY/', '').replace('_', ' ').lower(): score
for exp, score in expansions.items()
}
)
expansions = {
utils.tokenize(exp, to_str=True): score
for exp, score in expansions.items()
}
expansions = {exp: score for exp, score in expansions.items() if exp}
return expansions
def expand(self, entities, k=None):
expansions = self._expand_entities(entities, k=k)
return expansions
def filter_and_take_top_words(self, word_score, entities, k=None):
if k is None:
k = self.k
word_score.sort(key=lambda tup: tup[1]) # sort by word
filtered_word_score = []
for (word_i, word, score) in word_score:
word_lower = word.lower()
if entities is not None and any(
word_lower in entity.lower() for entity in entities
):
continue
# take the shorter option if a similar word was already selected + boost its score
if any(
prev_word.lower() in word_lower
for (_, prev_word, _) in filtered_word_score
):
filtered_word_score = [
(_, word, score * 2)
if word.lower() in word_lower
else (_, word, score)
for (_, word, score) in filtered_word_score
]
continue
filtered_word_score.append((word_i, word, score))
word_score = utils.get_top_items(
filtered_word_score, k, sort_by_index=2
) # sort by score, descending
return word_score
def _calc_temporal_relevance(self, word, event_year, event_neighbors):
years_to_calc = [event_year - 1, event_year + 1]
if not all(
year in self.models_manager.get_all_years() for year in years_to_calc
):
return np.nan
scores = []
for neighbor in event_neighbors:
hist = {
year: self.models_manager[year].similarity(word, neighbor)
for year in years_to_calc
}
# skip this word if it doesn't have an embedding for all years around the event's time
if any(none_val in hist.values() for none_val in [np.nan, 0, None]):
continue
score = hist[event_year + 1] / hist[event_year - 1]
scores.append(score)
return np.mean(scores) if scores and np.mean(scores) != np.nan else np.nan
def _find_candidate_words_for_event(self, event, model, entities, query_avg):
candidates = set()
# take words with a high tf/idf score in the event's page
top_tfidf = self.event_to_top_tfidf.get(event)
if top_tfidf:
candidates.update(
model.get_key(w)
for w in list(top_tfidf)[: int(self.candidates_lambda * self.k)]
if w in model
)
else:
logging.warning(f'the event {event} does not exist in the TF-IDF model')
# interpolate with words similar to the query
topn = (
self.k
if not candidates
else int((1 - self.candidates_lambda) * self.k) + len(entities)
)
candidates.update(
w
for w, p in model.most_similar(
[query_avg],
topn=topn,
filter_func=lambda word: word in model and not model.is_entity(word),
)
)
candidates = candidates - set(entities)
return list(candidates)
def find_words_based_on_events(
self, entities, year_event_scores, max_words_per_event
):
use_temporal_models = len(self.models_manager.year_to_model) > 1
query_avg = (
self.global_model.get_average_vector(entities, require_all=True)
if not use_temporal_models
else None
)
# look for relevant words (based on the events)
event_word_score = {}
for year, event_scores in year_event_scores.items():
model = (
self.models_manager[year] if use_temporal_models else self.global_model
)
if not model or any(entity not in model for entity in entities):
# fallback: use the global model
model = self.global_model
if not model or any(entity not in model for entity in entities):
continue
if use_temporal_models:
query_avg = model.get_average_vector(entities, require_all=True)
for event, score in event_scores:
if event not in model:
# logging.info(f'Ditching {event} because it doesn\'t exist in the model')
continue
candidates = self._find_candidate_words_for_event(
event, model, entities, query_avg
)
if not candidates:
logging.info(f'Event "{event}" has no candidates')
continue
# semantic similarity of each word (candidate) and the query
query_similarities = model.similarities(query_avg, candidates)
# similarity with this event
event_similarities = model.similarities(event, candidates)
# similarity of the event and the query (duplicated over all word candidates)
event_query_similarities = np.array(
[model.similarity(event, query_avg)] * len(candidates)
)
all_scores = [
query_similarities * 3,
event_similarities,
event_query_similarities,
]
top_tfidf = self.event_to_top_tfidf.get(event)
if top_tfidf:
# TF/IDF of the candidate in this event's page
tf_idf_scores = [
top_tfidf.get(model.get_word(candidate), 0)
for candidate in candidates
]
all_scores.append(np.array(tf_idf_scores))
if use_temporal_models:
# score of relevance to the event's neighbors
event_neighbors = [w for w in list(top_tfidf)[:5] if w in model]
neighbors_scores = [
self._calc_temporal_relevance(
candidate, year, event_neighbors
)
for candidate in candidates
]
neighbors_scores = utils.normalize(neighbors_scores)
all_scores.append(np.array(neighbors_scores))
final_scores = np.nanmean(all_scores, axis=0)
# take the words with top scores
max_words_margin = int(3 * max_words_per_event)
positive_word_indices = utils.argpartition(
final_scores, max_words_margin
)
word_score = [
(
word_i,
model.get_word(candidates[word_i]),
round(float(final_scores[word_i]), 2),
)
for word_i in positive_word_indices
]
word_score = self.filter_and_take_top_words(
word_score, entities, k=max_words_per_event
)
event_word_score[event] = word_score
if not event_word_score:
logging.info('* No candidate expansions were found')
return None
# aggregate all words of all events and sort the words
word_score = [
word_score
for word_score_list in event_word_score.values()
for word_score in word_score_list
] # flatten the list
word_score = self.filter_and_take_top_words(word_score, entities, k=self.k)
return word_score
def expand_using_events(self, entities, k=None):
if k is None:
k = self.k
event_score_threshold = 0.004
# tokenize and remove stopwords
entities = utils.tokenize(' '.join(entities), remove_stopwords=True)
key_years = self.models_manager.get_all_years()
# for each of the detected years, look for relevant events
year_event_score = self.onthology.find_events_for_entities(
entities,
self.event_detector,
event_score_threshold,
key_years,
min_score=self.min_score,
)
if not year_event_score: # no events were found
logging.info('* No events were found beyond the threshold')
return None
# find relevant words based on the events
event_scores = [
(event, score)
for year, event_scores in year_event_score.items()
for (event, score) in event_scores
]
if not event_scores: # no events were found
logging.info('** No events were found beyond the threshold')
return None
max_words_per_event = max(10, k // len(event_scores) * 2)
word_scores = self.find_words_based_on_events(
entities, year_event_score, max_words_per_event
)
# use the top k words as query expansion_score
expansions = (
{word: score for (word_i, word, score) in word_scores[:k]}
if word_scores
else {}
)
return expansions