Skip to content

Commit

Permalink
Fixes dtype of model.wv.syn0_vocab on updating vocab for `FastTex…
Browse files Browse the repository at this point in the history
…t` model. Fix  #1759 (#1760)

* extends `test_online` to check dtype

* casts numpy array returned by `random.uniform` to `float32`
  • Loading branch information
manneshiva authored and menshikh-iv committed Dec 5, 2017
1 parent aad849a commit ea1f3cf
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
8 changes: 4 additions & 4 deletions gensim/models/fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,12 @@ def init_ngrams(self, update=False):
new_vocab_rows = rand_obj.uniform(
-1.0 / self.vector_size, 1.0 / self.vector_size,
(len(self.wv.vocab) - self.old_vocab_len, self.vector_size)
)
).astype(REAL)
new_vocab_lockf_rows = ones((len(self.wv.vocab) - self.old_vocab_len, self.vector_size), dtype=REAL)
new_ngram_rows = rand_obj.uniform(
-1.0 / self.vector_size, 1.0 / self.vector_size,
(len(self.wv.hash2index) - self.old_hash2index_len, self.vector_size)
)
).astype(REAL)
new_ngram_lockf_rows = ones(
(len(self.wv.hash2index) - self.old_hash2index_len, self.vector_size), dtype=REAL
)
Expand All @@ -194,11 +194,11 @@ def reset_ngram_weights(self):
for index in range(len(self.wv.vocab)):
self.wv.syn0_vocab[index] = rand_obj.uniform(
-1.0 / self.vector_size, 1.0 / self.vector_size, self.vector_size
)
).astype(REAL)
for index in range(len(self.wv.hash2index)):
self.wv.syn0_ngrams[index] = rand_obj.uniform(
-1.0 / self.vector_size, 1.0 / self.vector_size, self.vector_size
)
).astype(REAL)

def _do_train_job(self, sentences, alpha, inits):
work, neu1 = inits
Expand Down
1 change: 1 addition & 0 deletions gensim/test/test_fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ def online_sanity(self, model):
self.assertFalse('terrorism' in model.wv.vocab)
self.assertFalse('orism>' in model.wv.ngrams)
model.build_vocab(terro, update=True) # update vocab
self.assertTrue(model.wv.syn0_ngrams.dtype == 'float32')
self.assertTrue('terrorism' in model.wv.vocab)
self.assertTrue('orism>' in model.wv.ngrams)
orig0_all = np.copy(model.wv.syn0_ngrams)
Expand Down

0 comments on commit ea1f3cf

Please sign in to comment.