From 10e8119110ddc3f74c4d8bab7c846dc342c3fb6f Mon Sep 17 00:00:00 2001 From: Aravinda Kumar Date: Thu, 18 Jul 2024 16:08:16 +0530 Subject: [PATCH] added fsdp mnist example --- distributed/{ => fsdp}/FSDP/.gitignore | 0 distributed/{ => fsdp}/FSDP/README.md | 0 distributed/{ => fsdp}/FSDP/T5_training.py | 0 .../{ => fsdp}/FSDP/configs/__init__.py | 0 distributed/{ => fsdp}/FSDP/configs/fsdp.py | 0 .../{ => fsdp}/FSDP/configs/training.py | 0 .../{ => fsdp}/FSDP/download_dataset.sh | 0 .../FSDP/model_checkpointing/__init__.py | 0 .../model_checkpointing/checkpoint_handler.py | 0 .../{ => fsdp}/FSDP/policies/__init__.py | 0 .../activation_checkpointing_functions.py | 0 .../FSDP/policies/mixed_precision.py | 0 .../{ => fsdp}/FSDP/policies/wrapping.py | 0 distributed/{ => fsdp}/FSDP/requirements.txt | 0 .../{ => fsdp}/FSDP/summarization_dataset.py | 0 distributed/{ => fsdp}/FSDP/utils/__init__.py | 0 .../{ => fsdp}/FSDP/utils/environment.py | 0 .../{ => fsdp}/FSDP/utils/train_utils.py | 0 distributed/fsdp/fsdp-mnist/README.md | 7 + distributed/fsdp/fsdp-mnist/fsdp_mnist.py | 198 ++++++++++++++++++ 20 files changed, 205 insertions(+) rename distributed/{ => fsdp}/FSDP/.gitignore (100%) rename distributed/{ => fsdp}/FSDP/README.md (100%) rename distributed/{ => fsdp}/FSDP/T5_training.py (100%) rename distributed/{ => fsdp}/FSDP/configs/__init__.py (100%) rename distributed/{ => fsdp}/FSDP/configs/fsdp.py (100%) rename distributed/{ => fsdp}/FSDP/configs/training.py (100%) rename distributed/{ => fsdp}/FSDP/download_dataset.sh (100%) rename distributed/{ => fsdp}/FSDP/model_checkpointing/__init__.py (100%) rename distributed/{ => fsdp}/FSDP/model_checkpointing/checkpoint_handler.py (100%) rename distributed/{ => fsdp}/FSDP/policies/__init__.py (100%) rename distributed/{ => fsdp}/FSDP/policies/activation_checkpointing_functions.py (100%) rename distributed/{ => fsdp}/FSDP/policies/mixed_precision.py (100%) rename distributed/{ => fsdp}/FSDP/policies/wrapping.py (100%) rename distributed/{ => fsdp}/FSDP/requirements.txt (100%) rename distributed/{ => fsdp}/FSDP/summarization_dataset.py (100%) rename distributed/{ => fsdp}/FSDP/utils/__init__.py (100%) rename distributed/{ => fsdp}/FSDP/utils/environment.py (100%) rename distributed/{ => fsdp}/FSDP/utils/train_utils.py (100%) create mode 100644 distributed/fsdp/fsdp-mnist/README.md create mode 100644 distributed/fsdp/fsdp-mnist/fsdp_mnist.py diff --git a/distributed/FSDP/.gitignore b/distributed/fsdp/FSDP/.gitignore similarity index 100% rename from distributed/FSDP/.gitignore rename to distributed/fsdp/FSDP/.gitignore diff --git a/distributed/FSDP/README.md b/distributed/fsdp/FSDP/README.md similarity index 100% rename from distributed/FSDP/README.md rename to distributed/fsdp/FSDP/README.md diff --git a/distributed/FSDP/T5_training.py b/distributed/fsdp/FSDP/T5_training.py similarity index 100% rename from distributed/FSDP/T5_training.py rename to distributed/fsdp/FSDP/T5_training.py diff --git a/distributed/FSDP/configs/__init__.py b/distributed/fsdp/FSDP/configs/__init__.py similarity index 100% rename from distributed/FSDP/configs/__init__.py rename to distributed/fsdp/FSDP/configs/__init__.py diff --git a/distributed/FSDP/configs/fsdp.py b/distributed/fsdp/FSDP/configs/fsdp.py similarity index 100% rename from distributed/FSDP/configs/fsdp.py rename to distributed/fsdp/FSDP/configs/fsdp.py diff --git a/distributed/FSDP/configs/training.py b/distributed/fsdp/FSDP/configs/training.py similarity index 100% rename from distributed/FSDP/configs/training.py rename to distributed/fsdp/FSDP/configs/training.py diff --git a/distributed/FSDP/download_dataset.sh b/distributed/fsdp/FSDP/download_dataset.sh similarity index 100% rename from distributed/FSDP/download_dataset.sh rename to distributed/fsdp/FSDP/download_dataset.sh diff --git a/distributed/FSDP/model_checkpointing/__init__.py b/distributed/fsdp/FSDP/model_checkpointing/__init__.py similarity index 100% rename from distributed/FSDP/model_checkpointing/__init__.py rename to distributed/fsdp/FSDP/model_checkpointing/__init__.py diff --git a/distributed/FSDP/model_checkpointing/checkpoint_handler.py b/distributed/fsdp/FSDP/model_checkpointing/checkpoint_handler.py similarity index 100% rename from distributed/FSDP/model_checkpointing/checkpoint_handler.py rename to distributed/fsdp/FSDP/model_checkpointing/checkpoint_handler.py diff --git a/distributed/FSDP/policies/__init__.py b/distributed/fsdp/FSDP/policies/__init__.py similarity index 100% rename from distributed/FSDP/policies/__init__.py rename to distributed/fsdp/FSDP/policies/__init__.py diff --git a/distributed/FSDP/policies/activation_checkpointing_functions.py b/distributed/fsdp/FSDP/policies/activation_checkpointing_functions.py similarity index 100% rename from distributed/FSDP/policies/activation_checkpointing_functions.py rename to distributed/fsdp/FSDP/policies/activation_checkpointing_functions.py diff --git a/distributed/FSDP/policies/mixed_precision.py b/distributed/fsdp/FSDP/policies/mixed_precision.py similarity index 100% rename from distributed/FSDP/policies/mixed_precision.py rename to distributed/fsdp/FSDP/policies/mixed_precision.py diff --git a/distributed/FSDP/policies/wrapping.py b/distributed/fsdp/FSDP/policies/wrapping.py similarity index 100% rename from distributed/FSDP/policies/wrapping.py rename to distributed/fsdp/FSDP/policies/wrapping.py diff --git a/distributed/FSDP/requirements.txt b/distributed/fsdp/FSDP/requirements.txt similarity index 100% rename from distributed/FSDP/requirements.txt rename to distributed/fsdp/FSDP/requirements.txt diff --git a/distributed/FSDP/summarization_dataset.py b/distributed/fsdp/FSDP/summarization_dataset.py similarity index 100% rename from distributed/FSDP/summarization_dataset.py rename to distributed/fsdp/FSDP/summarization_dataset.py diff --git a/distributed/FSDP/utils/__init__.py b/distributed/fsdp/FSDP/utils/__init__.py similarity index 100% rename from distributed/FSDP/utils/__init__.py rename to distributed/fsdp/FSDP/utils/__init__.py diff --git a/distributed/FSDP/utils/environment.py b/distributed/fsdp/FSDP/utils/environment.py similarity index 100% rename from distributed/FSDP/utils/environment.py rename to distributed/fsdp/FSDP/utils/environment.py diff --git a/distributed/FSDP/utils/train_utils.py b/distributed/fsdp/FSDP/utils/train_utils.py similarity index 100% rename from distributed/FSDP/utils/train_utils.py rename to distributed/fsdp/FSDP/utils/train_utils.py diff --git a/distributed/fsdp/fsdp-mnist/README.md b/distributed/fsdp/fsdp-mnist/README.md new file mode 100644 index 0000000000..73dab0a188 --- /dev/null +++ b/distributed/fsdp/fsdp-mnist/README.md @@ -0,0 +1,7 @@ +## FSDP MNIST + +To run a simple MNIST example with FSDP: + +```bash +python fsdp_mnist.py +``` diff --git a/distributed/fsdp/fsdp-mnist/fsdp_mnist.py b/distributed/fsdp/fsdp-mnist/fsdp_mnist.py new file mode 100644 index 0000000000..9ec17a1e85 --- /dev/null +++ b/distributed/fsdp/fsdp-mnist/fsdp_mnist.py @@ -0,0 +1,198 @@ +# Based on: https://github.com/pytorch/examples/blob/master/mnist/main.py +import os +import argparse +import functools +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms + + +from torch.optim.lr_scheduler import StepLR + +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data.distributed import DistributedSampler +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + CPUOffload, + BackwardPrefetch, +) +from torch.distributed.fsdp.wrap import ( + size_based_auto_wrap_policy, + enable_wrap, + wrap, +) + +def setup(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + + # initialize the process group + dist.init_process_group("nccl", rank=rank, world_size=world_size) + +def cleanup(): + dist.destroy_process_group() + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + +def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None): + model.train() + ddp_loss = torch.zeros(2).to(rank) + if sampler: + sampler.set_epoch(epoch) + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(rank), target.to(rank) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target, reduction='sum') + loss.backward() + optimizer.step() + ddp_loss[0] += loss.item() + ddp_loss[1] += len(data) + + dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM) + if rank == 0: + print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, ddp_loss[0] / ddp_loss[1])) + +def test(model, rank, world_size, test_loader): + model.eval() + correct = 0 + ddp_loss = torch.zeros(3).to(rank) + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(rank), target.to(rank) + output = model(data) + ddp_loss[0] += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + ddp_loss[1] += pred.eq(target.view_as(pred)).sum().item() + ddp_loss[2] += len(data) + + dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM) + + if rank == 0: + test_loss = ddp_loss[0] / ddp_loss[2] + print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( + test_loss, int(ddp_loss[1]), int(ddp_loss[2]), + 100. * ddp_loss[1] / ddp_loss[2])) + +def fsdp_main(rank, world_size, args): + setup(rank, world_size) + + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + + dataset1 = datasets.MNIST('../data', train=True, download=True, + transform=transform) + dataset2 = datasets.MNIST('../data', train=False, + transform=transform) + + sampler1 = DistributedSampler(dataset1, rank=rank, num_replicas=world_size, shuffle=True) + sampler2 = DistributedSampler(dataset2, rank=rank, num_replicas=world_size) + + train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1} + test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2} + cuda_kwargs = {'num_workers': 2, + 'pin_memory': True, + 'shuffle': False} + train_kwargs.update(cuda_kwargs) + test_kwargs.update(cuda_kwargs) + + train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) + test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) + my_auto_wrap_policy = functools.partial( + size_based_auto_wrap_policy, min_num_params=20000 + ) + torch.cuda.set_device(rank) + + + init_start_event = torch.cuda.Event(enable_timing=True) + init_end_event = torch.cuda.Event(enable_timing=True) + + model = Net().to(rank) + + model = FSDP(model, + fsdp_auto_wrap_policy=my_auto_wrap_policy, + cpu_offload=CPUOffload(offload_params=True) + ) + + optimizer = optim.Adadelta(model.parameters(), lr=args.lr) + + scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) + init_start_event.record() + for epoch in range(1, args.epochs + 1): + train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1) + test(model, rank, world_size, test_loader) + scheduler.step() + + init_end_event.record() + + if rank == 0: + print(f"CUDA event elapsed time: {init_start_event.elapsed_time(init_end_event) / 1000}sec") + print(f"{model}") + + if args.save_model: + # use a barrier to make sure training is done on all ranks + dist.barrier() + states = model.state_dict() + if rank == 0: + torch.save(states, "mnist_cnn.pt") + + cleanup() + +if __name__ == '__main__': + # Training settings + parser = argparse.ArgumentParser(description='PyTorch MNIST Example') + parser.add_argument('--batch-size', type=int, default=64, metavar='N', + help='input batch size for training (default: 64)') + parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', + help='input batch size for testing (default: 1000)') + parser.add_argument('--epochs', type=int, default=10, metavar='N', + help='number of epochs to train (default: 14)') + parser.add_argument('--lr', type=float, default=1.0, metavar='LR', + help='learning rate (default: 1.0)') + parser.add_argument('--gamma', type=float, default=0.7, metavar='M', + help='Learning rate step gamma (default: 0.7)') + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--save-model', action='store_true', default=False, + help='For Saving the current Model') + args = parser.parse_args() + + torch.manual_seed(args.seed) + + WORLD_SIZE = torch.cuda.device_count() + mp.spawn(fsdp_main, + args=(WORLD_SIZE, args), + nprocs=WORLD_SIZE, + join=True) \ No newline at end of file