Skip to content

Commit

Permalink
Add debug flag to TPU Training Plugins (PT_XLA_DEBUG) (#7219)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikb11 authored Apr 27, 2021
1 parent e76ebd6 commit 94fcaaf
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 3 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added new `EarlyStopping` parameters `stopping_threshold` and `divergence_threshold` ([#6868](https://github.com/PyTorchLightning/pytorch-lightning/pull/6868))


- Added `debug` flag to TPU Training Plugins (PT_XLA_DEBUG) ([#7219](https://github.com/PyTorchLightning/pytorch-lightning/pull/7219))


- Added new `UnrepeatedDistributedSampler` and `IndexBatchSamplerWrapper` for tracking distributed predictions ([#7215](https://github.com/PyTorchLightning/pytorch-lightning/pull/7215))


Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Callable, Union

from torch.optim import Optimizer
Expand Down Expand Up @@ -51,7 +52,8 @@ def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:
return super().setup(trainer, model)

def teardown(self) -> None:
pass
if "PT_XLA_DEBUG" in os.environ:
del os.environ["PT_XLA_DEBUG"]

def run_optimizer_step(
self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/plugins/training_type/single_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import torch

from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
Expand All @@ -24,11 +26,12 @@
class SingleTPUPlugin(SingleDevicePlugin):
""" Plugin for training on a single TPU device. """

def __init__(self, device: int):
def __init__(self, device: int, debug: bool = False):

device = xm.xla_device(device)
super().__init__(device)

self.debug = debug
self.tpu_local_core_rank = 0
self.tpu_global_core_rank = 0

Expand All @@ -47,6 +50,9 @@ def pre_dispatch(self) -> None:
if isinstance(self.device, int):
self.device = xm.xla_device(self.device)

if self.debug:
os.environ["PT_XLA_DEBUG"] = str(1)

self.tpu_local_core_rank = xm.get_local_ordinal()
self.tpu_global_core_rank = xm.get_ordinal()

Expand Down
7 changes: 6 additions & 1 deletion pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@
class TPUSpawnPlugin(DDPSpawnPlugin):
""" Plugin for training multiple TPU devices using the :func:`torch.multiprocessing.spawn` method. """

def __init__(self, parallel_devices: Optional[List[int]] = None, **_: Any) -> None:
def __init__(self, parallel_devices: Optional[List[int]] = None, debug: bool = False, **_: Any) -> None:
super().__init__(parallel_devices, num_nodes=1, cluster_environment=None, sync_batchnorm=False)
self.debug = debug
self.tpu_local_core_rank = 0
self.tpu_global_core_rank = 0
self.start_method = None
Expand Down Expand Up @@ -104,6 +105,10 @@ def connect(self, model: 'pl.LightningModule') -> None:
self.wrapped_model = xmp.MpModelWrapper(LightningDistributedModule(model))
return super().connect(model)

def pre_dispatch(self):
if self.debug:
os.environ["PT_XLA_DEBUG"] = str(1)

def setup(self, model: Module) -> Module:
self.create_mp_queue()
return self.model
Expand Down
28 changes: 28 additions & 0 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,3 +442,31 @@ def test_sync_dist(rank):
assert res["test_tensor"].item() == 8, "Result-Log does not work properly with TPU Spawn and Tensors"

xmp.spawn(test_sync_dist, nprocs=8, start_method='fork')


@RunIf(tpu=True)
@pl_multi_process_test
def test_tpu_debug_mode(tmpdir):
"""Test if debug mode works on TPU."""

class DebugModel(BoringModel):

def on_train_start(self):
assert os.environ.get("PT_XLA_DEBUG") == str(1), "PT_XLA_DEBUG was not set in environment variables"

def teardown(self, stage):
assert "PT_XLA_DEBUG" not in os.environ

tutils.reset_seed()
trainer_options = dict(
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=4,
tpu_cores=8,
limit_train_batches=0.4,
limit_val_batches=0.4,
plugins=TPUSpawnPlugin(debug=True),
)

model = DebugModel()
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)

0 comments on commit 94fcaaf

Please sign in to comment.