Skip to content

Commit

Permalink
Allow users to pass HF via local-path: model_cls.import_ckpt("hf:///p…
Browse files Browse the repository at this point in the history
…ath/to/ckpt/"); (NVIDIA#9978)

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
  • Loading branch information
akoumpa authored Aug 2, 2024
1 parent c277b9a commit fdf07a9
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion nemo/lightning/io/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
import shutil
from pathlib import Path, PosixPath, WindowsPath
from pathlib import Path, PosixPath, PurePath, WindowsPath
from typing import Generic, Optional, Tuple, TypeVar

import pytorch_lightning as pl
Expand Down Expand Up @@ -212,6 +212,10 @@ def local_path(self, base_path: Optional[Path] = None) -> Path:

_base = Path(NEMO_MODELS_CACHE)

# If the useu supplied `hf:///path/to/downloaded/my-model/`
# then extract the last dir-name (i.e. my-model) and append it to _base
if str(self).startswith('/'):
return _base / PurePath((str(self))).name
return _base / str(self).replace("://", "/")

def on_import_ckpt(self, model: pl.LightningModule):
Expand Down

0 comments on commit fdf07a9

Please sign in to comment.