Skip to content

Commit

Permalink
Optimize word mover distance (WMD) computation (#3163)
Browse files Browse the repository at this point in the history
* Faster WMD computation by removing a nested loop

* Update CHANGELOG.md

Co-authored-by: Michael Penkov <m@penkov.dev>
  • Loading branch information
flowlight0 and mpenkov authored Jun 29, 2021
1 parent 2f23566 commit b378b1b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 14 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Changes
* [#3115](https://github.com/RaRe-Technologies/gensim/pull/3115): Make LSI dispatcher CLI param for number of jobs optional, by [@robguinness](https://github.com/robguinness)
* [#3128](https://github.com/RaRe-Technologies/gensim/pull/3128): Materialize and copy the corpus passed to SoftCosineSimilarity, by [@Witiko](https://github.com/Witiko)
* [#3131](https://github.com/RaRe-Technologies/gensim/pull/3131): Added import to Nmf docs, and to models/__init__.py, by [@properGrammar](https://github.com/properGrammar)
* [#3163](https://github.com/RaRe-Technologies/gensim/pull/3163): Optimize word mover distance (WMD) computation, by [@flowlight0](https://github.com/flowlight0)
* [#2965](https://github.com/RaRe-Technologies/gensim/pull/2965): Remove strip_punctuation2 alias of strip_punctuation, by [@sciatro](https://github.com/sciatro)

### :books: Documentation
Expand Down
22 changes: 8 additions & 14 deletions gensim/models/keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@
)
import numpy as np
from scipy import stats
from scipy.spatial.distance import cdist

from gensim import utils, matutils # utility fnc for pickling, common scipy operations etc
from gensim.corpora.dictionary import Dictionary
Expand Down Expand Up @@ -901,23 +902,16 @@ def wmdistance(self, document1, document2, norm=True):
# Both documents are composed of a single unique token => zero distance.
return 0.0

# Sets for faster look-up.
docset1 = set(document1)
docset2 = set(document2)
doclist1 = list(set(document1))
doclist2 = list(set(document2))
v1 = np.array([self.get_vector(token, norm=norm) for token in doclist1])
v2 = np.array([self.get_vector(token, norm=norm) for token in doclist2])
doc1_indices = dictionary.doc2idx(doclist1)
doc2_indices = dictionary.doc2idx(doclist2)

# Compute distance matrix.
distance_matrix = zeros((vocab_len, vocab_len), dtype=double)
for i, t1 in dictionary.items():
if t1 not in docset1:
continue

for j, t2 in dictionary.items():
if t2 not in docset2 or distance_matrix[i, j] != 0.0:
continue

# Compute Euclidean distance between (potentially unit-normed) word vectors.
distance_matrix[i, j] = distance_matrix[j, i] = np.sqrt(
np_sum((self.get_vector(t1, norm=norm) - self.get_vector(t2, norm=norm))**2))
distance_matrix[np.ix_(doc1_indices, doc2_indices)] = cdist(v1, v2)

if abs(np_sum(distance_matrix)) < 1e-8:
# `emd` gets stuck if the distance matrix contains only zeros.
Expand Down

0 comments on commit b378b1b

Please sign in to comment.