Skip to content

Commit

Permalink
Fix bug where saved Phrases model did not load its connector_words (#…
Browse files Browse the repository at this point in the history
…3116)

* fixed bug of connector_words not loading, while loading saved phrases model of version >= 4

Added tests for asserting persistence of phrases connector_words

* Update test_phrases.py

* Update phrases.py

* Update CHANGELOG.md

Co-authored-by: Michael Penkov <m@penkov.dev>
  • Loading branch information
aloknayak29 and mpenkov authored May 8, 2021
1 parent 8d70657 commit 351456b
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 22 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Changes
- Improve & unify docs for dirichlet priors (PR [#3125](https://github.com/RaRe-Technologies/gensim/pull/3125), [@jonaschn](https://github.com/jonaschn))
- Materialize and copy the corpus passed to SoftCosineSimilarity (PR [#3128](https://github.com/RaRe-Technologies/gensim/pull/3128), [@Witiko](https://github.com/Witiko))
- [#3115](https://github.com/RaRe-Technologies/gensim/pull/3115): Make LSI dispatcher CLI param for number of jobs optional, by [@robguinness](https://github.com/robguinness))
- fix bug when loading saved Phrases model (PR [#3116](https://github.com/RaRe-Technologies/gensim/pull/3116), [@aloknayak29](https://github.com/aloknayak29))

### Documentation

Expand Down
16 changes: 7 additions & 9 deletions gensim/models/phrases.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,15 +391,13 @@ def load(cls, *args, **kwargs):
raise ValueError(f'failed to load {cls.__name__} model, unknown scoring "{model.scoring}"')

# common_terms didn't exist pre-3.?, and was renamed to connector in 4.0.0.
if hasattr(model, "common_terms"):
model.connector_words = model.common_terms
del model.common_terms
else:
logger.warning(
'older version of %s loaded without common_terms attribute, setting connector_words to an empty set',
cls.__name__,
)
model.connector_words = frozenset()
if not hasattr(model, "connector_words"):
if hasattr(model, "common_terms"):
model.connector_words = model.common_terms
del model.common_terms
else:
logger.warning('loaded older version of %s, setting connector_words to an empty set', cls.__name__)
model.connector_words = frozenset()

if not hasattr(model, 'corpus_word_count'):
logger.warning('older version of %s loaded without corpus_word_count', cls.__name__)
Expand Down
45 changes: 32 additions & 13 deletions gensim/test/test_phrases.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,32 +305,42 @@ def test_pruning(self):


class TestPhrasesPersistence(PhrasesData, unittest.TestCase):

def test_save_load_custom_scorer(self):
"""Test saving and loading a Phrases object with a custom scorer."""
bigram = Phrases(self.sentences, min_count=1, threshold=.001, scoring=dumb_scorer)
with temporary_file("test.pkl") as fpath:
bigram = Phrases(self.sentences, min_count=1, threshold=.001, scoring=dumb_scorer)
bigram.save(fpath)
bigram_loaded = Phrases.load(fpath)
test_sentences = [['graph', 'minors', 'survey', 'human', 'interface', 'system']]
seen_scores = list(bigram_loaded.find_phrases(test_sentences).values())

assert all(score == 1 for score in seen_scores)
assert len(seen_scores) == 3 # 'graph minors' and 'survey human' and 'interface system'
test_sentences = [['graph', 'minors', 'survey', 'human', 'interface', 'system']]
seen_scores = list(bigram_loaded.find_phrases(test_sentences).values())

assert all(score == 1 for score in seen_scores)
assert len(seen_scores) == 3 # 'graph minors' and 'survey human' and 'interface system'

def test_save_load(self):
"""Test saving and loading a Phrases object."""
bigram = Phrases(self.sentences, min_count=1, threshold=1)
with temporary_file("test.pkl") as fpath:
bigram.save(fpath)
bigram_loaded = Phrases.load(fpath)

test_sentences = [['graph', 'minors', 'survey', 'human', 'interface', 'system']]
seen_scores = set(round(score, 3) for score in bigram_loaded.find_phrases(test_sentences).values())
assert seen_scores == set([
5.167, # score for graph minors
3.444 # score for human interface
])

def test_save_load_with_connector_words(self):
"""Test saving and loading a Phrases object."""
connector_words = frozenset({'of'})
bigram = Phrases(self.sentences, min_count=1, threshold=1, connector_words=connector_words)
with temporary_file("test.pkl") as fpath:
bigram = Phrases(self.sentences, min_count=1, threshold=1)
bigram.save(fpath)
bigram_loaded = Phrases.load(fpath)
test_sentences = [['graph', 'minors', 'survey', 'human', 'interface', 'system']]
seen_scores = set(round(score, 3) for score in bigram_loaded.find_phrases(test_sentences).values())

assert seen_scores == set([
5.167, # score for graph minors
3.444 # score for human interface
])
assert bigram_loaded.connector_words == connector_words

def test_save_load_string_scoring(self):
"""Test backwards compatibility with a previous version of Phrases with custom scoring."""
Expand Down Expand Up @@ -385,6 +395,15 @@ def test_save_load(self):
bigram_loaded[['graph', 'minors', 'survey', 'human', 'interface', 'system']],
['graph_minors', 'survey', 'human_interface', 'system'])

def test_save_load_with_connector_words(self):
"""Test saving and loading a FrozenPhrases object."""
connector_words = frozenset({'of'})
with temporary_file("test.pkl") as fpath:
bigram = FrozenPhrases(Phrases(self.sentences, min_count=1, threshold=1, connector_words=connector_words))
bigram.save(fpath)
bigram_loaded = FrozenPhrases.load(fpath)
self.assertEqual(bigram_loaded.connector_words, connector_words)

def test_save_load_string_scoring(self):
"""Test saving and loading a FrozenPhrases object with a string scoring parameter.
This should ensure backwards compatibility with the previous version of FrozenPhrases"""
Expand Down

0 comments on commit 351456b

Please sign in to comment.