Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: select a subset of files for huggingface models #5144

Merged
merged 1 commit into from
Dec 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions src/_bentoml_sdk/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,18 @@ class HuggingFaceModel(Model[str]):
model_id (str): The model tag. E.g. "google-bert/bert-base-uncased".
revision (str, optional): The revision to use. Defaults to "main".
endpoint (str, optional): The Hugging Face endpoint to use. Defaults to https://huggingface.co.
include (List[str], optional): The files to include. Defaults to all files.
exclude (List[str], optional): The files to exclude. Defaults to no files.

Returns:
str: The downloaded model path.
"""

model_id: str
revision: str = "main"
endpoint: str | None = attrs.field(factory=lambda: os.getenv("HF_ENDPOINT"))
endpoint: t.Optional[str] = attrs.field(factory=lambda: os.getenv("HF_ENDPOINT"))
include: t.Optional[t.List[str]] = None
exclude: t.Optional[t.List[str]] = None

@cached_property
def _hf_api(self) -> HfApi:
Expand All @@ -65,6 +69,8 @@ def resolve(self, base_path: t.Union[PathType, FS, None] = None) -> str:
revision=self.revision,
endpoint=self.endpoint,
cache_dir=os.getenv("BENTOML_HF_CACHE_DIR"),
allow_patterns=self.include,
ignore_patterns=self.exclude,
)
if base_path is not None:
model_path = os.path.dirname(os.path.dirname(snapshot_path))
Expand All @@ -88,6 +94,8 @@ def to_info(self, alias: str | None = None) -> BentoModelInfo:
"model_id": model_id,
"revision": self.commit_hash,
"endpoint": self.endpoint or DEFAULT_HF_ENDPOINT,
"include": self.include,
"exclude": self.exclude,
},
)

Expand All @@ -100,16 +108,26 @@ def from_info(cls, info: BentoModelInfo) -> HuggingFaceModel:
model_id=info.metadata["model_id"],
revision=info.metadata["revision"],
endpoint=info.metadata["endpoint"],
include=info.metadata.get("include"),
exclude=info.metadata.get("exclude"),
)
# the commit hash is frozen in the model info, update the cache directly
model.__dict__.update(commit_hash=info.metadata["revision"])
return model

def _get_model_size(self, revision: str) -> int:
from huggingface_hub.utils import filter_repo_objects

info = self._hf_api.model_info(
self.model_id, revision=revision, files_metadata=True
)
return sum((file.size or 0) for file in (info.siblings or []))
filtered_files = filter_repo_objects(
items=info.siblings or [],
allow_patterns=self.include,
ignore_patterns=self.exclude,
key=lambda f: f.rfilename,
)
return sum((file.size or 0) for file in filtered_files)

def to_create_schema(self) -> CreateModelSchema:
context = ModelContext(framework_name="huggingface", framework_versions={})
Expand All @@ -121,6 +139,8 @@ def to_create_schema(self) -> CreateModelSchema:
"model_id": self.model_id,
"revision": revision,
"endpoint": endpoint,
"include": self.include,
"exclude": self.exclude,
"url": url,
}
return CreateModelSchema(
Expand Down
Loading