-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Adding sklearn wrapper for LDA code #932
Changes from 28 commits
08f417c
61a6f8c
66be324
cffa95b
10badc6
62a4d2f
b7eff2d
2a193fd
a32f8dc
a048ddc
ac1d28e
0d6cc0a
5d8c1a6
894784c
7a5ca4b
b35baba
13a136d
682f045
9fda951
380ea5f
e2485d4
3015896
a76eda4
97c1530
20a63ac
c0b2c5c
bd656a8
d749ba0
21119c5
14f984b
a3895b5
f832737
bc352a0
7cc39da
0ba233c
e23a8a4
041a32e
e7120f0
8a0950d
bd8bced
bb5872b
777576e
e50c3f9
e521269
51931fa
7ba30d6
82d1fdc
4f3441e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,10 @@ Unreleased: | |
|
||
None | ||
|
||
0.13.5, 2016-12-31 | ||
|
||
* Added sklearn wrapper for LdaModel (Basic LDA Model) along with relevant test cases and ipynb draft. (@AadityaJ,[#932](https://github.com/RaRe-Technologies/gensim/pull/932)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is "Basic LDA Model"? |
||
|
||
0.13.4.1, 2017-01-04 | ||
|
||
* Disable direct access warnings on save and load of Word2vec/Doc2vec (@tmylk, [#1072](https://github.com/RaRe-Technologies/gensim/pull/1072)) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,325 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Using wrappers for Scikit learn API" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"This tutorial is about using gensim models as a part of your scikit learn workflow with the help of wrappers found at ```gensim.sklearn_integration.SklearnWrapperGensimLdaModel```" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"The wrapper available (as of now) are :\n", | ||
"* LdaModel (```gensim.sklearn_integration.SklearnWrapperGensimLdaModel.SklearnWrapperLdaModel```),which implements gensim's ```LdaModel``` in a scikit-learn interface" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### LdaModel" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"To use LdaModel begin with importing LdaModel wrapper" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from gensim.sklearn_integration.SklearnWrapperGensimLdaModel import SklearnWrapperLdaModel" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Next we will create a dummy set of texts and convert it into a corpus" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add the examples to ipynb from https://gist.github.com/AadityaJ/c98da3d01f76f068242c17b5e1593973 |
||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from gensim.corpora import Dictionary\n", | ||
"texts = [['complier', 'system', 'computer'],\n", | ||
" ['eulerian', 'node', 'cycle', 'graph', 'tree', 'path'],\n", | ||
" ['graph', 'flow', 'network', 'graph'],\n", | ||
" ['loading', 'computer', 'system'],\n", | ||
" ['user', 'server', 'system'],\n", | ||
" ['tree','hamiltonian'],\n", | ||
" ['graph', 'trees'],\n", | ||
" ['computer', 'kernel', 'malfunction','computer'],\n", | ||
" ['server','system','computer']]\n", | ||
"dictionary = Dictionary(texts)\n", | ||
"corpus = [dictionary.doc2bow(text) for text in texts]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Then to run the LdaModel on it" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"WARNING:gensim.models.ldamodel:too few updates, training might not converge; consider increasing the number of passes or iterations to improve accuracy\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"[(0,\n", | ||
" u'0.164*\"computer\" + 0.117*\"system\" + 0.105*\"graph\" + 0.061*\"server\" + 0.057*\"tree\" + 0.046*\"malfunction\" + 0.045*\"kernel\" + 0.045*\"complier\" + 0.043*\"loading\" + 0.039*\"hamiltonian\"'),\n", | ||
" (1,\n", | ||
" u'0.102*\"graph\" + 0.083*\"system\" + 0.072*\"tree\" + 0.064*\"server\" + 0.059*\"user\" + 0.059*\"computer\" + 0.057*\"trees\" + 0.056*\"eulerian\" + 0.055*\"node\" + 0.052*\"flow\"')]" | ||
] | ||
}, | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"model=SklearnWrapperLdaModel(num_topics=2,id2word=dictionary,iterations=20, random_state=1)\n", | ||
"model.fit(corpus)\n", | ||
"model.print_topics(2)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"source": [ | ||
"### Integration with Sklearn" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"To provide a better example of how it can be used with Sklearn, Let's use CountVectorizer method of sklearn. For this example we will use [20 Newsgroups data set](http://qwone.com/~jason/20Newsgroups/). We will only use the categories rec.sport.baseball and sci.crypt and use it to generate topics." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"from gensim import matutils\n", | ||
"from gensim.models.ldamodel import LdaModel\n", | ||
"from sklearn.datasets import fetch_20newsgroups\n", | ||
"from sklearn.feature_extraction.text import CountVectorizer\n", | ||
"from gensim.sklearn_integration.SklearnWrapperGensimLdaModel import SklearnWrapperLdaModel" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"rand = np.random.mtrand.RandomState(1) # set seed for getting same result\n", | ||
"cats = ['rec.sport.baseball', 'sci.crypt']\n", | ||
"data = fetch_20newsgroups(subset='train',\n", | ||
" categories=cats,\n", | ||
" shuffle=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Next, we use countvectorizer to convert the collection of text documents to a matrix of token counts." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"vec = CountVectorizer(min_df=10, stop_words='english')\n", | ||
"\n", | ||
"X = vec.fit_transform(data.data)\n", | ||
"vocab = vec.get_feature_names() #vocab to be converted to id2word \n", | ||
"\n", | ||
"id2word=dict([(i, s) for i, s in enumerate(vocab)])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Next, we just need to fit X and id2word to our Lda wrapper." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"[(0,\n", | ||
" u'0.027*\"accurate\" + 0.022*\"corporate\" + 0.012*\"consideration\" + 0.012*\"chance\" + 0.011*\"decipher\" + 0.011*\"example\" + 0.010*\"basically\" + 0.010*\"cases\" + 0.010*\"follow\" + 0.009*\"dawson\"'),\n", | ||
" (1,\n", | ||
" u'0.018*\"face\" + 0.012*\"fierkelab\" + 0.009*\"cryptography\" + 0.008*\"abuse\" + 0.007*\"150\" + 0.007*\"finish\" + 0.006*\"communication\" + 0.006*\"fast\" + 0.006*\"constitutional\" + 0.006*\"database\"'),\n", | ||
" (2,\n", | ||
" u'0.074*\"abroad\" + 0.066*\"asking\" + 0.055*\"cryptography\" + 0.030*\"arithmetic\" + 0.028*\"argue\" + 0.028*\"ciphertext\" + 0.025*\"456\" + 0.023*\"courtesy\" + 0.019*\"facts\" + 0.015*\"beastmaster\"'),\n", | ||
" (3,\n", | ||
" u'0.014*\"cryptography\" + 0.014*\"clark\" + 0.012*\"authentication\" + 0.008*\"corporate\" + 0.008*\"1993apr5\" + 0.007*\"brett\" + 0.006*\"acceptable\" + 0.006*\"chance\" + 0.006*\"considering\" + 0.006*\"accurate\"'),\n", | ||
" (4,\n", | ||
" u'0.051*\"certain\" + 0.043*\"book\" + 0.041*\"69\" + 0.040*\"demand\" + 0.040*\"87\" + 0.039*\"cracking\" + 0.039*\"abroad\" + 0.037*\"farm\" + 0.019*\"asking\" + 0.015*\"cryptosystems\"')]" | ||
] | ||
}, | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"obj=SklearnWrapperLdaModel(id2word=id2word,num_topics=5,passes=20)\n", | ||
"lda=obj.fit_predict(X)\n", | ||
"lda.print_topics()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"source": [ | ||
"#### Comparison to Sklearn's Logistic Regression" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You mean "Using Gensim LDA together with Sklearn Logistic Regression"? |
||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Now lets try Sklearn's logistic classifier to classify the given categories into two types.Ideally we should get postive weights when cryptography is talked about and negative when baseball is talked about." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from sklearn import linear_model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"def print_features(clf, vocab, n=10):\n", | ||
" ''' Better printing for sorted list '''\n", | ||
" coef = clf.coef_[0]\n", | ||
" print 'Positive features: %s' % (' '.join(['%s:%.2f' % (vocab[j], coef[j]) for j in np.argsort(coef)[::-1][:n] if coef[j] > 0]))\n", | ||
" print 'Negative features: %s' % (' '.join(['%s:%.2f' % (vocab[j], coef[j]) for j in np.argsort(coef)[:n] if coef[j] < 0]))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Positive features: clipper:1.49 code:1.23 key:1.04 encryption:0.95 government:0.37 chip:0.37 nsa:0.37 uk:0.36 org:0.23 cryptography:0.23\n", | ||
"Negative features: baseball:-1.33 game:-0.72 year:-0.61 team:-0.38 edu:-0.27 games:-0.27 players:-0.23 ball:-0.17 season:-0.14 phillies:-0.11\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"clf=linear_model.LogisticRegression(penalty='l1', C=0.1) #l1 penalty used\n", | ||
"clf.fit(X,data.target)\n", | ||
"print_features(clf,vocab)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 2", | ||
"language": "python", | ||
"name": "python2" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 2 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython2", | ||
"version": "2.7.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add this to the top section for 0.13.5 release