diff --git a/gensim/models/word2vec.py b/gensim/models/word2vec.py index 50be048f4a..b4fab149aa 100755 --- a/gensim/models/word2vec.py +++ b/gensim/models/word2vec.py @@ -425,7 +425,8 @@ class Word2Vec(BaseWordEmbeddingsModel): def __init__(self, sentences=None, size=100, alpha=0.025, window=5, min_count=5, max_vocab_size=None, sample=1e-3, seed=1, workers=3, min_alpha=0.0001, sg=0, hs=0, negative=5, cbow_mean=1, hashfxn=hash, iter=5, null_word=0, - trim_rule=None, sorted_vocab=1, batch_words=MAX_WORDS_IN_BATCH, compute_loss=False, callbacks=()): + trim_rule=None, sorted_vocab=1, batch_words=MAX_WORDS_IN_BATCH, compute_loss=False, callbacks=(), + max_final_vocab=None): """ Initialize the model from an iterable of `sentences`. Each sentence is a list of words (unicode strings) that will be used for training. @@ -462,6 +463,10 @@ def __init__(self, sentences=None, size=100, alpha=0.025, window=5, min_count=5, Limits the RAM during vocabulary building; if there are more unique words than this, then prune the infrequent ones. Every 10 million word types need about 1GB of RAM. Set to `None` for no limit. + max_final_vocab : int + Limits the vocab to a target vocab size by automatically picking a matching min_count. If the specified + min_count is more than the calculated min_count, the specified min_count will be used. + Set to `None` if not required. sample : float The threshold for configuring which higher-frequency words are randomly downsampled, useful range is (0, 1e-5). @@ -510,6 +515,7 @@ def __init__(self, sentences=None, size=100, alpha=0.025, window=5, min_count=5, >>> say_vector = model['say'] # get vector for word """ + self.max_final_vocab = max_final_vocab self.callbacks = callbacks self.load = call_on_class_only @@ -517,7 +523,7 @@ def __init__(self, sentences=None, size=100, alpha=0.025, window=5, min_count=5, self.wv = Word2VecKeyedVectors(size) self.vocabulary = Word2VecVocab( max_vocab_size=max_vocab_size, min_count=min_count, sample=sample, - sorted_vocab=bool(sorted_vocab), null_word=null_word) + sorted_vocab=bool(sorted_vocab), null_word=null_word, max_final_vocab=max_final_vocab) self.trainables = Word2VecTrainables(seed=seed, vector_size=size, hashfxn=hashfxn) super(Word2Vec, self).__init__( @@ -972,7 +978,14 @@ def load(cls, *args, **kwargs): Returns the loaded model as an instance of :class: `~gensim.models.word2vec.Word2Vec`. """ try: - return super(Word2Vec, cls).load(*args, **kwargs) + model = super(Word2Vec, cls).load(*args, **kwargs) + + # for backward compatibility for `max_final_vocab` feature + if not hasattr(model, 'max_final_vocab'): + model.max_final_vocab = None + model.vocabulary.max_final_vocab = None + + return model except AttributeError: logger.info('Model saved using code from earlier Gensim Version. Re-loading old model in a compatible way.') from gensim.models.deprecated.word2vec import load_old_word2vec @@ -1131,7 +1144,8 @@ def __iter__(self): class Word2VecVocab(utils.SaveLoad): - def __init__(self, max_vocab_size=None, min_count=5, sample=1e-3, sorted_vocab=True, null_word=0): + def __init__(self, max_vocab_size=None, min_count=5, sample=1e-3, sorted_vocab=True, null_word=0, + max_final_vocab=None): self.max_vocab_size = max_vocab_size self.min_count = min_count self.sample = sample @@ -1139,6 +1153,7 @@ def __init__(self, max_vocab_size=None, min_count=5, sample=1e-3, sorted_vocab=T self.null_word = null_word self.cum_table = None # for negative sampling self.raw_vocab = None + self.max_final_vocab = max_final_vocab def scan_vocab(self, sentences, progress_per=10000, trim_rule=None): """Do an initial scan of all words appearing in sentences.""" @@ -1204,6 +1219,23 @@ def prepare_vocab(self, hs, negative, wv, update=False, keep_raw_vocab=False, tr sample = sample or self.sample drop_total = drop_unique = 0 + # set effective_min_count to min_count in case max_final_vocab isn't set + self.effective_min_count = min_count + + # if max_final_vocab is specified instead of min_count + # pick a min_count which satisfies max_final_vocab as well as possible + if self.max_final_vocab is not None: + sorted_vocab = sorted(self.raw_vocab.keys(), key=lambda word: self.raw_vocab[word], reverse=True) + calc_min_count = 1 + + if self.max_final_vocab < len(sorted_vocab): + calc_min_count = self.raw_vocab[sorted_vocab[self.max_final_vocab]] + 1 + + self.effective_min_count = max(calc_min_count, min_count) + logger.info("max_final_vocab=%d and min_count=%d resulted in calc_min_count=%d, effective_min_count=%d", + self.max_final_vocab, min_count, calc_min_count, self.effective_min_count + ) + if not update: logger.info("Loading a fresh vocabulary") retain_total, retain_words = 0, [] @@ -1216,7 +1248,7 @@ def prepare_vocab(self, hs, negative, wv, update=False, keep_raw_vocab=False, tr wv.vocab = {} for word, v in iteritems(self.raw_vocab): - if keep_vocab_item(word, v, min_count, trim_rule=trim_rule): + if keep_vocab_item(word, v, self.effective_min_count, trim_rule=trim_rule): retain_words.append(word) retain_total += v if not dry_run: @@ -1228,21 +1260,21 @@ def prepare_vocab(self, hs, negative, wv, update=False, keep_raw_vocab=False, tr original_unique_total = len(retain_words) + drop_unique retain_unique_pct = len(retain_words) * 100 / max(original_unique_total, 1) logger.info( - "min_count=%d retains %i unique words (%i%% of original %i, drops %i)", - min_count, len(retain_words), retain_unique_pct, original_unique_total, drop_unique + "effective_min_count=%d retains %i unique words (%i%% of original %i, drops %i)", + self.effective_min_count, len(retain_words), retain_unique_pct, original_unique_total, drop_unique ) original_total = retain_total + drop_total retain_pct = retain_total * 100 / max(original_total, 1) logger.info( - "min_count=%d leaves %i word corpus (%i%% of original %i, drops %i)", - min_count, retain_total, retain_pct, original_total, drop_total + "effective_min_count=%d leaves %i word corpus (%i%% of original %i, drops %i)", + self.effective_min_count, retain_total, retain_pct, original_total, drop_total ) else: logger.info("Updating model with new vocabulary") new_total = pre_exist_total = 0 new_words = pre_exist_words = [] for word, v in iteritems(self.raw_vocab): - if keep_vocab_item(word, v, min_count, trim_rule=trim_rule): + if keep_vocab_item(word, v, self.effective_min_count, trim_rule=trim_rule): if word in wv.vocab: pre_exist_words.append(word) pre_exist_total += v diff --git a/gensim/test/test_data/word2vec_3.3 b/gensim/test/test_data/word2vec_3.3 new file mode 100644 index 0000000000..17f869bc4a Binary files /dev/null and b/gensim/test/test_data/word2vec_3.3 differ diff --git a/gensim/test/test_word2vec.py b/gensim/test/test_word2vec.py index d1ec2703e9..9641a332a1 100644 --- a/gensim/test/test_word2vec.py +++ b/gensim/test/test_word2vec.py @@ -145,6 +145,27 @@ def testTotalWordCount(self): total_words = model.vocabulary.scan_vocab(sentences)[0] self.assertEqual(total_words, 29) + def testMaxFinalVocab(self): + # Test for less restricting effect of max_final_vocab + # max_final_vocab is specified but has no effect + model = word2vec.Word2Vec(size=10, max_final_vocab=4, min_count=4, sample=0) + model.vocabulary.scan_vocab(sentences) + reported_values = model.vocabulary.prepare_vocab(wv=model.wv, hs=0, negative=0) + self.assertEqual(reported_values['drop_unique'], 11) + self.assertEqual(reported_values['retain_total'], 4) + self.assertEqual(reported_values['num_retained_words'], 1) + self.assertEqual(model.vocabulary.effective_min_count, 4) + + # Test for more restricting effect of max_final_vocab + # results in setting a min_count more restricting than specified min_count + model = word2vec.Word2Vec(size=10, max_final_vocab=4, min_count=2, sample=0) + model.vocabulary.scan_vocab(sentences) + reported_values = model.vocabulary.prepare_vocab(wv=model.wv, hs=0, negative=0) + self.assertEqual(reported_values['drop_unique'], 8) + self.assertEqual(reported_values['retain_total'], 13) + self.assertEqual(reported_values['num_retained_words'], 4) + self.assertEqual(model.vocabulary.effective_min_count, 3) + def testOnlineLearning(self): """Test that the algorithm is able to add new words to the vocabulary and to a trained model when using a sorted vocabulary""" @@ -753,6 +774,12 @@ def testLoadOldModel(self): self.assertTrue(model.trainables.vectors_lockf.shape == (12,)) self.assertTrue(model.vocabulary.cum_table.shape == (12,)) + # test for max_final_vocab for model saved in 3.3 + model_file = 'word2vec_3.3' + model = word2vec.Word2Vec.load(datapath(model_file)) + self.assertEqual(model.max_final_vocab, None) + self.assertEqual(model.vocabulary.max_final_vocab, None) + @log_capture() def testBuildVocabWarning(self, l): """Test if warning is raised on non-ideal input to a word2vec model"""