Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP with full state dict #7487

Merged
merged 79 commits into from
May 24, 2021
Merged
Show file tree
Hide file tree
Changes from 54 commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
89f284d
Fix some test errors
Mar 23, 2021
80cfbff
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Mar 23, 2021
536c132
checkpoint consolidation
Mar 24, 2021
f172101
Update ddp_spawn.py
shuyingsunshine21 Mar 24, 2021
bf70e43
Update test_metric_result_integration.py
shuyingsunshine21 Mar 24, 2021
ea74906
Update test_results.py
shuyingsunshine21 Mar 24, 2021
a9aae99
Update utils.py
shuyingsunshine21 Mar 24, 2021
70fe5da
Update utils.py
shuyingsunshine21 Mar 24, 2021
0d23d75
Update test_all_gather_grad.py
shuyingsunshine21 Mar 24, 2021
ca6f98b
Update test_all_gather_grad.py
shuyingsunshine21 Mar 24, 2021
c5053da
Merge pull request #1 from shuyingsunshine21/shuyingsunshine21-checkp…
shuyingsunshine21 Mar 24, 2021
9d4a2b8
Update test_results.py
shuyingsunshine21 Mar 24, 2021
7635b4f
Revert "Update test_results.py"
shuyingsunshine21 Mar 24, 2021
d64f90c
Revert "Merge pull request #1 from shuyingsunshine21/shuyingsunshine2…
shuyingsunshine21 Mar 24, 2021
dcdcd29
Revert "Update test_all_gather_grad.py"
shuyingsunshine21 Mar 24, 2021
8651d54
Revert "Update utils.py"
shuyingsunshine21 Mar 24, 2021
15f4b9e
Revert "Update utils.py"
shuyingsunshine21 Mar 24, 2021
250d0aa
Revert "Update test_results.py"
shuyingsunshine21 Mar 24, 2021
6c095b2
Revert "Update test_metric_result_integration.py"
shuyingsunshine21 Mar 24, 2021
8222dc9
Revert "Update ddp_spawn.py"
shuyingsunshine21 Mar 24, 2021
3a9fde9
Revert "checkpoint consolidation"
shuyingsunshine21 Mar 24, 2021
7a369f4
Revert "Revert "checkpoint consolidation""
shuyingsunshine21 Mar 24, 2021
b4a0b9e
Revert "Revert "Revert "checkpoint consolidation"""
shuyingsunshine21 Mar 24, 2021
5cf1db1
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Mar 24, 2021
0ce7e05
Revert "Revert "Update ddp_spawn.py""
shuyingsunshine21 Mar 24, 2021
fe9736d
Revert "Revert "Update test_metric_result_integration.py""
shuyingsunshine21 Mar 24, 2021
c314ef6
Revert "Revert "Update test_results.py""
shuyingsunshine21 Mar 24, 2021
c3feda0
Revert "Revert "Update utils.py""
shuyingsunshine21 Mar 24, 2021
c759477
Revert "Revert "Update test_all_gather_grad.py""
shuyingsunshine21 Mar 24, 2021
7a8e540
Merge branch 'master' of https://github.com/shuyingsunshine21/pytorch…
Mar 24, 2021
ab8b849
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Mar 24, 2021
4e67db2
modify distributed environment to make test pass
Mar 24, 2021
67b6188
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Mar 25, 2021
179d47e
rebase
Apr 8, 2021
f9afa07
rebase to upstream master
Apr 8, 2021
b461e44
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Apr 8, 2021
e1bbc4d
fix version for ddp plugin test
Apr 8, 2021
8270d0d
fix
Apr 8, 2021
803d5dd
fix
Apr 8, 2021
ce1a19b
changelog
Apr 8, 2021
c6a13be
Update CHANGELOG.md
carmocca Apr 9, 2021
e274758
Merge pull request #3 from shuyingsunshine21/ddp_plugin_test_fix
shuyingsunshine21 Apr 10, 2021
f337156
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Apr 14, 2021
35bb931
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Apr 15, 2021
c938a9c
rebase
May 11, 2021
ad93cde
rebase
May 11, 2021
b5d989d
fsdp with full state dict
May 11, 2021
3b86dd2
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
May 11, 2021
d7c33ab
fix missing import
May 11, 2021
1537637
modify unitest
May 12, 2021
c128e03
fix
May 12, 2021
18aeb9e
fix
May 12, 2021
8adbe0c
fix typo
May 12, 2021
70f8bf0
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
May 13, 2021
7fb7c60
modify test and add changelog
May 19, 2021
362d1a8
rebase
May 19, 2021
310b0ee
fix
May 19, 2021
e7bef62
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 19, 2021
77bd7c8
limit max_epoch to 1 for testing
May 20, 2021
f64a54c
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
May 20, 2021
a5ee90c
rebase
May 20, 2021
af31f6a
test
May 20, 2021
a044f58
fix
May 20, 2021
78897a2
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
May 20, 2021
5fac05c
update
May 20, 2021
8fb38dc
testing remove special for multi gpu
May 20, 2021
53be8f4
assert gpu
May 21, 2021
2fdab94
add assertion for gpu
May 21, 2021
ffb985d
fix
May 21, 2021
3189f4c
Re-enable special test, use ModelCheckpoint
May 21, 2021
c33fbcb
Fix paths
May 21, 2021
0067aeb
Fix path passing
May 21, 2021
320c35d
test
May 22, 2021
5ea37a5
test
May 22, 2021
9e599c2
fix test
May 22, 2021
6c96625
fix
May 22, 2021
32cc552
pre-commit format
May 22, 2021
25329fe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 22, 2021
f43a36f
Merge branch 'fsdp' of https://github.com/shuyingsunshine21/pytorch-l…
May 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def on_train_start(self) -> None:
torch.cuda.empty_cache()

def teardown(self) -> None:
# Note: this would be problematic for large model (which could not fit in one GPU)
# as FSDP module.to(device) would first summon all parameters
# (TODO: need to figure out solution)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shuyingsunshine21 as a short-term workaround, could we add teardown to the training type plugin? and have the accelerator's teardown call the training type plugin instead?

this way we could start moving all the device-specific logic into each plugin to prep for the next refactor here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg, will have a separate PR for that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#7579 for reference.

Will remove this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO #7579 is not in line with the accelerator design and I believe we should not continue these patterns. I will comment there on the PR.

self.lightning_module.cpu()
Copy link
Contributor

@ananthsub ananthsub May 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awaelchli @justusschock @SeanNaren this is pointing out that teardown will be training type specific, not necessarily just device specific. if we're rethinking this in #7324 i think combining training type and accelerator, or parameterizing training type to accept what devices are being used, will make this easier to work through


# clean up memory
Expand Down
20 changes: 13 additions & 7 deletions pytorch_lightning/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import ( # noqa: F401
FullyShardedNativeMixedPrecisionPlugin,
)
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401
Expand All @@ -15,6 +18,7 @@
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.fully_sharded import DDPFullyShardedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401
Expand All @@ -32,24 +36,26 @@
"DDP2Plugin",
"DDPPlugin",
"DDPSpawnPlugin",
"DDPFullyShardedPlugin",
"DeepSpeedPlugin",
"DeepSpeedPrecisionPlugin",
"DoublePrecisionPlugin",
"HorovodPlugin",
"NativeMixedPrecisionPlugin",
"PrecisionPlugin",
"ShardedNativeMixedPrecisionPlugin",
"FullyShardedNativeMixedPrecisionPlugin"
"SingleDevicePlugin",
"SingleTPUPlugin",
"TPUHalfPrecisionPlugin",
"TPUSpawnPlugin",
'RPCPlugin',
'RPCSequentialPlugin',
'TrainingTypePlugin',
'ParallelPlugin',
'Plugin',
'DDPShardedPlugin',
'DDPSpawnShardedPlugin',
"RPCPlugin",
"RPCSequentialPlugin",
"TrainingTypePlugin",
"ParallelPlugin",
"Plugin",
"DDPShardedPlugin",
"DDPSpawnShardedPlugin",
]

from pathlib import Path
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/plugins/precision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import ( # noqa: F401
FullyShardedNativeMixedPrecisionPlugin,
)
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
Expand Down
46 changes: 46 additions & 0 deletions pytorch_lightning/plugins/precision/fully_sharded_native_amp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union

from torch.nn import Module
from torch.optim import Optimizer

from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
from pytorch_lightning.utilities import GradClipAlgorithmType


class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would we need this if the precision plugin was owned by the training type plugin?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then should not be needed.

"""Mixed Precision for Full Sharded Training"""

precision = "mixed"

def clip_gradients(
self,
optimizer: Optimizer,
clip_val: Union[int, float],
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.VALUE,
model: Optional[Module] = None,
) -> None:
clip_val = float(clip_val)
if clip_val <= 0:
return
# see https://fairscale.readthedocs.io/en/latest/api/nn/fsdp_tips.html
# section `Gradient Clipping`, using `torch.nn.utils.clip_grad_norm_` is incorrect
# for FSDP module. To overcome this, needs to call sharded_module.clip_grad_norm(clip_val)
# however we rely on LightningModule's configure_sharded_model to wrap FSDP, it would be hard to
# trace back the root FSDP. Now we only support clip by value.
assert (
gradient_clip_algorithm == GradClipAlgorithmType.VALUE
), "`gradient_clip_algorithm`: `norm` is currently not supported for `FullyShardedNativeMixedPrecisionPlugin`"
self.clip_grad_by_value(optimizer, clip_val)
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/training_type/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.fully_sharded import DDPFullyShardedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401
Expand Down
206 changes: 206 additions & 0 deletions pytorch_lightning/plugins/training_type/fully_sharded.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from typing import Any, Dict, Generator, List, Optional, Union

import torch
from torch import Tensor
from torch.nn import Module

from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE

if _FAIRSCALE_FULLY_SHARDED_AVAILABLE:
from fairscale.nn import default_auto_wrap_policy, enable_wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel


class DDPFullyShardedPlugin(DDPPlugin):

def __init__(
self,
cpu_offload: bool = False,
flatten_parameters: bool = True,
reshard_after_forward: bool = True,
move_grads_to_cpu: Optional[bool] = None,
fp32_reduce_scatter: Optional[bool] = None,
compute_dtype: Optional[torch.dtype] = None,
bucket_cap_mb: int = 25,
# pyre-ignore[9]
carmocca marked this conversation as resolved.
Show resolved Hide resolved
min_num_params: int = 1e8,
state_dict_to_cpu: bool = True,
parallel_devices: Optional[List[torch.device]] = None,
# pyre-ignore[9]
cluster_environment: ClusterEnvironment = None,
):
"""
Plugin for Fully Sharded Data Parallel provided by FairScale.

Full Sharded Training shards the entire model across all available GPUs, allowing you to scale model
size, whilst using efficient communication to reduce overhead. In practice, this means we can remain
at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar
to ZeRO-Stage 3 but has been built for upstreaming to PyTorch.
`For more information: https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html`.
.. warning:: ``FullyShardedPlugin`` is in beta and subject to change.

Defaults have been set and options have been exposed, but may require configuration
based on your level of memory/speed efficiency. We suggest having a look at this PR for more information.
`https://github.com/facebookresearch/fairscale/pull/413`

Many of the helpful doc strings below came from the original FairScale documentation:
`https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html`

Arguments:
cpu_offload: Offload FP32 params to CPU. Only usable in precision=16 mode.
(Default: False).
move_grads_to_cpu: Moves gradient shards to CPU after reduction.
Only disable if using CPU based optimizers
(Default to ``cpu_offload``).
flatten_parameters: Flattens parameter into single contiguous tensor for speed efficiency
(Default: True).
reshard_after_forward: Reshard parameters after the forward pass, which saves memory but slows
down training. This is only relevant when resharding individual layers.
(Default: True).
fp32_reduce_scatter: Reduce-Scatter gradients in FP32. Only relevant in mixed precision
(Default: None).
compute_dtype: dtype for full parameters for computation. Default to torch.float32,
unless using mixed precision, in which case defaults to torch.float16.
(Default: None).
bucket_cap_mb: bucket parameters so that gradient reduction
can potentially overlap with backward computation.
bucket_cap_mb controls the bucket size in MegaBytes (MB).
Buckets are sub-divided based on world_size,
so the max shard size is roughly bucket_cap_mb / world_size.
Values <= 0 disable bucketing.
(Default: 25).
min_num_params: Number of parameters to wrap when using FairScale ``auto_wrap``.
(Default: 1e8)
state_dict_to_cpu: Whether to return parameters (returned by :func:`state_dict`) on CPU device.
If ``False``, this will default to ``compute_device``.
(Defautl: True).
"""

super().__init__(
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
)
self.cpu_offload = cpu_offload
self.move_grads_to_cpu = move_grads_to_cpu
self.flatten_parameters = flatten_parameters
self.reshard_after_forward = reshard_after_forward
self.fp32_reduce_scatter = fp32_reduce_scatter
self.compute_dtype = compute_dtype
self.bucket_cap_mb = bucket_cap_mb
self.min_num_params = min_num_params
self.state_dict_device = torch.device("cpu") if state_dict_to_cpu else None
self._process_group = None

@property
def process_group(self):
if self._process_group is None:
self._process_group = torch.distributed.new_group()
return self._process_group

def setup_distributed(self):
super().setup_distributed()
if self.root_device.type == "cuda":
Copy link
Contributor

@ananthsub ananthsub May 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does FSDP work for CPU too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should not, cc @SeanNaren

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haven't tested it, but iirc it relies on CUDA streams being available

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shuyingsunshine21 in this case let's assert that since the DDP plugin this extends supports ddp_cpu. do you think that should happen here or in the accelerator connector?

additionally, should the FSDP plugin extend the parallel plugin instead of the DDP plugin? could we run into other issues if the ddp plugin is changed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's assert that since the DDP plugin this extends supports ddp_cpu. do you think that should happen here or in the accelerator connector

I think we could add assertion here. As we are thinking to flatten/merge accelerator/plugin, the AcceleratorStrategy should take care of this? wdyt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the FSDP plugin extend the parallel plugin instead of the DDP plugin? could we run into other issues if the ddp plugin is changed?

that is a good question, need to take a look

torch.cuda.set_device(self.root_device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add return type hints


@contextlib.contextmanager
def model_sharded_context(self) -> Generator:
precision = self.lightning_module.trainer.precision

def wrap_policy(*args, **kwargs):
return default_auto_wrap_policy(*args, **kwargs, min_num_params=self.min_num_params)

with enable_wrap(
wrapper_cls=FullyShardedDataParallel,
auto_wrap_policy=wrap_policy,
process_group=self.process_group,
cpu_offload=self.cpu_offload,
move_grads_to_cpu=self.move_grads_to_cpu,
flatten_parameters=self.flatten_parameters,
mixed_precision=precision == "mixed",
reshard_after_forward=self.reshard_after_forward,
fp32_reduce_scatter=self.fp32_reduce_scatter,
compute_dtype=self.compute_dtype,
bucket_cap_mb=self.bucket_cap_mb,
state_dict_device=self.state_dict_device,
):
yield

def connect(self, model: Module) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is needed when we load from ckpt, as now we checkpoint the full state, to load, we need the unwrapped model first, and load. This allows us to configure shard (i.e. wrap) afterwards.

Added unittest for making sure that for different stages start (on_x_start), _assert_layer_fsdp_instance passes.

super().connect(model)
model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False)
if not model_call_configure_sharded_model_hook:
# if model has not called configure sharded model, we reset
# the training type plugin's call_configure_sharded_model_hook
# to give trainer a chance to configure.
self.call_configure_sharded_model_hook = True

def configure_ddp(self):
if not self.cpu_offload:
# When using CPU Offload, FSDP will manage the CUDA movement for us.
# Note: this would be problematic for large model (which could not fit in one GPU)
# as FSDP module.to(device) would first summon all parameters
# (TODO: need to figure out solution)
self.model_to_device()

# setup optimizers after fully sharded has wrapped the lightning module
self.lightning_module.trainer.accelerator.setup_optimizers(self.lightning_module.trainer)

def pre_dispatch(self):
if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)
self.configure_ddp()
self.barrier()

def model_to_device(self):
# ensure we update the device type in the lightning module
self.lightning_module.to(self.root_device)

def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]:
# Currently it is same as default TrainingTypePlugin, i.e. return
# the full state dict for FSDP, in the future, we will provide sharded
# state dict.
return super().lightning_module_state_dict()

@property
def setup_optimizers_in_pre_dispatch(self) -> bool:
# Setup optimizers after the Fully Sharded Model has been made
return True

def training_step(self, *args, **kwargs):
return self.model.training_step(*args, **kwargs)

def validation_step(self, *args, **kwargs):
return self.model.validation_step(*args, **kwargs)

def test_step(self, *args, **kwargs):
return self.model.test_step(*args, **kwargs)

def predict_step(self, *args, **kwargs):
return self.model.predict_step(*args, **kwargs)

def post_training_step(self):
pass

@classmethod
def register_plugins(cls, plugin_registry: Dict):
plugin_registry.register(
"fsdp",
cls,
description="Fully sharded training with checkpointing the full state dict.",
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for this, by looking at this facebookresearch/fairscale#413

probably we could register one with reshard_after_forward=False for speed, maybe call, fsdp_fast?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After some thought not sure how important this is; currently a user has to wrap their model in nested FSDP wrap or auto_wraps, so they probably want reshard_after_forward to be set to True; I'm not entirely sure the benefit if you wrap intermediate layers and do not call reshard_after_forward to release memory!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we name it ddp_fsdp for consistency or ddp_fully_sharded (preferred) ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This already exists in the highest level (in the accelerator connector via the Enum it's called ddp_fully_sharded.

But going off #7259 (comment) I think we should introduce sdp fsdp as aliases in the plugin registry for sharded DDP and fully sharded DDP, as they kinda match ddp better!

Loading