Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Models With Tied Weights Need Re-Tieing After FSDP Param Init #3154

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
convert_model,
convert_outputs_to_fp32,
extract_model_from_parallel,
ensure_weights_retied,
gather,
gather_object,
get_grad_scaler,
Expand Down Expand Up @@ -1472,6 +1473,13 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
if not is_type_fsdp:
self.state.fsdp_plugin.set_auto_wrap_policy(model)
fsdp_plugin = self.state.fsdp_plugin

# need to ensure that params are re-tied after running
# param_init_fn
fsdp_plugin.param_init_fn = ensure_weights_retied(
fsdp_plugin.param_init_fn, model, self.device,
)
muellerzr marked this conversation as resolved.
Show resolved Hide resolved

kwargs = {
"sharding_strategy": fsdp_plugin.sharding_strategy,
"cpu_offload": fsdp_plugin.cpu_offload,
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@
merge_fsdp_weights,
save_fsdp_model,
save_fsdp_optimizer,
ensure_weights_retied,
)
from .launch import (
PrepareForLaunch,
Expand Down
51 changes: 51 additions & 0 deletions src/accelerate/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
import os
import shutil
from pathlib import Path
from collections import defaultdict

import torch


from ..logging import get_logger
from .constants import FSDP_MODEL_NAME, OPTIMIZER_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from .modeling import is_peft_model
Expand Down Expand Up @@ -324,3 +326,52 @@ def merge_fsdp_weights(
logger.info(f"Removing old checkpoint directory {checkpoint_dir}")
shutil.rmtree(checkpoint_dir)
state.wait_for_everyone()

def ensure_weights_retied(
param_init_fn, model: torch.nn.Module, device: torch.cuda.device
):

_tied_names = model._tied_weights_keys
if not _tied_names:
# if no tied names just passthrough
return param_init_fn

# get map of parameter instances to params.
# - needed for replacement later
_tied_params = {}
for name in _tied_names:
name = name.split('.')
name, param_name = '.'.join(name[:-1]), name[-1]
mod = model.get_submodule(name)
param = getattr(mod, param_name)

_tied_params[id(param)] = None # placeholder for the param first

# build param_init_fn for the case with tied params
def param_init_fn_tied_param(module: torch.nn.Module):

# track which params to tie
# - usually only 1, but for completeness consider > 1
params_to_tie = defaultdict(list)
for n, param in module.named_parameters(recurse=False):
if id(param) in _tied_params:
params_to_tie[id(param)].append(n)

# call the param init fn, which potentially re-allocates the
# parameters
module = param_init_fn(module)

# search the parameters again and tie them up again
for id_key, _param_names in params_to_tie.items():
for param_name in _param_names:
param = _tied_params[id_key]
if param is None:
# everything will be tied to the first time the
# param is observed
_tied_params[id_key] = getattr(module, param_name)
else:
setattr(module, param_name, param) # tie

return module

return param_init_fn_tied_param
Loading