-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
19 changed files
with
1,429 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (C) 2010 Radim Rehurek <radimrehurek@seznam.cz> | ||
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html | ||
|
||
""" | ||
Module for calculating topic coherence in python. This is the implementation of | ||
the four stage topic coherence pipeline from the paper [1]. | ||
The four stage pipeline is basically: | ||
Segmentation -> Probability Estimation -> Confirmation Measure -> Aggregation. | ||
Implementation of this pipeline allows for the user to in essence "make" a | ||
coherence measure of his/her choice by choosing a method in each of the pipelines. | ||
[1] Michael Roeder, Andreas Both and Alexander Hinneburg. Exploring the space of topic | ||
coherence measures. http://svn.aksw.org/papers/2015/WSDM_Topic_Evaluation/public.pdf. | ||
""" | ||
|
||
import logging | ||
|
||
from gensim import interfaces | ||
from gensim.topic_coherence import (segmentation, probability_estimation, | ||
direct_confirmation_measure, indirect_confirmation_measure, | ||
aggregation) | ||
from gensim.corpora import Dictionary | ||
from gensim.matutils import argsort | ||
from gensim.utils import is_corpus | ||
from gensim.models.ldamodel import LdaModel | ||
from gensim.models.wrappers import LdaVowpalWabbit | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class CoherenceModel(interfaces.TransformationABC): | ||
""" | ||
Objects of this class allow for building and maintaining a model for topic | ||
coherence. | ||
The main methods are: | ||
1. constructor, which initializes the four stage pipeline by accepting a coherence measure, | ||
2. the ``get_coherence()`` method, which returns the topic coherence. | ||
>>> cm = CoherenceModel(model=tm, corpus=corpus, coherence='u_mass') # tm is the trained topic model | ||
>>> cm.get_coherence() | ||
Model persistency is achieved via its load/save methods. | ||
""" | ||
def __init__(self, model, texts=None, corpus=None, dictionary=None, coherence='c_v'): | ||
""" | ||
Args: | ||
---- | ||
model : Pre-trained topic model. | ||
texts : Tokenized texts. Needed for coherence models that use sliding window based probability estimator. | ||
corpus : Gensim document corpus. | ||
dictionary : Gensim dictionary mapping of id word to create corpus. | ||
coherence : Coherence measure to be used. Supported values are: | ||
u_mass | ||
c_v | ||
""" | ||
if texts is None and corpus is None: | ||
raise ValueError("One of texts or corpus has to be provided.") | ||
if coherence == 'u_mass': | ||
if is_corpus(corpus)[0]: | ||
if dictionary is None: | ||
if model.id2word[0] == 0: | ||
raise ValueError("The associated dictionary should be provided with the corpus or 'id2word' for topic model" | ||
"should be set as the dictionary.") | ||
else: | ||
self.dictionary = model.id2word | ||
else: | ||
self.dictionary = dictionary | ||
self.corpus = corpus | ||
elif texts is not None: | ||
self.texts = texts | ||
if dictionary is None: | ||
self.dictionary = Dictionary(self.texts) | ||
else: | ||
self.dictionary = dictionary | ||
self.corpus = [self.dictionary.doc2bow(text) for text in self.texts] | ||
else: | ||
raise ValueError("Either 'corpus' with 'dictionary' or 'texts' should be provided for %s coherence." % coherence) | ||
|
||
elif coherence == 'c_v': | ||
if texts is None: | ||
raise ValueError("'texts' should be provided for %s coherence." % coherence) | ||
else: | ||
self.texts = texts | ||
self.dictionary = Dictionary(self.texts) | ||
self.corpus = [self.dictionary.doc2bow(text) for text in self.texts] | ||
|
||
else: | ||
raise ValueError("%s coherence is not currently supported." % coherence) | ||
|
||
self.model = model | ||
self.topics = self._get_topics() | ||
self.coherence = coherence | ||
# Set pipeline parameters: | ||
if self.coherence == 'u_mass': | ||
self.seg = segmentation.s_one_pre | ||
self.prob = probability_estimation.p_boolean_document | ||
self.conf = direct_confirmation_measure.log_conditional_probability | ||
self.aggr = aggregation.arithmetic_mean | ||
|
||
elif self.coherence == 'c_v': | ||
self.seg = segmentation.s_one_set | ||
self.prob = probability_estimation.p_boolean_sliding_window | ||
self.conf = indirect_confirmation_measure.cosine_similarity | ||
self.aggr = aggregation.arithmetic_mean | ||
|
||
def __str__(self): | ||
return "CoherenceModel(segmentation=%s, probability estimation=%s, confirmation measure=%s, aggregation=%s)" % ( | ||
self.seg, self.prob, self.conf, self.aggr) | ||
|
||
def _get_topics(self): | ||
"""Internal helper function to return topics from a trained topic model.""" | ||
topics = [] # FIXME : Meant to work for LDAModel, LdaVowpalWabbit right now. Make it work for others. | ||
if isinstance(self.model, LdaModel): | ||
for topic in self.model.state.get_lambda(): | ||
bestn = argsort(topic, topn=10, reverse=True) | ||
topics.append(bestn) | ||
elif isinstance(self.model, LdaVowpalWabbit): | ||
for topic in self.model._get_topics(): | ||
bestn = argsort(topic, topn=10, reverse=True) | ||
topics.append(bestn) | ||
return topics | ||
|
||
def get_coherence(self): | ||
if self.coherence == 'u_mass': | ||
segmented_topics = self.seg(self.topics) | ||
per_topic_postings, num_docs = self.prob(self.corpus, segmented_topics) | ||
confirmed_measures = self.conf(segmented_topics, per_topic_postings, num_docs) | ||
return self.aggr(confirmed_measures) | ||
|
||
elif self.coherence == 'c_v': | ||
segmented_topics = self.seg(self.topics) | ||
per_topic_postings, num_windows = self.prob(texts=self.texts, segmented_topics=segmented_topics, | ||
dictionary=self.dictionary, window_size=2) # FIXME : Change window size to 110 finally. | ||
confirmed_measures = self.conf(self.topics, segmented_topics, per_topic_postings, 'nlr', 1, num_windows) | ||
return self.aggr(confirmed_measures) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (C) 2011 Radim Rehurek <radimrehurek@seznam.cz> | ||
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html | ||
|
||
""" | ||
Automated tests for direct confirmation measures in the direct_confirmation_measure module. | ||
""" | ||
|
||
import logging | ||
import unittest | ||
|
||
from gensim.topic_coherence import direct_confirmation_measure | ||
|
||
class TestDirectConfirmationMeasure(unittest.TestCase): | ||
def setUp(self): | ||
# Set up toy example for better understanding and testing | ||
# of this module. See the modules for the mathematical formulas | ||
self.segmentation = [[(1, 2)]] | ||
self.posting_list = {1: set([2, 3, 4]), 2: set([3, 5])} | ||
self.num_docs = 5 | ||
|
||
def testLogConditionalProbability(self): | ||
"""Test log_conditional_probability()""" | ||
obtained = direct_confirmation_measure.log_conditional_probability(self.segmentation, self.posting_list, self.num_docs)[0] | ||
# Answer should be ~ ln(1 / 2) = -0.693147181 | ||
expected = -0.693147181 | ||
self.assertAlmostEqual(obtained, expected) | ||
|
||
def testLogRatioMeasure(self): | ||
"""Test log_ratio_measure()""" | ||
obtained = direct_confirmation_measure.log_ratio_measure(self.segmentation, self.posting_list, self.num_docs)[0] | ||
# Answer should be ~ ln{(1 / 5) / [(3 / 5) * (2 / 5)]} = -0.182321557 | ||
expected = -0.182321557 | ||
self.assertAlmostEqual(obtained, expected) | ||
|
||
def testNormalizedLogRatioMeasure(self): | ||
"""Test normalized_log_ratio_measure()""" | ||
obtained = direct_confirmation_measure.normalized_log_ratio_measure(self.segmentation, self.posting_list, self.num_docs)[0] | ||
# Answer should be ~ -0.182321557 / ln(1 / 5) = 0.113282753 | ||
expected = 0.113282753 | ||
self.assertAlmostEqual(obtained, expected) | ||
|
||
if __name__ == '__main__': | ||
logging.root.setLevel(logging.WARNING) | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (C) 2011 Radim Rehurek <radimrehurek@seznam.cz> | ||
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html | ||
|
||
""" | ||
Automated tests for indirect confirmation measures in the indirect_confirmation_measure module. | ||
""" | ||
|
||
import logging | ||
import unittest | ||
|
||
from gensim.topic_coherence import indirect_confirmation_measure | ||
|
||
import numpy as np | ||
from numpy import array | ||
|
||
class TestIndirectConfirmation(unittest.TestCase): | ||
def setUp(self): | ||
# Set up toy example for better understanding and testing | ||
# of this module. See the modules for the mathematical formulas | ||
self.topics = [np.array([1, 2])] | ||
# Result from s_one_set segmentation: | ||
self.segmentation = [[(1, array([1, 2])), (2, array([1, 2]))]] | ||
self.posting_list = {1: set([2, 3, 4]), 2: set([3, 5])} | ||
self.gamma = 1 | ||
self.measure = 'nlr' | ||
self.num_docs = 5 | ||
|
||
def testCosineSimilarity(self): | ||
"""Test cosine_similarity()""" | ||
obtained = indirect_confirmation_measure.cosine_similarity(self.topics, self.segmentation, | ||
self.posting_list, self.measure, | ||
self.gamma, self.num_docs) | ||
# The steps involved in this calculation are as follows: | ||
# 1. Take (1, array([1, 2]). Take w' which is 1. | ||
# 2. Calculate nlr(1, 1), nlr(1, 2). This is our first vector. | ||
# 3. Take w* which is array([1, 2]). | ||
# 4. Calculate nlr(1, 1) + nlr(2, 1). Calculate nlr(1, 2), nlr(2, 2). This is our second vector. | ||
# 5. Find out cosine similarity between these two vectors. | ||
# 6. Similarly for the second segmentation. | ||
expected = [0.6230, 0.6230] # To account for EPSILON approximation | ||
self.assertAlmostEqual(obtained[0], expected[0], 4) | ||
self.assertAlmostEqual(obtained[1], expected[1], 4) | ||
|
||
if __name__ == '__main__': | ||
logging.root.setLevel(logging.WARNING) | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.