Skip to content

Commit

Permalink
Add replacements for replace_sampler_ddp, resume_from_checkpoint_fit_…
Browse files Browse the repository at this point in the history
…path and few occurances of validation_epoch_end

Signed-off-by: Abhishree <abhishreetm@gmail.com>
  • Loading branch information
athitten committed Apr 20, 2023
1 parent d31740b commit 708f0d0
Show file tree
Hide file tree
Showing 54 changed files with 59 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ trainer:
num_nodes: 1
max_steps: 2285714 # precedence over max_epochs
num_sanity_val_steps: 0 # needed for bert pretraining from preproc
replace_sampler_ddp: false # needed for bert pretraining from preproc
use_distributed_sampler: false # needed for bert pretraining from preproc
accumulate_grad_batches: 1 # accumulates grads every k batches
precision: 16 # 16 to use AMP
accelerator: gpu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ trainer:
precision: 16
logger: False # logger provided by exp_manager
enable_checkpointing: False
replace_sampler_ddp: False
use_distributed_sampler: False
max_epochs: -1 # PTL default. In practice, max_steps will be reached first.
max_steps: 100000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ trainer:
precision: 16
logger: False # logger provided by exp_manager
enable_checkpointing: False
replace_sampler_ddp: False
use_distributed_sampler: False
max_epochs: -1 # PTL default. In practice we don't usually train for more than 1 epoch.
max_steps: 100000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ trainer:
precision: 16
logger: False # logger provided by exp_manager
enable_checkpointing: False
replace_sampler_ddp: False
use_distributed_sampler: False
max_epochs: -1 # PTL default. In practice, max_steps will be reached first.
max_steps: 100000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ trainer:
precision: 16
logger: False # logger provided by exp_manager
enable_checkpointing: False
replace_sampler_ddp: False
use_distributed_sampler: False
max_epochs: 3 # min 25 recommended
max_steps: -1 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10 # frequency with which training steps are logged
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ trainer:
precision: 16
logger: False # logger provided by exp_manager
enable_checkpointing: False
replace_sampler_ddp: False
use_distributed_sampler: False
max_epochs: -1 # PTL default. In practice we don't usually train for more than 1 epoch.
max_steps: 100000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ trainer:
precision: 16
logger: False # logger provided by exp_manager
enable_checkpointing: False
replace_sampler_ddp: False
use_distributed_sampler: False
max_epochs: -1 # PTL default. In practice we don't usually train for more than 1 epoch.
max_steps: 100000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ trainer:
precision: 16
logger: False # logger provided by exp_manager
enable_checkpointing: False
replace_sampler_ddp: False
use_distributed_sampler: False
max_epochs: -1 # PTL default. In practice we don't usually train for more than 1 epoch.
max_steps: 100000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ trainer:
precision: 16
logger: False # logger provided by exp_manager
enable_checkpointing: False
replace_sampler_ddp: False
use_distributed_sampler: False
max_epochs: -1 # PTL default. In practice, max_steps will be reached first.
max_steps: 100000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ trainer:
precision: 16
logger: False # logger provided by exp_manager
enable_checkpointing: False
replace_sampler_ddp: False
use_distributed_sampler: False
max_epochs: -1 # PTL default. In practice, max_steps will be reached first.
max_steps: 100000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ trainer:
precision: 16
logger: False # logger provided by exp_manager
enable_checkpointing: False
replace_sampler_ddp: False
use_distributed_sampler: False
benchmark: False

exp_manager:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ trainer:
precision: 16
logger: False # logger provided by exp_manager
enable_checkpointing: False
replace_sampler_ddp: False
use_distributed_sampler: False
benchmark: False

exp_manager:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ trainer:
precision: 16
logger: False # logger provided by exp_manager
enable_checkpointing: False
replace_sampler_ddp: False
use_distributed_sampler: False
max_epochs: 3
max_steps: -1 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ trainer:
precision: 16
logger: False # logger provided by exp_manager
enable_checkpointing: False
replace_sampler_ddp: False
use_distributed_sampler: False
max_epochs: 3
max_steps: -1 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ trainer:
precision: 16
logger: False # logger provided by exp_manager
enable_checkpointing: False
replace_sampler_ddp: False
use_distributed_sampler: False
max_epochs: 10
max_steps: -1 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ trainer:
precision: 16
logger: False # logger provided by exp_manager
enable_checkpointing: False
replace_sampler_ddp: False
use_distributed_sampler: False
max_epochs: -1 # PTL default. In practice, max_steps will be reached first.
max_steps: 100000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ trainer:
precision: 16
logger: False
enable_checkpointing: False
replace_sampler_ddp: False
use_distributed_sampler: False
max_epochs: 10
max_steps: -1
log_every_n_steps: 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ trainer:
precision: 16
logger: False # logger provided by exp_manager
enable_checkpointing: False
replace_sampler_ddp: False
use_distributed_sampler: False
max_epochs: -1 # PTL default. In practice, max_steps will be reached first.
max_steps: 100000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def main(cfg) -> None:
if cfg.model.resume_from_checkpoint is not None:
resume_from_checkpoint = cfg.model.resume_from_checkpoint
else:
resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path
resume_from_checkpoint = trainer._checkpoint_connector._ckpt_path
logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}')

trainer._checkpoint_connector = CheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def main(cfg) -> None:
exp_manager(trainer, cfg.exp_manager)

# update resume from checkpoint found by exp_manager
resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path
resume_from_checkpoint = trainer._checkpoint_connector._ckpt_path
# resume_from_checkpoint = uninject_model_parallel_rank(resume_from_checkpoint)
logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}')

Expand Down
2 changes: 1 addition & 1 deletion examples/nlp/language_modeling/megatron_gpt_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def main(cfg) -> None:
if cfg.model.resume_from_checkpoint is not None:
resume_from_checkpoint = cfg.model.resume_from_checkpoint
else:
resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path
resume_from_checkpoint = trainer._checkpoint_connector._ckpt_path

logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}')

Expand Down
2 changes: 1 addition & 1 deletion examples/nlp/language_modeling/megatron_retro_fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def main(cfg) -> None:
exp_manager(trainer, cfg.exp_manager)

# update resume from checkpoint found by exp_manager
resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path
resume_from_checkpoint = trainer._checkpoint_connector._ckpt_path
logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}')

trainer._checkpoint_connector = CheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def main(cfg) -> None:
exp_manager(trainer, cfg.exp_manager)

# update resume from checkpoint found by exp_manager
resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path
resume_from_checkpoint = trainer._checkpoint_connector._ckpt_path
# resume_from_checkpoint = uninject_model_parallel_rank(resume_from_checkpoint)
logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def main(cfg) -> None:
exp_manager(trainer, cfg.exp_manager)

# update resume from checkpoint found by exp_manager
resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path
resume_from_checkpoint = trainer._checkpoint_connector._ckpt_path
# resume_from_checkpoint = uninject_model_parallel_rank(resume_from_checkpoint)
logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def main(cfg) -> None:
if cfg.model.resume_from_checkpoint is not None:
resume_from_checkpoint = cfg.model.resume_from_checkpoint
else:
resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path
resume_from_checkpoint = trainer._checkpoint_connector._ckpt_path
logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}')

trainer._checkpoint_connector = CheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint)
Expand Down
2 changes: 1 addition & 1 deletion examples/nlp/language_modeling/megatron_t5_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def main(cfg) -> None:
if cfg.model.resume_from_checkpoint is not None:
resume_from_checkpoint = cfg.model.resume_from_checkpoint
else:
resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path
resume_from_checkpoint = trainer._checkpoint_connector._ckpt_path
logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}')

trainer._checkpoint_connector = CheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def main(cfg) -> None:
if cfg.model.resume_from_checkpoint is not None:
resume_from_checkpoint = cfg.model.resume_from_checkpoint
else:
resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path
resume_from_checkpoint = trainer._checkpoint_connector._ckpt_path
logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}')

trainer._checkpoint_connector = CheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ trainer:
precision: 16
logger: False # logger provided by exp_manager
enable_checkpointing: False
replace_sampler_ddp: False
use_distributed_sampler: False
max_epochs: -1
max_steps: 100 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ trainer:
precision: 16
logger: False # logger provided by exp_manager
enable_checkpointing: False
replace_sampler_ddp: False
use_distributed_sampler: False
max_epochs: -1
max_steps: 100 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ trainer:
precision: 16
logger: False # logger provided by exp_manager
enable_checkpointing: False
replace_sampler_ddp: False
use_distributed_sampler: False
max_epochs: 9999
max_steps: 20000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10 # frequency with which training steps are logged
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ trainer:
precision: 16
logger: False
enable_checkpointing: False
replace_sampler_ddp: False
use_distributed_sampler: False
max_epochs: -1
max_steps: 100
log_every_n_steps: 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ trainer:
precision: 16
logger: False
enable_checkpointing: False
replace_sampler_ddp: False
use_distributed_sampler: False
max_epochs: -1
max_steps: 100
log_every_n_steps: 10
Expand Down
2 changes: 1 addition & 1 deletion examples/nlp/language_modeling/tuning/megatron_gpt_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def main(cfg) -> None:
if cfg.model.resume_from_checkpoint is not None:
resume_from_checkpoint = cfg.model.resume_from_checkpoint
else:
resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path
resume_from_checkpoint = trainer._checkpoint_connector._ckpt_path
logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}')

trainer._checkpoint_connector = CheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ trainer:
accelerator: gpu
logger: False # logger provided by exp_manager
enable_checkpointing: False
replace_sampler_ddp: False
use_distributed_sampler: False
max_epochs: 1000 # PTL default. In practice, max_steps will be reached first.
max_steps: 400000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10
Expand Down
2 changes: 1 addition & 1 deletion examples/nlp/machine_translation/megatron_nmt_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def main(cfg) -> None:
if cfg.model.resume_from_checkpoint is not None:
resume_from_checkpoint = cfg.model.resume_from_checkpoint
else:
resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path
resume_from_checkpoint = trainer._checkpoint_connector._ckpt_path
logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}')

trainer._checkpoint_connector = CheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint)
Expand Down
2 changes: 1 addition & 1 deletion examples/tts/vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

@hydra_runner(config_path="conf", config_name="vits")
def main(cfg):
trainer = pl.Trainer(replace_sampler_ddp=False, **cfg.trainer)
trainer = pl.Trainer(use_distributed_sampler=False, **cfg.trainer)
exp_manager(trainer, cfg.get("exp_manager", None))
model = VitsModel(cfg=cfg.model, trainer=trainer)

Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/metrics/multi_binary_acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def validation_step(self, batch, batch_idx):
f1_acc = self._accuracy.compute()
return {'val_loss': loss, 'val_f1_acc': f1_acc}
def validation_epoch_end(self, outputs):
def on_validation_epoch_end(self, outputs):
...
val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean()
correct_counts = torch.stack([x['val_correct_counts'] for x in outputs]).sum(axis=0)
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/metrics/rnnt_wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,7 +1147,7 @@ def validation_step(self, batch, batch_idx):
wer_num, wer_denom = self.__wer(predictions, transcript, transcript_len)
return {'val_loss': loss_value, 'val_wer_num': wer_num, 'val_wer_denom': wer_denom}
def validation_epoch_end(self, outputs):
def on_validation_epoch_end(self, outputs):
...
wer_num = torch.stack([x['val_wer_num'] for x in outputs]).sum()
wer_denom = torch.stack([x['val_wer_denom'] for x in outputs]).sum()
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/metrics/rnnt_wer_bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def validation_step(self, batch, batch_idx):
wer_num, wer_denom = self.__wer(predictions, transcript, transcript_len)
return {'val_loss': loss_value, 'val_wer_num': wer_num, 'val_wer_denom': wer_denom}
def validation_epoch_end(self, outputs):
def on_validation_epoch_end(self, outputs):
...
wer_num = torch.stack([x['val_wer_num'] for x in outputs]).sum()
wer_denom = torch.stack([x['val_wer_denom'] for x in outputs]).sum()
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/metrics/wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1090,7 +1090,7 @@ def validation_step(self, batch, batch_idx):
wer_num, wer_denom = self.__wer(predictions, transcript, transcript_len)
return {'val_loss': loss_value, 'val_wer_num': wer_num, 'val_wer_denom': wer_denom}
def validation_epoch_end(self, outputs):
def on_validation_epoch_end(self, outputs):
...
wer_num = torch.stack([x['val_wer_num'] for x in outputs]).sum()
wer_denom = torch.stack([x['val_wer_denom'] for x in outputs]).sum()
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/metrics/wer_bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def validation_step(self, batch, batch_idx):
wer_num, wer_denom = self.__wer(predictions, transcript, transcript_len)
return {'val_loss': loss_value, 'val_wer_num': wer_num, 'val_wer_denom': wer_denom}
def validation_epoch_end(self, outputs):
def on_validation_epoch_end(self, outputs):
...
wer_num = torch.stack([x['val_wer_num'] for x in outputs]).sum()
wer_denom = torch.stack([x['val_wer_denom'] for x in outputs]).sum()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ def setup(self, stage=None):
f'Total number of model parameters: {total_num_parameters:.2e}.'
)

resume_checkpoint_path = self.trainer._checkpoint_connector.resume_from_checkpoint_fit_path
resume_checkpoint_path = self.trainer._checkpoint_connector._ckpt_path
if resume_checkpoint_path:
init_consumed_samples = self._extract_consumed_samples_from_ckpt(resume_checkpoint_path)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,7 @@ def setup(self, stage=None):
f'Total number of model parameters: {total_num_parameters:.2e}.'
)

resume_checkpoint_path = self.trainer._checkpoint_connector.resume_from_checkpoint_fit_path
resume_checkpoint_path = self.trainer._checkpoint_connector._ckpt_path
if resume_checkpoint_path:
init_consumed_samples = self._extract_consumed_samples_from_ckpt(resume_checkpoint_path)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def _metrics_require_string2category_map(self):
def setup(self, stage=None):
# NOTE: super().__init__ will try and setup train/val/test datasets, but we sidestep this using a if self._train_ds is not None condition
# We then set things up for real only once setup() of this class is called.
resume_checkpoint_path = self.trainer._checkpoint_connector.resume_from_checkpoint_fit_path
resume_checkpoint_path = self.trainer._checkpoint_connector._ckpt_path
if resume_checkpoint_path:
init_consumed_samples = self._extract_consumed_samples_from_ckpt(resume_checkpoint_path)
else:
Expand Down
Loading

0 comments on commit 708f0d0

Please sign in to comment.