Skip to content

Commit

Permalink
Add Sklearn API for Gensim models (piskvorky#1462)
Browse files Browse the repository at this point in the history
* created sklearn wrapper for Doc2Vec

* PEP8 fix

* added 'transform' function and refactored code

* updated d2v skl api code

* added unittests for sklearn api for d2v model

* fixed flake8 errors

* added skl api class for Text2Bow model

* updated docstring for d2vmodel api

* updated text2bow skl api code

* added unittests for text2bow skl api class

* updated 'testPipeline' and 'testTransform' for text2bow

* added 'tokenizer' param to text2bow skl api

* updated unittests for text2bow

* removed get_params and set_params functions from existing classes

* added tfidf api class

* added unittests for tfidf api class

* flake8 fixes

* added skl api for hdpmodel

* added unittests for hdp model api class

* flake8 fixes

* updated hdp api class

* added 'testPartialFit' and 'testPipeline' tests for hdp api class

* flake8 fixes

* added skl API class for phrases

* added unit tests for phrases API class

* flake8 fixes

* added 'testPartialFit' function for 'TestPhrasesTransformer'

* updated 'testPipeline' function for 'TestText2BowTransformer'

* updated code for transform function for HDP transformer

* updated tests as discussed in PR 1473

* added examples for new models in ipynb

* unpinned sklearn version for running unit-tests

* updated 'Pipeline' initialization format

* updated 'Pipeline' initialization format in ipynb
  • Loading branch information
chinmayapancholi13 authored and fabriciorsf committed Aug 23, 2017
1 parent b4bd541 commit a1d539f
Show file tree
Hide file tree
Showing 9 changed files with 1,217 additions and 45 deletions.
483 changes: 447 additions & 36 deletions docs/notebooks/sklearn_api.ipynb

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions gensim/sklearn_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,8 @@
from .ldaseqmodel import LdaSeqTransformer # noqa: F401
from .w2vmodel import W2VTransformer # noqa: F401
from .atmodel import AuthorTopicTransformer # noqa: F401
from .d2vmodel import D2VTransformer # noqa: F401
from .text2bow import Text2BowTransformer # noqa: F401
from .tfidf import TfIdfTransformer # noqa: F401
from .hdp import HdpTransformer # noqa: F401
from .phrases import PhrasesTransformer # noqa: F401
97 changes: 97 additions & 0 deletions gensim/sklearn_api/d2vmodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2011 Radim Rehurek <radimrehurek@seznam.cz>
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html

"""
Scikit learn interface for gensim for easy use of gensim with scikit-learn
Follows scikit-learn API conventions
"""

import numpy as np
from six import string_types
from sklearn.base import TransformerMixin, BaseEstimator
from sklearn.exceptions import NotFittedError

from gensim import models


class D2VTransformer(TransformerMixin, BaseEstimator):
"""
Base Doc2Vec module
"""

def __init__(self, dm_mean=None, dm=1, dbow_words=0, dm_concat=0,
dm_tag_count=1, docvecs=None, docvecs_mapfile=None,
comment=None, trim_rule=None, size=100, alpha=0.025,
window=5, min_count=5, max_vocab_size=None, sample=1e-3,
seed=1, workers=3, min_alpha=0.0001, hs=0, negative=5,
cbow_mean=1, hashfxn=hash, iter=5, sorted_vocab=1,
batch_words=10000):
"""
Sklearn api for Doc2Vec model. See gensim.models.Doc2Vec and gensim.models.Word2Vec for parameter details.
"""
self.gensim_model = None
self.dm_mean = dm_mean
self.dm = dm
self.dbow_words = dbow_words
self.dm_concat = dm_concat
self.dm_tag_count = dm_tag_count
self.docvecs = docvecs
self.docvecs_mapfile = docvecs_mapfile
self.comment = comment
self.trim_rule = trim_rule

# attributes associated with gensim.models.Word2Vec
self.size = size
self.alpha = alpha
self.window = window
self.min_count = min_count
self.max_vocab_size = max_vocab_size
self.sample = sample
self.seed = seed
self.workers = workers
self.min_alpha = min_alpha
self.hs = hs
self.negative = negative
self.cbow_mean = int(cbow_mean)
self.hashfxn = hashfxn
self.iter = iter
self.sorted_vocab = sorted_vocab
self.batch_words = batch_words

def fit(self, X, y=None):
"""
Fit the model according to the given training data.
Calls gensim.models.Doc2Vec
"""
self.gensim_model = models.Doc2Vec(documents=X, dm_mean=self.dm_mean, dm=self.dm,
dbow_words=self.dbow_words, dm_concat=self.dm_concat, dm_tag_count=self.dm_tag_count,
docvecs=self.docvecs, docvecs_mapfile=self.docvecs_mapfile, comment=self.comment,
trim_rule=self.trim_rule, size=self.size, alpha=self.alpha, window=self.window,
min_count=self.min_count, max_vocab_size=self.max_vocab_size, sample=self.sample,
seed=self.seed, workers=self.workers, min_alpha=self.min_alpha, hs=self.hs,
negative=self.negative, cbow_mean=self.cbow_mean, hashfxn=self.hashfxn,
iter=self.iter, sorted_vocab=self.sorted_vocab, batch_words=self.batch_words)
return self

def transform(self, docs):
"""
Return the vector representations for the input documents.
The input `docs` should be a list of lists like : [ ['calculus', 'mathematical'], ['geometry', 'operations', 'curves'] ]
or a single document like : ['calculus', 'mathematical']
"""
if self.gensim_model is None:
raise NotFittedError("This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method.")

# The input as array of array
check = lambda x: [x] if isinstance(x[0], string_types) else x
docs = check(docs)
X = [[] for _ in range(0, len(docs))]

for k, v in enumerate(docs):
doc_vec = self.gensim_model.infer_vector(v)
X[k] = doc_vec

return np.reshape(np.array(X), (len(docs), self.gensim_model.vector_size))
108 changes: 108 additions & 0 deletions gensim/sklearn_api/hdp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2011 Radim Rehurek <radimrehurek@seznam.cz>
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html

"""
Scikit learn interface for gensim for easy use of gensim with scikit-learn
Follows scikit-learn API conventions
"""

import numpy as np
from scipy import sparse
from sklearn.base import TransformerMixin, BaseEstimator
from sklearn.exceptions import NotFittedError

from gensim import models
from gensim import matutils


class HdpTransformer(TransformerMixin, BaseEstimator):
"""
Base HDP module
"""

def __init__(self, id2word, max_chunks=None, max_time=None,
chunksize=256, kappa=1.0, tau=64.0, K=15, T=150, alpha=1,
gamma=1, eta=0.01, scale=1.0, var_converge=0.0001,
outputdir=None, random_state=None):
"""
Sklearn api for HDP model. See gensim.models.HdpModel for parameter details.
"""
self.gensim_model = None
self.id2word = id2word
self.max_chunks = max_chunks
self.max_time = max_time
self.chunksize = chunksize
self.kappa = kappa
self.tau = tau
self.K = K
self.T = T
self.alpha = alpha
self.gamma = gamma
self.eta = eta
self.scale = scale
self.var_converge = var_converge
self.outputdir = outputdir
self.random_state = random_state

def fit(self, X, y=None):
"""
Fit the model according to the given training data.
Calls gensim.models.HdpModel
"""
if sparse.issparse(X):
corpus = matutils.Sparse2Corpus(X)
else:
corpus = X

self.gensim_model = models.HdpModel(corpus=corpus, id2word=self.id2word, max_chunks=self.max_chunks,
max_time=self.max_time, chunksize=self.chunksize, kappa=self.kappa, tau=self.tau,
K=self.K, T=self.T, alpha=self.alpha, gamma=self.gamma, eta=self.eta, scale=self.scale,
var_converge=self.var_converge, outputdir=self.outputdir, random_state=self.random_state)
return self

def transform(self, docs):
"""
Takes a list of documents as input ('docs').
Returns a matrix of topic distribution for the given document bow, where a_ij
indicates (topic_i, topic_probability_j).
The input `docs` should be in BOW format and can be a list of documents like : [ [(4, 1), (7, 1)], [(9, 1), (13, 1)], [(2, 1), (6, 1)] ]
or a single document like : [(4, 1), (7, 1)]
"""
if self.gensim_model is None:
raise NotFittedError("This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method.")

# The input as array of array
check = lambda x: [x] if isinstance(x[0], tuple) else x
docs = check(docs)
X = [[] for _ in range(0, len(docs))]

max_num_topics = 0
for k, v in enumerate(docs):
X[k] = self.gensim_model[v]
max_num_topics = max(max_num_topics, max(list(map(lambda x: x[0], X[k]))) + 1)

for k, v in enumerate(X):
# returning dense representation for compatibility with sklearn but we should go back to sparse representation in the future
dense_vec = matutils.sparse2full(v, max_num_topics)
X[k] = dense_vec

return np.reshape(np.array(X), (len(docs), max_num_topics))

def partial_fit(self, X):
"""
Train model over X.
"""
if sparse.issparse(X):
X = matutils.Sparse2Corpus(X)

if self.gensim_model is None:
self.gensim_model = models.HdpModel(id2word=self.id2word, max_chunks=self.max_chunks,
max_time=self.max_time, chunksize=self.chunksize, kappa=self.kappa, tau=self.tau,
K=self.K, T=self.T, alpha=self.alpha, gamma=self.gamma, eta=self.eta, scale=self.scale,
var_converge=self.var_converge, outputdir=self.outputdir, random_state=self.random_state)

self.gensim_model.update(corpus=X)
return self
68 changes: 68 additions & 0 deletions gensim/sklearn_api/phrases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2011 Radim Rehurek <radimrehurek@seznam.cz>
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html

"""
Scikit learn interface for gensim for easy use of gensim with scikit-learn
Follows scikit-learn API conventions
"""

from six import string_types
from sklearn.base import TransformerMixin, BaseEstimator
from sklearn.exceptions import NotFittedError

from gensim import models


class PhrasesTransformer(TransformerMixin, BaseEstimator):
"""
Base Phrases module
"""

def __init__(self, min_count=5, threshold=10.0, max_vocab_size=40000000,
delimiter=b'_', progress_per=10000):
"""
Sklearn wrapper for Phrases model.
"""
self.gensim_model = None
self.min_count = min_count
self.threshold = threshold
self.max_vocab_size = max_vocab_size
self.delimiter = delimiter
self.progress_per = progress_per

def fit(self, X, y=None):
"""
Fit the model according to the given training data.
"""
self.gensim_model = models.Phrases(sentences=X, min_count=self.min_count, threshold=self.threshold,
max_vocab_size=self.max_vocab_size, delimiter=self.delimiter, progress_per=self.progress_per)
return self

def transform(self, docs):
"""
Return the input documents to return phrase tokens.
"""
if self.gensim_model is None:
raise NotFittedError("This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method.")

# input as python lists
check = lambda x: [x] if isinstance(x[0], string_types) else x
docs = check(docs)
X = [[] for _ in range(0, len(docs))]

for k, v in enumerate(docs):
phrase_tokens = self.gensim_model[v]
X[k] = phrase_tokens

return X

def partial_fit(self, X):
if self.gensim_model is None:
self.gensim_model = models.Phrases(sentences=X, min_count=self.min_count, threshold=self.threshold,
max_vocab_size=self.max_vocab_size, delimiter=self.delimiter, progress_per=self.progress_per)

self.gensim_model.add_vocab(X)
return self
66 changes: 66 additions & 0 deletions gensim/sklearn_api/text2bow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2011 Radim Rehurek <radimrehurek@seznam.cz>
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html

"""
Scikit learn interface for gensim for easy use of gensim with scikit-learn
Follows scikit-learn API conventions
"""

from six import string_types
from sklearn.base import TransformerMixin, BaseEstimator
from sklearn.exceptions import NotFittedError

from gensim.corpora import Dictionary
from gensim.utils import tokenize


class Text2BowTransformer(TransformerMixin, BaseEstimator):
"""
Base Text2Bow module
"""

def __init__(self, prune_at=2000000, tokenizer=tokenize):
"""
Sklearn wrapper for Text2Bow model.
"""
self.gensim_model = None
self.prune_at = prune_at
self.tokenizer = tokenizer

def fit(self, X, y=None):
"""
Fit the model according to the given training data.
"""
tokenized_docs = list(map(lambda x: list(self.tokenizer(x)), X))
self.gensim_model = Dictionary(documents=tokenized_docs, prune_at=self.prune_at)
return self

def transform(self, docs):
"""
Return the BOW format for the input documents.
"""
if self.gensim_model is None:
raise NotFittedError("This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method.")

# input as python lists
check = lambda x: [x] if isinstance(x, string_types) else x
docs = check(docs)
tokenized_docs = list(map(lambda x: list(self.tokenizer(x)), docs))
X = [[] for _ in range(0, len(tokenized_docs))]

for k, v in enumerate(tokenized_docs):
bow_val = self.gensim_model.doc2bow(v)
X[k] = bow_val

return X

def partial_fit(self, X):
if self.gensim_model is None:
self.gensim_model = Dictionary(prune_at=self.prune_at)

tokenized_docs = list(map(lambda x: list(self.tokenizer(x)), X))
self.gensim_model.add_documents(tokenized_docs)
return self
Loading

0 comments on commit a1d539f

Please sign in to comment.