diff --git a/gensim/models/wrappers/fasttext.py b/gensim/models/wrappers/fasttext.py index 752e14eccb..8678b55d5c 100644 --- a/gensim/models/wrappers/fasttext.py +++ b/gensim/models/wrappers/fasttext.py @@ -278,6 +278,12 @@ def load(cls, *args, **kwargs): if hasattr(model.wv, 'syn0_all'): setattr(model.wv, 'syn0_ngrams', model.wv.syn0_all) delattr(model.wv, 'syn0_all') + setattr(model.wv, 'bucket', model.wv.syn0_ngrams.shape[0]) + if not hasattr(model.wv, 'hash2index') and hasattr(model.wv, 'ngrams'): + model.wv.hash2index = {} + for i, ngram in enumerate(model.wv.ngrams): + ngram_hash = ft_hash(ngram) % model.wv.bucket + model.wv.hash2index[ngram_hash] = i return model @classmethod diff --git a/gensim/test/test_fasttext_wrapper.py b/gensim/test/test_fasttext_wrapper.py index bc995f8159..247d4a0ba7 100644 --- a/gensim/test/test_fasttext_wrapper.py +++ b/gensim/test/test_fasttext_wrapper.py @@ -143,16 +143,16 @@ def testLoadFastTextFormat(self): # vector for oov words are slightly different from original FastText due to discarding unused ngrams # obtained using a modified version of ./fasttext print-word-vectors lee_fasttext_new.bin expected_vec_oov = [ - -0.23825, - -0.58482, - -0.22276, - -0.41215, - 0.91015, - -1.6786, - -0.26724, - 0.58818, - 0.57828, - 0.75801 + -0.21929, + -0.53778, + -0.22463, + -0.41735, + 0.71737, + -1.59758, + -0.24833, + 0.62028, + 0.53203, + 0.77568 ] self.assertTrue(numpy.allclose(model["rejection"], expected_vec_oov, atol=1e-4)) @@ -194,16 +194,16 @@ def testLoadFastTextNewFormat(self): # vector for oov words are slightly different from original FastText due to discarding unused ngrams # obtained using a modified version of ./fasttext print-word-vectors lee_fasttext_new.bin expected_vec_oov = [ - -0.53378, - -0.19, - 0.013482, - -0.86767, - -0.21684, - -0.89928, - 0.45124, - 0.18025, - -0.14128, - 0.22508 + -0.49111, + -0.13122, + -0.02109, + -0.88769, + -0.20105, + -0.91732, + 0.47243, + 0.19708, + -0.17856, + 0.19815 ] self.assertTrue(numpy.allclose(new_model["rejection"], expected_vec_oov, atol=1e-4)) @@ -296,8 +296,6 @@ def testLookup(self): # Out of vocab check self.assertFalse('nights' in self.test_model.wv.vocab) self.assertTrue(numpy.allclose(self.test_model['nights'], self.test_model[['nights']])) - # Word with no ngrams in model - self.assertRaises(KeyError, lambda: self.test_model['a!@']) def testContains(self): """Tests __contains__ for in-vocab and out-of-vocab words""" @@ -307,20 +305,14 @@ def testContains(self): # Out of vocab check self.assertFalse('nights' in self.test_model.wv.vocab) self.assertTrue('nights' in self.test_model) - # Word with no ngrams in model - self.assertFalse('a!@' in self.test_model.wv.vocab) - self.assertFalse('a!@' in self.test_model) def testWmdistance(self): """Tests wmdistance for docs with in-vocab and out-of-vocab words""" doc = ['night', 'payment'] oov_doc = ['nights', 'forests', 'payments'] - ngrams_absent_doc = ['a!@', 'b#$'] dist = self.test_model.wmdistance(doc, oov_doc) self.assertNotEqual(float('inf'), dist) - dist = self.test_model.wmdistance(doc, ngrams_absent_doc) - self.assertEqual(float('inf'), dist) def testDoesntMatch(self): """Tests doesnt_match for list of out-of-vocab words""" @@ -363,8 +355,8 @@ def testPersistenceForOldVersions(self): 1.19031894, 2.01627707, 1.98942184, -1.39095843, -0.65036952]) self.assertTrue(numpy.allclose(loaded_model["the"], in_expected_vec, atol=1e-4)) # out-of-vocab word - out_expected_vec = numpy.array([-1.34948218, -0.8686831, -1.51483142, -1.0164026, 0.56272298, - 0.66228276, 1.06477463, 1.1355902, -0.80972326, -0.39845538]) + out_expected_vec = numpy.array([-0.33959097, -0.21121596, -0.37212455, -0.25057459, 0.11222091, + 0.17517674, 0.26949012, 0.29352987, -0.1930912, -0.09438948]) self.assertTrue(numpy.allclose(loaded_model["random_word"], out_expected_vec, atol=1e-4))