-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Adding unsupervised FastText to Gensim #1525
Merged
menshikh-iv
merged 35 commits into
piskvorky:develop
from
chinmayapancholi13:fasttext_gensim
Sep 19, 2017
Merged
Changes from all commits
Commits
Show all changes
35 commits
Select commit
Hold shift + click to select a range
a815c84
added initial code for CBOW
chinmayapancholi13 102c14a
updated unit tests for fasttext
chinmayapancholi13 4c449df
corrected use of matrix and precomputed ngrams for vocab words
chinmayapancholi13 f49df54
added EOS token in 'LineSentence' class
chinmayapancholi13 1fcb8fa
added skipgram training code
chinmayapancholi13 82fda3c
updated unit tests for fasttext
chinmayapancholi13 cd59034
seeded 'np.random' with 'self.seed'
chinmayapancholi13 353f7a8
added test for persistence
chinmayapancholi13 569a026
updated seeding numpy obj
chinmayapancholi13 c228b8d
updated (unclean) fasttext code for review
chinmayapancholi13 29c627f
updated fasttext tutorial notebook
chinmayapancholi13 acbfdf2
added 'save' and 'load_fasttext_format' functions
chinmayapancholi13 cb7a2ad
updated unit tests for fasttext
chinmayapancholi13 5a18297
cleaned main fasttext code
chinmayapancholi13 4b98722
updated unittests
chinmayapancholi13 cf1f3e0
removed EOS token from LineSentence
chinmayapancholi13 d986242
fixed flake8 errors
chinmayapancholi13 bce17ff
[WIP] added online learning
chinmayapancholi13 cb84001
added tests for online learning
chinmayapancholi13 fbe8bdc
Merge branch 'develop' of https://github.com/RaRe-Technologies/gensim…
chinmayapancholi13 58c673a
flake8 fixes
chinmayapancholi13 893ef76
refactored code to remove redundancy
chinmayapancholi13 e12f6c0
reusing 'word_vec' from 'FastTextKeyedVectors'
chinmayapancholi13 39d14bd
flake8 fixes
chinmayapancholi13 d3ec5a8
split 'syn0_all' into 'syn0_vocab' and 'syn0_ngrams'
chinmayapancholi13 0854622
removed 'init_wv' param from Word2Vec
chinmayapancholi13 904882a
updated unittests
chinmayapancholi13 a9e7d03
flake8 errors fixed
chinmayapancholi13 ec58512
fixed oov word_vec
chinmayapancholi13 2ed7d31
removed merge conflicts
chinmayapancholi13 daace4a
updated test_training unittest
chinmayapancholi13 58c531b
Merge branch 'develop' into fasttext_gensim
menshikh-iv 3ffa103
Fix broken merge
menshikh-iv 2b0583b
useless change (need to re-run Appveyour)
menshikh-iv 55d731a
Add skipIf for Appveyor x32 (avoid memory error)
menshikh-iv File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,231 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
import logging | ||
|
||
import numpy as np | ||
from numpy import zeros, ones, vstack, sum as np_sum, empty, float32 as REAL | ||
|
||
from gensim.models.word2vec import Word2Vec, train_sg_pair, train_cbow_pair | ||
from gensim.models.wrappers.fasttext import FastTextKeyedVectors | ||
from gensim.models.wrappers.fasttext import FastText as Ft_Wrapper, compute_ngrams, ft_hash | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
MAX_WORDS_IN_BATCH = 10000 | ||
|
||
|
||
def train_batch_cbow(model, sentences, alpha, work=None, neu1=None): | ||
result = 0 | ||
for sentence in sentences: | ||
word_vocabs = [model.wv.vocab[w] for w in sentence if w in model.wv.vocab and | ||
model.wv.vocab[w].sample_int > model.random.rand() * 2**32] | ||
for pos, word in enumerate(word_vocabs): | ||
reduced_window = model.random.randint(model.window) | ||
start = max(0, pos - model.window + reduced_window) | ||
window_pos = enumerate(word_vocabs[start:(pos + model.window + 1 - reduced_window)], start) | ||
word2_indices = [word2.index for pos2, word2 in window_pos if (word2 is not None and pos2 != pos)] | ||
|
||
word2_subwords = [] | ||
vocab_subwords_indices = [] | ||
ngrams_subwords_indices = [] | ||
|
||
for index in word2_indices: | ||
vocab_subwords_indices += [index] | ||
word2_subwords += model.wv.ngrams_word[model.wv.index2word[index]] | ||
|
||
for subword in word2_subwords: | ||
ngrams_subwords_indices.append(model.wv.ngrams[subword]) | ||
|
||
l1_vocab = np_sum(model.wv.syn0_vocab[vocab_subwords_indices], axis=0) # 1 x vector_size | ||
l1_ngrams = np_sum(model.wv.syn0_ngrams[ngrams_subwords_indices], axis=0) # 1 x vector_size | ||
|
||
l1 = np_sum([l1_vocab, l1_ngrams], axis=0) | ||
subwords_indices = [vocab_subwords_indices] + [ngrams_subwords_indices] | ||
if (subwords_indices[0] or subwords_indices[1]) and model.cbow_mean: | ||
l1 /= (len(subwords_indices[0]) + len(subwords_indices[1])) | ||
|
||
train_cbow_pair(model, word, subwords_indices, l1, alpha, is_ft=True) # train on the sliding window for target word | ||
result += len(word_vocabs) | ||
return result | ||
|
||
|
||
def train_batch_sg(model, sentences, alpha, work=None): | ||
result = 0 | ||
for sentence in sentences: | ||
word_vocabs = [model.wv.vocab[w] for w in sentence if w in model.wv.vocab and | ||
model.wv.vocab[w].sample_int > model.random.rand() * 2**32] | ||
for pos, word in enumerate(word_vocabs): | ||
reduced_window = model.random.randint(model.window) # `b` in the original word2vec code | ||
# now go over all words from the (reduced) window, predicting each one in turn | ||
start = max(0, pos - model.window + reduced_window) | ||
|
||
subwords_indices = [word.index] | ||
word2_subwords = model.wv.ngrams_word[model.wv.index2word[word.index]] | ||
|
||
for subword in word2_subwords: | ||
subwords_indices.append(model.wv.ngrams[subword]) | ||
|
||
for pos2, word2 in enumerate(word_vocabs[start:(pos + model.window + 1 - reduced_window)], start): | ||
if pos2 != pos: # don't train on the `word` itself | ||
train_sg_pair(model, model.wv.index2word[word2.index], subwords_indices, alpha, is_ft=True) | ||
|
||
result += len(word_vocabs) | ||
return result | ||
|
||
|
||
class FastText(Word2Vec): | ||
def __init__( | ||
self, sentences=None, sg=0, hs=0, size=100, alpha=0.025, window=5, min_count=5, | ||
max_vocab_size=None, word_ngrams=1, loss='ns', sample=1e-3, seed=1, workers=3, min_alpha=0.0001, | ||
negative=5, cbow_mean=1, hashfxn=hash, iter=5, null_word=0, min_n=3, max_n=6, sorted_vocab=1, bucket=2000000, | ||
trim_rule=None, batch_words=MAX_WORDS_IN_BATCH): | ||
|
||
# fastText specific params | ||
self.bucket = bucket | ||
self.word_ngrams = word_ngrams | ||
self.min_n = min_n | ||
self.max_n = max_n | ||
if self.word_ngrams <= 1 and self.max_n == 0: | ||
self.bucket = 0 | ||
|
||
super(FastText, self).__init__(sentences=sentences, size=size, alpha=alpha, window=window, min_count=min_count, | ||
max_vocab_size=max_vocab_size, sample=sample, seed=seed, workers=workers, min_alpha=min_alpha, | ||
sg=sg, hs=hs, negative=negative, cbow_mean=cbow_mean, hashfxn=hashfxn, iter=iter, null_word=null_word, | ||
trim_rule=trim_rule, sorted_vocab=sorted_vocab, batch_words=batch_words) | ||
|
||
def initialize_word_vectors(self): | ||
self.wv = FastTextKeyedVectors() | ||
self.wv.min_n = self.min_n | ||
self.wv.max_n = self.max_n | ||
|
||
def build_vocab(self, sentences, keep_raw_vocab=False, trim_rule=None, progress_per=10000, update=False): | ||
if update: | ||
if not len(self.wv.vocab): | ||
raise RuntimeError("You cannot do an online vocabulary-update of a model which has no prior vocabulary. " | ||
"First build the vocabulary of your model with a corpus " | ||
"before doing an online update.") | ||
self.old_vocab_len = len(self.wv.vocab) | ||
self.old_hash2index_len = len(self.wv.hash2index) | ||
|
||
super(FastText, self).build_vocab(sentences, keep_raw_vocab=keep_raw_vocab, trim_rule=trim_rule, progress_per=progress_per, update=update) | ||
self.init_ngrams(update=update) | ||
|
||
def init_ngrams(self, update=False): | ||
if not update: | ||
self.wv.ngrams = {} | ||
self.wv.syn0_vocab = empty((len(self.wv.vocab), self.vector_size), dtype=REAL) | ||
self.syn0_vocab_lockf = ones((len(self.wv.vocab), self.vector_size), dtype=REAL) | ||
|
||
self.wv.syn0_ngrams = empty((self.bucket, self.vector_size), dtype=REAL) | ||
self.syn0_ngrams_lockf = ones((self.bucket, self.vector_size), dtype=REAL) | ||
|
||
all_ngrams = [] | ||
for w, v in self.wv.vocab.items(): | ||
self.wv.ngrams_word[w] = compute_ngrams(w, self.min_n, self.max_n) | ||
all_ngrams += self.wv.ngrams_word[w] | ||
|
||
all_ngrams = list(set(all_ngrams)) | ||
self.num_ngram_vectors = len(all_ngrams) | ||
logger.info("Total number of ngrams is %d", len(all_ngrams)) | ||
|
||
self.wv.hash2index = {} | ||
ngram_indices = [] | ||
new_hash_count = 0 | ||
for i, ngram in enumerate(all_ngrams): | ||
ngram_hash = ft_hash(ngram) | ||
if ngram_hash in self.wv.hash2index: | ||
self.wv.ngrams[ngram] = self.wv.hash2index[ngram_hash] | ||
else: | ||
ngram_indices.append(ngram_hash % self.bucket) | ||
self.wv.hash2index[ngram_hash] = new_hash_count | ||
self.wv.ngrams[ngram] = self.wv.hash2index[ngram_hash] | ||
new_hash_count = new_hash_count + 1 | ||
|
||
self.wv.syn0_ngrams = self.wv.syn0_ngrams.take(ngram_indices, axis=0) | ||
self.syn0_ngrams_lockf = self.syn0_ngrams_lockf.take(ngram_indices, axis=0) | ||
self.reset_ngram_weights() | ||
else: | ||
new_ngrams = [] | ||
for w, v in self.wv.vocab.items(): | ||
self.wv.ngrams_word[w] = compute_ngrams(w, self.min_n, self.max_n) | ||
new_ngrams += [ng for ng in self.wv.ngrams_word[w] if ng not in self.wv.ngrams] | ||
|
||
new_ngrams = list(set(new_ngrams)) | ||
logger.info("Number of new ngrams is %d", len(new_ngrams)) | ||
new_hash_count = 0 | ||
for i, ngram in enumerate(new_ngrams): | ||
ngram_hash = ft_hash(ngram) | ||
if ngram_hash not in self.wv.hash2index: | ||
self.wv.hash2index[ngram_hash] = new_hash_count + self.old_hash2index_len | ||
self.wv.ngrams[ngram] = self.wv.hash2index[ngram_hash] | ||
new_hash_count = new_hash_count + 1 | ||
else: | ||
self.wv.ngrams[ngram] = self.wv.hash2index[ngram_hash] | ||
|
||
rand_obj = np.random | ||
rand_obj.seed(self.seed) | ||
new_vocab_rows = rand_obj.uniform(-1.0 / self.vector_size, 1.0 / self.vector_size, (len(self.wv.vocab) - self.old_vocab_len, self.vector_size)) | ||
new_vocab_lockf_rows = ones((len(self.wv.vocab) - self.old_vocab_len, self.vector_size), dtype=REAL) | ||
new_ngram_rows = rand_obj.uniform(-1.0 / self.vector_size, 1.0 / self.vector_size, (len(self.wv.hash2index) - self.old_hash2index_len, self.vector_size)) | ||
new_ngram_lockf_rows = ones((len(self.wv.hash2index) - self.old_hash2index_len, self.vector_size), dtype=REAL) | ||
|
||
self.wv.syn0_vocab = vstack([self.wv.syn0_vocab, new_vocab_rows]) | ||
self.syn0_vocab_lockf = vstack([self.syn0_vocab_lockf, new_vocab_lockf_rows]) | ||
self.wv.syn0_ngrams = vstack([self.wv.syn0_ngrams, new_ngram_rows]) | ||
self.syn0_ngrams_lockf = vstack([self.syn0_ngrams_lockf, new_ngram_lockf_rows]) | ||
|
||
def reset_ngram_weights(self): | ||
rand_obj = np.random | ||
rand_obj.seed(self.seed) | ||
for index in range(len(self.wv.vocab)): | ||
self.wv.syn0_vocab[index] = rand_obj.uniform(-1.0 / self.vector_size, 1.0 / self.vector_size, self.vector_size) | ||
for index in range(len(self.wv.hash2index)): | ||
self.wv.syn0_ngrams[index] = rand_obj.uniform(-1.0 / self.vector_size, 1.0 / self.vector_size, self.vector_size) | ||
|
||
def _do_train_job(self, sentences, alpha, inits): | ||
work, neu1 = inits | ||
tally = 0 | ||
if self.sg: | ||
tally += train_batch_sg(self, sentences, alpha, work) | ||
else: | ||
tally += train_batch_cbow(self, sentences, alpha, work, neu1) | ||
|
||
return tally, self._raw_word_count(sentences) | ||
|
||
def train(self, sentences, total_examples=None, total_words=None, | ||
epochs=None, start_alpha=None, end_alpha=None, | ||
word_count=0, queue_factor=2, report_delay=1.0): | ||
self.neg_labels = [] | ||
if self.negative > 0: | ||
# precompute negative labels optimization for pure-python training | ||
self.neg_labels = zeros(self.negative + 1) | ||
self.neg_labels[0] = 1. | ||
|
||
Word2Vec.train(self, sentences, total_examples=self.corpus_count, epochs=self.iter, | ||
start_alpha=self.alpha, end_alpha=self.min_alpha) | ||
self.get_vocab_word_vecs() | ||
|
||
def __getitem__(self, word): | ||
return self.word_vec(word) | ||
|
||
def get_vocab_word_vecs(self): | ||
for w, v in self.wv.vocab.items(): | ||
word_vec = self.wv.syn0_vocab[v.index] | ||
ngrams = self.wv.ngrams_word[w] | ||
ngram_weights = self.wv.syn0_ngrams | ||
for ngram in ngrams: | ||
word_vec += ngram_weights[self.wv.ngrams[ngram]] | ||
word_vec /= (len(ngrams) + 1) | ||
self.wv.syn0[v.index] = word_vec | ||
|
||
def word_vec(self, word, use_norm=False): | ||
return FastTextKeyedVectors.word_vec(self.wv, word, use_norm=use_norm) | ||
|
||
@classmethod | ||
def load_fasttext_format(cls, *args, **kwargs): | ||
return Ft_Wrapper.load_fasttext_format(*args, **kwargs) | ||
|
||
def save(self, *args, **kwargs): | ||
kwargs['ignore'] = kwargs.get('ignore', ['syn0norm', 'syn0_vocab_norm', 'syn0_ngrams_norm']) | ||
super(FastText, self).save(*args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason for not reusing
word_vec
fromFastTextKeyedVectors
here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I get this error if I try to access an OOV word:
print model["use"] # only "user" is available in the vocabulary
Looks like
FastTextKeyedVectors
class has nomin_n
attribute?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's the minimal example to reproduce my issue https://gist.github.com/rilut/31f41d5cf3da075d43cf7e4f2c993b76
Thanks 🙏
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw, since
self.wv = FastTextKeyedVectors()
can we just doself.wv.word_vec(word, word_norm)
? (just asking)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rilut Seems like I missed setting
min_n
andmax_n
inFastTextKeyedVectors
after I had refactored the code b/w the wrapper and mainmodels.FastText
class.Anyway, thanks a lot for pointing out this! I have made the necessary changes for this now and also added a unittest to ensure that this error doesn't go unnoticed in the future.
And yes,
FastTextKeyedVectors.word_vec(self.wv, word, use_norm=use_norm)
andself.wv.word_vec(word, word_norm)
are equivalent. I preferred the former form at the time since it made the borrowed usage of the function more explicit.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@chinmayapancholi13 oh ok that's good. Thanks for your fix and hard work! I really appreciated it.