-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Cythonize token block dataset (#834)
Summary: Cythonized token block dataset code, it's `> 100x` faster. Token block for entire `bookwiki+CC+stories+openweb` is just ~`39.9` seconds. TODO: 1) I think, I can make it 2x more faster. 2) cleanup. EDIT History: ~~First pass at parellelizing `token_block_dataset`. The code feels somewhat complicated and cluttered. This is 2-3x faster though on my tests on `bookwiki` dataset with both `complete` and `complete_doc` modes. myleott Can you take a look for correctness as I am still not 100% sure that I am not missing corner cases.~~ Pull Request resolved: fairinternal/fairseq-py#834 Test Plan: Imported from GitHub, without a `Test Plan:` line. Test workflow: f133816198 Reviewed By: myleott Differential Revision: D16970257 Pulled By: myleott fbshipit-source-id: ec45a308193c9e9f3e7075336c15df4723228d6f
- Loading branch information
1 parent
6e2bd79
commit 4fc3953
Showing
6 changed files
with
285 additions
and
154 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import numpy as np | ||
|
||
cimport cython | ||
cimport numpy as np | ||
|
||
DTYPE = np.int64 | ||
ctypedef np.int64_t DTYPE_t | ||
|
||
|
||
cdef _is_batch_full(list batch, long num_tokens, long max_tokens, long max_sentences): | ||
if len(batch) == 0: | ||
return 0 | ||
if len(batch) == max_sentences: | ||
return 1 | ||
if num_tokens > max_tokens: | ||
return 1 | ||
return 0 | ||
|
||
|
||
@cython.cdivision(True) | ||
cpdef list batch_by_size_fast( | ||
np.ndarray[DTYPE_t, ndim=1] indices, | ||
num_tokens_fn, | ||
long max_tokens, | ||
long max_sentences, | ||
int bsz_mult, | ||
): | ||
cdef long sample_len = 0 | ||
cdef list sample_lens = [] | ||
cdef list batch = [] | ||
cdef list batches = [] | ||
cdef long mod_len | ||
cdef long i | ||
cdef long idx | ||
cdef long num_tokens | ||
cdef DTYPE_t[:] indices_view = indices | ||
|
||
for i in range(len(indices_view)): | ||
idx = indices_view[i] | ||
num_tokens = num_tokens_fn(idx) | ||
sample_lens.append(num_tokens) | ||
sample_len = max(sample_len, num_tokens) | ||
|
||
assert sample_len <= max_tokens, ( | ||
"sentence at index {} of size {} exceeds max_tokens " | ||
"limit of {}!".format(idx, sample_len, max_tokens) | ||
) | ||
num_tokens = (len(batch) + 1) * sample_len | ||
|
||
if _is_batch_full(batch, num_tokens, max_tokens, max_sentences): | ||
mod_len = max( | ||
bsz_mult * (len(batch) // bsz_mult), | ||
len(batch) % bsz_mult, | ||
) | ||
batches.append(batch[:mod_len]) | ||
batch = batch[mod_len:] | ||
sample_lens = sample_lens[mod_len:] | ||
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 | ||
batch.append(idx) | ||
if len(batch) > 0: | ||
batches.append(batch) | ||
return batches |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
4fc3953
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Naman, PyText repo(depending on FairSeq master) is broken today due to "from Cython.Build import cythonize". I think it's caused by this commit which introduces Cython but it's not added as dependency. Could you take a look?
Reference: PyText CircleCI build failure log: https://circleci.com/gh/facebookresearch/pytext/12865?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link
4fc3953
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hudeven, this should be fixed now, please reach out if it's still an issue