diff --git a/CHANGELOG.md b/CHANGELOG.md index ff333c0fcabb8..dc5844f2b8191 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added flag `replace_sampler_ddp` to manually disaple sampler replacement in ddp ([#1513](https://github.com/PyTorchLightning/pytorch-lightning/pull/1513)) - Added `auto_select_gpus` flag to trainer that enables automatic selection of available GPUs on exclusive mode systems. - Added learning rate finder ([#1347](https://github.com/PyTorchLightning/pytorch-lightning/pull/1347)) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index ca9e09939d8e4..bfbf058407947 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -61,6 +61,7 @@ class TrainerDataLoadingMixin(ABC): train_percent_check: float val_percent_check: float test_percent_check: float + replace_sampler_ddp: bool @abstractmethod def is_overriden(self, *args): @@ -88,10 +89,8 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader: # don't do anything if it's not a dataloader if not isinstance(dataloader, DataLoader): return dataloader - - need_dist_sampler = self.use_ddp or self.use_ddp2 or self.use_tpu - - if need_dist_sampler: + need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_tpu) + if self.replace_sampler_ddp and need_dist_sampler: skip_keys = ['sampler', 'batch_sampler', 'dataset_kind'] diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0e9bb209c7879..4324fca09f55a 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -127,6 +127,7 @@ def __init__( benchmark: bool = False, reload_dataloaders_every_epoch: bool = False, auto_lr_find: Union[bool, str] = False, + replace_sampler_ddp: bool = True, default_save_path=None, # backward compatible, todo: remove in v0.8.0 gradient_clip=None, # backward compatible, todo: remove in v0.8.0 nb_gpu_nodes=None, # backward compatible, todo: remove in v0.8.0 @@ -282,6 +283,9 @@ def __init__( rate in self.hparams.lr | self.hparams.learning_rate in the lightning module. To use a different key, set a string instead of True with the key name. + replace_sampler_ddp: Explicitly enables or disables sampler replacement. + If not specified this will toggled automatically ddp is used + benchmark: If true enables cudnn.benchmark. terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the @@ -362,6 +366,7 @@ def __init__( self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch self.auto_lr_find = auto_lr_find + self.replace_sampler_ddp = replace_sampler_ddp self.truncated_bptt_steps = truncated_bptt_steps self.resume_from_checkpoint = resume_from_checkpoint