Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set the state device dependant to Accelerator on multigpu #1220

Merged
merged 13 commits into from
Apr 6, 2023
19 changes: 17 additions & 2 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2322,7 +2322,8 @@ def load_state(self, input_dir: str, **load_model_func_kwargs):
The name of the folder all relevant weights and states were saved in.
load_model_func_kwargs (`dict`, *optional*):
Additional keyword arguments for loading model which can be passed to the underlying load function,
such as optional arguments for DeepSpeed's `load_checkpoint` function.
such as optional arguments for DeepSpeed's `load_checkpoint` function or a `map_location` to load the
model and optimizer on.

Example:

Expand Down Expand Up @@ -2385,8 +2386,22 @@ def load_state(self, input_dir: str, **load_model_func_kwargs):
for hook in self._load_model_state_pre_hook.values():
hook(models, input_dir)

map_location = load_model_func_kwargs.pop("map_location", None)
if map_location is None:
if self.num_processes > 1 and self.distributed_type == DistributedType.MULTI_GPU:
map_location = "on_device"
else:
map_location = "cpu"

load_accelerator_state(
input_dir, models, optimizers, schedulers, self.state.process_index, self.scaler, **load_model_func_kwargs
input_dir,
models,
optimizers,
schedulers,
self.state.process_index,
self.scaler,
map_location,
**load_model_func_kwargs,
)
custom_checkpoints = [f for f in os.listdir(input_dir) if "custom_checkpoint" in f]
if len(custom_checkpoints) != len(self._custom_objects):
Expand Down
25 changes: 22 additions & 3 deletions src/accelerate/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import torch_xla.core.xla_model as xm

from .logging import get_logger
from .state import PartialState


logger = get_logger(__name__)
Expand Down Expand Up @@ -110,7 +111,14 @@ def save_accelerator_state(


def load_accelerator_state(
input_dir, models, optimizers, schedulers, process_index, scaler=None, **load_model_func_kwargs
input_dir,
models,
optimizers,
schedulers,
process_index,
scaler=None,
map_location=None,
**load_model_func_kwargs,
):
"""
Loads states of the models, optimizers, scaler, and RNG generators from a given directory.
Expand All @@ -128,21 +136,32 @@ def load_accelerator_state(
The current process index in the Accelerator state
scaler (`torch.cuda.amp.GradScaler`, *optional*):
An optional *GradScaler* instance to load
map_location (`str`, *optional*):
What device to load the optimizer state onto. Should be one of either "cpu" or "on_device".
load_model_func_kwargs (`dict`, *optional*):
Additional arguments that can be passed to the model's `load_state_dict` method.
"""
if map_location not in [None, "cpu", "on_device"]:
raise TypeError(
"Unsupported optimizer map location passed, please choose one of `None`, `'cpu'`, or `'on_device'`"
)
if map_location is None:
map_location = "cpu"
elif map_location == "on_device":
map_location = PartialState().device
# Model states
for i, model in enumerate(models):
weights_name = f"{MODEL_NAME}.bin" if i == 0 else f"{MODEL_NAME}_{i}.bin"
input_model_file = os.path.join(input_dir, weights_name)
models[i].load_state_dict(torch.load(input_model_file, map_location="cpu"), **load_model_func_kwargs)
models[i].load_state_dict(torch.load(input_model_file, map_location=map_location), **load_model_func_kwargs)
logger.info("All model weights loaded successfully")

# Optimizer states
for i, opt in enumerate(optimizers):
optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"
input_optimizer_file = os.path.join(input_dir, optimizer_name)
optimizers[i].load_state_dict(torch.load(input_optimizer_file, map_location="cpu"))
optimizer_state = torch.load(input_optimizer_file)
optimizers[i].load_state_dict(optimizer_state)
logger.info("All optimizer states loaded successfully")

# Scheduler states
Expand Down
68 changes: 67 additions & 1 deletion tests/test_state_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import logging
import os
import random
import shutil
import tempfile
import unittest

import pytest
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset

from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration, set_seed
from accelerate.test_utils import execute_subprocess_async, require_cuda
from accelerate.utils import ProjectConfiguration, get_launch_prefix, set_seed


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -248,3 +252,65 @@ def test_checkpoint_deletion(self):
self.assertTrue(not os.path.exists(os.path.join(tmpdir, "checkpoints", "checkpoint_0")))
self.assertTrue(os.path.exists(os.path.join(tmpdir, "checkpoints", "checkpoint_9")))
self.assertTrue(os.path.exists(os.path.join(tmpdir, "checkpoints", "checkpoint_10")))

@require_cuda
def test_map_location(self):
cmd = get_launch_prefix()
cmd += [f"--nproc_per_node={torch.cuda.device_count()}", inspect.getfile(self.__class__)]
execute_subprocess_async(cmd, env=os.environ.copy())


if __name__ == "__main__":
savedir = "/tmp/accelerate/state_checkpointing"
model = DummyModel()
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)
train_dataloader, valid_dataloader = dummy_dataloaders()
project_config = ProjectConfiguration(automatic_checkpoint_naming=True)
# Train baseline
accelerator = Accelerator(project_dir=savedir, project_config=project_config, mixed_precision="no")
if accelerator.process_index == 0:
if os.path.exists(savedir):
shutil.rmtree(savedir)
os.makedirs(savedir)
model, optimizer, train_dataloader, valid_dataloader, scheduler = accelerator.prepare(
model, optimizer, train_dataloader, valid_dataloader, scheduler
)
model, optimizer = accelerator.prepare(model, optimizer)
train(3, model, train_dataloader, optimizer, accelerator, scheduler)
# Check that the intial optimizer is loaded on the GPU
for group in optimizer.param_groups:
param_device = group["params"][0].device
break
assert param_device.type == accelerator.device.type
model = model.cpu()
accelerator.wait_for_everyone()
accelerator.save_state()
accelerator.wait_for_everyone()

# Check CPU state
accelerator.load_state(os.path.join(savedir, "checkpoints", "checkpoint_0"), map_location="cpu")
for group in optimizer.param_groups:
param_device = group["params"][0].device
break
assert (
param_device.type == torch.device("cpu").type
), f"Loaded optimizer states did not match, expected to be loaded on the CPU but got {param_device}"

# Check device state
model.to(accelerator.device)
accelerator.load_state(os.path.join(savedir, "checkpoints", "checkpoint_0"), map_location="on_device")
for group in optimizer.param_groups:
param_device = group["params"][0].device
break
assert (
param_device.type == accelerator.device.type
), f"Loaded optimizer states did not match, expected to be loaded on {accelerator.device} but got {param_device}"

# Check error
with pytest.raises(TypeError, match="Unsupported optimizer map location passed"):
accelerator.load_state(os.path.join(savedir, "checkpoints", "checkpoint_0"), map_location="invalid")
accelerator.wait_for_everyone()
if accelerator.process_index == 0:
shutil.rmtree(savedir)
accelerator.wait_for_everyone()