From 8e374d3401ebfc14e785832bd2fbbb23fba0b861 Mon Sep 17 00:00:00 2001 From: markroxor Date: Thu, 20 Oct 2016 15:27:36 +0530 Subject: [PATCH] handling dm_mean --- gensim/models/doc2vec.py | 5 ++--- gensim/models/word2vec.py | 4 ++-- gensim/test/test_doc2vec.py | 4 ++-- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/gensim/models/doc2vec.py b/gensim/models/doc2vec.py index d5ac7d9084..07789c8ff7 100644 --- a/gensim/models/doc2vec.py +++ b/gensim/models/doc2vec.py @@ -529,7 +529,7 @@ def repeat(self, word_count): class Doc2Vec(Word2Vec): """Class for training, using and evaluating neural networks described in http://arxiv.org/pdf/1405.4053v2.pdf""" - def __init__(self, documents=None, + def __init__(self, documents=None,dm_mean=None, dm=1, dbow_words=0, dm_concat=0, dm_tag_count=1, docvecs=None, docvecs_mapfile=None, comment=None, trim_rule=None, **kwargs): """ @@ -600,9 +600,8 @@ def __init__(self, documents=None, """ - super(Doc2Vec, self).__init__( - sg=(1 + dm) % 2, + sg=(1 + dm) % 2, dm_mean=dm_mean, null_word=dm_concat, **kwargs) self.dbow_words = dbow_words diff --git a/gensim/models/word2vec.py b/gensim/models/word2vec.py index 4979706330..f9bf727a35 100644 --- a/gensim/models/word2vec.py +++ b/gensim/models/word2vec.py @@ -469,11 +469,11 @@ def __init__( self.sorted_vocab = sorted_vocab self.batch_words = batch_words - if "dm_mean" in kwargs: + if "dm_mean" in kwargs and kwargs["dm_mean"] is not None: self.cbow_mean = int(kwargs["dm_mean"]) else: self.cbow_mean = int(cbow_mean) - + if sentences is not None: if isinstance(sentences, GeneratorType): raise TypeError("You can't pass a generator as the sentences argument. Try an iterator.") diff --git a/gensim/test/test_doc2vec.py b/gensim/test/test_doc2vec.py index 15de64c846..e03d35ba80 100644 --- a/gensim/test/test_doc2vec.py +++ b/gensim/test/test_doc2vec.py @@ -171,7 +171,7 @@ def model_sanity(self, model): def test_training(self): """Test doc2vec training.""" corpus = DocsLeeCorpus() - model = doc2vec.Doc2Vec(size=100, min_count=2, iter=20, window=8, sample=0.01, workers=1) + model = doc2vec.Doc2Vec(size=100, min_count=2, iter=20, workers=1) model.build_vocab(corpus) self.assertEqual(model.docvecs.doctag_syn0.shape, (300, 100)) model.train(corpus) @@ -179,7 +179,7 @@ def test_training(self): self.model_sanity(model) # build vocab and train in one step; must be the same as above - model2 = doc2vec.Doc2Vec(corpus, size=100, min_count=2, iter=20, window=8, sample=0.01, workers=1) + model2 = doc2vec.Doc2Vec(corpus, size=100, min_count=2, iter=20, workers=1) self.models_equal(model, model2) def test_dbow_hs(self):