Skip to content

Commit

Permalink
added Python-only implementation for skip-gram model
Browse files Browse the repository at this point in the history
  • Loading branch information
chinmayapancholi13 committed May 24, 2017
1 parent 0939b32 commit 8949749
Showing 1 changed file with 15 additions and 21 deletions.
36 changes: 15 additions & 21 deletions gensim/models/word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@
FAST_VERSION = -1
MAX_WORDS_IN_BATCH = 10000

def train_batch_sg(model, sentences, alpha, work=None, print_freq=0):
def train_batch_sg(model, sentences, alpha, work=None, enable_loss_logging=False):
"""
Update skip-gram model by training on a sequence of sentences.
Expand All @@ -151,7 +151,6 @@ def train_batch_sg(model, sentences, alpha, work=None, print_freq=0):
will use the optimized version from word2vec_inner instead.
"""
p_iter = 0
result = 0
for sentence in sentences:
word_vocabs = [model.wv.vocab[w] for w in sentence if w in model.wv.vocab and
Expand All @@ -164,14 +163,7 @@ def train_batch_sg(model, sentences, alpha, work=None, print_freq=0):
for pos2, word2 in enumerate(word_vocabs[start:(pos + model.window + 1 - reduced_window)], start):
# don't train on the `word` itself
if pos2 != pos:
if print_freq != 0:
if p_iter == 0:
train_sg_pair(model, model.wv.index2word[word.index], word2.index, alpha, enable_loss_logging=True)
else:
train_sg_pair(model, model.wv.index2word[word.index], word2.index, alpha, enable_loss_logging=False)
p_iter = (p_iter + 1) % print_freq
else:
train_sg_pair(model, model.wv.index2word[word.index], word2.index, alpha, enable_loss_logging=False)
train_sg_pair(model, model.wv.index2word[word.index], word2.index, alpha, enable_loss_logging=enable_loss_logging)

result += len(word_vocabs)
return result
Expand Down Expand Up @@ -280,9 +272,6 @@ def train_sg_pair(model, word, context_index, alpha, learn_vectors=True, learn_h

neu1e = zeros(l1.shape)

if enable_loss_logging:
train_error_value = 0

if model.hs:
# work on the entire tree at once, to push as much work into numpy's C routines as possible (performance)
l2a = deepcopy(model.syn1[predict_word.point]) # 2d matrix, codelen x layer1_size
Expand All @@ -296,7 +285,7 @@ def train_sg_pair(model, word, context_index, alpha, learn_vectors=True, learn_h
if enable_loss_logging:
sgn = (-1.0)**predict_word.code # ch function, 0-> 1, 1 -> -1
lprob = -log(expit(-sgn * dot(l1, l2a.T)))
train_error_value += sum(lprob)
model.cumulative_training_loss += sum(lprob)

if model.negative:
# use this word (label = 1) + `negative` other random words not from this sentence (label = 0)
Expand All @@ -315,9 +304,8 @@ def train_sg_pair(model, word, context_index, alpha, learn_vectors=True, learn_h

# loss component corresponding to negative sampling
if enable_loss_logging:
train_error_value -= sum(log(expit(-1 * prod_term[range(1, len(prod_term))]))) # for the sampled words
train_error_value -= log(expit(prod_term[0])) # for the output word
logger.info("current training loss : %f", train_error_value)
model.cumulative_training_loss -= sum(log(expit(-1 * prod_term[range(1, len(prod_term))]))) # for the sampled words
model.cumulative_training_loss -= log(expit(prod_term[0])) # for the output word

if learn_vectors:
l1 += neu1e * lock_factor # learn input -> hidden (mutates model.wv.syn0[word2.index], if that is l1)
Expand Down Expand Up @@ -497,6 +485,7 @@ def __init__(
self.sorted_vocab = sorted_vocab
self.batch_words = batch_words
self.model_trimmed_post_training = False
self.cumulative_training_loss = 0
if sentences is not None:
if isinstance(sentences, GeneratorType):
raise TypeError("You can't pass a generator as the sentences argument. Try an iterator.")
Expand Down Expand Up @@ -773,15 +762,15 @@ def reset_from(self, other_model):
self.corpus_count = other_model.corpus_count
self.reset_weights()

def _do_train_job(self, sentences, alpha, inits):
def _do_train_job(self, sentences, alpha, inits, enable_loss_logging=False):
"""
Train a single batch of sentences. Return 2-tuple `(effective word count after
ignoring unknown words and sentence length trimming, total word count)`.
"""
work, neu1 = inits
tally = 0
if self.sg:
tally += train_batch_sg(self, sentences, alpha, work)
tally += train_batch_sg(self, sentences, alpha, work, enable_loss_logging)
else:
tally += train_batch_cbow(self, sentences, alpha, work, neu1)
return tally, self._raw_word_count(sentences)
Expand All @@ -793,7 +782,7 @@ def _raw_word_count(self, job):
def train(self, sentences, total_examples=None, total_words=None,
epochs=None, start_alpha=None, end_alpha=None,
word_count=0,
queue_factor=2, report_delay=1.0):
queue_factor=2, report_delay=1.0, enable_loss_logging=False):
"""
Update the model's neural weights from a sequence of sentences (can be a once-only generator stream).
For Word2Vec, each sentence must be a list of unicode strings. (Subclasses may accept other examples.)
Expand All @@ -819,6 +808,8 @@ def train(self, sentences, total_examples=None, total_words=None,
self.neg_labels = zeros(self.negative + 1)
self.neg_labels[0] = 1.

self.cumulative_training_loss = 0

logger.info(
"training model with %i workers on %i vocabulary and %i features, "
"using sg=%s hs=%s sample=%s negative=%s window=%s",
Expand Down Expand Up @@ -861,7 +852,7 @@ def worker_loop():
progress_queue.put(None)
break # no more jobs => quit this worker
sentences, alpha = job
tally, raw_tally = self._do_train_job(sentences, alpha, (work, neu1))
tally, raw_tally = self._do_train_job(sentences, alpha, (work, neu1), enable_loss_logging)
progress_queue.put((len(sentences), tally, raw_tally)) # report back progress
jobs_processed += 1
logger.debug("worker exiting, processed %i jobs", jobs_processed)
Expand Down Expand Up @@ -1462,6 +1453,9 @@ def save_word2vec_format(self, fname, fvocab=None, binary=False):
"""Deprecated. Use model.wv.save_word2vec_format instead."""
raise DeprecationWarning("Deprecated. Use model.wv.save_word2vec_format instead.")

def get_latest_training_loss(self):
return self.cumulative_training_loss


class BrownCorpus(object):
"""Iterate over sentences from the Brown corpus (part of NLTK data)."""
Expand Down

0 comments on commit 8949749

Please sign in to comment.