Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor scaler to util #3142

Merged
merged 3 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions benchmarks/fp8/ms_amp/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@
import msamp
import torch
from fp8_utils import evaluate_model, get_training_utilities
from packaging import version
from torch.nn.parallel import DistributedDataParallel as DDP

from accelerate import Accelerator
from accelerate.state import AcceleratorState
from accelerate.utils import FP8RecipeKwargs, set_seed
from accelerate.utils import FP8RecipeKwargs, get_grad_scaler, set_seed


MODEL_NAME = "bert-base-cased"
Expand All @@ -36,10 +35,7 @@

def train_baseline(opt_level="O2"):
set_seed(42)
if version.parse(torch.__version__) > version.parse("2.3"):
scaler = torch.amp.GradScaler("cuda")
else:
scaler = torch.cuda.amp.GradScaler()
scaler = get_grad_scaler()
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)
accelerator = Accelerator()
device = accelerator.device
Expand Down
8 changes: 2 additions & 6 deletions benchmarks/fp8/ms_amp/non_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@
import msamp
import torch
from fp8_utils import evaluate_model, get_training_utilities
from packaging import version

from accelerate import Accelerator
from accelerate.state import AcceleratorState
from accelerate.utils import FP8RecipeKwargs, set_seed
from accelerate.utils import FP8RecipeKwargs, get_grad_scaler, set_seed


MODEL_NAME = "bert-base-cased"
Expand All @@ -42,10 +41,7 @@ def train_baseline(opt_level="O2"):

base_model_results = evaluate_model(model, eval_dataloader, METRIC)
model.train()
if version.parse(torch.__version__) > version.parse("2.3"):
scaler = torch.amp.GradScaler("cuda")
else:
scaler = torch.cuda.amp.GradScaler()
scaler = get_grad_scaler()

for batch in train_dataloader:
batch = batch.to("cuda")
Expand Down
4 changes: 4 additions & 0 deletions docs/source/package_reference/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ These include data operations that mimic the same `torch` ops but can be used on

[[autodoc]] utils.gather_object

[[autodoc]] utils.get_grad_scaler

[[autodoc]] utils.get_mixed_precision_context_manager

[[autodoc]] utils.listify

[[autodoc]] utils.pad_across_processes
Expand Down
28 changes: 3 additions & 25 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import torch
import torch.utils.hooks as hooks
from huggingface_hub import split_torch_state_dict_into_shards
from packaging import version

from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches
Expand Down Expand Up @@ -78,6 +77,7 @@
extract_model_from_parallel,
gather,
gather_object,
get_grad_scaler,
get_mixed_precision_context_manager,
get_pretty_name,
is_bf16_available,
Expand Down Expand Up @@ -136,7 +136,6 @@


if is_torch_xla_available():
import torch_xla.amp as xamp
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

Expand Down Expand Up @@ -484,25 +483,7 @@ def __init__(
):
raise ValueError(f"fp16 mixed precision requires a GPU (not {self.device.type!r}).")
kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {}
if self.distributed_type == DistributedType.FSDP:
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler

self.scaler = ShardedGradScaler(**kwargs)
elif is_torch_xla_available(check_is_gpu=True):
self.scaler = xamp.GradScaler(**kwargs)
elif is_mlu_available():
self.scaler = torch.mlu.amp.GradScaler(**kwargs)
elif is_musa_available():
self.scaler = torch.musa.amp.GradScaler(**kwargs)
elif is_npu_available():
self.scaler = torch.npu.amp.GradScaler(**kwargs)
elif is_xpu_available():
self.scaler = torch.amp.GradScaler("xpu", **kwargs)
else:
if version.parse(torch.__version__) > version.parse("2.3"):
self.scaler = torch.amp.GradScaler("cuda", **kwargs)
else:
self.scaler = torch.cuda.amp.GradScaler(**kwargs)
self.scaler = get_grad_scaler(self.distributed_type, **kwargs)

elif self.state.mixed_precision == "bf16" and self.distributed_type not in (
DistributedType.DEEPSPEED,
Expand All @@ -526,10 +507,7 @@ def __init__(
)
elif self.distributed_type != DistributedType.DEEPSPEED:
# MS-AMP requires `GradScaler` even with bf16 autocast w/ single GPU or DDP:
if version.parse(torch.__version__) > version.parse("2.3"):
self.scaler = torch.amp.GradScaler("cuda")
else:
self.scaler = torch.cuda.amp.GradScaler()
self.scaler = get_grad_scaler(**kwargs)

# Start of internal step tracking
self.step = 0
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
dtype_byte_size,
find_tied_parameters,
get_balanced_memory,
get_grad_scaler,
get_max_layer_size,
get_max_memory,
get_mixed_precision_context_manager,
Expand Down
34 changes: 34 additions & 0 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1871,3 +1871,37 @@ def get_mixed_precision_context_manager(native_amp: bool = False, autocast_kwarg
return torch.autocast(device_type=device_type, **autocast_kwargs)
else:
return contextlib.nullcontext()


def get_grad_scaler(distributed_type: DistributedType = None, **kwargs):
"""
A generic helper which will initialize the correct `GradScaler` implementation based on the environment and return
it.

Args:
distributed_type (`DistributedType`, *optional*, defaults to None):
The type of distributed environment.
kwargs:
Additional arguments for the utilized `GradScaler` constructor.
"""
if distributed_type == DistributedType.FSDP:
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler

return ShardedGradScaler(**kwargs)
if is_torch_xla_available(check_is_gpu=True):
import torch_xla.amp as xamp

return xamp.GradScaler(**kwargs)
elif is_mlu_available():
return torch.mlu.amp.GradScaler(**kwargs)
elif is_musa_available():
return torch.musa.amp.GradScaler(**kwargs)
elif is_npu_available():
return torch.npu.amp.GradScaler(**kwargs)
elif is_xpu_available():
return torch.amp.GradScaler("xpu", **kwargs)
else:
if is_torch_version(">=", "2.3"):
return torch.amp.GradScaler("cuda", **kwargs)
else:
return torch.cuda.amp.GradScaler(**kwargs)
2 changes: 1 addition & 1 deletion src/accelerate/utils/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
from .imports import (
is_npu_available,
is_torch_distributed_available,
is_torch_version,
is_torch_xla_available,
is_xpu_available,
)
from .versions import is_torch_version


if is_torch_xla_available():
Expand Down
Loading