Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton committed Feb 24, 2021
1 parent c36e00a commit 59dbb83
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 9 deletions.
8 changes: 8 additions & 0 deletions pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def backward(
# unscale gradient to allow analyze within `on_after_backward`
if not should_accumulate and model.automatic_optimization:
self.scaler.unscale_(optimizer)
self.move_grad_to_cpu(model.trainer.model)

return closure_loss

Expand Down Expand Up @@ -88,6 +89,13 @@ def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
self.scaler.step(optimizer)
self.scaler.update()

def move_grad_to_cpu(self, model):
if hasattr(model, "cpu_offload"):
if model.cpu_offload:
for param in model.params:
param._cpu_grad.copy_(param.grad.data, non_blocking=True)
param.grad.data = param._cpu_grad

@contextmanager
def train_step_context(self) -> Generator[autocast, None, None]:
"""Enable autocast context"""
Expand Down
43 changes: 34 additions & 9 deletions pytorch_lightning/plugins/training_type/full_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,53 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _FAIRSCALE_AVAILABLE:
from fairscale.nn.data_parallel import FullyShardedDataParallel
from fairscale.nn.data_parallel.fully_sharded_data_parallel import (
FullyShardedDataParallel, Parameter, TrainingState)

from pytorch_lightning.overrides.fairscale import (
LightningFullShardedDataParallel,
unwrap_lightning_module_full_sharded,
)


class LightningFullyShardedDataParallel(FullyShardedDataParallel):
def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
"""Hook to call on each param after the reduce-scatter."""
assert torch.cuda.current_stream() == self._streams["post_backward"]
assert param.grad is not None
self.assert_state(TrainingState.BACKWARD)
param.grad.data = reduced_grad
# Cast grad to param's dtype (typically FP32). Note: we do this
# before the move_grads_to_cpu step so that this entire hook remains
# non-blocking. The downside is a bit more D2H transfer in that case.
if self.mixed_precision:
param.grad.data = param.grad.data.to(dtype=param.data.dtype)
# Optionally move gradients to CPU, typically used if one is running
# the optimizer on the CPU.
# issues with this part

# This part needs to be done after unscaling the gradients.
#if self.move_grads_to_cpu:
# param._cpu_grad.copy_(param.grad.data, non_blocking=True)
# param.grad.data = param._cpu_grad
# Don't let this memory get reused until after the transfers.
#reduced_grad.record_stream(torch.cuda.current_stream())


class FullShardedPlugin(DDPPlugin):

def __init__(
self,
cpu_offload: bool = True,
flatten_parameters: bool = False,
reshard_after_forward: bool = True,
move_grads_to_cpu: Optional[bool] = None,
fp32_reduce_scatter: Optional[bool] = None,
reshard_after_forward: bool = False,
fp32_reduce_scatter: Optional[bool] = False,
compute_dtype: Optional[torch.dtype] = None,
bucket_cap_mb: int = 25,
parallel_devices: Optional[List[torch.device]] = None,
num_nodes: int = 1,
cluster_environment: ClusterEnvironment = None,
sync_batchnorm: Optional[bool] = False
sync_batchnorm: Optional[bool] = False,
):
"""
Expand Down Expand Up @@ -72,7 +96,7 @@ def __init__(
cpu_offload: Offload FP32 params to CPU. Only useable in precision=16 mode (default: False).
move_grads_to_cpu: Moves gradient shards to CPU after reducation.
move_grads_to_cpu: Moves gradient shards to CPU after reduction.
Only disable if using CPU based optimizers (defaults to ``cpu_offload``).
flatten_parameters: Flattens parameter into single contiguous tensor for speed efficiency
Expand Down Expand Up @@ -105,26 +129,27 @@ def __init__(
raise MisconfigurationException("Currently sync batch norm is not supported by Full Sharded Training.")
super().__init__(parallel_devices, num_nodes, cluster_environment, sync_batchnorm=sync_batchnorm)
self.cpu_offload = cpu_offload
self.move_grads_to_cpu = move_grads_to_cpu
self.flatten_parameters = flatten_parameters
self.reshard_after_forward = reshard_after_forward
self.fp32_reduce_scatter = fp32_reduce_scatter
self.compute_dtype = compute_dtype
self.bucket_cap_mb = bucket_cap_mb

def configure_ddp(self):
precision = self.lightning_module.trainer.precision
trainer = self.lightning_module.trainer
precision = trainer.precision
self.model = FullyShardedDataParallel(
LightningFullShardedDataParallel(self.model),
cpu_offload=self.cpu_offload,
move_grads_to_cpu=self.move_grads_to_cpu,
move_grads_to_cpu=self.cpu_offload,
flatten_parameters=self.flatten_parameters,
mixed_precision=precision == "mixed",
reshard_after_forward=self.reshard_after_forward,
fp32_reduce_scatter=self.fp32_reduce_scatter,
compute_dtype=self.compute_dtype,
bucket_cap_mb=self.bucket_cap_mb,
)
trainer.accelerator.setup_optimizers(trainer)

@property
def lightning_module(self) -> LightningModule:
Expand Down

0 comments on commit 59dbb83

Please sign in to comment.