diff --git a/benchmarks/fp8/ms_amp/ddp.py b/benchmarks/fp8/ms_amp/ddp.py index 25d9fc0a7bf..ce80cded35e 100644 --- a/benchmarks/fp8/ms_amp/ddp.py +++ b/benchmarks/fp8/ms_amp/ddp.py @@ -22,12 +22,11 @@ import msamp import torch from fp8_utils import evaluate_model, get_training_utilities -from packaging import version from torch.nn.parallel import DistributedDataParallel as DDP from accelerate import Accelerator from accelerate.state import AcceleratorState -from accelerate.utils import FP8RecipeKwargs, set_seed +from accelerate.utils import FP8RecipeKwargs, get_grad_scaler, set_seed MODEL_NAME = "bert-base-cased" @@ -36,10 +35,7 @@ def train_baseline(opt_level="O2"): set_seed(42) - if version.parse(torch.__version__) > version.parse("2.3"): - scaler = torch.amp.GradScaler("cuda") - else: - scaler = torch.cuda.amp.GradScaler() + scaler = get_grad_scaler() model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME) accelerator = Accelerator() device = accelerator.device diff --git a/benchmarks/fp8/ms_amp/non_distributed.py b/benchmarks/fp8/ms_amp/non_distributed.py index 5fc659117ca..6e4284baf3f 100644 --- a/benchmarks/fp8/ms_amp/non_distributed.py +++ b/benchmarks/fp8/ms_amp/non_distributed.py @@ -22,11 +22,10 @@ import msamp import torch from fp8_utils import evaluate_model, get_training_utilities -from packaging import version from accelerate import Accelerator from accelerate.state import AcceleratorState -from accelerate.utils import FP8RecipeKwargs, set_seed +from accelerate.utils import FP8RecipeKwargs, get_grad_scaler, set_seed MODEL_NAME = "bert-base-cased" @@ -42,10 +41,7 @@ def train_baseline(opt_level="O2"): base_model_results = evaluate_model(model, eval_dataloader, METRIC) model.train() - if version.parse(torch.__version__) > version.parse("2.3"): - scaler = torch.amp.GradScaler("cuda") - else: - scaler = torch.cuda.amp.GradScaler() + scaler = get_grad_scaler() for batch in train_dataloader: batch = batch.to("cuda") diff --git a/docs/source/package_reference/utilities.md b/docs/source/package_reference/utilities.md index 9e7aece6df7..40d18e686d1 100644 --- a/docs/source/package_reference/utilities.md +++ b/docs/source/package_reference/utilities.md @@ -126,6 +126,10 @@ These include data operations that mimic the same `torch` ops but can be used on [[autodoc]] utils.gather_object +[[autodoc]] utils.get_grad_scaler + +[[autodoc]] utils.get_mixed_precision_context_manager + [[autodoc]] utils.listify [[autodoc]] utils.pad_across_processes diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index de0ce0f374d..f35571a0356 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -32,7 +32,6 @@ import torch import torch.utils.hooks as hooks from huggingface_hub import split_torch_state_dict_into_shards -from packaging import version from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches @@ -78,6 +77,7 @@ extract_model_from_parallel, gather, gather_object, + get_grad_scaler, get_mixed_precision_context_manager, get_pretty_name, is_bf16_available, @@ -136,7 +136,6 @@ if is_torch_xla_available(): - import torch_xla.amp as xamp import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp @@ -484,25 +483,7 @@ def __init__( ): raise ValueError(f"fp16 mixed precision requires a GPU (not {self.device.type!r}).") kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {} - if self.distributed_type == DistributedType.FSDP: - from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler - - self.scaler = ShardedGradScaler(**kwargs) - elif is_torch_xla_available(check_is_gpu=True): - self.scaler = xamp.GradScaler(**kwargs) - elif is_mlu_available(): - self.scaler = torch.mlu.amp.GradScaler(**kwargs) - elif is_musa_available(): - self.scaler = torch.musa.amp.GradScaler(**kwargs) - elif is_npu_available(): - self.scaler = torch.npu.amp.GradScaler(**kwargs) - elif is_xpu_available(): - self.scaler = torch.amp.GradScaler("xpu", **kwargs) - else: - if version.parse(torch.__version__) > version.parse("2.3"): - self.scaler = torch.amp.GradScaler("cuda", **kwargs) - else: - self.scaler = torch.cuda.amp.GradScaler(**kwargs) + self.scaler = get_grad_scaler(self.distributed_type, **kwargs) elif self.state.mixed_precision == "bf16" and self.distributed_type not in ( DistributedType.DEEPSPEED, @@ -526,10 +507,7 @@ def __init__( ) elif self.distributed_type != DistributedType.DEEPSPEED: # MS-AMP requires `GradScaler` even with bf16 autocast w/ single GPU or DDP: - if version.parse(torch.__version__) > version.parse("2.3"): - self.scaler = torch.amp.GradScaler("cuda") - else: - self.scaler = torch.cuda.amp.GradScaler() + self.scaler = get_grad_scaler(**kwargs) # Start of internal step tracking self.step = 0 diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index 324fcd17886..0c57cfb3eb3 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -130,6 +130,7 @@ dtype_byte_size, find_tied_parameters, get_balanced_memory, + get_grad_scaler, get_max_layer_size, get_max_memory, get_mixed_precision_context_manager, diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index f4230c55994..181694b633c 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1871,3 +1871,37 @@ def get_mixed_precision_context_manager(native_amp: bool = False, autocast_kwarg return torch.autocast(device_type=device_type, **autocast_kwargs) else: return contextlib.nullcontext() + + +def get_grad_scaler(distributed_type: DistributedType = None, **kwargs): + """ + A generic helper which will initialize the correct `GradScaler` implementation based on the environment and return + it. + + Args: + distributed_type (`DistributedType`, *optional*, defaults to None): + The type of distributed environment. + kwargs: + Additional arguments for the utilized `GradScaler` constructor. + """ + if distributed_type == DistributedType.FSDP: + from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler + + return ShardedGradScaler(**kwargs) + if is_torch_xla_available(check_is_gpu=True): + import torch_xla.amp as xamp + + return xamp.GradScaler(**kwargs) + elif is_mlu_available(): + return torch.mlu.amp.GradScaler(**kwargs) + elif is_musa_available(): + return torch.musa.amp.GradScaler(**kwargs) + elif is_npu_available(): + return torch.npu.amp.GradScaler(**kwargs) + elif is_xpu_available(): + return torch.amp.GradScaler("xpu", **kwargs) + else: + if is_torch_version(">=", "2.3"): + return torch.amp.GradScaler("cuda", **kwargs) + else: + return torch.cuda.amp.GradScaler(**kwargs) diff --git a/src/accelerate/utils/operations.py b/src/accelerate/utils/operations.py index 162009e76b6..66931779909 100644 --- a/src/accelerate/utils/operations.py +++ b/src/accelerate/utils/operations.py @@ -29,10 +29,10 @@ from .imports import ( is_npu_available, is_torch_distributed_available, - is_torch_version, is_torch_xla_available, is_xpu_available, ) +from .versions import is_torch_version if is_torch_xla_available():