Skip to content

Commit

Permalink
Add new argparse syntax for boolean values
Browse files Browse the repository at this point in the history
Utilized new syntax for booleans, which allow for setting positive or
negative values on the cli via `no` prefixing:

`--compile` - compiles (True)
`--no-compile` - does not compile (False)

This also allows for setting the default value independently from
presence of the boolean flag.
  • Loading branch information
gkielian committed Nov 8, 2023
1 parent 94a4fcd commit 4b0754b
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,21 +81,21 @@ def parse_args():

# System args
training_group.add_argument('--device', default='cuda', type=str)
training_group.add_argument('--dtype', default='float16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16', type=str)
training_group.add_argument('--compile', type=bool, default=False)
training_group.add_argument("--dtype", type=str, default="bfloat16", choices=["bfloat16", "float16", "float32"], help="torch data type for inference, e.g. 'int8'")
training_group.add_argument('--compile', default=False, action=argparse.BooleanOptionalAction)

# Logging args
logging_group.add_argument('--log_project', default='out-test', type=str)
logging_group.add_argument('--log_run_name', default='logs-test', type=str)

# Tensorboard args
logging_group.add_argument('--tensorboard_log', type=bool, default=True)
logging_group.add_argument('--tensorboard_log', default=True, action=argparse.BooleanOptionalAction)
logging_group.add_argument('--tensorboard_log_dir', type=str, default='logs')
logging_group.add_argument('--tensorboard_project', type=str, default='out-test')
logging_group.add_argument('--tensorboard_run_name', type=str, default='logs-test')

# Wandb args
logging_group.add_argument('--wandb_log', type=bool, default=False)
logging_group.add_argument('--wandb_log', default=False, action=argparse.BooleanOptionalAction)
logging_group.add_argument('--wandb_project', type=str, default='out-test')
logging_group.add_argument('--wandb_run_name', type=str, default='logs-test')

Expand Down Expand Up @@ -140,7 +140,7 @@ def setup(self):
torch.backends.cudnn.allow_tf32 = True

self.device_type = 'cuda' if 'cuda' in self.args.device else 'cpu'
self.ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[self.args.dtype]
self.ptdtype = {"bfloat16" : torch.bfloat16, "float16" : torch.float16, "float32" : torch.float32}[self.args.dtype]
self.ctx = nullcontext() if self.device_type == 'cpu' else torch.amp.autocast(device_type=self.device_type, dtype=self.ptdtype)

# Data loader
Expand Down

0 comments on commit 4b0754b

Please sign in to comment.