Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cached dir for timm & hugging-face #3914

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/otx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,19 @@

__version__ = "2.2.0rc0"

import os
from pathlib import Path

from otx.core.types import * # noqa: F403

# Set the value of HF_HUB_CACHE to set the cache folder that stores the pretrained weights for timm and huggingface.
# Refer: huggingface_hub/constants.py::HF_HUB_CACHE
# Default, Pretrained weight is saved into ~/.cache/torch/hub/checkpoints
os.environ["HF_HUB_CACHE"] = os.getenv(
"HF_HUB_CACHE",
str(Path.home() / ".cache" / "torch" / "hub" / "checkpoints"),
harimkang marked this conversation as resolved.
Show resolved Hide resolved
)

OTX_LOGO: str = """

██████╗ ████████╗ ██╗ ██╗
Expand Down
43 changes: 10 additions & 33 deletions src/otx/algo/classification/backbones/timm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,15 @@
import torch
from torch import nn

from otx.algo.utils.mmengine_utils import load_from_http

PRETRAINED_ROOT = "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/"
pretrained_urls = {
"efficientnetv2_s_21k": PRETRAINED_ROOT + "tf_efficientnetv2_s_21k-6337ad01.pth",
"efficientnetv2_s_1k": PRETRAINED_ROOT + "tf_efficientnetv2_s_21ft1k-d7dafa41.pth",
}

TIMM_MODEL_NAME_DICT = {
"mobilenetv3_large_21k": "mobilenetv3_large_100_miil_in21k",
"mobilenetv3_large_1k": "mobilenetv3_large_100_miil",
"tresnet": "tresnet_m",
"efficientnetv2_s_21k": "tf_efficientnetv2_s.in21k",
"efficientnetv2_s_1k": "tf_efficientnetv2_s_in21ft1k",
"efficientnetv2_m_21k": "tf_efficientnetv2_m_in21k",
"efficientnetv2_m_1k": "tf_efficientnetv2_m_in21ft1k",
"efficientnetv2_b0": "tf_efficientnetv2_b0",
}

TimmModelType = Literal[
"mobilenetv3_large_21k",
"mobilenetv3_large_1k",
"tresnet",
"efficientnetv2_s_21k",
"efficientnetv2_s_1k",
"efficientnetv2_m_21k",
"efficientnetv2_m_1k",
"efficientnetv2_b0",
"mobilenetv3_large_100_miil_in21k",
"mobilenetv3_large_100_miil",
"tresnet_m",
"tf_efficientnetv2_s.in21k",
"tf_efficientnetv2_s.in21ft1k",
"tf_efficientnetv2_m.in21k",
"tf_efficientnetv2_m.in21ft1k",
"tf_efficientnetv2_b0",
]


Expand All @@ -60,14 +41,10 @@ def __init__(
self.backbone = backbone
self.pretrained: bool | dict = pretrained
self.is_mobilenet = backbone.startswith("mobilenet")
if pretrained and self.backbone in pretrained_urls:
# This pretrained weight is saved into ~/.cache/torch/hub/checkpoints
# Otherwise, it is stored in ~/.cache/huggingface/hub. (timm defaults)
self.pretrained = load_from_http(filename=pretrained_urls[self.backbone])

self.model = timm.create_model(
TIMM_MODEL_NAME_DICT[self.backbone],
pretrained=self.pretrained,
self.backbone,
pretrained=pretrained,
num_classes=1000,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
model:
class_path: otx.algo.classification.timm_model.TimmModelForHLabelCls
init_args:
backbone: efficientnetv2_s_21k
backbone: tf_efficientnetv2_s.in21k

optimizer:
class_path: torch.optim.SGD
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ model:
class_path: otx.algo.classification.timm_model.TimmModelForMulticlassCls
init_args:
label_info: 1000
backbone: efficientnetv2_s_21k
backbone: tf_efficientnetv2_s.in21k

optimizer:
class_path: torch.optim.SGD
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ model:
class_path: otx.algo.classification.timm_model.TimmModelForMulticlassCls
init_args:
label_info: 1000
backbone: efficientnetv2_s_21k
backbone: tf_efficientnetv2_s.in21k
train_type: SEMI_SUPERVISED

optimizer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ model:
class_path: otx.algo.classification.timm_model.TimmModelForMultilabelCls
init_args:
label_info: 1000
backbone: efficientnetv2_s_21k
backbone: tf_efficientnetv2_s.in21k

optimizer:
class_path: torch.optim.SGD
Expand Down
16 changes: 14 additions & 2 deletions tests/unit/algo/classification/backbones/test_timm.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,28 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import os
import shutil
from pathlib import Path

import torch
from otx.algo.classification.backbones.timm import TimmBackbone


class TestOTXEfficientNetV2:
def test_forward(self):
model = TimmBackbone(backbone="efficientnetv2_s_21k")
model = TimmBackbone(backbone="tf_efficientnetv2_s.in21k")
assert model(torch.randn(1, 3, 244, 244))[0].shape == torch.Size([1, 1280, 8, 8])

def test_get_config_optim(self):
model = TimmBackbone(backbone="efficientnetv2_s_21k")
model = TimmBackbone(backbone="tf_efficientnetv2_s.in21k")
assert model.get_config_optim([0.01])[0]["lr"] == 0.01
assert model.get_config_optim(0.01)[0]["lr"] == 0.01

def test_check_pretrained_weight_download(self):
target = Path(os.environ.get("HF_HUB_CACHE")) / "models--timm--tf_efficientnetv2_s.in21k"
if target.exists():
shutil.rmtree(target)
assert not target.exists()
TimmBackbone(backbone="tf_efficientnetv2_s.in21k", pretrained=True)
assert target.exists()
6 changes: 3 additions & 3 deletions tests/unit/algo/classification/test_timm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
def fxt_multi_class_cls_model():
return TimmModelForMulticlassCls(
label_info=10,
backbone="efficientnetv2_s_21k",
backbone="tf_efficientnetv2_s.in21k",
)


Expand Down Expand Up @@ -59,7 +59,7 @@ def test_predict_step(self, fxt_multi_class_cls_model, fxt_multiclass_cls_batch_
def fxt_multi_label_cls_model():
return TimmModelForMultilabelCls(
label_info=10,
backbone="efficientnetv2_s_21k",
backbone="tf_efficientnetv2_s.in21k",
)


Expand Down Expand Up @@ -97,7 +97,7 @@ def test_predict_step(self, fxt_multi_label_cls_model, fxt_multilabel_cls_batch_
def fxt_h_label_cls_model(fxt_hlabel_cifar):
return TimmModelForHLabelCls(
label_info=fxt_hlabel_cifar,
backbone="efficientnetv2_s_21k",
backbone="tf_efficientnetv2_s.in21k",
)


Expand Down
Loading