diff --git a/src/timesfm/timesfm_base.py b/src/timesfm/timesfm_base.py index 0364089..4d60a4c 100644 --- a/src/timesfm/timesfm_base.py +++ b/src/timesfm/timesfm_base.py @@ -130,6 +130,7 @@ class TimesFmCheckpoint: huggingface_repo_id: str | None = None type: Any = None step: int | None = None + local_dir: str | None = None class TimesFmBase: diff --git a/src/timesfm/timesfm_torch.py b/src/timesfm/timesfm_torch.py index 5137e71..5775e8d 100644 --- a/src/timesfm/timesfm_torch.py +++ b/src/timesfm/timesfm_torch.py @@ -55,8 +55,9 @@ def load_from_checkpoint( checkpoint_path = checkpoint.path repo_id = checkpoint.huggingface_repo_id if checkpoint_path is None: - checkpoint_path = path.join(snapshot_download(repo_id), - "torch_model.ckpt") + checkpoint_path = path.join( + snapshot_download(repo_id, local_dir=checkpoint.local_dir), + "torch_model.ckpt") self._model = ppd.PatchedTimeSeriesDecoder(self._model_config) loaded_checkpoint = torch.load(checkpoint_path, weights_only=True) logging.info("Loading checkpoint from %s", checkpoint_path)