diff --git a/gensim/models/wrappers/fasttext.py b/gensim/models/wrappers/fasttext.py index 926e994eaf..fbf301d78d 100644 --- a/gensim/models/wrappers/fasttext.py +++ b/gensim/models/wrappers/fasttext.py @@ -81,13 +81,13 @@ def word_vec(self, word, use_norm=False): else: word_vec = np.zeros(self.syn0_all.shape[1]) ngrams = FastText.compute_ngrams(word, self.min_n, self.max_n) + ngrams = [ng for ng in ngrams if ng in self.ngrams] if use_norm: ngram_weights = self.syn0_all_norm else: ngram_weights = self.syn0_all for ngram in ngrams: - if ngram in self.ngrams: - word_vec += ngram_weights[self.ngrams[ngram]] + word_vec += ngram_weights[self.ngrams[ngram]] if word_vec.any(): return word_vec / len(ngrams) else: # No ngrams of the word are present in self.ngrams @@ -344,7 +344,7 @@ def init_ngrams(self): ngram_indices = [] for i, ngram in enumerate(all_ngrams): ngram_hash = self.ft_hash(ngram) - ngram_indices.append((len(self.wv.vocab) + ngram_hash) % self.bucket) + ngram_indices.append(len(self.wv.vocab) + ngram_hash % self.bucket) self.wv.ngrams[ngram] = i self.wv.syn0_all = self.wv.syn0_all.take(ngram_indices, axis=0) @@ -353,10 +353,10 @@ def compute_ngrams(word, min_n, max_n): ngram_indices = [] BOW, EOW = ('<', '>') # Used by FastText to attach to all words as prefix and suffix extended_word = BOW + word + EOW - ngrams = set() - for i in range(len(extended_word) - min_n + 1): - for j in range(min_n, max(len(extended_word) - max_n, max_n + 1)): - ngrams.add(extended_word[i:i+j]) + ngrams = [] + for ngram_length in range(min_n, min(len(extended_word), max_n) + 1): + for i in range(0, len(extended_word) - ngram_length + 1): + ngrams.append(extended_word[i:i + ngram_length]) return ngrams @staticmethod diff --git a/gensim/test/test_data/lee_fasttext b/gensim/test/test_data/lee_fasttext new file mode 100644 index 0000000000..355fcec858 Binary files /dev/null and b/gensim/test/test_data/lee_fasttext differ diff --git a/gensim/test/test_fasttext_wrapper.py b/gensim/test/test_fasttext_wrapper.py index b0de495ad4..b21d3d5a1a 100644 --- a/gensim/test/test_fasttext_wrapper.py +++ b/gensim/test/test_fasttext_wrapper.py @@ -29,7 +29,6 @@ def testfile(): class TestFastText(unittest.TestCase): - @classmethod def setUp(self): ft_home = os.environ.get('FT_HOME', None) self.ft_path = os.path.join(ft_home, 'fasttext') if ft_home else None @@ -37,8 +36,7 @@ def setUp(self): self.test_model_file = datapath('lee_fasttext') self.test_new_model_file = datapath('lee_fasttext_new') # Load pre-trained model to perform tests in case FastText binary isn't available in test environment - self.test_model = fasttext.FastText.load_fasttext_format(self.test_model_file) - self.test_new_model = fasttext.FastText.load_fasttext_format(self.test_new_model_file) + self.test_model = fasttext.FastText.load(self.test_model_file) def model_sanity(self, model): """Even tiny models trained on any corpus should pass these sanity checks""" @@ -74,7 +72,10 @@ def testMinCount(self): if self.ft_path is None: logger.info("FT_HOME env variable not set, skipping test") return # Use self.skipTest once python < 2.7 is no longer supported - self.assertTrue('forests' not in self.test_model.wv.vocab) + test_model_min_count_5 = fasttext.FastText.train( + self.ft_path, self.corpus_file, output_file=testfile(), size=10, min_count=5) + self.assertTrue('forests' not in test_model_min_count_5.wv.vocab) + test_model_min_count_1 = fasttext.FastText.train( self.ft_path, self.corpus_file, output_file=testfile(), size=10, min_count=1) self.assertTrue('forests' in test_model_min_count_1.wv.vocab) @@ -115,60 +116,104 @@ def testNormalizedVectorsNotSaved(self): def testLoadFastTextFormat(self): """Test model successfully loaded from fastText .vec and .bin files""" - model = fasttext.FastText.load_fasttext_format(self.test_model_file) + try: + model = fasttext.FastText.load_fasttext_format(self.test_model_file) + except Exception as exc: + self.fail('Unable to load FastText model from file %s: %s' % (self.test_model_file, exc)) vocab_size, model_size = 1762, 10 - self.assertEqual(self.test_model.wv.syn0.shape, (vocab_size, model_size)) - self.assertEqual(len(self.test_model.wv.vocab), vocab_size, model_size) - self.assertEqual(self.test_model.wv.syn0_all.shape, (self.test_model.num_ngram_vectors, model_size)) - expected_vec = [-0.5714373588562012, - -0.008556111715734005, - 0.15747803449630737, - -0.6785456538200378, - -0.25458523631095886, - -0.5807671546936035, - -0.09912964701652527, - 1.1446694135665894, - 0.23417705297470093, - 0.06000664085149765] - self.assertTrue(numpy.allclose(self.test_model["hundred"], expected_vec, 0.001)) - self.assertEquals(self.test_model.min_count, 5) - self.assertEquals(self.test_model.window, 5) - self.assertEquals(self.test_model.iter, 5) - self.assertEquals(self.test_model.negative, 5) - self.assertEquals(self.test_model.sample, 0.0001) - self.assertEquals(self.test_model.bucket, 1000) - self.assertEquals(self.test_model.wv.max_n, 6) - self.assertEquals(self.test_model.wv.min_n, 3) + self.assertEqual(model.wv.syn0.shape, (vocab_size, model_size)) + self.assertEqual(len(model.wv.vocab), vocab_size, model_size) + self.assertEqual(model.wv.syn0_all.shape, (model.num_ngram_vectors, model_size)) + + expected_vec = [ + -0.57144, + -0.0085561, + 0.15748, + -0.67855, + -0.25459, + -0.58077, + -0.09913, + 1.1447, + 0.23418, + 0.060007 + ] # obtained using ./fasttext print-word-vectors lee_fasttext_new.bin + self.assertTrue(numpy.allclose(model["hundred"], expected_vec, atol=1e-4)) + + # vector for oov words are slightly different from original FastText due to discarding unused ngrams + # obtained using a modified version of ./fasttext print-word-vectors lee_fasttext_new.bin + expected_vec_oov = [ + -0.23825, + -0.58482, + -0.22276, + -0.41215, + 0.91015, + -1.6786, + -0.26724, + 0.58818, + 0.57828, + 0.75801 + ] + self.assertTrue(numpy.allclose(model["rejection"], expected_vec_oov, atol=1e-4)) + + self.assertEquals(model.min_count, 5) + self.assertEquals(model.window, 5) + self.assertEquals(model.iter, 5) + self.assertEquals(model.negative, 5) + self.assertEquals(model.sample, 0.0001) + self.assertEquals(model.bucket, 1000) + self.assertEquals(model.wv.max_n, 6) + self.assertEquals(model.wv.min_n, 3) self.model_sanity(model) def testLoadFastTextNewFormat(self): """ Test model successfully loaded from fastText (new format) .vec and .bin files """ - new_model = fasttext.FastText.load_fasttext_format(self.test_new_model_file) + try: + new_model = fasttext.FastText.load_fasttext_format(self.test_new_model_file) + except Exception as exc: + self.fail('Unable to load FastText model from file %s: %s' % (self.test_new_model_file, exc)) vocab_size, model_size = 1763, 10 - self.assertEqual(self.test_new_model.wv.syn0.shape, (vocab_size, model_size)) - self.assertEqual(len(self.test_new_model.wv.vocab), vocab_size, model_size) - self.assertEqual(self.test_new_model.wv.syn0_all.shape, (self.test_new_model.num_ngram_vectors, model_size)) - - expected_vec_new = [-0.025627, - -0.11448, - 0.18116, - -0.96779, - 0.2532, - -0.93224, - 0.3929, - 0.12679, - -0.19685, - -0.13179] # obtained using ./fasttext print-word-vectors lee_fasttext_new.bin < queries.txt - - self.assertTrue(numpy.allclose(self.test_new_model["hundred"], expected_vec_new, 0.001)) - self.assertEquals(self.test_new_model.min_count, 5) - self.assertEquals(self.test_new_model.window, 5) - self.assertEquals(self.test_new_model.iter, 5) - self.assertEquals(self.test_new_model.negative, 5) - self.assertEquals(self.test_new_model.sample, 0.0001) - self.assertEquals(self.test_new_model.bucket, 1000) - self.assertEquals(self.test_new_model.wv.max_n, 6) - self.assertEquals(self.test_new_model.wv.min_n, 3) + self.assertEqual(new_model.wv.syn0.shape, (vocab_size, model_size)) + self.assertEqual(len(new_model.wv.vocab), vocab_size, model_size) + self.assertEqual(new_model.wv.syn0_all.shape, (new_model.num_ngram_vectors, model_size)) + + expected_vec = [ + -0.025627, + -0.11448, + 0.18116, + -0.96779, + 0.2532, + -0.93224, + 0.3929, + 0.12679, + -0.19685, + -0.13179 + ] # obtained using ./fasttext print-word-vectors lee_fasttext_new.bin + self.assertTrue(numpy.allclose(new_model["hundred"], expected_vec, atol=1e-4)) + + # vector for oov words are slightly different from original FastText due to discarding unused ngrams + # obtained using a modified version of ./fasttext print-word-vectors lee_fasttext_new.bin + expected_vec_oov = [ + -0.53378, + -0.19, + 0.013482, + -0.86767, + -0.21684, + -0.89928, + 0.45124, + 0.18025, + -0.14128, + 0.22508 + ] + self.assertTrue(numpy.allclose(new_model["rejection"], expected_vec_oov, atol=1e-4)) + + self.assertEquals(new_model.min_count, 5) + self.assertEquals(new_model.window, 5) + self.assertEquals(new_model.iter, 5) + self.assertEquals(new_model.negative, 5) + self.assertEquals(new_model.sample, 0.0001) + self.assertEquals(new_model.bucket, 1000) + self.assertEquals(new_model.wv.max_n, 6) + self.assertEquals(new_model.wv.min_n, 3) self.model_sanity(new_model) def testLoadModelWithNonAsciiVocab(self): @@ -248,7 +293,7 @@ def testContains(self): self.assertTrue('night' in self.test_model) # Out of vocab check self.assertFalse('nights' in self.test_model.wv.vocab) - self.assertTrue('night' in self.test_model) + self.assertTrue('nights' in self.test_model) # Word with no ngrams in model self.assertFalse('a!@' in self.test_model.wv.vocab) self.assertFalse('a!@' in self.test_model)