From f7813d6d3321fbef5dfe04f4294ed913c8c784e6 Mon Sep 17 00:00:00 2001 From: anj-s <32556631+anj-s@users.noreply.github.com> Date: Thu, 25 Feb 2021 16:22:04 -0800 Subject: [PATCH] [feature] Add support for OffloadModel to enable training large models on 1 GPU. (#432) * clean start * removing per layer split strategy, probably not that useful indeed * initial transformer benchmark * hack, enable testing ViT + offload, python3 benchmarks/oss.py --epochs 2 --optim_type oss_offload_ddp --batch_size=32 --model vit_large_patch16_224 * proper cuda streams and device, something off in terms of mems consumption * minor, stashing * unit test fix * removing all the distributed parts * simpler test, needs debugging * working OOP, running a model which does not fit on the gpu memory * spring cleaning * removing the ill-advised optimizer bits, better keep that orthogonal * [offload] Add support for activation offloading + other changes (#367) * initial fwd/bwd commit * checkpoint work * modify shard loop * activation offloading and test to start with * fix lint errors * update comments * fix lint * remove unused var * remove commented out lines * modify name * remove break * remove profiler comments * avoid saving inputs * fix lint errors Co-authored-by: Anjali Sridhar * [offload] Add support for fp16 training (#374) * initial fwd/bwd commit * checkpoint work * modify shard loop * activation offloading and test to start with * fix lint errors * update comments * fix lint * remove unused var * remove commented out lines * modify name * remove break * remove profiler comments * add support for fp16 * add unit tests * fix lint errors * fix test failure Co-authored-by: Anjali Sridhar * [offload] Add support for activation checkpointing for all layers. (#381) * initial fwd/bwd commit * checkpoint work * modify shard loop * activation offloading and test to start with * fix lint errors * update comments * fix lint * remove unused var * remove commented out lines * modify name * remove break * remove profiler comments * add support for fp16 * add unit tests * fix lint errors * fix test failure * cp work, incorrect output dimensions still need to be fixed * fixed activation outputs * intermediate cp of work * add tests * fix lint errors Co-authored-by: Anjali Sridhar * add support for microbatches * revert benchmark config changes * add parametrization * fix lint errors and tests * skip test for 1.5 * fix lint errors * skip test if there are no GPUs * fix lint errors * fix lint errors * move experimental to the fairscale repo * lint error fixes * modify test imports * lint error fixes * move offload files to the experimental directory * move tests and benchmarks to their forlder * fix mypy errors * cp intermediate working benchmarks * more changes * split benchmark configs * remove print statements * fix lint errors * remove unused print * stress testing * remove unused file * change param nae * lint fixes * move file to the right folder * offload_experimental * add doc string * add error message Co-authored-by: Benjamin Lefaudeux Co-authored-by: Benjamin Lefaudeux Co-authored-by: Anjali Sridhar --- benchmarks/datasets/wikitext2_data.py | 2 +- benchmarks/experimental/offload.py | 415 +++++++++++++++++++ benchmarks/golden_configs/lm_wikitext2.py | 123 ++++-- benchmarks/pipe.py | 2 +- fairscale/experimental/nn/offload.py | 469 ++++++++++++++++++++++ tests/experimental/nn/test_offload.py | 138 +++++++ 6 files changed, 1110 insertions(+), 39 deletions(-) create mode 100755 benchmarks/experimental/offload.py create mode 100644 fairscale/experimental/nn/offload.py create mode 100644 tests/experimental/nn/test_offload.py diff --git a/benchmarks/datasets/wikitext2_data.py b/benchmarks/datasets/wikitext2_data.py index 6c3abfb77..1229b3bcd 100644 --- a/benchmarks/datasets/wikitext2_data.py +++ b/benchmarks/datasets/wikitext2_data.py @@ -44,7 +44,7 @@ def data_process(raw_text_iter): test_dataset = data_process(iter(io.open(test_filepath, encoding="utf8"))) def batchify(data): - batch_size = args.batch_size + batch_size = benchmark_config["batch_size"] return _batchify(data, batch_size) total_batch_size = _get_total_batch_size(benchmark_config, model_specs) diff --git a/benchmarks/experimental/offload.py b/benchmarks/experimental/offload.py new file mode 100755 index 000000000..b5f9a2b9f --- /dev/null +++ b/benchmarks/experimental/offload.py @@ -0,0 +1,415 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import contextlib +from functools import reduce +import logging +import math +import operator +import time + +import numpy as np +import torch +from torch.optim import Adam +from torch.utils.data.dataloader import DataLoader +from torchvision.datasets import FakeData +from torchvision.transforms import ToTensor + +from benchmarks.datasets.wikitext2_data import get_real_dataloaders as get_real_wikitext2_dataloaders +from benchmarks.datasets.wikitext2_data import get_synthetic_dataloaders as get_synthetic_wikitext2_dataloaders +from benchmarks.golden_configs.lm_wikitext2 import Offload_Sequential as offload_seq +from benchmarks.golden_configs.lm_wikitext2 import Offload_Transformer as lm_wikitext2 +from benchmarks.models import transformer_lm +from fairscale.experimental.nn.offload import OffloadModel + + +def init_random_seed(seed: int): + + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + np.random.seed(seed) + + +def get_model_and_optimizer(args, device, benchmark_config, model_specs): + """Return instantiated model and optimizer function.""" + + if args.model_name == "lm": + model = get_lm_model(args, device, model_specs) + lr = benchmark_config["lr"] + + def make_adam(params): + return Adam(params, lr=lr) + + optimizer = make_adam + elif args.model_name == "seq": + model = get_seq_model(args, device, model_specs) + optimizer = torch.optim.SGD + + model = OffloadModel( + model_cpu=model, + device=torch.device("cuda"), + offload_device=torch.device("cpu"), + num_slices=benchmark_config["slices"], + checkpoint_activation=benchmark_config["checkpoint_activation"], + num_microbatches=benchmark_config["num_microbatches"], + ) + + return model, optimizer + + +def get_seq_model(args, device, model_specs): + model = torch.nn.Sequential( + torch.nn.Linear(model_specs["inputs"] * model_specs["inputs"], model_specs["hidden"]), + *([torch.nn.Linear(model_specs["hidden"], model_specs["hidden"]) for _ in range(model_specs["layers"])]), + torch.nn.Linear(model_specs["hidden"], model_specs["outputs"]), + ) + return model.cpu() + + +def get_lm_model(args, device, config): + """Get language model(based on GPT-2) used for sequence prediction.""" + + ninp = config["ninp"] + nhead = config["nhead"] + initrange = config["initrange"] + dropout = config["dropout"] + vocab_size = config["vocab_size"] + nhid = config["nhid"] + ndecoder = config["num_decoder_layers"] + + return transformer_lm.TransformerLM(vocab_size, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device) + + +def log_number_of_parameters(model): + + num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters())) + logging.info(f"training model, #params = {num_params}") + + +def _get_fp16_context(use_fp16=False): + if use_fp16: + return torch.cuda.amp.autocast() + else: + return contextlib.nullcontext() + + +def _get_profiler_context(use_profiler=False): + if use_profiler: + return torch.autograd.profiler.profile(use_cuda=True, profile_memory=True) + else: + return contextlib.nullcontext() + + +def _get_profiler_record_context(record_name, use_profiler=False): + if use_profiler: + return torch.autograd.profiler.record_function(record_name) + else: + return contextlib.nullcontext() + + +def train_seq(model_config, benchmark_config, model_specs, args): + device = torch.device("cuda") + torch.cuda.set_device(0) + torch.manual_seed(5) + + model = model_config["model"] + criterion = benchmark_config["criterion"] + optimizer = model_config["optimizer"](model.parameters(), lr=benchmark_config["lr"]) + dataloader, _, _ = model_config["data"] + + def train_epoch(args): + model.train() + for batch_inputs, batch_outputs in dataloader: + batch_inputs, batch_outputs = batch_inputs.to("cuda"), batch_outputs.to("cuda") + start = time.time_ns() + with _get_profiler_context() as prof: + optimizer.zero_grad() + inputs = batch_inputs.reshape(-1, model_specs["inputs"] * model_specs["inputs"]) + with _get_profiler_record_context("model_training"): + with _get_fp16_context(use_fp16=args.use_fp16): + output = model(inputs) + loss = criterion(output, target=batch_outputs) + loss.backward() + optimizer.step() + logging.info( + "Memory stats are {:.2f}GB".format(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] / 2 ** 30) + ) + logging.info( + "Loss {:.2f} - throughput {:.2f}fps".format( + loss.item(), benchmark_config["batch_size"] / (time.time_ns() - start) * 10 ** 9 + ) + ) + if args.use_profiler: + prof.export_chrome_trace("/tmp/offload_prof") + + train_epoch(args) + + +def train(model_config, model, benchmark_config, model_specs, args): + lm_dataloader, _, _ = model_config["data"] + criterion = benchmark_config["criterion"] + vocab_size = model_specs["vocab_size"] + optimizer = model_config["optimizer"] + + model.train() + log_number_of_parameters(model) + + total_loss = 0.0 + word_counter = 0 + + optimizer = optimizer(model.parameters()) + + total_tokens = 0 + total_tokens_per_log_interval = 0 + bptt = 2 + start_time = time.time() + epoch_start_time = 0.0 + + def get_batch(source): + seq_len = len(source) - 1 + data = source[0:seq_len] + target = source[1 : 1 + seq_len] + return data, target + + for i, batch in enumerate(lm_dataloader): + if i == 1: + epoch_start_time = time.time() + + source, target = get_batch(batch) + + if i > 0: + total_tokens += source.numel() + + optimizer.zero_grad() + output = model(source) + + target = target.to("cuda") + output = output.to(target.device) + loss = criterion(output.view(-1, vocab_size), target.view(-1)) + loss.backward() + + torch.nn.utils.clip_grad_value_(model.parameters(), model_specs["clip_value"]) + optimizer.step() + + total_loss += loss.item() + log_interval = 1 + total_tokens_per_log_interval += source.numel() + if i % log_interval == 0 and i > 0: + cur_loss = total_loss / log_interval + elapsed = time.time() - start_time + print( + "| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format( + i, total_tokens_per_log_interval / elapsed, cur_loss, math.exp(cur_loss) + ) + ) + total_tokens_per_log_interval = 0 + total_loss = 0 + start_time = time.time() + if epoch_start_time != 0: + wps = total_tokens / (time.time() - epoch_start_time) + else: + raise RuntimeError( + "Unable to benchmark on a single batch. Increase the size " " of the dataset and rerun the benchmark." + ) + return wps, loss.item() + + +def verify_peak_memory(rank, golden_config, std_dev): + print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"])) + current_device_usage = torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"] + golden_ref = golden_config["peak_mem_usage"][rank] + if not current_device_usage < golden_ref * std_dev: + raise RuntimeError( + "Peak memory usage for cuda device {:d} is {:d} which" + "is less than golden reference value of {:d}".format(rank, current_device_usage, golden_ref) + ) + + +def verify_lm_run(wps, golden_config, args): + """Verify that words per second for a given benchmark run matches the golden data.""" + + # Verify wps only on the last rank in multiprocess pipe + if not args.multiprocess or dist.get_rank() == dist.get_world_size() - 1: + # Assert that words per second is within 3 standard deviations of the average + # of five golden runs + print("Throughput(wps) is {:.2f}.".format(wps)) + if not wps > (golden_config["avg_wps"] - (3 * golden_config["std_dev_wps"])): + raise RuntimeError( + "Throughput(wps):{:.2f} is below the golden threshold of an " + "average value of {:.2f} and standard dev of {:.2f}.".format( + wps, golden_config["avg_wps"], golden_config["std_dev_wps"] + ) + ) + + if args.multiprocess: + verify_peak_memory(dist.get_rank(), golden_config, 1.5) + else: + for i in range(4): + verify_peak_memory(i, golden_config, 1.1) + + +def benchmark_language_model(model_config, model, benchmark_config, model_specs, args): + epoch = benchmark_config["epochs"] + start_time = time.time() + print("-" * 110) + print("| start of epoch {:1d}".format(epoch)) + print("-" * 110) + wps, loss = train(model_config, model, benchmark_config, model_specs, args) + elapsed_time = time.time() - start_time + print("-" * 110) + print("| end of epoch {:1d} | time: {:5.2f}s | train loss {:5.2f} ".format(epoch, elapsed_time, loss)) + print("-" * 110) + print("Throughput(wps) is {:.2f}.".format(wps)) + print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"])) + # TODO(anj-s): Enable golden config data verification. + + +def get_synthetic_dataloaders(args, device, benchmark_config, model_specs): + """Returns dataloader for synthetic data.""" + + if args.model_name == "lm": + return get_synthetic_wikitext2_dataloaders(args, benchmark_config, model_specs) + elif args.model_name == "seq": + transform = ToTensor() + dataloader = DataLoader( + FakeData( + image_size=(1, model_specs["inputs"], model_specs["inputs"]), + num_classes=model_specs["outputs"], + transform=transform, + ), + batch_size=benchmark_config["batch_size"], + ) + return dataloader, dataloader, dataloader + else: + raise RuntimeError(f"Unrecognized args.model_name {args.model_name}") + + +def get_real_dataloaders(args, device, benchmark_config, model_specs): + """Returns dataloaders for real data.""" + + if args.model_name == "lm": + data = get_real_wikitext2_dataloaders(args, benchmark_config, model_specs) + ntokens, train_dataloader, valid_dataloader, test_dataloader = data + model_specs["vocab_size"] = ntokens + return train_dataloader, valid_dataloader, test_dataloader + else: + raise RuntimeError(f"Unrecognized args.model_mame {args.model_name}") + + +def create_model_config(args, benchmark_config=None, model_specs=None): + """Return a dict with the given model, dataset and optimizer.""" + + # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + device = torch.device("cpu") + + if args.model_name == "lm": + if args.use_synthetic_data: + dataloader_fn = get_synthetic_dataloaders + else: + dataloader_fn = get_real_dataloaders + + data = dataloader_fn(args, device, benchmark_config, model_specs) + model, optimizer = get_model_and_optimizer(args, device, benchmark_config, model_specs) + return { + "model": model, + "optimizer": optimizer, + "data": data, + } + elif args.model_name == "seq": + + data = get_synthetic_dataloaders( + args, device, offload_seq.get_benchmark_config(), offload_seq.get_model_config() + ) + model, optimizer = get_model_and_optimizer(args, device, benchmark_config, model_specs) + return { + "model": model, + "optimizer": optimizer, + "data": data, + } + else: + raise RuntimeError(f"Unrecognized args.model_mame {args.model_name}") + + +def create_benchmark_config(model_name): + """Return a dict with configurations required for benchmarking `model_name` model.""" + + if args.model_name == "lm": + return lm_wikitext2.get_benchmark_config() + elif args.model_name == "seq": + return offload_seq.get_benchmark_config() + else: + raise RuntimeError(f"Unrecognized args.model_name {args.model_name}") + + +def get_golden_config(model_name, args): + """Return a dict with the golden data for throughput and memory usage.""" + + if model_name == "lm": + return lm_wikitext2.get_golden_real_stats(False) + else: + raise RuntimeError(f"Unrecognized args.model_mame {args.model_name}") + + +def get_model_specs(model_name): + """Return a dict with configurations required for configuring `model_name` model.""" + + if model_name == "lm": + return lm_wikitext2.get_model_config() + elif model_name == "seq": + return offload_seq.get_model_config() + else: + raise RuntimeError("Unrecognized args.model_mame " % args.model_name) + + +def run_benchmark(args): + """Benchmark a given model using a single process and single devices.""" + + # We need at least 1 GPU to benchmark the offload model API. + num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 0 + assert num_devices > 0 + init_random_seed(0) + + if args.model_name == "lm": + benchmark_config = create_benchmark_config(args.model_name) + model_specs = get_model_specs(args.model_name) + model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs) + model = model_config["model"] + + if args.dry_run: + train(model_config, model, benchmark_config, args) + else: + benchmark_language_model(model_config, model, benchmark_config, model_specs, args) + elif args.model_name == "seq": + benchmark_config = create_benchmark_config(args.model_name) + model_specs = get_model_specs(args.model_name) + model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs) + model = model_config["model"] + train_seq(model_config, benchmark_config, model_specs, args) + else: + raise RuntimeError(f"Unable to recognize model name {args.model_name}") + + +parser = argparse.ArgumentParser(description="benchmark") +parser.add_argument("--dry_run", action="store_true", help="Run a sample training run without regression testing.") +parser.add_argument( + "--debug", action="store_true", help="Print debugging statements which is more verbose than the default." +) +parser.add_argument( + "--model_name", default="lm", type=str, help="Language Model(LM) used to benchmark nn.pipe.", +) +parser.add_argument("--use_synthetic_data", action="store_true", help="Uses synthetic data for running benchmarks.") +parser.add_argument("--use_fp16", action="store_true", default=False) +parser.add_argument("--checkpoint_activation", action="store_true", default=False) +parser.add_argument("--use_profiler", action="store_true", default=False) + + +if __name__ == "__main__": + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG) + logging.info("Benchmark arguments: %s" % args) + + run_benchmark(args) diff --git a/benchmarks/golden_configs/lm_wikitext2.py b/benchmarks/golden_configs/lm_wikitext2.py index c9aedefad..255878b07 100644 --- a/benchmarks/golden_configs/lm_wikitext2.py +++ b/benchmarks/golden_configs/lm_wikitext2.py @@ -5,46 +5,95 @@ from fairscale.optim import GradScaler -def get_model_config(): - return { - "vocab_size": 10000, - "ninp": 2048, # embedding dimension - "nhid": 2048, # the dimension of the feedforward network model in nn.TransformerEncoder - "nhead": 32, # the number of heads in the multiheadattention models - "dropout": 0, - "initrange": 0.1, - "scaler": GradScaler(), - "clip_value": 0.05, - "num_decoder_layers": 10, - "seq_len": 32, - } - - -def get_benchmark_config(): - - return { - "epochs": 1, - "lr": 0.001, # learning rate - "batch_size": 8, - "criterion": nn.CrossEntropyLoss(), - } - - -def get_golden_real_stats(multiprocess=False): - if not multiprocess: +class Offload_Transformer: + def get_model_config(): return { - "avg_wps": 703.778, - "std_dev_wps": 5.732, - "peak_mem_usage": [2320996352, 1396742144, 1396742144, 2340010496], + "vocab_size": 10000, + "ninp": 2048, # embedding dimension + "nhid": 2048, # the dimension of the feedforward network model in nn.TransformerEncoder + "nhead": 32, # the number of heads in the multiheadattention models + "dropout": 0, + "initrange": 0.1, + "scaler": GradScaler(), + "clip_value": 0.05, + "num_decoder_layers": 10, + "seq_len": 32, } - else: + + def get_benchmark_config(): + + return { + "epochs": 1, + "lr": 0.001, # learning rate + "batch_size": 8, + "criterion": nn.CrossEntropyLoss(), + "checkpoint_activation": True, + "num_microbatches": 4, + "slices": 3, + } + + +class Offload_Sequential: + def get_model_config(): + return { + "inputs": 100, + "outputs": 5, + "hidden": 1000, + "layers": 100, + "clip_value": 0.05, + } + + def get_benchmark_config(): + + return { + "epochs": 1, + "lr": 0.001, # learning rate + "batch_size": 8, + "criterion": nn.CrossEntropyLoss(), + "slices": 3, + "checkpoint_activation": True, + "num_microbatches": 4, + } + + +class Pipe: + def get_model_config(): + return { + "vocab_size": 10000, + "ninp": 2048, # embedding dimension + "nhid": 2048, # the dimension of the feedforward network model in nn.TransformerEncoder + "nhead": 32, # the number of heads in the multiheadattention models + "dropout": 0, + "initrange": 0.1, + "scaler": GradScaler(), + "clip_value": 0.05, + "num_decoder_layers": 10, + "seq_len": 32, + } + + def get_benchmark_config(): + return { - "avg_wps": 647.404, - "std_dev_wps": 14.51, - "peak_mem_usage": [3305007616, 2578692608, 3304524288, 2578692608], + "epochs": 1, + "lr": 0.001, # learning rate + "batch_size": 8, + "criterion": nn.CrossEntropyLoss(), } + def get_golden_real_stats(multiprocess=False): + if not multiprocess: + return { + "avg_wps": 703.778, + "std_dev_wps": 5.732, + "peak_mem_usage": [2320996352, 1396742144, 1396742144, 2340010496], + } + else: + return { + "avg_wps": 647.404, + "std_dev_wps": 14.51, + "peak_mem_usage": [3305007616, 2578692608, 3304524288, 2578692608], + } -def get_golden_synthetic_stats(): - # TODO(anj-s): Add support for synthetic regression benchmarks - raise NotImplementedError("Synthetic data benchmarks are not supported.") + def get_golden_synthetic_stats(): + # TODO(anj-s): Add support for synthetic regression benchmarks + raise NotImplementedError("Synthetic data benchmarks are not supported.") diff --git a/benchmarks/pipe.py b/benchmarks/pipe.py index 947ada40b..38554750e 100644 --- a/benchmarks/pipe.py +++ b/benchmarks/pipe.py @@ -12,7 +12,6 @@ from datasets.wikitext2_data import get_real_dataloaders as get_real_wikitext2_dataloaders from datasets.wikitext2_data import get_synthetic_dataloaders as get_synthetic_wikitext2_dataloaders -from golden_configs import lm_wikitext2 from models import transformer_lm import numpy as np import torch @@ -22,6 +21,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Adam +from benchmarks.golden_configs.lm_wikitext2 import Pipe as lm_wikitext2 from fairscale.nn import Pipe from fairscale.nn.model_parallel import initialize_model_parallel from fairscale.nn.model_parallel.initialize import get_data_parallel_group, get_pipeline_parallel_group diff --git a/fairscale/experimental/nn/offload.py b/fairscale/experimental/nn/offload.py new file mode 100644 index 000000000..3e4e5bee6 --- /dev/null +++ b/fairscale/experimental/nn/offload.py @@ -0,0 +1,469 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from builtins import isinstance +import functools +import logging +from typing import Any, List, Tuple + +import torch +from torch import nn + + +def conditional_amp_fwd_decorator(orig_func): # type: ignore + + if hasattr(torch.cuda.amp, "custom_fwd"): + return torch.cuda.amp.custom_fwd(orig_func) # type: ignore + + @functools.wraps(orig_func) + def inner_decorator(*args: Any, **kwargs: Any) -> Any: + return orig_func(*args, **kwargs) + + return inner_decorator + + +def conditional_amp_bwd_decorator(orig_func): # type: ignore + if hasattr(torch.cuda.amp, "custom_bwd"): + return torch.cuda.amp.custom_bwd(orig_func) # type: ignore + + @functools.wraps(orig_func) + def inner_decorator(*args: Any, **kwargs: Any) -> Any: + return orig_func(*args, **kwargs) + + return inner_decorator + + +def _split(modules: nn.Sequential, number_splits: int) -> List[List[nn.Module]]: + number_splits = min(len(modules), number_splits) + splits: List[List[nn.Module]] = [[] for _ in range(number_splits)] + + # Count the number of parameters per exposed layer, use that as a proxy for memory footprint + total_number_params = sum([sum(p.numel() for p in m.parameters()) for m in modules]) + number_parameters_per_shard = total_number_params // number_splits + + current_shard = 0 + + logging.info( + f"This model has {total_number_params/1e6:.2f}M parameters, aiming for {number_parameters_per_shard/1e6:.2f}M parameters per shard" + ) + + for m in modules: + # Number of parameters in the current shard + current_shard_params = sum(p.numel() for sm in splits[current_shard] for p in sm.parameters()) + + # This shard is big enough, point to the next one + if ( + current_shard_params > 0 + and current_shard_params + sum(p.numel() for p in m.parameters()) > number_parameters_per_shard + and current_shard < number_splits - 1 + ): + current_shard += 1 + + splits[current_shard].append(m) + + for i, split in enumerate(splits): + current_shard_params = sum(p.numel() for sm in split for p in sm.parameters()) + logging.info(f"Shard {i} holds {current_shard_params/1e6:.2f}M parameters") + + return splits + + +class ModelShard(nn.Module): + """ + Wrap one shard of the model, make it possible to load parameters on the + fly for the FW and BW pass on the given device. + """ + + def __init__( + self, cpu_model_shard: nn.Module, device: torch.device, offload_device: torch.device, index: int, + ): + super().__init__() + self.model_shard = cpu_model_shard + self.index = index + + # Save all the parameter sizes to be able to restore them + self.device = device + torch.cuda.device(self.device) + + self.offload_device = offload_device + + self.model_shard.to(offload_device) + self.cuda_stream = torch.cuda.Stream( + device=self.device + ) # needed to make sure load/offload really run in parallel with compute + + def forward(self, *inputs): # type: ignore + return self.model_shard(*inputs) if isinstance(inputs, tuple) else self.model_shard(inputs) + + def to(self, device: torch.device) -> "ModelShard": # type: ignore + # Make sure that the lookahead and lookback shards are not captured by this call + self.model_shard.to(device) + return self + + def train(self, mode: bool = True) -> "ModelShard": + # Make sure that the lookahead and lookback shards are not captured by this call + self.model_shard.train(mode) + return self + + def to_device(self) -> None: + self.model_shard.to(device=self.device, non_blocking=True) + + def forward_load(self, non_blocking: bool = True) -> None: + with torch.cuda.stream(self.cuda_stream): + # Restore all the parameter buffers + self.model_shard.to(device=self.device, non_blocking=non_blocking) + + def backward_load(self, non_blocking: bool = True) -> None: + with torch.cuda.stream(self.cuda_stream): + self.model_shard.to(self.device, non_blocking=non_blocking) + + def forward_drop(self, non_blocking: bool = True) -> None: + with torch.cuda.stream(self.cuda_stream): + self.model_shard.to(self.offload_device, non_blocking=non_blocking) + + def backward_drop(self, non_blocking: bool = True) -> None: + with torch.cuda.stream(self.cuda_stream): + self.model_shard.to(self.offload_device, non_blocking=non_blocking) + + +class ActivationCheckpointing(torch.autograd.Function): + """ + This Function enables checkpointing of intermediate activations at + shard boundaries by overriding the forward and backward pass of the nn.Module. + + - In the FW pass, it drops parameters in the previous shard and + loads parameters for the next shard. No graph is constructed in the FW pass. + This enables us to offload intermediate activations present at the shard + boundaries. + + - In the BW pass, it does the reverse. We run the forward pass using the + saved intermediate activations and calculate gradients as needed. + The trade-off is latency vs memory when using activation checkpointing. + + - Follows heavily from https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html#checkpoint. + + NOTE: see https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function + """ + + @staticmethod + @conditional_amp_fwd_decorator # type: ignore + def forward(ctx: Any, inputs: Any, model_instance: Any) -> Any: + inputs = inputs if isinstance(inputs, tuple) else (inputs,) + + ctx.inputs = inputs + ctx.model_instance = model_instance + # TODO(anj-s): We might need to store this for each boundary activation. + # Currently we assume all boundary activation inputs require + ctx.grad_requirements = tuple(x.requires_grad for x in inputs) + ctx.fwd_rng_state = torch.get_rng_state() + + # List of input activations starting with the given input. + model_instance._activations = [inputs] + # Enumerate through layer shards and apply activations from the previous shard. + for index, layer_shard in enumerate(model_instance.model_slices): + # Bring in the current activations onto the device. + model_instance._activations[index] = tuple([a.cuda() for a in list(model_instance._activations[index])]) + # Bring in the current layer shard onto the device. + layer_shard.forward_load() + # Apply the FP and store the activations on the CPU. + inputs = model_instance._activations[index] + + with torch.no_grad(): + output_list: List[Any] = [] + for given_input in inputs: + given_input_list = torch.chunk(given_input, model_instance._num_microbatches) + given_output_list = [] + for inputs in given_input_list: + output = layer_shard(inputs) + given_output_list.append(output) + given_output = torch.cat(given_output_list).squeeze(-1) + output_list.append(given_output) + output = tuple(output_list) + + output = output if isinstance(output, tuple) else (output,) + # The last instance will lose the gradient function if we move it to the CPU. + # This is because all grad function are present on the device that ran the FW pass. + if index == len(model_instance.model_slices) - 1: + model_instance._activations.append(output) + else: + model_instance._activations.append(tuple([a.cpu() for a in list(output)])) + # Move the layer shard back to the CPU. + layer_shard.forward_drop() + + # TODO(anj-s): Check device of the result to make sure the outputs and targets match device. + result = model_instance._activations[-1] + for r in result: + r.requires_grad = True + return result[0] if len(result) == 1 else result + + @staticmethod + @conditional_amp_bwd_decorator + def backward(ctx, *grad_outputs): # type: ignore + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible") + inputs = ctx.inputs + model_instance = ctx.model_instance + + for i, need_grad in enumerate(ctx.grad_requirements): + inputs[i].requires_grad = need_grad + + all_grads = [grad_outputs] + + final_index = len(model_instance._activations) - 1 + + for model_shard, activation in zip( + reversed(model_instance.model_slices), reversed(model_instance._activations[:-1]) + ): + # Move the activation to the device. + activation = tuple([a.cuda() for a in list(activation)]) + # One of the inputs to the FW pass must require grad. + for a in activation: + a.requires_grad = True + + # Move the model shard to the device. + model_shard.backward_load() + # Store the BW pass state. + bwd_rng_state = torch.get_rng_state() + + # TODO(anj-s): Why detach inputs? + activation = torch.utils.checkpoint.detach_variable(activation) + # Get the last gradient calculation. + final_grads = all_grads[-1] + if isinstance(activation, torch.Tensor): + activation = (activation,) + if isinstance(final_grads, torch.Tensor): + final_grads = (final_grads,) + # Iterate through all the inputs/outputs of a shard (there could be multiple). + chunked_grad_list: List[Any] = [] + # Chunk the activation and grad based on the number of microbatches that are set. + for chunked_activation, chunked_grad in zip( + torch.chunk(*activation, model_instance._num_microbatches), # type: ignore + torch.chunk(*final_grads, model_instance._num_microbatches), # type: ignore + ): + # Set the states to what it used to be before the forward pass. + torch.set_rng_state(ctx.fwd_rng_state) + + if isinstance(chunked_activation, torch.Tensor): + chunked_activation = (chunked_activation,) # type: ignore + if isinstance(chunked_grad, torch.Tensor): + chunked_grad = (chunked_grad,) # type: ignore + + # Since we need a grad value of a non leaf element we need to set these properties. + for a in chunked_activation: + a.requires_grad = True + a.retain_grad() + + with torch.enable_grad(): + # calculate the output of the last shard wrt to the stored activation at the slice boundary. + outputs = model_shard(*chunked_activation) + + # Set the states back to what it was at the start of this function. + torch.set_rng_state(bwd_rng_state) + torch.autograd.backward(outputs, chunked_grad) + chunked_grad_list += [a.grad for a in chunked_activation] + + # Append the list of grads to the all_grads list and this should be on the CPU. + all_grads.append(torch.cat(chunked_grad_list).squeeze(-1)) # type: ignore + # Move activation back to the CPU. + # TODO(anj-s): Why does moving activations to CPU cause the .grad property to be None? + activation = tuple([a.cpu() for a in list(activation)]) + # Move the shard back to the CPU. + model_shard.backward_drop() + detached_inputs = model_instance._activations[0] + grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs) + return (None, None) + grads + + +class ShardSyncLayer(torch.autograd.Function): + """ + The shard sync layer is a synchronization point between model shards. + + - In the forward pass, it drops parameters in the previous shard and + loads parameters for the next shard. + + - In the backward pass, it does the reverse. + + It does not change or create any outputs at all, instead it just + forwards the input as the output. + + NOTE: see https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function + """ + + @staticmethod + @conditional_amp_fwd_decorator # type: ignore + def forward(ctx: Any, inputs: Any, index: int, model_slices: Any, model_instance: Any) -> Any: + drop_index = index + load_index = index + 1 + max_slices = len(model_slices) + + if drop_index >= 0: + # Move shard from device to offload device. + logging.info(f"Dropping shard {drop_index}") + model_slices[drop_index].forward_drop() + + if load_index < max_slices: + # Load shard from offload device to device. + logging.info(f"Loading shard{load_index}") + model_slices[load_index].forward_load() + + ctx.index = index + ctx.model_slices = model_slices + ctx.model_instance = model_instance + + return inputs if isinstance(inputs, tuple) else (inputs,) + + @staticmethod + @conditional_amp_bwd_decorator + def backward(ctx, *grad_outputs): # type: ignore + + load_index = ctx.index + drop_index = load_index + 1 + model_slices = ctx.model_slices + model_instance = ctx.model_instance + + # TODO(anj-s): Are these redundant in the backward pass? + if drop_index == len(model_slices): + # Drop the last activation since it is still on the CPU + # after the loss.backward() call. + model_instance._activations[-1] = tuple([a.cuda() for a in list(model_instance._activations[-1])]) + + if drop_index < len(model_slices): + # Move shard from device to offload device. + logging.info(f"Backward Dropping shard {drop_index}") + model_slices[drop_index].backward_drop() + model_instance._activations[drop_index] = tuple( + [a.cpu() for a in list(model_instance._activations[drop_index])] + ) + + if load_index >= 0: + # Load shard from offload device to device. + logging.info(f"Backward Loading shard{load_index}") + model_slices[load_index].backward_load() + model_instance._activations[load_index] = tuple( + [a.cuda() for a in list(model_instance._activations[load_index])] + ) + + # The returned variables need to mirror the forward inputs + # TODO(anj-s): Why do we need to do this? + if isinstance(grad_outputs, tuple): + return grad_outputs[0], None, None, None + + return grad_outputs, None, None, None + + +class OffloadModel(nn.Module): + """Wrapper used offload parts of a model to the CPU. + + The model is sharded into chunks and at each iteration, a + single chunk is copied from CPU->GPU, FW pass is computed and + the chunk is copied back to CPU. This process is repeated for + all the chunks. In the BW pass, the same process happens in + reverse. + + Note: OffloadModel currently only supports nn.Sequential models. + + Args: + module (~torch.nn.Sequential): Module to be offloaded. + + device (torch.device): + Device where the active model should reside. + + offload_device (torch.device): + Device where the inactive model should reside. + + num_slices (int): + Number of slices into which the model should be chunked. + + checkpoint_activation (bool): + Boolean to indicate if we want to checkpoint intermediate + activation states on the CPU. Default value is False. + + num_microbatches (int): + Number of microbatches which should be run per model + shard on device. + """ + + def __init__( + self, + model_cpu: nn.Sequential, + device: torch.device, + offload_device: torch.device = torch.device("cpu"), + num_slices: int = 5, + checkpoint_activation: bool = False, + num_microbatches: int = 1, + ): + super().__init__() + # TODO(anj-s): Add error checks for cuda and sequential model. + + self.device = device + self.offload_device = offload_device + + # Slice the model into roughly equivalent sequential shards. + splits = _split(model_cpu, num_slices) + + # List of model shards that will be placed on/off the device. + self.model_slices: List[nn.Module] = [] + + for i, split in enumerate(splits): + # Add one model handling this slice + self.model_slices.append( + ModelShard( + cpu_model_shard=nn.Sequential(*split), device=device, offload_device=offload_device, index=i, + ) + ) + + # Expose a unified view of the slices + self.model = torch.nn.Sequential(*self.model_slices) + + # intermediate activations at the slice boundaries. + self._activations: List[Tuple] = [] + + # Currently we only support microbatches with activation checkpointing. + if not checkpoint_activation and num_microbatches > 1: + raise RuntimeError("We currently only support microbatches with activation checkpointing.") + + # Bool indicating if we want to checkpoint activation on the host. + self._checkpoint_activation = checkpoint_activation + + # Number of microbatches to run per batch on the device + self._num_microbatches = num_microbatches + + def forward(self, *inputs: Any, **_: Any) -> Any: + # At least one of the inputs needs to have `requires_grad` set. + # TODO(anj-s): Should we require users to set this or should we set it here? + set_at_least_once = False + for inp in inputs: + if inp.dtype == torch.long: + continue + inp.requires_grad = True + set_at_least_once = True + + if not set_at_least_once: + raise RuntimeError("We need at least one of the inputs to require grads.") + + if self._checkpoint_activation: + return ActivationCheckpointing.apply(*inputs, self) + + self._activations = [] + for index in range(-1, len(self.model_slices)): + if index >= 0: + # TODO(anj-s): This might be a redundant call since we have the previous + # activation on the device already. + self._activations[index] = tuple([a.cuda() for a in list(self._activations[index])]) + inputs = self._activations[index] + inputs = self.model_slices[index](*inputs) + # Call the custom autograd hooks (discard/load slices FW and BW) + inputs = ShardSyncLayer.apply(inputs, index, self.model_slices, self) + self._activations.append(inputs) + if index >= 0: + self._activations[index] = tuple([a.cpu() for a in list(self._activations[index])]) + + # We don't move the last activation/output since the target is present + # on the device. + # TODO(anj-s): It is now a requirement that the target tensors be placed on the + # device. + result = self._activations[-1] + return result[0] if len(result) == 1 else result diff --git a/tests/experimental/nn/test_offload.py b/tests/experimental/nn/test_offload.py new file mode 100644 index 000000000..f6d4d5ab2 --- /dev/null +++ b/tests/experimental/nn/test_offload.py @@ -0,0 +1,138 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +""" +Testing Offload Module +""" + +import contextlib +import copy + +import numpy as np +import pytest +import torch + +from fairscale.experimental.nn.offload import OffloadModel +from fairscale.utils.testing import skip_if_no_cuda + + +def _init(): + torch.cuda.set_device(0) + torch.manual_seed(0) + np.random.seed(0) + device = torch.device("cuda") + offload_device = torch.device("cpu") + return device, offload_device + + +@skip_if_no_cuda +def test_single_run(): + device, offload_device = _init() + model = _get_model() + + offload_model = OffloadModel(model_cpu=model, device=device, offload_device=offload_device, num_slices=2,) + offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001) + + input = torch.ones(2, 2).to(device) + labels = torch.ones(2, 2).to(device) + offload_model.train() + pred = offload_model(input) + loss_fn = torch.nn.MSELoss(reduction="sum") + loss = loss_fn(pred, labels) + loss.backward() + offload_optimizer.step() + + +def _get_model(num_inputs=2, num_hidden=2, num_layers=1, num_outputs=2): + model = torch.nn.Sequential( + torch.nn.Linear(num_inputs, num_hidden), + *([torch.nn.Linear(num_hidden, num_hidden) for _ in range(num_layers)]), + torch.nn.Linear(num_hidden, num_outputs), + ) + return model + + +def _check_parity(rmodel, omodel, ropt, oopt, rloss, oloss): + + for oparams, rparams in zip(omodel.parameters(), rmodel.parameters()): + assert torch.allclose(oparams, rparams, atol=1e-2), f"Model params are different {oparams} {rparams}" + + for o_pg, reg_pg in zip(oopt.param_groups, ropt.param_groups): + for o_pg, reg_pg in zip(o_pg["params"], reg_pg["params"]): + assert torch.allclose( + o_pg, reg_pg, atol=1e-2 + ), f"Model parameters differ in between Offlad and Vanilla {[o_pg]} {reg_pg}" + + for o_buf, reg_buf in zip(omodel.buffers(), rmodel.buffers()): + assert torch.allclose(o_buf, reg_buf, atol=1e-2), "Model buffers differ in between Offload and Vanilla." + + +def _get_fp16_context(use_fp16=False): + if use_fp16: + return torch.cuda.amp.autocast() + else: + return contextlib.nullcontext() + + +def _train(model, optimizer, use_fp16, device): + + inputs = torch.ones(32, 2).to(device) + labels = torch.ones(32, 2).to(device) + loss_fn = torch.nn.MSELoss(reduction="sum") + model.train() + with _get_fp16_context(use_fp16): + pred = model(inputs) + loss = loss_fn(pred, labels) + loss.backward() + optimizer.step() + return model, optimizer, loss + + +def _train_reg_model(model, device, offload_device, use_fp16=False): + reg_model = copy.deepcopy(model) + reg_model = reg_model.cuda() + reg_optimizer = torch.optim.SGD(reg_model.parameters(), lr=0.001) + return _train(reg_model, reg_optimizer, use_fp16, device) + + +def _train_offload_model( + model, device, offload_device, use_fp16=False, checkpoint_activation=False, num_microbatches=1 +): + omodel = copy.deepcopy(model) + offload_model = OffloadModel( + model_cpu=omodel, + device=device, + offload_device=offload_device, + num_slices=2, + checkpoint_activation=checkpoint_activation, + num_microbatches=num_microbatches, + ) + offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001) + return _train(offload_model, offload_optimizer, use_fp16, device) + + +@skip_if_no_cuda +@pytest.mark.parametrize("use_fp16", [True, False]) +@pytest.mark.parametrize("checkpoint_activation", [True, False]) +@pytest.mark.parametrize("num_microbatches", [1, 5]) +def test_correctness(use_fp16, checkpoint_activation, num_microbatches): + if (use_fp16 or checkpoint_activation) and not hasattr(torch.cuda.amp, "custom_fwd"): + pytest.skip(f"AMP APIs are not supported in torch version {torch.__version__}") + + if not checkpoint_activation and num_microbatches > 1: + pytest.skip("We only support microbatches with activation offloading.") + + device, offload_device = _init() + model = _get_model() + rmodel, ropt, rloss = _train_reg_model(model, device, offload_device) + omodel, oopt, oloss = _train_offload_model( + model, + device, + offload_device, + use_fp16=use_fp16, + checkpoint_activation=checkpoint_activation, + num_microbatches=num_microbatches, + ) + _check_parity(rmodel.cpu(), omodel.cpu(), ropt, oopt, rloss, oloss)