diff --git a/gensim/summarization/bm25.py b/gensim/summarization/bm25.py index 3a2bf5bbf6..50019c32fe 100644 --- a/gensim/summarization/bm25.py +++ b/gensim/summarization/bm25.py @@ -22,7 +22,7 @@ ... ["cat", "outer", "space"], ... ["wag", "dog"] ... ] ->>> result = get_bm25_weights(corpus) +>>> result = get_bm25_weights(corpus, n_jobs=-1) Data: @@ -37,7 +37,9 @@ import math from six import iteritems from six.moves import xrange - +from functools import partial +from multiprocessing import Pool +from ..utils import effective_n_jobs PARAM_K1 = 1.5 PARAM_B = 0.75 @@ -152,7 +154,33 @@ def get_scores(self, document, average_idf): return scores -def get_bm25_weights(corpus): +def _get_scores(bm25, document, average_idf): + """Helper function for retrieving bm25 scores of given `document` in parallel + in relation to every item in corpus. + + Parameters + ---------- + bm25 : BM25 object + BM25 object fitted on the corpus where documents are retrieved. + document : list of str + Document to be scored. + average_idf : float + Average idf in corpus. + + Returns + ------- + list of float + BM25 scores. + + """ + scores = [] + for index in xrange(bm25.corpus_size): + score = bm25.get_score(document, index, average_idf) + scores.append(score) + return scores + + +def get_bm25_weights(corpus, n_jobs=1): """Returns BM25 scores (weights) of documents in corpus. Each document has to be weighted with every document in given corpus. @@ -160,6 +188,8 @@ def get_bm25_weights(corpus): ---------- corpus : list of list of str Corpus of documents. + n_jobs : int + The number of processes to use for computing bm25. Returns ------- @@ -174,15 +204,21 @@ def get_bm25_weights(corpus): ... ["cat", "outer", "space"], ... ["wag", "dog"] ... ] - >>> result = get_bm25_weights(corpus) + >>> result = get_bm25_weights(corpus, n_jobs=-1) """ bm25 = BM25(corpus) average_idf = sum(float(val) for val in bm25.idf.values()) / len(bm25.idf) - weights = [] - for doc in corpus: - scores = bm25.get_scores(doc, average_idf) - weights.append(scores) + n_processes = effective_n_jobs(n_jobs) + if n_processes == 1: + weights = [bm25.get_scores(doc, average_idf) for doc in corpus] + return weights + + get_score = partial(_get_scores, bm25, average_idf=average_idf) + pool = Pool(n_processes) + weights = pool.map(get_score, corpus) + pool.close() + pool.join() return weights diff --git a/gensim/test/test_BM25.py b/gensim/test/test_BM25.py index a96302e8c9..e37efc2920 100644 --- a/gensim/test/test_BM25.py +++ b/gensim/test/test_BM25.py @@ -44,6 +44,15 @@ def test_disjoint_docs_if_weight_zero(self): self.assertAlmostEqual(weights[0][1], 0) self.assertAlmostEqual(weights[1][0], 0) + def test_multiprocessing(self): + """ Result should be the same using different processes """ + weights1 = get_bm25_weights(common_texts) + weights2 = get_bm25_weights(common_texts, n_jobs=2) + weights3 = get_bm25_weights(common_texts, n_jobs=-1) + self.assertAlmostEqual(weights1, weights2) + self.assertAlmostEqual(weights1, weights3) + self.assertAlmostEqual(weights2, weights3) + if __name__ == '__main__': logging.basicConfig(level=logging.DEBUG) diff --git a/gensim/utils.py b/gensim/utils.py index ec02cf4bb2..35abc203d8 100644 --- a/gensim/utils.py +++ b/gensim/utils.py @@ -44,6 +44,8 @@ from smart_open import smart_open +from multiprocessing import cpu_count + if sys.version_info[0] >= 3: unicode = str @@ -2025,3 +2027,29 @@ def lazy_flatten(nested_list): yield sub else: yield el + + +def effective_n_jobs(n_jobs): + """Determines the number of jobs can run in parallel. + + Just like in sklearn, passing n_jobs=-1 means using all available + CPU cores. + + Parameters + ---------- + n_jobs : int + Number of workers requested by caller. + + Returns + ------- + int + Number of effective jobs. + + """ + if n_jobs == 0: + raise ValueError('n_jobs == 0 in Parallel has no meaning') + elif n_jobs is None: + return 1 + elif n_jobs < 0: + n_jobs = max(cpu_count() + 1 + n_jobs, 1) + return n_jobs