Skip to content

Commit

Permalink
Word2Vec.predict_output_word: Changed handling of ints and strs, tryi…
Browse files Browse the repository at this point in the history
…ng to trying to make it more compact and versatile.
  • Loading branch information
M-Demay authored and Mathis committed May 25, 2021
1 parent 91d5dca commit 84258b4
Showing 1 changed file with 14 additions and 18 deletions.
32 changes: 14 additions & 18 deletions gensim/models/word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1819,24 +1819,20 @@ def predict_output_word(self, context_words_list, topn=10):

if not hasattr(self.wv, 'vectors') or not hasattr(self, 'syn1neg'):
raise RuntimeError("Parameters required for predicting the output words not found.")

if all(isinstance(w, int) for w in context_words_list):
# then, indices were passed. Check they are valid
word2_indices = np.array(context_words_list)
if np.any(word2_indices < 0):
logger.warning("All input context word indices must be non-negative.")
return None
# take only the ones in the vocabulary
word2_indices = word2_indices[word2_indices < self.wv.vectors.shape[0]]
if word2_indices.size == 0:
logger.warning("All the input context words are out-of-vocabulary for the current model.")
return None
else:
# then, words were passed. Retrieve their indices
word2_indices = [self.wv.get_index(w) for w in context_words_list if w in self.wv]
if not word2_indices:
logger.warning("All the input context words are out-of-vocabulary for the current model.")
return None

# Retrieve indices if words were passed as input, otherwise keep the input indices
# Remark : out-of-vocabulary words are discarded.
word2_indices = []
max_index = self.wv.vectors.shape[0]
for w in context_words_list:
if w in self.wv:
word2_indices.append(self.wv.get_index(w))
elif isinstance(w, int) and (w < max_index):
word2_indices.append(w)

if not word2_indices:
logger.warning("All the input context words are out-of-vocabulary for the current model.")
return None

l1 = np.sum(self.wv.vectors[word2_indices], axis=0)
if self.cbow_mean:
Expand Down

0 comments on commit 84258b4

Please sign in to comment.