Skip to content

Commit

Permalink
Set amp_o2 to false
Browse files Browse the repository at this point in the history
Signed-off-by: Virginia Adams <vadams@nvidia.com>
  • Loading branch information
vadam5 committed Jun 8, 2022
1 parent 1986b59 commit 54f60c3
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ trainer:
logger: False # logger provided by exp_manager
enable_checkpointing: False
replace_sampler_ddp: False
max_epochs: null
max_steps: 6000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
max_epochs: 3
max_steps: null # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10
val_check_interval: 1.0
accumulate_grad_batches: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):

# Need to overwrite some params in frozen model's config before restoring
with open_dict(frozen_model_cfg):
frozen_model_cfg.megatron_amp_O2 = False
frozen_model_cfg.micro_batch_size = self.cfg.micro_batch_size
frozen_model_cfg.global_batch_size = self.cfg.global_batch_size
frozen_model_cfg.precision = trainer.precision
Expand All @@ -104,23 +105,18 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
override_config_path=frozen_model_cfg,
)

self.float_type = torch.float
if self.frozen_model.cfg.precision == 16:
self.float_type = torch.float16
elif self.frozen_model.cfg.precision == 'bf16':
self.float_type = torch.bfloat16
else:
self.float_type = torch.float

# Make prompt learning model able to load gpt models trained with amp_o2
if self.frozen_model.megatron_amp_o2:
self.frozen_model.model = self.frozen_model.model.module

if self.frozen_model.cfg.precision == 16:
self.float_type = torch.float16
raise ValueError(
"fp16 training is not yet supported with O2. Please set megatron_amp_O2 to False in the model config."
)
elif self.frozen_model.cfg.precision == 'bf16':
self.float_type = torch.bfloat16
else:
raise ValueError(f'Precision {self.frozen_model.precision} Not supported with O2')
# if self.frozen_model.megatron_amp_o2:
# self.frozen_model.model = self.frozen_model.model.module

self.megatron_amp_o2 = self.frozen_model.cfg.get('megatron_amp_O2', False)
self.megatron_amp_o2 = False
self.tokenizer = self.frozen_model.tokenizer
self.hidden_size = self.frozen_model.cfg.hidden_size
self.existing_tasks = list(self.cfg.get('existing_tasks', []))
Expand Down

0 comments on commit 54f60c3

Please sign in to comment.