Skip to content

Commit

Permalink
Cythonize token block dataset (#834)
Browse files Browse the repository at this point in the history
Summary:
Cythonized token block dataset code, it's `> 100x` faster. Token block for entire `bookwiki+CC+stories+openweb` is just ~`39.9` seconds.

TODO:
1) I think, I can make it 2x more faster.
2) cleanup.

EDIT History:
~~First pass at parellelizing `token_block_dataset`. The code feels somewhat complicated and cluttered.
This is 2-3x faster though on my tests on `bookwiki` dataset with both `complete` and `complete_doc` modes.
myleott Can you take a look for correctness as I am still not 100% sure that I am not missing corner cases.~~
Pull Request resolved: fairinternal/fairseq-py#834

Test Plan:
Imported from GitHub, without a `Test Plan:` line.

Test workflow: f133816198

Reviewed By: myleott

Differential Revision: D16970257

Pulled By: myleott

fbshipit-source-id: ec45a308193c9e9f3e7075336c15df4723228d6f
  • Loading branch information
Naman Goyal authored and facebook-github-bot committed Aug 23, 2019
1 parent 6e2bd79 commit 4fc3953
Show file tree
Hide file tree
Showing 6 changed files with 285 additions and 154 deletions.
46 changes: 9 additions & 37 deletions fairseq/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
import os

import numpy as np
import sys
import types

from fairseq.data.data_utils_fast import batch_by_size_fast


def infer_language_pair(path):
Expand Down Expand Up @@ -196,45 +200,13 @@ def batch_by_size(
required_batch_size_multiple (int, optional): require batch size to
be a multiple of N (default: 1).
"""
max_tokens = max_tokens if max_tokens is not None else float('Inf')
max_sentences = max_sentences if max_sentences is not None else float('Inf')
max_tokens = max_tokens if max_tokens is not None else sys.maxsize
max_sentences = max_sentences if max_sentences is not None else sys.maxsize
bsz_mult = required_batch_size_multiple

batch = []

def is_batch_full(num_tokens):
if len(batch) == 0:
return False
if len(batch) == max_sentences:
return True
if num_tokens > max_tokens:
return True
return False

sample_len = 0
sample_lens = []
for idx in indices:
sample_lens.append(num_tokens_fn(idx))
sample_len = max(sample_len, sample_lens[-1])
assert sample_len <= max_tokens, (
"sentence at index {} of size {} exceeds max_tokens "
"limit of {}!".format(idx, sample_len, max_tokens)
)
num_tokens = (len(batch) + 1) * sample_len
if is_batch_full(num_tokens):
mod_len = max(
bsz_mult * (len(batch) // bsz_mult),
len(batch) % bsz_mult,
)
yield batch[:mod_len]
batch = batch[mod_len:]
sample_lens = sample_lens[mod_len:]
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0

batch.append(idx)

if len(batch) > 0:
yield batch
if isinstance(indices, types.GeneratorType):
indices = np.fromiter(indices, dtype=np.int64, count=-1)
return batch_by_size_fast(indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult)


def process_bpe_symbol(sentence: str, bpe_symbol: str):
Expand Down
67 changes: 67 additions & 0 deletions fairseq/data/data_utils_fast.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np

cimport cython
cimport numpy as np

DTYPE = np.int64
ctypedef np.int64_t DTYPE_t


cdef _is_batch_full(list batch, long num_tokens, long max_tokens, long max_sentences):
if len(batch) == 0:
return 0
if len(batch) == max_sentences:
return 1
if num_tokens > max_tokens:
return 1
return 0


@cython.cdivision(True)
cpdef list batch_by_size_fast(
np.ndarray[DTYPE_t, ndim=1] indices,
num_tokens_fn,
long max_tokens,
long max_sentences,
int bsz_mult,
):
cdef long sample_len = 0
cdef list sample_lens = []
cdef list batch = []
cdef list batches = []
cdef long mod_len
cdef long i
cdef long idx
cdef long num_tokens
cdef DTYPE_t[:] indices_view = indices

for i in range(len(indices_view)):
idx = indices_view[i]
num_tokens = num_tokens_fn(idx)
sample_lens.append(num_tokens)
sample_len = max(sample_len, num_tokens)

assert sample_len <= max_tokens, (
"sentence at index {} of size {} exceeds max_tokens "
"limit of {}!".format(idx, sample_len, max_tokens)
)
num_tokens = (len(batch) + 1) * sample_len

if _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
mod_len = max(
bsz_mult * (len(batch) // bsz_mult),
len(batch) % bsz_mult,
)
batches.append(batch[:mod_len])
batch = batch[mod_len:]
sample_lens = sample_lens[mod_len:]
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
batch.append(idx)
if len(batch) > 0:
batches.append(batch)
return batches
134 changes: 18 additions & 116 deletions fairseq/data/token_block_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math

import numpy as np
import torch

from fairseq.data.token_block_utils_fast import (
_get_slice_indices_fast,
_get_block_to_dataset_index_fast,
)

from fairseq.data import FairseqDataset, plasma_utils


Expand All @@ -33,7 +36,6 @@ class TokenBlockDataset(FairseqDataset):
'complete_doc' break mode). Typically 1 if the sentences have eos
and 0 otherwise.
"""

def __init__(
self,
dataset,
Expand All @@ -50,70 +52,22 @@ def __init__(
self.pad = pad
self.eos = eos
self.include_targets = include_targets
slice_indices = []

assert len(dataset) == len(sizes)
assert len(dataset) > 0
sizes = np.array(sizes, dtype=int)

if break_mode is None or break_mode == "none":
total_size = sum(sizes)
length = math.ceil(total_size / block_size)
if isinstance(sizes, list):
sizes = np.array(sizes, dtype=np.int64)
else:
sizes = sizes.astype(np.int64)

def block_at(i):
start = i * block_size
end = min(start + block_size, total_size)
return (start, end)
break_mode = break_mode if break_mode is not None else 'none'

slice_indices = [block_at(i) for i in range(length)]
elif break_mode == "complete":
tok_idx = 0
sz_idx = 0
curr_size = 0
while sz_idx < len(sizes):
if curr_size + sizes[sz_idx] <= block_size or curr_size == 0:
curr_size += sizes[sz_idx]
sz_idx += 1
else:
slice_indices.append((tok_idx, tok_idx + curr_size))
tok_idx += curr_size
curr_size = 0
if curr_size > 0:
slice_indices.append((tok_idx, tok_idx + curr_size))
elif break_mode == "complete_doc":
tok_idx = 0
sz_idx = 0
curr_size = 0
while sz_idx < len(sizes):
if (
(curr_size + sizes[sz_idx] <= block_size or curr_size == 0)
# an empty sentence indicates end-of-document:
and sizes[sz_idx] != document_sep_len
):
curr_size += sizes[sz_idx]
sz_idx += 1
else:
if curr_size > 1:
slice_indices.append((tok_idx, tok_idx + curr_size))
tok_idx += curr_size
curr_size = 0
if sizes[sz_idx] == document_sep_len:
tok_idx += sizes[sz_idx]
sz_idx += 1
if curr_size > 1:
slice_indices.append((tok_idx, tok_idx + curr_size))
elif break_mode == "eos":
slice_indices = np.empty((len(sizes), 2), dtype=int)
if not torch.is_tensor(sizes):
sizes = torch.tensor(sizes)
cumsum = torch.cumsum(sizes, dim=0)
slice_indices[0] = [0, sizes[0]]
if len(cumsum) > 1:
slice_indices[1:] = cumsum.unfold(0, 2, 1)
else:
raise ValueError("Invalid break_mode: " + break_mode)
# For "eos" break-mode, block_size is not required parameters.
if break_mode == "eos" and block_size is None:
block_size = 0

slice_indices = np.array(slice_indices, dtype=int)
slice_indices = _get_slice_indices_fast(sizes, break_mode, block_size, document_sep_len)
self._sizes = slice_indices[:, 1] - slice_indices[:, 0]

# build index mapping block indices to the underlying dataset indices
Expand All @@ -130,23 +84,10 @@ def block_at(i):
1,
)
else:
ds = DatasetSearcher(sizes)
block_to_dataset_index = np.empty((len(slice_indices), 3), dtype=int)
for i, (s, e) in enumerate(slice_indices):
ds.seek(s)
start_ds_idx = ds.current_index
start_offset = ds.current_offset
if e <= s:
end_ds_idx = start_ds_idx
else:
ds.seek(e - 1)
end_ds_idx = ds.current_index
block_to_dataset_index[i] = (
start_ds_idx, # starting index in dataset
start_offset, # starting offset within starting index
end_ds_idx, # ending index in dataset
)

block_to_dataset_index = _get_block_to_dataset_index_fast(
sizes,
slice_indices,
)
self._slice_indices = plasma_utils.PlasmaArray(slice_indices)
self._sizes = plasma_utils.PlasmaArray(self._sizes)
self._block_to_dataset_index = plasma_utils.PlasmaArray(block_to_dataset_index)
Expand Down Expand Up @@ -215,42 +156,3 @@ def prefetch(self, indices):
for ds_idx in range(start_ds_idx, end_ds_idx + 1)
}
)


class DatasetSearcher(object):
"""Helper for mapping "flat" indices to indices and offsets in an
underlying dataset."""

def __init__(self, sizes):
self.sizes = sizes
self.reset()

def reset(self):
self.current_index = 0 # index in underlying dataset
self.current_offset = 0 # offset within current index in underlying dataset
self.current_i = 0 # "flat" index

def seek(self, i):
assert i >= 0

def step():
if i < self.current_i:
self.reset()
if i > self.current_i:
to_consume = i - self.current_i
remaining = self.sizes[self.current_index] - self.current_offset
if remaining > to_consume:
self.current_offset += to_consume
self.current_i += to_consume
else:
assert remaining > 0
self.current_i += remaining
self.current_index += 1
self.current_offset = 0
return True
return False

not_done = True
while not_done:
not_done = step()
assert self.current_i == i
Loading

2 comments on commit 4fc3953

@hudeven
Copy link

@hudeven hudeven commented on 4fc3953 Aug 23, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Naman, PyText repo(depending on FairSeq master) is broken today due to "from Cython.Build import cythonize". I think it's caused by this commit which introduces Cython but it's not added as dependency. Could you take a look?

Reference: PyText CircleCI build failure log: https://circleci.com/gh/facebookresearch/pytext/12865?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link

@myleott
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hudeven, this should be fixed now, please reach out if it's still an issue

Please sign in to comment.