Skip to content

Commit

Permalink
Update training script
Browse files Browse the repository at this point in the history
Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>
  • Loading branch information
MaximumEntropy committed Jun 18, 2022
1 parent 5299acd commit 7a7ad85
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions examples/nlp/machine_translation/megatron_nmt_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,12 @@ def main(cfg) -> None:
pretrained_cfg.train_ds = cfg.model.train_ds
pretrained_cfg.train_ds.micro_batch_size = cfg.model.micro_batch_size
pretrained_cfg.train_ds.global_batch_size = cfg.model.global_batch_size
pretrained_cfg.validation_ds = cfg.model.validation_ds
pretrained_cfg.test_ds = cfg.model.test_ds
if hasattr(cfg.model, 'validation_ds'):
pretrained_cfg.validation_ds = cfg.model.validation_ds
else:
raise AttributeError(f"No validation dataset found in config.")
if hasattr(cfg.model, 'test_ds'):
pretrained_cfg.test_ds = cfg.model.test_ds

# Class target for the new class being restored.
pretrained_cfg.target = (
Expand Down

0 comments on commit 7a7ad85

Please sign in to comment.