Skip to content

Commit

Permalink
fix RuntimeError in export_phrases (change defaultdict to dict) (#3041)
Browse files Browse the repository at this point in the history
* fix typo

* fix test cases for test_export_phrases

* add test cases for test_find_phrases

* Fix #3031 Runtime error in phrases.py

* remove unused variable reference

* fix newline to end of file

* fix formattingpy

* Update CHANGELOG.md

* Update CHANGELOG.md

Co-authored-by: Michael Penkov <m@penkov.dev>
  • Loading branch information
thalishsajeed and mpenkov authored Feb 13, 2021
1 parent 0502284 commit cfc9e95
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 16 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
Changes
=======

## Unreleased

- fix RuntimeError in export_phrases (change defaultdict to dict) (PR [#3041](https://github.com/RaRe-Technologies/gensim/pull/3041), [@thalishsajeed](https://github.com/thalishsajeed))

## 4.0.0beta, 2020-10-31

**⚠️ Gensim 4.0 contains breaking API changes! See the [Migration guide](https://github.com/RaRe-Technologies/gensim/wiki/Migrating-from-Gensim-3.x-to-4) to update your existing Gensim 3.x code and models.**
Expand Down
20 changes: 10 additions & 10 deletions gensim/models/phrases.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
"""

import logging
from collections import defaultdict
import itertools
from math import log
import pickle
Expand Down Expand Up @@ -412,7 +411,7 @@ def load(cls, *args, **kwargs):
if not isinstance(word, str):
logger.info("old version of %s loaded, upgrading %i words in memory", cls.__name__, len(model.vocab))
logger.info("re-save the loaded model to avoid this upgrade in the future")
vocab = defaultdict(int)
vocab = {}
for key, value in model.vocab.items(): # needs lots of extra RAM temporarily!
vocab[str(key, encoding='utf8')] = value
model.vocab = vocab
Expand Down Expand Up @@ -554,7 +553,7 @@ def __init__(
self.min_count = min_count
self.threshold = threshold
self.max_vocab_size = max_vocab_size
self.vocab = defaultdict(int) # mapping between token => its count
self.vocab = {} # mapping between token => its count
self.min_reduce = 1 # ignore any tokens with count smaller than this
self.delimiter = delimiter
self.progress_per = progress_per
Expand All @@ -579,7 +578,7 @@ def __str__(self):
def _learn_vocab(sentences, max_vocab_size, delimiter, connector_words, progress_per):
"""Collect unigram and bigram counts from the `sentences` iterable."""
sentence_no, total_words, min_reduce = -1, 0, 1
vocab = defaultdict(int)
vocab = {}
logger.info("collecting all words and their counts")
for sentence_no, sentence in enumerate(sentences):
if sentence_no % progress_per == 0:
Expand All @@ -590,10 +589,11 @@ def _learn_vocab(sentences, max_vocab_size, delimiter, connector_words, progress
start_token, in_between = None, []
for word in sentence:
if word not in connector_words:
vocab[word] += 1
vocab[word] = vocab.get(word, 0) + 1
if start_token is not None:
phrase_tokens = itertools.chain([start_token], in_between, [word])
vocab[delimiter.join(phrase_tokens)] += 1
joined_phrase_token = delimiter.join(phrase_tokens)
vocab[joined_phrase_token] = vocab.get(joined_phrase_token, 0) + 1
start_token, in_between = word, [] # treat word as both end of a phrase AND beginning of another
elif start_token is not None:
in_between.append(word)
Expand Down Expand Up @@ -654,7 +654,7 @@ def add_vocab(self, sentences):
logger.info("merging %i counts into %s", len(vocab), self)
self.min_reduce = max(self.min_reduce, min_reduce)
for word, count in vocab.items():
self.vocab[word] += count
self.vocab[word] = self.vocab.get(word, 0) + count
if len(self.vocab) > self.max_vocab_size:
utils.prune_vocab(self.vocab, self.min_reduce)
self.min_reduce += 1
Expand All @@ -666,17 +666,17 @@ def add_vocab(self, sentences):

def score_candidate(self, word_a, word_b, in_between):
# Micro optimization: check for quick early-out conditions, before the actual scoring.
word_a_cnt = self.vocab[word_a]
word_a_cnt = self.vocab.get(word_a, 0)
if word_a_cnt <= 0:
return None, None

word_b_cnt = self.vocab[word_b]
word_b_cnt = self.vocab.get(word_b, 0)
if word_b_cnt <= 0:
return None, None

phrase = self.delimiter.join([word_a] + in_between + [word_b])
# XXX: Why do we care about *all* phrase tokens? Why not just score the start+end bigram?
phrase_cnt = self.vocab[phrase]
phrase_cnt = self.vocab.get(phrase, 0)
if phrase_cnt <= 0:
return None, None

Expand Down
46 changes: 40 additions & 6 deletions gensim/test/test_phrases.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,15 +213,34 @@ def dumb_scorer(worda_count, wordb_count, bigram_count, len_vocab, min_count, co
class TestPhrasesModel(PhrasesCommon, unittest.TestCase):

def test_export_phrases(self):
"""Test Phrases bigram export phrases."""
"""Test Phrases bigram and trigram export phrases."""
bigram = Phrases(self.sentences, min_count=1, threshold=1, delimiter=' ')
trigram = Phrases(bigram[self.sentences], min_count=1, threshold=1, delimiter=' ')
seen_bigrams = set(bigram.export_phrases().keys())
seen_trigrams = set(trigram.export_phrases().keys())

assert seen_bigrams == set([
'human interface',
'response time',
'graph minors',
'minors survey',
])

assert seen_trigrams == set([
'human interface',
'graph minors survey',
])

def test_find_phrases(self):
"""Test Phrases bigram find phrases."""
bigram = Phrases(self.sentences, min_count=1, threshold=1, delimiter=' ')
seen_bigrams = set(bigram.find_phrases(self.sentences).keys())

assert seen_bigrams == {
assert seen_bigrams == set([
'response time',
'graph minors',
'human interface',
}
])

def test_multiple_bigrams_single_entry(self):
"""Test a single entry produces multiple bigrams."""
Expand Down Expand Up @@ -441,7 +460,7 @@ def test_multiple_bigrams_single_entry(self):
'human interface',
])

def test_export_phrases(self):
def test_find_phrases(self):
"""Test Phrases bigram export phrases."""
bigram = Phrases(self.sentences, min_count=1, threshold=1, connector_words=self.connector_words, delimiter=' ')
seen_bigrams = set(bigram.find_phrases(self.sentences).keys())
Expand All @@ -453,6 +472,21 @@ def test_export_phrases(self):
'lack of interest',
])

def test_export_phrases(self):
"""Test Phrases bigram export phrases."""
bigram = Phrases(self.sentences, min_count=1, threshold=1, delimiter=' ')
seen_bigrams = set(bigram.export_phrases().keys())
assert seen_bigrams == set([
'and graph',
'data and',
'graph of',
'graph survey',
'human interface',
'lack of',
'of interest',
'of trees',
])

def test_scoring_default(self):
""" test the default scoring, from the mikolov word2vec paper """
bigram = Phrases(self.sentences, min_count=1, threshold=1, connector_words=self.connector_words)
Expand Down Expand Up @@ -510,9 +544,9 @@ def test__getitem__(self):
assert phrased_sentence == ['data_and_graph', 'survey', 'for', 'human_interface']


class TestFrozenPhrasesModelCompatibilty(unittest.TestCase):
class TestFrozenPhrasesModelCompatibility(unittest.TestCase):

def test_compatibilty(self):
def test_compatibility(self):
phrases = Phrases.load(datapath("phrases-3.6.0.model"))
phraser = FrozenPhrases.load(datapath("phraser-3.6.0.model"))
test_sentences = ['trees', 'graph', 'minors']
Expand Down

0 comments on commit cfc9e95

Please sign in to comment.