Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Added sklearn wrapper for LDASeq model #1405

Merged

Conversation

chinmayapancholi13
Copy link
Contributor

This PR adds a scikit-learn wrapper for Gensim's LDASeq model.

"""
Sklearn wrapper for LdaSeq model. Class derived from gensim.models.LdaSeqModel
"""
self.corpus = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why you needed a field for a corpus?

Copy link
Contributor Author

@chinmayapancholi13 chinmayapancholi13 Jun 14, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@menshikh-iv In my opinion, the user might be interested to know about the corpus used for training the model (using the get_params function). Should we continue to store this value?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chinmayapancholi13 No, sklearn does not store X, so we should not

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@menshikh-iv Yes, that is true for sklearn. Removing corpus attribute from all the wrappers then.

Sklearn wrapper for LdaSeq model. Class derived from gensim.models.LdaSeqModel
"""
self.corpus = None
self.model = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please do this field "private" (start with underscores)

initialize='gensim', sstats=None, lda_model=None, obs_variance=0.5, chain_variance=0.005, passes=10,
random_state=None, lda_inference_max_iter=25, em_min_iter=6, em_max_iter=20, chunksize=100)
"""
self.corpus = X
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't need to save X.

"""
Fit the model according to the given training data.
Calls gensim.models.LdaSeqModel:
>>> gensim.models.LdaSeqModel(corpus=None, time_slice=None, id2word=None, alphas=0.01, num_topics=10,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this block >>> ... , this example does not help for a new user.

Copy link
Contributor Author

@chinmayapancholi13 chinmayapancholi13 Jun 14, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@menshikh-iv Should we remove this >>> .... statement in all the model wrappers? This line basically tells us how the associated Gensim model is actually called.

Copy link
Contributor

@menshikh-iv menshikh-iv Jun 14, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You just need to specify the class that is used (you have already done above) and write where a user can read the documentation.

em_min_iter=self.em_min_iter, em_max_iter=self.em_max_iter, chunksize=self.chunksize)
return self

def transform(self, docs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chek case, when you create instance and call transform immediately (without fit), you need to raise exception like sklearn

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, please add an example of docs param in docstring.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@menshikh-iv For checking if the model has been fitted, would it be a good idea to check if self.gensim_model is None or not? This approach would clearly give an error when fit hasn't been called before calling transform but this also allows the user to set the value of self.gensim_model through set_params function (or even as wrapper.gensim_model=...) and then call transform function, which makes sense for us to allow.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I completely forgot about set_param, so, I think if you disable gensim_model in set_param, you can check model is None (it does not cover all cases, but covers the most obvious)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate the meaning of "disabling" gensim_model param from the function set_params?
Actually, gensim_model is a public attribute of the model so it can be set like ldaseq_wrapper.gensim_model = some_model, which is almost the same as using set_params function to set this value. So, checking whether self.gensim_model is None should be enough, right?
This would be like :

    def transform(self, docs):
        """
        Return the topic proportions for the documents passed.
        """
        if self.gensim_model is None:
            raise NotFittedError("This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method.")

        # The input as array of array
        check = lambda x: [x] if isinstance(x[0], tuple) else x
        ..........................................................................
        ..........................................................................
        ..........................................................................
        ..........................................................................

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, as a temporary option.

return np.reshape(np.array(X), (len(docs), self.num_topics))

def partial_fit(self, X):
raise NotImplementedError("'partial_fit' has not been implemented for the LDA Seq model")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LDA Seq model -> SklLdaSeqModel

for key in param_dict.keys():
self.assertEqual(model_params[key], param_dict[key])


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add persistence test with pickle

Copy link
Contributor

@menshikh-iv menshikh-iv Jun 16, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And add test with pipeline

score = text_ldaseq.score(corpus, test_target)
self.assertGreater(score, 0.50)

def testPersistence(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's sanity check only.
For persistence, you need to compare current and loaded models. For this purpose, you need to compare current and loaded inner matrices OR get corpus, transform it with both variant and compare results

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I have now added code for comparing the vectors transformed from original and loaded models, in addition to this sanity check. :)

text_ldaseq = Pipeline((('features', model,), ('classifier', clf)))
text_ldaseq.fit(corpus, test_target)
score = text_ldaseq.score(corpus, test_target)
self.assertGreater(score, 0.50)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's will be correct every time? No needed to fix seeds for reproducibility?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We now have a fixed seed which is set before the test testPipeline to ensure that we get similar values.

@menshikh-iv
Copy link
Contributor

Thank you @chinmayapancholi13 👍

@menshikh-iv menshikh-iv merged commit 477a3a3 into piskvorky:develop Jun 20, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants