diff --git a/gensim/models/keyedvectors.py b/gensim/models/keyedvectors.py index b35a974f4f..fa1999aa8d 100644 --- a/gensim/models/keyedvectors.py +++ b/gensim/models/keyedvectors.py @@ -50,7 +50,8 @@ And on analogies:: - >>> word_vectors.accuracy(os.path.join(module_path, 'test_data', 'questions-words.txt')) + >>> word_vectors.evaluate_word_analogies(os.path.join(module_path, 'test_data', 'questions-words.txt'))[0] + 0.58 and so on. @@ -850,6 +851,139 @@ def n_similarity(self, ws1, ws2): v2 = [self[word] for word in ws2] return dot(matutils.unitvec(array(v1).mean(axis=0)), matutils.unitvec(array(v2).mean(axis=0))) + @staticmethod + def _log_evaluate_word_analogies(section): + """Calculate score by section, helper for + :meth:`~gensim.models.keyedvectors.WordEmbeddingsKeyedVectors.evaluate_word_analogies`. + + Parameters + ---------- + section : dict of (str, (str, str, str, str)) + Section given from evaluation. + + Returns + ------- + float + Accuracy score. + + """ + correct, incorrect = len(section['correct']), len(section['incorrect']) + if correct + incorrect > 0: + score = correct / (correct + incorrect) + logger.info("%s: %.1f%% (%i/%i)", section['section'], 100.0 * score, correct, correct + incorrect) + return score + + def evaluate_word_analogies(self, analogies, restrict_vocab=300000, case_insensitive=True, dummy4unknown=False): + """Compute performance of the model on an analogy test set. + + This is modern variant of :meth:`~gensim.models.keyedvectors.WordEmbeddingsKeyedVectors.accuracy`, see + `discussion on GitHub #1935 `_. + + The accuracy is reported (printed to log and returned as a score) for each section separately, + plus there's one aggregate summary at the end. + + This method corresponds to the `compute-accuracy` script of the original C word2vec. + See also `Analogy (State of the art) `_. + + Parameters + ---------- + analogies : str + Path to file, where lines are 4-tuples of words, split into sections by ": SECTION NAME" lines. + See `gensim/test/test_data/questions-words.txt` as example. + restrict_vocab : int, optional + Ignore all 4-tuples containing a word not in the first `restrict_vocab` words. + This may be meaningful if you've sorted the model vocabulary by descending frequency (which is standard + in modern word embedding models). + case_insensitive : bool, optional + If True - convert all words to their uppercase form before evaluating the performance. + Useful to handle case-mismatch between training tokens and words in the test set. + In case of multiple case variants of a single word, the vector for the first occurrence + (also the most frequent if vocabulary is sorted) is taken. + dummy4unknown : bool, optional + If True - produce zero accuracies for 4-tuples with out-of-vocabulary words. + Otherwise, these tuples are skipped entirely and not used in the evaluation. + + Returns + ------- + (float, list of dict of (str, (str, str, str)) + Overall evaluation score and full lists of correct and incorrect predictions divided by sections. + + """ + ok_vocab = [(w, self.vocab[w]) for w in self.index2word[:restrict_vocab]] + ok_vocab = {w.upper(): v for w, v in reversed(ok_vocab)} if case_insensitive else dict(ok_vocab) + oov = 0 + logger.info("Evaluating word analogies for top %i words in the model on %s", restrict_vocab, analogies) + sections, section = [], None + quadruplets_no = 0 + for line_no, line in enumerate(utils.smart_open(analogies)): + line = utils.to_unicode(line) + if line.startswith(': '): + # a new section starts => store the old section + if section: + sections.append(section) + self._log_evaluate_word_analogies(section) + section = {'section': line.lstrip(': ').strip(), 'correct': [], 'incorrect': []} + else: + if not section: + raise ValueError("Missing section header before line #%i in %s" % (line_no, analogies)) + try: + if case_insensitive: + a, b, c, expected = [word.upper() for word in line.split()] + else: + a, b, c, expected = [word for word in line.split()] + except ValueError: + logger.info("Skipping invalid line #%i in %s", line_no, analogies) + continue + quadruplets_no += 1 + if a not in ok_vocab or b not in ok_vocab or c not in ok_vocab or expected not in ok_vocab: + oov += 1 + if dummy4unknown: + logger.debug('Zero accuracy for line #%d with OOV words: %s', line_no, line.strip()) + section['incorrect'].append((a, b, c, expected)) + else: + logger.debug("Skipping line #%i with OOV words: %s", line_no, line.strip()) + continue + original_vocab = self.vocab + self.vocab = ok_vocab + ignore = {a, b, c} # input words to be ignored + predicted = None + # find the most likely prediction using 3CosAdd (vector offset) method + # TODO: implement 3CosMul and set-based methods for solving analogies + sims = self.most_similar(positive=[b, c], negative=[a], topn=5, restrict_vocab=restrict_vocab) + self.vocab = original_vocab + for element in sims: + predicted = element[0].upper() if case_insensitive else element[0] + if predicted in ok_vocab and predicted not in ignore: + if predicted != expected: + logger.debug("%s: expected %s, predicted %s", line.strip(), expected, predicted) + break + if predicted == expected: + section['correct'].append((a, b, c, expected)) + else: + section['incorrect'].append((a, b, c, expected)) + if section: + # store the last section, too + sections.append(section) + self._log_evaluate_word_analogies(section) + + total = { + 'section': 'Total accuracy', + 'correct': sum((s['correct'] for s in sections), []), + 'incorrect': sum((s['incorrect'] for s in sections), []), + } + + oov_ratio = float(oov) / quadruplets_no * 100 + logger.info('Quadruplets with out-of-vocabulary words: %.1f%%', oov_ratio) + if not dummy4unknown: + logger.info( + 'NB: analogies containing OOV words were skipped from evaluation! ' + 'To change this behavior, use "dummy4unknown=True"' + ) + analogies_score = self._log_evaluate_word_analogies(total) + sections.append(total) + # Return the overall score and the full lists of correct and incorrect analogies + return analogies_score, sections + @staticmethod def log_accuracy(section): correct, incorrect = len(section['correct']), len(section['incorrect']) @@ -859,6 +993,7 @@ def log_accuracy(section): section['section'], 100.0 * correct / (correct + incorrect), correct, correct + incorrect ) + @deprecated("Method will be removed in 4.0.0, use self.evaluate_word_analogies() instead") def accuracy(self, questions, restrict_vocab=30000, most_similar=most_similar, case_insensitive=True): """ Compute accuracy of the model. `questions` is a filename where lines are @@ -881,7 +1016,6 @@ def accuracy(self, questions, restrict_vocab=30000, most_similar=most_similar, c occurrence (also the most frequent if vocabulary is sorted) is taken. This method corresponds to the `compute-accuracy` script of the original C word2vec. - """ ok_vocab = [(w, self.vocab[w]) for w in self.index2word[:restrict_vocab]] ok_vocab = {w.upper(): v for w, v in reversed(ok_vocab)} if case_insensitive else dict(ok_vocab) @@ -898,19 +1032,18 @@ def accuracy(self, questions, restrict_vocab=30000, most_similar=most_similar, c section = {'section': line.lstrip(': ').strip(), 'correct': [], 'incorrect': []} else: if not section: - raise ValueError("missing section header before line #%i in %s" % (line_no, questions)) + raise ValueError("Missing section header before line #%i in %s" % (line_no, questions)) try: if case_insensitive: a, b, c, expected = [word.upper() for word in line.split()] else: a, b, c, expected = [word for word in line.split()] except ValueError: - logger.info("skipping invalid line #%i in %s", line_no, questions) + logger.info("Skipping invalid line #%i in %s", line_no, questions) continue if a not in ok_vocab or b not in ok_vocab or c not in ok_vocab or expected not in ok_vocab: - logger.debug("skipping line #%i with OOV words: %s", line_no, line.strip()) + logger.debug("Skipping line #%i with OOV words: %s", line_no, line.strip()) continue - original_vocab = self.vocab self.vocab = ok_vocab ignore = {a, b, c} # input words to be ignored diff --git a/gensim/models/word2vec.py b/gensim/models/word2vec.py index f51b4cd25f..7ca85a6340 100755 --- a/gensim/models/word2vec.py +++ b/gensim/models/word2vec.py @@ -74,7 +74,8 @@ And on analogies:: - >>> model.wv.accuracy(os.path.join(module_path, 'test_data', 'questions-words.txt')) + >>> model.wv.evaluate_word_analogies(os.path.join(module_path, 'test_data', 'questions-words.txt'))[0] + 0.58 and so on. @@ -896,7 +897,7 @@ def reset_from(self, other_model): def log_accuracy(section): return Word2VecKeyedVectors.log_accuracy(section) - @deprecated("Method will be removed in 4.0.0, use self.wv.accuracy() instead") + @deprecated("Method will be removed in 4.0.0, use self.wv.evaluate_word_analogies() instead") def accuracy(self, questions, restrict_vocab=30000, most_similar=None, case_insensitive=True): most_similar = most_similar or Word2VecKeyedVectors.most_similar return self.wv.accuracy(questions, restrict_vocab, most_similar, case_insensitive)