Skip to content

Commit

Permalink
fix load_state_dict for xpu and refine xpu safetensor version check (
Browse files Browse the repository at this point in the history
…#2879)

* add fix

* update warning

* no and
  • Loading branch information
faaany authored Jul 3, 2024
1 parent 3a02754 commit 92404fb
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
22 changes: 11 additions & 11 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import contextlib
import gc
import importlib
import inspect
import json
import logging
Expand All @@ -26,7 +25,6 @@
from collections import OrderedDict, defaultdict
from typing import Dict, List, Optional, Tuple, Union

import packaging
import torch
import torch.nn as nn

Expand Down Expand Up @@ -1456,7 +1454,15 @@ def load_state_dict(checkpoint_file, device_map=None):
else:
# if we only have one device we can load everything directly
if len(set(device_map.values())) == 1:
return safe_load_file(checkpoint_file, device=list(device_map.values())[0])
device = list(device_map.values())[0]
target_device = device
if is_xpu_available():
if compare_versions("safetensors", "<", "0.4.2"):
raise ImportError("Safetensors version must be >= 0.4.2 for XPU. Please upgrade safetensors.")
if isinstance(device, int):
target_device = f"xpu:{device}"

return safe_load_file(checkpoint_file, device=target_device)

devices = list(set(device_map.values()) - {"disk"})
# cpu device should always exist as fallback option
Expand Down Expand Up @@ -1486,15 +1492,9 @@ def load_state_dict(checkpoint_file, device_map=None):
progress_bar = None
for device in devices:
target_device = device

if is_xpu_available():
current_safetensors_version = packaging.version.parse(importlib.metadata.version("safetensors"))

if compare_versions(current_safetensors_version, "<", "0.4.2"):
raise ModuleNotFoundError(
f"You need at least safetensors 0.4.2 for Intel GPU, while you have {current_safetensors_version}"
)

if compare_versions("safetensors", "<", "0.4.2"):
raise ImportError("Safetensors version must be >= 0.4.2 for XPU. Please upgrade safetensors.")
if isinstance(device, int):
target_device = f"xpu:{device}"

Expand Down
7 changes: 5 additions & 2 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ def test_get_balanced_memory(self):
max_memory = get_balanced_memory(model, max_memory={0: 0, "cpu": 100})
assert {0: 0, "cpu": 100} == max_memory

@require_cuda
@require_non_cpu
def test_load_state_dict(self):
state_dict = {k: torch.randn(4, 5) for k in ["a", "b", "c"]}
device_maps = [{"a": "cpu", "b": 0, "c": "disk"}, {"a": 0, "b": 0, "c": "disk"}, {"a": 0, "b": 0, "c": 0}]
Expand All @@ -748,7 +748,10 @@ def test_load_state_dict(self):

for param, device in device_map.items():
device = device if device != "disk" else "cpu"
assert loaded_state_dict[param].device == torch.device(device)
expected_device = (
torch.device(f"{torch_device}:{device}") if isinstance(device, int) else torch.device(device)
)
assert loaded_state_dict[param].device == expected_device

def test_convert_file_size(self):
result = convert_file_size_to_int("0MB")
Expand Down

0 comments on commit 92404fb

Please sign in to comment.