Skip to content

Commit

Permalink
[Feature] Support persistent_workers in DataLoader (PyTorch>=1.7.0) (o…
Browse files Browse the repository at this point in the history
  • Loading branch information
xvjiarui authored Jun 28, 2021
1 parent 60baa4e commit bf746bf
Showing 1 changed file with 33 additions and 24 deletions.
57 changes: 33 additions & 24 deletions mmseg/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from functools import partial

import numpy as np
import torch
from mmcv.parallel import collate
from mmcv.runner import get_dist_info
from mmcv.utils import Registry, build_from_cfg
from mmcv.utils.parrots_wrapper import DataLoader, PoolDataLoader
from torch.utils.data import DistributedSampler
from torch.utils.data import DataLoader, DistributedSampler

if platform.system() != 'Windows':
# https://github.com/pytorch/pytorch/issues/973
Expand Down Expand Up @@ -84,7 +84,7 @@ def build_dataloader(dataset,
seed=None,
drop_last=False,
pin_memory=True,
dataloader_type='PoolDataLoader',
persistent_workers=True,
**kwargs):
"""Build PyTorch DataLoader.
Expand All @@ -106,7 +106,11 @@ def build_dataloader(dataset,
Default: False
pin_memory (bool): Whether to use pin_memory in DataLoader.
Default: True
dataloader_type (str): Type of dataloader. Default: 'PoolDataLoader'
persistent_workers (bool): If True, the data loader will not shutdown
the worker processes after a dataset has been consumed once.
This allows to maintain the workers Dataset instances alive.
The argument also has effect in PyTorch>=1.7.0.
Default: True
kwargs: any keyword argument to be used to initialize DataLoader
Returns:
Expand All @@ -128,26 +132,31 @@ def build_dataloader(dataset,
worker_init_fn, num_workers=num_workers, rank=rank,
seed=seed) if seed is not None else None

assert dataloader_type in (
'DataLoader',
'PoolDataLoader'), f'unsupported dataloader {dataloader_type}'

if dataloader_type == 'PoolDataLoader':
dataloader = PoolDataLoader
elif dataloader_type == 'DataLoader':
dataloader = DataLoader

data_loader = dataloader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
pin_memory=pin_memory,
shuffle=shuffle,
worker_init_fn=init_fn,
drop_last=drop_last,
**kwargs)
if torch.__version__ >= '1.7.0':
data_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
pin_memory=pin_memory,
shuffle=shuffle,
worker_init_fn=init_fn,
drop_last=drop_last,
persistent_workers=persistent_workers,
**kwargs)
else:
data_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
pin_memory=pin_memory,
shuffle=shuffle,
worker_init_fn=init_fn,
drop_last=drop_last,
**kwargs)

return data_loader

Expand Down

0 comments on commit bf746bf

Please sign in to comment.