From 94fcaaf5d78fe38642c02c8b178195b08431a5ca Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Wed, 28 Apr 2021 02:04:25 +0530 Subject: [PATCH] Add `debug` flag to TPU Training Plugins (PT_XLA_DEBUG) (#7219) --- CHANGELOG.md | 3 ++ pytorch_lightning/accelerators/tpu.py | 4 ++- .../plugins/training_type/single_tpu.py | 8 +++++- .../plugins/training_type/tpu_spawn.py | 7 ++++- tests/models/test_tpu.py | 28 +++++++++++++++++++ 5 files changed, 47 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a08263748d5d..8247d1eb549e3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -117,6 +117,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added new `EarlyStopping` parameters `stopping_threshold` and `divergence_threshold` ([#6868](https://github.com/PyTorchLightning/pytorch-lightning/pull/6868)) +- Added `debug` flag to TPU Training Plugins (PT_XLA_DEBUG) ([#7219](https://github.com/PyTorchLightning/pytorch-lightning/pull/7219)) + + - Added new `UnrepeatedDistributedSampler` and `IndexBatchSamplerWrapper` for tracking distributed predictions ([#7215](https://github.com/PyTorchLightning/pytorch-lightning/pull/7215)) diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 2f8852159b4f8..6bbf88e35d026 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from typing import Any, Callable, Union from torch.optim import Optimizer @@ -51,7 +52,8 @@ def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None: return super().setup(trainer, model) def teardown(self) -> None: - pass + if "PT_XLA_DEBUG" in os.environ: + del os.environ["PT_XLA_DEBUG"] def run_optimizer_step( self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index 60cfaef9842fa..fce325f322cc3 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os + import torch from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin @@ -24,11 +26,12 @@ class SingleTPUPlugin(SingleDevicePlugin): """ Plugin for training on a single TPU device. """ - def __init__(self, device: int): + def __init__(self, device: int, debug: bool = False): device = xm.xla_device(device) super().__init__(device) + self.debug = debug self.tpu_local_core_rank = 0 self.tpu_global_core_rank = 0 @@ -47,6 +50,9 @@ def pre_dispatch(self) -> None: if isinstance(self.device, int): self.device = xm.xla_device(self.device) + if self.debug: + os.environ["PT_XLA_DEBUG"] = str(1) + self.tpu_local_core_rank = xm.get_local_ordinal() self.tpu_global_core_rank = xm.get_ordinal() diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 693f57da3cf4f..2303b27a7ea3b 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -49,8 +49,9 @@ class TPUSpawnPlugin(DDPSpawnPlugin): """ Plugin for training multiple TPU devices using the :func:`torch.multiprocessing.spawn` method. """ - def __init__(self, parallel_devices: Optional[List[int]] = None, **_: Any) -> None: + def __init__(self, parallel_devices: Optional[List[int]] = None, debug: bool = False, **_: Any) -> None: super().__init__(parallel_devices, num_nodes=1, cluster_environment=None, sync_batchnorm=False) + self.debug = debug self.tpu_local_core_rank = 0 self.tpu_global_core_rank = 0 self.start_method = None @@ -104,6 +105,10 @@ def connect(self, model: 'pl.LightningModule') -> None: self.wrapped_model = xmp.MpModelWrapper(LightningDistributedModule(model)) return super().connect(model) + def pre_dispatch(self): + if self.debug: + os.environ["PT_XLA_DEBUG"] = str(1) + def setup(self, model: Module) -> Module: self.create_mp_queue() return self.model diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 7154e036bcbf5..39be1620909ee 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -442,3 +442,31 @@ def test_sync_dist(rank): assert res["test_tensor"].item() == 8, "Result-Log does not work properly with TPU Spawn and Tensors" xmp.spawn(test_sync_dist, nprocs=8, start_method='fork') + + +@RunIf(tpu=True) +@pl_multi_process_test +def test_tpu_debug_mode(tmpdir): + """Test if debug mode works on TPU.""" + + class DebugModel(BoringModel): + + def on_train_start(self): + assert os.environ.get("PT_XLA_DEBUG") == str(1), "PT_XLA_DEBUG was not set in environment variables" + + def teardown(self, stage): + assert "PT_XLA_DEBUG" not in os.environ + + tutils.reset_seed() + trainer_options = dict( + default_root_dir=tmpdir, + progress_bar_refresh_rate=0, + max_epochs=4, + tpu_cores=8, + limit_train_batches=0.4, + limit_val_batches=0.4, + plugins=TPUSpawnPlugin(debug=True), + ) + + model = DebugModel() + tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)