Skip to content

Commit

Permalink
Generalise BERT_pytorch's device option (#1801)
Browse files Browse the repository at this point in the history
Summary:
Generalises the interface to the `BERT_pytorch` model so that you can use arbitrary devices, not just CPU or CUDA.

Pull Request resolved: #1801

Reviewed By: FindHao

Differential Revision: D47954195

Pulled By: xuzhao9

fbshipit-source-id: 51092614052a57c1c22e29b697a9aef2a8aa285b
  • Loading branch information
hmellor authored and facebook-github-bot committed Aug 1, 2023
1 parent 673279d commit 00c01df
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 11 deletions.
4 changes: 2 additions & 2 deletions torchbenchmark/models/BERT_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(self, test, device, batch_size=None, extra_args=[]):
'--vocab_path', f'{root}/data/vocab.small',
'--output_path', 'bert.model',
]) # Avoid reading sys.argv here
args.with_cuda = self.device == 'cuda'
args.device = self.device
args.script = False
args.on_memory = True

Expand Down Expand Up @@ -147,7 +147,7 @@ def __init__(self, test, device, batch_size=None, extra_args=[]):

trainer = BERTTrainer(bert, len(vocab), train_dataloader=train_data_loader, test_dataloader=test_data_loader,
lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
with_cuda=args.with_cuda, cuda_devices=args.cuda_devices, log_freq=args.log_freq, debug=args.debug)
device=args.device, device_ids=args.device_ids, log_freq=args.log_freq, debug=args.debug)

if test == "eval":
bert.eval()
Expand Down
8 changes: 6 additions & 2 deletions torchbenchmark/models/BERT_pytorch/bert_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ def parse_args(args=None):
parser.add_argument("-e", "--epochs", type=int, default=10, help="number of epochs")
parser.add_argument("-w", "--num_workers", type=int, default=0, help="dataloader worker size")

parser.add_argument("--with_cuda", type=bool, default=True, help="training with CUDA: true, or false")
parser.add_argument("--device", default=0, help="Device to use for training, str or int (CUDA only)")
parser.add_argument("--log_freq", type=int, default=10, help="printing loss every n iter: setting n")
parser.add_argument("--corpus_lines", type=int, default=None, help="total number of lines in corpus")
parser.add_argument("--cuda_devices", type=int, nargs='+', default=None, help="CUDA device ids")
parser.add_argument("--device_ids", nargs='+', default=None, help="Device ids, str or int (CUDA only)")
parser.add_argument("--on_memory", type=bool, default=True, help="Loading on memory: true or false")

parser.add_argument("--lr", type=float, default=1e-3, help="learning rate of adam")
Expand All @@ -32,4 +32,8 @@ def parse_args(args=None):
parser.add_argument("--adam_beta2", type=float, default=0.999, help="adam first beta value")

parsed_args = parser.parse_args(args)
if isinstance(parsed_args.device, str) and parsed_args.device.isdigit():
parsed_args.device = int(parsed_args.device)
if isinstance(parsed_args.device_ids, str) and parsed_args.device_ids.isdigit():
parsed_args.device_ids = int(parsed_args.device_ids)
return parsed_args
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def train():
print("Creating BERT Trainer")
trainer = BERTTrainer(bert, len(vocab), train_dataloader=train_data_loader, test_dataloader=test_data_loader,
lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
with_cuda=args.with_cuda, cuda_devices=args.cuda_devices, log_freq=args.log_freq, debug=args.debug)
device=args.device, device_ids=args.device_ids, log_freq=args.log_freq, debug=args.debug)

print("Training Start")
for epoch in range(args.epochs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class BERTTrainer:
def __init__(self, bert: BERT, vocab_size: int,
train_dataloader: DataLoader, test_dataloader: DataLoader = None,
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=10000,
with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, debug: str = None):
device: str = "cuda", device_ids=None, log_freq: int = 10, debug: str = None):
"""
:param bert: BERT model which you want to train
:param vocab_size: total word vocab size
Expand All @@ -30,22 +30,21 @@ def __init__(self, bert: BERT, vocab_size: int,
:param lr: learning rate of optimizer
:param betas: Adam optimizer betas
:param weight_decay: Adam optimizer weight decay param
:param with_cuda: traning with cuda
:param device: device to use for training
:param log_freq: logging frequency of the batch iteration
"""

# Setup cuda device for BERT training, argument -c, --cuda should be true
cuda_condition = torch.cuda.is_available() and with_cuda
self.device = torch.device("cuda:0" if cuda_condition else "cpu")
self.device = torch.device(device)

# This BERT model will be saved every epoch
self.bert = bert
# Initialize the BERT Language Model, with BERT model
self.model = BERTLM(bert, vocab_size).to(self.device)

# Distributed GPU training if CUDA can detect more than 1 GPU
if with_cuda and torch.cuda.device_count() > 1:
self.model = nn.DataParallel(self.model, device_ids=cuda_devices)
if self.device.type == "cuda" and torch.cuda.device_count() > 1:
self.model = nn.DataParallel(self.model, device_ids=device_ids)

# Setting the train and test data loader
self.train_data = train_dataloader
Expand Down

0 comments on commit 00c01df

Please sign in to comment.