Skip to content

Commit

Permalink
handling dm_mean
Browse files Browse the repository at this point in the history
  • Loading branch information
markroxor committed Oct 20, 2016
1 parent c144bf3 commit 49163a0
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 7 deletions.
5 changes: 2 additions & 3 deletions gensim/models/doc2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def repeat(self, word_count):

class Doc2Vec(Word2Vec):
"""Class for training, using and evaluating neural networks described in http://arxiv.org/pdf/1405.4053v2.pdf"""
def __init__(self, documents=None,
def __init__(self, documents=None,dm_mean=None,
dm=1, dbow_words=0, dm_concat=0, dm_tag_count=1,
docvecs=None, docvecs_mapfile=None, comment=None, trim_rule=None, **kwargs):
"""
Expand Down Expand Up @@ -600,9 +600,8 @@ def __init__(self, documents=None,
"""


super(Doc2Vec, self).__init__(
sg=(1 + dm) % 2,
sg=(1 + dm) % 2, dm_mean=dm_mean,
null_word=dm_concat, **kwargs)

self.dbow_words = dbow_words
Expand Down
4 changes: 2 additions & 2 deletions gensim/models/word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,11 +459,11 @@ def __init__(
self.sorted_vocab = sorted_vocab
self.batch_words = batch_words

if "dm_mean" in kwargs:
if kwargs["dm_mean"] is not None:
self.cbow_mean = int(kwargs["dm_mean"])
else:
self.cbow_mean = int(cbow_mean)

if sentences is not None:
if isinstance(sentences, GeneratorType):
raise TypeError("You can't pass a generator as the sentences argument. Try an iterator.")
Expand Down
4 changes: 2 additions & 2 deletions gensim/test/test_doc2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,15 @@ def model_sanity(self, model):
def test_training(self):
"""Test doc2vec training."""
corpus = DocsLeeCorpus()
model = doc2vec.Doc2Vec(size=100, min_count=2, iter=20, window=8, sample=0.01, workers=1)
model = doc2vec.Doc2Vec(size=100, min_count=2, iter=20, workers=1)
model.build_vocab(corpus)
self.assertEqual(model.docvecs.doctag_syn0.shape, (300, 100))
model.train(corpus)

self.model_sanity(model)

# build vocab and train in one step; must be the same as above
model2 = doc2vec.Doc2Vec(corpus, size=100, min_count=2, iter=20, window=8, sample=0.01, workers=1)
model2 = doc2vec.Doc2Vec(corpus, size=100, min_count=2, iter=20, workers=1)
self.models_equal(model, model2)

def test_dbow_hs(self):
Expand Down

0 comments on commit 49163a0

Please sign in to comment.