Skip to content

Commit

Permalink
[Mod] multinode.py minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolasrosa committed Mar 14, 2024
1 parent c5bac2b commit a4ad613
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions distributed/ddp-tutorial-series/multinode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,22 @@
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from datautils import MyTrainDataset
from utils import print_nodes_info

# --- Additional modules required for Distributed Training
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
# ---


def ddp_setup():
init_process_group(backend="nccl")
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))


class Trainer:
def __init__(
self,
Expand Down Expand Up @@ -53,7 +57,7 @@ def _run_batch(self, source, targets):

def _run_epoch(self, epoch):
b_sz = len(next(iter(self.train_data))[0])
print(f"[GPU{self.global_rank}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}")
print(f"[GPU{self.global_rank}] Epoch {epoch} | Batch size: {b_sz} | Steps: {len(self.train_data)}")
self.train_data.sampler.set_epoch(epoch)
for source, targets in self.train_data:
source = source.to(self.local_rank)
Expand Down Expand Up @@ -108,5 +112,9 @@ def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str
parser.add_argument('save_every', type=int, help='How often to save a snapshot')
parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
args = parser.parse_args()


# --- Print the environment variables
print_nodes_info()
# ---

main(args.save_every, args.total_epochs, args.batch_size)

0 comments on commit a4ad613

Please sign in to comment.