Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto Parallel] Improve the fine-grained APIs #46552

Merged
merged 18 commits into from
Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/paddle/distributed/auto_parallel/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,10 @@ def set_field_default_config(category, field, default_value):
set_field_default_config(TUNING, "profile_end_step", 1)
set_field_default_config(TUNING, "run_after_tuning", True)
set_field_default_config(TUNING, "verbose", True)

#########################################
# dataset configuration
#########################################
DATASET = "dataset"
set_field_default_config(DATASET, "enable", False)
set_field_default_config(DATASET, "num_shards", 1)
183 changes: 133 additions & 50 deletions python/paddle/distributed/auto_parallel/dist_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,38 +17,11 @@

import paddle
from paddle.io import BatchSampler, IterableDataset
from paddle.fluid.dataloader.batch_sampler import _InfiniteIterableSampler
from paddle.fluid.dataloader.batch_sampler import _InfiniteIterableSampler, DistributedBatchSampler
from paddle.fluid.dataloader.dataloader_iter import _DatasetKind, default_collate_fn, default_convert_fn


class DistributedDataLoader(metaclass=abc.ABCMeta):

def __init__(self, dataset, batch_size=1, epochs=1, drop_last=False):
if isinstance(dataset, IterableDataset):
self.dataset_kind = _DatasetKind.ITER
else:
self.dataset_kind = _DatasetKind.MAP

self.dataset = dataset
self.epochs = epochs
self.drop_last = drop_last

if batch_size is None:
self.batch_size = None
self.batch_sampler = None
else:
self.batch_size = batch_size
if isinstance(dataset, IterableDataset):
self.batch_sampler = _InfiniteIterableSampler(
dataset, batch_size)
else:
self.batch_sampler = BatchSampler(dataset,
batch_size=batch_size,
shuffle=False,
drop_last=drop_last)

self.auto_collate_batch = self.batch_sampler is not None
self.sampler_iter = iter(self.index_sampler)
class DistributedDataLoaderBase(metaclass=abc.ABCMeta):

@abc.abstractmethod
def __iter__(self):
Expand All @@ -58,48 +31,70 @@ def __iter__(self):
def __next__(self):
raise NotImplementedError

@property
def index_sampler(self):
if self.auto_collate_batch:
return self.batch_sampler
else:
if self.dataset_kind == _DatasetKind.MAP:
return list(range(len(self.dataset)))
else:
return _InfiniteIterableSampler(self.dataset, 1)


class NonIterableGeneratorLoader(DistributedDataLoader):
class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):

def __init__(self,
dataset,
feed_list,
places,
feed_list=None,
capacity=None,
use_double_buffer=True,
iterable=True,
return_list=False,
use_multiprocess=False,
drop_last=True,
places=None,
batch_size=1,
epochs=1,
steps_per_epoch=None,
collate_fn=None,
split_data=True,
data_parallel_world_size=[],
data_parallel_rank=[],
drop_last=False,
split_data=True):
data_parallel_rank=[]):
self.dataset = dataset
self.feed_list = feed_list
self.capacity = capacity
self.use_double_buffer = use_double_buffer
self.iterable = iterable
self.return_list = return_list
self.use_multiprocess = use_multiprocess
self.drop_last = drop_last
self.places = places
self.batch_size = batch_size
self.epochs = epochs
self.steps_per_epoch = steps_per_epoch

self.collate_fn = collate_fn
self.split_data = split_data
assert len(data_parallel_world_size) == len(feed_list)
assert len(data_parallel_rank) == len(feed_list)
self.dp_world_sizes = data_parallel_world_size
self.dp_ranks = data_parallel_rank
self.split_data = split_data

super(NonIterableGeneratorLoader,
self).__init__(dataset, batch_size, epochs, drop_last)
if isinstance(dataset, IterableDataset):
self.dataset_kind = _DatasetKind.ITER
else:
self.dataset_kind = _DatasetKind.MAP

if self.batch_size is None:
self.batch_sampler = None
else:
if isinstance(dataset, IterableDataset):
self.batch_sampler = _InfiniteIterableSampler(
dataset, batch_size)
else:
self.batch_sampler = BatchSampler(dataset,
batch_size=batch_size,
shuffle=False,
drop_last=drop_last)

self.auto_collate_batch = self.batch_sampler is not None
self.sampler_iter = iter(self.index_sampler)

if self.auto_collate_batch:
self.collate_fn = collate_fn or default_collate_fn
else:
self.collate_fn = collate_fn or default_convert_fn

self.dataset_fetcher = _DatasetKind.create_fetcher(
self.dataset_kind, self.dataset, self.auto_collate_batch,
self.collate_fn, self.drop_last)
Expand All @@ -115,8 +110,10 @@ def __iter__(self):
def __next__(self):
if not self._steps:
self._cur_step += 1
return None
elif self._cur_step < self._steps:
self._cur_step += 1
return None
else:
self._inner_dataloader.reset()
self.sampler_iter = iter(self.index_sampler)
Expand All @@ -138,6 +135,16 @@ def _infer_steps(self):
)
return steps_per_epoch

@property
def index_sampler(self):
if self.auto_collate_batch:
return self.batch_sampler
else:
if self.dataset_kind == _DatasetKind.MAP:
return list(range(len(self.dataset)))
else:
return _InfiniteIterableSampler(self.dataset, 1)

def _create_inner_dataloader(self):

def data_generator():
Expand Down Expand Up @@ -170,7 +177,83 @@ def data_generator():
yield partial_data

dataloader = paddle.fluid.io.DataLoader.from_generator(
feed_list=self.feed_list, capacity=70, iterable=False)
feed_list=self.feed_list,
capacity=self.capacity,
use_double_buffer=self.use_double_buffer,
# iterable=self.iterable,
iterable=False,
return_list=self.return_list,
use_multiprocess=self.use_multiprocess,
drop_last=self.drop_last)
dataloader.set_batch_generator(data_generator, self.places)

return dataloader


class DistributedDataLoader(DistributedDataLoaderBase):

def __init__(self,
dataset,
feed_list=None,
places=None,
return_list=True,
batch_size=1,
shuffle=False,
drop_last=False,
collate_fn=None,
num_workers=0,
use_buffer_reader=True,
use_shared_memory=True,
timeout=0,
worker_init_fn=None,
epochs=1,
steps_per_epoch=None,
split_data=True,
data_parallel_world_size=[],
data_parallel_rank=[]):
self.dataset = dataset
self.feed_list = feed_list
self.return_list = return_list
self.places = places
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.collate_fn = collate_fn
self.num_workers = num_workers
self.use_buffer_reader = use_buffer_reader
self.use_shared_memory = use_shared_memory
self.timeout = timeout
self.worker_init_fn = worker_init_fn
self.epochs = epochs
self.steps_per_epoch = steps_per_epoch
self.dp_world_sizes = data_parallel_world_size
self.dp_ranks = data_parallel_rank
self.split_data = split_data
# TODO: rank info
self.batch_sampler = DistributedBatchSampler(
self.dataset, self.batch_size, self.dp_world_sizes[0],
self.dp_ranks[0], self.shuffle, self.drop_last)
self._inner_dataloader = self._create_inner_dataloader()

def __iter__(self):
return self

def __next__(self):
return next(self.data)

def _create_inner_dataloader(self):
dataloader = paddle.fluid.io.DataLoader(
self.dataset,
feed_list=self.feed_list,
places=self.places,
return_list=self.return_list,
batch_sampler=self.batch_sampler,
collate_fn=self.collate_fn,
num_workers=self.num_workers,
use_buffer_reader=self.use_buffer_reader,
use_shared_memory=self.use_shared_memory,
timeout=self.timeout,
worker_init_fn=self.worker_init_fn)
self.data = (x for x in dataloader)

return dataloader
Loading