forked from piskvorky/gensim
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Sklearn API for Gensim models (piskvorky#1462)
* created sklearn wrapper for Doc2Vec * PEP8 fix * added 'transform' function and refactored code * updated d2v skl api code * added unittests for sklearn api for d2v model * fixed flake8 errors * added skl api class for Text2Bow model * updated docstring for d2vmodel api * updated text2bow skl api code * added unittests for text2bow skl api class * updated 'testPipeline' and 'testTransform' for text2bow * added 'tokenizer' param to text2bow skl api * updated unittests for text2bow * removed get_params and set_params functions from existing classes * added tfidf api class * added unittests for tfidf api class * flake8 fixes * added skl api for hdpmodel * added unittests for hdp model api class * flake8 fixes * updated hdp api class * added 'testPartialFit' and 'testPipeline' tests for hdp api class * flake8 fixes * added skl API class for phrases * added unit tests for phrases API class * flake8 fixes * added 'testPartialFit' function for 'TestPhrasesTransformer' * updated 'testPipeline' function for 'TestText2BowTransformer' * updated code for transform function for HDP transformer * updated tests as discussed in PR 1473 * added examples for new models in ipynb * unpinned sklearn version for running unit-tests * updated 'Pipeline' initialization format * updated 'Pipeline' initialization format in ipynb
- Loading branch information
1 parent
b4bd541
commit a1d539f
Showing
9 changed files
with
1,217 additions
and
45 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (C) 2011 Radim Rehurek <radimrehurek@seznam.cz> | ||
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html | ||
|
||
""" | ||
Scikit learn interface for gensim for easy use of gensim with scikit-learn | ||
Follows scikit-learn API conventions | ||
""" | ||
|
||
import numpy as np | ||
from six import string_types | ||
from sklearn.base import TransformerMixin, BaseEstimator | ||
from sklearn.exceptions import NotFittedError | ||
|
||
from gensim import models | ||
|
||
|
||
class D2VTransformer(TransformerMixin, BaseEstimator): | ||
""" | ||
Base Doc2Vec module | ||
""" | ||
|
||
def __init__(self, dm_mean=None, dm=1, dbow_words=0, dm_concat=0, | ||
dm_tag_count=1, docvecs=None, docvecs_mapfile=None, | ||
comment=None, trim_rule=None, size=100, alpha=0.025, | ||
window=5, min_count=5, max_vocab_size=None, sample=1e-3, | ||
seed=1, workers=3, min_alpha=0.0001, hs=0, negative=5, | ||
cbow_mean=1, hashfxn=hash, iter=5, sorted_vocab=1, | ||
batch_words=10000): | ||
""" | ||
Sklearn api for Doc2Vec model. See gensim.models.Doc2Vec and gensim.models.Word2Vec for parameter details. | ||
""" | ||
self.gensim_model = None | ||
self.dm_mean = dm_mean | ||
self.dm = dm | ||
self.dbow_words = dbow_words | ||
self.dm_concat = dm_concat | ||
self.dm_tag_count = dm_tag_count | ||
self.docvecs = docvecs | ||
self.docvecs_mapfile = docvecs_mapfile | ||
self.comment = comment | ||
self.trim_rule = trim_rule | ||
|
||
# attributes associated with gensim.models.Word2Vec | ||
self.size = size | ||
self.alpha = alpha | ||
self.window = window | ||
self.min_count = min_count | ||
self.max_vocab_size = max_vocab_size | ||
self.sample = sample | ||
self.seed = seed | ||
self.workers = workers | ||
self.min_alpha = min_alpha | ||
self.hs = hs | ||
self.negative = negative | ||
self.cbow_mean = int(cbow_mean) | ||
self.hashfxn = hashfxn | ||
self.iter = iter | ||
self.sorted_vocab = sorted_vocab | ||
self.batch_words = batch_words | ||
|
||
def fit(self, X, y=None): | ||
""" | ||
Fit the model according to the given training data. | ||
Calls gensim.models.Doc2Vec | ||
""" | ||
self.gensim_model = models.Doc2Vec(documents=X, dm_mean=self.dm_mean, dm=self.dm, | ||
dbow_words=self.dbow_words, dm_concat=self.dm_concat, dm_tag_count=self.dm_tag_count, | ||
docvecs=self.docvecs, docvecs_mapfile=self.docvecs_mapfile, comment=self.comment, | ||
trim_rule=self.trim_rule, size=self.size, alpha=self.alpha, window=self.window, | ||
min_count=self.min_count, max_vocab_size=self.max_vocab_size, sample=self.sample, | ||
seed=self.seed, workers=self.workers, min_alpha=self.min_alpha, hs=self.hs, | ||
negative=self.negative, cbow_mean=self.cbow_mean, hashfxn=self.hashfxn, | ||
iter=self.iter, sorted_vocab=self.sorted_vocab, batch_words=self.batch_words) | ||
return self | ||
|
||
def transform(self, docs): | ||
""" | ||
Return the vector representations for the input documents. | ||
The input `docs` should be a list of lists like : [ ['calculus', 'mathematical'], ['geometry', 'operations', 'curves'] ] | ||
or a single document like : ['calculus', 'mathematical'] | ||
""" | ||
if self.gensim_model is None: | ||
raise NotFittedError("This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method.") | ||
|
||
# The input as array of array | ||
check = lambda x: [x] if isinstance(x[0], string_types) else x | ||
docs = check(docs) | ||
X = [[] for _ in range(0, len(docs))] | ||
|
||
for k, v in enumerate(docs): | ||
doc_vec = self.gensim_model.infer_vector(v) | ||
X[k] = doc_vec | ||
|
||
return np.reshape(np.array(X), (len(docs), self.gensim_model.vector_size)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (C) 2011 Radim Rehurek <radimrehurek@seznam.cz> | ||
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html | ||
|
||
""" | ||
Scikit learn interface for gensim for easy use of gensim with scikit-learn | ||
Follows scikit-learn API conventions | ||
""" | ||
|
||
import numpy as np | ||
from scipy import sparse | ||
from sklearn.base import TransformerMixin, BaseEstimator | ||
from sklearn.exceptions import NotFittedError | ||
|
||
from gensim import models | ||
from gensim import matutils | ||
|
||
|
||
class HdpTransformer(TransformerMixin, BaseEstimator): | ||
""" | ||
Base HDP module | ||
""" | ||
|
||
def __init__(self, id2word, max_chunks=None, max_time=None, | ||
chunksize=256, kappa=1.0, tau=64.0, K=15, T=150, alpha=1, | ||
gamma=1, eta=0.01, scale=1.0, var_converge=0.0001, | ||
outputdir=None, random_state=None): | ||
""" | ||
Sklearn api for HDP model. See gensim.models.HdpModel for parameter details. | ||
""" | ||
self.gensim_model = None | ||
self.id2word = id2word | ||
self.max_chunks = max_chunks | ||
self.max_time = max_time | ||
self.chunksize = chunksize | ||
self.kappa = kappa | ||
self.tau = tau | ||
self.K = K | ||
self.T = T | ||
self.alpha = alpha | ||
self.gamma = gamma | ||
self.eta = eta | ||
self.scale = scale | ||
self.var_converge = var_converge | ||
self.outputdir = outputdir | ||
self.random_state = random_state | ||
|
||
def fit(self, X, y=None): | ||
""" | ||
Fit the model according to the given training data. | ||
Calls gensim.models.HdpModel | ||
""" | ||
if sparse.issparse(X): | ||
corpus = matutils.Sparse2Corpus(X) | ||
else: | ||
corpus = X | ||
|
||
self.gensim_model = models.HdpModel(corpus=corpus, id2word=self.id2word, max_chunks=self.max_chunks, | ||
max_time=self.max_time, chunksize=self.chunksize, kappa=self.kappa, tau=self.tau, | ||
K=self.K, T=self.T, alpha=self.alpha, gamma=self.gamma, eta=self.eta, scale=self.scale, | ||
var_converge=self.var_converge, outputdir=self.outputdir, random_state=self.random_state) | ||
return self | ||
|
||
def transform(self, docs): | ||
""" | ||
Takes a list of documents as input ('docs'). | ||
Returns a matrix of topic distribution for the given document bow, where a_ij | ||
indicates (topic_i, topic_probability_j). | ||
The input `docs` should be in BOW format and can be a list of documents like : [ [(4, 1), (7, 1)], [(9, 1), (13, 1)], [(2, 1), (6, 1)] ] | ||
or a single document like : [(4, 1), (7, 1)] | ||
""" | ||
if self.gensim_model is None: | ||
raise NotFittedError("This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method.") | ||
|
||
# The input as array of array | ||
check = lambda x: [x] if isinstance(x[0], tuple) else x | ||
docs = check(docs) | ||
X = [[] for _ in range(0, len(docs))] | ||
|
||
max_num_topics = 0 | ||
for k, v in enumerate(docs): | ||
X[k] = self.gensim_model[v] | ||
max_num_topics = max(max_num_topics, max(list(map(lambda x: x[0], X[k]))) + 1) | ||
|
||
for k, v in enumerate(X): | ||
# returning dense representation for compatibility with sklearn but we should go back to sparse representation in the future | ||
dense_vec = matutils.sparse2full(v, max_num_topics) | ||
X[k] = dense_vec | ||
|
||
return np.reshape(np.array(X), (len(docs), max_num_topics)) | ||
|
||
def partial_fit(self, X): | ||
""" | ||
Train model over X. | ||
""" | ||
if sparse.issparse(X): | ||
X = matutils.Sparse2Corpus(X) | ||
|
||
if self.gensim_model is None: | ||
self.gensim_model = models.HdpModel(id2word=self.id2word, max_chunks=self.max_chunks, | ||
max_time=self.max_time, chunksize=self.chunksize, kappa=self.kappa, tau=self.tau, | ||
K=self.K, T=self.T, alpha=self.alpha, gamma=self.gamma, eta=self.eta, scale=self.scale, | ||
var_converge=self.var_converge, outputdir=self.outputdir, random_state=self.random_state) | ||
|
||
self.gensim_model.update(corpus=X) | ||
return self |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (C) 2011 Radim Rehurek <radimrehurek@seznam.cz> | ||
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html | ||
|
||
""" | ||
Scikit learn interface for gensim for easy use of gensim with scikit-learn | ||
Follows scikit-learn API conventions | ||
""" | ||
|
||
from six import string_types | ||
from sklearn.base import TransformerMixin, BaseEstimator | ||
from sklearn.exceptions import NotFittedError | ||
|
||
from gensim import models | ||
|
||
|
||
class PhrasesTransformer(TransformerMixin, BaseEstimator): | ||
""" | ||
Base Phrases module | ||
""" | ||
|
||
def __init__(self, min_count=5, threshold=10.0, max_vocab_size=40000000, | ||
delimiter=b'_', progress_per=10000): | ||
""" | ||
Sklearn wrapper for Phrases model. | ||
""" | ||
self.gensim_model = None | ||
self.min_count = min_count | ||
self.threshold = threshold | ||
self.max_vocab_size = max_vocab_size | ||
self.delimiter = delimiter | ||
self.progress_per = progress_per | ||
|
||
def fit(self, X, y=None): | ||
""" | ||
Fit the model according to the given training data. | ||
""" | ||
self.gensim_model = models.Phrases(sentences=X, min_count=self.min_count, threshold=self.threshold, | ||
max_vocab_size=self.max_vocab_size, delimiter=self.delimiter, progress_per=self.progress_per) | ||
return self | ||
|
||
def transform(self, docs): | ||
""" | ||
Return the input documents to return phrase tokens. | ||
""" | ||
if self.gensim_model is None: | ||
raise NotFittedError("This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method.") | ||
|
||
# input as python lists | ||
check = lambda x: [x] if isinstance(x[0], string_types) else x | ||
docs = check(docs) | ||
X = [[] for _ in range(0, len(docs))] | ||
|
||
for k, v in enumerate(docs): | ||
phrase_tokens = self.gensim_model[v] | ||
X[k] = phrase_tokens | ||
|
||
return X | ||
|
||
def partial_fit(self, X): | ||
if self.gensim_model is None: | ||
self.gensim_model = models.Phrases(sentences=X, min_count=self.min_count, threshold=self.threshold, | ||
max_vocab_size=self.max_vocab_size, delimiter=self.delimiter, progress_per=self.progress_per) | ||
|
||
self.gensim_model.add_vocab(X) | ||
return self |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (C) 2011 Radim Rehurek <radimrehurek@seznam.cz> | ||
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html | ||
|
||
""" | ||
Scikit learn interface for gensim for easy use of gensim with scikit-learn | ||
Follows scikit-learn API conventions | ||
""" | ||
|
||
from six import string_types | ||
from sklearn.base import TransformerMixin, BaseEstimator | ||
from sklearn.exceptions import NotFittedError | ||
|
||
from gensim.corpora import Dictionary | ||
from gensim.utils import tokenize | ||
|
||
|
||
class Text2BowTransformer(TransformerMixin, BaseEstimator): | ||
""" | ||
Base Text2Bow module | ||
""" | ||
|
||
def __init__(self, prune_at=2000000, tokenizer=tokenize): | ||
""" | ||
Sklearn wrapper for Text2Bow model. | ||
""" | ||
self.gensim_model = None | ||
self.prune_at = prune_at | ||
self.tokenizer = tokenizer | ||
|
||
def fit(self, X, y=None): | ||
""" | ||
Fit the model according to the given training data. | ||
""" | ||
tokenized_docs = list(map(lambda x: list(self.tokenizer(x)), X)) | ||
self.gensim_model = Dictionary(documents=tokenized_docs, prune_at=self.prune_at) | ||
return self | ||
|
||
def transform(self, docs): | ||
""" | ||
Return the BOW format for the input documents. | ||
""" | ||
if self.gensim_model is None: | ||
raise NotFittedError("This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method.") | ||
|
||
# input as python lists | ||
check = lambda x: [x] if isinstance(x, string_types) else x | ||
docs = check(docs) | ||
tokenized_docs = list(map(lambda x: list(self.tokenizer(x)), docs)) | ||
X = [[] for _ in range(0, len(tokenized_docs))] | ||
|
||
for k, v in enumerate(tokenized_docs): | ||
bow_val = self.gensim_model.doc2bow(v) | ||
X[k] = bow_val | ||
|
||
return X | ||
|
||
def partial_fit(self, X): | ||
if self.gensim_model is None: | ||
self.gensim_model = Dictionary(prune_at=self.prune_at) | ||
|
||
tokenized_docs = list(map(lambda x: list(self.tokenizer(x)), X)) | ||
self.gensim_model.add_documents(tokenized_docs) | ||
return self |
Oops, something went wrong.