From ea1f3cfb44e47f122c63b0cbbd5c8c68f446ce64 Mon Sep 17 00:00:00 2001 From: Shiva Manne Date: Tue, 5 Dec 2017 07:14:30 +0530 Subject: [PATCH] Fixes dtype of `model.wv.syn0_vocab` on updating `vocab` for `FastText` model. Fix #1759 (#1760) * extends `test_online` to check dtype * casts numpy array returned by `random.uniform` to `float32` --- gensim/models/fasttext.py | 8 ++++---- gensim/test/test_fasttext.py | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/gensim/models/fasttext.py b/gensim/models/fasttext.py index 6174754314..473f8dfa42 100644 --- a/gensim/models/fasttext.py +++ b/gensim/models/fasttext.py @@ -173,12 +173,12 @@ def init_ngrams(self, update=False): new_vocab_rows = rand_obj.uniform( -1.0 / self.vector_size, 1.0 / self.vector_size, (len(self.wv.vocab) - self.old_vocab_len, self.vector_size) - ) + ).astype(REAL) new_vocab_lockf_rows = ones((len(self.wv.vocab) - self.old_vocab_len, self.vector_size), dtype=REAL) new_ngram_rows = rand_obj.uniform( -1.0 / self.vector_size, 1.0 / self.vector_size, (len(self.wv.hash2index) - self.old_hash2index_len, self.vector_size) - ) + ).astype(REAL) new_ngram_lockf_rows = ones( (len(self.wv.hash2index) - self.old_hash2index_len, self.vector_size), dtype=REAL ) @@ -194,11 +194,11 @@ def reset_ngram_weights(self): for index in range(len(self.wv.vocab)): self.wv.syn0_vocab[index] = rand_obj.uniform( -1.0 / self.vector_size, 1.0 / self.vector_size, self.vector_size - ) + ).astype(REAL) for index in range(len(self.wv.hash2index)): self.wv.syn0_ngrams[index] = rand_obj.uniform( -1.0 / self.vector_size, 1.0 / self.vector_size, self.vector_size - ) + ).astype(REAL) def _do_train_job(self, sentences, alpha, inits): work, neu1 = inits diff --git a/gensim/test/test_fasttext.py b/gensim/test/test_fasttext.py index 69aa9d074a..c7de6206d6 100644 --- a/gensim/test/test_fasttext.py +++ b/gensim/test/test_fasttext.py @@ -422,6 +422,7 @@ def online_sanity(self, model): self.assertFalse('terrorism' in model.wv.vocab) self.assertFalse('orism>' in model.wv.ngrams) model.build_vocab(terro, update=True) # update vocab + self.assertTrue(model.wv.syn0_ngrams.dtype == 'float32') self.assertTrue('terrorism' in model.wv.vocab) self.assertTrue('orism>' in model.wv.ngrams) orig0_all = np.copy(model.wv.syn0_ngrams)