From 465eecc0e4b173d8a7968dbe2f9351255f7e2ec2 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Fri, 17 Apr 2020 09:51:26 +0200 Subject: [PATCH 1/6] Add explicit flag for ddp sampler replacement --- pytorch_lightning/trainer/data_loading.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index ca9e09939d8e4..98f41712e39dc 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -61,6 +61,7 @@ class TrainerDataLoadingMixin(ABC): train_percent_check: float val_percent_check: float test_percent_check: float + replace_sampler_ddp: bool @abstractmethod def is_overriden(self, *args): @@ -89,9 +90,7 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader: if not isinstance(dataloader, DataLoader): return dataloader - need_dist_sampler = self.use_ddp or self.use_ddp2 or self.use_tpu - - if need_dist_sampler: + if self.replace_sampler_ddp: skip_keys = ['sampler', 'batch_sampler', 'dataset_kind'] From b4a51bb3677444e8d6f04530ab77c615717f9b9b Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Fri, 17 Apr 2020 09:56:21 +0200 Subject: [PATCH 2/6] Add flag for sampler replacement in ddp --- pytorch_lightning/trainer/trainer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0e9bb209c7879..cc91494864a4c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -127,6 +127,7 @@ def __init__( benchmark: bool = False, reload_dataloaders_every_epoch: bool = False, auto_lr_find: Union[bool, str] = False, + replace_sampler_ddp: bool = True, default_save_path=None, # backward compatible, todo: remove in v0.8.0 gradient_clip=None, # backward compatible, todo: remove in v0.8.0 nb_gpu_nodes=None, # backward compatible, todo: remove in v0.8.0 @@ -281,6 +282,9 @@ def __init__( trying to optimize initial learning for faster convergence. Sets learning rate in self.hparams.lr | self.hparams.learning_rate in the lightning module. To use a different key, set a string instead of True with the key name. + + replace_sampler_ddp: Explicitly enables or disables sampler replacement. + If not specified this will toggled automatically ddp is used benchmark: If true enables cudnn.benchmark. @@ -362,6 +366,7 @@ def __init__( self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch self.auto_lr_find = auto_lr_find + self.replace_sampler_ddp = replace_sampler_ddp self.truncated_bptt_steps = truncated_bptt_steps self.resume_from_checkpoint = resume_from_checkpoint From 0bb65192b37c2c069401dde20bc9475ba7091b53 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Fri, 17 Apr 2020 09:57:19 +0200 Subject: [PATCH 3/6] Update data_loading.py --- pytorch_lightning/trainer/data_loading.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 98f41712e39dc..bfbf058407947 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -89,8 +89,8 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader: # don't do anything if it's not a dataloader if not isinstance(dataloader, DataLoader): return dataloader - - if self.replace_sampler_ddp: + need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_tpu) + if self.replace_sampler_ddp and need_dist_sampler: skip_keys = ['sampler', 'batch_sampler', 'dataset_kind'] From 03bcca1dc963574b45fa58663c5ec94d41b1baf4 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Fri, 17 Apr 2020 10:07:25 +0200 Subject: [PATCH 4/6] Update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ff333c0fcabb8..dc5844f2b8191 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added flag `replace_sampler_ddp` to manually disaple sampler replacement in ddp ([#1513](https://github.com/PyTorchLightning/pytorch-lightning/pull/1513)) - Added `auto_select_gpus` flag to trainer that enables automatic selection of available GPUs on exclusive mode systems. - Added learning rate finder ([#1347](https://github.com/PyTorchLightning/pytorch-lightning/pull/1347)) From d0c0380de87739dd9ed8236a4ab2ff05a59ef332 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Fri, 17 Apr 2020 10:08:54 +0200 Subject: [PATCH 5/6] pep8 fixes --- pytorch_lightning/trainer/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cc91494864a4c..1177c0f9fcb96 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -282,8 +282,8 @@ def __init__( trying to optimize initial learning for faster convergence. Sets learning rate in self.hparams.lr | self.hparams.learning_rate in the lightning module. To use a different key, set a string instead of True with the key name. - - replace_sampler_ddp: Explicitly enables or disables sampler replacement. + + replace_sampler_ddp: Explicitly enables or disables sampler replacement. If not specified this will toggled automatically ddp is used benchmark: If true enables cudnn.benchmark. From 8ef962f716aca400b6496deb3c8eeec5990974dd Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Fri, 17 Apr 2020 10:15:26 +0200 Subject: [PATCH 6/6] pep8 --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 1177c0f9fcb96..4324fca09f55a 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -282,7 +282,7 @@ def __init__( trying to optimize initial learning for faster convergence. Sets learning rate in self.hparams.lr | self.hparams.learning_rate in the lightning module. To use a different key, set a string instead of True with the key name. - + replace_sampler_ddp: Explicitly enables or disables sampler replacement. If not specified this will toggled automatically ddp is used