diff --git a/gensim/models/doc2vec.py b/gensim/models/doc2vec.py index 07789c8ff7..824d4cda68 100644 --- a/gensim/models/doc2vec.py +++ b/gensim/models/doc2vec.py @@ -529,7 +529,7 @@ def repeat(self, word_count): class Doc2Vec(Word2Vec): """Class for training, using and evaluating neural networks described in http://arxiv.org/pdf/1405.4053v2.pdf""" - def __init__(self, documents=None,dm_mean=None, + def __init__(self, documents=None, dm_mean=None, dm=1, dbow_words=0, dm_concat=0, dm_tag_count=1, docvecs=None, docvecs_mapfile=None, comment=None, trim_rule=None, **kwargs): """ @@ -601,9 +601,12 @@ def __init__(self, documents=None,dm_mean=None, """ super(Doc2Vec, self).__init__( - sg=(1 + dm) % 2, dm_mean=dm_mean, + sg=(1 + dm) % 2, null_word=dm_concat, **kwargs) + if dm_mean is not None: + self.cbow_mean = dm_mean + self.dbow_words = dbow_words self.dm_concat = dm_concat self.dm_tag_count = dm_tag_count diff --git a/gensim/models/word2vec.py b/gensim/models/word2vec.py index f9bf727a35..a2642092f5 100644 --- a/gensim/models/word2vec.py +++ b/gensim/models/word2vec.py @@ -365,7 +365,7 @@ def __init__( self, sentences=None, size=100, alpha=0.025, window=5, min_count=5, max_vocab_size=None, sample=1e-3, seed=1, workers=3, min_alpha=0.0001, sg=0, hs=0, negative=5, cbow_mean=1, hashfxn=hash, iter=5, null_word=0, - trim_rule=None, sorted_vocab=1, batch_words=MAX_WORDS_IN_BATCH,**kwargs): + trim_rule=None, sorted_vocab=1, batch_words=MAX_WORDS_IN_BATCH): """ Initialize the model from an iterable of `sentences`. Each sentence is a list of words (unicode strings) that will be used for training. @@ -461,6 +461,7 @@ def __init__( self.min_alpha = float(min_alpha) self.hs = hs self.negative = negative + self.cbow_mean = int(cbow_mean) self.hashfxn = hashfxn self.iter = iter self.null_word = null_word @@ -469,11 +470,6 @@ def __init__( self.sorted_vocab = sorted_vocab self.batch_words = batch_words - if "dm_mean" in kwargs and kwargs["dm_mean"] is not None: - self.cbow_mean = int(kwargs["dm_mean"]) - else: - self.cbow_mean = int(cbow_mean) - if sentences is not None: if isinstance(sentences, GeneratorType): raise TypeError("You can't pass a generator as the sentences argument. Try an iterator.") diff --git a/gensim/test/test_doc2vec.py b/gensim/test/test_doc2vec.py index e03d35ba80..09ab714a8e 100644 --- a/gensim/test/test_doc2vec.py +++ b/gensim/test/test_doc2vec.py @@ -84,11 +84,10 @@ def test_int_doctags(self): """Test doc2vec doctag alternatives""" corpus = DocsLeeCorpus() - size = 300 - model = doc2vec.Doc2Vec(min_count=1, size=size) + model = doc2vec.Doc2Vec(min_count=1) model.build_vocab(corpus) - self.assertEqual(len(model.docvecs.doctag_syn0), size) - self.assertEqual(model.docvecs[0].shape, (size,)) + self.assertEqual(len(model.docvecs.doctag_syn0), 300) + self.assertEqual(model.docvecs[0].shape, (100,)) self.assertRaises(KeyError, model.__getitem__, '_*0') def test_missing_string_doctag(self): @@ -107,13 +106,12 @@ def test_string_doctags(self): # force duplicated tags corpus = corpus[0:10] + corpus - size = 300 - model = doc2vec.Doc2Vec(size=size, min_count=1) + model = doc2vec.Doc2Vec(min_count=1) model.build_vocab(corpus) - self.assertEqual(len(model.docvecs.doctag_syn0), size) - self.assertEqual(model.docvecs[0].shape, (size,)) - self.assertEqual(model.docvecs['_*0'].shape, (size,)) + self.assertEqual(len(model.docvecs.doctag_syn0), 300) + self.assertEqual(model.docvecs[0].shape, (100,)) + self.assertEqual(model.docvecs['_*0'].shape, (100,)) self.assertTrue(all(model.docvecs['_*0'] == model.docvecs[0])) self.assertTrue(max(d.offset for d in model.docvecs.doctags.values()) < len(model.docvecs.doctags)) self.assertTrue(max(model.docvecs._int_index(str_key) for str_key in model.docvecs.doctags.keys()) < len(model.docvecs.doctag_syn0)) diff --git a/gensim/test/test_wikicorpus.py b/gensim/test/test_wikicorpus.py index eeb4f40ec7..588b51dc9b 100644 --- a/gensim/test/test_wikicorpus.py +++ b/gensim/test/test_wikicorpus.py @@ -42,11 +42,10 @@ def setUp(self): # self.assertEqual(type(first), list) # self.assertTrue(isinstance(first[0], bytes) or isinstance(first[0], str)) def test_get_texts_returns_generator_of_lists(self): - logger.debug("Current Python Version is "+str(sys.version_info)) if sys.version_info < (2, 7, 0): return - + wc = WikiCorpus(datapath(FILENAME)) l = wc.get_texts() self.assertEqual(type(l), types.GeneratorType) @@ -62,7 +61,6 @@ def test_first_element(self): 2) autism """ - if sys.version_info < (2, 7, 0): return wc = WikiCorpus(datapath(FILENAME))