Skip to content

Commit

Permalink
refactor accelerator teardown -> training type plugin teardown (#7579)
Browse files Browse the repository at this point in the history
  • Loading branch information
shuyingsunshine21 authored May 22, 2021
1 parent a8d9b5f commit 2242423
Show file tree
Hide file tree
Showing 15 changed files with 237 additions and 32 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- MLFlowLogger now accepts `run_name` as an constructor argument ([#7622](https://github.com/PyTorchLightning/pytorch-lightning/issues/7622))


- Changed `teardown()` in `Accelerator` to allow `training_type_plugin` to customize `teardown` logic ([#7579](https://github.com/PyTorchLightning/pytorch-lightning/pull/7579))


### Deprecated


Expand Down
8 changes: 3 additions & 5 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,17 +151,15 @@ def lightning_module(self) -> 'pl.LightningModule':

@property
def root_device(self) -> torch.device:
"""Returns the root device"""
return self.training_type_plugin.root_device

def teardown(self) -> None:
"""
This method is called to teardown the training process.
It is the right place to release memory and free other ressources.
By default we add a barrier here to synchronize processes before returning
control back to the caller.
It is the right place to release memory and free other resources.
"""
self.barrier("teardown")
self.training_type_plugin.teardown()

def batch_to_device(
self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: Optional[int] = None
Expand Down
7 changes: 0 additions & 7 deletions pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,6 @@ def on_train_start(self) -> None:
with torch.cuda.device(self.root_device):
torch.cuda.empty_cache()

def teardown(self) -> None:
self.lightning_module.cpu()

# clean up memory
with torch.cuda.device(self.root_device):
torch.cuda.empty_cache()

@staticmethod
def set_nvidia_flags(local_rank: int) -> None:
# set the correct cuda visible devices (using pci order)
Expand Down
5 changes: 0 additions & 5 deletions pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Callable

from torch.optim import Optimizer
Expand Down Expand Up @@ -51,10 +50,6 @@ def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:
raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.")
return super().setup(trainer, model)

def teardown(self) -> None:
if "PT_XLA_DEBUG" in os.environ:
del os.environ["PT_XLA_DEBUG"]

def run_optimizer_step(
self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
) -> None:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(
self.set_world_ranks()

@property
def root_device(self):
def root_device(self) -> torch.device:
return self.parallel_devices[self.local_rank]

@property
Expand Down Expand Up @@ -126,7 +126,7 @@ def distributed_sampler_kwargs(self):
def _is_single_process_single_device(self) -> bool:
return True

def setup_environment(self):
def setup_environment(self) -> None:
# start the other scripts
if not self.cluster_environment.creates_children() and os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1":
self._call_children_scripts()
Expand Down
17 changes: 15 additions & 2 deletions pytorch_lightning/plugins/training_type/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytorch_lightning.overrides.base import unwrap_lightning_module
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
from pytorch_lightning.utilities import _XLA_AVAILABLE
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, ReduceOp


Expand All @@ -40,13 +41,17 @@ def __init__(

@property
@abstractmethod
def root_device(self):
def root_device(self) -> torch.device:
raise NotImplementedError

@property
def on_gpu(self):
def on_gpu(self) -> bool:
return self.root_device.type == "cuda" and torch.cuda.is_available()

@property
def on_tpu(self) -> bool:
return self.root_device.type == "xla" and _XLA_AVAILABLE

@property
def lightning_module(self):
return unwrap_lightning_module(self._model)
Expand Down Expand Up @@ -122,3 +127,11 @@ def block_backward_sync(self):
yield None
else:
yield None

def teardown(self) -> None:
if self.on_gpu:
# GPU teardown
self.lightning_module.cpu()
# clean up memory
with torch.cuda.device(self.root_device):
torch.cuda.empty_cache()
13 changes: 11 additions & 2 deletions pytorch_lightning/plugins/training_type/single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch

from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
from pytorch_lightning.utilities import _XLA_AVAILABLE


class SingleDevicePlugin(TrainingTypePlugin):
Expand All @@ -30,11 +31,11 @@ def __init__(self, device: torch.device):

@property
def on_tpu(self) -> bool:
return False
return self.root_device.type == "xla" and _XLA_AVAILABLE

@property
def on_gpu(self) -> bool:
return self.device.type == "cuda" and torch.cuda.is_available()
return self.root_device.type == "cuda" and torch.cuda.is_available()

def reduce(self, tensor: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]:
"""
Expand Down Expand Up @@ -78,3 +79,11 @@ def barrier(self, *args, **kwargs) -> None:

def broadcast(self, obj: object, src: int = 0) -> object:
return obj

def teardown(self) -> None:
if self.on_gpu:
# GPU teardown
self.lightning_module.cpu()
# clean up memory
with torch.cuda.device(self.root_device):
torch.cuda.empty_cache()
8 changes: 4 additions & 4 deletions pytorch_lightning/plugins/training_type/single_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ def __init__(self, device: int, debug: bool = False):
self.tpu_local_core_rank = 0
self.tpu_global_core_rank = 0

@property
def on_tpu(self) -> bool:
return True

@property
def is_distributed(self) -> bool:
return False
Expand All @@ -63,3 +59,7 @@ def on_save(self, checkpoint: dict) -> dict:
https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors
"""
return move_data_to_device(checkpoint, torch.device("cpu"))

def teardown(self) -> None:
# TPU teardown
os.environ.pop("PT_XLA_DEBUG", None)
14 changes: 9 additions & 5 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def world_size(self) -> int:

@property
def root_device(self) -> torch.device:
return self.device
return xm.xla_device()

@staticmethod
def _validate_dataloader(dataloaders: Union[List[DataLoader], DataLoader]) -> None:
Expand Down Expand Up @@ -129,7 +129,7 @@ def is_distributed(self) -> bool:

def process_dataloader(self, dataloader: DataLoader) -> MpDeviceLoader:
TPUSpawnPlugin._validate_dataloader(dataloader)
return MpDeviceLoader(dataloader, self.device)
return MpDeviceLoader(dataloader, self.root_device)

def configure_ddp(self) -> None:
pass
Expand Down Expand Up @@ -172,8 +172,7 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
time.sleep(2)

def model_to_device(self) -> None:
self.device = xm.xla_device()
self.model = self.wrapped_model.to(self.device)
self.model = self.wrapped_model.to(self.root_device)

def barrier(self, name: Optional[str] = None) -> None:
# HOST_WORLD_SIZE is None outside the xmp.spawn process
Expand Down Expand Up @@ -209,7 +208,7 @@ def broadcast(self, obj: object, src: int = 0) -> object:
buffer = io.BytesIO()
torch.save(obj, buffer)
data = bytearray(buffer.getbuffer())
data_tensor = torch.tensor(data, device=self.device, dtype=torch.float)
data_tensor = torch.tensor(data, device=self.root_device, dtype=torch.float)
data = xm.all_gather(data_tensor)
buffer = io.BytesIO(data.cpu().byte().numpy())
obj = torch.load(buffer)
Expand Down Expand Up @@ -302,3 +301,8 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra
if isinstance(tensor, torch.Tensor) and tensor.dim() == 0:
tensor = tensor.unsqueeze(0)
return xm.all_gather(tensor)

def teardown(self) -> None:
# TPU teardown
os.environ.pop("PT_XLA_DEBUG", None)
self.barrier("teardown")
16 changes: 16 additions & 0 deletions pytorch_lightning/plugins/training_type/training_type_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,19 @@ def setup(self, model: Module) -> None:
@abstractmethod
def on_gpu(self) -> bool:
"""Returns whether the current process is done on GPU"""
raise NotImplementedError

@property
@abstractmethod
def on_tpu(self) -> bool:
"""Returns whether the current process is done on TPU"""
raise NotImplementedError

@property
@abstractmethod
def root_device(self) -> torch.device:
"""Returns the root device"""
raise NotImplementedError

@abstractmethod
def model_to_device(self) -> None:
Expand Down Expand Up @@ -290,6 +298,14 @@ def call_configure_sharded_model_hook(self) -> bool:
def call_configure_sharded_model_hook(self, mode: bool) -> None:
self._call_configure_sharded_model_hook = mode

@abstractmethod
def teardown(self) -> None:
"""
This method is called to teardown the training process.
It is the right place to release memory and free other resources.
"""
raise NotImplementedError

@classmethod
def register_plugins(cls, plugin_registry):
pass
48 changes: 48 additions & 0 deletions tests/plugins/test_ddp_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DDPPlugin
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf


class BoringModelGPU(BoringModel):

def on_train_start(self) -> None:
# make sure that the model is on GPU when training
assert self.device == torch.device(f"cuda:{self.trainer.training_type_plugin.local_rank}")
self.start_cuda_memory = torch.cuda.memory_allocated()


@RunIf(skip_windows=True, min_gpus=2, special=True)
def test_ddp_with_2_gpus():
"""Tests if device is set correctely when training and after teardown for DDPPlugin."""
trainer = Trainer(gpus=2, accelerator="ddp", fast_dev_run=True)
# assert training type plugin attributes for device setting
assert isinstance(trainer.training_type_plugin, DDPPlugin)
assert trainer.training_type_plugin.on_gpu
assert not trainer.training_type_plugin.on_tpu
local_rank = trainer.training_type_plugin.local_rank
assert trainer.training_type_plugin.root_device == torch.device(f"cuda:{local_rank}")

model = BoringModelGPU()

trainer.fit(model)

# assert after training, model is moved to CPU and memory is deallocated
assert model.device == torch.device("cpu")
cuda_memory = torch.cuda.memory_allocated()
assert cuda_memory < model.start_cuda_memory
42 changes: 42 additions & 0 deletions tests/plugins/test_ddp_spawn_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DDPSpawnPlugin
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf


class BoringModelDDPCPU(BoringModel):

def on_train_start(self) -> None:
# make sure that the model is on CPU when training
assert self.device == torch.device("cpu")


@RunIf(skip_windows=True)
def test_ddp_cpu():
"""Tests if device is set correctely when training for DDPSpawnPlugin."""
trainer = Trainer(num_processes=2, fast_dev_run=True)
# assert training type plugin attributes for device setting

assert isinstance(trainer.training_type_plugin, DDPSpawnPlugin)
assert not trainer.training_type_plugin.on_gpu
assert not trainer.training_type_plugin.on_tpu
assert trainer.training_type_plugin.root_device == torch.device("cpu")

model = BoringModelDDPCPU()

trainer.fit(model)
1 change: 1 addition & 0 deletions tests/plugins/test_deepspeed_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,7 @@ def test_deepspeed_multigpu_stage_2_accumulated_grad_batches(tmpdir, cpu_offload
"""
Test to ensure with Stage 2 and multiple GPUs, accumulated grad batches works.
"""
os.environ['MASTER_PORT'] = "29500"
seed_everything(42)

class VerificationCallback(Callback):
Expand Down
Loading

0 comments on commit 2242423

Please sign in to comment.