Skip to content

Commit

Permalink
Stateless timer fix for PTL 1.6 (#3925)
Browse files Browse the repository at this point in the history
* Stateless timer fix for PTL 1.6

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Stateless timer PTL test

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Fix year

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Style

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Remove unused imports

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Style

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* GPU test

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Style

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* clean import

Signed-off-by: ericharper <complex451@gmail.com>

Co-authored-by: ericharper <complex451@gmail.com>
  • Loading branch information
MaximumEntropy and ericharper committed Apr 8, 2022
1 parent 9da27f1 commit 29cce8e
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 21 deletions.
25 changes: 4 additions & 21 deletions nemo/utils/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@
from pytorch_lightning.loggers import LoggerCollection as _LoggerCollection
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.distributed import rank_zero_info

from nemo.constants import NEMO_ENV_VARNAME_TESTING, NEMO_ENV_VARNAME_VERSION
from nemo.utils import logging, timers
Expand Down Expand Up @@ -904,24 +902,9 @@ class StatelessTimer(Timer):
def __init__(self, duration: timedelta = None, interval: str = Interval.step, verbose: bool = True,) -> None:
super().__init__(duration, interval, verbose)

def on_save_checkpoint(self, trainer, pl_module, checkpoint) -> Dict[str, Any]:
return
# Override PTL Timer's state dict to not store elapsed time information so that we can restore and continue training.
def state_dict(self) -> Dict[str, Any]:
return {}

def on_load_checkpoint(self, trainer, pl_module, callback_state) -> None:
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
return

def _check_time_remaining(self, trainer) -> None:
# Default timer only checks for train time exceeding max_time, this includes time for all stages.
train_duration = self.time_elapsed(RunningStage.TRAINING)
validation_duration = self.time_elapsed(RunningStage.VALIDATING)
test_duration = self.time_elapsed(RunningStage.TESTING)
total_duration = train_duration + validation_duration + test_duration
should_stop = total_duration >= self._duration
# should_stop = trainer.training_type_plugin.broadcast(should_stop)
should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop)
trainer.should_stop = trainer.should_stop or should_stop
if should_stop and self._verbose:
rank_zero_info(f"Time limit reached. Signaling Trainer to stop.")
rank_zero_info(
f"Spent {timedelta(seconds=train_duration)} seconds on training, {timedelta(seconds=validation_duration)} seconds on validation and {timedelta(seconds=test_duration)} seconds on testing"
)
130 changes: 130 additions & 0 deletions tests/core_ptl/test_ptl_stateless_timer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil

import pytest
import torch
from omegaconf import OmegaConf
from pytorch_lightning import Trainer
from pytorch_lightning.utilities.distributed import rank_zero_only

from nemo.core import ModelPT
from nemo.utils import logging
from nemo.utils.exp_manager import CallbackParams, ExpManagerConfig, StatelessTimer, exp_manager


class OnesDataset(torch.utils.data.Dataset):
def __init__(self, dataset_len):
super().__init__()
self.__dataset_len = dataset_len

def __getitem__(self, *args):
return torch.ones(2)

def __len__(self):
return self.__dataset_len


class ExampleModel(ModelPT):
def __init__(self, *args, **kwargs):
cfg = OmegaConf.structured({})
super().__init__(cfg, trainer=kwargs.get('trainer', None))
# dummy parameter in order to allow DDP to execute
self.l1 = torch.nn.modules.Linear(in_features=2, out_features=1)

def train_dataloader(self):
dataset = OnesDataset(10000)
return torch.utils.data.DataLoader(dataset, batch_size=2)

def val_dataloader(self):
dataset = OnesDataset(10)
return torch.utils.data.DataLoader(dataset, batch_size=2)

def predict_dataloader(self):
dataset = OnesDataset(10)
return torch.utils.data.DataLoader(dataset, batch_size=2)

def forward(self, batch):
return (self.l1(batch) - batch.mean(dim=1)).mean()

def validation_step(self, batch, batch_idx):
return (self.l1(batch) - batch.mean(dim=1)).mean()

def training_step(self, batch, batch_idx):
return (self.l1(batch) - batch.mean(dim=1)).mean()

def list_available_models(self):
pass

def setup_training_data(self):
pass

def setup_validation_data(self):
pass

def validation_epoch_end(self, loss):
self.log("val_loss", torch.stack(loss).mean())


class TestStatelessTimer:
def setup_model(self):
# Stateless timer for 3 seconds.
# Max steps shouldn't matter for it should stop in 3 seconds based on the timer.
# Val check interval makes sure a checkpoint is written and can be restored from.
callback_params = CallbackParams()
callback_params.monitor = "val_loss"
callback_params.save_top_k = 1
trainer = Trainer(
devices=1,
val_check_interval=5,
max_steps=10000,
accelerator='gpu',
strategy='ddp',
logger=None,
callbacks=[StatelessTimer('00:00:00:03')],
checkpoint_callback=False,
)
exp_manager_cfg = ExpManagerConfig(
explicit_log_dir='./ptl_stateless_timer_check/',
use_datetime_version=False,
version="",
resume_ignore_no_checkpoint=True,
create_checkpoint_callback=True,
checkpoint_callback_params=callback_params,
resume_if_exists=True,
)
exp_manager(trainer, cfg=OmegaConf.structured(exp_manager_cfg))
model = ExampleModel(trainer=trainer)
trainer.fit(model)
return trainer

def cleanup(self):
if os.path.exists('./ptl_stateless_timer_check'):
shutil.rmtree('./ptl_stateless_timer_check', ignore_errors=True)

@pytest.mark.run_only_on('GPU')
@pytest.mark.unit
def test_stateless_timer(self):
self.cleanup()
trainer = self.setup_model()
global_step_1 = trainer.global_step
trainer = self.setup_model()
global_step_2 = trainer.global_step
trainer = self.setup_model()
global_step_3 = trainer.global_step
logging.info(f"Global steps : {global_step_1}, {global_step_2}, {global_step_3}")
assert global_step_3 > global_step_2 > global_step_1
self.cleanup()

0 comments on commit 29cce8e

Please sign in to comment.