Skip to content

Commit

Permalink
Skip converting .safetensors to .bin (#1853)
Browse files Browse the repository at this point in the history
Co-authored-by: Andrei-Aksionov <aksionau.andrei@gmail.com>
  • Loading branch information
ysjprojects and Andrei-Aksionov authored Dec 30, 2024
1 parent 93fc1b8 commit 470f14e
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 42 deletions.
21 changes: 7 additions & 14 deletions litgpt/scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from litgpt.config import Config
from litgpt.utils import extend_checkpoint_dir, incremental_save, lazy_load, save_config
from safetensors.torch import load_file as load_safetensors


def copy_weights_gpt_neox(
Expand Down Expand Up @@ -556,13 +557,13 @@ def convert_hf_checkpoint(
elif model_safetensor_map_json_path.is_file():
with open(model_safetensor_map_json_path, encoding="utf-8") as json_map:
bin_index = json.load(json_map)
bin_files = {checkpoint_dir / Path(bin).with_suffix(".bin") for bin in bin_index["weight_map"].values()}
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
else:
bin_files = set(checkpoint_dir.glob("*.bin"))
bin_files = set(checkpoint_dir.glob("*.bin")) | set(checkpoint_dir.glob("*.safetensors"))
# some checkpoints serialize the training arguments
bin_files = {f for f in bin_files if f.name != "training_args.bin"}
if not bin_files:
raise ValueError(f"Expected {str(checkpoint_dir)!r} to contain .bin files")
raise ValueError(f"Expected {str(checkpoint_dir)!r} to contain .bin or .safetensors files")

with incremental_save(checkpoint_dir / "lit_model.pth") as saver:
# for checkpoints that split the QKV across several files, we need to keep all the bin files
Expand All @@ -584,16 +585,8 @@ def convert_hf_checkpoint(
current_file_size = os.path.getsize(bin_file)
progress_per_file = (current_file_size / total_size) * total_progress

hf_weights = lazy_load(bin_file)
copy_fn(
sd,
hf_weights,
saver=saver,
dtype=dtype,
pbar=pbar,
progress_per_file=progress_per_file,
debug_mode=debug_mode,
)
hf_weights = load_safetensors(bin_file) if bin_file.suffix == ".safetensors" else lazy_load(bin_file)
copy_fn(sd, hf_weights, saver=saver, dtype=dtype, pbar=pbar, progress_per_file=progress_per_file, debug_mode=debug_mode)
gc.collect()

if pbar.n < total_progress:
Expand All @@ -602,7 +595,7 @@ def convert_hf_checkpoint(
else:
# Handling files without progress bar in debug mode
for bin_file in sorted(bin_files):
hf_weights = lazy_load(bin_file)
hf_weights = load_safetensors(bin_file) if bin_file.suffix == ".safetensors" else lazy_load(bin_file)
copy_fn(sd, hf_weights, saver=saver, dtype=dtype, debug_mode=debug_mode)

print(f"Saving converted checkpoint to {checkpoint_dir}")
Expand Down
28 changes: 0 additions & 28 deletions litgpt/scripts/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def download_from_hub(
from huggingface_hub import snapshot_download

download_files = ["tokenizer*", "generation_config.json", "config.json"]
from_safetensors = False
if not tokenizer_only:
bins, safetensors = find_weight_files(repo_id, access_token)
if bins:
Expand All @@ -68,7 +67,6 @@ def download_from_hub(
if not _SAFETENSORS_AVAILABLE:
raise ModuleNotFoundError(str(_SAFETENSORS_AVAILABLE))
download_files.append("*.safetensors*")
from_safetensors = True
else:
raise ValueError(f"Couldn't find weight files for {repo_id}")

Expand All @@ -93,37 +91,11 @@ def download_from_hub(
constants.HF_HUB_ENABLE_HF_TRANSFER = previous
download.HF_HUB_ENABLE_HF_TRANSFER = previous

if from_safetensors:
print("Converting .safetensor files to PyTorch binaries (.bin)")
safetensor_paths = list(directory.glob("*.safetensors"))
with ProcessPoolExecutor() as executor:
executor.map(convert_safetensors_file, safetensor_paths)

if convert_checkpoint and not tokenizer_only:
print("Converting checkpoint files to LitGPT format.")
convert_hf_checkpoint(checkpoint_dir=directory, dtype=dtype, model_name=model_name)


def convert_safetensors_file(safetensor_path: Path) -> None:
from safetensors import SafetensorError
from safetensors.torch import load_file as safetensors_load

bin_path = safetensor_path.with_suffix(".bin")
try:
result = safetensors_load(safetensor_path)
except SafetensorError as e:
raise RuntimeError(f"{safetensor_path} is likely corrupted. Please try to re-download it.") from e
print(f"{safetensor_path} --> {bin_path}")
torch.save(result, bin_path)
try:
os.remove(safetensor_path)
except PermissionError:
print(
f"Unable to remove {safetensor_path} file. "
"This file is no longer needed and you may want to delete it manually to save disk space."
)


def find_weight_files(repo_id: str, access_token: Optional[str]) -> Tuple[List[str], List[str]]:
from huggingface_hub import repo_info
from huggingface_hub.utils import filter_repo_objects
Expand Down

0 comments on commit 470f14e

Please sign in to comment.