Skip to content

Commit

Permalink
feat: select a subset of files for huggingface models (#5144)
Browse files Browse the repository at this point in the history
  • Loading branch information
frostming authored Dec 23, 2024
1 parent 96fd58e commit d4e45ed
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions src/_bentoml_sdk/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,18 @@ class HuggingFaceModel(Model[str]):
model_id (str): The model tag. E.g. "google-bert/bert-base-uncased".
revision (str, optional): The revision to use. Defaults to "main".
endpoint (str, optional): The Hugging Face endpoint to use. Defaults to https://huggingface.co.
include (List[str], optional): The files to include. Defaults to all files.
exclude (List[str], optional): The files to exclude. Defaults to no files.
Returns:
str: The downloaded model path.
"""

model_id: str
revision: str = "main"
endpoint: str | None = attrs.field(factory=lambda: os.getenv("HF_ENDPOINT"))
endpoint: t.Optional[str] = attrs.field(factory=lambda: os.getenv("HF_ENDPOINT"))
include: t.Optional[t.List[str]] = None
exclude: t.Optional[t.List[str]] = None

@cached_property
def _hf_api(self) -> HfApi:
Expand All @@ -65,6 +69,8 @@ def resolve(self, base_path: t.Union[PathType, FS, None] = None) -> str:
revision=self.revision,
endpoint=self.endpoint,
cache_dir=os.getenv("BENTOML_HF_CACHE_DIR"),
allow_patterns=self.include,
ignore_patterns=self.exclude,
)
if base_path is not None:
model_path = os.path.dirname(os.path.dirname(snapshot_path))
Expand All @@ -88,6 +94,8 @@ def to_info(self, alias: str | None = None) -> BentoModelInfo:
"model_id": model_id,
"revision": self.commit_hash,
"endpoint": self.endpoint or DEFAULT_HF_ENDPOINT,
"include": self.include,
"exclude": self.exclude,
},
)

Expand All @@ -100,16 +108,26 @@ def from_info(cls, info: BentoModelInfo) -> HuggingFaceModel:
model_id=info.metadata["model_id"],
revision=info.metadata["revision"],
endpoint=info.metadata["endpoint"],
include=info.metadata.get("include"),
exclude=info.metadata.get("exclude"),
)
# the commit hash is frozen in the model info, update the cache directly
model.__dict__.update(commit_hash=info.metadata["revision"])
return model

def _get_model_size(self, revision: str) -> int:
from huggingface_hub.utils import filter_repo_objects

info = self._hf_api.model_info(
self.model_id, revision=revision, files_metadata=True
)
return sum((file.size or 0) for file in (info.siblings or []))
filtered_files = filter_repo_objects(
items=info.siblings or [],
allow_patterns=self.include,
ignore_patterns=self.exclude,
key=lambda f: f.rfilename,
)
return sum((file.size or 0) for file in filtered_files)

def to_create_schema(self) -> CreateModelSchema:
context = ModelContext(framework_name="huggingface", framework_versions={})
Expand All @@ -121,6 +139,8 @@ def to_create_schema(self) -> CreateModelSchema:
"model_id": self.model_id,
"revision": revision,
"endpoint": endpoint,
"include": self.include,
"exclude": self.exclude,
"url": url,
}
return CreateModelSchema(
Expand Down

0 comments on commit d4e45ed

Please sign in to comment.