Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed KeyError in coherence model #2830

Merged
merged 13 commits into from
Jun 29, 2021
16 changes: 11 additions & 5 deletions gensim/models/coherencemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class CoherenceModel(interfaces.TransformationABC):
>>> coherence = cm.get_coherence() # get coherence value

"""

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you add this blank line?

def __init__(self, model=None, topics=None, texts=None, corpus=None, dictionary=None,
window_size=None, keyed_vectors=None, coherence='c_v', topn=20, processes=-1):
"""
Expand Down Expand Up @@ -441,11 +442,16 @@ def topics(self, topics):
self._topics = new_topics

def _ensure_elements_are_ids(self, topic):
try:
return np.array([self.dictionary.token2id[token] for token in topic])
except KeyError: # might be a list of token ids already, but let's verify all in dict
topic = (self.dictionary.id2token[_id] for _id in topic)
return np.array([self.dictionary.token2id[token] for token in topic])
elements_are_tokens = np.array([self.dictionary.token2id[token]
for token in topic if token in self.dictionary.token2id])
topic_tokens_from_id = (self.dictionary.id2token[_id] for _id in topic if _id in self.dictionary.id2token)
elements_are_ids = np.array([self.dictionary.token2id[token] for token in topic_tokens_from_id])
if elements_are_tokens.size > elements_are_ids.size:
return elements_are_tokens
elif elements_are_ids.size > elements_are_tokens.size:
return elements_are_ids
else:
raise Exception("Topic list is not a list of lists of tokens or ids")
pietrotrope marked this conversation as resolved.
Show resolved Hide resolved

def _update_accumulator(self, new_topics):
if self._relevant_ids_will_differ(new_topics):
Expand Down