Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fsdp with custom process groups #2006

Merged
merged 36 commits into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
3a280cd
wip
abhi-mosaic Feb 11, 2023
e84bbc6
remove exit
abhi-mosaic Feb 11, 2023
855360b
wip
abhi-mosaic Feb 11, 2023
acec8c7
Merge branch 'dev' into abhi/fsdp_process_group
abhi-mosaic Feb 23, 2023
6fb7be2
updates
abhi-mosaic Feb 24, 2023
8c8d0c2
print out
abhi-mosaic Feb 24, 2023
f7ca399
remove fun features
abhi-mosaic Feb 24, 2023
f245b23
add pg cache
vchiley Feb 24, 2023
c3b43ca
pg testing
vchiley Feb 24, 2023
a690d36
fix pg instantiation bug
vchiley Feb 25, 2023
0976448
undo extra edits
vchiley Feb 26, 2023
955cd2f
destroy process groups
vchiley Feb 26, 2023
1193c34
lint
vchiley Feb 26, 2023
e5f8011
add pg options
vchiley Feb 26, 2023
14b74a5
rm destroy_process_group
vchiley Feb 26, 2023
0a3d6f0
propagating license
vchiley Feb 26, 2023
3fbe1c9
making links perminent
vchiley Feb 26, 2023
f57eb74
Merge branch 'dev' into vchil/fsdp_process_group
mvpatel2000 Feb 27, 2023
74aec1a
enable mosaicml ui for more custom per module fsdp kwargs
vchiley Feb 27, 2023
4a31f62
daya review comments
vchiley Feb 27, 2023
fbe6a30
lint
vchiley Feb 27, 2023
a6d895f
Merge branch 'dev' into vchil/fsdp_process_group
vchiley Feb 27, 2023
d201777
clean up
vchiley Feb 27, 2023
1e9ebe2
Merge branch 'dev' into vchil/fsdp_process_group
vchiley Feb 27, 2023
d532a04
add amp_fp8 warning
vchiley Feb 27, 2023
3d6822f
Merge branch 'dev' into vchil/fsdp_process_group
vchiley Feb 27, 2023
1e252f7
Update composer/trainer/mosaic_fsdp.py
vchiley Feb 27, 2023
15ac04e
adding to docs, dk review comments
vchiley Feb 28, 2023
eb2efde
updt docs
vchiley Feb 28, 2023
f7a27f7
Merge branch 'dev' into vchil/fsdp_process_group
vchiley Feb 28, 2023
281844c
Apply suggestions from code review
vchiley Feb 28, 2023
e57ae38
dk review
vchiley Feb 28, 2023
c5d0f68
Merge branch 'dev' into vchil/fsdp_process_group
vchiley Feb 28, 2023
28df3a1
Merge branch 'dev' into vchil/fsdp_process_group
vchiley Feb 28, 2023
e236dbd
Merge branch 'dev' into vchil/fsdp_process_group
mvpatel2000 Feb 28, 2023
774bb8f
Merge branch 'dev' into vchil/fsdp_process_group
mvpatel2000 Mar 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 10 additions & 70 deletions composer/trainer/dist_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,23 +123,6 @@ def prepare_ddp_module(module: torch.nn.Module, find_unused_parameters: bool) ->
'with distributed support.')


def get_torch_dtype(dtype: Union[Precision, str]):
"""Convert common string representations of dtypes to torch dtypes."""
dtype = dtype.value if isinstance(dtype, Precision) else dtype
if dtype in ['float32', 'torch.float32', 'fp32']:
return torch.float32
elif dtype in ['float16', 'torch.float16', 'half', 'fp16', 'amp', 'amp_fp16']:
return torch.float16
elif dtype in ['bfloat16', 'bfloat', 'torch.bfloat16', 'bf16', 'amp_bf16']:
return torch.bfloat16
elif dtype in ['amp_fp8']:
# We use torch.bfloat16 by default for amp_fp8 as there is no
# fp8 datatype in PyTorch yet.
return torch.bfloat16
else:
raise ValueError(f'Not sure how to convert dtype={dtype} to a torch dtype.')


def prepare_fsdp_module(model: torch.nn.Module, optimizers: Optional[Union[torch.optim.Optimizer,
Sequence[torch.optim.Optimizer]]],
fsdp_config: Dict[str, Any], precision: Precision) -> None:
Expand All @@ -155,10 +138,12 @@ def prepare_fsdp_module(model: torch.nn.Module, optimizers: Optional[Union[torch
raise RuntimeError('To use FSDP with Composer, you must use torch>=1.13.0.')
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (apply_activation_checkpointing,
checkpoint_wrapper)
from torch.distributed.fsdp import (BackwardPrefetch, CPUOffload, FullyShardedDataParallel, MixedPrecision,
ShardingStrategy)
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.distributed.fsdp.flatten_params_wrapper import FlattenParamsWrapper

from composer.trainer.mosaic_fsdp import (MosaicFullyShardedDataParallel, backward_prefetch_map, get_cpu_offload,
get_mixed_precision, sharding_map)

if optimizers:
optimizers_tuple = ensure_tuple(optimizers)
if len(optimizers_tuple) != 1:
Expand All @@ -170,47 +155,16 @@ def prepare_fsdp_module(model: torch.nn.Module, optimizers: Optional[Union[torch
optim.param_groups.clear()
optim.state.clear()

sharding_map = {
'NO_SHARD': ShardingStrategy.NO_SHARD,
'SHARD_GRAD_OP': ShardingStrategy.SHARD_GRAD_OP,
'FULL_SHARD': ShardingStrategy.FULL_SHARD,
}
sharding_map_key = fsdp_config.get('sharding_strategy', 'FULL_SHARD').upper()
sharding_strategy = sharding_map[sharding_map_key]

cpu_offload = CPUOffload(offload_params=True) if fsdp_config.get('cpu_offload', False) else None
if cpu_offload is not None:
raise ValueError('FSDP CPU Offload not supported yet.')
cpu_offload = get_cpu_offload(cpu_offload=fsdp_config.get('cpu_offload', False))

mixed_precision = fsdp_config.get('mixed_precision', 'DEFAULT')
param_dtype = None
reduce_dtype = None
buffer_dtype = None
if isinstance(mixed_precision, dict):
param_dtype = mixed_precision.get('param_dtype', None)
if param_dtype is not None:
param_dtype = get_torch_dtype(param_dtype)
reduce_dtype = mixed_precision.get('reduce_dtype', None)
if reduce_dtype is not None:
reduce_dtype = get_torch_dtype(reduce_dtype)
buffer_dtype = mixed_precision.get('buffer_dtype', None)
if buffer_dtype is not None:
buffer_dtype = get_torch_dtype(buffer_dtype)
elif isinstance(mixed_precision, str):
mixed_precision = mixed_precision.upper()
if mixed_precision == 'FULL':
pass
elif mixed_precision == 'DEFAULT':
reduce_dtype = get_torch_dtype(precision)
buffer_dtype = torch.float32
elif mixed_precision == 'PURE':
param_dtype = get_torch_dtype(precision)
reduce_dtype = get_torch_dtype(precision)
buffer_dtype = get_torch_dtype(precision)
else:
raise ValueError(f'Unable to interpret mixed_precision={mixed_precision}')
else:
raise ValueError(f'Unable to interpret mixed_precision={mixed_precision}')
keep_low_precision_grads = fsdp_config.get('keep_low_precision_grads', False)
mixed_precision, param_dtype, _, _ = get_mixed_precision(precision,
mixed_precision=mixed_precision,
keep_low_precision_grads=keep_low_precision_grads)

# Note: FSDP does support the use of torch.float32 with sharding.
# They just never expected a user to pass in torch.float32 into mixed_precision as a param_dtype.
Expand All @@ -232,20 +186,6 @@ def prepare_fsdp_module(model: torch.nn.Module, optimizers: Optional[Union[torch
f'Consider using `amp` or `bf16` for precision or setting param_dtype in mixed_precision to `None` '
f'with sharding strategy `{sharding_map_key}.`')

keep_low_precision_grads = fsdp_config.get('keep_low_precision_grads', False)

mixed_precision = MixedPrecision(
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
buffer_dtype=buffer_dtype,
keep_low_precision_grads=keep_low_precision_grads,
)

backward_prefetch_map = {
'NONE': None,
'BACKWARD_PRE': BackwardPrefetch.BACKWARD_PRE,
'BACKWARD_POST': BackwardPrefetch.BACKWARD_POST,
}
backward_prefetch = backward_prefetch_map[fsdp_config.get('backward_prefetch', 'BACKWARD_POST').upper()]
min_params = int(float(fsdp_config.get('min_params', 1e9)))
activation_checkpointing = fsdp_config.get('activation_checkpointing', False)
Expand Down Expand Up @@ -346,7 +286,7 @@ def _auto_wrap_policy(module: torch.nn.Module, recurse: bool, unwrapped_params:
else:
return is_large

fsdp_obj = FullyShardedDataParallel(
fsdp_obj = MosaicFullyShardedDataParallel(
obj,
sharding_strategy=sharding_strategy,
auto_wrap_policy=_auto_wrap_policy,
Expand Down
Loading