From 9b4576a3e7a945e30cc958865c4732d471473afd Mon Sep 17 00:00:00 2001 From: Tomasz Oliwa Date: Mon, 13 Mar 2017 14:09:24 -0500 Subject: [PATCH 1/2] Add the 'keep_tokens' parameter to 'filter_extremes' and test it Add the optional 'keep_tokens' parameter to the 'filter_extremes' method in dictionary.py. This parameter can contain a list of tokens, which will be kept regardless of the 'no_below' and 'no_above' settings. This can be useful if the research goal is to enforce certain tokens to appear in topics, and still be able to filter all other extremes. If 'keep_tokens' is not given, the functionality of 'filter_extremes' is unchanged. Unit tests are also provided to assert examples of the above. --- gensim/corpora/dictionary.py | 12 ++++++++++-- gensim/test/test_corpora_dictionary.py | 21 +++++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/gensim/corpora/dictionary.py b/gensim/corpora/dictionary.py index 1fd7e31e61..163d464a39 100644 --- a/gensim/corpora/dictionary.py +++ b/gensim/corpora/dictionary.py @@ -172,14 +172,16 @@ def doc2bow(self, document, allow_update=False, return_missing=False): else: return result - def filter_extremes(self, no_below=5, no_above=0.5, keep_n=100000): + def filter_extremes(self, no_below=5, no_above=0.5, keep_n=100000, keep_tokens=None): """ Filter out tokens that appear in 1. less than `no_below` documents (absolute number) or 2. more than `no_above` documents (fraction of total corpus size, *not* absolute number). - 3. after (1) and (2), keep only the first `keep_n` most frequent tokens (or + 3. if tokens are given in keep_tokens (list of strings), they will be kept regardless of + the `no_below` and `no_above` settings + 4. after (1), (2) and (3), keep only the first `keep_n` most frequent tokens (or keep all if `None`). After the pruning, shrink resulting gaps in word ids. @@ -193,6 +195,12 @@ def filter_extremes(self, no_below=5, no_above=0.5, keep_n=100000): good_ids = ( v for v in itervalues(self.token2id) if no_below <= self.dfs.get(v, 0) <= no_above_abs) + # add ids of keep_tokens elements to good_ids + if keep_tokens: + keep_ids = [self.token2id[v] for v in keep_tokens if v in self.token2id] + good_ids_copy = (v for v in itervalues(self.token2id) if no_below <= self.dfs.get(v, 0) <= no_above_abs) + keep_ids = list(set(keep_ids).union(set(good_ids_copy))) + good_ids = keep_ids good_ids = sorted(good_ids, key=self.dfs.get, reverse=True) if keep_n is not None: good_ids = good_ids[:keep_n] diff --git a/gensim/test/test_corpora_dictionary.py b/gensim/test/test_corpora_dictionary.py index bbf5fa339d..16c499b245 100644 --- a/gensim/test/test_corpora_dictionary.py +++ b/gensim/test/test_corpora_dictionary.py @@ -120,6 +120,27 @@ def testFilter(self): d.filter_extremes(no_below=2, no_above=1.0, keep_n=4) expected = {0: 3, 1: 3, 2: 3, 3: 3} self.assertEqual(d.dfs, expected) + + def testFilterKeepTokens_keepTokens(self): + # provide keep_tokens argument, keep the tokens given + d = Dictionary(self.texts) + d.filter_extremes(no_below=3, no_above=1.0, keep_tokens=['human', 'survey']) + expected = set(['graph', 'trees', 'human', 'system', 'user', 'survey']) + self.assertEqual(set(d.token2id.keys()), expected) + + def testFilterKeepTokens_unchangedFunctionality(self): + # do not provide keep_tokens argument, filter_extremes functionality is unchanged + d = Dictionary(self.texts) + d.filter_extremes(no_below=3, no_above=1.0) + expected = set(['graph', 'trees', 'system', 'user']) + self.assertEqual(set(d.token2id.keys()), expected) + + def testFilterKeepTokens_unseenToken(self): + # do provide keep_tokens argument with unseen tokens, filter_extremes functionality is unchanged + d = Dictionary(self.texts) + d.filter_extremes(no_below=3, no_above=1.0, keep_tokens=['unknown_token']) + expected = set(['graph', 'trees', 'system', 'user']) + self.assertEqual(set(d.token2id.keys()), expected) def testFilterMostFrequent(self): d = Dictionary(self.texts) From ee6b4f736a7a05605477885f0b8322d84444ce18 Mon Sep 17 00:00:00 2001 From: Tomasz Oliwa Date: Mon, 13 Mar 2017 16:29:53 -0500 Subject: [PATCH 2/2] Create good_ids only once Create good_ids only once as per optimization suggestion, regardless if 'keep_tokens' is provided or not. --- gensim/corpora/dictionary.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/gensim/corpora/dictionary.py b/gensim/corpora/dictionary.py index 163d464a39..484684c26d 100644 --- a/gensim/corpora/dictionary.py +++ b/gensim/corpora/dictionary.py @@ -192,15 +192,15 @@ def filter_extremes(self, no_below=5, no_above=0.5, keep_n=100000, keep_tokens=N no_above_abs = int(no_above * self.num_docs) # convert fractional threshold to absolute threshold # determine which tokens to keep - good_ids = ( - v for v in itervalues(self.token2id) - if no_below <= self.dfs.get(v, 0) <= no_above_abs) - # add ids of keep_tokens elements to good_ids if keep_tokens: keep_ids = [self.token2id[v] for v in keep_tokens if v in self.token2id] - good_ids_copy = (v for v in itervalues(self.token2id) if no_below <= self.dfs.get(v, 0) <= no_above_abs) - keep_ids = list(set(keep_ids).union(set(good_ids_copy))) - good_ids = keep_ids + good_ids = (v for v in itervalues(self.token2id) + if no_below <= self.dfs.get(v, 0) <= no_above_abs + or v in keep_ids) + else: + good_ids = ( + v for v in itervalues(self.token2id) + if no_below <= self.dfs.get(v, 0) <= no_above_abs) good_ids = sorted(good_ids, key=self.dfs.get, reverse=True) if keep_n is not None: good_ids = good_ids[:keep_n]