diff --git a/fairscale/nn/data_parallel/shard_params_data_parallel.py b/fairscale/nn/data_parallel/shard_params_data_parallel.py index f79051baf..52b080d63 100644 --- a/fairscale/nn/data_parallel/shard_params_data_parallel.py +++ b/fairscale/nn/data_parallel/shard_params_data_parallel.py @@ -11,6 +11,7 @@ import torch.distributed as dist from torch.distributed import ProcessGroup import torch.nn as nn +from torch.nn import Parameter from fairscale.nn.misc import FlattenParamsWrapper from fairscale.utils.containers import ( @@ -133,6 +134,11 @@ def __init__( for n, p in self.named_parameters(): assert getattr(p, "_is_sharded", False), f"found unsharded parameter: {n} ; {p.size()}" + # Flag to indicate if this instance is wrapped by any other + # ShardParamsDataParallel instances. This flag is only set after the + # first forward pass. + self._is_root: Optional[bool] = None + @torch.no_grad() def _shard_initial_params(self) -> None: """ @@ -195,8 +201,8 @@ def __setstate__(self, state: Dict[str, Any]) -> None: """Intercept state setting and perform needed changes on params.""" super().__setstate__(state) - def fixup(p: torch.nn.Parameter, size: int) -> torch.nn.Parameter: - assert isinstance(p, torch.nn.Parameter) + def fixup(p: Parameter, size: torch.Size) -> Parameter: + assert isinstance(p, Parameter) p.data = p.data.clone() # move tensors out of shared memory # Ignore mypy error since we add additional fields to a param. p._is_sharded = True @@ -247,49 +253,71 @@ def load_local_state_dict( return self.module.load_state_dict(state_dict, strict) @torch.no_grad() - def _pre_forward_init(self) -> None: - did_init = False - for p in self.params: - if hasattr(p, "_full_param"): - continue - did_init = True - assert p._is_sharded + def _init_param(self, p: Parameter) -> None: + assert p._is_sharded + assert not hasattr(p, "_full_param") - p._fp32_shard = p.data + p._fp32_shard = p.data - if self.mixed_precision: - assert p._fp32_shard.dtype == torch.float32 + if self.mixed_precision: + assert p._fp32_shard.dtype == torch.float32 - if self.cpu_offload: - assert p._fp32_shard.device == torch.device("cpu") - p._fp32_shard = p._fp32_shard.pin_memory() + if self.cpu_offload: + assert p._fp32_shard.device == torch.device("cpu") + p._fp32_shard = p._fp32_shard.pin_memory() - p._fp16_shard = torch.zeros_like( - p._fp32_shard, - device=self.compute_device, - dtype=self.compute_dtype, - ) - free_storage_(p._fp16_shard) - p._full_param = torch.zeros(p._orig_size, device=self.compute_device, dtype=self.compute_dtype) - else: - p._fp16_shard = None # use _fp32_shard - p._full_param = p._fp32_shard.new_empty(p._orig_size) + p._fp16_shard = torch.zeros_like(p._fp32_shard, device=self.compute_device, dtype=self.compute_dtype) + free_storage_(p._fp16_shard) + else: + p._fp16_shard = None # use _fp32_shard - p._full_param = p._full_param.to(dtype=self.compute_dtype, device=self.compute_device) - free_storage_(p._full_param) + p.data = p._fp32_shard - p.data = p._fp32_shard + p._full_param = torch.zeros(p._orig_size, device=self.compute_device, dtype=self.compute_dtype) + free_storage_(p._full_param) - if self.move_grads_to_cpu: - if self.mixed_precision and not self.fp32_reduce_scatter: - grad_dtype = torch.float16 - else: - grad_dtype = torch.float32 - p._cpu_grad = torch.zeros_like(p.data, dtype=grad_dtype, device="cpu").pin_memory() + if self.move_grads_to_cpu: + if self.mixed_precision and not self.fp32_reduce_scatter: + grad_dtype = torch.float16 + else: + grad_dtype = torch.float32 + p._cpu_grad = torch.zeros_like(p.data, dtype=grad_dtype, device="cpu").pin_memory() - if did_init: - self._fp32_to_fp16_stream = torch.cuda.Stream() - self._fp32_to_fp16_stream.wait_stream(torch.cuda.current_stream()) + @torch.no_grad() + def _pre_forward_init(self) -> None: + first_time_params = [p for p in self.params if not hasattr(p, "_full_param")] + for p in first_time_params: + self._init_param(p) + + if len(first_time_params) > 0: + if self._is_root is None: + # This implies that no other ShardParamsDataParallel instance + # wraps this instance, otherwise it would have already set this + # flag to False. + self._is_root = True + + # As the root, we now set all children instances to False. + for n, m in self.named_modules(): + if n != "" and isinstance(m, ShardParamsDataParallel): + assert m._is_root is None + m._is_root = False + + if self._is_root: + # Stream for moving FP32 master params (which may be on CPU) to + # FP16 for computation. We share this stream with all children + # instances, which allows them to overlap transfers across the + # forward pass without synchronizing with the default stream. + self._fp32_to_fp16_stream = torch.cuda.Stream() + + for n, m in self.named_modules(): + if n != "" and isinstance(m, ShardParamsDataParallel): + m._fp32_to_fp16_stream = self._fp32_to_fp16_stream + + assert self._is_root is not None + if self._is_root: + # The top-most (root) instance needs to synchronize with the default + # stream, so we don't move the FP32 master weights prematurely. + self._fp32_to_fp16_stream.wait_stream(torch.cuda.current_stream()) def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: self._pre_forward_init() @@ -313,8 +341,12 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: # initialized with the correct dtype and size. self._use_fp32_param_shard() - if not torch.is_grad_enabled(): - return outputs + if torch.is_grad_enabled(): + outputs = self._register_pre_backward_hooks(outputs) + + return outputs + + def _register_pre_backward_hooks(self, outputs: Any) -> Any: # Register pre-backward hook to run before the wrapped module's backward. pre_backward_hook_has_run = [False] @@ -352,7 +384,7 @@ def _register_post_backward_hooks(self) -> None: p._shard_bwd_hook = (grad_acc, handle) @torch.no_grad() - def _post_backward_hook(self, param: torch.nn.Parameter, *unused: Any) -> None: + def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: if param.grad is None: return if param.grad.requires_grad: @@ -407,7 +439,7 @@ def _use_full_params(self) -> None: p.data = p._full_param @torch.no_grad() - def _free_full_params(self, params: Optional[List[torch.nn.Parameter]] = None) -> None: + def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None: """Free up storage for full parameters.""" if params is None: params = self.params @@ -423,7 +455,7 @@ def _free_full_params(self, params: Optional[List[torch.nn.Parameter]] = None) - free_storage_(p._full_param) @torch.no_grad() - def _use_fp32_param_shard(self, params: Optional[List[torch.nn.Parameter]] = None) -> None: + def _use_fp32_param_shard(self, params: Optional[List[Parameter]] = None) -> None: """Use FP32 shard for a list of params.""" if params is None: params = self.params @@ -431,7 +463,7 @@ def _use_fp32_param_shard(self, params: Optional[List[torch.nn.Parameter]] = Non p.data = p._fp32_shard @torch.no_grad() - def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[torch.nn.Parameter]] = None) -> None: + def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = None) -> None: """Cast FP32 param shard to FP16 for a list of params.""" if params is None: params = self.params @@ -444,7 +476,7 @@ def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[torch.nn.Paramet torch.cuda.current_stream().wait_stream(self._fp32_to_fp16_stream) @torch.no_grad() - def _free_fp16_param_shard(self, params: Optional[List[torch.nn.Parameter]] = None) -> None: + def _free_fp16_param_shard(self, params: Optional[List[Parameter]] = None) -> None: """Free storage for FP16 shards for a list of params.""" if params is None: params = self.params diff --git a/fairscale/nn/misc/flatten_params_wrapper.py b/fairscale/nn/misc/flatten_params_wrapper.py index 6e039f78b..d05f0f3a0 100644 --- a/fairscale/nn/misc/flatten_params_wrapper.py +++ b/fairscale/nn/misc/flatten_params_wrapper.py @@ -130,10 +130,10 @@ def __getattr__(self, name: str) -> Any: except AttributeError: return getattr(self.module, name) # fallback to wrapped module - def state_dict(self, prefix: str = "", keep_vars: bool = False) -> "OrderedDict[str, Tensor]": # type: ignore + def state_dict(self, *args: Any, **kwargs: Any) -> "OrderedDict[str, Tensor]": # type: ignore """Return an unflattened state_dict.""" with self.unflatten_params(): - return self.module.state_dict(prefix=prefix, keep_vars=keep_vars) + return self.module.state_dict(*args, **kwargs) def flat_state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: """Return the flattened state_dict.""" diff --git a/fairscale/utils/testing.py b/fairscale/utils/testing.py index 683825d4c..323c0e4e7 100644 --- a/fairscale/utils/testing.py +++ b/fairscale/utils/testing.py @@ -294,11 +294,7 @@ def __init__(self, embed_dim: int, num_heads: int) -> None: self.ln_1 = nn.LayerNorm(embed_dim) self.ln_2 = nn.LayerNorm(embed_dim) self.attn = nn.MultiheadAttention(embed_dim, num_heads) # type: ignore - self.mlp = nn.Sequential( - nn.Linear(embed_dim, embed_dim * 4), - nn.GELU(), - nn.Linear(embed_dim * 4, embed_dim), - ) + self.mlp = nn.Sequential(nn.Linear(embed_dim, embed_dim * 4), nn.GELU(), nn.Linear(embed_dim * 4, embed_dim),) def forward(self, *inputs: Any, **kwargs: Any) -> Tensor: x = inputs[0] @@ -452,3 +448,19 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: self._check("loss.device", loss.device, self.expected_loss_device) return loss + + +@functools.lru_cache +def get_cycles_per_ms() -> float: + """Approximate number of cycles per millisecond for torch.cuda._sleep + + Copied from: github.com/pytorch/pytorch/blob/master/test/test_cuda.py + """ + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + torch.cuda._sleep(1000000) + end.record() + end.synchronize() + cycles_per_ms = 1000000 / start.elapsed_time(end) + return cycles_per_ms diff --git a/stubs/torch/nn/parameter.pyi b/stubs/torch/nn/parameter.pyi index ae0ea0861..05f24df38 100644 --- a/stubs/torch/nn/parameter.pyi +++ b/stubs/torch/nn/parameter.pyi @@ -1,17 +1,18 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from .. import Tensor +from typing import Optional +from .. import Size, Tensor import builtins class Parameter(Tensor): # These are dynamic attributes added by shard_params_data_parallel class. # Added here for better type checking. _is_sharded: bool - _orig_size: int - _cpu_grad: Parameter - _full_param: Parameter - _fp32_shard: Parameter - _fp16_shard: Parameter + _orig_size: Size + _cpu_grad: Tensor + _full_param: Tensor + _fp32_shard: Tensor + _fp16_shard: Optional[Tensor] def __init__(self, data: Tensor, requires_grad: builtins.bool = True): ... diff --git a/tests/nn/data_parallel/test_shard_params_data_parallel.py b/tests/nn/data_parallel/test_shard_params_data_parallel.py index e184c2584..289bded37 100644 --- a/tests/nn/data_parallel/test_shard_params_data_parallel.py +++ b/tests/nn/data_parallel/test_shard_params_data_parallel.py @@ -7,6 +7,7 @@ import itertools import sys import tempfile +from typing import Dict import unittest from unittest import mock @@ -14,8 +15,7 @@ from torch import nn from fairscale.nn.data_parallel import ShardParamsDataParallel -from fairscale.utils.testing import DeviceAndTypeCheckModule, objects_are_equal -from typing import Dict +from fairscale.utils.testing import DeviceAndTypeCheckModule, get_cycles_per_ms, objects_are_equal class DistributedTest(unittest.TestCase): @@ -30,6 +30,7 @@ def setUp(self): raise unittest.SkipTest("NCCL doesn't support Windows, skipping test") if torch.cuda.device_count() < 2: raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping") + torch.manual_seed(0) # keep everything deterministic @staticmethod def _train_for_several_steps(model, num_steps, autocast): @@ -42,7 +43,7 @@ def _train_for_several_steps(model, num_steps, autocast): input = model.module.get_input(torch.device("cuda")) output = model(*input) loss = model.module.get_loss(input, output).to(model_device) - print(f'loss device: {loss.device}') + print(f"loss device: {loss.device}") assert loss.dtype == torch.float32 loss.backward() optim.step() @@ -122,9 +123,7 @@ def _test_dtypes(cfg: Dict, autocast, in_dtype, p_dtype, loss_dtype, reduce_dtyp orig_reduce_scatter = ShardParamsDataParallel._reduce_scatter model = DeviceAndTypeCheckModule( - expected_input_dtype=in_dtype, - expected_param_dtype=p_dtype, - expected_loss_dtype=loss_dtype, + expected_input_dtype=in_dtype, expected_param_dtype=p_dtype, expected_loss_dtype=loss_dtype, ) def _reduce_scatter(self, tensor): @@ -152,8 +151,7 @@ def test_transformer(self): for config in itertools.product([True, False], repeat=len(keys)): config = dict(zip(keys, config)) spawn_and_init( - functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config), - world_size=2, + functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config), world_size=2, ) def test_cpu_offload_and_cpu_grads(self): @@ -164,12 +162,34 @@ def test_cpu_offload_and_cpu_grads(self): def test_cpu_offload_and_cuda_grads(self): # If grads are on gpu, but model and optimizer are on cpu, backward breaks. config = {"mixed_precision": True, "cpu_offload": True, "move_grads_to_cpu": False} - with self.assertRaises(Exception): # RuntimeError inside spawn - test_fn = functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False) + with self.assertRaises(Exception): # RuntimeError inside spawn + test_fn = functools.partial( + self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False + ) spawn_and_init(test_fn) + def test_delayed_optim_step(self): + # We use a model with a long CUDA delay right before the optimizer step. + # This tests our streams logic, and that we don't start the FP32 -> FP16 + # transfer until after the optimization step completes. + config = {"mixed_precision": True} + test_fn = functools.partial(self._test_identical_outputs, self._delayed_optim_step_model, config) + spawn_and_init(test_fn) + @classmethod - def _test_identical_outputs(cls, model_cls, config, rank, group, num_steps=3, use_cuda=True): + def _delayed_optim_step_model(cls, rank, group, config=None): + def _maybe_wrap(layer): + if config is not None: + return ShardParamsDataParallel(layer, group, **config) + return layer + + model = nn.Sequential( + nn.Linear(8, 4), _maybe_wrap(nn.Linear(4, 16)), _maybe_wrap(nn.Linear(16, 4)), nn.Linear(4, 8), + ) + return ModuleWithDelay(model, delay_after_loss_ms=250) + + @classmethod + def _test_identical_outputs(cls, model_init_fn, config, rank, group, num_steps=3, use_cuda=True): if config["mixed_precision"]: autocast = True # Force the compute dtype to be torch.float32 so that we get @@ -181,14 +201,13 @@ def _test_identical_outputs(cls, model_cls, config, rank, group, num_steps=3, us autocast = False # Establish reference behavior with PyTorch DDP (+ optionally autocast). - model = nn.parallel.DistributedDataParallel( - model_cls().cuda(), device_ids=[rank], output_device=rank, process_group=group - ) + model = model_init_fn(rank, group).cuda() + model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, process_group=group) ref_loss = cls._train_for_several_steps(model, num_steps, autocast) ref_state_dict = model.module.state_dict() # Confirm we get the same behavior using ShardParamsDataParallel. - model = ShardParamsDataParallel(model_cls(), group, **config) + model = ShardParamsDataParallel(model_init_fn(rank, group, config), group, **config) if use_cuda: model = model.cuda() else: @@ -204,18 +223,14 @@ def _test_identical_outputs(cls, model_cls, config, rank, group, num_steps=3, us class TransformerWithSharedParams(nn.Module): - def __init__(self): + def __init__(self, *args, **kwargs): super().__init__() torch.manual_seed(0) # keep everything deterministic d_model = 16 d_vocab = 32 self.embed_tokens = nn.Embedding(d_vocab, d_model) self.transformer = nn.Transformer( - d_model=d_model, - num_encoder_layers=2, - num_decoder_layers=2, - dim_feedforward=8, - dropout=0.1, + d_model=d_model, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=8, dropout=0.1, ) self.output_proj = nn.Linear(d_model, d_vocab) # share the embedding and output projection weights @@ -238,6 +253,25 @@ def get_loss(self, input, output): return nn.functional.cross_entropy(output.view(-1, output.size(-1)), tgt.view(-1), reduction="sum") +class ModuleWithDelay(nn.Module): + def __init__(self, module, delay_after_loss_ms): + super().__init__() + self.module = module + self.delay_after_loss_ms = delay_after_loss_ms + + def get_input(self, device): + torch.manual_seed(1) # keep everything deterministic + return (torch.rand(4, 8, device=device),) + + def forward(self, x): + return self.module(x) + + def get_loss(self, input, output): + loss = output.sum() + torch.cuda._sleep(int(self.delay_after_loss_ms * get_cycles_per_ms())) + return loss + + def spawn_and_init(fn, world_size=2, args=None): if args is None: args = () @@ -252,10 +286,7 @@ def spawn_and_init(fn, world_size=2, args=None): def distributed_init(rank, world_size, tmp_file): torch.distributed.init_process_group( - backend="nccl", - init_method="file://{}".format(tmp_file), - world_size=world_size, - rank=rank, + backend="nccl", init_method="file://{}".format(tmp_file), world_size=world_size, rank=rank, ) torch.cuda.set_device(rank)