Skip to content

Commit

Permalink
Recover from aborted or failed model downloads (pytorch#358)
Browse files Browse the repository at this point in the history
  • Loading branch information
GregoryComer authored and malfet committed Jul 17, 2024
1 parent 8465aed commit 2b83c7c
Showing 1 changed file with 38 additions and 25 deletions.
63 changes: 38 additions & 25 deletions download.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import shutil
import urllib.request
from pathlib import Path
from typing import Optional, Sequence
Expand All @@ -19,26 +20,22 @@


def _download_hf_snapshot(
model_config: ModelConfig, models_dir: Path, hf_token: Optional[str]
model_config: ModelConfig, artifact_dir: Path, hf_token: Optional[str]
):
model_dir = models_dir / model_config.name
os.makedirs(model_dir, exist_ok=True)

from huggingface_hub import snapshot_download

# Download and store the HF model artifacts.
print(f"Downloading {model_config.name} from HuggingFace...")
try:
snapshot_download(
model_config.distribution_path,
local_dir=model_dir,
local_dir=artifact_dir,
local_dir_use_symlinks=False,
token=hf_token,
ignore_patterns="*safetensors*",
)
except HTTPError as e:
if e.response.status_code == 401:
os.rmdir(model_dir)
raise RuntimeError(
"Access denied. Run huggingface-cli login to authenticate."
)
Expand All @@ -48,20 +45,16 @@ def _download_hf_snapshot(

# Convert the model to the torchchat format.
print(f"Converting {model_config.name} to torchchat format...")
convert_hf_checkpoint(model_dir=model_dir, model_name=model_config.name, remove_bin_files=True)
convert_hf_checkpoint(model_dir=artifact_dir, model_name=model_config.name, remove_bin_files=True)


def _download_direct(
model_config: ModelConfig,
urls: Sequence[str],
models_dir: Path,
artifact_dir: Path,
):
model_dir = models_dir / model_config.name
os.makedirs(model_dir, exist_ok=True)

for url in urls:
for url in model_config.distribution_path:
filename = url.split("/")[-1]
local_path = model_dir / filename
local_path = artifact_dir / filename
print(f"Downloading {url}...")
urllib.request.urlretrieve(url, str(local_path.absolute()))

Expand All @@ -70,18 +63,38 @@ def download_and_convert(
model: str, models_dir: Path, hf_token: Optional[str] = None
) -> None:
model_config = resolve_model_config(model)
model_dir = models_dir / model_config.name

if (
model_config.distribution_channel
== ModelDistributionChannel.HuggingFaceSnapshot
):
_download_hf_snapshot(model_config, models_dir, hf_token)
elif model_config.distribution_channel == ModelDistributionChannel.DirectDownload:
_download_direct(model_config, model_config.distribution_path, models_dir)
else:
raise RuntimeError(
f"Unknown distribution channel {model_config.distribution_channel}."
)
# Download into a temporary directory. We'll move to the final location once
# the download and conversion is complete. This allows recovery in the event
# that the download or conversion fails unexpectedly.
temp_dir = models_dir / "downloads" / model_config.name
if os.path.isdir(temp_dir):
shutil.rmtree(temp_dir)
os.makedirs(temp_dir, exist_ok=True)

try:
if (
model_config.distribution_channel
== ModelDistributionChannel.HuggingFaceSnapshot
):
_download_hf_snapshot(model_config, temp_dir, hf_token)
elif model_config.distribution_channel == ModelDistributionChannel.DirectDownload:
_download_direct(model_config, temp_dir)
else:
raise RuntimeError(
f"Unknown distribution channel {model_config.distribution_channel}."
)

# Move from the temporary directory to the intended location,
# overwriting if necessary.
if os.path.isdir(model_dir):
shutil.rmtree(model_dir)
os.rename(temp_dir, model_dir)

finally:
if os.path.isdir(temp_dir):
shutil.rmtree(temp_dir)


def is_model_downloaded(model: str, models_dir: Path) -> bool:
Expand Down

0 comments on commit 2b83c7c

Please sign in to comment.