diff --git a/CHANGELOG.md b/CHANGELOG.md index de379e60c0..f99bfe48b4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ Changes * [#3115](https://github.com/RaRe-Technologies/gensim/pull/3115): Make LSI dispatcher CLI param for number of jobs optional, by [@robguinness](https://github.com/robguinness) * [#3128](https://github.com/RaRe-Technologies/gensim/pull/3128): Materialize and copy the corpus passed to SoftCosineSimilarity, by [@Witiko](https://github.com/Witiko) * [#3131](https://github.com/RaRe-Technologies/gensim/pull/3131): Added import to Nmf docs, and to models/__init__.py, by [@properGrammar](https://github.com/properGrammar) +* [#3163](https://github.com/RaRe-Technologies/gensim/pull/3163): Optimize word mover distance (WMD) computation, by [@flowlight0](https://github.com/flowlight0) * [#2965](https://github.com/RaRe-Technologies/gensim/pull/2965): Remove strip_punctuation2 alias of strip_punctuation, by [@sciatro](https://github.com/sciatro) ### :books: Documentation diff --git a/gensim/models/keyedvectors.py b/gensim/models/keyedvectors.py index daa5482184..b3f27cad2d 100644 --- a/gensim/models/keyedvectors.py +++ b/gensim/models/keyedvectors.py @@ -178,6 +178,7 @@ ) import numpy as np from scipy import stats +from scipy.spatial.distance import cdist from gensim import utils, matutils # utility fnc for pickling, common scipy operations etc from gensim.corpora.dictionary import Dictionary @@ -901,23 +902,16 @@ def wmdistance(self, document1, document2, norm=True): # Both documents are composed of a single unique token => zero distance. return 0.0 - # Sets for faster look-up. - docset1 = set(document1) - docset2 = set(document2) + doclist1 = list(set(document1)) + doclist2 = list(set(document2)) + v1 = np.array([self.get_vector(token, norm=norm) for token in doclist1]) + v2 = np.array([self.get_vector(token, norm=norm) for token in doclist2]) + doc1_indices = dictionary.doc2idx(doclist1) + doc2_indices = dictionary.doc2idx(doclist2) # Compute distance matrix. distance_matrix = zeros((vocab_len, vocab_len), dtype=double) - for i, t1 in dictionary.items(): - if t1 not in docset1: - continue - - for j, t2 in dictionary.items(): - if t2 not in docset2 or distance_matrix[i, j] != 0.0: - continue - - # Compute Euclidean distance between (potentially unit-normed) word vectors. - distance_matrix[i, j] = distance_matrix[j, i] = np.sqrt( - np_sum((self.get_vector(t1, norm=norm) - self.get_vector(t2, norm=norm))**2)) + distance_matrix[np.ix_(doc1_indices, doc2_indices)] = cdist(v1, v2) if abs(np_sum(distance_matrix)) < 1e-8: # `emd` gets stuck if the distance matrix contains only zeros.