Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fix backward incompatibility due to random_state #1327

Merged
merged 11 commits into from
May 25, 2017
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions gensim/models/ldamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,11 @@ def save(self, fname, ignore=['state', 'dispatcher'], separately=None, *args, **
"""
if self.state is not None:
self.state.save(utils.smart_extension(fname, '.state'), *args, **kwargs)

# Save 'random_state' separately
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the purpose of saving it separately?

Copy link
Contributor

@menshikh-iv menshikh-iv May 17, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tmylk I think It's better for case "file don't exist" (and more flexible)

if self.random_state is not None:
utils.pickle(self.random_state, utils.smart_extension(fname, '.random_state'))

# Save the dictionary separately if not in 'ignore'.
if 'id2word' not in ignore:
utils.pickle(self.id2word, utils.smart_extension(fname, '.id2word'))
Expand All @@ -1023,9 +1028,9 @@ def save(self, fname, ignore=['state', 'dispatcher'], separately=None, *args, **
separately_explicit = ['expElogbeta', 'sstats']
# Also add 'alpha' and 'eta' to separately list if they are set 'auto' or some
# array manually.
if (isinstance(self.alpha, six.string_types) and self.alpha == 'auto') or len(self.alpha.shape) != 1:
if (isinstance(self.alpha, six.string_types) and self.alpha == 'auto') or (isinstance(self.alpha, np.ndarray) and len(self.alpha.shape) != 1):
separately_explicit.append('alpha')
if (isinstance(self.eta, six.string_types) and self.eta == 'auto') or len(self.eta.shape) != 1:
if (isinstance(self.eta, six.string_types) and self.eta == 'auto') or (isinstance(self.eta, np.ndarray) and len(self.eta.shape) != 1):
separately_explicit.append('eta')
# Merge separately_explicit with separately.
if separately:
Expand Down Expand Up @@ -1054,13 +1059,22 @@ def load(cls, fname, *args, **kwargs):
result.state = super(LdaModel, cls).load(state_fname, *args, **kwargs)
except Exception as e:
logging.warning("failed to load state from %s: %s", state_fname, e)
id2word_fname = utils.smart_extension(fname, '.id2word')
if (os.path.isfile(id2word_fname)):
try:
result.id2word = utils.unpickle(id2word_fname)
except Exception as e:
logging.warning("failed to load id2word dictionary from %s: %s", id2word_fname, e)

random_state_fname = utils.smart_extension(fname, '.random_state')
if (os.path.isfile(random_state_fname)):
result.random_state = utils.unpickle(random_state_fname)
else:
result.id2word = None
logging.warning("random_state not stored on disk so using default value")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if random_state is not stored that means that it is a new version of the model and it is going to be loaded in the main pickle load. Please change the logic

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tmylk Your suggestion is a bit different from the earlier discussion on the issue. Hence, I want to make sure I understand the desired solution before making the changes.

If I understand it correctly, what you are suggesting is :

  • If there is indeed a file with the extension .random_state present on disk, this means that the model was saved using a pre-0.13.2 version of Gensim. So we use this file to set result.random_state at the time of loading.
  • However, if there is no such file present on disk, then this means that the model was saved using a post-0.13.2 version of Gensim and thus result.random_state got set at the time of the main pickle load. So in this case, we don't need to do anything else.

But for models saved using a pre-0.13.2 version of Gensim, there was no .random_state file created at the time of saving the model. So while loading such a model from disk, where would the .random_state file come from in this case? Is the user responsible for creating this file explicitly in such a case? If this is true, then I believe we don't need to make any changes in the save function for LdaModel at all and we just need to change the load function.

Please correct me if I am wrong or missing something here. Otherwise, if this is indeed what we need, I could go ahead and make the appropriate changes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a check that this indeed happened: thus result.random_state got set at the time of the main pickle load.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tmylk I believe I have understood the solution that we want. However, I have a minor doubt about where would the .random_state file come from? Is it that the user is responsible for creating it explicitly always and we (i.e. from within the save function) need not create it ever?
If this is true, then in case we are loading a pre-0.13.2 model and no .random_state file exists on disk, then should we set result.random_state using a default value like get_random_state(None) with a logger.warning?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tmylk Could you please respond to this query?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it that the user is responsible for creating it explicitly always

I think the user should not think about additional files, he saves the whole model

If this is true, then in case we are loading a pre-0.13.2 model and no .random_state file exists on disk, then should we set result.random_state using a default value like get_random_state(None) with a logger.warning?

I think it's correct

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The simple way to do this is check that random_state was loaded, if not - you set result.random_state using a default value like get_random_state(None) with a logger.warning

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@menshikh-iv Then there is no .random_state file involved at all, correct? To summarize, the solution is :

  • First, load the entire model.
  • Check if result.random_state was set or not. For the newer (post 0.13.2) models, it would have been set through the main pickle load. For the older (pre 0.13.2) models, result.random_state would not be set through the main pickle load so we set result.random_state to get_random_state(None).

And in this solution, we don't need to make any changes in the save function, just the load function.

result.random_state = utils.get_random_state(None)

if not result.id2word:
id2word_fname = utils.smart_extension(fname, '.id2word')
if (os.path.isfile(id2word_fname)):
try:
result.id2word = utils.unpickle(id2word_fname)
except Exception as e:
logging.warning("failed to load id2word dictionary from %s: %s", id2word_fname, e)
else:
result.id2word = None
return result
# endclass LdaModel
Binary file added gensim/test/test_data/pre_0_13_2_model
Binary file not shown.
Binary file added gensim/test/test_data/pre_0_13_2_model.state
Binary file not shown.
33 changes: 28 additions & 5 deletions gensim/test/test_ldamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

def testfile(test_fname=''):
# temporary data will be stored to this file
fname = 'gensim_models_' + test_fname + '.tst'
fname = 'gensim_models_' + test_fname + '.tst'
return os.path.join(tempfile.gettempdir(), fname)


Expand Down Expand Up @@ -247,9 +247,9 @@ def testGetDocumentTopics(self):

#Test case to use the get_document_topic function for the corpus
all_topics = model.get_document_topics(self.corpus, per_word_topics=True)

self.assertEqual(model.state.numdocs, len(corpus))

for topic in all_topics:
self.assertTrue(isinstance(topic, tuple))
for k, v in topic[0]: # list of doc_topics
Expand All @@ -269,9 +269,9 @@ def testGetDocumentTopics(self):
word_phi_count_na = 0

all_topics = model.get_document_topics(self.corpus, minimum_probability=0.8, minimum_phi_value=1.0, per_word_topics=True)

self.assertEqual(model.state.numdocs, len(corpus))

for topic in all_topics:
self.assertTrue(isinstance(topic, tuple))
for k, v in topic[0]: # list of doc_topics
Expand Down Expand Up @@ -470,6 +470,29 @@ def testLargeMmapCompressed(self):
# test loading the large model arrays with mmap
self.assertRaises(IOError, self.class_.load, fname, mmap='r')

def testId2WordBackwardCompatibility(self):
# load a model saved using a pre-0.13.2 version of Gensim
pre_0_13_2_fname = datapath('pre_0_13_2_model')
model_pre_0_13_2 = self.class_.load(pre_0_13_2_fname)

model_topics = model_pre_0_13_2.print_topics(num_topics=3, num_words=3)

for i in model_topics:
self.assertTrue(isinstance(i[0], int))
self.assertTrue(isinstance(i[1], six.string_types))

def testRandomStateBackwardCompatibility(self):
# load a model saved using a pre-0.13.2 version of Gensim
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test is identical to the previous one. just one test is enough that checks all the fields

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I thought since these were two different issues so we'd want to put separate tests to verify both are resolved. I'll make the update according to your suggestion and remove the earlier test here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to avoid code duplication

pre_0_13_2_fname = datapath('pre_0_13_2_model')
model_pre_0_13_2 = self.class_.load(pre_0_13_2_fname)

# set `num_topics` less than `model_pre_0_13_2.num_topics` so that `model_pre_0_13_2.random_state` is used
model_topics = model_pre_0_13_2.print_topics(num_topics=2, num_words=3)

for i in model_topics:
self.assertTrue(isinstance(i[0], int))
self.assertTrue(isinstance(i[1], six.string_types))

#endclass TestLdaModel


Expand Down