Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Unicode string incompatibility in gensim.similarities.fastss.editdist #3178

Merged
merged 5 commits into from
Jun 29, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions gensim/similarities/fastss.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ cdef extern from *:
WIDTH * CYTHON_RESTRICT pos_new;
WIDTH * CYTHON_RESTRICT pos_old;
int row_flip = 1; /* Does pos_new represent row1 or row2? */
int kind = PyUnicode_KIND(s1); /* How many bytes per unicode codepoint? */
if (kind != PyUnicode_KIND(s2)) return -1;
int kind1 = PyUnicode_KIND(s1); /* How many bytes per unicode codepoint? */
int kind2 = PyUnicode_KIND(s2);

WIDTH len_s1 = (WIDTH)PyUnicode_GET_LENGTH(s1);
WIDTH len_s2 = (WIDTH)PyUnicode_GET_LENGTH(s2);
Expand All @@ -39,15 +39,15 @@ cdef extern from *:
const WIDTH tmpi = len_s1; len_s1 = len_s2; len_s2 = tmpi;
}
if (len_s2 - len_s1 > maximum) return maximum + 1;
if (len_s2 > MAX_WORD_LENGTH) return -2;
if (len_s2 > MAX_WORD_LENGTH) return -1;
void * s1_data = PyUnicode_DATA(s1);
void * s2_data = PyUnicode_DATA(s2);

for (WIDTH tmpi = 0; tmpi <= len_s1; tmpi++) row2[tmpi] = tmpi;

for (WIDTH i2 = 0; i2 < len_s2; i2++) {
int all_bad = i2 >= maximum;
const Py_UCS4 ch = PyUnicode_READ(kind, s2_data, i2);
const Py_UCS4 ch = PyUnicode_READ(kind2, s2_data, i2);
row_flip = 1 - row_flip;
if (row_flip) {
pos_new = row2; pos_old = row1;
Expand All @@ -58,7 +58,7 @@ cdef extern from *:

for (WIDTH i1 = 0; i1 < len_s1; i1++) {
WIDTH val = *(pos_old++);
if (ch != PyUnicode_READ(kind, s1_data, i1)) {
if (ch != PyUnicode_READ(kind1, s1_data, i1)) {
const WIDTH _val1 = *pos_old;
const WIDTH _val2 = *pos_new;
if (_val1 < val) val = _val1;
Expand Down Expand Up @@ -96,8 +96,6 @@ def editdist(s1: str, s2: str, max_dist=None):
if result >= 0:
return result
elif result == -1:
raise ValueError("incompatible types of unicode strings")
elif result == -2:
raise ValueError(f"editdist doesn't support strings longer than {MAX_WORD_LENGTH} characters")
else:
raise ValueError(f"editdist returned an error: {result}")
Expand Down
27 changes: 27 additions & 0 deletions gensim/test/test_similarities.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from gensim.similarities import SparseTermSimilarityMatrix
from gensim.similarities import LevenshteinSimilarityIndex
from gensim.similarities.docsim import _nlargest
from gensim.similarities.fastss import editdist

try:
from pyemd import emd # noqa:F401
Expand Down Expand Up @@ -1631,6 +1632,32 @@ def test_most_similar(self):
self.assertTrue(numpy.allclose(first_similarities**2.0, second_similarities))


class TestFastSS(unittest.TestCase):
def test_editdist_same_unicode_kind_latin1(self):
"""Test editdist returns the expected result with two Latin-1 strings."""
expected = 2
actual = editdist('Zizka', 'siska')
assert expected == actual

def test_editdist_same_unicode_kind_ucs2(self):
"""Test editdist returns the expected result with two UCS-2 strings."""
expected = 2
actual = editdist('Žižka', 'šiška')
assert expected == actual

def test_editdist_same_unicode_kind_ucs4(self):
"""Test editdist returns the expected result with two UCS-4 strings."""
expected = 2
actual = editdist('Žižka 😀', 'šiška 😀')
assert expected == actual

def test_editdist_different_unicode_kinds(self):
"""Test editdist returns the expected result with strings of different Unicode kinds."""
expected = 2
actual = editdist('Žižka', 'siska')
assert expected == actual


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
unittest.main()