-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add smart information retrieval system for TfidfModel
. Fix #1785
#1791
Changes from 17 commits
5e1830b
6cef4b1
e8a3f16
648bf21
a6f1afb
d091138
951c549
40c0558
b35344c
0917e75
bef79cc
d3d431c
0e6f21e
7ee7560
f2251a4
b2def84
5b2d37a
ac4b154
0bacc08
51e0eb9
3039732
99e6a6f
7d63d9c
e5140f8
4afbadd
d2fe235
52ee3c4
48e84f7
6d2f47b
607ba61
d0878a4
b544c9c
c4e3656
98ffde5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,51 +6,100 @@ | |
|
||
|
||
import logging | ||
import math | ||
from functools import partial | ||
|
||
from gensim import interfaces, matutils, utils | ||
from six import iteritems | ||
|
||
import numpy as np | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def resolve_weights(smartirs): | ||
""" | ||
Checks for validity of smartirs parameter. | ||
""" | ||
if not isinstance(smartirs, str) or len(smartirs) != 3: | ||
raise ValueError("Expected a string of length 3 except got " + smartirs) | ||
|
||
w_tf, w_df, w_n = smartirs | ||
|
||
if w_tf not in 'nlabL': | ||
raise ValueError("Expected term frequency weight to be one of 'nlabL', except got " + w_tf) | ||
|
||
if w_df not in 'ntp': | ||
raise ValueError("Expected inverse document frequency weight to be one of 'ntp', except got " + w_df) | ||
|
||
if w_n not in 'ncb': | ||
raise ValueError("Expected normalization weight to be one of 'ncb', except got " + w_n) | ||
|
||
return w_tf, w_df, w_n | ||
|
||
|
||
def df2idf(docfreq, totaldocs, log_base=2.0, add=0.0): | ||
""" | ||
Compute default inverse-document-frequency for a term with document frequency `doc_freq`:: | ||
|
||
idf = add + log(totaldocs / doc_freq) | ||
idf = add + log(totaldocs / doc_freq) | ||
""" | ||
return add + math.log(1.0 * totaldocs / docfreq, log_base) | ||
return add + np.log(float(totaldocs) / docfreq) / np.log(log_base) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's a reason to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consistency, I am using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No problem this is only a question :) |
||
|
||
|
||
def precompute_idfs(wglobal, dfs, total_docs): | ||
"""Precompute the inverse document frequency mapping for all terms.""" | ||
""" | ||
Precompute the inverse document frequency mapping for all terms. | ||
""" | ||
# not strictly necessary and could be computed on the fly in TfidfModel__getitem__. | ||
# this method is here just to speed things up a little. | ||
return {termid: wglobal(df, total_docs) for termid, df in iteritems(dfs)} | ||
|
||
|
||
def updated_wlocal(tf, n_tf): | ||
if n_tf == "n": | ||
return tf | ||
elif n_tf == "l": | ||
return 1 + np.log(tf) / np.log(2) | ||
elif n_tf == "a": | ||
return 0.5 + (0.5 * tf / tf.max(axis=0)) | ||
elif n_tf == "b": | ||
return tf.astype('bool').astype('int') | ||
elif n_tf == "L": | ||
return (1 + np.log(tf) / np.log(2)) / (1 + np.log(tf.mean(axis=0) / np.log(2))) | ||
|
||
|
||
def updated_wglobal(docfreq, totaldocs, n_df): # TODO rename it (to avoid confusion) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please remove |
||
if n_df == "n": | ||
return utils.identity(docfreq) | ||
elif n_df == "t": | ||
return np.log(1.0 * totaldocs / docfreq) / np.log(2) | ||
elif n_df == "p": | ||
return np.log((1.0 * totaldocs - docfreq) / docfreq) / np.log(2) | ||
|
||
|
||
def updated_normalize(x, n_n): # TODO rename it (to avoid confusion) | ||
if n_n == "n": | ||
return x | ||
elif n_n == "c": | ||
return matutils.unitvec(x) | ||
|
||
|
||
class TfidfModel(interfaces.TransformationABC): | ||
""" | ||
Objects of this class realize the transformation between word-document co-occurrence | ||
matrix (integers) into a locally/globally weighted TF_IDF matrix (positive floats). | ||
|
||
The main methods are: | ||
|
||
1. constructor, which calculates inverse document counts for all terms in the training corpus. | ||
2. the [] method, which transforms a simple count representation into the TfIdf | ||
space. | ||
|
||
Examples | ||
-------- | ||
>>> tfidf = TfidfModel(corpus) | ||
>>> print(tfidf[some_doc]) | ||
>>> tfidf.save('/tmp/foo.tfidf_model') | ||
|
||
Model persistency is achieved via its load/save methods. | ||
|
||
""" | ||
|
||
def __init__(self, corpus=None, id2word=None, dictionary=None, | ||
wlocal=utils.identity, wglobal=df2idf, normalize=True): | ||
def __init__(self, corpus=None, id2word=None, dictionary=None, wlocal=utils.identity, | ||
wglobal=df2idf, normalize=True, smartirs=None): | ||
""" | ||
Compute tf-idf by multiplying a local component (term frequency) with a | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you convert all docstrings in this file to numpy-style, according to my previous comment #1780 (comment) |
||
global component (inverse document frequency), and normalizing | ||
|
@@ -65,23 +114,64 @@ def __init__(self, corpus=None, id2word=None, dictionary=None, | |
|
||
so you can plug in your own custom `wlocal` and `wglobal` functions. | ||
|
||
Default for `wlocal` is identity (other options: math.sqrt, math.log1p, ...) | ||
and default for `wglobal` is `log_2(total_docs / doc_freq)`, giving the | ||
formula above. | ||
|
||
`normalize` dictates how the final transformed vectors will be normalized. | ||
`normalize=True` means set to unit length (default); `False` means don't | ||
normalize. You can also set `normalize` to your own function that accepts | ||
and returns a sparse vector. | ||
Parameters | ||
---------- | ||
corpus : dictionary.doc2bow | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. type should be |
||
Corpus is a list of sets where each set has two elements. First being the termid and | ||
second being the term frequency of each term in the document. | ||
id2word : dict | ||
id2word is an optional dictionary that maps the word_id to a token. | ||
In case id2word isn’t specified the mapping id2word[word_id] = str(word_id) will be used. | ||
dictionary :corpora.Dictionary | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. type should be
|
||
If `dictionary` is specified, it must be a `corpora.Dictionary` object | ||
and it will be used to directly construct the inverse document frequency | ||
mapping (then `corpus`, if specified, is ignored). | ||
wlocals : user specified function | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of
should be
everywhere |
||
Default for `wlocal` is identity (other options: math.sqrt, math.log1p, ...) | ||
wglobal : user specified function | ||
Default for `wglobal` is `log_2(total_docs / doc_freq)`, giving the | ||
formula above. | ||
normalize : user specified function | ||
It dictates how the final transformed vectors will be normalized. | ||
`normalize=True` means set to unit length (default); `False` means don't | ||
normalize. You can also set `normalize` to your own function that accepts | ||
and returns a sparse vector. | ||
smartirs : {'None' ,'str'} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. str, optional |
||
`smartirs` or SMART (System for the Mechanical Analysis and Retrieval of Text) | ||
Information Retrieval System, a mnemonic scheme for denoting tf-idf weighting | ||
variants in the vector space model. The mnemonic for representing a combination | ||
of weights takes the form ddd, where the letters represents the term weighting | ||
of the document vector. | ||
|
||
Term frequency weighing: | ||
natural - `n`, logarithm - `l` , augmented - `a`, boolean `b`, log average - `L`. | ||
Document frequency weighting: | ||
none - `n`, idf - `t`, prob idf - `p`. | ||
Document normalization: | ||
none - `n`, cosine - `c`. | ||
|
||
for more information visit https://en.wikipedia.org/wiki/SMART_Information_Retrieval_System | ||
|
||
Returns | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
------- | ||
x : gensim.models.tfidfmodel.TfidfModel | ||
|
||
If `dictionary` is specified, it must be a `corpora.Dictionary` object | ||
and it will be used to directly construct the inverse document frequency | ||
mapping (then `corpus`, if specified, is ignored). | ||
""" | ||
self.normalize = normalize | ||
|
||
self.id2word = id2word | ||
self.wlocal, self.wglobal = wlocal, wglobal | ||
self.wlocal, self.wglobal, self.normalize = wlocal, wglobal, normalize | ||
self.num_docs, self.num_nnz, self.idfs = None, None, None | ||
self.smartirs = smartirs | ||
|
||
# If smartirs is not None, override wlocal, wglobal and normalize | ||
if smartirs is not None: | ||
n_tf, n_df, n_n = resolve_weights(smartirs) | ||
|
||
self.wlocal = partial(updated_wlocal, n_tf=n_tf) | ||
self.wglobal = partial(updated_wglobal, n_df=n_df) | ||
self.normalize = partial(updated_normalize, n_n=n_n) | ||
|
||
if dictionary is not None: | ||
# user supplied a Dictionary object, which already contains all the | ||
# statistics we need to construct the IDF mapping. we can skip the | ||
|
@@ -113,6 +203,7 @@ def initialize(self, corpus): | |
logger.info("collecting document frequencies") | ||
dfs = {} | ||
numnnz, docno = 0, -1 | ||
|
||
for docno, bow in enumerate(corpus): | ||
if docno % 10000 == 0: | ||
logger.info("PROGRESS: processing document #%i", docno) | ||
|
@@ -124,7 +215,6 @@ def initialize(self, corpus): | |
self.num_docs = docno + 1 | ||
self.num_nnz = numnnz | ||
self.dfs = dfs | ||
|
||
# and finally compute the idf weights | ||
n_features = max(dfs) if dfs else 0 | ||
logger.info( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why you remove this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This showed the progress of the |
||
|
@@ -144,17 +234,27 @@ def __getitem__(self, bow, eps=1e-12): | |
|
||
# unknown (new) terms will be given zero weight (NOT infinity/huge weight, | ||
# as strict application of the IDF formula would dictate) | ||
|
||
termid_array, tf_array = [], [] | ||
for termid, tf in bow: | ||
termid_array.append(termid) | ||
tf_array.append(tf) | ||
|
||
tf_array = self.wlocal(np.array(tf_array)) | ||
|
||
vector = [ | ||
(termid, self.wlocal(tf) * self.idfs.get(termid)) | ||
for termid, tf in bow if self.idfs.get(termid, 0.0) != 0.0 | ||
(termid, tf * self.idfs.get(termid)) | ||
for termid, tf in zip(termid_array, tf_array) if self.idfs.get(termid, 0.0) != 0.0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is brittle; better compare floats for equality using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm confused, #1791 (comment) and #1791 (comment) contradict each other, what you mean @piskvorky? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To compare floats for equality, use |
||
] | ||
|
||
if self.normalize is True: | ||
self.normalize = matutils.unitvec | ||
elif self.normalize is False: | ||
self.normalize = utils.identity | ||
|
||
# and finally, normalize the vector either to unit length, or use a | ||
# user-defined normalization function | ||
if self.normalize is True: | ||
vector = matutils.unitvec(vector) | ||
elif self.normalize: | ||
vector = self.normalize(vector) | ||
vector = self.normalize(vector) | ||
|
||
# make sure there are no explicit zeroes in the vector (must be sparse) | ||
vector = [(termid, weight) for termid, weight in vector if abs(weight) > eps] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -973,13 +973,13 @@ def testTransform(self): | |
|
||
def testSetGetParams(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't forget to add more tests (also, check situations, when you pass |
||
# updating only one param | ||
self.model.set_params(normalize=False) | ||
self.model.set_params(smartirs='nnn') | ||
model_params = self.model.get_params() | ||
self.assertEqual(model_params["normalize"], False) | ||
self.assertEqual(model_params["smartirs"], 'nnn') | ||
|
||
# verify that the attributes values are also changed for `gensim_model` after fitting | ||
self.model.fit(self.corpus) | ||
self.assertEqual(getattr(self.model.gensim_model, 'normalize'), False) | ||
self.assertEqual(getattr(self.model.gensim_model, 'smartirs'), 'nnn') | ||
|
||
def testPipeline(self): | ||
with open(datapath('mini_newsgroup'), 'rb') as f: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstrings needed too (for all stuff here)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that
Checks for validity of smartirs parameter.
is enough. Do you have anything else in mind as well?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@markroxor need to add "Parameters" (type, description), "Raises" (type, reason), "Returns" (type, description)