Skip to content

Commit

Permalink
feat: 🧑‍💻 add **rough** score range for each metric
Browse files Browse the repository at this point in the history
  • Loading branch information
chaofengc committed Aug 18, 2024
1 parent b57dede commit cbba398
Show file tree
Hide file tree
Showing 9 changed files with 130 additions and 34 deletions.
6 changes: 6 additions & 0 deletions pyiqa/api_helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import fnmatch
import re
from pyiqa.default_model_configs import DEFAULT_CONFIGS
from pyiqa.dataset_info import DATASET_INFO

from pyiqa.utils import get_root_logger
from pyiqa.models.inference_model import InferenceModel
Expand Down Expand Up @@ -49,3 +50,8 @@ def list_models(metric_mode=None, filter='', exclude_filters=''):
if len(exclude_models):
models = set(models).difference(exclude_models)
return list(sorted(models, key=_natural_key))


def get_dataset_info(dataset_name):
assert dataset_name in DATASET_INFO.keys(), f'Dataset {dataset_name} not implemented yet.'
return DATASET_INFO[dataset_name]
2 changes: 1 addition & 1 deletion pyiqa/archs/ahiq_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def __init__(
)
elif pretrained:
weight_path = load_file_from_url(default_model_urls["pipal"])
checkpoint = torch.load(weight_path)
checkpoint = torch.load(weight_path, map_location='cpu', weights_only=False)
self.regressor.load_state_dict(checkpoint["regressor_model_state_dict"])
self.deform_net.load_state_dict(checkpoint["deform_net_model_state_dict"])

Expand Down
2 changes: 1 addition & 1 deletion pyiqa/archs/arch_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def load_pretrained_network(net, model_path, strict=True, weight_keys=None):
if model_path.startswith("https://") or model_path.startswith("http://"):
model_path = load_file_from_url(model_path)
print(f"Loading pretrained model {net.__class__.__name__} from {model_path}")
state_dict = torch.load(model_path, map_location=torch.device("cpu"))
state_dict = torch.load(model_path, map_location=torch.device("cpu"), weights_only=False)
if weight_keys is not None:
state_dict = state_dict[weight_keys]
state_dict = clean_state_dict(state_dict)
Expand Down
2 changes: 1 addition & 1 deletion pyiqa/archs/brisque_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def brisque(x: torch.Tensor,
scaled_features = scale_features(features)

if pretrained_model_path:
sv_coef, sv = torch.load(pretrained_model_path)
sv_coef, sv = torch.load(pretrained_model_path, weights_only=False)
sv_coef = sv_coef.to(x)
sv = sv.to(x)

Expand Down
2 changes: 1 addition & 1 deletion pyiqa/archs/clipiqa_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(self,

if pretrained and 'clipiqa+' in model_type:
if model_type == 'clipiqa+' and backbone == 'RN50':
self.prompt_learner.ctx.data = torch.load(load_file_from_url(default_model_urls['clipiqa+']))
self.prompt_learner.ctx.data = torch.load(load_file_from_url(default_model_urls['clipiqa+']), weights_only=False)
elif model_type in default_model_urls.keys():
load_pretrained_network(self, default_model_urls[model_type], True, 'params')
else:
Expand Down
2 changes: 1 addition & 1 deletion pyiqa/archs/liqe_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(self,
text_feat_cache_path = os.path.expanduser("~/.cache/pyiqa/liqe_text_feat.pt")

if os.path.exists(text_feat_cache_path):
self.text_features = torch.load(text_feat_cache_path, map_location='cpu')
self.text_features = torch.load(text_feat_cache_path, map_location='cpu', weights_only=False)
else:
print(f'Generating text features for LIQE model, will be cached at {text_feat_cache_path}.')
if self.mtl:
Expand Down
39 changes: 39 additions & 0 deletions pyiqa/dataset_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@

DATASET_INFO = {
"live": {
"score_range": (1, 100),
"mos_type": "dmos"
},
"csiq": {
"score_range": (0, 1),
"mos_type": "dmos"
},
"tid": {
"score_range": (0, 9),
"mos_type": "mos"
},
"kadid": {
"score_range": (1, 5),
"mos_type": "mos"
},
"koniq": {
"score_range": (1, 100),
"mos_type": "mos"
},
"clive": {
"score_range": (1, 100),
"mos_type": "mos"
},
"flive": {
"score_range": (1, 100),
"mos_type": "mos"
},
"spaq": {
"score_range": (1, 100),
"mos_type": "mos"
},
"ava": {
"score_range": (1, 10),
"mos_type": "mos"
},
}
Loading

0 comments on commit cbba398

Please sign in to comment.