diff --git a/gensim/corpora/dictionary.py b/gensim/corpora/dictionary.py index 1fd7e31e61..484684c26d 100644 --- a/gensim/corpora/dictionary.py +++ b/gensim/corpora/dictionary.py @@ -172,14 +172,16 @@ def doc2bow(self, document, allow_update=False, return_missing=False): else: return result - def filter_extremes(self, no_below=5, no_above=0.5, keep_n=100000): + def filter_extremes(self, no_below=5, no_above=0.5, keep_n=100000, keep_tokens=None): """ Filter out tokens that appear in 1. less than `no_below` documents (absolute number) or 2. more than `no_above` documents (fraction of total corpus size, *not* absolute number). - 3. after (1) and (2), keep only the first `keep_n` most frequent tokens (or + 3. if tokens are given in keep_tokens (list of strings), they will be kept regardless of + the `no_below` and `no_above` settings + 4. after (1), (2) and (3), keep only the first `keep_n` most frequent tokens (or keep all if `None`). After the pruning, shrink resulting gaps in word ids. @@ -190,9 +192,15 @@ def filter_extremes(self, no_below=5, no_above=0.5, keep_n=100000): no_above_abs = int(no_above * self.num_docs) # convert fractional threshold to absolute threshold # determine which tokens to keep - good_ids = ( - v for v in itervalues(self.token2id) - if no_below <= self.dfs.get(v, 0) <= no_above_abs) + if keep_tokens: + keep_ids = [self.token2id[v] for v in keep_tokens if v in self.token2id] + good_ids = (v for v in itervalues(self.token2id) + if no_below <= self.dfs.get(v, 0) <= no_above_abs + or v in keep_ids) + else: + good_ids = ( + v for v in itervalues(self.token2id) + if no_below <= self.dfs.get(v, 0) <= no_above_abs) good_ids = sorted(good_ids, key=self.dfs.get, reverse=True) if keep_n is not None: good_ids = good_ids[:keep_n] diff --git a/gensim/test/test_corpora_dictionary.py b/gensim/test/test_corpora_dictionary.py index bbf5fa339d..16c499b245 100644 --- a/gensim/test/test_corpora_dictionary.py +++ b/gensim/test/test_corpora_dictionary.py @@ -120,6 +120,27 @@ def testFilter(self): d.filter_extremes(no_below=2, no_above=1.0, keep_n=4) expected = {0: 3, 1: 3, 2: 3, 3: 3} self.assertEqual(d.dfs, expected) + + def testFilterKeepTokens_keepTokens(self): + # provide keep_tokens argument, keep the tokens given + d = Dictionary(self.texts) + d.filter_extremes(no_below=3, no_above=1.0, keep_tokens=['human', 'survey']) + expected = set(['graph', 'trees', 'human', 'system', 'user', 'survey']) + self.assertEqual(set(d.token2id.keys()), expected) + + def testFilterKeepTokens_unchangedFunctionality(self): + # do not provide keep_tokens argument, filter_extremes functionality is unchanged + d = Dictionary(self.texts) + d.filter_extremes(no_below=3, no_above=1.0) + expected = set(['graph', 'trees', 'system', 'user']) + self.assertEqual(set(d.token2id.keys()), expected) + + def testFilterKeepTokens_unseenToken(self): + # do provide keep_tokens argument with unseen tokens, filter_extremes functionality is unchanged + d = Dictionary(self.texts) + d.filter_extremes(no_below=3, no_above=1.0, keep_tokens=['unknown_token']) + expected = set(['graph', 'trees', 'system', 'user']) + self.assertEqual(set(d.token2id.keys()), expected) def testFilterMostFrequent(self): d = Dictionary(self.texts)