Skip to content

Commit

Permalink
Rebuild cumulative table on load. Fix #1180 (#1181)
Browse files Browse the repository at this point in the history
* Rebuild cumulative table on load. Fix #1180
* Train on sentences instead of corpus
  • Loading branch information
tmylk authored Mar 3, 2017
1 parent 4f0e2ae commit 63cf941
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 4 deletions.
2 changes: 1 addition & 1 deletion gensim/models/word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1271,7 +1271,7 @@ def load(cls, *args, **kwargs):
# update older models
if hasattr(model, 'table'):
delattr(model, 'table') # discard in favor of cum_table
if model.negative and hasattr(model, 'index2word'):
if model.negative and hasattr(model.wv, 'index2word'):
model.make_cum_table() # rebuild cum_table from vocabulary
if not hasattr(model, 'corpus_count'):
model.corpus_count = None
Expand Down
12 changes: 9 additions & 3 deletions gensim/test/test_doc2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def test_similarity_unseen_docs(self):
model.build_vocab(corpus)
self.assertTrue(model.docvecs.similarity_unseen_docs(model, rome_str, rome_str) > model.docvecs.similarity_unseen_docs(model, rome_str, car_str))

def model_sanity(self, model):
def model_sanity(self, model, keep_training=True):
"""Any non-trivial model on DocsLeeCorpus can pass these sanity checks"""
fire1 = 0 # doc 0 sydney fires
fire2 = 8 # doc 8 sydney fires
Expand Down Expand Up @@ -179,6 +179,12 @@ def model_sanity(self, model):
# fire docs should be closer than fire-tennis
self.assertTrue(model.docvecs.similarity(fire1, fire2) > model.docvecs.similarity(fire1, tennis1))

# keep training after save
if keep_training:
model.save(testfile())
loaded = doc2vec.Doc2Vec.load(testfile())
loaded.train(sentences)

def test_training(self):
"""Test doc2vec training."""
corpus = DocsLeeCorpus()
Expand Down Expand Up @@ -316,10 +322,10 @@ def test_delete_temporary_training_data(self):
model.delete_temporary_training_data(keep_doctags_vectors=True, keep_inference=True)
self.assertTrue(model.docvecs and hasattr(model.docvecs, 'doctag_syn0'))
self.assertTrue(hasattr(model, 'syn1'))
self.model_sanity(model)
self.model_sanity(model, keep_training=False)
model = doc2vec.Doc2Vec(list_corpus, dm=1, dm_mean=1, size=24, window=4, hs=0, negative=1, alpha=0.05, min_count=2, iter=20)
model.delete_temporary_training_data(keep_doctags_vectors=True, keep_inference=True)
self.model_sanity(model)
self.model_sanity(model, keep_training=False)
self.assertTrue(hasattr(model, 'syn1neg'))

@log_capture()
Expand Down
12 changes: 12 additions & 0 deletions gensim/test/test_word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,18 @@ def testOnlineLearning(self):
self.assertEqual(len(model_hs.wv.vocab), 14)
self.assertEqual(len(model_neg.wv.vocab), 14)

def testOnlineLearningAfterSave(self):
"""Test that the algorithm is able to add new words to the
vocabulary and to a trained model when using a sorted vocabulary"""
model_neg = word2vec.Word2Vec(sentences, size=10, min_count=0, seed=42, hs=0, negative=5)
model_neg.save(testfile())
model_neg = word2vec.Word2Vec.load(testfile())
self.assertTrue(len(model_neg.wv.vocab), 12)
model_neg.build_vocab(new_sentences, update=True)
model_neg.train(new_sentences)
self.assertEqual(len(model_neg.wv.vocab), 14)


def onlineSanity(self, model):
terro, others = [], []
for l in list_corpus:
Expand Down

0 comments on commit 63cf941

Please sign in to comment.