diff --git a/train.py b/train.py index 9b64b8822e..84f9745772 100644 --- a/train.py +++ b/train.py @@ -81,21 +81,21 @@ def parse_args(): # System args training_group.add_argument('--device', default='cuda', type=str) - training_group.add_argument('--dtype', default='float16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16', type=str) - training_group.add_argument('--compile', type=bool, default=False) + training_group.add_argument("--dtype", type=str, default="bfloat16", choices=["bfloat16", "float16", "float32"], help="torch data type for inference, e.g. 'int8'") + training_group.add_argument('--compile', default=False, action=argparse.BooleanOptionalAction) # Logging args logging_group.add_argument('--log_project', default='out-test', type=str) logging_group.add_argument('--log_run_name', default='logs-test', type=str) # Tensorboard args - logging_group.add_argument('--tensorboard_log', type=bool, default=True) + logging_group.add_argument('--tensorboard_log', default=True, action=argparse.BooleanOptionalAction) logging_group.add_argument('--tensorboard_log_dir', type=str, default='logs') logging_group.add_argument('--tensorboard_project', type=str, default='out-test') logging_group.add_argument('--tensorboard_run_name', type=str, default='logs-test') # Wandb args - logging_group.add_argument('--wandb_log', type=bool, default=False) + logging_group.add_argument('--wandb_log', default=False, action=argparse.BooleanOptionalAction) logging_group.add_argument('--wandb_project', type=str, default='out-test') logging_group.add_argument('--wandb_run_name', type=str, default='logs-test') @@ -140,7 +140,7 @@ def setup(self): torch.backends.cudnn.allow_tf32 = True self.device_type = 'cuda' if 'cuda' in self.args.device else 'cpu' - self.ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[self.args.dtype] + self.ptdtype = {"bfloat16" : torch.bfloat16, "float16" : torch.float16, "float32" : torch.float32}[self.args.dtype] self.ctx = nullcontext() if self.device_type == 'cpu' else torch.amp.autocast(device_type=self.device_type, dtype=self.ptdtype) # Data loader