Skip to content

Commit

Permalink
Add seed and lenght for sample_text (#1422)
Browse files Browse the repository at this point in the history
* Create local random generator for sample_text & add lenght

* Fix typos
  • Loading branch information
vlejd authored and menshikh-iv committed Jun 22, 2017
1 parent dfb66f1 commit 0d47a6f
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 27 deletions.
50 changes: 39 additions & 11 deletions gensim/corpora/textcorpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,28 +98,56 @@ def get_texts(self):
else:
yield utils.tokenize(line, lowercase=True)

def sample_texts(self, n):
def sample_texts(self, n, seed=None, length=None):
"""
Yield n random texts from the corpus without replacement.
Yield n random documents from the corpus without replacement.
Given the the number of remaingin elements in stream is remaining and we need
to choose n elements, the probability for current element to be chosen is n/remaining.
If we choose it, we just decreese the n and move to the next element.
Given the number of remaining documents in a corpus, we need to choose n elements.
The probability for the current element to be chosen is n/remaining.
If we choose it, we just decrease the n and move to the next element.
Computing the corpus length may be a costly operation so you can use the optional
parameter `length` instead.
Args:
n (int): number of documents we want to sample.
seed (int|None): if specified, use it as a seed for local random generator.
length (int|None): if specified, use it as a guess of corpus length.
It must be positive and not greater than actual corpus length.
Yields:
list[str]: document represented as a list of tokens. See get_texts method.
Raises:
ValueError: when n is invalid or length was set incorrectly.
"""
length = len(self)
if not n <= length:
raise ValueError("sample larger than population")
random_generator = None
if seed is None:
random_generator = random
else:
random_generator = random.Random(seed)

if length is None:
length = len(self)

if not n <= length:
raise ValueError("n is larger than length of corpus.")
if not 0 <= n:
raise ValueError("negative sample size")
raise ValueError("Negative sample size.")

for i, sample in enumerate(self.get_texts()):
remaining_in_stream = length - i
chance = random.randint(1, remaining_in_stream)
if i == length:
break
remaining_in_corpus = length - i
chance = random_generator.randint(1, remaining_in_corpus)
if chance <= n:
n -= 1
yield sample

if n != 0:
# This means that length was set to be greater than number of items in corpus
# and we were not able to sample enough documents before the stream ended.
raise ValueError("length greater than number of documents in corpus")

def __len__(self):
if not hasattr(self, 'length'):
# cache the corpus length
Expand Down
46 changes: 30 additions & 16 deletions gensim/test/test_textcorpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,34 +21,48 @@
class TestTextCorpus(unittest.TestCase):
# TODO add tests for other methods

def test_sample_text(self):
class TestTextCorpus(TextCorpus):
def __init__(self):
self.data = [["document1"], ["document2"]]
class DummyTextCorpus(TextCorpus):
def __init__(self):
self.size = 10
self.data = [["document%s" % i] for i in range(self.size)]

def get_texts(self):
for document in self.data:
yield document
def get_texts(self):
for document in self.data:
yield document

corpus = TestTextCorpus()
def test_sample_text(self):
corpus = self.DummyTextCorpus()

sample1 = list(corpus.sample_texts(1))
self.assertEqual(len(sample1), 1)
document1 = sample1[0] == ["document1"]
document2 = sample1[0] == ["document2"]
self.assertTrue(document1 or document2)
self.assertIn(sample1[0], corpus.data)

sample2 = list(corpus.sample_texts(2))
self.assertEqual(len(sample2), 2)
self.assertEqual(sample2[0], ["document1"])
self.assertEqual(sample2[1], ["document2"])
sample2 = list(corpus.sample_texts(corpus.size))
self.assertEqual(len(sample2), corpus.size)
for i in range(corpus.size):
self.assertEqual(sample2[i], ["document%s" % i])

with self.assertRaises(ValueError):
list(corpus.sample_texts(3))
list(corpus.sample_texts(corpus.size + 1))

with self.assertRaises(ValueError):
list(corpus.sample_texts(-1))

def test_sample_text_length(self):
corpus = self.DummyTextCorpus()
sample1 = list(corpus.sample_texts(1, length=1))
self.assertEqual(sample1[0], ["document0"])

sample2 = list(corpus.sample_texts(2, length=2))
self.assertEqual(sample2[0], ["document0"])
self.assertEqual(sample2[1], ["document1"])

def test_sample_text_seed(self):
corpus = self.DummyTextCorpus()
sample1 = list(corpus.sample_texts(5, seed=42))
sample2 = list(corpus.sample_texts(5, seed=42))
self.assertEqual(sample1, sample2)


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
Expand Down

0 comments on commit 0d47a6f

Please sign in to comment.