Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize word mover distance (WMD) computation #3163

Merged
merged 3 commits into from
Jun 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Changes
* [#3115](https://github.com/RaRe-Technologies/gensim/pull/3115): Make LSI dispatcher CLI param for number of jobs optional, by [@robguinness](https://github.com/robguinness)
* [#3128](https://github.com/RaRe-Technologies/gensim/pull/3128): Materialize and copy the corpus passed to SoftCosineSimilarity, by [@Witiko](https://github.com/Witiko)
* [#3131](https://github.com/RaRe-Technologies/gensim/pull/3131): Added import to Nmf docs, and to models/__init__.py, by [@properGrammar](https://github.com/properGrammar)
* [#3163](https://github.com/RaRe-Technologies/gensim/pull/3163): Optimize word mover distance (WMD) computation, by [@flowlight0](https://github.com/flowlight0)
* [#2965](https://github.com/RaRe-Technologies/gensim/pull/2965): Remove strip_punctuation2 alias of strip_punctuation, by [@sciatro](https://github.com/sciatro)

### :books: Documentation
Expand Down
22 changes: 8 additions & 14 deletions gensim/models/keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@
)
import numpy as np
from scipy import stats
from scipy.spatial.distance import cdist

from gensim import utils, matutils # utility fnc for pickling, common scipy operations etc
from gensim.corpora.dictionary import Dictionary
Expand Down Expand Up @@ -901,23 +902,16 @@ def wmdistance(self, document1, document2, norm=True):
# Both documents are composed of a single unique token => zero distance.
return 0.0

# Sets for faster look-up.
docset1 = set(document1)
docset2 = set(document2)
doclist1 = list(set(document1))
doclist2 = list(set(document2))
v1 = np.array([self.get_vector(token, norm=norm) for token in doclist1])
v2 = np.array([self.get_vector(token, norm=norm) for token in doclist2])
doc1_indices = dictionary.doc2idx(doclist1)
doc2_indices = dictionary.doc2idx(doclist2)

# Compute distance matrix.
distance_matrix = zeros((vocab_len, vocab_len), dtype=double)
for i, t1 in dictionary.items():
if t1 not in docset1:
continue

for j, t2 in dictionary.items():
if t2 not in docset2 or distance_matrix[i, j] != 0.0:
continue

# Compute Euclidean distance between (potentially unit-normed) word vectors.
distance_matrix[i, j] = distance_matrix[j, i] = np.sqrt(
np_sum((self.get_vector(t1, norm=norm) - self.get_vector(t2, norm=norm))**2))
distance_matrix[np.ix_(doc1_indices, doc2_indices)] = cdist(v1, v2)

if abs(np_sum(distance_matrix)) < 1e-8:
# `emd` gets stuck if the distance matrix contains only zeros.
Expand Down