Skip to content

Commit

Permalink
Modify trainer.precision check and other small edits
Browse files Browse the repository at this point in the history
Signed-off-by: Abhishree <abhishreetm@gmail.com>
  • Loading branch information
athitten committed May 1, 2023
1 parent 3901c89 commit 154dead
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,11 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
else:
self.model = Float16Module(module=self.model, precision=cfg.precision)

if self.trainer.precision == 'bf16':
if self.trainer.precision[0:4] == 'bf16':
self.autocast_dtype = torch.bfloat16
elif int(self.trainer.precision) == 32:
elif self.trainer.precision[0:2] == '32':
self.autocast_dtype = torch.float
elif int(self.trainer.precision) == 16:
elif self.trainer.precision[0:2] == '16':
self.autocast_dtype = torch.half
else:
raise ValueError('precision must be in [32, 16, "bf16"]')
Expand Down
3 changes: 1 addition & 2 deletions tests/collections/nlp/test_nlp_exportables.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,9 @@ def setup_method(self):
"accumulate_grad_batches": 1,
"precision": 32,
"accelerator": "gpu",
"strategy": None,
"strategy": 'auto',
"log_every_n_steps": 1,
"val_check_interval": 1,
"resume_from_checkpoint": None,
"enable_checkpointing": False,
"logger": False,
},
Expand Down
2 changes: 1 addition & 1 deletion tests/collections/nlp/test_rampup_batch_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def trainer_cfg():
'precision': 16,
'logger': False,
'enable_checkpointing': False,
'replace_sampler_ddp': False,
'use_distributed_sampler': False,
'max_epochs': 1,
'max_steps': 150,
'log_every_n_steps': 10,
Expand Down

0 comments on commit 154dead

Please sign in to comment.