Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multiprocessing support for BM25 #2146

Merged
merged 15 commits into from
Aug 13, 2018
52 changes: 44 additions & 8 deletions gensim/summarization/bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
... ["cat", "outer", "space"],
... ["wag", "dog"]
... ]
>>> result = get_bm25_weights(corpus)
>>> result = get_bm25_weights(corpus, n_jobs=-1)


Data:
Expand All @@ -37,7 +37,9 @@
import math
from six import iteritems
from six.moves import xrange

from functools import partial
from multiprocessing import Pool
from ..utils import effective_n_jobs

PARAM_K1 = 1.5
PARAM_B = 0.75
Expand Down Expand Up @@ -152,14 +154,42 @@ def get_scores(self, document, average_idf):
return scores


def get_bm25_weights(corpus):
def _get_scores(bm25, document, average_idf):
"""Helper function for retrieving bm25 scores of given `document` in parallel
in relation to every item in corpus.

Parameters
----------
bm25 : BM25 object
BM25 object fitted on the corpus where documents are retrieved.
document : list of str
Document to be scored.
average_idf : float
Average idf in corpus.

Returns
-------
list of float
BM25 scores.

"""
scores = []
for index in xrange(bm25.corpus_size):
score = bm25.get_score(document, index, average_idf)
scores.append(score)
return scores


def get_bm25_weights(corpus, n_jobs=1):
"""Returns BM25 scores (weights) of documents in corpus.
Each document has to be weighted with every document in given corpus.

Parameters
----------
corpus : list of list of str
Corpus of documents.
n_jobs : int
The number of processes to use for computing bm25.

Returns
-------
Expand All @@ -174,15 +204,21 @@ def get_bm25_weights(corpus):
... ["cat", "outer", "space"],
... ["wag", "dog"]
... ]
>>> result = get_bm25_weights(corpus)
>>> result = get_bm25_weights(corpus, n_jobs=-1)

"""
bm25 = BM25(corpus)
average_idf = sum(float(val) for val in bm25.idf.values()) / len(bm25.idf)

weights = []
for doc in corpus:
scores = bm25.get_scores(doc, average_idf)
weights.append(scores)
n_processes = effective_n_jobs(n_jobs)
if n_processes == 1:
weights = [bm25.get_scores(doc, average_idf) for doc in corpus]
return weights

get_score = partial(_get_scores, bm25, average_idf=average_idf)
pool = Pool(n_processes)
weights = pool.map(get_score, corpus)
pool.close()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strange order, you close and join after, why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@menshikh-iv I came across this SO question a while ago and learned that one actually need to call close before using join. This can also be found in python's official docs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

pool.join()

return weights
9 changes: 9 additions & 0 deletions gensim/test/test_BM25.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ def test_disjoint_docs_if_weight_zero(self):
self.assertAlmostEqual(weights[0][1], 0)
self.assertAlmostEqual(weights[1][0], 0)

def test_multiprocessing(self):
""" Result should be the same using different processes """
weights1 = get_bm25_weights(common_texts)
weights2 = get_bm25_weights(common_texts, n_jobs=2)
weights3 = get_bm25_weights(common_texts, n_jobs=-1)
self.assertAlmostEqual(weights1, weights2)
self.assertAlmostEqual(weights1, weights3)
self.assertAlmostEqual(weights2, weights3)


if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
Expand Down
28 changes: 28 additions & 0 deletions gensim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@

from smart_open import smart_open

from multiprocessing import cpu_count

if sys.version_info[0] >= 3:
unicode = str

Expand Down Expand Up @@ -2025,3 +2027,29 @@ def lazy_flatten(nested_list):
yield sub
else:
yield el


def effective_n_jobs(n_jobs):
"""Determines the number of jobs can run in parallel.

Just like in sklearn, passing n_jobs=-1 means using all available
CPU cores.

Parameters
----------
n_jobs : int
Number of workers requested by caller.

Returns
-------
int
Number of effective jobs.

"""
if n_jobs == 0:
raise ValueError('n_jobs == 0 in Parallel has no meaning')
elif n_jobs is None:
return 1
elif n_jobs < 0:
n_jobs = max(cpu_count() + 1 + n_jobs, 1)
return n_jobs