diff --git a/gensim/models/word2vec.py b/gensim/models/word2vec.py index c363ac42fa..9fb0a4a9bb 100755 --- a/gensim/models/word2vec.py +++ b/gensim/models/word2vec.py @@ -1819,24 +1819,20 @@ def predict_output_word(self, context_words_list, topn=10): if not hasattr(self.wv, 'vectors') or not hasattr(self, 'syn1neg'): raise RuntimeError("Parameters required for predicting the output words not found.") - - if all(isinstance(w, int) for w in context_words_list): - # then, indices were passed. Check they are valid - word2_indices = np.array(context_words_list) - if np.any(word2_indices < 0): - logger.warning("All input context word indices must be non-negative.") - return None - # take only the ones in the vocabulary - word2_indices = word2_indices[word2_indices < self.wv.vectors.shape[0]] - if word2_indices.size == 0: - logger.warning("All the input context words are out-of-vocabulary for the current model.") - return None - else: - # then, words were passed. Retrieve their indices - word2_indices = [self.wv.get_index(w) for w in context_words_list if w in self.wv] - if not word2_indices: - logger.warning("All the input context words are out-of-vocabulary for the current model.") - return None + + # Retrieve indices if words were passed as input, otherwise keep the input indices + # Remark : out-of-vocabulary words are discarded. + word2_indices = [] + max_index = self.wv.vectors.shape[0] + for w in context_words_list: + if w in self.wv: + word2_indices.append(self.wv.get_index(w)) + elif isinstance(w, int) and (w < max_index): + word2_indices.append(w) + + if not word2_indices: + logger.warning("All the input context words are out-of-vocabulary for the current model.") + return None l1 = np.sum(self.wv.vectors[word2_indices], axis=0) if self.cbow_mean: