Skip to content

Commit

Permalink
Do integer checks using both six.integer_types and np.integer
Browse files Browse the repository at this point in the history
  • Loading branch information
bogdanteleaga committed Apr 3, 2017
1 parent 8e05520 commit 206682f
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 21 deletions.
4 changes: 2 additions & 2 deletions gensim/corpora/indexedcorpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""

import logging
import shelve
import six

import numpy

Expand Down Expand Up @@ -124,7 +124,7 @@ def __getitem__(self, docno):

if isinstance(docno, (slice, list, numpy.ndarray)):
return utils.SlicedCorpus(self, docno)
elif isinstance(docno, (int, numpy.integer)):
elif isinstance(docno, six.integer_types + (numpy.integer,)):
return self.docbyoffset(self.index[docno])
else:
raise ValueError('Unrecognised value for docno, use either a single integer, a slice or a numpy.ndarray')
Expand Down
7 changes: 1 addition & 6 deletions gensim/models/atmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,8 @@
# are included in the code where this is the case, for example in the log_perplexity
# and do_estep methods.

import pdb
from pdb import set_trace as st
from pprint import pprint

import logging
import numpy as np # for arrays, array broadcasting etc.
import numbers
from copy import deepcopy
from shutil import copyfile
from os.path import isfile
Expand Down Expand Up @@ -391,7 +386,7 @@ def inference(self, chunk, author2doc, doc2author, rhot, collect_sstats=False, c
doc_no = d
# Get the IDs and counts of all the words in the current document.
# TODO: this is duplication of code in LdaModel. Refactor.
if doc and not isinstance(doc[0][0], six.integer_types):
if doc and not isinstance(doc[0][0], six.integer_types + (np.integer,)):
# make sure the term IDs are ints, otherwise np will get upset
ids = [int(id) for id, _ in doc]
else:
Expand Down
18 changes: 9 additions & 9 deletions gensim/models/doc2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from collections import namedtuple, defaultdict
from timeit import default_timer

from numpy import zeros, random, sum as np_sum, add as np_add, concatenate, \
from numpy import zeros, sum as np_sum, add as np_add, concatenate, \
repeat as np_repeat, array, float32 as REAL, empty, ones, memmap as np_memmap, \
sqrt, newaxis, ndarray, dot, vstack, dtype, divide as np_divide, integer

Expand All @@ -62,7 +62,7 @@
from gensim import utils, matutils # utility fnc for pickling, common scipy operations etc
from gensim.models.word2vec import Word2Vec, train_cbow_pair, train_sg_pair, train_batch_sg
from six.moves import xrange, zip
from six import string_types, integer_types, itervalues
from six import string_types, integer_types

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -296,7 +296,7 @@ def __init__(self, mapfile_path=None):

def note_doctag(self, key, document_no, document_length):
"""Note a document tag during initial corpus scan, for structure sizing."""
if isinstance(key, int):
if isinstance(key, integer_types + (integer,)):
self.max_rawint = max(self.max_rawint, key)
else:
if key in self.doctags:
Expand All @@ -318,7 +318,7 @@ def trained_item(self, indexed_tuple):

def _int_index(self, index):
"""Return int index for either string or int index"""
if isinstance(index, (int, integer)):
if isinstance(index, integer_types + (integer,)):
return index
else:
return self.max_rawint + 1 + self.doctags[index].offset
Expand Down Expand Up @@ -346,7 +346,7 @@ def __getitem__(self, index):
If a list, return designated tags' vector representations as a
2D numpy array: #tags x #vector_size.
"""
if isinstance(index, string_types + (int, integer)):
if isinstance(index, string_types + integer_types + (integer,)):
return self.doctag_syn0[self._int_index(index)]

return vstack([self[i] for i in index])
Expand All @@ -355,7 +355,7 @@ def __len__(self):
return self.count

def __contains__(self, index):
if isinstance(index, int):
if isinstance(index, integer_types + (integer,)):
return index < self.count
else:
return index in self.doctags
Expand Down Expand Up @@ -438,17 +438,17 @@ def most_similar(self, positive=[], negative=[], topn=10, clip_start=0, clip_end
self.init_sims()
clip_end = clip_end or len(self.doctag_syn0norm)

if isinstance(positive, string_types + integer_types) and not negative:
if isinstance(positive, string_types + integer_types + (integer,)) and not negative:
# allow calls like most_similar('dog'), as a shorthand for most_similar(['dog'])
positive = [positive]

# add weights for each doc, if not already present; default to 1.0 for positive and -1.0 for negative docs
positive = [
(doc, 1.0) if isinstance(doc, string_types + (ndarray,) + integer_types)
(doc, 1.0) if isinstance(doc, string_types + integer_types + (ndarray, integer))
else doc for doc in positive
]
negative = [
(doc, -1.0) if isinstance(doc, string_types + (ndarray,) + integer_types)
(doc, -1.0) if isinstance(doc, string_types + integer_types + (ndarray, integer))
else doc for doc in negative
]

Expand Down
8 changes: 4 additions & 4 deletions gensim/models/ldamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def inference(self, chunk, collect_sstats=False):
# Lee&Seung trick which speeds things up by an order of magnitude, compared
# to Blei's original LDA-C code, cool!).
for d, doc in enumerate(chunk):
if len(doc) > 0 and not isinstance(doc[0][0], six.integer_types):
if len(doc) > 0 and not isinstance(doc[0][0], six.integer_types + (np.integer,)):
# make sure the term IDs are ints, otherwise np will get upset
ids = [int(id) for id, _ in doc]
else:
Expand Down Expand Up @@ -1006,7 +1006,7 @@ def save(self, fname, ignore=['state', 'dispatcher'], separately=None, *args, **
if 'id2word' not in ignore:
utils.pickle(self.id2word, utils.smart_extension(fname, '.id2word'))

# make sure 'state', 'id2word' and 'dispatcher' are ignored from the pickled object, even if
# make sure 'state', 'id2word' and 'dispatcher' are ignored from the pickled object, even if
# someone sets the ignore list themselves
if ignore is not None and ignore:
if isinstance(ignore, six.string_types):
Expand All @@ -1015,7 +1015,7 @@ def save(self, fname, ignore=['state', 'dispatcher'], separately=None, *args, **
ignore = list(set(['state', 'dispatcher', 'id2word']) | set(ignore))
else:
ignore = ['state', 'dispatcher', 'id2word']

# make sure 'expElogbeta' and 'sstats' are ignored from the pickled object, even if
# someone sets the separately list themselves.
separately_explicit = ['expElogbeta', 'sstats']
Expand All @@ -1034,7 +1034,7 @@ def save(self, fname, ignore=['state', 'dispatcher'], separately=None, *args, **
else:
separately = separately_explicit
super(LdaModel, self).save(fname, ignore=ignore, separately = separately, *args, **kwargs)

@classmethod
def load(cls, fname, *args, **kwargs):
"""
Expand Down

0 comments on commit 206682f

Please sign in to comment.