-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Add ability to use Tensorflow to train a word2vec model #809
Changes from all commits
2fb9af7
2349cee
3f3acf7
58c908c
c2607c6
ad03abd
30b681d
7ce0d95
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
import tensorflow as tf | ||
from tensorflow.models.embedding.word2vec_optimized import Word2Vec | ||
from gensim.models.word2vec import Word2Vec as GensimWord2Vec, Vocab | ||
from gensim import utils | ||
from six import string_types | ||
import logging | ||
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO) | ||
|
||
|
||
class GensimWord2VecNoTraining(GensimWord2Vec): | ||
""" | ||
Gensim word2vec without training methods | ||
|
||
""" | ||
|
||
def make_cum_table(self, *args, **kwargs): | ||
raise Exception("Cannot call a gensim training method on a tf trained model") | ||
|
||
def create_binary_tree(self, *args, **kwargs): | ||
raise Exception("Cannot call a gensim training method on a tf trained model") | ||
|
||
def build_vocab(self, *args, **kwargs): | ||
raise Exception("Cannot call a gensim training method on a tf trained model") | ||
|
||
def scan_vocab(self, *args, **kwargs): | ||
raise Exception("Cannot call a gensim training method on a tf trained model") | ||
|
||
def scale_vocab(self, *args, **kwargs): | ||
raise Exception("Cannot call a gensim training method on a tf trained model") | ||
|
||
def finalize_vocab(self, *args, **kwargs): | ||
raise Exception("Cannot call a gensim training method on a tf trained model") | ||
|
||
def sort_vocab(self, *args, **kwargs): | ||
raise Exception("Cannot call a gensim training method on a tf trained model") | ||
|
||
def _do_train_job(self, *args, **kwargs): | ||
raise Exception("Cannot call a gensim training method on a tf trained model") | ||
|
||
def train(self, *args, **kwargs): | ||
raise Exception("Cannot call a gensim training method on a tf trained model") | ||
|
||
def score(self, *args, **kwargs): | ||
raise Exception("Cannot call a gensim training method on a tf trained model") | ||
|
||
def save_word2vec_format(self, *args, **kwargs): | ||
raise Exception("Cannot call a gensim training method on a tf trained model") | ||
|
||
@classmethod | ||
def load_word2vec_format(cls, *args, **kwargs): | ||
raise Exception("Cannot call a gensim training method on a tf trained model") | ||
|
||
def intersect_word2vec_format(self, *args, **kwargs): | ||
raise Exception("Cannot call a gensim training method on a tf trained model") | ||
|
||
|
||
class Options(object): | ||
"""Options class that doesn't use FLAGS""" | ||
|
||
def __init__(self, train_data=None, save_path=None, eval_data=None, | ||
embedding_size=200, epochs_to_train=15, learning_rate=0.025, | ||
num_neg_samples=25, batch_size=500, concurrent_steps=12, | ||
window_size=5, min_count=5, subsample=1e-3): | ||
""" | ||
train_data: Training data. E.g., unzipped file http://mattmahoney.net/dc/text8.zip. | ||
|
||
save_path: Directory to write the model. | ||
|
||
eval_data: Analogy questions. | ||
|
||
embedding_size: The embedding dimension size. | ||
|
||
epochs_to_train: Number of epochs to train. Each epoch processes the training data once | ||
|
||
learning_rate: Initial learning rate | ||
|
||
batch_size: Numbers of training examples each step processes | ||
|
||
concurrent_steps: The number of concurrent training steps. | ||
|
||
window_size: The number of worlds to predict to the left and right of the target word. | ||
|
||
min_count: The minimum number of word occurrences for it to be included in the vocabulary. | ||
|
||
subsample: Subsample threshold for word occurrence. Words that appear with higher frequency | ||
will be randomly down-sampled. Set to 0 to disable. | ||
""" | ||
|
||
if train_data is None: | ||
raise ValueError("train_data must be specified.") | ||
|
||
# Model options. | ||
|
||
# Embedding dimension. | ||
self.emb_dim = embedding_size | ||
|
||
# Training options. | ||
|
||
# The training text file. | ||
self.train_data = train_data | ||
|
||
# Number of negative samples per example. | ||
self.num_samples = num_neg_samples | ||
|
||
# The initial learning rate. | ||
self.learning_rate = learning_rate | ||
|
||
# Number of epochs to train. After these many epochs, the learning | ||
# rate decays linearly to zero and the training stops. | ||
self.epochs_to_train = epochs_to_train | ||
|
||
# Concurrent training steps. | ||
self.concurrent_steps = concurrent_steps | ||
|
||
# Number of examples for one training step. | ||
self.batch_size = batch_size | ||
|
||
# The number of words to predict to the left and right of the target word. | ||
self.window_size = window_size | ||
|
||
# The minimum number of word occurrences for it to be included in the | ||
# vocabulary. | ||
self.min_count = min_count | ||
|
||
# Subsampling threshold for word occurrence. | ||
self.subsample = subsample | ||
|
||
# Where to write out summaries. | ||
self.save_path = save_path | ||
|
||
# Eval options. | ||
|
||
# The text file for eval. | ||
self.eval_data = eval_data | ||
|
||
|
||
def modified_tfw2v_init(self, options, session): | ||
self._options = options | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a conflict with another variable name? Or why are these attributes internal (~underscored, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are the standard parameters used by TF which are being overrided. |
||
self._session = session | ||
self._word2id = {} | ||
self._id2word = [] | ||
self.build_graph() | ||
self.build_eval_graph() | ||
if options.save_path is not None: | ||
self.save_vocab() | ||
if options.eval_data is not None: | ||
self._read_analogies() | ||
|
||
Word2Vec.__init__ = modified_tfw2v_init | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is this meant to do? Such magic -- monkey-patching outside modules -- has to be very carefully motivated and documented. And best avoided. |
||
|
||
|
||
class TfWord2Vec(GensimWord2VecNoTraining): | ||
|
||
def __init__(self, train_data=None, save_path=None, eval_data=None, | ||
embedding_size=200, epochs_to_train=15, learning_rate=0.025, | ||
num_neg_samples=25, batch_size=500, concurrent_steps=12, | ||
window_size=5, min_count=5, subsample=1e-3): | ||
|
||
self.options = Options(train_data, save_path=save_path, eval_data=eval_data, | ||
embedding_size=embedding_size, epochs_to_train=epochs_to_train, | ||
learning_rate=learning_rate, num_neg_samples=num_neg_samples, | ||
batch_size=batch_size, concurrent_steps=concurrent_steps, | ||
window_size=window_size, min_count=min_count, subsample=subsample) | ||
|
||
self.convert_input(train_data) | ||
self.train() | ||
self.vocab = {} | ||
self.create_vocab() | ||
|
||
def train(self): | ||
with tf.Graph().as_default(), tf.Session() as session: | ||
self.model = Word2Vec(self.options, session) | ||
for _ in xrange(self.options.epochs_to_train): | ||
self.model.train() # Process one epoch | ||
if self.options.eval_data is not None: | ||
self.model.eval() # Eval analogies.''' | ||
|
||
self.syn0 = self.model._w_in | ||
self.syn0norm = session.run(tf.nn.l2_normalize(self.model._w_in, 1)) | ||
self.index2word = self.model._id2word | ||
|
||
def create_vocab(self): | ||
for word in self.options.vocab_words: | ||
self.vocab[word] = Vocab(index=self.model._word2id[word]) | ||
|
||
def convert_input(self, corpus): | ||
""" | ||
Converts gensim corpus to a file that can be used by tf word2vec | ||
|
||
""" | ||
#assumes that the string represents a file extension | ||
if not isinstance(corpus, str): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Filenames can be unicode too (check against |
||
with utils.smart_open('/tmp/converted_corpus', 'w+') as fout: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not portable -- use More importantly --- what is this supposed to do? Hardwiring some paths into code looks like a bad idea. |
||
for line in corpus: | ||
for word in line: | ||
fout.write(utils.to_utf8(str(word) + " ")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this do what you intended? Looks like this will fail on unicode strings. |
||
fout.write("\n") | ||
self.options.train_data = "/tmp/converted_corpus" | ||
|
||
def __getitem__(self, words): | ||
if isinstance(words, string_types): | ||
# allow calls like trained_model['office'], as a shorthand for trained_model[['office']] | ||
return self.syn0norm[self.model._word2id[words]] | ||
|
||
ids = [self.model._word2id[word] for word in words] | ||
return [self.syn0norm[id] for id in ids] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import gensim,os,time | ||
from gensim.models.tfword2vec import TfWord2Vec | ||
from gensim.models.word2vec import Word2Vec, Text8Corpus | ||
|
||
#TF benchmark | ||
start = time.time() | ||
model = TfWord2Vec("text8", epochs_to_train=5, batch_size=100) | ||
print "Tensorflow:\n" + str(time.time()-start) | ||
|
||
#Gensim benchmark | ||
corpus = Text8Corpus("text8") | ||
start = time.time() | ||
model = Word2Vec(corpus) | ||
print "Gensim:\n" + str(time.time()-start) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't belong in module scope -- libraries do not set up logging. That's up to applications that use them.