Skip to content

Commit

Permalink
updated script for DDP
Browse files Browse the repository at this point in the history
  • Loading branch information
bhavya-work committed Nov 21, 2024
1 parent 8046d9f commit a522f0e
Showing 1 changed file with 40 additions and 43 deletions.
83 changes: 40 additions & 43 deletions references/detection/train_pytorch_DDP.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,21 @@

import datetime
import hashlib
import multiprocessing
import time
import numpy as np
import torch
import wandb
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR, PolynomialLR
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torchvision.transforms.v2 import Compose, GaussianBlur, Normalize, RandomGrayscale, RandomPhotometricDistort
from tqdm.auto import tqdm

# The following import is required for DDP
import torch.distributed as dist
import torch.distributed as dist
import torch.multiprocessing as mp
import wandb
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR, PolynomialLR
from torch.utils.data import DataLoader, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from torchvision.transforms.v2 import Compose, GaussianBlur, Normalize, RandomGrayscale, RandomPhotometricDistort
from tqdm.auto import tqdm

from doctr import transforms as T
from doctr.datasets import DetectionDataset
Expand Down Expand Up @@ -137,7 +138,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a


@torch.no_grad()
def evaluate(model, val_loader, batch_transforms, val_metric,args, amp=False):
def evaluate(model, val_loader, batch_transforms, val_metric, args, amp=False):
# Model in eval mode
model.eval()
# Reset val metric
Expand Down Expand Up @@ -170,24 +171,21 @@ def evaluate(model, val_loader, batch_transforms, val_metric,args, amp=False):
return val_loss, recall, precision, mean_iou


def main(rank:int, world_size:int, args):

def main(rank: int, world_size: int, args):

Check notice on line 174 in references/detection/train_pytorch_DDP.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

references/detection/train_pytorch_DDP.py#L174

main is too complex (26) (MC0001)
"""
Args:
----
Args:
rank (int): device id to put the model on
world_size (int): number of processes participating in the job
args: other arguments passed through the CLI
"""


print(args)

if rank == 0 and args.push_to_hub:
login_to_hub()

if not isinstance(args.workers, int):
args.workers = min(16, mp.cpu_count())
args.workers = min(16, multiprocessing.cpu_count())

torch.backends.cudnn.benchmark = True

Expand Down Expand Up @@ -227,7 +225,7 @@ def main(rank:int, world_size:int, args):
print(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in {len(val_loader)} batches)")
with open(os.path.join(args.val_path, "labels.json"), "rb") as f:
val_hash = hashlib.sha256(f.read()).hexdigest()

class_names = val_set.class_names
else:
class_names = None
Expand All @@ -248,21 +246,22 @@ def main(rank:int, world_size:int, args):
model.load_state_dict(checkpoint)

# create default process group
device = torch.device('cuda', args.devices[rank])
dist.init_process_group(args.backend, rank = rank, world_size = world_size)
device = torch.device("cuda", args.devices[rank])
dist.init_process_group(args.backend, rank=rank, world_size=world_size)
# create local model
model = model.to(device)
# construct the DDP model
model = DDP(model, device_ids = [device])

# construct the DDP model
model = DDP(model, device_ids=[device])

if rank == 0:
# Metrics
val_metric = LocalizationConfusion(use_polygons=args.rotation and not args.eval_straight)

if rank == 0 and args.test_only:
print("Running evaluation")
val_loss, recall, precision, mean_iou = evaluate(model, val_loader, batch_transforms, val_metric,args, amp=args.amp)
val_loss, recall, precision, mean_iou = evaluate(
model, val_loader, batch_transforms, val_metric, args, amp=args.amp
)
print(
f"Validation loss: {val_loss:.6} (Recall: {recall:.2%} | Precision: {precision:.2%} | "
f"Mean IoU: {mean_iou:.2%})"
Expand All @@ -286,7 +285,7 @@ def main(rank:int, world_size:int, args):
RandomPhotometricDistort(p=0.3),
lambda x: x, # Identity no transformation
])

# Image + target augmentations
sample_transforms = T.SampleCompose(
(
Expand All @@ -312,7 +311,6 @@ def main(rank:int, world_size:int, args):
]
)
)


# Load both train and val data generators
train_set = DetectionDataset(
Expand All @@ -328,17 +326,18 @@ def main(rank:int, world_size:int, args):
batch_size=args.batch_size,
drop_last=True,
num_workers=args.workers,
# sampler=(train_set),
sampler = DistributedSampler(train_set, num_replicas = world_size, rank = rank, shuffle = False, drop_last = True),
sampler=DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=False, drop_last=True),
pin_memory=torch.cuda.is_available(),
collate_fn=train_set.collate_fn,
)
print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in {len(train_loader)} batches) along with DistributedSampler")

print(
f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in {len(train_loader)} batches)"
)

with open(os.path.join(args.train_path, "labels.json"), "rb") as f:
train_hash = hashlib.sha256(f.read()).hexdigest()

if args.show_samples:
if rank == 0 and args.show_samples:
x, target = next(iter(train_loader))
plot_samples(x, target)
# return
Expand Down Expand Up @@ -374,7 +373,7 @@ def main(rank:int, world_size:int, args):
exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name

# W&B
if rank==0 and args.wb:
if rank == 0 and args.wb:
run = wandb.init(
name=exp_name,
project="text-detection",
Expand All @@ -401,14 +400,15 @@ def main(rank:int, world_size:int, args):
if args.early_stop:
early_stopper = EarlyStopper(patience=args.early_stop_epochs, min_delta=args.early_stop_delta)


# Training loop
for epoch in range(args.epochs):
fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=args.amp)

if rank == 0:
# Validation loop at the end of each epoch
val_loss, recall, precision, mean_iou = evaluate(model, val_loader, batch_transforms, val_metric,args, amp=args.amp)
val_loss, recall, precision, mean_iou = evaluate(
model, val_loader, batch_transforms, val_metric, args, amp=args.amp
)
if val_loss < min_loss:
print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...")
torch.save(model.module.state_dict(), f"./{exp_name}.pt")
Expand All @@ -433,14 +433,13 @@ def main(rank:int, world_size:int, args):
if args.early_stop and early_stopper.early_stop(val_loss):
print("Training halted early due to reaching patience limit.")
break

if rank == 0:
if args.wb:
run.finish()

if args.push_to_hub:
push_to_hf_hub(model, exp_name, task="detection", run_config=args)



def parse_args():
Expand All @@ -452,15 +451,15 @@ def parse_args():
)

# DDP related args
parser.add_argument('--backend', default='nccl', type=str, help='backend to use for torch DDP')
parser.add_argument("--backend", default="nccl", type=str, help="backend to use for torch DDP")

parser.add_argument("arch", type=str, help="text-detection model to train")
parser.add_argument("--train_path", type=str, required=True, help="path to training data folder")
parser.add_argument("--val_path", type=str, required=True, help="path to validation data folder")
parser.add_argument("--name", type=str, default=None, help="Name of your training experiment")
parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on")
parser.add_argument("-b", "--batch_size", type=int, default=2, help="batch size for training")
parser.add_argument("--devices", default=None, nargs='+',type=int, help="GPU devices to use for training")
parser.add_argument("--devices", default=None, nargs="+", type=int, help="GPU devices to use for training")
parser.add_argument(
"--save-interval-epoch", dest="save_interval_epoch", action="store_true", help="Save model every epoch"
)
Expand Down Expand Up @@ -506,17 +505,15 @@ def parse_args():
if __name__ == "__main__":
args = parse_args()
if not torch.cuda.is_available():
raise AssertionError('PyTorch cannot access your GPUs. please look into it bro !!!')
raise AssertionError("PyTorch cannot access your GPUs. please look into it bro !!!")

if not isinstance(args.devices, list):
args.devices = list(range(torch.cuda.device_count()))
args.devices = list(range(torch.cuda.device_count()))
# no of process per gpu
nprocs = len(args.devices)
# Environment variables which need to be
# set when using c10d's default "env"
# initialization mode.
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
mp.spawn(main, args=(nprocs, args), nprocs = nprocs, join=True)


os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
mp.spawn(main, args=(nprocs, args), nprocs=nprocs, join=True)

0 comments on commit a522f0e

Please sign in to comment.