Skip to content

Commit

Permalink
Merge pull request #1409 from jayantj/ft_oov_fix
Browse files Browse the repository at this point in the history
[MRG] FastText out-of-vocab fix
  • Loading branch information
menshikh-iv authored Jun 13, 2017
2 parents 85d8ba1 + f11c68e commit fe3dc53
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 60 deletions.
14 changes: 7 additions & 7 deletions gensim/models/wrappers/fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ def word_vec(self, word, use_norm=False):
else:
word_vec = np.zeros(self.syn0_all.shape[1])
ngrams = FastText.compute_ngrams(word, self.min_n, self.max_n)
ngrams = [ng for ng in ngrams if ng in self.ngrams]
if use_norm:
ngram_weights = self.syn0_all_norm
else:
ngram_weights = self.syn0_all
for ngram in ngrams:
if ngram in self.ngrams:
word_vec += ngram_weights[self.ngrams[ngram]]
word_vec += ngram_weights[self.ngrams[ngram]]
if word_vec.any():
return word_vec / len(ngrams)
else: # No ngrams of the word are present in self.ngrams
Expand Down Expand Up @@ -344,7 +344,7 @@ def init_ngrams(self):
ngram_indices = []
for i, ngram in enumerate(all_ngrams):
ngram_hash = self.ft_hash(ngram)
ngram_indices.append((len(self.wv.vocab) + ngram_hash) % self.bucket)
ngram_indices.append(len(self.wv.vocab) + ngram_hash % self.bucket)
self.wv.ngrams[ngram] = i
self.wv.syn0_all = self.wv.syn0_all.take(ngram_indices, axis=0)

Expand All @@ -353,10 +353,10 @@ def compute_ngrams(word, min_n, max_n):
ngram_indices = []
BOW, EOW = ('<', '>') # Used by FastText to attach to all words as prefix and suffix
extended_word = BOW + word + EOW
ngrams = set()
for i in range(len(extended_word) - min_n + 1):
for j in range(min_n, max(len(extended_word) - max_n, max_n + 1)):
ngrams.add(extended_word[i:i+j])
ngrams = []
for ngram_length in range(min_n, min(len(extended_word), max_n) + 1):
for i in range(0, len(extended_word) - ngram_length + 1):
ngrams.append(extended_word[i:i + ngram_length])
return ngrams

@staticmethod
Expand Down
Binary file added gensim/test/test_data/lee_fasttext
Binary file not shown.
151 changes: 98 additions & 53 deletions gensim/test/test_fasttext_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,14 @@ def testfile():


class TestFastText(unittest.TestCase):
@classmethod
def setUp(self):
ft_home = os.environ.get('FT_HOME', None)
self.ft_path = os.path.join(ft_home, 'fasttext') if ft_home else None
self.corpus_file = datapath('lee_background.cor')
self.test_model_file = datapath('lee_fasttext')
self.test_new_model_file = datapath('lee_fasttext_new')
# Load pre-trained model to perform tests in case FastText binary isn't available in test environment
self.test_model = fasttext.FastText.load_fasttext_format(self.test_model_file)
self.test_new_model = fasttext.FastText.load_fasttext_format(self.test_new_model_file)
self.test_model = fasttext.FastText.load(self.test_model_file)

def model_sanity(self, model):
"""Even tiny models trained on any corpus should pass these sanity checks"""
Expand Down Expand Up @@ -74,7 +72,10 @@ def testMinCount(self):
if self.ft_path is None:
logger.info("FT_HOME env variable not set, skipping test")
return # Use self.skipTest once python < 2.7 is no longer supported
self.assertTrue('forests' not in self.test_model.wv.vocab)
test_model_min_count_5 = fasttext.FastText.train(
self.ft_path, self.corpus_file, output_file=testfile(), size=10, min_count=5)
self.assertTrue('forests' not in test_model_min_count_5.wv.vocab)

test_model_min_count_1 = fasttext.FastText.train(
self.ft_path, self.corpus_file, output_file=testfile(), size=10, min_count=1)
self.assertTrue('forests' in test_model_min_count_1.wv.vocab)
Expand Down Expand Up @@ -115,60 +116,104 @@ def testNormalizedVectorsNotSaved(self):

def testLoadFastTextFormat(self):
"""Test model successfully loaded from fastText .vec and .bin files"""
model = fasttext.FastText.load_fasttext_format(self.test_model_file)
try:
model = fasttext.FastText.load_fasttext_format(self.test_model_file)
except Exception as exc:
self.fail('Unable to load FastText model from file %s: %s' % (self.test_model_file, exc))
vocab_size, model_size = 1762, 10
self.assertEqual(self.test_model.wv.syn0.shape, (vocab_size, model_size))
self.assertEqual(len(self.test_model.wv.vocab), vocab_size, model_size)
self.assertEqual(self.test_model.wv.syn0_all.shape, (self.test_model.num_ngram_vectors, model_size))
expected_vec = [-0.5714373588562012,
-0.008556111715734005,
0.15747803449630737,
-0.6785456538200378,
-0.25458523631095886,
-0.5807671546936035,
-0.09912964701652527,
1.1446694135665894,
0.23417705297470093,
0.06000664085149765]
self.assertTrue(numpy.allclose(self.test_model["hundred"], expected_vec, 0.001))
self.assertEquals(self.test_model.min_count, 5)
self.assertEquals(self.test_model.window, 5)
self.assertEquals(self.test_model.iter, 5)
self.assertEquals(self.test_model.negative, 5)
self.assertEquals(self.test_model.sample, 0.0001)
self.assertEquals(self.test_model.bucket, 1000)
self.assertEquals(self.test_model.wv.max_n, 6)
self.assertEquals(self.test_model.wv.min_n, 3)
self.assertEqual(model.wv.syn0.shape, (vocab_size, model_size))
self.assertEqual(len(model.wv.vocab), vocab_size, model_size)
self.assertEqual(model.wv.syn0_all.shape, (model.num_ngram_vectors, model_size))

expected_vec = [
-0.57144,
-0.0085561,
0.15748,
-0.67855,
-0.25459,
-0.58077,
-0.09913,
1.1447,
0.23418,
0.060007
] # obtained using ./fasttext print-word-vectors lee_fasttext_new.bin
self.assertTrue(numpy.allclose(model["hundred"], expected_vec, atol=1e-4))

# vector for oov words are slightly different from original FastText due to discarding unused ngrams
# obtained using a modified version of ./fasttext print-word-vectors lee_fasttext_new.bin
expected_vec_oov = [
-0.23825,
-0.58482,
-0.22276,
-0.41215,
0.91015,
-1.6786,
-0.26724,
0.58818,
0.57828,
0.75801
]
self.assertTrue(numpy.allclose(model["rejection"], expected_vec_oov, atol=1e-4))

self.assertEquals(model.min_count, 5)
self.assertEquals(model.window, 5)
self.assertEquals(model.iter, 5)
self.assertEquals(model.negative, 5)
self.assertEquals(model.sample, 0.0001)
self.assertEquals(model.bucket, 1000)
self.assertEquals(model.wv.max_n, 6)
self.assertEquals(model.wv.min_n, 3)
self.model_sanity(model)

def testLoadFastTextNewFormat(self):
""" Test model successfully loaded from fastText (new format) .vec and .bin files """
new_model = fasttext.FastText.load_fasttext_format(self.test_new_model_file)
try:
new_model = fasttext.FastText.load_fasttext_format(self.test_new_model_file)
except Exception as exc:
self.fail('Unable to load FastText model from file %s: %s' % (self.test_new_model_file, exc))
vocab_size, model_size = 1763, 10
self.assertEqual(self.test_new_model.wv.syn0.shape, (vocab_size, model_size))
self.assertEqual(len(self.test_new_model.wv.vocab), vocab_size, model_size)
self.assertEqual(self.test_new_model.wv.syn0_all.shape, (self.test_new_model.num_ngram_vectors, model_size))

expected_vec_new = [-0.025627,
-0.11448,
0.18116,
-0.96779,
0.2532,
-0.93224,
0.3929,
0.12679,
-0.19685,
-0.13179] # obtained using ./fasttext print-word-vectors lee_fasttext_new.bin < queries.txt

self.assertTrue(numpy.allclose(self.test_new_model["hundred"], expected_vec_new, 0.001))
self.assertEquals(self.test_new_model.min_count, 5)
self.assertEquals(self.test_new_model.window, 5)
self.assertEquals(self.test_new_model.iter, 5)
self.assertEquals(self.test_new_model.negative, 5)
self.assertEquals(self.test_new_model.sample, 0.0001)
self.assertEquals(self.test_new_model.bucket, 1000)
self.assertEquals(self.test_new_model.wv.max_n, 6)
self.assertEquals(self.test_new_model.wv.min_n, 3)
self.assertEqual(new_model.wv.syn0.shape, (vocab_size, model_size))
self.assertEqual(len(new_model.wv.vocab), vocab_size, model_size)
self.assertEqual(new_model.wv.syn0_all.shape, (new_model.num_ngram_vectors, model_size))

expected_vec = [
-0.025627,
-0.11448,
0.18116,
-0.96779,
0.2532,
-0.93224,
0.3929,
0.12679,
-0.19685,
-0.13179
] # obtained using ./fasttext print-word-vectors lee_fasttext_new.bin
self.assertTrue(numpy.allclose(new_model["hundred"], expected_vec, atol=1e-4))

# vector for oov words are slightly different from original FastText due to discarding unused ngrams
# obtained using a modified version of ./fasttext print-word-vectors lee_fasttext_new.bin
expected_vec_oov = [
-0.53378,
-0.19,
0.013482,
-0.86767,
-0.21684,
-0.89928,
0.45124,
0.18025,
-0.14128,
0.22508
]
self.assertTrue(numpy.allclose(new_model["rejection"], expected_vec_oov, atol=1e-4))

self.assertEquals(new_model.min_count, 5)
self.assertEquals(new_model.window, 5)
self.assertEquals(new_model.iter, 5)
self.assertEquals(new_model.negative, 5)
self.assertEquals(new_model.sample, 0.0001)
self.assertEquals(new_model.bucket, 1000)
self.assertEquals(new_model.wv.max_n, 6)
self.assertEquals(new_model.wv.min_n, 3)
self.model_sanity(new_model)

def testLoadModelWithNonAsciiVocab(self):
Expand Down Expand Up @@ -248,7 +293,7 @@ def testContains(self):
self.assertTrue('night' in self.test_model)
# Out of vocab check
self.assertFalse('nights' in self.test_model.wv.vocab)
self.assertTrue('night' in self.test_model)
self.assertTrue('nights' in self.test_model)
# Word with no ngrams in model
self.assertFalse('a!@' in self.test_model.wv.vocab)
self.assertFalse('a!@' in self.test_model)
Expand Down

0 comments on commit fe3dc53

Please sign in to comment.