Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the 'keep_tokens' parameter to 'filter_extremes' and test it #1210

Merged
merged 2 commits into from
Mar 13, 2017
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions gensim/corpora/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,14 +172,16 @@ def doc2bow(self, document, allow_update=False, return_missing=False):
else:
return result

def filter_extremes(self, no_below=5, no_above=0.5, keep_n=100000):
def filter_extremes(self, no_below=5, no_above=0.5, keep_n=100000, keep_tokens=None):
"""
Filter out tokens that appear in

1. less than `no_below` documents (absolute number) or
2. more than `no_above` documents (fraction of total corpus size, *not*
absolute number).
3. after (1) and (2), keep only the first `keep_n` most frequent tokens (or
3. if tokens are given in keep_tokens (list of strings), they will be kept regardless of
the `no_below` and `no_above` settings
4. after (1), (2) and (3), keep only the first `keep_n` most frequent tokens (or
keep all if `None`).

After the pruning, shrink resulting gaps in word ids.
Expand All @@ -193,6 +195,12 @@ def filter_extremes(self, no_below=5, no_above=0.5, keep_n=100000):
good_ids = (
v for v in itervalues(self.token2id)
if no_below <= self.dfs.get(v, 0) <= no_above_abs)
# add ids of keep_tokens elements to good_ids
if keep_tokens:
keep_ids = [self.token2id[v] for v in keep_tokens if v in self.token2id]
good_ids_copy = (v for v in itervalues(self.token2id) if no_below <= self.dfs.get(v, 0) <= no_above_abs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please keep to a single iteration of token2id by adding a or v in keep_tokens check if keep_tokens is present

keep_ids = list(set(keep_ids).union(set(good_ids_copy)))
good_ids = keep_ids
good_ids = sorted(good_ids, key=self.dfs.get, reverse=True)
if keep_n is not None:
good_ids = good_ids[:keep_n]
Expand Down
21 changes: 21 additions & 0 deletions gensim/test/test_corpora_dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,27 @@ def testFilter(self):
d.filter_extremes(no_below=2, no_above=1.0, keep_n=4)
expected = {0: 3, 1: 3, 2: 3, 3: 3}
self.assertEqual(d.dfs, expected)

def testFilterKeepTokens_keepTokens(self):
# provide keep_tokens argument, keep the tokens given
d = Dictionary(self.texts)
d.filter_extremes(no_below=3, no_above=1.0, keep_tokens=['human', 'survey'])
expected = set(['graph', 'trees', 'human', 'system', 'user', 'survey'])
self.assertEqual(set(d.token2id.keys()), expected)

def testFilterKeepTokens_unchangedFunctionality(self):
# do not provide keep_tokens argument, filter_extremes functionality is unchanged
d = Dictionary(self.texts)
d.filter_extremes(no_below=3, no_above=1.0)
expected = set(['graph', 'trees', 'system', 'user'])
self.assertEqual(set(d.token2id.keys()), expected)

def testFilterKeepTokens_unseenToken(self):
# do provide keep_tokens argument with unseen tokens, filter_extremes functionality is unchanged
d = Dictionary(self.texts)
d.filter_extremes(no_below=3, no_above=1.0, keep_tokens=['unknown_token'])
expected = set(['graph', 'trees', 'system', 'user'])
self.assertEqual(set(d.token2id.keys()), expected)

def testFilterMostFrequent(self):
d = Dictionary(self.texts)
Expand Down