Skip to content

Commit

Permalink
DataLoader support not auto collate batch (#28425)
Browse files Browse the repository at this point in the history
* DataLoader support not auto collate batch. test=develop
  • Loading branch information
heavengate authored Nov 16, 2020
1 parent c5c273c commit 89d27de
Show file tree
Hide file tree
Showing 8 changed files with 327 additions and 36 deletions.
34 changes: 24 additions & 10 deletions python/paddle/fluid/dataloader/dataloader_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from ..framework import in_dygraph_mode
from ..multiprocess_utils import CleanupFuncRegistrar, _cleanup_mmap, _set_SIGCHLD_handler
from .fetcher import _IterableDatasetFetcher, _MapDatasetFetcher
from .batch_sampler import _InfiniteIterableSampler

__all__ = ['get_worker_info']

Expand Down Expand Up @@ -100,11 +101,13 @@ class _DatasetKind(object):
ITER = 1

@staticmethod
def create_fetcher(kind, dataset, collate_fn, drop_last):
def create_fetcher(kind, dataset, auto_collate_batch, collate_fn, drop_last):
if kind == _DatasetKind.MAP:
return _MapDatasetFetcher(dataset, collate_fn, drop_last)
return _MapDatasetFetcher(dataset, auto_collate_batch,
collate_fn, drop_last)
elif kind == _DatasetKind.ITER:
return _IterableDatasetFetcher(dataset, collate_fn, drop_last)
return _IterableDatasetFetcher(dataset, auto_collate_batch,
collate_fn, drop_last)
else:
raise NotImplementedError("unknown Dataset kind {}".format(kind))

Expand Down Expand Up @@ -221,8 +224,7 @@ def __init__(self, loader):
self._places = loader.places
self._return_list = loader.return_list
self._batch_sampler = loader.batch_sampler
self._sampler_iter = iter(loader.batch_sampler)
self._collate_fn = loader.collate_fn or default_collate_fn
self._auto_collate_batch = loader.auto_collate_batch
self._num_workers = loader.num_workers
self._use_buffer_reader = loader.use_buffer_reader
self._use_shared_memory = loader.use_shared_memory
Expand All @@ -231,6 +233,16 @@ def __init__(self, loader):
self._dataset_kind = loader.dataset_kind
self._pin_memory = loader.pin_memory

if self._auto_collate_batch:
self._sampler_iter = iter(loader.batch_sampler)
self._collate_fn = loader.collate_fn or default_collate_fn
else:
if self._dataset_kind == _DatasetKind.MAP:
self._sampler_iter = iter(list(range(len(self._dataset))))
else:
self._sampler_iter = iter(_InfiniteIterableSampler(self._dataset, 1))
self._collate_fn = loader.collate_fn

# LoDTensorBlockingQueue instance for create_py_reader and a thread
# to put mini-batch data to self._blocking_queue, mini-batch data
# will be get from:
Expand All @@ -257,7 +269,8 @@ def __init__(self, loader):
super(_DataLoaderIterSingleProcess, self).__init__(loader)

self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._collate_fn, True)
self._dataset_kind, self._dataset, self._auto_collate_batch,
self._collate_fn, True)

# NOTE: len(self._places) batch data compose as an output
# iteration, set blocking_queue can cache 2 iteration datas
Expand Down Expand Up @@ -367,7 +380,7 @@ def __del__(self):

# NOTE(chenweihang): _worker_loop must be top level method to be pickled
def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event,
collate_fn, init_fn, worker_id, num_workers,
auto_collate_batch, collate_fn, init_fn, worker_id, num_workers,
use_shared_memory):
try:
# NOTE: [ mmap files clear ] When the child process exits unexpectedly,
Expand All @@ -388,7 +401,7 @@ def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event,
if init_fn is not None:
init_fn(worker_id)
fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset,
collate_fn, True)
auto_collate_batch, collate_fn, True)
except:
init_exception = Exception("init_fn failed in worker {}: " \
"{}".format(worker_id, sys.exc_info()))
Expand Down Expand Up @@ -511,8 +524,9 @@ def _init_workers(self):
target=_worker_loop,
args=(self._dataset, self._dataset_kind, indices_queue,
self._data_queue, self._workers_done_event,
self._collate_fn, self._worker_init_fn, i,
self._num_workers, self._use_shared_memory))
self._auto_collate_batch, self._collate_fn,
self._worker_init_fn, i, self._num_workers,
self._use_shared_memory))
worker.daemon = True
worker.start()
self._workers.append(worker)
Expand Down
49 changes: 31 additions & 18 deletions python/paddle/fluid/dataloader/fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@


class _DatasetFetcher(object):
def __init__(self, dataset, collate_fn, drop_last):
def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last):
self.dataset = dataset
self.auto_collate_batch = auto_collate_batch
self.collate_fn = collate_fn
self.drop_last = drop_last

Expand All @@ -25,29 +26,41 @@ def fetch(self, batch_indices):


class _IterableDatasetFetcher(_DatasetFetcher):
def __init__(self, dataset, collate_fn, drop_last):
super(_IterableDatasetFetcher, self).__init__(dataset, collate_fn,
drop_last)
def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last):
super(_IterableDatasetFetcher, self).__init__(dataset, auto_collate_batch,
collate_fn, drop_last)
self.dataset_iter = iter(dataset)

def fetch(self, batch_indices):
data = []
for _ in batch_indices:
try:
data.append(next(self.dataset_iter))
except StopIteration:
break
if len(data) == 0 or (self.drop_last and
len(data) < len(batch_indices)):
raise StopIteration

return self.collate_fn(data)
if self.auto_collate_batch:
data = []
for _ in batch_indices:
try:
data.append(next(self.dataset_iter))
except StopIteration:
break
if len(data) == 0 or (self.drop_last and
len(data) < len(batch_indices)):
raise StopIteration
else:
data = next(self.dataset_iter)

if self.collate_fn:
data = self.collate_fn(data)
return data


class _MapDatasetFetcher(_DatasetFetcher):
def __init__(self, dataset, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, collate_fn, drop_last)
def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collate_batch, collate_fn, drop_last)

def fetch(self, batch_indices):
data = [self.dataset[idx] for idx in batch_indices]
return self.collate_fn(data)
if self.auto_collate_batch:
data = [self.dataset[idx] for idx in batch_indices]
else:
data = self.dataset[batch_indices]

if self.collate_fn:
data = self.collate_fn(data)
return data
36 changes: 32 additions & 4 deletions python/paddle/fluid/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,21 @@ class DataLoader(object):
For :code:`batch_sampler` please see :code:`paddle.io.BatchSampler`
**Disable automatic batching**
In certain cases such as some NLP tasks, instead of automatic batching,
handling batching manually in dataset is needed by users. For these
cases, automatic batching is disabled if both :attr:`batch_size` and
:attr:`batch_sampler` is set as None, each data got from :attr:`dataset`
should be batched data and will be processed with function define by
:attr:`collate_fn` or :attr:`default_collate_fn`.
.. note::
When automatic batching is disabled, :attr:`default_collate_fn` will
do nothing to data from dataset.
Args:
dataset(Dataset): the dataset to load data from, should be an
instance of subclass of :code:`paddle.io.Dataset` or
Expand All @@ -185,7 +200,7 @@ class DataLoader(object):
batch_sampler(BatchSampler): an instance of `paddle.io.BatchSampler`
to generate batch indices to draw samples from :attr:`dataset`
and combine a batch. Default None.
batch_size(int): sample number in a mini-batch, a substitution
batch_size(int|None): sample number in a mini-batch, a substitution
parameter for :attr:`batch_sampler`, if :attr:`batch_sampler`
is not set, a default `paddle.io.BatchSampler` will be used
and initialize by :attr:`batch_size`, :attr:`shuffle` and
Expand Down Expand Up @@ -358,10 +373,15 @@ def __init__(self,
"batch_size/shuffle/drop_last should not be set when " \
"batch_sampler is given"
self.batch_sampler = batch_sampler
self.batch_size = None
elif batch_size is None:
self.batch_sampler = None
self.batch_size = None
else:
assert batch_size is not None and batch_size > 0, \
"batch_size should be a positive value when " \
assert batch_size > 0, \
"batch_size should be None or a positive value when " \
"batch_sampler is not given"
self.batch_size = batch_size
if isinstance(dataset, IterableDataset):
self.batch_sampler = _InfiniteIterableSampler(dataset,
batch_size)
Expand All @@ -372,13 +392,21 @@ def __init__(self,
shuffle=shuffle,
drop_last=drop_last)

self.auto_collate_batch = self.batch_sampler is not None

self.pin_memory = False
if in_dygraph_mode():
self.pin_memory = True if use_pinned_memory(
) is None else use_pinned_memory()

def __len__(self):
return len(self.batch_sampler)
if self.dataset_kind == _DatasetKind.ITER:
raise ValueError("length of IterableDataset not supported")
else:
if self.batch_size is None:
return len(self.dataset)
else:
return len(self.batch_sampler)

def __iter__(self):
if self.num_workers == 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from paddle.fluid.dygraph.nn import Linear
from paddle.fluid.dygraph.base import to_variable

from test_multiprocess_dataloader_static import RandomDataset, prepare_places
from test_multiprocess_dataloader_static import RandomDataset, RandomBatchedDataset, prepare_places
from test_multiprocess_dataloader_static import EPOCH_NUM, BATCH_SIZE, IMAGE_SIZE, SAMPLE_NUM, CLASS_NUM


Expand Down Expand Up @@ -122,5 +122,48 @@ def test_main(self):
self.assertLess(diff, 1e-2)


class TestDygraphDataLoaderWithBatchedDataset(TestDygraphDataLoader):
def run_main(self, num_workers, places):
fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1
with fluid.dygraph.guard(places[0]):
fc_net = SimpleFCNet()
optimizer = fluid.optimizer.Adam(parameter_list=fc_net.parameters())

dataset = RandomBatchedDataset(SAMPLE_NUM, CLASS_NUM)
dataloader = DataLoader(
dataset,
num_workers=num_workers,
batch_size=None,
drop_last=True)
assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)

step_list = []
loss_list = []
start_t = time.time()
for _ in six.moves.range(EPOCH_NUM):
step = 0
for image, label in dataloader():
out = fc_net(image)
loss = fluid.layers.cross_entropy(out, label)
avg_loss = fluid.layers.reduce_mean(loss)
avg_loss.backward()
optimizer.minimize(avg_loss)
fc_net.clear_gradients()

loss_list.append(np.mean(avg_loss.numpy()))
step += 1
step_list.append(step)

end_t = time.time()
ret = {
"time": end_t - start_t,
"step": step_list,
"loss": np.array(loss_list)
}
print("time cost", ret['time'], 'step_list', ret['step'])
return ret


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def _collate_fn(sample_list):
indices_queue.put(None)
_worker_loop(loader._dataset, 0, indices_queue,
loader._data_queue, loader._workers_done_event,
_collate_fn, _init_fn, 0, 1,
True, _collate_fn, _init_fn, 0, 1,
loader._use_shared_memory)
self.assertTrue(False)
except AssertionError:
Expand Down Expand Up @@ -232,7 +232,7 @@ def _collate_fn(sample_list):
loader._workers_done_event.set()
_worker_loop(loader._dataset, 0, indices_queue,
loader._data_queue, loader._workers_done_event,
_collate_fn, _init_fn, 0, 1,
True, _collate_fn, _init_fn, 0, 1,
loader._use_shared_memory)
self.assertTrue(True)
except AssertionError:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from paddle.fluid.dygraph.nn import Linear
from paddle.fluid.dygraph.base import to_variable

from test_multiprocess_dataloader_iterable_dataset_static import RandomDataset, prepare_places
from test_multiprocess_dataloader_iterable_dataset_static import RandomDataset, RandomBatchedDataset, prepare_places
from test_multiprocess_dataloader_iterable_dataset_static import EPOCH_NUM, BATCH_SIZE, IMAGE_SIZE, SAMPLE_NUM, CLASS_NUM


Expand Down Expand Up @@ -119,5 +119,46 @@ def test_main(self):
0]


class TestDygraphDataLoaderWithBatchedDataset(TestDygraphDataLoader):
def run_main(self, num_workers, places):
fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1
with fluid.dygraph.guard(places[0]):
fc_net = SimpleFCNet()
optimizer = fluid.optimizer.Adam(parameter_list=fc_net.parameters())

dataset = RandomBatchedDataset(SAMPLE_NUM, CLASS_NUM)
dataloader = DataLoader(
dataset,
num_workers=num_workers,
batch_size=None,
drop_last=True)

step_list = []
loss_list = []
start_t = time.time()
for _ in six.moves.range(EPOCH_NUM):
step = 0
for image, label in dataloader():
out = fc_net(image)
loss = fluid.layers.cross_entropy(out, label)
avg_loss = fluid.layers.reduce_mean(loss)
avg_loss.backward()
optimizer.minimize(avg_loss)
fc_net.clear_gradients()

loss_list.append(np.mean(avg_loss.numpy()))
step += 1
step_list.append(step)

end_t = time.time()
ret = {
"time": end_t - start_t,
"step": step_list,
"loss": np.array(loss_list)
}
print("time cost", ret['time'], 'step_list', ret['step'])
return ret

if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 89d27de

Please sign in to comment.