From cbba3988852ac5dc8de0d4ba1ecdf32d38bf3ed2 Mon Sep 17 00:00:00 2001 From: chaofengc Date: Sun, 18 Aug 2024 16:10:03 +0800 Subject: [PATCH] feat: :technologist: add **rough** score range for each metric --- pyiqa/api_helpers.py | 6 ++ pyiqa/archs/ahiq_arch.py | 2 +- pyiqa/archs/arch_util.py | 2 +- pyiqa/archs/brisque_arch.py | 2 +- pyiqa/archs/clipiqa_arch.py | 2 +- pyiqa/archs/liqe_arch.py | 2 +- pyiqa/dataset_info.py | 39 ++++++++++++ pyiqa/default_model_configs.py | 108 +++++++++++++++++++++++--------- pyiqa/models/inference_model.py | 1 + 9 files changed, 130 insertions(+), 34 deletions(-) create mode 100644 pyiqa/dataset_info.py diff --git a/pyiqa/api_helpers.py b/pyiqa/api_helpers.py index a62972d..56eef27 100644 --- a/pyiqa/api_helpers.py +++ b/pyiqa/api_helpers.py @@ -1,6 +1,7 @@ import fnmatch import re from pyiqa.default_model_configs import DEFAULT_CONFIGS +from pyiqa.dataset_info import DATASET_INFO from pyiqa.utils import get_root_logger from pyiqa.models.inference_model import InferenceModel @@ -49,3 +50,8 @@ def list_models(metric_mode=None, filter='', exclude_filters=''): if len(exclude_models): models = set(models).difference(exclude_models) return list(sorted(models, key=_natural_key)) + + +def get_dataset_info(dataset_name): + assert dataset_name in DATASET_INFO.keys(), f'Dataset {dataset_name} not implemented yet.' + return DATASET_INFO[dataset_name] diff --git a/pyiqa/archs/ahiq_arch.py b/pyiqa/archs/ahiq_arch.py index c884351..77acded 100644 --- a/pyiqa/archs/ahiq_arch.py +++ b/pyiqa/archs/ahiq_arch.py @@ -198,7 +198,7 @@ def __init__( ) elif pretrained: weight_path = load_file_from_url(default_model_urls["pipal"]) - checkpoint = torch.load(weight_path) + checkpoint = torch.load(weight_path, map_location='cpu', weights_only=False) self.regressor.load_state_dict(checkpoint["regressor_model_state_dict"]) self.deform_net.load_state_dict(checkpoint["deform_net_model_state_dict"]) diff --git a/pyiqa/archs/arch_util.py b/pyiqa/archs/arch_util.py index b6dc79e..2c67f7b 100644 --- a/pyiqa/archs/arch_util.py +++ b/pyiqa/archs/arch_util.py @@ -163,7 +163,7 @@ def load_pretrained_network(net, model_path, strict=True, weight_keys=None): if model_path.startswith("https://") or model_path.startswith("http://"): model_path = load_file_from_url(model_path) print(f"Loading pretrained model {net.__class__.__name__} from {model_path}") - state_dict = torch.load(model_path, map_location=torch.device("cpu")) + state_dict = torch.load(model_path, map_location=torch.device("cpu"), weights_only=False) if weight_keys is not None: state_dict = state_dict[weight_keys] state_dict = clean_state_dict(state_dict) diff --git a/pyiqa/archs/brisque_arch.py b/pyiqa/archs/brisque_arch.py index 708e24f..b6c9ea7 100644 --- a/pyiqa/archs/brisque_arch.py +++ b/pyiqa/archs/brisque_arch.py @@ -61,7 +61,7 @@ def brisque(x: torch.Tensor, scaled_features = scale_features(features) if pretrained_model_path: - sv_coef, sv = torch.load(pretrained_model_path) + sv_coef, sv = torch.load(pretrained_model_path, weights_only=False) sv_coef = sv_coef.to(x) sv = sv.to(x) diff --git a/pyiqa/archs/clipiqa_arch.py b/pyiqa/archs/clipiqa_arch.py index 5d88367..7d2fef9 100644 --- a/pyiqa/archs/clipiqa_arch.py +++ b/pyiqa/archs/clipiqa_arch.py @@ -138,7 +138,7 @@ def __init__(self, if pretrained and 'clipiqa+' in model_type: if model_type == 'clipiqa+' and backbone == 'RN50': - self.prompt_learner.ctx.data = torch.load(load_file_from_url(default_model_urls['clipiqa+'])) + self.prompt_learner.ctx.data = torch.load(load_file_from_url(default_model_urls['clipiqa+']), weights_only=False) elif model_type in default_model_urls.keys(): load_pretrained_network(self, default_model_urls[model_type], True, 'params') else: diff --git a/pyiqa/archs/liqe_arch.py b/pyiqa/archs/liqe_arch.py index 6b2c5b8..56873f4 100644 --- a/pyiqa/archs/liqe_arch.py +++ b/pyiqa/archs/liqe_arch.py @@ -75,7 +75,7 @@ def __init__(self, text_feat_cache_path = os.path.expanduser("~/.cache/pyiqa/liqe_text_feat.pt") if os.path.exists(text_feat_cache_path): - self.text_features = torch.load(text_feat_cache_path, map_location='cpu') + self.text_features = torch.load(text_feat_cache_path, map_location='cpu', weights_only=False) else: print(f'Generating text features for LIQE model, will be cached at {text_feat_cache_path}.') if self.mtl: diff --git a/pyiqa/dataset_info.py b/pyiqa/dataset_info.py new file mode 100644 index 0000000..e45459f --- /dev/null +++ b/pyiqa/dataset_info.py @@ -0,0 +1,39 @@ + +DATASET_INFO = { + "live": { + "score_range": (1, 100), + "mos_type": "dmos" + }, + "csiq": { + "score_range": (0, 1), + "mos_type": "dmos" + }, + "tid": { + "score_range": (0, 9), + "mos_type": "mos" + }, + "kadid": { + "score_range": (1, 5), + "mos_type": "mos" + }, + "koniq": { + "score_range": (1, 100), + "mos_type": "mos" + }, + "clive": { + "score_range": (1, 100), + "mos_type": "mos" + }, + "flive": { + "score_range": (1, 100), + "mos_type": "mos" + }, + "spaq": { + "score_range": (1, 100), + "mos_type": "mos" + }, + "ava": { + "score_range": (1, 10), + "mos_type": "mos" + }, +} \ No newline at end of file diff --git a/pyiqa/default_model_configs.py b/pyiqa/default_model_configs.py index aff1609..b98984c 100644 --- a/pyiqa/default_model_configs.py +++ b/pyiqa/default_model_configs.py @@ -1,17 +1,22 @@ from collections import OrderedDict +# IMPORTANT NOTES !!! +# - The score range (min, max) is only rough estimation, the actual score range may vary. + DEFAULT_CONFIGS = OrderedDict({ 'ahiq': { 'metric_opts': { 'type': 'AHIQ', }, 'metric_mode': 'FR', + 'score_range': '~0, ~1', }, 'ckdn': { 'metric_opts': { 'type': 'CKDN', }, 'metric_mode': 'FR', + 'score_range': '0, 1', }, 'lpips': { 'metric_opts': { @@ -21,6 +26,7 @@ }, 'metric_mode': 'FR', 'lower_better': True, + 'score_range': '0, 1', }, 'lpips-vgg': { 'metric_opts': { @@ -30,6 +36,7 @@ }, 'metric_mode': 'FR', 'lower_better': True, + 'score_range': '0, 1', }, 'lpips+': { 'metric_opts': { @@ -40,6 +47,7 @@ }, 'metric_mode': 'FR', 'lower_better': True, + 'score_range': '0, 1', }, 'lpips-vgg+': { 'metric_opts': { @@ -50,6 +58,7 @@ }, 'metric_mode': 'FR', 'lower_better': True, + 'score_range': '0, 1', }, 'stlpips': { 'metric_opts': { @@ -59,6 +68,7 @@ }, 'metric_mode': 'FR', 'lower_better': True, + 'score_range': '0, 1', }, 'stlpips-vgg': { 'metric_opts': { @@ -68,6 +78,7 @@ }, 'metric_mode': 'FR', 'lower_better': True, + 'score_range': '0, 1', }, 'dists': { 'metric_opts': { @@ -75,6 +86,7 @@ }, 'metric_mode': 'FR', 'lower_better': True, + 'score_range': '0, 1', }, 'ssim': { 'metric_opts': { @@ -83,6 +95,7 @@ 'test_y_channel': True, }, 'metric_mode': 'FR', + 'score_range': '0, 1', }, 'ssimc': { 'metric_opts': { @@ -91,6 +104,7 @@ 'test_y_channel': False, }, 'metric_mode': 'FR', + 'score_range': '0, 1', }, 'psnr': { 'metric_opts': { @@ -98,6 +112,7 @@ 'test_y_channel': False, }, 'metric_mode': 'FR', + 'score_range': '~0, ~40', }, 'psnry': { 'metric_opts': { @@ -105,6 +120,7 @@ 'test_y_channel': True, }, 'metric_mode': 'FR', + 'score_range': '~0, ~60', }, 'fsim': { 'metric_opts': { @@ -112,6 +128,7 @@ 'chromatic': True, }, 'metric_mode': 'FR', + 'score_range': '0, 1', }, 'ms_ssim': { 'metric_opts': { @@ -121,12 +138,14 @@ 'is_prod': True, }, 'metric_mode': 'FR', + 'score_range': '0, 1', }, 'vif': { 'metric_opts': { 'type': 'VIF', }, 'metric_mode': 'FR', + 'score_range': '0, ~1', }, 'gmsd': { 'metric_opts': { @@ -135,6 +154,7 @@ }, 'metric_mode': 'FR', 'lower_better': True, + 'score_range': '0, ~1', }, 'nlpd': { 'metric_opts': { @@ -144,12 +164,14 @@ }, 'metric_mode': 'FR', 'lower_better': True, + 'score_range': '0, 1', }, 'vsi': { 'metric_opts': { 'type': 'VSI', }, 'metric_mode': 'FR', + 'score_range': '0, ~1', }, 'cw_ssim': { 'metric_opts': { @@ -160,6 +182,7 @@ 'test_y_channel': True, }, 'metric_mode': 'FR', + 'score_range': '0, 1', }, 'mad': { 'metric_opts': { @@ -167,6 +190,7 @@ }, 'metric_mode': 'FR', 'lower_better': True, + 'score_range': '0, ~', }, # ============================================================= 'niqe': { @@ -176,6 +200,7 @@ }, 'metric_mode': 'NR', 'lower_better': True, + 'score_range': '~0, ~100', }, 'ilniqe': { 'metric_opts': { @@ -183,6 +208,7 @@ }, 'metric_mode': 'NR', 'lower_better': True, + 'score_range': '~0, ~100', }, 'brisque': { 'metric_opts': { @@ -191,12 +217,14 @@ }, 'metric_mode': 'NR', 'lower_better': True, + 'score_range': '~0, ~150', }, 'nrqm': { 'metric_opts': { 'type': 'NRQM', }, 'metric_mode': 'NR', + 'score_range': '~0, ~10', }, 'pi': { 'metric_opts': { @@ -204,6 +232,7 @@ }, 'metric_mode': 'NR', 'lower_better': True, + 'score_range': '~0, ~', }, 'cnniqa': { 'metric_opts': { @@ -211,6 +240,7 @@ 'pretrained': 'koniq10k' }, 'metric_mode': 'NR', + 'score_range': '~0, ~1', }, 'musiq': { 'metric_opts': { @@ -218,6 +248,7 @@ 'pretrained': 'koniq10k' }, 'metric_mode': 'NR', + 'score_range': '~0, ~100', }, 'musiq-ava': { 'metric_opts': { @@ -225,13 +256,7 @@ 'pretrained': 'ava' }, 'metric_mode': 'NR', - }, - 'musiq-koniq': { - 'metric_opts': { - 'type': 'MUSIQ', - 'pretrained': 'koniq10k' - }, - 'metric_mode': 'NR', + 'score_range': '1, 10', }, 'musiq-paq2piq': { 'metric_opts': { @@ -239,6 +264,7 @@ 'pretrained': 'paq2piq' }, 'metric_mode': 'NR', + 'score_range': '~0, ~100', }, 'musiq-spaq': { 'metric_opts': { @@ -246,6 +272,7 @@ 'pretrained': 'spaq' }, 'metric_mode': 'NR', + 'score_range': '~0, ~100', }, 'nima': { 'metric_opts': { @@ -254,6 +281,7 @@ 'base_model_name': 'inception_resnet_v2', }, 'metric_mode': 'NR', + 'score_range': '0, 10', }, 'nima-koniq': { 'metric_opts': { @@ -263,6 +291,7 @@ 'base_model_name': 'inception_resnet_v2', }, 'metric_mode': 'NR', + 'score_range': '~0, ~1', }, 'nima-spaq': { 'metric_opts': { @@ -272,6 +301,7 @@ 'base_model_name': 'inception_resnet_v2', }, 'metric_mode': 'NR', + 'score_range': '~0, ~1', }, 'nima-vgg16-ava': { 'metric_opts': { @@ -280,6 +310,7 @@ 'base_model_name': 'vgg16', }, 'metric_mode': 'NR', + 'score_range': '0, 10', }, 'pieapp': { 'metric_opts': { @@ -287,12 +318,14 @@ }, 'metric_mode': 'FR', 'lower_better': True, + 'score_range': '~0, ~5', }, 'paq2piq': { 'metric_opts': { 'type': 'PAQ2PIQ', }, 'metric_mode': 'NR', + 'score_range': '~0, ~100', }, 'dbcnn': { 'metric_opts': { @@ -300,6 +333,7 @@ 'pretrained': 'koniq' }, 'metric_mode': 'NR', + 'score_range': '~0, ~1', }, 'fid': { 'metric_opts': { @@ -307,6 +341,7 @@ }, 'metric_mode': 'NR', 'lower_better': True, + 'score_range': '0, ~', }, 'maniqa': { 'metric_opts': { @@ -315,14 +350,7 @@ 'scale': 0.8, }, 'metric_mode': 'NR', - }, - 'maniqa-koniq': { - 'metric_opts': { - 'type': 'MANIQA', - 'train_dataset': 'koniq', - 'scale': 0.8, - }, - 'metric_mode': 'NR', + 'score_range': '~0, ~1', }, 'maniqa-pipal': { 'metric_opts': { @@ -330,6 +358,7 @@ 'train_dataset': 'pipal', }, 'metric_mode': 'NR', + 'score_range': '~0, ~1', }, 'maniqa-kadid': { 'metric_opts': { @@ -338,12 +367,14 @@ 'scale': 0.8, }, 'metric_mode': 'NR', + 'score_range': '~0, ~1', }, 'clipiqa': { 'metric_opts': { 'type': 'CLIPIQA', }, 'metric_mode': 'NR', + 'score_range': '0, 1', }, 'clipiqa+': { 'metric_opts': { @@ -351,6 +382,7 @@ 'model_type': 'clipiqa+', }, 'metric_mode': 'NR', + 'score_range': '0, 1', }, 'clipiqa+_vitL14_512': { 'metric_opts': { @@ -360,6 +392,7 @@ 'pos_embedding': True, }, 'metric_mode': 'NR', + 'score_range': '0, 1', }, 'clipiqa+_rn50_512': { 'metric_opts': { @@ -369,6 +402,7 @@ 'pos_embedding': True, }, 'metric_mode': 'NR', + 'score_range': '0, 1', }, 'tres': { 'metric_opts': { @@ -376,13 +410,7 @@ 'train_dataset': 'koniq', }, 'metric_mode': 'NR', - }, - 'tres-koniq': { - 'metric_opts': { - 'type': 'TReS', - 'train_dataset': 'koniq', - }, - 'metric_mode': 'NR', + 'score_range': '~0, ~100', }, 'tres-flive': { 'metric_opts': { @@ -390,30 +418,35 @@ 'train_dataset': 'flive', }, 'metric_mode': 'NR', + 'score_range': '~0, ~100', }, 'hyperiqa': { 'metric_opts': { 'type': 'HyperNet', }, 'metric_mode': 'NR', + 'score_range': '~0, ~1', }, 'uranker': { 'metric_opts': { 'type': 'URanker', }, 'metric_mode': 'NR', + 'score_range': '~-1, ~2', }, 'clipscore': { 'metric_opts': { 'type': 'CLIPScore', }, 'metric_mode': 'NR', # Caption image similarity + 'score_range': '0, 2.5' }, 'entropy': { 'metric_opts': { 'type': 'Entropy', }, 'metric_mode': 'NR', + 'score_range': '0, 8' }, 'topiq_nr': { 'metric_opts': { @@ -423,6 +456,7 @@ 'use_ref': False, }, 'metric_mode': 'NR', + 'score_range': '~0, ~1', }, 'topiq_nr-flive': { 'metric_opts': { @@ -432,6 +466,7 @@ 'use_ref': False, }, 'metric_mode': 'NR', + 'score_range': '~0, ~1', }, 'topiq_nr-spaq': { 'metric_opts': { @@ -441,6 +476,7 @@ 'use_ref': False, }, 'metric_mode': 'NR', + 'score_range': '~0, ~1', }, 'topiq_nr-face': { 'metric_opts': { @@ -451,6 +487,7 @@ 'test_img_size': 512, }, 'metric_mode': 'NR', + 'score_range': '~0, ~1', }, 'topiq_fr': { 'metric_opts': { @@ -460,6 +497,7 @@ 'use_ref': True, }, 'metric_mode': 'FR', + 'score_range': '~0, ~1', }, 'topiq_fr-pipal': { 'metric_opts': { @@ -469,6 +507,7 @@ 'use_ref': True, }, 'metric_mode': 'FR', + 'score_range': '~0, ~1', }, 'topiq_iaa': { 'metric_opts': { @@ -481,6 +520,7 @@ 'num_class': 10, }, 'metric_mode': 'NR', + 'score_range': '1, 10', }, 'topiq_iaa_res50': { 'metric_opts': { @@ -494,12 +534,14 @@ 'test_img_size': 384, }, 'metric_mode': 'NR', + 'score_range': '1, 10', }, 'laion_aes': { 'metric_opts': { 'type': 'LAIONAes', }, 'metric_mode': 'NR', + 'score_range': '~1, ~10', }, 'liqe': { 'metric_opts': { @@ -507,6 +549,7 @@ 'pretrained': 'koniq' }, 'metric_mode': 'NR', + 'score_range': '1, 5' }, 'liqe_mix': { 'metric_opts': { @@ -514,6 +557,7 @@ 'pretrained': 'mix' }, 'metric_mode': 'NR', + 'score_range': '1, 5' }, 'wadiqam_fr': { 'metric_opts': { @@ -522,6 +566,7 @@ 'model_name': 'wadiqam_fr_kadid', }, 'metric_mode': 'FR', + 'score_range': '~-1, ~0.1', }, 'wadiqam_nr': { 'metric_opts': { @@ -530,18 +575,21 @@ 'model_name': 'wadiqam_nr_koniq', }, 'metric_mode': 'NR', + 'score_range': '~-1, ~0.1', }, 'qalign': { 'metric_opts': { 'type': 'QAlign', }, 'metric_mode': 'NR', + 'score_range': '0, 1', }, 'unique': { 'metric_opts': { 'type': 'UNIQUE', }, 'metric_mode': 'NR', + 'score_range': '~-3, ~3', }, 'inception_score': { 'metric_opts': { @@ -549,6 +597,7 @@ }, 'metric_mode': 'NR', 'lower_better': False, + 'score_range': '0, ~', }, 'arniqa': { 'metric_opts': { @@ -556,6 +605,7 @@ 'regressor_dataset': 'koniq', }, 'metric_mode': 'NR', + 'score_range': '0, 1', }, 'arniqa-live': { 'metric_opts': { @@ -563,6 +613,7 @@ 'regressor_dataset': 'live', }, 'metric_mode': 'NR', + 'score_range': '0, 1', }, 'arniqa-csiq': { 'metric_opts': { @@ -570,6 +621,7 @@ 'regressor_dataset': 'csiq', }, 'metric_mode': 'NR', + 'score_range': '0, 1', }, 'arniqa-tid': { 'metric_opts': { @@ -577,6 +629,7 @@ 'regressor_dataset': 'tid', }, 'metric_mode': 'NR', + 'score_range': '0, 1', }, 'arniqa-kadid': { 'metric_opts': { @@ -584,13 +637,7 @@ 'regressor_dataset': 'kadid', }, 'metric_mode': 'NR', - }, - 'arniqa-koniq': { - 'metric_opts': { - 'type': 'ARNIQA', - 'regressor_dataset': 'koniq', - }, - 'metric_mode': 'NR', + 'score_range': '0, 1', }, 'arniqa-clive': { 'metric_opts': { @@ -598,6 +645,7 @@ 'regressor_dataset': 'clive', }, 'metric_mode': 'NR', + 'score_range': '0, 1', }, 'arniqa-flive': { 'metric_opts': { @@ -605,6 +653,7 @@ 'regressor_dataset': 'flive', }, 'metric_mode': 'NR', + 'score_range': '0, 1', }, 'arniqa-spaq': { 'metric_opts': { @@ -612,5 +661,6 @@ 'regressor_dataset': 'spaq', }, 'metric_mode': 'NR', + 'score_range': '0, 1', }, }) diff --git a/pyiqa/models/inference_model.py b/pyiqa/models/inference_model.py index 6b5771f..98df41b 100644 --- a/pyiqa/models/inference_model.py +++ b/pyiqa/models/inference_model.py @@ -28,6 +28,7 @@ def __init__( # ============ set metric properties =========== self.lower_better = DEFAULT_CONFIGS[metric_name].get('lower_better', False) self.metric_mode = DEFAULT_CONFIGS[metric_name].get('metric_mode', None) + self.score_range = DEFAULT_CONFIGS[metric_name].get('score_range', None) if self.metric_mode is None: self.metric_mode = kwargs.pop('metric_mode') elif 'metric_mode' in kwargs: