Skip to content

Commit

Permalink
[Mod] multigpu.py debug print added
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolasrosa committed Mar 14, 2024
1 parent a4ad613 commit 1e81da3
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions distributed/ddp-tutorial-series/multigpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from datautils import MyTrainDataset
from icecream import ic

# --- Additional modules required for Distributed Training
import torch.multiprocessing as mp
Expand Down Expand Up @@ -86,6 +87,8 @@ def prepare_dataloader(dataset: Dataset, batch_size: int):


def main(rank: int, world_size: int, save_every: int, total_epochs: int, batch_size: int):
ic(rank, world_size)

ddp_setup(rank, world_size)
dataset, model, optimizer = load_train_objs()
train_data = prepare_dataloader(dataset, batch_size)
Expand Down

0 comments on commit 1e81da3

Please sign in to comment.