Skip to content

Commit

Permalink
Fix fasttext model loading from gzip files (#2476)
Browse files Browse the repository at this point in the history
* added some assertions

* extract np.fromfile function, add tests around it

* add gzip test case

* get matrix loading working with gzip

* remove assertion, some tests trip it

* apply comments from review

* make flake8 happy

* Update gensim/models/_fasttext_bin.py

Co-Authored-By: mpenkov <m@penkov.dev>

* More review responses
  • Loading branch information
mpenkov authored May 6, 2019
1 parent 40792c6 commit 790b9a7
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 7 deletions.
64 changes: 57 additions & 7 deletions gensim/models/_fasttext_bin.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import codecs
import collections
import gzip
import io
import logging
import struct
Expand Down Expand Up @@ -74,6 +75,14 @@
('t', 'd'),
]

_FLOAT_SIZE = struct.calcsize('@f')
if _FLOAT_SIZE == 4:
_FLOAT_DTYPE = np.dtype(np.float32)
elif _FLOAT_SIZE == 8:
_FLOAT_DTYPE = np.dtype(np.float64)
else:
_FLOAT_DTYPE = None


def _yield_field_names():
for name, _ in _OLD_HEADER_FORMAT + _NEW_HEADER_FORMAT:
Expand Down Expand Up @@ -220,24 +229,65 @@ def _load_matrix(fin, new_format=True):
The number of columns of the array will correspond to the vector size.
"""
if _FLOAT_DTYPE is None:
raise ValueError('bad _FLOAT_SIZE: %r' % _FLOAT_SIZE)

if new_format:
_struct_unpack(fin, '@?') # bool quant_input in fasttext.cc

num_vectors, dim = _struct_unpack(fin, '@2q')
count = num_vectors * dim

float_size = struct.calcsize('@f')
if float_size == 4:
dtype = np.dtype(np.float32)
elif float_size == 8:
dtype = np.dtype(np.float64)
#
# numpy.fromfile doesn't play well with gzip.GzipFile as input:
#
# - https://github.com/RaRe-Technologies/gensim/pull/2476
# - https://github.com/numpy/numpy/issues/13470
#
# Until they fix it, we have to apply a workaround. We only apply the
# workaround when it's necessary, because np.fromfile is heavily optimized
# and very efficient (when it works).
#
if isinstance(fin, gzip.GzipFile):
logger.warning(
'Loading model from a compressed .gz file. This can be slow. '
'This is a work-around for a bug in NumPy: https://github.com/numpy/numpy/issues/13470. '
'Consider decompressing your model file for a faster load. '
)
matrix = _fromfile(fin, _FLOAT_DTYPE, count)
else:
raise ValueError("Incompatible float size: %r" % float_size)
matrix = np.fromfile(fin, _FLOAT_DTYPE, count)

matrix = np.fromfile(fin, dtype=dtype, count=num_vectors * dim)
assert matrix.shape == (count,), 'expected (%r,), got %r' % (count, matrix.shape)
matrix = matrix.reshape((num_vectors, dim))
return matrix


def _batched_generator(fin, count, batch_size=1e6):
"""Read `count` floats from `fin`.
Batches up read calls to avoid I/O overhead. Keeps no more than batch_size
floats in memory at once.
Yields floats.
"""
while count > batch_size:
batch = _struct_unpack(fin, '@%df' % batch_size)
for f in batch:
yield f
count -= batch_size

batch = _struct_unpack(fin, '@%df' % count)
for f in batch:
yield f


def _fromfile(fin, dtype, count):
"""Reimplementation of numpy.fromfile."""
return np.fromiter(_batched_generator(fin, count), dtype=dtype)


def load(fin, encoding='utf-8', full_model=True):
"""Load a model from a binary stream.
Expand Down
Binary file added gensim/test/test_data/reproduce.dat
Binary file not shown.
Binary file added gensim/test/test_data/reproduce.dat.gz
Binary file not shown.
23 changes: 23 additions & 0 deletions gensim/test/test_fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# -*- coding: utf-8 -*-
from __future__ import division

import gzip
import io
import logging
import unittest
Expand Down Expand Up @@ -1289,6 +1290,28 @@ def test_bad_unicode(self):
self.assertEqual(nlabels, -1)


_BYTES = b'the quick brown fox jumps over the lazy dog'
_ARRAY = np.array([0., 1., 2., 3., 4., 5., 6., 7., 8.], dtype=np.dtype('float32'))


class TestFromfile(unittest.TestCase):
def test_decompressed(self):
with open(datapath('reproduce.dat'), 'rb') as fin:
self._run(fin)

def test_compressed(self):
with gzip.GzipFile(datapath('reproduce.dat.gz'), 'rb') as fin:
self._run(fin)

def _run(self, fin):
actual = fin.read(len(_BYTES))
self.assertEqual(_BYTES, actual)

array = gensim.models._fasttext_bin._fromfile(fin, _ARRAY.dtype, _ARRAY.shape[0])
logger.error('array: %r', array)
self.assertTrue(np.allclose(_ARRAY, array))


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
unittest.main()

0 comments on commit 790b9a7

Please sign in to comment.