Skip to content

Commit

Permalink
Add the 'keep_tokens' parameter to 'filter_extremes' (#1210)
Browse files Browse the repository at this point in the history
* Add the 'keep_tokens' parameter to 'filter_extremes' and test it

Add the optional 'keep_tokens' parameter to the 'filter_extremes'
method in dictionary.py. This parameter can contain a list of tokens,
which will be kept regardless of the 'no_below' and 'no_above' settings.
This can be useful if the research goal is to enforce certain tokens to
appear in topics, and still be able to filter all other extremes.

If 'keep_tokens' is not given, the functionality of 'filter_extremes' is
unchanged.

Unit tests are also provided to assert examples of the above.

* Create good_ids only once

Create good_ids only once as per optimization
suggestion, regardless if 'keep_tokens' is provided or not.
  • Loading branch information
Tomasz Oliwa authored and tmylk committed Mar 13, 2017
1 parent e6405c9 commit 8c869cb
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 5 deletions.
18 changes: 13 additions & 5 deletions gensim/corpora/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,14 +172,16 @@ def doc2bow(self, document, allow_update=False, return_missing=False):
else:
return result

def filter_extremes(self, no_below=5, no_above=0.5, keep_n=100000):
def filter_extremes(self, no_below=5, no_above=0.5, keep_n=100000, keep_tokens=None):
"""
Filter out tokens that appear in
1. less than `no_below` documents (absolute number) or
2. more than `no_above` documents (fraction of total corpus size, *not*
absolute number).
3. after (1) and (2), keep only the first `keep_n` most frequent tokens (or
3. if tokens are given in keep_tokens (list of strings), they will be kept regardless of
the `no_below` and `no_above` settings
4. after (1), (2) and (3), keep only the first `keep_n` most frequent tokens (or
keep all if `None`).
After the pruning, shrink resulting gaps in word ids.
Expand All @@ -190,9 +192,15 @@ def filter_extremes(self, no_below=5, no_above=0.5, keep_n=100000):
no_above_abs = int(no_above * self.num_docs) # convert fractional threshold to absolute threshold

# determine which tokens to keep
good_ids = (
v for v in itervalues(self.token2id)
if no_below <= self.dfs.get(v, 0) <= no_above_abs)
if keep_tokens:
keep_ids = [self.token2id[v] for v in keep_tokens if v in self.token2id]
good_ids = (v for v in itervalues(self.token2id)
if no_below <= self.dfs.get(v, 0) <= no_above_abs
or v in keep_ids)
else:
good_ids = (
v for v in itervalues(self.token2id)
if no_below <= self.dfs.get(v, 0) <= no_above_abs)
good_ids = sorted(good_ids, key=self.dfs.get, reverse=True)
if keep_n is not None:
good_ids = good_ids[:keep_n]
Expand Down
21 changes: 21 additions & 0 deletions gensim/test/test_corpora_dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,27 @@ def testFilter(self):
d.filter_extremes(no_below=2, no_above=1.0, keep_n=4)
expected = {0: 3, 1: 3, 2: 3, 3: 3}
self.assertEqual(d.dfs, expected)

def testFilterKeepTokens_keepTokens(self):
# provide keep_tokens argument, keep the tokens given
d = Dictionary(self.texts)
d.filter_extremes(no_below=3, no_above=1.0, keep_tokens=['human', 'survey'])
expected = set(['graph', 'trees', 'human', 'system', 'user', 'survey'])
self.assertEqual(set(d.token2id.keys()), expected)

def testFilterKeepTokens_unchangedFunctionality(self):
# do not provide keep_tokens argument, filter_extremes functionality is unchanged
d = Dictionary(self.texts)
d.filter_extremes(no_below=3, no_above=1.0)
expected = set(['graph', 'trees', 'system', 'user'])
self.assertEqual(set(d.token2id.keys()), expected)

def testFilterKeepTokens_unseenToken(self):
# do provide keep_tokens argument with unseen tokens, filter_extremes functionality is unchanged
d = Dictionary(self.texts)
d.filter_extremes(no_below=3, no_above=1.0, keep_tokens=['unknown_token'])
expected = set(['graph', 'trees', 'system', 'user'])
self.assertEqual(set(d.token2id.keys()), expected)

def testFilterMostFrequent(self):
d = Dictionary(self.texts)
Expand Down

0 comments on commit 8c869cb

Please sign in to comment.