Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loading and Saving LDA Models across Python 2 and 3. #913

Closed
wants to merge 9 commits into from
51 changes: 49 additions & 2 deletions gensim/models/ldamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from scipy.special import polygamma
from six.moves import xrange
import six
import json

# log(sum(exp(x))) that tries to avoid overflow
try:
Expand Down Expand Up @@ -979,7 +980,7 @@ def __getitem__(self, bow, eps=None):
"""
return self.get_document_topics(bow, eps)

def save(self, fname, ignore=['state', 'dispatcher'], *args, **kwargs):
def save(self, fname, ignore=['state', 'dispatcher'], separately = None, *args, **kwargs):
"""
Save the model to file.

Expand Down Expand Up @@ -1018,7 +1019,41 @@ def save(self, fname, ignore=['state', 'dispatcher'], *args, **kwargs):
ignore = list(set(['state', 'dispatcher']) | set(ignore))
else:
ignore = ['state', 'dispatcher']
super(LdaModel, self).save(fname, *args, ignore=ignore, **kwargs)

# make sure 'expElogbeta' and 'sstats' are ignored from the pickled object, even if
# someone sets the separately list themselves.
separately_explicit = ['expElogbeta', 'sstats']
# Also add 'alpha' and 'eta' to separately list if they are set 'auto' or some
# array manually.
if (isinstance(self.alpha, six.string_types) and self.alpha == 'auto') or len(self.alpha.shape) != 1:
separately_explicit.append('alpha')
if (isinstance(self.eta, six.string_types) and self.eta == 'auto') or len(self.eta.shape) != 1:
separately_explicit.append('eta')
# Merge separately_explicit with separately.
if separately:
if isinstance(separately, six.string_types):
separately = [separately]
separately = [e for e in separately if e] # make sure None and '' are not in the list
separately = list(set(separately_explicit) | set(separately))
else:
separately = separately_explicit

# id2word needs to saved separately.
# If id2word is not already in ignore, then saving it separately in json.
id2word = None
if self.id2word is not None and 'id2word' not in ignore:
id2word = dict((k,v) for k,v in self.id2word.iteritems())
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PEP8: space after comma.

self.id2word = None # remove the dictionary from model
super(LdaModel, self).save(fname, ignore=ignore, separately = separately, *args, **kwargs)
self.id2word = id2word # restore the dictionary.

# Save the dictionary separately in json.
id2word_fname = utils.smart_extension(fname, '.json')
try:
with utils.smart_open(id2word_fname, 'w', encoding='utf-8') as fout:
json.dump(id2word, fout)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better open the output as binary and write encoded utf8 to it.

Actually, the json module already produces binary strings in dump AFAIK, so what is this even for?

except Exception as e:
logging.warning("failed to save id2words dictionary in %s: %s", id2word_fname, e)

@classmethod
def load(cls, fname, *args, **kwargs):
Expand All @@ -1032,6 +1067,18 @@ def load(cls, fname, *args, **kwargs):
"""
kwargs['mmap'] = kwargs.get('mmap', None)
result = super(LdaModel, cls).load(fname, *args, **kwargs)
# Load the separately stored id2word dictionary saved in json.
id2word_fname = utils.smart_extension(fname, '.json')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please make all files for one model in a special folder, so it is easy to keep track

try:
with utils.smart_open(id2word_fname, 'r') as fin:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Open file as binary, decode as necessary (if necessary).

id2word = json.load(fin)
if id2word is not None:
result.id2word = utils.FakeDict(id2word)
else:
result.id2word = None
except Exception as e:
logging.warning("failed to load id2words from %s: %s", id2word_fname, e)

state_fname = utils.smart_extension(fname, '.state')
try:
result.state = super(LdaModel, cls).load(state_fname, *args, **kwargs)
Expand Down
Binary file added gensim/test/ldamodel_python_2_7
Binary file not shown.
Binary file added gensim/test/ldamodel_python_2_7.eta.npy
Binary file not shown.
Binary file added gensim/test/ldamodel_python_2_7.expElogbeta.npy
Binary file not shown.
1 change: 1 addition & 0 deletions gensim/test/ldamodel_python_2_7.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"0": "interface", "1": "computer", "2": "human", "3": "response", "4": "time", "5": "survey", "6": "system", "7": "user", "8": "eps", "9": "trees", "10": "graph", "11": "minors"}
Binary file added gensim/test/ldamodel_python_2_7.state
Binary file not shown.
Binary file added gensim/test/ldamodel_python_3_5
Binary file not shown.
Binary file added gensim/test/ldamodel_python_3_5.eta.npy
Binary file not shown.
Binary file added gensim/test/ldamodel_python_3_5.expElogbeta.npy
Binary file not shown.
1 change: 1 addition & 0 deletions gensim/test/ldamodel_python_3_5.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"0": "interface", "1": "human", "2": "computer", "3": "response", "4": "system", "5": "user", "6": "time", "7": "survey", "8": "eps", "9": "trees", "10": "graph", "11": "minors"}
Binary file added gensim/test/ldamodel_python_3_5.state
Binary file not shown.
18 changes: 18 additions & 0 deletions gensim/test/test_ldamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,24 @@ def testPersistence(self):
tstvec = []
self.assertTrue(np.allclose(model[tstvec], model2[tstvec])) # try projecting an empty vector

# Method used to save LDA models in Python 2.7 and 3.5 environments.
# def testSaveModelsForPythonVersion(self):
# fname = os.path.join(os.path.dirname(__file__), 'ldamodel_python_2_7')
# corpus = mmcorpus.MmCorpus(datapath('testcorpus.mm'))
# model = ldamodel.LdaModel(corpus, id2word=dictionary, num_topics=2, passes=100, random_state = 1000007)
# model.save(fname)
# logging.warning("LDA Model saved")

def testModelCompatibilityWithPythonVersions(self):
fname_model_2_7 = os.path.join(os.path.dirname(__file__), 'ldamodel_python_2_7')
model_2_7 = self.class_.load(fname_model_2_7)
fname_model_3_5 = os.path.join(os.path.dirname(__file__), 'ldamodel_python_3_5')
model_3_5 = self.class_.load(fname_model_3_5)
self.assertEqual(model_2_7.num_topics, model_3_5.num_topics)
self.assertTrue(numpy.allclose(model_2_7.expElogbeta, model_3_5.expElogbeta))
tstvec = []
self.assertTrue(numpy.allclose(model_2_7[tstvec], model_3_5[tstvec])) # try projecting an empty vector

def testPersistenceIgnore(self):
fname = testfile()
model = ldamodel.LdaModel(self.corpus, num_topics=2)
Expand Down
13 changes: 13 additions & 0 deletions gensim/test/test_word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,19 @@ def testPersistenceWord2VecFormatCombinationWithStandardPersistence(self):
binary_model_with_vocab.save(testfile())
binary_model_with_vocab = word2vec.Word2Vec.load(testfile())
self.assertEqual(model.vocab['human'].count, binary_model_with_vocab.vocab['human'].count)

# def testSaveModelsForPythonVersion(self):
# fname = os.path.join(os.path.dirname(__file__), 'word2vecmodel_python_3_5')
# model = word2vec.Word2Vec(sentences, size=10, min_count=0, seed=42, hs=1, negative=0)
# model.save(fname)
# logging.warning("Word2Vec model saved")

def testModelCompatibilityWithPythonVersions(self):
fname_model_2_7 = os.path.join(os.path.dirname(__file__), 'word2vecmodel_python_2_7')
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use module_path and datapath defined above.

model_2_7 = word2vec.Word2Vec.load(fname_model_2_7)
fname_model_3_5 = os.path.join(os.path.dirname(__file__), 'word2vecmodel_python_3_5')
model_3_5 = word2vec.Word2Vec.load(fname_model_3_5)
self.models_equal(model_2_7, model_3_5)

def testLargeMmap(self):
"""Test storing/loading the entire model."""
Expand Down
Binary file added gensim/test/word2vecmodel_python_2_7
Binary file not shown.
Binary file added gensim/test/word2vecmodel_python_3_5
Binary file not shown.
8 changes: 5 additions & 3 deletions gensim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,10 +907,12 @@ def pickle(obj, fname, protocol=2):

def unpickle(fname):
"""Load pickled object from `fname`"""
with smart_open(fname) as f:
with smart_open(fname, 'rb') as f:
# Because of loading from S3 load can't be used (missing readline in smart_open)
return _pickle.loads(f.read())

if sys.version_info > (3,0):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PEP8: space after comma.

return _pickle.load(f, encoding='latin1')
else:
return _pickle.loads(f.read())

def revdict(d):
"""
Expand Down