Skip to content

Commit

Permalink
Correct loading of models with shared tensors when using accelerator.…
Browse files Browse the repository at this point in the history
…load_state() (#2875)

* Enabled correct loading of models with shared tensors when using accelerator.load_state()

* removed unused import

* added a test for a model with shared weights

* removed unnecessary bits

* fixed linting errors
  • Loading branch information
jkuntzer authored Jul 15, 2024
1 parent c6da9f8 commit f4f1260
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
6 changes: 3 additions & 3 deletions src/accelerate/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import numpy as np
import torch
from safetensors.torch import load_file
from safetensors.torch import load_model
from torch.cuda.amp import GradScaler

from .utils import (
Expand Down Expand Up @@ -205,12 +205,12 @@ def load_accelerator_state(
ending = f"_{i}" if i > 0 else ""
input_model_file = input_dir.joinpath(f"{SAFE_MODEL_NAME}{ending}.safetensors")
if input_model_file.exists():
state_dict = load_file(input_model_file, device=str(map_location))
load_model(model, input_model_file, device=str(map_location), **load_model_func_kwargs)
else:
# Load with torch
input_model_file = input_dir.joinpath(f"{MODEL_NAME}{ending}.bin")
state_dict = torch.load(input_model_file, map_location=map_location)
models[i].load_state_dict(state_dict, **load_model_func_kwargs)
model.load_state_dict(state_dict, **load_model_func_kwargs)
logger.info("All model weights loaded successfully")

# Optimizer states
Expand Down
30 changes: 23 additions & 7 deletions tests/test_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,20 @@
from accelerate.utils.modeling import get_state_dict_from_offload, load_checkpoint_in_model


def create_components():
model = torch.nn.Linear(2, 4)
class ModelWithTiedWeights(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(2, 4)
self.linear2 = torch.nn.Linear(4, 2)
self.linear2.weight = self.linear1.weight
self.linear2.bias = self.linear1.bias

def forward(self, x):
return self.linear2(self.linear1(x))


def create_components(tied_weights=False):
model = ModelWithTiedWeights() if tied_weights else torch.nn.Linear(2, 4)
optimizer = torch.optim.AdamW(model.parameters(), lr=1.0)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=2, epochs=1)
train_dl = DataLoader(TensorDataset(torch.tensor([1, 2, 3])))
Expand All @@ -54,18 +66,22 @@ def forward(self, x):


def get_signature(model):
return (model.weight.abs().sum() + model.bias.abs().sum()).item()
return sum(param.abs().sum().item() for param in model.parameters())


def load_random_weights(model):
state = torch.nn.Linear(*tuple(model.weight.T.shape)).state_dict()
if isinstance(model, torch.nn.Linear):
state = torch.nn.Linear(*tuple(model.weight.T.shape)).state_dict()
elif isinstance(model, ModelWithTiedWeights):
state = ModelWithTiedWeights().state_dict()
model.load_state_dict(state)


def parameterized_custom_name_func(func, param_num, param):
# customize the test name generator function as we want both params to appear in the sub-test
# name, as by default it shows only the first param
param_based_name = "use_safetensors" if param.args[0] is True else "use_pytorch"
param_based_name += "_tied_weights" if (len(param.args) == 2 and param.args[1] is True) else ""
return f"{func.__name__}_{param_based_name}"


Expand Down Expand Up @@ -230,10 +246,10 @@ def noop(*args, **kwargs):
accelerator = Accelerator()
assert str(accelerator.state.device) == "cuda:64"

@parameterized.expand((True, False), name_func=parameterized_custom_name_func)
def test_save_load_model(self, use_safetensors):
@parameterized.expand([(True, True), (True, False), (False, False)], name_func=parameterized_custom_name_func)
def test_save_load_model(self, use_safetensors, tied_weights):
accelerator = Accelerator()
model, optimizer, scheduler, train_dl, valid_dl = create_components()
model, optimizer, scheduler, train_dl, valid_dl = create_components(tied_weights)
accelerator.prepare(model, optimizer, scheduler, train_dl, valid_dl)

model_signature = get_signature(model)
Expand Down

0 comments on commit f4f1260

Please sign in to comment.