Skip to content

Commit

Permalink
rewrite the editdist function (Levenshtein) in C
Browse files Browse the repository at this point in the history
  • Loading branch information
piskvorky committed May 20, 2021
1 parent 86e8a25 commit 7655d75
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 53 deletions.
135 changes: 84 additions & 51 deletions gensim/similarities/fastss.pyx
Original file line number Diff line number Diff line change
@@ -1,67 +1,106 @@
#!/usr/bin/env cython
# cython: language_level=3
# cython: boundscheck=False
# cython: wraparound=False
# cython: cdivision=True
# cython: embedsignature=True
# coding: utf-8
#
# Copyright (C) 2021 Radim Rehurek <radimrehurek@seznam.cz>
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html
# Code adapted from TinyFastSS (public domain), https://github.com/fujimotos/TinyFastSS

"""Create and query FastSS index for fast approximate string similarity search."""
"""Fast approximate string similarity search using the FastSS algorithm."""

import itertools

from cpython.ref cimport PyObject

DEF MAX_WORD_LENGTH = 254 # a trade-off between speed (fast stack allocations) and versatility (long strings)

DEF MAX_WORD_LENGTH = 10000 # Maximum allowed word length, in characters. Must fit in the C `int` range.

def editdist(s1: unicode, s2: unicode, max_dist=None):

cdef extern from *:
"""
#define WIDTH int
#define MAX_WORD_LENGTH 10000
int ceditdist(PyObject * s1, PyObject * s2, WIDTH maximum) {
WIDTH row1[MAX_WORD_LENGTH + 1];
WIDTH row2[MAX_WORD_LENGTH + 1];
WIDTH * CYTHON_RESTRICT pos_new;
WIDTH * CYTHON_RESTRICT pos_old;
int row_flip = 1; /* Does pos_new represent row1 or row2? */
int kind = PyUnicode_KIND(s1); /* How many bytes per unicode codepoint? */
if (kind != PyUnicode_KIND(s2)) return -1;
WIDTH len_s1 = (WIDTH)PyUnicode_GET_LENGTH(s1);
WIDTH len_s2 = (WIDTH)PyUnicode_GET_LENGTH(s2);
if (len_s1 > len_s2) {
PyObject * tmp = s1; s1 = s2; s2 = tmp;
const WIDTH tmpi = len_s1; len_s1 = len_s2; len_s2 = tmpi;
}
if (len_s2 - len_s1 > maximum) return maximum + 1;
if (len_s2 > MAX_WORD_LENGTH) return -2;
void * s1_data = PyUnicode_DATA(s1);
void * s2_data = PyUnicode_DATA(s2);
for (WIDTH tmpi = 0; tmpi <= len_s1; tmpi++) row2[tmpi] = tmpi;
for (WIDTH i2 = 0; i2 < len_s2; i2++) {
int all_bad = i2 >= maximum;
const Py_UCS4 ch = PyUnicode_READ(kind, s2_data, i2);
row_flip = 1 - row_flip;
if (row_flip) {
pos_new = row2; pos_old = row1;
} else {
pos_new = row1; pos_old = row2;
}
*pos_new = i2 + 1;
for (WIDTH i1 = 0; i1 < len_s1; i1++) {
WIDTH val = *(pos_old++);
if (ch != PyUnicode_READ(kind, s1_data, i1)) {
const WIDTH _val1 = *pos_old;
const WIDTH _val2 = *pos_new;
if (_val1 < val) val = _val1;
if (_val2 < val) val = _val2;
val += 1;
}
*(++pos_new) = val;
if (all_bad && val <= maximum) all_bad = 0;
}
if (all_bad) return maximum + 1;
}
return row_flip ? row2[len_s1] : row1[len_s1];
}
"""
If the Levenshtein distance between two strings is <= max_dist, return that distance.
Otherwise return max_dist+1.
int ceditdist(PyObject *s1, PyObject *s2, int maximum)


def editdist(s1: str, s2: str, max_dist=None):
"""
Return the Levenshtein distance between two strings.
Use `max_dist` to control the maximum distance you care about. If the actual distance is larger
than `max_dist`, editdist will return early, with the value `max_dist+1`.
This is a performance optimization – for example if anything above distance 2 is uninteresting
to your application, call editdist with `max_dist=2` and ignore any return value greater than 2.
Leave `max_dist=None` (default) to always return the full Levenshtein distance (slower).
"""
if s1 == s2:
return 0

if len(s1) > len(s2):
s1, s2 = s2, s1

if len(s2) > MAX_WORD_LENGTH:
result = ceditdist(<PyObject *>s1, <PyObject *>s2, MAX_WORD_LENGTH if max_dist is None else int(max_dist))
if result >= 0:
return result
elif result == -1:
raise ValueError("incompatible types of unicode strings")
elif result == -2:
raise ValueError(f"editdist doesn't support strings longer than {MAX_WORD_LENGTH} characters")

cdef unsigned char len_s1 = len(s1)
cdef unsigned char len_s2 = len(s2)
cdef unsigned char maximum = min(len_s2, max_dist or MAX_WORD_LENGTH)

if len_s2 - len_s1 > maximum:
return maximum + 1

cdef unsigned char all_bad, i1, i2, val
cdef unsigned char[MAX_WORD_LENGTH + 1] row1, row2
cdef unsigned char * row_new = &row1[0]
cdef unsigned char * row_old = &row2[0]
for i1 in range(len_s1 + 1):
row_old[i1] = i1

for i2 in range(len_s2):
row_new[0] = i2 + 1
all_bad = i2 >= maximum
for i1 in range(len_s1):
if s1[i1] == s2[i2]:
val = row_old[i1]
else:
val = 1 + min((row_old[i1], row_old[i1 + 1], row_new[i1]))
row_new[i1 + 1] = val
if all_bad and val <= maximum:
all_bad = 0
if all_bad:
return maximum + 1
row_new, row_old = row_old, row_new

return row_old[len_s1]
else:
raise ValueError(f"editdist returned an error: {result}")


def indexkeys(word, max_dist):
Expand All @@ -75,9 +114,7 @@ def indexkeys(word, max_dist):
limit = min(max_dist, wordlen) + 1

for dist in range(limit):
variants = itertools.combinations(word, wordlen - dist)

for variant in variants:
for variant in itertools.combinations(word, wordlen - dist):
res.add(''.join(variant))

return res
Expand All @@ -98,10 +135,7 @@ def bytes2set(b):
>>> bytes2set(b'a\x00b\x00c')
{u'a', u'b', u'c'}
"""
if not b:
return set()

return set(b.decode('utf8').split('\x00'))
return set(b.decode('utf8').split('\x00')) if b else set()


class FastSS:
Expand Down Expand Up @@ -147,8 +181,7 @@ class FastSS:
max_dist = self.max_dist
if max_dist > self.max_dist:
raise ValueError(
f"query max_dist={max_dist} cannot be greater than "
f"max_dist={self.max_dist} specified in the constructor"
f"query max_dist={max_dist} cannot be greater than max_dist={self.max_dist} from the constructor"
)

res = {d: [] for d in range(max_dist + 1)}
Expand Down
4 changes: 2 additions & 2 deletions gensim/similarities/levenshtein.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from gensim.similarities.termsim import TermSimilarityIndex
from gensim import utils
try:
from gensim.similarities.fastss import FastSS
from gensim.similarities.fastss import FastSS, editdist # noqa:F401
except ImportError:
raise utils.NO_CYTHON

Expand All @@ -29,7 +29,7 @@ class LevenshteinSimilarityIndex(TermSimilarityIndex):
"Levenshtein similarity" is a modification of the Levenshtein (edit) distance,
defined in [charletetal17]_.
This implementation uses a neighbourhood algorithm (FastSS)
This implementation uses the FastSS neighbourhood algorithm
for fast kNN nearest-neighbor retrieval.
Parameters
Expand Down

0 comments on commit 7655d75

Please sign in to comment.