-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Fix backward incompatibility due to random_state
#1327
Changes from 6 commits
43d3f00
addb8c7
0edd447
45152e9
ba5bfb8
81d3b3f
562f959
8687b17
129bd8e
2010559
c7194c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1004,6 +1004,11 @@ def save(self, fname, ignore=['state', 'dispatcher'], separately=None, *args, ** | |
""" | ||
if self.state is not None: | ||
self.state.save(utils.smart_extension(fname, '.state'), *args, **kwargs) | ||
|
||
# Save 'random_state' separately | ||
if self.random_state is not None: | ||
utils.pickle(self.random_state, utils.smart_extension(fname, '.random_state')) | ||
|
||
# Save the dictionary separately if not in 'ignore'. | ||
if 'id2word' not in ignore: | ||
utils.pickle(self.id2word, utils.smart_extension(fname, '.id2word')) | ||
|
@@ -1023,9 +1028,9 @@ def save(self, fname, ignore=['state', 'dispatcher'], separately=None, *args, ** | |
separately_explicit = ['expElogbeta', 'sstats'] | ||
# Also add 'alpha' and 'eta' to separately list if they are set 'auto' or some | ||
# array manually. | ||
if (isinstance(self.alpha, six.string_types) and self.alpha == 'auto') or len(self.alpha.shape) != 1: | ||
if (isinstance(self.alpha, six.string_types) and self.alpha == 'auto') or (isinstance(self.alpha, np.ndarray) and len(self.alpha.shape) != 1): | ||
separately_explicit.append('alpha') | ||
if (isinstance(self.eta, six.string_types) and self.eta == 'auto') or len(self.eta.shape) != 1: | ||
if (isinstance(self.eta, six.string_types) and self.eta == 'auto') or (isinstance(self.eta, np.ndarray) and len(self.eta.shape) != 1): | ||
separately_explicit.append('eta') | ||
# Merge separately_explicit with separately. | ||
if separately: | ||
|
@@ -1054,13 +1059,22 @@ def load(cls, fname, *args, **kwargs): | |
result.state = super(LdaModel, cls).load(state_fname, *args, **kwargs) | ||
except Exception as e: | ||
logging.warning("failed to load state from %s: %s", state_fname, e) | ||
id2word_fname = utils.smart_extension(fname, '.id2word') | ||
if (os.path.isfile(id2word_fname)): | ||
try: | ||
result.id2word = utils.unpickle(id2word_fname) | ||
except Exception as e: | ||
logging.warning("failed to load id2word dictionary from %s: %s", id2word_fname, e) | ||
|
||
random_state_fname = utils.smart_extension(fname, '.random_state') | ||
if (os.path.isfile(random_state_fname)): | ||
result.random_state = utils.unpickle(random_state_fname) | ||
else: | ||
result.id2word = None | ||
logging.warning("random_state not stored on disk so using default value") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if random_state is not stored that means that it is a new version of the model and it is going to be loaded in the main pickle load. Please change the logic There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tmylk Your suggestion is a bit different from the earlier discussion on the issue. Hence, I want to make sure I understand the desired solution before making the changes. If I understand it correctly, what you are suggesting is :
But for models saved using a pre-0.13.2 version of Gensim, there was no Please correct me if I am wrong or missing something here. Otherwise, if this is indeed what we need, I could go ahead and make the appropriate changes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add a check that this indeed happened: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tmylk I believe I have understood the solution that we want. However, I have a minor doubt about where would the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tmylk Could you please respond to this query? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think the user should not think about additional files, he saves the whole model
I think it's correct There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The simple way to do this is check that random_state was loaded, if not - you There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @menshikh-iv Then there is no
And in this solution, we don't need to make any changes in the |
||
result.random_state = utils.get_random_state(None) | ||
|
||
if not result.id2word: | ||
id2word_fname = utils.smart_extension(fname, '.id2word') | ||
if (os.path.isfile(id2word_fname)): | ||
try: | ||
result.id2word = utils.unpickle(id2word_fname) | ||
except Exception as e: | ||
logging.warning("failed to load id2word dictionary from %s: %s", id2word_fname, e) | ||
else: | ||
result.id2word = None | ||
return result | ||
# endclass LdaModel |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,7 +46,7 @@ | |
|
||
def testfile(test_fname=''): | ||
# temporary data will be stored to this file | ||
fname = 'gensim_models_' + test_fname + '.tst' | ||
fname = 'gensim_models_' + test_fname + '.tst' | ||
return os.path.join(tempfile.gettempdir(), fname) | ||
|
||
|
||
|
@@ -247,9 +247,9 @@ def testGetDocumentTopics(self): | |
|
||
#Test case to use the get_document_topic function for the corpus | ||
all_topics = model.get_document_topics(self.corpus, per_word_topics=True) | ||
|
||
self.assertEqual(model.state.numdocs, len(corpus)) | ||
|
||
for topic in all_topics: | ||
self.assertTrue(isinstance(topic, tuple)) | ||
for k, v in topic[0]: # list of doc_topics | ||
|
@@ -269,9 +269,9 @@ def testGetDocumentTopics(self): | |
word_phi_count_na = 0 | ||
|
||
all_topics = model.get_document_topics(self.corpus, minimum_probability=0.8, minimum_phi_value=1.0, per_word_topics=True) | ||
|
||
self.assertEqual(model.state.numdocs, len(corpus)) | ||
|
||
for topic in all_topics: | ||
self.assertTrue(isinstance(topic, tuple)) | ||
for k, v in topic[0]: # list of doc_topics | ||
|
@@ -470,6 +470,29 @@ def testLargeMmapCompressed(self): | |
# test loading the large model arrays with mmap | ||
self.assertRaises(IOError, self.class_.load, fname, mmap='r') | ||
|
||
def testId2WordBackwardCompatibility(self): | ||
# load a model saved using a pre-0.13.2 version of Gensim | ||
pre_0_13_2_fname = datapath('pre_0_13_2_model') | ||
model_pre_0_13_2 = self.class_.load(pre_0_13_2_fname) | ||
|
||
model_topics = model_pre_0_13_2.print_topics(num_topics=3, num_words=3) | ||
|
||
for i in model_topics: | ||
self.assertTrue(isinstance(i[0], int)) | ||
self.assertTrue(isinstance(i[1], six.string_types)) | ||
|
||
def testRandomStateBackwardCompatibility(self): | ||
# load a model saved using a pre-0.13.2 version of Gensim | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this test is identical to the previous one. just one test is enough that checks all the fields There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure. I thought since these were two different issues so we'd want to put separate tests to verify both are resolved. I'll make the update according to your suggestion and remove the earlier test here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better to avoid code duplication |
||
pre_0_13_2_fname = datapath('pre_0_13_2_model') | ||
model_pre_0_13_2 = self.class_.load(pre_0_13_2_fname) | ||
|
||
# set `num_topics` less than `model_pre_0_13_2.num_topics` so that `model_pre_0_13_2.random_state` is used | ||
model_topics = model_pre_0_13_2.print_topics(num_topics=2, num_words=3) | ||
|
||
for i in model_topics: | ||
self.assertTrue(isinstance(i[0], int)) | ||
self.assertTrue(isinstance(i[1], six.string_types)) | ||
|
||
#endclass TestLdaModel | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the purpose of saving it separately?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tmylk I think It's better for case "file don't exist" (and more flexible)