Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
[FEATURE] offer load_w2v_binary method to load w2v binary file (#620)
Browse files Browse the repository at this point in the history
* ✨ (w2v) offer load_w2v_binary method to load w2v binary file

* 🐛 (w2v) expand user path

* 🚧 (w2v) add preload flag. it decides to load default model or not.

* 🐛 (w2v) preserve 0 for unknown tok

* ✅ (w2v) add test for load_w2v_binary

* ♻️ (w2v) If source is None, construct an empty Word2Vec class object.

* ♻️ (w2v) add classmethod: from_binary as a constructor

* 🎨 (w2v) rename from_binary -> from_w2v_binary

* ♻️ (w2v) check source in w2v constructor

* ♻️ (w2v) use file name extension to identify binary file

* 🎨 (w2v) pyilnt: fix unused argument 'encoding' (unused-argument)

* 📝 (w2v) update doc

* 📝 (w2v) update description in w2v constructor
  • Loading branch information
Gary Lai authored and eric-haibin-lin committed Mar 26, 2019
1 parent 6ec0c84 commit e1910c5
Show file tree
Hide file tree
Showing 4 changed files with 554 additions and 6 deletions.
108 changes: 102 additions & 6 deletions src/gluonnlp/embedding/token_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,6 @@ def allow_extend(self):
"""
return self._allow_extend


@property
def unknown_lookup(self):
"""Vector lookup for unknown tokens.
Expand Down Expand Up @@ -997,6 +996,8 @@ class Word2Vec(TokenEmbedding):
----------
source : str, default 'GoogleNews-vectors-negative300'
The name of the pre-trained token embedding file.
A binary pre-trained file outside from the source list can be used for this constructor by
passing the path to it which ends with .bin as file extension name.
embedding_root : str, default '$MXNET_HOME/embedding'
The root directory for storing embedding-related files.
MXNET_HOME defaults to '~/.mxnet'.
Expand All @@ -1019,10 +1020,105 @@ class Word2Vec(TokenEmbedding):
source_file_hash = C.WORD2VEC_NPZ_SHA1

def __init__(self, source='GoogleNews-vectors-negative300',
embedding_root=os.path.join(get_home_dir(), 'embedding'), **kwargs):
self._check_source(self.source_file_hash, source)

embedding_root=os.path.join(get_home_dir(), 'embedding'), encoding='utf8',
**kwargs):
super(Word2Vec, self).__init__(**kwargs)
pretrained_file_path = self._get_file_path(self.source_file_hash, embedding_root, source)
if source.endswith('.bin'):
pretrained_file_path = os.path.expanduser(source)
self._load_w2v_binary(pretrained_file_path, encoding=encoding)
else:
self._check_source(self.source_file_hash, source)
pretrained_file_path = self._get_file_path(self.source_file_hash,
embedding_root, source)
self._load_embedding(pretrained_file_path, elem_delim=' ')

self._load_embedding(pretrained_file_path, elem_delim=' ')
def _load_w2v_binary(self, pretrained_file_path, encoding='utf8'):
"""Load embedding vectors from a binary pre-trained token embedding file.
Parameters
----------
pretrained_file_path: str
The path to a binary pre-trained token embedding file end with .bin as file extension
name.
encoding: str
The encoding type of the file.
"""
self._idx_to_token = [self.unknown_token] if self.unknown_token else []
if self.unknown_token:
self._token_to_idx = DefaultLookupDict(C.UNK_IDX)
else:
self._token_to_idx = {}
self._token_to_idx.update((token, idx) for idx, token in enumerate(self._idx_to_token))
self._idx_to_vec = None
all_elems = []
tokens = set()
loaded_unknown_vec = None
pretrained_file_path = os.path.expanduser(pretrained_file_path)
with io.open(pretrained_file_path, 'rb') as f:
header = f.readline().decode(encoding=encoding)
vocab_size, vec_len = (int(x) for x in header.split())
if self.unknown_token:
# Reserve a vector slot for the unknown token at the very beggining
# because the unknown token index is 0.
all_elems.extend([0] * vec_len)
binary_len = np.dtype(np.float32).itemsize * vec_len
for line_num in range(vocab_size):
token = []
while True:
ch = f.read(1)
if ch == b' ':
break
if ch == b'':
raise EOFError('unexpected end of input; is count incorrect or file '
'otherwise damaged?')
if ch != b'\n': # ignore newlines in front of words (some binary files have)
token.append(ch)
try:
token = b''.join(token).decode(encoding=encoding)
except ValueError:
warnings.warn('line {} in {}: failed to decode. Skipping.'
.format(line_num, pretrained_file_path))
continue
elems = np.fromstring(f.read(binary_len), dtype=np.float32)

assert len(elems) > 1, 'line {} in {}: unexpected data format.'.format(
line_num, pretrained_file_path)

if token == self.unknown_token and loaded_unknown_vec is None:
loaded_unknown_vec = elems
tokens.add(self.unknown_token)
elif token in tokens:
warnings.warn('line {} in {}: duplicate embedding found for '
'token "{}". Skipped.'.format(line_num, pretrained_file_path,
token))
else:
assert len(elems) == vec_len, \
'line {} in {}: found vector of inconsistent dimension for token ' \
'"{}". expected dim: {}, found: {}'.format(line_num,
pretrained_file_path,
token, vec_len, len(elems))
all_elems.extend(elems)
self._idx_to_token.append(token)
self._token_to_idx[token] = len(self._idx_to_token) - 1
tokens.add(token)
self._idx_to_vec = nd.array(all_elems).reshape((-1, vec_len))

if self.unknown_token:
if loaded_unknown_vec is None:
self._idx_to_vec[C.UNK_IDX] = self._init_unknown_vec(shape=vec_len)
else:
self._idx_to_vec[C.UNK_IDX] = nd.array(loaded_unknown_vec)

@classmethod
def from_w2v_binary(cls, pretrained_file_path, encoding='utf8'):
"""Load embedding vectors from a binary pre-trained token embedding file.
Parameters
----------
pretrained_file_path: str
The path to a binary pre-trained token embedding file end with .bin as file extension
name.
encoding: str
The encoding type of the file.
"""
return cls(source=pretrained_file_path, encoding=encoding)
16 changes: 16 additions & 0 deletions tests/unittest/train/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,19 @@ def test_fasttext_embedding_load_binary_compare_vec():
np.isclose(a=token_embedding_vec.idx_to_vec.asnumpy(),
b=idx_to_vec.asnumpy(), atol=0.001))
assert all(token in model for token in token_embedding_vec.idx_to_token)


def test_word2vec_embedding_load_binary_format():
test_dir = os.path.dirname(os.path.realpath(__file__))
word2vec_vec = nlp.embedding.Word2Vec.from_file(
os.path.join(str(test_dir), 'test_embedding', 'lorem_ipsum_w2v.vec'),
elem_delim=' '
)
word2vec_bin = nlp.embedding.Word2Vec.from_w2v_binary(
os.path.join(str(test_dir), 'test_embedding', 'lorem_ipsum_w2v.bin')
)
idx_to_vec = word2vec_bin[word2vec_vec.idx_to_token]
assert np.all(
np.isclose(a=word2vec_vec.idx_to_vec.asnumpy(),
b=idx_to_vec.asnumpy(), atol=0.001))
assert all(token in word2vec_bin for token in word2vec_vec.idx_to_token)
Binary file not shown.
Loading

0 comments on commit e1910c5

Please sign in to comment.