Skip to content

Commit

Permalink
Add common_terms parameter to sklearn_api.PhrasesTransformer (#2074)
Browse files Browse the repository at this point in the history
* add common_terms parameter

This parameter is being propagated to the underlying models.Phrases class.

* add tests for new common_terms parameter

* utilize models.phrases.Phraser class

this avoids the following  warning:
"UserWarning: For a faster implementation, use the
gensim.models.phrases.Phraser class"

* add testCompareToOld, add pre-trained Phrases model

* use pickle to load old PhrasesTransformer

* allow setting Phrases model without setting Phraser model

A pre-trained Phrases model (self.gensim_model) may be set to avoid
using the fit() method. In transform(), the also necessary Phraser model
(self.phraser) will be instantiated if it hasn't been before.

* open pickle file

* add __setstate__ for backward compatibility

* use pickle protocol 2

* test loading new phrases transformer
  • Loading branch information
pmlk authored and menshikh-iv committed Oct 4, 2018
1 parent 4543646 commit 367bdbd
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 4 deletions.
25 changes: 21 additions & 4 deletions gensim/sklearn_api/phrases.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from sklearn.exceptions import NotFittedError

from gensim import models
from gensim.models.phrases import Phraser


class PhrasesTransformer(TransformerMixin, BaseEstimator):
Expand All @@ -44,7 +45,7 @@ class PhrasesTransformer(TransformerMixin, BaseEstimator):
"""
def __init__(self, min_count=5, threshold=10.0, max_vocab_size=40000000,
delimiter=b'_', progress_per=10000, scoring='default'):
delimiter=b'_', progress_per=10000, scoring='default', common_terms=frozenset()):
"""
Parameters
Expand Down Expand Up @@ -87,15 +88,25 @@ def __init__(self, min_count=5, threshold=10.0, max_vocab_size=40000000,
A scoring function without any of these parameters (even if the parameters are not used) will
raise a ValueError on initialization of the Phrases class. The scoring function must be pickleable.
common_terms : set of str, optional
List of "stop words" that won't affect frequency count of expressions containing them.
Allow to detect expressions like "bank_of_america" or "eye_of_the_beholder".
"""
self.gensim_model = None
self.phraser = None
self.min_count = min_count
self.threshold = threshold
self.max_vocab_size = max_vocab_size
self.delimiter = delimiter
self.progress_per = progress_per
self.scoring = scoring
self.common_terms = common_terms

def __setstate__(self, state):
self.__dict__ = state
self.common_terms = frozenset()
self.phraser = None

def fit(self, X, y=None):
"""Fit the model according to the given training data.
Expand All @@ -114,8 +125,9 @@ def fit(self, X, y=None):
self.gensim_model = models.Phrases(
sentences=X, min_count=self.min_count, threshold=self.threshold,
max_vocab_size=self.max_vocab_size, delimiter=self.delimiter,
progress_per=self.progress_per, scoring=self.scoring
progress_per=self.progress_per, scoring=self.scoring, common_terms=self.common_terms
)
self.phraser = Phraser(self.gensim_model)
return self

def transform(self, docs):
Expand All @@ -139,10 +151,14 @@ def transform(self, docs):
"This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method."
)

if self.phraser is None:
self.phraser = Phraser(self.gensim_model)

# input as python lists
if isinstance(docs[0], string_types):
docs = [docs]
return [self.gensim_model[doc] for doc in docs]

return [self.phraser[doc] for doc in docs]

def partial_fit(self, X):
"""Train model over a potentially incomplete set of sentences.
Expand All @@ -166,8 +182,9 @@ def partial_fit(self, X):
self.gensim_model = models.Phrases(
sentences=X, min_count=self.min_count, threshold=self.threshold,
max_vocab_size=self.max_vocab_size, delimiter=self.delimiter,
progress_per=self.progress_per, scoring=self.scoring
progress_per=self.progress_per, scoring=self.scoring, common_terms=self.common_terms
)

self.gensim_model.add_vocab(X)
self.phraser = Phraser(self.gensim_model)
return self
Binary file not shown.
Binary file not shown.
83 changes: 83 additions & 0 deletions gensim/test/test_sklearn_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,14 @@
['graph', 'minors', 'survey', 'human', 'interface']
]

common_terms = ["of", "the", "was", "are"]
phrases_w_common_terms = [
[u'the', u'mayor', u'of', u'new', u'york', u'was', u'there'],
[u'the', u'mayor', u'of', u'new', u'orleans', u'was', u'there'],
[u'the', u'bank', u'of', u'america', u'offices', u'are', u'open'],
[u'the', u'bank', u'of', u'america', u'offices', u'are', u'closed']
]


class TestLdaWrapper(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -1152,6 +1160,81 @@ def testModelNotFitted(self):
self.assertRaises(NotFittedError, phrases_transformer.transform, phrases_sentences[0])


class TestPhrasesTransformerCommonTerms(unittest.TestCase):
def setUp(self):
self.model = PhrasesTransformer(min_count=1, threshold=1, common_terms=common_terms)
self.expected_transformations = [
[u'the', u'mayor_of_new', u'york', u'was', u'there'],
[u'the', u'mayor_of_new', u'orleans', u'was', u'there'],
[u'the', u'bank_of_america', u'offices', u'are', u'open'],
[u'the', u'bank_of_america', u'offices', u'are', u'closed']
]

def testCompareToOld(self):
with open(datapath("phrases-transformer-v3-5-0.pkl"), "rb") as old_phrases_transformer_pkl:
old_phrases_transformer = pickle.load(old_phrases_transformer_pkl)
doc = phrases_sentences[-1]
phrase_tokens = old_phrases_transformer.transform(doc)[0]
expected_phrase_tokens = [u'graph_minors', u'survey', u'human_interface']
self.assertEqual(phrase_tokens, expected_phrase_tokens)

self.model.fit(phrases_sentences)
new_phrase_tokens = self.model.transform(doc)[0]
self.assertEqual(new_phrase_tokens, phrase_tokens)

def testLoadNew(self):
with open(datapath("phrases-transformer-new-v3-5-0.pkl"), "rb") as new_phrases_transformer_pkl:
old_phrases_transformer = pickle.load(new_phrases_transformer_pkl)
doc = phrases_sentences[-1]
phrase_tokens = old_phrases_transformer.transform(doc)[0]
expected_phrase_tokens = [u'graph_minors', u'survey', u'human_interface']
self.assertEqual(phrase_tokens, expected_phrase_tokens)

self.model.fit(phrases_sentences)
new_phrase_tokens = self.model.transform(doc)[0]
self.assertEqual(new_phrase_tokens, phrase_tokens)

def testFitAndTransform(self):
self.model.fit(phrases_w_common_terms)

transformed = self.model.transform(phrases_w_common_terms)
self.assertEqual(transformed, self.expected_transformations)

def testFitTransform(self):
transformed = self.model.fit_transform(phrases_w_common_terms)
self.assertEqual(transformed, self.expected_transformations)

def testPartialFit(self):
# fit half of the sentences
self.model.fit(phrases_w_common_terms[:2])

expected_transformations_0 = [
[u'the', u'mayor_of_new', u'york', u'was', u'there'],
[u'the', u'mayor_of_new', u'orleans', u'was', u'there'],
[u'the', u'bank', u'of', u'america', u'offices', u'are', u'open'],
[u'the', u'bank', u'of', u'america', u'offices', u'are', u'closed']
]
# transform all sentences, second half should be same as original
transformed_0 = self.model.transform(phrases_w_common_terms)
self.assertEqual(transformed_0, expected_transformations_0)

# fit remaining sentences, result should be the same as in the other tests
self.model.partial_fit(phrases_w_common_terms[2:])
transformed_1 = self.model.fit_transform(phrases_w_common_terms)
self.assertEqual(transformed_1, self.expected_transformations)

new_phrases = [[u'offices', u'are', u'open'], [u'offices', u'are', u'closed']]
self.model.partial_fit(new_phrases)
expected_transformations_2 = [
[u'the', u'mayor_of_new', u'york', u'was', u'there'],
[u'the', u'mayor_of_new', u'orleans', u'was', u'there'],
[u'the', u'bank_of_america', u'offices_are_open'],
[u'the', u'bank_of_america', u'offices_are_closed']
]
transformed_2 = self.model.transform(phrases_w_common_terms)
self.assertEqual(transformed_2, expected_transformations_2)


# specifically test pluggable scoring in Phrases, because possible pickling issues with function parameter

# this is intentionally in main rather than a class method to support pickling
Expand Down

0 comments on commit 367bdbd

Please sign in to comment.