-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[WIP] Add sklearn wrapper for LDA code (#932)
* adding basic sklearn wrapper for LDA code * updating changelog * adding test case,adding id2word,deleting showtopics * adding relevant ipynb * adding transfrom and other get methods and modifying print_topics * stylizing code to follow conventions * removing redundant default argumen values * adding partial_fit * adding a line in test_sklearn_integration * using LDAModel as Parent Class * adding docs, modifying getparam * changing class name.Adding comments * adding test case for update and transform * adding init * updating changes,fixed typo and changing file name * deleted base.py * adding better testPartialFit method and minor changes due to change in class * change name of test class * adding changes in classname to ipynb * Updating CHANGELOG.md * Updated Main Model. Added fit_predict to class for example * added sklearn countvectorizer example to ipynb * adding logistic regression example * adding if condition for csr_matrix to ldamodel * adding check for fit csrmatrix also stylizing code * minor bug.solved, fit should convert X to corpus * removing fit_predict.adding csr_matrix check for update * minor updates in ipynb * adding rst file * removed "basic" , added rst update to log * changing indentation in texts * added file preamble, removed unnecessary space * following more pep8 conventions * removing unnecessary comments * changing isinstance csr_matrix to issparse * changed to hanging indentation * changing main filename * changing module name in test * updating ipynb with main filename * changed class name * changed file name * fixing filename typo * adding html file * deleting html file * vertical indentation fixes * adding file to apiref.rst
- Loading branch information
Showing
7 changed files
with
532 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,325 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Using wrappers for Scikit learn API" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"This tutorial is about using gensim models as a part of your scikit learn workflow with the help of wrappers found at ```gensim.sklearn_integration.sklearn_wrapper_gensim_ldaModel```" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"The wrapper available (as of now) are :\n", | ||
"* LdaModel (```gensim.sklearn_integration.sklearn_wrapper_gensim_ldaModel.SklearnWrapperLdaModel```),which implements gensim's ```LdaModel``` in a scikit-learn interface" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### LdaModel" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"To use LdaModel begin with importing LdaModel wrapper" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from gensim.sklearn_integration.sklearn_wrapper_gensim_ldaModel import SklearnWrapperLdaModel" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Next we will create a dummy set of texts and convert it into a corpus" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from gensim.corpora import Dictionary\n", | ||
"texts = [['complier', 'system', 'computer'],\n", | ||
" ['eulerian', 'node', 'cycle', 'graph', 'tree', 'path'],\n", | ||
" ['graph', 'flow', 'network', 'graph'],\n", | ||
" ['loading', 'computer', 'system'],\n", | ||
" ['user', 'server', 'system'],\n", | ||
" ['tree','hamiltonian'],\n", | ||
" ['graph', 'trees'],\n", | ||
" ['computer', 'kernel', 'malfunction','computer'],\n", | ||
" ['server','system','computer']]\n", | ||
"dictionary = Dictionary(texts)\n", | ||
"corpus = [dictionary.doc2bow(text) for text in texts]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Then to run the LdaModel on it" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"WARNING:gensim.models.ldamodel:too few updates, training might not converge; consider increasing the number of passes or iterations to improve accuracy\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"[(0,\n", | ||
" u'0.164*\"computer\" + 0.117*\"system\" + 0.105*\"graph\" + 0.061*\"server\" + 0.057*\"tree\" + 0.046*\"malfunction\" + 0.045*\"kernel\" + 0.045*\"complier\" + 0.043*\"loading\" + 0.039*\"hamiltonian\"'),\n", | ||
" (1,\n", | ||
" u'0.102*\"graph\" + 0.083*\"system\" + 0.072*\"tree\" + 0.064*\"server\" + 0.059*\"user\" + 0.059*\"computer\" + 0.057*\"trees\" + 0.056*\"eulerian\" + 0.055*\"node\" + 0.052*\"flow\"')]" | ||
] | ||
}, | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"model=SklearnWrapperLdaModel(num_topics=2,id2word=dictionary,iterations=20, random_state=1)\n", | ||
"model.fit(corpus)\n", | ||
"model.print_topics(2)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"source": [ | ||
"### Integration with Sklearn" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"To provide a better example of how it can be used with Sklearn, Let's use CountVectorizer method of sklearn. For this example we will use [20 Newsgroups data set](http://qwone.com/~jason/20Newsgroups/). We will only use the categories rec.sport.baseball and sci.crypt and use it to generate topics." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"from gensim import matutils\n", | ||
"from gensim.models.ldamodel import LdaModel\n", | ||
"from sklearn.datasets import fetch_20newsgroups\n", | ||
"from sklearn.feature_extraction.text import CountVectorizer\n", | ||
"from gensim.sklearn_integration.sklearn_wrapper_gensim_ldaModel import SklearnWrapperLdaModel" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"rand = np.random.mtrand.RandomState(1) # set seed for getting same result\n", | ||
"cats = ['rec.sport.baseball', 'sci.crypt']\n", | ||
"data = fetch_20newsgroups(subset='train',\n", | ||
" categories=cats,\n", | ||
" shuffle=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Next, we use countvectorizer to convert the collection of text documents to a matrix of token counts." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"vec = CountVectorizer(min_df=10, stop_words='english')\n", | ||
"\n", | ||
"X = vec.fit_transform(data.data)\n", | ||
"vocab = vec.get_feature_names() #vocab to be converted to id2word \n", | ||
"\n", | ||
"id2word=dict([(i, s) for i, s in enumerate(vocab)])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Next, we just need to fit X and id2word to our Lda wrapper." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"[(0,\n", | ||
" u'0.018*\"cryptography\" + 0.018*\"face\" + 0.017*\"fierkelab\" + 0.008*\"abuse\" + 0.007*\"constitutional\" + 0.007*\"collection\" + 0.007*\"finish\" + 0.007*\"150\" + 0.007*\"fast\" + 0.006*\"difference\"'),\n", | ||
" (1,\n", | ||
" u'0.022*\"corporate\" + 0.022*\"accurate\" + 0.012*\"chance\" + 0.008*\"decipher\" + 0.008*\"example\" + 0.008*\"basically\" + 0.008*\"dawson\" + 0.008*\"cases\" + 0.008*\"consideration\" + 0.008*\"follow\"'),\n", | ||
" (2,\n", | ||
" u'0.034*\"argue\" + 0.031*\"456\" + 0.031*\"arithmetic\" + 0.024*\"courtesy\" + 0.020*\"beastmaster\" + 0.019*\"bitnet\" + 0.015*\"false\" + 0.015*\"classified\" + 0.014*\"cubs\" + 0.014*\"digex\"'),\n", | ||
" (3,\n", | ||
" u'0.108*\"abroad\" + 0.089*\"asking\" + 0.060*\"cryptography\" + 0.035*\"certain\" + 0.030*\"ciphertext\" + 0.030*\"book\" + 0.028*\"69\" + 0.028*\"demand\" + 0.028*\"87\" + 0.027*\"cracking\"'),\n", | ||
" (4,\n", | ||
" u'0.022*\"clark\" + 0.019*\"authentication\" + 0.017*\"candidates\" + 0.016*\"decryption\" + 0.015*\"attempt\" + 0.013*\"creation\" + 0.013*\"1993apr5\" + 0.013*\"acceptable\" + 0.013*\"algorithms\" + 0.013*\"employer\"')]" | ||
] | ||
}, | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"obj=SklearnWrapperLdaModel(id2word=id2word,num_topics=5,passes=20)\n", | ||
"lda=obj.fit(X)\n", | ||
"lda.print_topics()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"source": [ | ||
"#### Using together with Scikit learn's Logistic Regression" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Now lets try Sklearn's logistic classifier to classify the given categories into two types.Ideally we should get postive weights when cryptography is talked about and negative when baseball is talked about." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from sklearn import linear_model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"def print_features(clf, vocab, n=10):\n", | ||
" ''' Better printing for sorted list '''\n", | ||
" coef = clf.coef_[0]\n", | ||
" print 'Positive features: %s' % (' '.join(['%s:%.2f' % (vocab[j], coef[j]) for j in np.argsort(coef)[::-1][:n] if coef[j] > 0]))\n", | ||
" print 'Negative features: %s' % (' '.join(['%s:%.2f' % (vocab[j], coef[j]) for j in np.argsort(coef)[:n] if coef[j] < 0]))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Positive features: clipper:1.50 code:1.24 key:1.04 encryption:0.95 chip:0.37 nsa:0.37 government:0.36 uk:0.36 org:0.23 cryptography:0.23\n", | ||
"Negative features: baseball:-1.32 game:-0.71 year:-0.61 team:-0.38 edu:-0.27 games:-0.26 players:-0.23 ball:-0.17 season:-0.14 phillies:-0.11\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"clf=linear_model.LogisticRegression(penalty='l1', C=0.1) #l1 penalty used\n", | ||
"clf.fit(X,data.target)\n", | ||
"print_features(clf,vocab)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 2", | ||
"language": "python", | ||
"name": "python2" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 2 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython2", | ||
"version": "2.7.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
9 changes: 9 additions & 0 deletions
9
docs/src/sklearn_integration/sklearn_wrapper_gensim_ldamodel.rst
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
:mod:`sklearn_integration.sklearn_wrapper_gensim_ldamodel.SklearnWrapperLdaModel` -- Scikit learn wrapper for Latent Dirichlet Allocation | ||
====================================================== | ||
|
||
.. automodule:: gensim.sklearn_integration.sklearn_wrapper_gensim_ldamodel.SklearnWrapperLdaModel | ||
:synopsis: Scikit learn wrapper for LDA model | ||
:members: | ||
:inherited-members: | ||
:undoc-members: | ||
:show-inheritance: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (C) 2011 Radim Rehurek <radimrehurek@seznam.cz> | ||
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html | ||
"""Scikit learn wrapper for gensim. | ||
Contains various gensim based implementations which match with scikit-learn standards. | ||
See [1] for complete set of conventions. | ||
[1] http://scikit-learn.org/stable/developers/ | ||
""" |
Oops, something went wrong.