diff --git a/gensim/sklearn_api/phrases.py b/gensim/sklearn_api/phrases.py index 7579a09cc9..5c1cfa83b0 100644 --- a/gensim/sklearn_api/phrases.py +++ b/gensim/sklearn_api/phrases.py @@ -30,6 +30,7 @@ from sklearn.exceptions import NotFittedError from gensim import models +from gensim.models.phrases import Phraser class PhrasesTransformer(TransformerMixin, BaseEstimator): @@ -41,7 +42,7 @@ class PhrasesTransformer(TransformerMixin, BaseEstimator): """ def __init__(self, min_count=5, threshold=10.0, max_vocab_size=40000000, - delimiter=b'_', progress_per=10000, scoring='default'): + delimiter=b'_', progress_per=10000, scoring='default', common_terms=frozenset()): """ Parameters @@ -84,15 +85,25 @@ def __init__(self, min_count=5, threshold=10.0, max_vocab_size=40000000, A scoring function without any of these parameters (even if the parameters are not used) will raise a ValueError on initialization of the Phrases class. The scoring function must be pickleable. + common_terms : set of str, optional + List of "stop words" that won't affect frequency count of expressions containing them. + Allow to detect expressions like "bank_of_america" or "eye_of_the_beholder". """ self.gensim_model = None + self.phraser = None self.min_count = min_count self.threshold = threshold self.max_vocab_size = max_vocab_size self.delimiter = delimiter self.progress_per = progress_per self.scoring = scoring + self.common_terms = common_terms + + def __setstate__(self, state): + self.__dict__ = state + self.common_terms = frozenset() + self.phraser = None def fit(self, X, y=None): """Fit the model according to the given training data. @@ -111,8 +122,9 @@ def fit(self, X, y=None): self.gensim_model = models.Phrases( sentences=X, min_count=self.min_count, threshold=self.threshold, max_vocab_size=self.max_vocab_size, delimiter=self.delimiter, - progress_per=self.progress_per, scoring=self.scoring + progress_per=self.progress_per, scoring=self.scoring, common_terms=self.common_terms ) + self.phraser = Phraser(self.gensim_model) return self def transform(self, docs): @@ -136,10 +148,14 @@ def transform(self, docs): "This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method." ) + if self.phraser is None: + self.phraser = Phraser(self.gensim_model) + # input as python lists if isinstance(docs[0], string_types): docs = [docs] - return [self.gensim_model[doc] for doc in docs] + + return [self.phraser[doc] for doc in docs] def partial_fit(self, X): """Train model over a potentially incomplete set of sentences. @@ -163,8 +179,9 @@ def partial_fit(self, X): self.gensim_model = models.Phrases( sentences=X, min_count=self.min_count, threshold=self.threshold, max_vocab_size=self.max_vocab_size, delimiter=self.delimiter, - progress_per=self.progress_per, scoring=self.scoring + progress_per=self.progress_per, scoring=self.scoring, common_terms=self.common_terms ) self.gensim_model.add_vocab(X) + self.phraser = Phraser(self.gensim_model) return self diff --git a/gensim/test/test_data/phrases-transformer-new-v3-5-0.pkl b/gensim/test/test_data/phrases-transformer-new-v3-5-0.pkl new file mode 100644 index 0000000000..7799418058 Binary files /dev/null and b/gensim/test/test_data/phrases-transformer-new-v3-5-0.pkl differ diff --git a/gensim/test/test_data/phrases-transformer-v3-5-0.pkl b/gensim/test/test_data/phrases-transformer-v3-5-0.pkl new file mode 100644 index 0000000000..8ffef6763b Binary files /dev/null and b/gensim/test/test_data/phrases-transformer-v3-5-0.pkl differ diff --git a/gensim/test/test_sklearn_api.py b/gensim/test/test_sklearn_api.py index ed5516df37..014e3526c5 100644 --- a/gensim/test/test_sklearn_api.py +++ b/gensim/test/test_sklearn_api.py @@ -286,6 +286,14 @@ ['graph', 'minors', 'survey', 'human', 'interface'] ] +common_terms = ["of", "the", "was", "are"] +phrases_w_common_terms = [ + [u'the', u'mayor', u'of', u'new', u'york', u'was', u'there'], + [u'the', u'mayor', u'of', u'new', u'orleans', u'was', u'there'], + [u'the', u'bank', u'of', u'america', u'offices', u'are', u'open'], + [u'the', u'bank', u'of', u'america', u'offices', u'are', u'closed'] +] + class TestLdaWrapper(unittest.TestCase): def setUp(self): @@ -1151,6 +1159,81 @@ def testModelNotFitted(self): self.assertRaises(NotFittedError, phrases_transformer.transform, phrases_sentences[0]) +class TestPhrasesTransformerCommonTerms(unittest.TestCase): + def setUp(self): + self.model = PhrasesTransformer(min_count=1, threshold=1, common_terms=common_terms) + self.expected_transformations = [ + [u'the', u'mayor_of_new', u'york', u'was', u'there'], + [u'the', u'mayor_of_new', u'orleans', u'was', u'there'], + [u'the', u'bank_of_america', u'offices', u'are', u'open'], + [u'the', u'bank_of_america', u'offices', u'are', u'closed'] + ] + + def testCompareToOld(self): + with open(datapath("phrases-transformer-v3-5-0.pkl"), "rb") as old_phrases_transformer_pkl: + old_phrases_transformer = pickle.load(old_phrases_transformer_pkl) + doc = phrases_sentences[-1] + phrase_tokens = old_phrases_transformer.transform(doc)[0] + expected_phrase_tokens = [u'graph_minors', u'survey', u'human_interface'] + self.assertEqual(phrase_tokens, expected_phrase_tokens) + + self.model.fit(phrases_sentences) + new_phrase_tokens = self.model.transform(doc)[0] + self.assertEqual(new_phrase_tokens, phrase_tokens) + + def testLoadNew(self): + with open(datapath("phrases-transformer-new-v3-5-0.pkl"), "rb") as new_phrases_transformer_pkl: + old_phrases_transformer = pickle.load(new_phrases_transformer_pkl) + doc = phrases_sentences[-1] + phrase_tokens = old_phrases_transformer.transform(doc)[0] + expected_phrase_tokens = [u'graph_minors', u'survey', u'human_interface'] + self.assertEqual(phrase_tokens, expected_phrase_tokens) + + self.model.fit(phrases_sentences) + new_phrase_tokens = self.model.transform(doc)[0] + self.assertEqual(new_phrase_tokens, phrase_tokens) + + def testFitAndTransform(self): + self.model.fit(phrases_w_common_terms) + + transformed = self.model.transform(phrases_w_common_terms) + self.assertEqual(transformed, self.expected_transformations) + + def testFitTransform(self): + transformed = self.model.fit_transform(phrases_w_common_terms) + self.assertEqual(transformed, self.expected_transformations) + + def testPartialFit(self): + # fit half of the sentences + self.model.fit(phrases_w_common_terms[:2]) + + expected_transformations_0 = [ + [u'the', u'mayor_of_new', u'york', u'was', u'there'], + [u'the', u'mayor_of_new', u'orleans', u'was', u'there'], + [u'the', u'bank', u'of', u'america', u'offices', u'are', u'open'], + [u'the', u'bank', u'of', u'america', u'offices', u'are', u'closed'] + ] + # transform all sentences, second half should be same as original + transformed_0 = self.model.transform(phrases_w_common_terms) + self.assertEqual(transformed_0, expected_transformations_0) + + # fit remaining sentences, result should be the same as in the other tests + self.model.partial_fit(phrases_w_common_terms[2:]) + transformed_1 = self.model.fit_transform(phrases_w_common_terms) + self.assertEqual(transformed_1, self.expected_transformations) + + new_phrases = [[u'offices', u'are', u'open'], [u'offices', u'are', u'closed']] + self.model.partial_fit(new_phrases) + expected_transformations_2 = [ + [u'the', u'mayor_of_new', u'york', u'was', u'there'], + [u'the', u'mayor_of_new', u'orleans', u'was', u'there'], + [u'the', u'bank_of_america', u'offices_are_open'], + [u'the', u'bank_of_america', u'offices_are_closed'] + ] + transformed_2 = self.model.transform(phrases_w_common_terms) + self.assertEqual(transformed_2, expected_transformations_2) + + # specifically test pluggable scoring in Phrases, because possible pickling issues with function parameter # this is intentionally in main rather than a class method to support pickling