diff --git a/gensim/matutils.py b/gensim/matutils.py index 3b4dce9c4a..bd5c8a34d5 100644 --- a/gensim/matutils.py +++ b/gensim/matutils.py @@ -1272,25 +1272,40 @@ def __init__(self, input, transposed=True): """ logger.info("initializing corpus reader from %s", input) self.input, self.transposed = input, transposed - with utils.file_or_filename(self.input) as lines: + + # 'with' statement behaviour without closing the file object + mgr = (utils.file_or_filename(self.input)) + exit = type(mgr).__exit__ + value = type(mgr).__enter__(mgr) + exc = True + try: try: - header = utils.to_unicode(next(lines)).strip() - if not header.lower().startswith('%%matrixmarket matrix coordinate real general'): - raise ValueError( - "File %s not in Matrix Market format with coordinate real general; instead found: \n%s" % - (self.input, header) - ) - except StopIteration: - pass - - self.num_docs = self.num_terms = self.num_nnz = 0 - for lineno, line in enumerate(lines): - line = utils.to_unicode(line) - if not line.startswith('%'): - self.num_docs, self.num_terms, self.num_nnz = (int(x) for x in line.split()) - if not self.transposed: - self.num_docs, self.num_terms = self.num_terms, self.num_docs - break + lines = value + try: + header = utils.to_unicode(next(lines)).strip() + if not header.lower().startswith('%%matrixmarket matrix coordinate real general'): + raise ValueError( + "File %s not in Matrix Market format with coordinate real general; instead found: \n%s" % + (self.input, header) + ) + except StopIteration: + pass + + self.num_docs = self.num_terms = self.num_nnz = 0 + for lineno, line in enumerate(lines): + line = utils.to_unicode(line) + if not line.startswith('%'): + self.num_docs, self.num_terms, self.num_nnz = (int(x) for x in line.split()) + if not self.transposed: + self.num_docs, self.num_terms = self.num_terms, self.num_docs + break + except: + exc = False + if not exit(mgr, *sys.exc_info()): + raise + finally: + if isinstance(self.input, string_types): + exit(mgr, None, None, None) logger.info( "accepted corpus with %i documents, %i features, %i non-zero entries", diff --git a/gensim/test/test_corpora.py b/gensim/test/test_corpora.py index f330dbd271..79a7b2c93f 100644 --- a/gensim/test/test_corpora.py +++ b/gensim/test/test_corpora.py @@ -224,6 +224,14 @@ def setUp(self): def test_serialize_compressed(self): # MmCorpus needs file write with seek => doesn't support compressed output (only input) pass + + def test_closed_file_object(self): + file_obj = open(datapath('testcorpus.mm')) + f = file_obj.closed + corpus = mmcorpus.MmCorpus(file_obj) + s = file_obj.closed + self.assertEqual(f, 0) + self.assertEqual(s, 0) def test_load(self): self.assertEqual(self.corpus.num_docs, 9)