From eee8affcb3860c1623e7470849fa3749823941c2 Mon Sep 17 00:00:00 2001 From: Frost Ming Date: Mon, 23 Dec 2024 16:57:44 +0800 Subject: [PATCH] feat: select a subset of files for huggingface models Signed-off-by: Frost Ming --- src/_bentoml_sdk/models/huggingface.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/src/_bentoml_sdk/models/huggingface.py b/src/_bentoml_sdk/models/huggingface.py index 02c0d0abea2..adfe253e98b 100644 --- a/src/_bentoml_sdk/models/huggingface.py +++ b/src/_bentoml_sdk/models/huggingface.py @@ -32,6 +32,8 @@ class HuggingFaceModel(Model[str]): model_id (str): The model tag. E.g. "google-bert/bert-base-uncased". revision (str, optional): The revision to use. Defaults to "main". endpoint (str, optional): The Hugging Face endpoint to use. Defaults to https://huggingface.co. + include (List[str], optional): The files to include. Defaults to all files. + exclude (List[str], optional): The files to exclude. Defaults to no files. Returns: str: The downloaded model path. @@ -39,7 +41,9 @@ class HuggingFaceModel(Model[str]): model_id: str revision: str = "main" - endpoint: str | None = attrs.field(factory=lambda: os.getenv("HF_ENDPOINT")) + endpoint: t.Optional[str] = attrs.field(factory=lambda: os.getenv("HF_ENDPOINT")) + include: t.Optional[t.List[str]] = None + exclude: t.Optional[t.List[str]] = None @cached_property def _hf_api(self) -> HfApi: @@ -65,6 +69,8 @@ def resolve(self, base_path: t.Union[PathType, FS, None] = None) -> str: revision=self.revision, endpoint=self.endpoint, cache_dir=os.getenv("BENTOML_HF_CACHE_DIR"), + allow_patterns=self.include, + ignore_patterns=self.exclude, ) if base_path is not None: model_path = os.path.dirname(os.path.dirname(snapshot_path)) @@ -88,6 +94,8 @@ def to_info(self, alias: str | None = None) -> BentoModelInfo: "model_id": model_id, "revision": self.commit_hash, "endpoint": self.endpoint or DEFAULT_HF_ENDPOINT, + "include": self.include, + "exclude": self.exclude, }, ) @@ -100,16 +108,26 @@ def from_info(cls, info: BentoModelInfo) -> HuggingFaceModel: model_id=info.metadata["model_id"], revision=info.metadata["revision"], endpoint=info.metadata["endpoint"], + include=info.metadata.get("include"), + exclude=info.metadata.get("exclude"), ) # the commit hash is frozen in the model info, update the cache directly model.__dict__.update(commit_hash=info.metadata["revision"]) return model def _get_model_size(self, revision: str) -> int: + from huggingface_hub.utils import filter_repo_objects + info = self._hf_api.model_info( self.model_id, revision=revision, files_metadata=True ) - return sum((file.size or 0) for file in (info.siblings or [])) + filtered_files = filter_repo_objects( + items=info.siblings or [], + allow_patterns=self.include, + ignore_patterns=self.exclude, + key=lambda f: f.rfilename, + ) + return sum((file.size or 0) for file in filtered_files) def to_create_schema(self) -> CreateModelSchema: context = ModelContext(framework_name="huggingface", framework_versions={}) @@ -121,6 +139,8 @@ def to_create_schema(self) -> CreateModelSchema: "model_id": self.model_id, "revision": revision, "endpoint": endpoint, + "include": self.include, + "exclude": self.exclude, "url": url, } return CreateModelSchema(