Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use FastSS for fast kNN over Levenshtein distance #3146

Merged
merged 27 commits into from
May 20, 2021
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
af5833d
Use VP-Tree for fast kNN over Levenshtein distance
Witiko May 15, 2021
fb98a43
Use DAWG for fast approximate kNN over Levenshtein distance
Witiko May 15, 2021
80ec65f
Remove itertools import
Witiko May 16, 2021
40b96bc
Improve unit tests for the levenshtein module
Witiko May 16, 2021
48cb664
Implement a back-off strategy for max_distance
Witiko May 16, 2021
ba36d01
Apply suggestions from code review
Witiko May 16, 2021
567b0a4
Replace DAWG with TinyFastSS
Witiko May 16, 2021
b675507
clean up FastSS
piskvorky May 16, 2021
6c2d033
Return only similarities greater than zero
Witiko May 16, 2021
80b99d0
Merge branch 'levenshtein-ball-tree' of github.com:Witiko/gensim into…
Witiko May 16, 2021
362f458
reintroduce max_dist to FastSS query
piskvorky May 16, 2021
22a0221
Silence flake8 about fastss
Witiko May 16, 2021
6350a22
Merge branch 'levenshtein-ball-tree' of github.com:Witiko/gensim into…
Witiko May 16, 2021
e2e1d9f
Suggest max_distance <= 2
Witiko May 16, 2021
80cdb7e
Suggest max_distance < 3
Witiko May 16, 2021
c91bda5
Eagerly filter out zero similarities
Witiko May 16, 2021
18fe2a2
clarify + add comments
piskvorky May 16, 2021
9e614c0
Merge branch 'levenshtein-ball-tree' of github.com:Witiko/gensim into…
piskvorky May 16, 2021
da501b5
update docs
piskvorky May 16, 2021
de2ec13
minor doc fixes
piskvorky May 16, 2021
6c4abc5
clean up FastSS & logging
piskvorky May 17, 2021
05284d1
cythonize FastSS
piskvorky May 18, 2021
9381965
remove dead code
piskvorky May 18, 2021
7054f90
update Cython to 0.29.23
piskvorky May 18, 2021
ae91204
make max_distance=2 the default in LevenshteinSimilarityIndex
piskvorky May 18, 2021
86e8a25
FastSS cleanup
piskvorky May 19, 2021
7655d75
rewrite the editdist function (Levenshtein) in C
piskvorky May 19, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 1 addition & 13 deletions gensim/similarities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,7 @@
"""

# bring classes directly into package namespace, to save some typing
import warnings
try:
import Levenshtein # noqa:F401
except ImportError:
msg = (
"The gensim.similarities.levenshtein submodule is disabled, because the optional "
"Levenshtein package <https://pypi.org/project/python-Levenshtein/> is unavailable. "
"Install Levenhstein (e.g. `pip install python-Levenshtein`) to suppress this warning."
)
warnings.warn(msg)
LevenshteinSimilarityIndex = None
else:
from .levenshtein import LevenshteinSimilarityIndex # noqa:F401
from .levenshtein import LevenshteinSimilarityIndex # noqa:F401
from .docsim import ( # noqa:F401
Similarity,
MatrixSimilarity,
Expand Down
Empty file modified gensim/similarities/docsim.py
100755 → 100644
Empty file.
159 changes: 159 additions & 0 deletions gensim/similarities/fastss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2021 Radim Rehurek <radimrehurek@seznam.cz>
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html
# Code adapted from TinyFastSS (public domain), https://github.com/fujimotos/TinyFastSS

"""Create and query FastSS index for fast approximate string similarity search."""

import struct
import itertools

ENCODING = 'utf-8'
DELIMITER = b'\x00'


def editdist(s1, s2):
"""Return the Levenshtein distance between two strings.

>>> editdist('aiu', 'aie')
1
"""
matrix = {}

for i in range(len(s1) + 1):
matrix[(i, 0)] = i
for j in range(len(s2) + 1):
matrix[(0, j)] = j

for i in range(1, len(s1) + 1):
for j in range(1, len(s2) + 1):
if s1[i - 1] == s2[j - 1]:
matrix[(i, j)] = matrix[(i - 1, j - 1)]
else:
matrix[(i, j)] = min(
matrix[(i - 1, j)],
matrix[(i, j - 1)],
matrix[(i - 1, j - 1)]
) + 1

return matrix[(i, j)]


def indexkeys(word, max_dist):
"""Return the set of index keys ("variants") of a word.

>>> indexkeys('aiu', 1)
{'aiu', 'iu', 'au', 'ai'}
"""
res = set()
wordlen = len(word)
limit = min(max_dist, wordlen) + 1

for dist in range(limit):
variants = itertools.combinations(word, wordlen - dist)

for variant in variants:
res.add(''.join(variant))

return res


def int2byte(i):
"""Encode a positive int (<= 256) into a 8-bit byte.

>>> int2byte(1)
b'\x01'
"""
return struct.pack('B', i)


def byte2int(b):
"""Decode a 8-bit byte into an integer.

>>> byte2int(b'\x01')
1
"""
return struct.unpack('B', b)[0]


def set2bytes(s):
"""Serialize a set of unicode strings into bytes.

>>> set2byte({u'a', u'b', u'c'})
b'a\x00b\x00c'
"""
lis = []
for uword in sorted(s):
bword = uword.encode(ENCODING)
lis.append(bword)
return DELIMITER.join(lis)


def bytes2set(b):
"""Deserialize bytes into a set of unicode strings.

>>> int2byte(b'a\x00b\x00c')
{u'a', u'b', u'c'}
"""
if not b:
return set()

lis = b.split(DELIMITER)
return set(bword.decode(ENCODING) for bword in lis)


class FastSS:
"""Open a FastSS index."""

def __init__(self, max_dist=2):
"""max_dist: the upper threshold of edit distance of works from the index."""
self.db = {}
self.max_dist = max_dist

def __str__(self):
return "%s<max_dist=%s, db_size=%i>" % (self.__class__.__name__, self.max_dist, len(self.db), )

def __contains__(self, word):
bkey = word.encode(ENCODING)
if bkey in self.db:
return word in bytes2set(self.db[bkey])
return False

def add(self, word):
"""Add a string to the index."""
for key in indexkeys(word, self.max_dist):
bkey = key.encode(ENCODING)
wordset = {word}

if bkey in self.db:
wordset |= bytes2set(self.db[bkey])

self.db[bkey] = set2bytes(wordset)

def query(self, word, max_dist=None):
"""Find all words from the index that are within max_dist of `word`."""
if max_dist is None:
max_dist = self.max_dist
if max_dist > self.max_dist:
raise ValueError(
f"query max_dist={max_dist} cannot be greater than "
f"max_dist={self.max_dist} specified in the constructor"
)

res = {d: [] for d in range(max_dist + 1)}
cands = set()

for key in indexkeys(word, max_dist):
bkey = key.encode(ENCODING)

if bkey in self.db:
cands.update(bytes2set(self.db[bkey]))

for cand in cands:
dist = editdist(word, cand)
if dist <= max_dist:
res[dist].append(cand)

return res
Loading