Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP integration #6152

Closed
wants to merge 77 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
78f1eb4
Add initial FSDP integration
Feb 23, 2021
c36e00a
Fix error in refactor
Feb 23, 2021
59dbb83
update
tchaton Feb 24, 2021
19a1440
Revert "update"
Feb 24, 2021
3b38615
Address reviews
Feb 24, 2021
5ff06ab
Fix doc string
Feb 24, 2021
36434f0
Even moar code review
Feb 24, 2021
c61a190
Add deprecation
Feb 24, 2021
1c4f011
Merge branch 'master' into feat/fsdp
Feb 25, 2021
02599e6
Fix name of test
Feb 25, 2021
e79977a
Integrate nesting, fix bugs across implementation
Mar 1, 2021
d15d4b5
Merge branch 'master' into feat/fsdp
Mar 2, 2021
ebf1818
Formatting types
Mar 2, 2021
290e8fd
Add additional tests for accelerator model
Mar 2, 2021
5c5f762
Fix import
Mar 2, 2021
d28438b
Few test fixes, expose params
Mar 3, 2021
ab591a8
Allow training_type_plugin to delay optimizer configure
Mar 3, 2021
23ccdb8
Merge branch 'feat/fsdp_2n' into feat/fsdp
Mar 3, 2021
a60f2c0
Add missing references to trainer, add a CPU accelerator based test
Mar 3, 2021
3d4e6df
Merge branch 'feat/fsdp_2n' into feat/fsdp
Mar 4, 2021
516bd04
Update for latest API changes to fairscale
Mar 9, 2021
9f8864f
Add base hook for model parallel
Mar 23, 2021
eac5344
fix callback signature
kaushikb11 Mar 25, 2021
32df0cb
Simplify hook
Mar 25, 2021
282a133
Add hook logic
Mar 25, 2021
7a94e72
add tests
kaushikb11 Mar 25, 2021
8091481
add property setter
kaushikb11 Mar 25, 2021
633fc77
add logic for being called once
kaushikb11 Mar 25, 2021
c99a36f
Update changelog
kaushikb11 Mar 25, 2021
a68c8d7
Merge branch 'master' into feat/model_parallel_hook
kaushikb11 Mar 25, 2021
9529a22
Fix
kaushikb11 Mar 25, 2021
3c1c782
fix return type
kaushikb11 Mar 25, 2021
7daba43
Merge branch 'master' into feat/fsdp
Mar 25, 2021
87ec222
Fix property name
Mar 25, 2021
966b2e5
Merge branch 'feat/model_parallel_hook' into feat/fsdp
Mar 25, 2021
5f6e039
Updaet wrapper, use latest fixes for hooks
Mar 25, 2021
b512e72
Swap hook order
Mar 25, 2021
8ba82df
Merge branch 'master' into feat/fsdp
Mar 29, 2021
1e5ca37
Small changes
Mar 29, 2021
936dc1a
Fixes
Mar 29, 2021
a6de18e
Remove activation checkpointing
Apr 1, 2021
8684f94
Turn off auto wrap by default
Apr 1, 2021
76091ae
Move to trainer.model
Apr 7, 2021
226d498
fix reference
Apr 7, 2021
cd63c10
Merge branch 'master' into feat/fsdp
Apr 7, 2021
b881e2f
Remove flag
Apr 7, 2021
e8959be
Fix imports
Apr 7, 2021
52478ac
Fix versions, update docs
Apr 7, 2021
b7f1896
Fix clip gradients
Apr 8, 2021
a62f8d8
Merge branch 'master' into feat/fsdp
Apr 10, 2021
69c33f1
Merge branch 'master' into feat/fsdp
Apr 14, 2021
9fa26c0
Fixes
Apr 14, 2021
56f23ce
pull
Apr 14, 2021
9ca3f0c
Few changes across the board
Apr 14, 2021
b53ba36
Fix imports
Apr 14, 2021
0da5249
Set none
Apr 14, 2021
90c6479
Swap to warnings
Apr 14, 2021
69d8178
Remove fairscale from container
Apr 14, 2021
a459d10
pull
Apr 14, 2021
a7842d9
Update dockers/base-cuda/Dockerfile
Apr 14, 2021
48ee83f
Add defaults, add test to ensure nested wrapper is set correctly
Apr 15, 2021
57a696c
Remove deprecation as this will be removed completely
Apr 15, 2021
36889b8
Check for nested FSDP wrappers, and omit wrapping algorithm
Apr 16, 2021
89b8cb5
Merge branch 'master' into feat/fsdp
Apr 16, 2021
0c1d2de
Update pytorch_lightning/trainer/connectors/accelerator_connector.py
Apr 21, 2021
592bb28
Address code review points
Apr 21, 2021
4e230c9
Merge branch 'master' into feat/fsdp
Apr 26, 2021
ca8e586
Add back missing model that was removed from clipping signature
Apr 26, 2021
54f501d
Do not pass model through, accelerator does it
Apr 26, 2021
02925cc
Merge branch 'master' into feat/fsdp
Apr 27, 2021
b67f1a9
Fix merge
Apr 27, 2021
132eb64
Fix imports
Apr 27, 2021
e6ce3cf
Changes to precision plugin
Apr 27, 2021
01153af
Require 2 GPU for multi gpu test
Apr 27, 2021
6cfe57d
Merge branch 'master' into feat/fsdp
May 2, 2021
efa81ab
Use callback in test, swap to DynamicLossScaler from fairscale to tes…
May 4, 2021
78d52b5
Disable loss scaler for now
May 4, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions pytorch_lightning/overrides/fairscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,48 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, unwrap_lightning_module
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_FULLY_SHARDED_AVAILABLE


class LightningShardedDataParallel(_LightningModuleWrapperBase):
# Just do this for later docstrings
pass


LightningShardedDataParallel = None
if _FAIRSCALE_AVAILABLE:
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel

class LightningShardedDataParallel(_LightningModuleWrapperBase):
# Just do this for later docstrings
pass

def unwrap_lightning_module_sharded(wrapped_model) -> LightningModule:
model = wrapped_model
if isinstance(model, ShardedDataParallel):
model = model.module

return unwrap_lightning_module(model)


class LightningFullyShardedModule(_LightningModuleWrapperBase):
# Just do this for later docstrings
pass


if _FAIRSCALE_FULLY_SHARDED_AVAILABLE:
from fairscale.nn import FlattenParamsWrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel

def unwrap_lightning_module_fully_sharded(wrapped_model) -> LightningModule:
"""
Unwrap the lightning module within the FSDP wrapper. This is recursive as FSDP can be nested, meaning
the LightningModule could be a few layers deep.
"""
model = wrapped_model
if isinstance(model, FullyShardedDataParallel):
model = unwrap_lightning_module_fully_sharded(model.module)
# Additional check if we're using a flattened parameters buffer
elif isinstance(model, FlattenParamsWrapper):
model = unwrap_lightning_module_fully_sharded(model.module)
if isinstance(model, _LightningModuleWrapperBase):
model = unwrap_lightning_module_fully_sharded(model.module)
return model
6 changes: 6 additions & 0 deletions pytorch_lightning/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import ( # noqa: F401
FullyShardedNativeMixedPrecisionPlugin,
)
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401
Expand All @@ -15,6 +18,7 @@
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.fully_sharded import FullyShardedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401
Expand All @@ -35,6 +39,8 @@
"DeepSpeedPlugin",
"DeepSpeedPrecisionPlugin",
"DoublePrecisionPlugin",
"FullyShardedPlugin",
"FullyShardedNativeMixedPrecisionPlugin",
"HorovodPlugin",
"NativeMixedPrecisionPlugin",
"PrecisionPlugin",
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/plugins/precision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import ( # noqa: F401
FullyShardedNativeMixedPrecisionPlugin,
)
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
Expand Down
39 changes: 39 additions & 0 deletions pytorch_lightning/plugins/precision/fully_sharded_native_amp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import cast, Optional, Union

from torch.nn import Module
from torch.optim import Optimizer

from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE, GradClipAlgorithmType

if _FAIRSCALE_FULLY_SHARDED_AVAILABLE:
from fairscale.nn.data_parallel import FullyShardedDataParallel


class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin):
"""Mixed Precision for Full Sharded Training"""

def clip_gradients(
self,
optimizer: 'Optimizer',
clip_val: Union[int, float],
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
model: Optional[Module] = None
) -> None:
# Model manages clipping of gradients
model = cast(FullyShardedDataParallel, model)
# todo: expose norm type once precision plugin supports this.
model.clip_grad_norm_(clip_val, norm_type=2.0)
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/training_type/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.fully_sharded import FullyShardedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401
Expand Down
218 changes: 218 additions & 0 deletions pytorch_lightning/plugins/training_type/fully_sharded.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from typing import Dict, Generator, List, Optional

import torch

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _FAIRSCALE_FULLY_SHARDED_AVAILABLE:
from fairscale.nn import auto_wrap, default_auto_wrap_policy, enable_wrap, FlattenParamsWrapper, wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel

from pytorch_lightning.overrides.fairscale import LightningFullyShardedModule, unwrap_lightning_module_fully_sharded


class FullyShardedPlugin(DDPPlugin):

def __init__(
self,
cpu_offload: bool = False,
flatten_parameters: bool = True,
reshard_after_forward: bool = True,
move_grads_to_cpu: Optional[bool] = None,
fp32_reduce_scatter: Optional[bool] = None,
compute_dtype: Optional[torch.dtype] = None,
bucket_cap_mb: int = 25,
automatic_module_wrap: bool = False,
min_num_params: int = 1e8,
parallel_devices: Optional[List[torch.device]] = None,
num_nodes: Optional[int] = None,
cluster_environment: ClusterEnvironment = None,
sync_batchnorm: Optional[bool] = None
):
"""

Provides capabilities to run training using the Full Sharded capabilities provided by FairScale.

Full Sharded Training shards the entire model across all available GPUs, allowing you to scale model
size, whilst using efficient communication to reduce overhead. In practice, this means we can remain
at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar
to ZeRO-Stage 3 but has been built for upstreaming to PyTorch.

`For more information: https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html`.

.. warning:: ``FullyShardedPlugin`` is in beta and subject to change.

Defaults have been set and options have been exposed, but may require configuration
based on your level of memory/speed efficiency.
We suggest having a look at this PR for more information.
`https://github.com/facebookresearch/fairscale/pull/413`


Many of the helpful doc strings below came from the original FairScale documentation:
`https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html`

Arguments:

cpu_offload: Offload FP32 params to CPU. Only useable in precision=16 mode (default: False).

move_grads_to_cpu: Moves gradient shards to CPU after reduction.
Only disable if using CPU based optimizers (defaults to ``cpu_offload``).

flatten_parameters: Flattens parameter into single contiguous tensor for speed efficiency
(default: False).

reshard_after_forward: Reshard parameters after the forward pass, which saves memory but slows
down training. Only revelant when nesting FullyShardedDataParallel wrappers inside the model.
(default: False).

fp32_reduce_scatter: Reduce-Scatter gradients in FP32. Only relevant in mixed precision
(default: None)

compute_dtype: dtype for full parameters for computation. Default to torch.float32,
unless using mixed precision, in which case defaults to torch.float16.

bucket_cap_mb: bucket parameters so that gradient reduction
can potentially overlap with backward computation.
bucket_cap_mb controls the bucket size in MegaBytes (MB).
Buckets are sub-divided based on world_size,
so the max shard size is roughly bucket_cap_mb / world_size.
Values <= 0 disable bucketing. (Default: 25).

automatic_module_wrap: Automatically wrap the lightning module with Fully Sharded recursively.
Using ``min_num_params`` to determine the amount of parameters to wrap at a time.
(default: False)

min_num_params: Number of parameters to wrap when using FairScale ``auto_wrap``.
(default: 1e8)

"""
if not _FAIRSCALE_FULLY_SHARDED_AVAILABLE:
raise MisconfigurationException(
"Full Sharded Training is not available. Install the latest FairScale via `pip install fairscale -U`"
)

super().__init__(parallel_devices, num_nodes, cluster_environment, sync_batchnorm)
self.cpu_offload = cpu_offload
self.move_grads_to_cpu = move_grads_to_cpu
self.flatten_parameters = flatten_parameters
self.reshard_after_forward = reshard_after_forward
self.fp32_reduce_scatter = fp32_reduce_scatter
self.compute_dtype = compute_dtype
self.bucket_cap_mb = bucket_cap_mb
self.automatic_module_wrap = automatic_module_wrap
self.min_num_params = min_num_params
self._process_group = None

@property
def process_group(self):
if self._process_group is None:
self._process_group = torch.distributed.new_group()
return self._process_group

def setup_distributed(self):
super().setup_distributed()
if self.root_device.type == "cuda":
torch.cuda.set_device(self.root_device)

@contextlib.contextmanager
def model_sharded_context(self) -> Generator:
precision = self.lightning_module.trainer.precision

def wrap_policy(*args, **kwargs):
return default_auto_wrap_policy(*args, **kwargs, min_num_params=self.min_num_params)

with enable_wrap(
wrapper_cls=FullyShardedDataParallel,
auto_wrap_policy=wrap_policy,
process_group=self.process_group,
cpu_offload=self.cpu_offload,
move_grads_to_cpu=self.move_grads_to_cpu,
flatten_parameters=self.flatten_parameters,
mixed_precision=precision == "mixed",
reshard_after_forward=self.reshard_after_forward,
fp32_reduce_scatter=self.fp32_reduce_scatter,
compute_dtype=self.compute_dtype,
bucket_cap_mb=self.bucket_cap_mb,
):
yield

def configure_ddp(self):
with self.model_sharded_context():
if self.automatic_module_wrap and not self._model_has_nested_fsdp():
self.model = auto_wrap(LightningFullyShardedModule(self.model))
if not isinstance(self.model, FullyShardedDataParallel):
self.model = wrap(self.model)
else:
self.model = wrap(LightningFullyShardedModule(self.model))
Comment on lines +159 to +164
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if manually wrapping contents inside the lightning module, is this final outer layer wrap needed? or could we defer this to the user in the lightning module too?

then we could not wrap model in the dummy LightningFullyShardedModule to map forward to one of the step functions. would it also mean users don't have to refer to self.trainer.model inside of the lightning module?

would this avoid the parameter flattening issue across stages?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I understand! Theoretically I think so, since then we're just using the LM as a wrapper. So the cases I see:

  1. User wraps nothing, expects module to be wrapped by Lightning, and potentially auto_wrap to handle recursive wrapping
  2. User wraps some of the layers in configure_sharded_model but then expects all other layers to be included in a higher wrapper class (wrap the entire LM)
  3. User wraps all of the layers in configure_sharded_model, doesn't require any high level wrapping

Solutions

  1. This should be default behaviour, i.e plugins=fsdp or plugins=fsdp_auto_wrap
  2. This should be the same as 1., i.e plugins=fsdp or plugins=fsdp_auto_wrap
  3. This could be plugins=fsdp_manual where we do not wrap the highest level module, allowing the user to do whatever they'd like in configure_optimizers.

In either case, it's important to fix the flattening issue for 1. and 2. which for most users trying out will be the first step I think. Thoughts @ananthsub?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly those 3 cases!

To unblock the initial integration, I was wondering if we should start with option #3 to unblock power users in release candidates with the caveat that they are responsible for the full wrapping. Maybe this could can be option on the plugin as to whether the outer wrap on lightning module needs to be applied in order to distinguish between cases 2 and 3.

Completely agreed with you that most users will opt for cases 1 and 2, so we'll need to figure out the parameter flattening, whether in lightning or fairscale, but wanted to offer this as one way we could sequence these cases

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I started making changes to see if any issues arise with case 3 and a few observation:

The user still has to define a single model, may it be a Module containing modules in a sequential wrapper, or just defining their own model structure defining a forward function. This means self.model will still probably be required in every case for FSDP to work in configure_optimizers.

I also ran into an issue where clipping grad norms which in manual mode cannot be handled automatically, as we do not wrap the model:

class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin):
    """Mixed Precision for Full Sharded Training"""

    def clip_gradients(`
        self,
        optimizer: 'Optimizer',
        clip_val: Union[int, float],
        gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
        model: Optional[Module] = None
    ) -> None:
        # Model manages clipping of gradients
        model = cast(FullyShardedDataParallel, model)
        # todo: expose norm type once precision plugin supports this.
        model.clip_grad_norm_(clip_val, norm_type=2.0) # This breaks

A potential solution albeit not as elegant as I'd like, would be to go through the immediate children of the LightningModule, find the root FSDP module and call clip_grad_norm_ on it. I assume this will be a negligible cost added on top of the training loop but what are your thoughts @ananthsub?


if not self.cpu_offload:
# When using CPU Offload, FSDP will manage the CUDA movement for us
self.model_to_device()
# setup optimizers after fully sharded has wrapped the lightning module
self.lightning_module.trainer.accelerator.setup_optimizers(self.lightning_module.trainer)

def model_to_device(self):
self.model.to(self.root_device)
# ensure we update the device type in the lightning module
self.lightning_module.to(self.root_device)
Comment on lines +173 to +175
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might need to be cautious about this, as fsdp_module.to(device) will summon full parameters first: https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L348-L367

and when we perform teardown for GPU memory cleanup, we have self.lightning_module.cpu()

https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/accelerators/gpu.py#L50-L55


def pre_dispatch(self):
if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)
self.configure_ddp()
self.barrier()

@property
def lightning_module(self) -> LightningModule:
return unwrap_lightning_module_fully_sharded(self.model)

def on_save(self, checkpoint: dict) -> dict:
state_dict = self.collate_state_dict()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SeanNaren , after getting detailed memory usage, I finally figured out why originally the full model fits in one GPU, but when checkpointing, it OOM

because in checkpoint_connector (https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L270-L277), we have

  model = self.trainer.lightning_module

  checkpoint = {
      'epoch': current_epoch,
      'global_step': global_step,
      'pytorch-lightning_version': pytorch_lightning.__version__,
      'state_dict': model.state_dict(),
  }

here, we try to collect again, this would double the size.

One easy workaround now, is to add

del  checkpoint['state_dict']

but this is not ideal, we summon the full parameters twice which is unnecessary.

I feel, we should modify that file to let training type plugin to control, something like trainer.accelerator.training_type_plugin.state_dict()

especially when we would like to collect only sharded state dict in the future.

cc @ananthsub

@min-xu-ai , I think this is the root cause for OOM, facebookresearch/fairscale#658 should not be problem (for setting state_dict_device=torch.device("cpu"), CPU OOM should be similar problem as we also double the model storage in CPU)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @shuyingsunshine21 for your help here! This makes sense since we're allocating memory new memory.

I agree with allowing the training type plugin to return the state dict, we already rely on the accelerator to dump the optimizer dicts. I'm happy to make the change!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SeanNaren , thanks, no worry, if you have not already made the change, I could help send a small PR for that.

checkpoint['state_dict'] = state_dict
return checkpoint

def collate_state_dict(self):
"""
Collects the models sharded state dict from all processes before returning.
Returns: The unsharded model state dict.
"""
state_dict = self.model.state_dict()
# Remove module prefix from state dict as this is the behaviour of state dict.
state_dict = {k.partition('module.')[2]: state_dict[k] for k in state_dict.keys()}
return state_dict

@property
def setup_optimizers_in_pre_dispatch(self) -> bool:
# Setup optimizers after the Fully Sharded Model has been made
return True

def _model_has_nested_fsdp(self):
for module in self.model.modules():
if isinstance(module, FullyShardedDataParallel):
return True
return False

@classmethod
def register_plugins(cls, plugin_registry: Dict):
plugin_registry.register("fsdp", cls, description="Fully Sharded with LightningModule wrap")
plugin_registry.register(
"fsdp_offload", cls, description="Fully Sharded Training with CPU Offloading.", cpu_offload=True
)
Loading