Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix deprecated torch.cuda.amp.GradScaler FutureWarning for pytorch 2.4+ #3132

Merged
merged 9 commits into from
Oct 7, 2024
6 changes: 5 additions & 1 deletion benchmarks/fp8/ms_amp/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import msamp
import torch
from fp8_utils import evaluate_model, get_training_utilities
from packaging import version
from torch.nn.parallel import DistributedDataParallel as DDP

from accelerate import Accelerator
Expand All @@ -35,7 +36,10 @@

def train_baseline(opt_level="O2"):
set_seed(42)
scaler = torch.cuda.amp.GradScaler()
if version.parse(torch.__version__) > version.parse("2.3"):
scaler = torch.amp.GradScaler("cuda")
else:
scaler = torch.cuda.amp.GradScaler()
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)
accelerator = Accelerator()
device = accelerator.device
Expand Down
6 changes: 5 additions & 1 deletion benchmarks/fp8/ms_amp/non_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import msamp
import torch
from fp8_utils import evaluate_model, get_training_utilities
from packaging import version

from accelerate import Accelerator
from accelerate.state import AcceleratorState
Expand All @@ -41,7 +42,10 @@ def train_baseline(opt_level="O2"):

base_model_results = evaluate_model(model, eval_dataloader, METRIC)
model.train()
scaler = torch.cuda.amp.GradScaler()
if version.parse(torch.__version__) > version.parse("2.3"):
scaler = torch.amp.GradScaler("cuda")
else:
scaler = torch.cuda.amp.GradScaler()

for batch in train_dataloader:
batch = batch.to("cuda")
Expand Down
16 changes: 13 additions & 3 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import torch
import torch.utils.hooks as hooks
from huggingface_hub import split_torch_state_dict_into_shards
from packaging import version

from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches
Expand Down Expand Up @@ -494,11 +495,17 @@ def __init__(
elif is_musa_available():
self.scalar = torch.musa.amp.GradScaler(**kwargs)
elif is_npu_available():
self.scaler = torch.npu.amp.GradScaler(**kwargs)
if version.parse(torch.__version__) > version.parse("2.3"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we both can't test this, let's not touch NPU for now.

self.scaler = torch.amp.GradScaler("npu", **kwargs)
else:
self.scaler = torch.npu.amp.GradScaler(**kwargs)
elif is_xpu_available():
self.scaler = torch.amp.GradScaler("xpu", **kwargs)
else:
self.scaler = torch.cuda.amp.GradScaler(**kwargs)
if version.parse(torch.__version__) > version.parse("2.3"):
self.scaler = torch.amp.GradScaler("cuda", **kwargs)
else:
self.scaler = torch.cuda.amp.GradScaler(**kwargs)

elif self.state.mixed_precision == "bf16" and self.distributed_type not in (
DistributedType.DEEPSPEED,
Expand All @@ -522,7 +529,10 @@ def __init__(
)
elif self.distributed_type != DistributedType.DEEPSPEED:
# MS-AMP requires `GradScaler` even with bf16 autocast w/ single GPU or DDP:
self.scaler = torch.cuda.amp.GradScaler()
if version.parse(torch.__version__) > version.parse("2.3"):
self.scaler = torch.amp.GradScaler("cuda")
else:
self.scaler = torch.cuda.amp.GradScaler()

# Start of internal step tracking
self.step = 0
Expand Down
6 changes: 3 additions & 3 deletions src/accelerate/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def save_accelerator_state(
The current process index in the Accelerator state
step (`int`):
The current step in the internal step tracker
scaler (`torch.cuda.amp.GradScaler`, *optional*):
An optional gradient scaler instance to save
scaler (`torch.amp.GradScaler`, *optional*) for pytorch>2.3:
An optional gradient scaler instance to save; for lower version, check `torch.cuda.amp.GradScaler`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need this extra explanation in the docstring, it should be quite clear as is what this refers to.

save_on_each_node (`bool`, *optional*):
Whether to save on every node, or only the main node.
safe_serialization (`bool`, *optional*, defaults to `True`):
Expand Down Expand Up @@ -186,7 +186,7 @@ def load_accelerator_state(
A list of learning rate schedulers
process_index (`int`):
The current process index in the Accelerator state
scaler (`torch.cuda.amp.GradScaler`, *optional*):
scaler (`torch.amp.GradScaler`, *optional*):
An optional *GradScaler* instance to load
map_location (`str`, *optional*):
What device to load the optimizer state onto. Should be one of either "cpu" or "on_device".
Expand Down
5 changes: 3 additions & 2 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,9 @@ def register_comm_hook(self, model):
class GradScalerKwargs(KwargsHandler):
"""
Use this object in your [`Accelerator`] to customize the behavior of mixed precision, specifically how the
`torch.cuda.amp.GradScaler` used is created. Please refer to the documentation of this
[scaler](https://pytorch.org/docs/stable/amp.html?highlight=gradscaler) for more information on each argument.
`torch.amp.GradScaler` used is created for pytoch>2.3 or `torch.cuda.amp.GradScaler` for lower version. Please
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, let's not overexplain.

refer to the documentation of this [scaler](https://pytorch.org/docs/stable/amp.html?highlight=gradscaler) for more
information on each argument.

<Tip warning={true}>

Expand Down