Skip to content

Commit

Permalink
Merge branch 'tpu/debug' of https://github.com/kaushikb11/pytorch-lig…
Browse files Browse the repository at this point in the history
…htning into tpu/debug
  • Loading branch information
kaushikb11 committed Apr 27, 2021
2 parents 9086b5e + 424860a commit bd63e0d
Show file tree
Hide file tree
Showing 33 changed files with 454 additions and 87 deletions.
13 changes: 12 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `debug` flag to TPU Training Plugins (PT_XLA_DEBUG) ([#7219](https://github.com/PyTorchLightning/pytorch-lightning/pull/7219))


- Added new `UnrepeatedDistributedSampler` and `IndexBatchSamplerWrapper` for tracking distributed predictions ([#7215](https://github.com/PyTorchLightning/pytorch-lightning/pull/7215))


- Added `trainer.predict(return_predictions=None|False|True)` ([#7215](https://github.com/PyTorchLightning/pytorch-lightning/pull/7215))


### Changed

Expand Down Expand Up @@ -147,7 +152,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed warnings and recommendations for dataloaders in `ddp_spawn` ([#6762](https://github.com/PyTorchLightning/pytorch-lightning/pull/6762/))


- `pl.seed_everyting` will now also set the seed on the `DistributedSampler` ([#7024](https://github.com/PyTorchLightning/pytorch-lightning/pull/7024))
- `pl.seed_everything` will now also set the seed on the `DistributedSampler` ([#7024](https://github.com/PyTorchLightning/pytorch-lightning/pull/7024))


- Changed default setting for communication of multi-node training using `DDPShardedPlugin` ([#6937](https://github.com/PyTorchLightning/pytorch-lightning/pull/6937))


### Deprecated
Expand Down Expand Up @@ -347,6 +355,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed parsing for pre-release package versions ([#6999](https://github.com/PyTorchLightning/pytorch-lightning/pull/6999))


- Fixed `num_sanity_val_steps` affecting reproducibility of training data shuffling ([#7014](https://github.com/PyTorchLightning/pytorch-lightning/pull/7014))


- Fixed resetting device after `fitting/evaluating/predicting` ([#7188](https://github.com/PyTorchLightning/pytorch-lightning/pull/7188))


Expand Down
2 changes: 1 addition & 1 deletion docs/source/advanced/multi_gpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ To use Sharded Training, you need to first install FairScale using the command b
.. code-block:: python
# train using Sharded DDP
trainer = Trainer(accelerator='ddp', plugins='ddp_sharded')
trainer = Trainer(plugins='ddp_sharded')
Sharded Training can work across all DDP variants by adding the additional ``--plugins ddp_sharded`` flag.

Expand Down
7 changes: 6 additions & 1 deletion pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,12 @@ def clip_gradients(
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
) -> None:
"""clips all the optimizer parameters to the given value"""
self.precision_plugin.clip_gradients(optimizer, clip_val, gradient_clip_algorithm=gradient_clip_algorithm)
self.precision_plugin.clip_gradients(
optimizer,
clip_val,
gradient_clip_algorithm=gradient_clip_algorithm,
model=self.model,
)

def on_train_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
"""Hook to do something on the end of an training epoch
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

class _DataModuleWrapper(type):

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.__has_added_checks = False

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class LightningModule(
"model_size",
] + DeviceDtypeModuleMixin.__jit_unused_properties__

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

# see (https://github.com/pytorch/pytorch/blob/3e6bb5233f9ca2c5aa55d9cda22a7ee85439aa6e/
Expand Down
70 changes: 69 additions & 1 deletion pytorch_lightning/overrides/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from typing import Any
from typing import Any, Iterator, List, Optional

import torch
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import BatchSampler, DistributedSampler, Sampler

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
Expand Down Expand Up @@ -75,3 +76,70 @@ def prepare_for_backward(model: DistributedDataParallel, output: Any):
model.reducer.prepare_for_backward([])
else:
model.require_forward_param_sync = False


class UnrepeatedDistributedSampler(DistributedSampler):
"""
A fork of the pytorch DistributedSampler that doesn't repeat data, instead
allowing the number of batches per process to be off-by-one from each other.
This makes this sampler usable for predictions (it's deterministic and
doesn't require shuffling). It is potentially unsafe to use this sampler for
training, because during training the DistributedDataParallel syncs buffers
on each forward pass, so it could freeze if one of the processes runs one
fewer batch. During prediction, buffers are only synced on the first batch,
so this is safe to use as long as each process runs at least one batch. We
verify this in an assert.
Taken from https://github.com/jpuigcerver/PyLaia/blob/v1.0.0/laia/data/unpadded_distributed_sampler.py
and https://github.com/pytorch/pytorch/issues/25162#issuecomment-634146002
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.num_samples = len(range(self.rank, len(self.dataset), self.num_replicas))
self.total_size = len(self.dataset)
# If any process has at least one batch, every other process needs to
# have at least one batch, or the DistributedDataParallel could lock up.
assert self.num_samples >= 1 or self.total_size == 0

def __iter__(self) -> Iterator[List[int]]:
if self.shuffle:
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))

assert len(indices) == self.total_size

# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples

return iter(indices)


class IndexBatchSamplerWrapper:
"""This class is used to wrap a :class:`torch.utils.data.BatchSampler` and capture its indices."""

def __init__(self, sampler: BatchSampler) -> None:
self._sampler = sampler
self.batch_indices: Optional[List[int]] = None

def __iter__(self) -> Iterator[List[int]]:
for batch in self._sampler:
self.batch_indices = batch
yield batch

@property
def drop_last(self) -> bool:
return self._sampler.drop_last

@property
def batch_size(self) -> int:
return self._sampler.batch_size

@property
def sampler(self) -> Sampler:
return self._sampler.sampler
6 changes: 4 additions & 2 deletions pytorch_lightning/plugins/precision/deepspeed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Union
from typing import Any, Callable, Optional, Union

from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer

import pytorch_lightning as pl
Expand Down Expand Up @@ -79,8 +80,9 @@ def clip_gradients(
optimizer: Optimizer,
clip_val: Union[int, float],
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
model: Optional[Module] = None,
) -> None:
"""
DeepSpeed handles clipping gradients via the training type plugin.
DeepSpeed handles clipping gradients internally via the training type plugin.
"""
pass
3 changes: 2 additions & 1 deletion pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Callable, List, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -104,6 +104,7 @@ def clip_gradients(
optimizer: Optimizer,
clip_val: Union[int, float],
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
model: Optional[Module] = None
) -> None:
"""Clips the gradients"""
if clip_val is None:
Expand Down
7 changes: 2 additions & 5 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
)
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.seed import reset_seed

if _HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
Expand Down Expand Up @@ -180,10 +180,7 @@ def _call_children_scripts(self):
sleep(delay)

def setup_distributed(self):
# TODO: check if needed
seed = os.environ.get("PL_GLOBAL_SEED")
if seed is not None:
seed_everything(int(seed))
reset_seed()

# determine which process we are and world size
self.set_world_ranks()
Expand Down
7 changes: 2 additions & 5 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.seed import reset_seed

if _TORCH_GREATER_EQUAL_1_8:
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
Expand Down Expand Up @@ -132,10 +132,7 @@ def start_predicting(self, trainer):
def new_process(self, process_idx, trainer, mp_queue):
self.mp_queue = mp_queue

# TODO: check if needed
seed = os.environ.get("PL_GLOBAL_SEED")
if seed is not None:
seed_everything(int(seed))
reset_seed()

self.set_world_ranks(process_idx)

Expand Down
23 changes: 20 additions & 3 deletions pytorch_lightning/plugins/training_type/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from pytorch_lightning.core.optimizer import is_lightning_optimizer
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _FAIRSCALE_AVAILABLE:
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
Expand All @@ -32,10 +33,15 @@
class DDPShardedPlugin(DDPPlugin):
""" Optimizer and gradient sharded training provided by FairScale. """

_REDUCE_BUFFER_SIZE_DEFAULT = 2**23 # 8M

def configure_ddp(self):
self._wrap_optimizers()
self._model = ShardedDataParallel(
LightningShardedDataParallel(self.model), sharded_optimizer=self.lightning_module.trainer.optimizers
LightningShardedDataParallel(self.model),
sharded_optimizer=self.lightning_module.trainer.optimizers,
# For multi-node training, enabling bucketing will improve performance.
reduce_buffer_size=self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0,
)
setattr(self._model, "require_backward_grad_sync", False)

Expand All @@ -47,6 +53,12 @@ def _reinit_optimizers_with_oss(self):
if not isinstance(optimizer, OSS):
optim_class = type(optimizer)
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE:
is_fp16 = self.lightning_module.trainer.precision == 16
# For multi-node training, compressing the model shards in fp16 before broadcasting
# improves performance. When using PyTorch AMP, it will not degrade
# the model performance.
zero_optimizer.broadcast_fp16 = is_fp16 and self.num_nodes > 1
optimizers[x] = zero_optimizer
del optimizer
trainer = self.lightning_module.trainer
Expand All @@ -58,7 +70,7 @@ def _wrap_optimizers(self):
return
self._reinit_optimizers_with_oss()

def optimizer_state(self, optimizer: 'OSS') -> Optional[dict]:
def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
if is_lightning_optimizer(optimizer):
optimizer = optimizer._optimizer
optimizer.consolidate_state_dict()
Expand All @@ -74,6 +86,11 @@ def _optim_state_dict(self, optimizer):

@property
def lightning_module(self) -> LightningModule:
if not _FAIRSCALE_AVAILABLE: # pragma: no cover
raise MisconfigurationException(
"`DDPShardedPlugin` requires `fairscale` to be installed."
" Install it by running `pip install fairscale`."
)
return unwrap_lightning_module_sharded(self._model)

def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/plugins/training_type/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _FAIRSCALE_AVAILABLE:
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
Expand Down Expand Up @@ -71,6 +72,11 @@ def _optim_state_dict(self, optimizer):

@property
def lightning_module(self) -> LightningModule:
if not _FAIRSCALE_AVAILABLE: # pragma: no cover
raise MisconfigurationException(
"`DDPSpawnShardedPlugin` requires `fairscale` to be installed."
" Install it by running `pip install fairscale`."
)
return unwrap_lightning_module_sharded(self._model)

def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
Expand Down
6 changes: 2 additions & 4 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from pytorch_lightning.utilities.data import has_len
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.seed import reset_seed

if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -142,9 +142,7 @@ def set_world_ranks(self, process_idx: int = 0) -> None:
def new_process(self, process_idx: int, trainer, mp_queue) -> None:
self.mp_queue = mp_queue

seed = os.environ.get("PL_GLOBAL_SEED")
if seed is not None:
seed_everything(int(seed))
reset_seed()

self.tpu_local_core_rank = xm.get_local_ordinal()
self.tpu_global_core_rank = xm.get_ordinal()
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def on_trainer_init(
# configure swa callback
self._configure_swa_callbacks()

# configure the timer callback.
# responsible to stop the training when max_time is reached.
self._configure_timer_callback(max_time)

# init progress bar
Expand Down
Loading

0 comments on commit bd63e0d

Please sign in to comment.