diff --git a/torchbenchmark/models/BERT_pytorch/__init__.py b/torchbenchmark/models/BERT_pytorch/__init__.py index 4df8874ddc..3505240570 100644 --- a/torchbenchmark/models/BERT_pytorch/__init__.py +++ b/torchbenchmark/models/BERT_pytorch/__init__.py @@ -86,7 +86,7 @@ def __init__(self, test, device, batch_size=None, extra_args=[]): '--vocab_path', f'{root}/data/vocab.small', '--output_path', 'bert.model', ]) # Avoid reading sys.argv here - args.with_cuda = self.device == 'cuda' + args.device = self.device args.script = False args.on_memory = True @@ -147,7 +147,7 @@ def __init__(self, test, device, batch_size=None, extra_args=[]): trainer = BERTTrainer(bert, len(vocab), train_dataloader=train_data_loader, test_dataloader=test_data_loader, lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, - with_cuda=args.with_cuda, cuda_devices=args.cuda_devices, log_freq=args.log_freq, debug=args.debug) + device=args.device, device_ids=args.device_ids, log_freq=args.log_freq, debug=args.debug) if test == "eval": bert.eval() diff --git a/torchbenchmark/models/BERT_pytorch/bert_pytorch/__init__.py b/torchbenchmark/models/BERT_pytorch/bert_pytorch/__init__.py index 1b8f65ebb2..fd2f6e802c 100644 --- a/torchbenchmark/models/BERT_pytorch/bert_pytorch/__init__.py +++ b/torchbenchmark/models/BERT_pytorch/bert_pytorch/__init__.py @@ -20,10 +20,10 @@ def parse_args(args=None): parser.add_argument("-e", "--epochs", type=int, default=10, help="number of epochs") parser.add_argument("-w", "--num_workers", type=int, default=0, help="dataloader worker size") - parser.add_argument("--with_cuda", type=bool, default=True, help="training with CUDA: true, or false") + parser.add_argument("--device", default=0, help="Device to use for training, str or int (CUDA only)") parser.add_argument("--log_freq", type=int, default=10, help="printing loss every n iter: setting n") parser.add_argument("--corpus_lines", type=int, default=None, help="total number of lines in corpus") - parser.add_argument("--cuda_devices", type=int, nargs='+', default=None, help="CUDA device ids") + parser.add_argument("--device_ids", nargs='+', default=None, help="Device ids, str or int (CUDA only)") parser.add_argument("--on_memory", type=bool, default=True, help="Loading on memory: true or false") parser.add_argument("--lr", type=float, default=1e-3, help="learning rate of adam") @@ -32,4 +32,8 @@ def parse_args(args=None): parser.add_argument("--adam_beta2", type=float, default=0.999, help="adam first beta value") parsed_args = parser.parse_args(args) + if isinstance(parsed_args.device, str) and parsed_args.device.isdigit(): + parsed_args.device = int(parsed_args.device) + if isinstance(parsed_args.device_ids, str) and parsed_args.device_ids.isdigit(): + parsed_args.device_ids = int(parsed_args.device_ids) return parsed_args diff --git a/torchbenchmark/models/BERT_pytorch/bert_pytorch/__main__.py b/torchbenchmark/models/BERT_pytorch/bert_pytorch/__main__.py index 1e91d38e24..35a1fd4d3e 100644 --- a/torchbenchmark/models/BERT_pytorch/bert_pytorch/__main__.py +++ b/torchbenchmark/models/BERT_pytorch/bert_pytorch/__main__.py @@ -49,7 +49,7 @@ def train(): print("Creating BERT Trainer") trainer = BERTTrainer(bert, len(vocab), train_dataloader=train_data_loader, test_dataloader=test_data_loader, lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, - with_cuda=args.with_cuda, cuda_devices=args.cuda_devices, log_freq=args.log_freq, debug=args.debug) + device=args.device, device_ids=args.device_ids, log_freq=args.log_freq, debug=args.debug) print("Training Start") for epoch in range(args.epochs): diff --git a/torchbenchmark/models/BERT_pytorch/bert_pytorch/trainer/pretrain.py b/torchbenchmark/models/BERT_pytorch/bert_pytorch/trainer/pretrain.py index 67840ef54e..57f3168b8f 100644 --- a/torchbenchmark/models/BERT_pytorch/bert_pytorch/trainer/pretrain.py +++ b/torchbenchmark/models/BERT_pytorch/bert_pytorch/trainer/pretrain.py @@ -21,7 +21,7 @@ class BERTTrainer: def __init__(self, bert: BERT, vocab_size: int, train_dataloader: DataLoader, test_dataloader: DataLoader = None, lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=10000, - with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, debug: str = None): + device: str = "cuda", device_ids=None, log_freq: int = 10, debug: str = None): """ :param bert: BERT model which you want to train :param vocab_size: total word vocab size @@ -30,13 +30,12 @@ def __init__(self, bert: BERT, vocab_size: int, :param lr: learning rate of optimizer :param betas: Adam optimizer betas :param weight_decay: Adam optimizer weight decay param - :param with_cuda: traning with cuda + :param device: device to use for training :param log_freq: logging frequency of the batch iteration """ # Setup cuda device for BERT training, argument -c, --cuda should be true - cuda_condition = torch.cuda.is_available() and with_cuda - self.device = torch.device("cuda:0" if cuda_condition else "cpu") + self.device = torch.device(device) # This BERT model will be saved every epoch self.bert = bert @@ -44,8 +43,8 @@ def __init__(self, bert: BERT, vocab_size: int, self.model = BERTLM(bert, vocab_size).to(self.device) # Distributed GPU training if CUDA can detect more than 1 GPU - if with_cuda and torch.cuda.device_count() > 1: - self.model = nn.DataParallel(self.model, device_ids=cuda_devices) + if self.device.type == "cuda" and torch.cuda.device_count() > 1: + self.model = nn.DataParallel(self.model, device_ids=device_ids) # Setting the train and test data loader self.train_data = train_dataloader