From a4ad6132566a31ceb7c82b2c293768f2cdf23c94 Mon Sep 17 00:00:00 2001 From: nicolasrosa Date: Thu, 14 Mar 2024 14:42:10 -0300 Subject: [PATCH] [Mod] multinode.py minor changes --- distributed/ddp-tutorial-series/multinode.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/distributed/ddp-tutorial-series/multinode.py b/distributed/ddp-tutorial-series/multinode.py index e80636bcc4..893989931b 100644 --- a/distributed/ddp-tutorial-series/multinode.py +++ b/distributed/ddp-tutorial-series/multinode.py @@ -2,18 +2,22 @@ import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from datautils import MyTrainDataset +from utils import print_nodes_info +# --- Additional modules required for Distributed Training import torch.multiprocessing as mp from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel as DDP from torch.distributed import init_process_group, destroy_process_group import os +# --- def ddp_setup(): init_process_group(backend="nccl") torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + class Trainer: def __init__( self, @@ -53,7 +57,7 @@ def _run_batch(self, source, targets): def _run_epoch(self, epoch): b_sz = len(next(iter(self.train_data))[0]) - print(f"[GPU{self.global_rank}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}") + print(f"[GPU{self.global_rank}] Epoch {epoch} | Batch size: {b_sz} | Steps: {len(self.train_data)}") self.train_data.sampler.set_epoch(epoch) for source, targets in self.train_data: source = source.to(self.local_rank) @@ -108,5 +112,9 @@ def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str parser.add_argument('save_every', type=int, help='How often to save a snapshot') parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)') args = parser.parse_args() - + + # --- Print the environment variables + print_nodes_info() + # --- + main(args.save_every, args.total_epochs, args.batch_size)