Skip to content

Commit

Permalink
Only sync fp32_to_fp16 stream for the top-most (root) ShardParams wra…
Browse files Browse the repository at this point in the history
…pper (#42)

* Only sync fp32_to_fp16 stream for the top-most (root) ShardParams wrapper

* Fix mypy, add test, address some comments

* Add missing assert

* Comments
  • Loading branch information
myleott authored Feb 2, 2021
1 parent 5bb212f commit cbd243e
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 82 deletions.
120 changes: 76 additions & 44 deletions fairscale/nn/data_parallel/shard_params_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch.distributed as dist
from torch.distributed import ProcessGroup
import torch.nn as nn
from torch.nn import Parameter

from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.utils.containers import (
Expand Down Expand Up @@ -133,6 +134,11 @@ def __init__(
for n, p in self.named_parameters():
assert getattr(p, "_is_sharded", False), f"found unsharded parameter: {n} ; {p.size()}"

# Flag to indicate if this instance is wrapped by any other
# ShardParamsDataParallel instances. This flag is only set after the
# first forward pass.
self._is_root: Optional[bool] = None

@torch.no_grad()
def _shard_initial_params(self) -> None:
"""
Expand Down Expand Up @@ -195,8 +201,8 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
"""Intercept state setting and perform needed changes on params."""
super().__setstate__(state)

def fixup(p: torch.nn.Parameter, size: int) -> torch.nn.Parameter:
assert isinstance(p, torch.nn.Parameter)
def fixup(p: Parameter, size: torch.Size) -> Parameter:
assert isinstance(p, Parameter)
p.data = p.data.clone() # move tensors out of shared memory
# Ignore mypy error since we add additional fields to a param.
p._is_sharded = True
Expand Down Expand Up @@ -247,49 +253,71 @@ def load_local_state_dict(
return self.module.load_state_dict(state_dict, strict)

@torch.no_grad()
def _pre_forward_init(self) -> None:
did_init = False
for p in self.params:
if hasattr(p, "_full_param"):
continue
did_init = True
assert p._is_sharded
def _init_param(self, p: Parameter) -> None:
assert p._is_sharded
assert not hasattr(p, "_full_param")

p._fp32_shard = p.data
p._fp32_shard = p.data

if self.mixed_precision:
assert p._fp32_shard.dtype == torch.float32
if self.mixed_precision:
assert p._fp32_shard.dtype == torch.float32

if self.cpu_offload:
assert p._fp32_shard.device == torch.device("cpu")
p._fp32_shard = p._fp32_shard.pin_memory()
if self.cpu_offload:
assert p._fp32_shard.device == torch.device("cpu")
p._fp32_shard = p._fp32_shard.pin_memory()

p._fp16_shard = torch.zeros_like(
p._fp32_shard,
device=self.compute_device,
dtype=self.compute_dtype,
)
free_storage_(p._fp16_shard)
p._full_param = torch.zeros(p._orig_size, device=self.compute_device, dtype=self.compute_dtype)
else:
p._fp16_shard = None # use _fp32_shard
p._full_param = p._fp32_shard.new_empty(p._orig_size)
p._fp16_shard = torch.zeros_like(p._fp32_shard, device=self.compute_device, dtype=self.compute_dtype)
free_storage_(p._fp16_shard)
else:
p._fp16_shard = None # use _fp32_shard

p._full_param = p._full_param.to(dtype=self.compute_dtype, device=self.compute_device)
free_storage_(p._full_param)
p.data = p._fp32_shard

p.data = p._fp32_shard
p._full_param = torch.zeros(p._orig_size, device=self.compute_device, dtype=self.compute_dtype)
free_storage_(p._full_param)

if self.move_grads_to_cpu:
if self.mixed_precision and not self.fp32_reduce_scatter:
grad_dtype = torch.float16
else:
grad_dtype = torch.float32
p._cpu_grad = torch.zeros_like(p.data, dtype=grad_dtype, device="cpu").pin_memory()
if self.move_grads_to_cpu:
if self.mixed_precision and not self.fp32_reduce_scatter:
grad_dtype = torch.float16
else:
grad_dtype = torch.float32
p._cpu_grad = torch.zeros_like(p.data, dtype=grad_dtype, device="cpu").pin_memory()

if did_init:
self._fp32_to_fp16_stream = torch.cuda.Stream()
self._fp32_to_fp16_stream.wait_stream(torch.cuda.current_stream())
@torch.no_grad()
def _pre_forward_init(self) -> None:
first_time_params = [p for p in self.params if not hasattr(p, "_full_param")]
for p in first_time_params:
self._init_param(p)

if len(first_time_params) > 0:
if self._is_root is None:
# This implies that no other ShardParamsDataParallel instance
# wraps this instance, otherwise it would have already set this
# flag to False.
self._is_root = True

# As the root, we now set all children instances to False.
for n, m in self.named_modules():
if n != "" and isinstance(m, ShardParamsDataParallel):
assert m._is_root is None
m._is_root = False

if self._is_root:
# Stream for moving FP32 master params (which may be on CPU) to
# FP16 for computation. We share this stream with all children
# instances, which allows them to overlap transfers across the
# forward pass without synchronizing with the default stream.
self._fp32_to_fp16_stream = torch.cuda.Stream()

for n, m in self.named_modules():
if n != "" and isinstance(m, ShardParamsDataParallel):
m._fp32_to_fp16_stream = self._fp32_to_fp16_stream

assert self._is_root is not None
if self._is_root:
# The top-most (root) instance needs to synchronize with the default
# stream, so we don't move the FP32 master weights prematurely.
self._fp32_to_fp16_stream.wait_stream(torch.cuda.current_stream())

def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
self._pre_forward_init()
Expand All @@ -313,8 +341,12 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
# initialized with the correct dtype and size.
self._use_fp32_param_shard()

if not torch.is_grad_enabled():
return outputs
if torch.is_grad_enabled():
outputs = self._register_pre_backward_hooks(outputs)

return outputs

def _register_pre_backward_hooks(self, outputs: Any) -> Any:

# Register pre-backward hook to run before the wrapped module's backward.
pre_backward_hook_has_run = [False]
Expand Down Expand Up @@ -352,7 +384,7 @@ def _register_post_backward_hooks(self) -> None:
p._shard_bwd_hook = (grad_acc, handle)

@torch.no_grad()
def _post_backward_hook(self, param: torch.nn.Parameter, *unused: Any) -> None:
def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
if param.grad is None:
return
if param.grad.requires_grad:
Expand Down Expand Up @@ -407,7 +439,7 @@ def _use_full_params(self) -> None:
p.data = p._full_param

@torch.no_grad()
def _free_full_params(self, params: Optional[List[torch.nn.Parameter]] = None) -> None:
def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None:
"""Free up storage for full parameters."""
if params is None:
params = self.params
Expand All @@ -423,15 +455,15 @@ def _free_full_params(self, params: Optional[List[torch.nn.Parameter]] = None) -
free_storage_(p._full_param)

@torch.no_grad()
def _use_fp32_param_shard(self, params: Optional[List[torch.nn.Parameter]] = None) -> None:
def _use_fp32_param_shard(self, params: Optional[List[Parameter]] = None) -> None:
"""Use FP32 shard for a list of params."""
if params is None:
params = self.params
for p in params:
p.data = p._fp32_shard

@torch.no_grad()
def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[torch.nn.Parameter]] = None) -> None:
def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = None) -> None:
"""Cast FP32 param shard to FP16 for a list of params."""
if params is None:
params = self.params
Expand All @@ -444,7 +476,7 @@ def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[torch.nn.Paramet
torch.cuda.current_stream().wait_stream(self._fp32_to_fp16_stream)

@torch.no_grad()
def _free_fp16_param_shard(self, params: Optional[List[torch.nn.Parameter]] = None) -> None:
def _free_fp16_param_shard(self, params: Optional[List[Parameter]] = None) -> None:
"""Free storage for FP16 shards for a list of params."""
if params is None:
params = self.params
Expand Down
4 changes: 2 additions & 2 deletions fairscale/nn/misc/flatten_params_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ def __getattr__(self, name: str) -> Any:
except AttributeError:
return getattr(self.module, name) # fallback to wrapped module

def state_dict(self, prefix: str = "", keep_vars: bool = False) -> "OrderedDict[str, Tensor]": # type: ignore
def state_dict(self, *args: Any, **kwargs: Any) -> "OrderedDict[str, Tensor]": # type: ignore
"""Return an unflattened state_dict."""
with self.unflatten_params():
return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)
return self.module.state_dict(*args, **kwargs)

def flat_state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
"""Return the flattened state_dict."""
Expand Down
22 changes: 17 additions & 5 deletions fairscale/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,11 +294,7 @@ def __init__(self, embed_dim: int, num_heads: int) -> None:
self.ln_1 = nn.LayerNorm(embed_dim)
self.ln_2 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, num_heads) # type: ignore
self.mlp = nn.Sequential(
nn.Linear(embed_dim, embed_dim * 4),
nn.GELU(),
nn.Linear(embed_dim * 4, embed_dim),
)
self.mlp = nn.Sequential(nn.Linear(embed_dim, embed_dim * 4), nn.GELU(), nn.Linear(embed_dim * 4, embed_dim),)

def forward(self, *inputs: Any, **kwargs: Any) -> Tensor:
x = inputs[0]
Expand Down Expand Up @@ -452,3 +448,19 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
self._check("loss.device", loss.device, self.expected_loss_device)

return loss


@functools.lru_cache
def get_cycles_per_ms() -> float:
"""Approximate number of cycles per millisecond for torch.cuda._sleep
Copied from: github.com/pytorch/pytorch/blob/master/test/test_cuda.py
"""
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
torch.cuda._sleep(1000000)
end.record()
end.synchronize()
cycles_per_ms = 1000000 / start.elapsed_time(end)
return cycles_per_ms
13 changes: 7 additions & 6 deletions stubs/torch/nn/parameter.pyi
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

from .. import Tensor
from typing import Optional
from .. import Size, Tensor
import builtins

class Parameter(Tensor):
# These are dynamic attributes added by shard_params_data_parallel class.
# Added here for better type checking.
_is_sharded: bool
_orig_size: int
_cpu_grad: Parameter
_full_param: Parameter
_fp32_shard: Parameter
_fp16_shard: Parameter
_orig_size: Size
_cpu_grad: Tensor
_full_param: Tensor
_fp32_shard: Tensor
_fp16_shard: Optional[Tensor]

def __init__(self, data: Tensor, requires_grad: builtins.bool = True): ...

Expand Down
Loading

0 comments on commit cbd243e

Please sign in to comment.