Skip to content

Commit

Permalink
State dictionary retrieval from offloaded modules (#2619)
Browse files Browse the repository at this point in the history
* added get_state_dict_from_offloaded

* cleaned

* make style

* Update src/accelerate/utils/modeling.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* implemented suggestions, refactored, make style

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
  • Loading branch information
blbadger and SunMarc authored Jun 3, 2024
1 parent 065e74d commit d5d378d
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 1 deletion.
43 changes: 43 additions & 0 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1566,6 +1566,49 @@ def get_state_dict_offloaded_model(model: nn.Module):
return state_dict


def get_state_dict_from_offload(
module: nn.Module,
module_name: str,
state_dict: Dict[str, Union[str, torch.tensor]],
device_to_put_offload: Union[int, str, torch.device] = "cpu",
):
"""
Retrieve the state dictionary (with parameters) from an offloaded module and load into a specified device (defualts
to cpu).
Args:
module: (`torch.nn.Module`):
The module we want to retrieve a state dictionary from
module_name: (`str`):
The name of the module of interest
state_dict (`Dict[str, Union[int, str, torch.device]]`):
Dictionary of {module names: parameters}
device_to_put_offload (`Union[int, str, torch.device]`):
Device to load offloaded parameters into, defaults to the cpu.
"""
from ..hooks import AlignDevicesHook

root = module_name[: module_name.rfind(".")] # module name without .weight or .bias
preforward = False
if hasattr(module, "_hf_hook") and isinstance(module._hf_hook, AlignDevicesHook) and module._hf_hook.offload:
# assign the device to which the offloaded parameters will be sent
original_device = module._hf_hook.execution_device
module._hf_hook.execution_device = device_to_put_offload
module._hf_hook.pre_forward(module)
preforward = True

for m_key in module.state_dict():
params = module.state_dict()[m_key]
if (root + f".{m_key}") in state_dict:
state_dict[root + f".{m_key}"] = params

if preforward:
module._hf_hook.post_forward(module, torch.tensor([]))
module._hf_hook.execution_device = original_device

return state_dict


def load_checkpoint_in_model(
model: nn.Module,
checkpoint: Union[str, os.PathLike],
Expand Down
29 changes: 28 additions & 1 deletion tests/test_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from accelerate.test_utils import require_bnb, require_multi_device, require_non_cpu, slow, torch_device
from accelerate.test_utils.testing import AccelerateTestCase, require_cuda, require_non_torch_xla
from accelerate.utils import patch_environment
from accelerate.utils.modeling import load_checkpoint_in_model
from accelerate.utils.modeling import get_state_dict_from_offload, load_checkpoint_in_model


def create_components():
Expand Down Expand Up @@ -281,6 +281,33 @@ def test_save_model_offload(self, use_safetensors):
output = model(inputs)
assert torch.allclose(expected, output, atol=1e-5)

@parameterized.expand([True, False], name_func=parameterized_custom_name_func)
def test_get_state_dict_from_offload(self, use_safetensors):
accelerator = Accelerator()

device_map = {"linear1": "cpu", "batchnorm": "disk", "linear2": "disk"}
model = ModelForTest()
offloaded_layer_weight = model.linear2.weight
with tempfile.TemporaryDirectory() as tmp_dir:
accelerator.save_model(model, tmp_dir, safe_serialization=use_safetensors)
# load model with offloaded layers
load_checkpoint_and_dispatch(model, tmp_dir, device_map=device_map, offload_folder=tmp_dir)
cpu_onloaded_layer = get_state_dict_from_offload(
model.linear2, "linear2.weight", {"linear2.weight": ""}, device_to_put_offload="cpu"
)
cuda_onloaded_layer = get_state_dict_from_offload(
model.linear2, "linear2.weight", {"linear2.weight": ""}, device_to_put_offload=0
)
cpu_onloaded_layer_weight = cpu_onloaded_layer["linear2.weight"]
cuda_onloaded_layer_weight = cuda_onloaded_layer["linear2.weight"]

assert torch.allclose(offloaded_layer_weight, cpu_onloaded_layer_weight)
assert torch.allclose(
offloaded_layer_weight, cuda_onloaded_layer_weight.to("cpu")
) # must be on the same device for torch.allclose()
assert cpu_onloaded_layer_weight.device.type == "cpu"
assert cuda_onloaded_layer_weight.device.type == "cuda"

@parameterized.expand([True, False], name_func=parameterized_custom_name_func)
def test_save_load_model_with_hooks(self, use_safetensors):
accelerator = Accelerator()
Expand Down

0 comments on commit d5d378d

Please sign in to comment.