Skip to content

Commit

Permalink
adding to docs, dk review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley committed Feb 28, 2023
1 parent 1e252f7 commit 58f210b
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 19 deletions.
12 changes: 6 additions & 6 deletions composer/trainer/dist_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ def prepare_fsdp_module(model: torch.nn.Module, optimizers: Optional[Union[torch
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.distributed.fsdp.flatten_params_wrapper import FlattenParamsWrapper

from composer.trainer.mosaic_fsdp import (MosaicFullyShardedDataParallel, _get_cpu_offload, _get_mixed_precision,
backward_prefetch_map, sharding_map)
from composer.trainer.mosaic_fsdp import (MosaicFullyShardedDataParallel, backward_prefetch_map, get_cpu_offload,
get_mixed_precision, sharding_map)

if optimizers:
optimizers_tuple = ensure_tuple(optimizers)
Expand All @@ -158,13 +158,13 @@ def prepare_fsdp_module(model: torch.nn.Module, optimizers: Optional[Union[torch
sharding_map_key = fsdp_config.get('sharding_strategy', 'FULL_SHARD').upper()
sharding_strategy = sharding_map[sharding_map_key]

cpu_offload = _get_cpu_offload(cpu_offload=fsdp_config.get('cpu_offload', False))
cpu_offload = get_cpu_offload(cpu_offload=fsdp_config.get('cpu_offload', False))

mixed_precision = fsdp_config.get('mixed_precision', 'DEFAULT')
keep_low_precision_grads = fsdp_config.get('keep_low_precision_grads', False)
mixed_precision, param_dtype, _, _ = _get_mixed_precision(precision,
mixed_precision=mixed_precision,
keep_low_precision_grads=keep_low_precision_grads)
mixed_precision, param_dtype, _, _ = get_mixed_precision(precision,
mixed_precision=mixed_precision,
keep_low_precision_grads=keep_low_precision_grads)

# Note: FSDP does support the use of torch.float32 with sharding.
# They just never expected a user to pass in torch.float32 into mixed_precision as a param_dtype.
Expand Down
35 changes: 22 additions & 13 deletions composer/trainer/mosaic_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
'sharding_map',
'backward_prefetch_map',
'get_torch_dtype',
'_get_mixed_precision',
'_get_cpu_offload',
'_get_process_group',
'get_mixed_precision',
'get_cpu_offload',
'get_process_group',
'MosaicFullyShardedDataParallel',
]

Expand Down Expand Up @@ -64,8 +64,8 @@ def get_torch_dtype(dtype: Union[Precision, str]):
raise ValueError(f'Not sure how to convert dtype={dtype} to a torch dtype.')


def _get_mixed_precision(precision, mixed_precision='DEFAULT', keep_low_precision_grads=False):
"""Helper fn for configuring mixed_precision."""
def get_mixed_precision(precision, mixed_precision='DEFAULT', keep_low_precision_grads=False):
"""Helper function for configuring mixed_precision."""
param_dtype = None
reduce_dtype = None
buffer_dtype = None
Expand Down Expand Up @@ -105,16 +105,16 @@ def _get_mixed_precision(precision, mixed_precision='DEFAULT', keep_low_precisio
return mixed_precision, param_dtype, reduce_dtype, buffer_dtype


def _get_cpu_offload(cpu_offload=False):
def get_cpu_offload(cpu_offload=False):
"""Helper fn for configuring cpu_offload."""
cpu_offload = CPUOffload(offload_params=True) if cpu_offload else None
if cpu_offload is not None:
raise ValueError('FSDP CPU Offload not supported yet.')
return cpu_offload


def _get_process_group(pg, process_group_cache=None):
"""Helper fn for configuring and/or retrieving process groups."""
def get_process_group(pg, process_group_cache=None):
"""Helper function for configuring and/or retrieving process groups."""
warnings.warn(f'Instantiating FSDP with custom process groups is an experimental feature.')

# Return regular process_groups as is, no cacheing
Expand All @@ -138,14 +138,14 @@ def _get_process_group(pg, process_group_cache=None):
if isinstance(pg, str) and pg.startswith('set'):
k = int(pg.strip('set'))
world_size = dist.get_world_size()
if world_size % k:
if world_size % k != 0:
raise RuntimeError(f'{world_size} must be divisible by set size ({k})')
start = dist.get_global_rank() // k * k
ranks = tuple(range(start, start + k))
elif isinstance(pg, str) and pg.startswith('mod'):
k = int(pg.strip('mod'))
world_size = dist.get_world_size()
if world_size % k:
if world_size % k != 0:
raise RuntimeError(f'{world_size} must be divisible by mod ({k})')
ranks = tuple(range(dist.get_global_rank() % k, world_size, k))
elif isinstance(pg, (list, tuple)):
Expand All @@ -156,7 +156,8 @@ def _get_process_group(pg, process_group_cache=None):
if process_group_cache is not None and ranks in process_group_cache:
warnings.warn(
f'On rank={dist.get_global_rank()} using cached progress group with {ranks=}. ' +\
'Instantiate new process group if this is what was intended.'
'If the intention was to use a new process group, a new process group can be instantiated and passed' +\
"in as an arguement (`'process_group': newly_instantiated_process_group_obect,`)"
)
return process_group_cache[ranks]

Expand Down Expand Up @@ -185,6 +186,10 @@ def _custom_recursive_wrap(module: nn.Module,
modified version of
https://github.com/pytorch/pytorch/blob/d922c29a22e4bf0fba49526f7536395eb8cd66f4/torch/distributed/fsdp/wrap.py#L353
which recursively wraps modules as FSDP modules for parameter sharding.
This modification enables the user to pass custom FSDP arguements for every wrapped module.
The added process_group_cache enables different FSDP modules to, when appropriate, use the
same process group instead of instantiating a new process group.
Automatically wrap child modules of *module* that meet the given
criteria with :func:`auto_wrap`. Does not rely on _ConfigAutoWrap.
Expand Down Expand Up @@ -253,14 +258,14 @@ def _custom_recursive_wrap(module: nn.Module,
'backward_prefetch'] not in backward_prefetch_map.values():
module_kwargs['backward_prefetch'] = backward_prefetch_map[module_kwargs['backward_prefetch'].upper()]
if 'cpu_offload' in module_kwargs and not isinstance(module_kwargs['cpu_offload'], CPUOffload):
module_kwargs['cpu_offload'] = _get_cpu_offload(cpu_offload=module_kwargs['cpu_offload'].upper())
module_kwargs['cpu_offload'] = get_cpu_offload(cpu_offload=module_kwargs['cpu_offload'].upper())
if 'mixed_precision' in module_kwargs and not isinstance(module_kwargs['mixed_precision'], MixedPrecision):
# `precision` needs to set `'mixed_precision'`, but `precision` is not part of fsdp kwargs
raise NotImplementedError(
f"Automated setting of custom per module mixed_precision is not implemented, but it can be set if `isinstance(module_kwargs['mixed_precision'], MixedPrecision)`"
)
if 'process_group' in module_kwargs:
module_kwargs['process_group'] = _get_process_group(module_kwargs['process_group'], process_group_cache)
module_kwargs['process_group'] = get_process_group(module_kwargs['process_group'], process_group_cache)

final_kwargs = {**kwargs, **module_kwargs}

Expand All @@ -283,6 +288,10 @@ def _auto_wrap(
modified version of
https://github.com/pytorch/pytorch/blob/d922c29a22e4bf0fba49526f7536395eb8cd66f4/torch/distributed/fsdp/fully_sharded_data_parallel.py#L1252
FSDP's _auto_wrap recursively wraps modules as FSDP modules for parameter sharding.
This modification enables the user to pass custom FSDP arguements for every wrapped module.
The added process_group_cache enables different FSDP modules to, when appropriate, use the
same process group instead of instantiating a new process group.
Recursively auto wraps the root module given by the key "module" in
``auto_wrap_kwargs`` with the arguments in ``auto_wrap_kwargs`` and
Expand Down
26 changes: 26 additions & 0 deletions docs/source/notes/distributed_training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,32 @@ In `gpt.py <https://github.com/mosaicml/examples/blob/6972fe3000d5a5480d8757ff71
A very similar auto wrap policy is provided for activation checkpointing, with analogous rule #1 that looks for :code:`module._activation_checkpointing = True | False` and rule #2 that looks for :code:`def activation_checkpointing_fn(module: torch.nn.Module) -> bool`.


**Experimental:** Composer enables the users to specify custom FSDP args for all wrapped modules. This is enabled by returning a dictionary of args instead of returning a bool.

.. code:: python
# FSDP Wrap function
def fsdp_wrap_fn(self, module):
if isinstance(module, Block):
return True
# extends FSDP wrapping to custom args
if isinstance(module, BlockWithCustomArgs):
return {
'process_group': 'node',
'mixed_precision': 'FULL',
}
# defaults to False
return False
While the user can instantiate and pass in process groups, Composer enables process groups to be specified using the following keywords: {:code:`'self'`, :code:`'node'`, :code:`'local_rank_across_nodes'`, :code:`'setK'`, :code:`'modK'`} (where :code:`K` is an integer).
:code:`'self'` is the degenerate case where all process groups only operate within their current rank (:code:`'self'` == :code:`'set1'`). This is useful when you do not want a layer to be synchonized across accelerators.
:code:`'node'` instantiates process groups which opereate within a node (:code:`'mode'` == :code:`f'set{worold_size}'`). This is useful for Expert Layers in MoE models.
:code:`'local_rank_across_nodes'` instantiates process groups with the same local rank across all nodes (:code:`'local_rank_across_nodes'` == :code:`f'mod{worold_size}'`). This is useful for Tensor Parallel Layers.
:code:`'setK'` instantiates process groups which opereate within a set of K GPUs. This is useful for Expert Layers in MoE models.
:code:`'modK'` instantiates process groups which opereate on every Kth GPUs. This is useful for Tensor Parallel Layers.


Saving and Loading Sharded Checkpoints with FSDP
------------------------------------------------
To save and load sharded checkpoints with FSDP, you can make use of the field, :code:`state_dict_type` in :code:`fsdp_config`.
Expand Down

0 comments on commit 58f210b

Please sign in to comment.