Skip to content

Commit

Permalink
Update fasttext wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
jbaiter committed Feb 20, 2018
1 parent 779180e commit 7c6afb2
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 30 deletions.
6 changes: 6 additions & 0 deletions gensim/models/wrappers/fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,12 @@ def load(cls, *args, **kwargs):
if hasattr(model.wv, 'syn0_all'):
setattr(model.wv, 'syn0_ngrams', model.wv.syn0_all)
delattr(model.wv, 'syn0_all')
setattr(model.wv, 'bucket', model.wv.syn0_ngrams.shape[0])
if not hasattr(model.wv, 'hash2index') and hasattr(model.wv, 'ngrams'):
model.wv.hash2index = {}
for i, ngram in enumerate(model.wv.ngrams):
ngram_hash = ft_hash(ngram) % model.wv.bucket
model.wv.hash2index[ngram_hash] = i
return model

@classmethod
Expand Down
52 changes: 22 additions & 30 deletions gensim/test/test_fasttext_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,16 @@ def testLoadFastTextFormat(self):
# vector for oov words are slightly different from original FastText due to discarding unused ngrams
# obtained using a modified version of ./fasttext print-word-vectors lee_fasttext_new.bin
expected_vec_oov = [
-0.23825,
-0.58482,
-0.22276,
-0.41215,
0.91015,
-1.6786,
-0.26724,
0.58818,
0.57828,
0.75801
-0.21929,
-0.53778,
-0.22463,
-0.41735,
0.71737,
-1.59758,
-0.24833,
0.62028,
0.53203,
0.77568
]
self.assertTrue(numpy.allclose(model["rejection"], expected_vec_oov, atol=1e-4))

Expand Down Expand Up @@ -194,16 +194,16 @@ def testLoadFastTextNewFormat(self):
# vector for oov words are slightly different from original FastText due to discarding unused ngrams
# obtained using a modified version of ./fasttext print-word-vectors lee_fasttext_new.bin
expected_vec_oov = [
-0.53378,
-0.19,
0.013482,
-0.86767,
-0.21684,
-0.89928,
0.45124,
0.18025,
-0.14128,
0.22508
-0.49111,
-0.13122,
-0.02109,
-0.88769,
-0.20105,
-0.91732,
0.47243,
0.19708,
-0.17856,
0.19815
]
self.assertTrue(numpy.allclose(new_model["rejection"], expected_vec_oov, atol=1e-4))

Expand Down Expand Up @@ -296,8 +296,6 @@ def testLookup(self):
# Out of vocab check
self.assertFalse('nights' in self.test_model.wv.vocab)
self.assertTrue(numpy.allclose(self.test_model['nights'], self.test_model[['nights']]))
# Word with no ngrams in model
self.assertRaises(KeyError, lambda: self.test_model['a!@'])

def testContains(self):
"""Tests __contains__ for in-vocab and out-of-vocab words"""
Expand All @@ -307,20 +305,14 @@ def testContains(self):
# Out of vocab check
self.assertFalse('nights' in self.test_model.wv.vocab)
self.assertTrue('nights' in self.test_model)
# Word with no ngrams in model
self.assertFalse('a!@' in self.test_model.wv.vocab)
self.assertFalse('a!@' in self.test_model)

def testWmdistance(self):
"""Tests wmdistance for docs with in-vocab and out-of-vocab words"""
doc = ['night', 'payment']
oov_doc = ['nights', 'forests', 'payments']
ngrams_absent_doc = ['a!@', 'b#$']

dist = self.test_model.wmdistance(doc, oov_doc)
self.assertNotEqual(float('inf'), dist)
dist = self.test_model.wmdistance(doc, ngrams_absent_doc)
self.assertEqual(float('inf'), dist)

def testDoesntMatch(self):
"""Tests doesnt_match for list of out-of-vocab words"""
Expand Down Expand Up @@ -363,8 +355,8 @@ def testPersistenceForOldVersions(self):
1.19031894, 2.01627707, 1.98942184, -1.39095843, -0.65036952])
self.assertTrue(numpy.allclose(loaded_model["the"], in_expected_vec, atol=1e-4))
# out-of-vocab word
out_expected_vec = numpy.array([-1.34948218, -0.8686831, -1.51483142, -1.0164026, 0.56272298,
0.66228276, 1.06477463, 1.1355902, -0.80972326, -0.39845538])
out_expected_vec = numpy.array([-0.33959097, -0.21121596, -0.37212455, -0.25057459, 0.11222091,
0.17517674, 0.26949012, 0.29352987, -0.1930912, -0.09438948])
self.assertTrue(numpy.allclose(loaded_model["random_word"], out_expected_vec, atol=1e-4))


Expand Down

0 comments on commit 7c6afb2

Please sign in to comment.