From 0924919707f944b1d523f7885d2b7acfd583d875 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Mon, 7 Oct 2024 14:45:05 -0400 Subject: [PATCH 1/3] Refactor scaler to util --- src/accelerate/accelerator.py | 28 +++------------------------- src/accelerate/utils/__init__.py | 1 + src/accelerate/utils/modeling.py | 28 ++++++++++++++++++++++++++++ src/accelerate/utils/operations.py | 2 +- 4 files changed, 33 insertions(+), 26 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index de0ce0f374d..9d9ff54ec36 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -32,7 +32,6 @@ import torch import torch.utils.hooks as hooks from huggingface_hub import split_torch_state_dict_into_shards -from packaging import version from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches @@ -78,6 +77,7 @@ extract_model_from_parallel, gather, gather_object, + get_grad_scaler, get_mixed_precision_context_manager, get_pretty_name, is_bf16_available, @@ -136,7 +136,6 @@ if is_torch_xla_available(): - import torch_xla.amp as xamp import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp @@ -484,25 +483,7 @@ def __init__( ): raise ValueError(f"fp16 mixed precision requires a GPU (not {self.device.type!r}).") kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {} - if self.distributed_type == DistributedType.FSDP: - from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler - - self.scaler = ShardedGradScaler(**kwargs) - elif is_torch_xla_available(check_is_gpu=True): - self.scaler = xamp.GradScaler(**kwargs) - elif is_mlu_available(): - self.scaler = torch.mlu.amp.GradScaler(**kwargs) - elif is_musa_available(): - self.scaler = torch.musa.amp.GradScaler(**kwargs) - elif is_npu_available(): - self.scaler = torch.npu.amp.GradScaler(**kwargs) - elif is_xpu_available(): - self.scaler = torch.amp.GradScaler("xpu", **kwargs) - else: - if version.parse(torch.__version__) > version.parse("2.3"): - self.scaler = torch.amp.GradScaler("cuda", **kwargs) - else: - self.scaler = torch.cuda.amp.GradScaler(**kwargs) + self.scaler = get_grad_scaler(self.distributed_type == DistributedType.FSDP, **kwargs) elif self.state.mixed_precision == "bf16" and self.distributed_type not in ( DistributedType.DEEPSPEED, @@ -526,10 +507,7 @@ def __init__( ) elif self.distributed_type != DistributedType.DEEPSPEED: # MS-AMP requires `GradScaler` even with bf16 autocast w/ single GPU or DDP: - if version.parse(torch.__version__) > version.parse("2.3"): - self.scaler = torch.amp.GradScaler("cuda") - else: - self.scaler = torch.cuda.amp.GradScaler() + self.scaler = get_grad_scaler(**kwargs) # Start of internal step tracking self.step = 0 diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index 324fcd17886..0c57cfb3eb3 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -130,6 +130,7 @@ dtype_byte_size, find_tied_parameters, get_balanced_memory, + get_grad_scaler, get_max_layer_size, get_max_memory, get_mixed_precision_context_manager, diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index f4230c55994..7eb48f48484 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1871,3 +1871,31 @@ def get_mixed_precision_context_manager(native_amp: bool = False, autocast_kwarg return torch.autocast(device_type=device_type, **autocast_kwargs) else: return contextlib.nullcontext() + + +def get_grad_scaler(use_fsdp=False, **kwargs): + """ + A generic helper which will initialize the correct `GradScaler` implementation based on the environment and return + it. + """ + if use_fsdp: + from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler + + return ShardedGradScaler(**kwargs) + if is_torch_xla_available(check_is_gpu=True): + import torch_xla.amp as xamp + + return xamp.GradScaler(**kwargs) + elif is_mlu_available(): + return torch.mlu.amp.GradScaler(**kwargs) + elif is_musa_available(): + return torch.musa.amp.GradScaler(**kwargs) + elif is_npu_available(): + return torch.npu.amp.GradScaler(**kwargs) + elif is_xpu_available(): + return torch.amp.GradScaler("xpu", **kwargs) + else: + if is_torch_version(">=", "2.3"): + return torch.amp.GradScaler("cuda", **kwargs) + else: + return torch.cuda.amp.GradScaler(**kwargs) diff --git a/src/accelerate/utils/operations.py b/src/accelerate/utils/operations.py index 162009e76b6..66931779909 100644 --- a/src/accelerate/utils/operations.py +++ b/src/accelerate/utils/operations.py @@ -29,10 +29,10 @@ from .imports import ( is_npu_available, is_torch_distributed_available, - is_torch_version, is_torch_xla_available, is_xpu_available, ) +from .versions import is_torch_version if is_torch_xla_available(): From c8c5f875af381c8dd62f261e759032daaf28efdd Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Mon, 7 Oct 2024 14:48:31 -0400 Subject: [PATCH 2/3] Document --- docs/source/package_reference/utilities.md | 4 ++++ src/accelerate/utils/modeling.py | 8 +++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/docs/source/package_reference/utilities.md b/docs/source/package_reference/utilities.md index 9e7aece6df7..40d18e686d1 100644 --- a/docs/source/package_reference/utilities.md +++ b/docs/source/package_reference/utilities.md @@ -126,6 +126,10 @@ These include data operations that mimic the same `torch` ops but can be used on [[autodoc]] utils.gather_object +[[autodoc]] utils.get_grad_scaler + +[[autodoc]] utils.get_mixed_precision_context_manager + [[autodoc]] utils.listify [[autodoc]] utils.pad_across_processes diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 7eb48f48484..a1548b7fcf9 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1873,10 +1873,16 @@ def get_mixed_precision_context_manager(native_amp: bool = False, autocast_kwarg return contextlib.nullcontext() -def get_grad_scaler(use_fsdp=False, **kwargs): +def get_grad_scaler(use_fsdp: bool = False, **kwargs): """ A generic helper which will initialize the correct `GradScaler` implementation based on the environment and return it. + + Args: + use_fsdp (`bool`, *optional*, defaults to False): + Whether FSDP is enabled. + kwargs: + Additional arguments for the utilized `GradScaler` constructor. """ if use_fsdp: from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler From 21cab59ff4aa2359aa65dc0d0e8a889c6c91031f Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Tue, 8 Oct 2024 10:54:37 -0400 Subject: [PATCH 3/3] Use the distributed_type directly --- benchmarks/fp8/ms_amp/ddp.py | 8 ++------ benchmarks/fp8/ms_amp/non_distributed.py | 8 ++------ src/accelerate/accelerator.py | 2 +- src/accelerate/utils/modeling.py | 8 ++++---- 4 files changed, 9 insertions(+), 17 deletions(-) diff --git a/benchmarks/fp8/ms_amp/ddp.py b/benchmarks/fp8/ms_amp/ddp.py index 25d9fc0a7bf..ce80cded35e 100644 --- a/benchmarks/fp8/ms_amp/ddp.py +++ b/benchmarks/fp8/ms_amp/ddp.py @@ -22,12 +22,11 @@ import msamp import torch from fp8_utils import evaluate_model, get_training_utilities -from packaging import version from torch.nn.parallel import DistributedDataParallel as DDP from accelerate import Accelerator from accelerate.state import AcceleratorState -from accelerate.utils import FP8RecipeKwargs, set_seed +from accelerate.utils import FP8RecipeKwargs, get_grad_scaler, set_seed MODEL_NAME = "bert-base-cased" @@ -36,10 +35,7 @@ def train_baseline(opt_level="O2"): set_seed(42) - if version.parse(torch.__version__) > version.parse("2.3"): - scaler = torch.amp.GradScaler("cuda") - else: - scaler = torch.cuda.amp.GradScaler() + scaler = get_grad_scaler() model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME) accelerator = Accelerator() device = accelerator.device diff --git a/benchmarks/fp8/ms_amp/non_distributed.py b/benchmarks/fp8/ms_amp/non_distributed.py index 5fc659117ca..6e4284baf3f 100644 --- a/benchmarks/fp8/ms_amp/non_distributed.py +++ b/benchmarks/fp8/ms_amp/non_distributed.py @@ -22,11 +22,10 @@ import msamp import torch from fp8_utils import evaluate_model, get_training_utilities -from packaging import version from accelerate import Accelerator from accelerate.state import AcceleratorState -from accelerate.utils import FP8RecipeKwargs, set_seed +from accelerate.utils import FP8RecipeKwargs, get_grad_scaler, set_seed MODEL_NAME = "bert-base-cased" @@ -42,10 +41,7 @@ def train_baseline(opt_level="O2"): base_model_results = evaluate_model(model, eval_dataloader, METRIC) model.train() - if version.parse(torch.__version__) > version.parse("2.3"): - scaler = torch.amp.GradScaler("cuda") - else: - scaler = torch.cuda.amp.GradScaler() + scaler = get_grad_scaler() for batch in train_dataloader: batch = batch.to("cuda") diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 9d9ff54ec36..f35571a0356 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -483,7 +483,7 @@ def __init__( ): raise ValueError(f"fp16 mixed precision requires a GPU (not {self.device.type!r}).") kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {} - self.scaler = get_grad_scaler(self.distributed_type == DistributedType.FSDP, **kwargs) + self.scaler = get_grad_scaler(self.distributed_type, **kwargs) elif self.state.mixed_precision == "bf16" and self.distributed_type not in ( DistributedType.DEEPSPEED, diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index a1548b7fcf9..181694b633c 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1873,18 +1873,18 @@ def get_mixed_precision_context_manager(native_amp: bool = False, autocast_kwarg return contextlib.nullcontext() -def get_grad_scaler(use_fsdp: bool = False, **kwargs): +def get_grad_scaler(distributed_type: DistributedType = None, **kwargs): """ A generic helper which will initialize the correct `GradScaler` implementation based on the environment and return it. Args: - use_fsdp (`bool`, *optional*, defaults to False): - Whether FSDP is enabled. + distributed_type (`DistributedType`, *optional*, defaults to None): + The type of distributed environment. kwargs: Additional arguments for the utilized `GradScaler` constructor. """ - if use_fsdp: + if distributed_type == DistributedType.FSDP: from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler return ShardedGradScaler(**kwargs)