From 79e8ed9dc81439bdf7489849dc942068b4bc285f Mon Sep 17 00:00:00 2001 From: roger <18309862+rogerwwww@users.noreply.github.com> Date: Wed, 15 Nov 2023 16:24:23 -0500 Subject: [PATCH] fix dataset config setting --- tests/test_dataset.py | 97 +++++++++++++++++++++++-------------------- 1 file changed, 51 insertions(+), 46 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 0cc492b..97648f1 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -67,51 +67,56 @@ def test_dataset_and_benchmark(): problem_type_list = ['2GM', 'MGM'] set_list = ['train', 'test'] filter_list = ['intersection', 'inclusion', 'unfiltered'] - dict_list = [] - - willow_cfg_dict = dict() - willow_cfg_dict['CLASSES'] = dataset_cfg.WillowObject.CLASSES - willow_cfg_dict['KPT_LEN'] = dataset_cfg.WillowObject.KPT_LEN - willow_cfg_dict['ROOT_DIR'] = dataset_cfg.WillowObject.ROOT_DIR - willow_cfg_dict['TRAIN_NUM'] = dataset_cfg.WillowObject.TRAIN_NUM - willow_cfg_dict['SPLIT_OFFSET'] = dataset_cfg.WillowObject.SPLIT_OFFSET - willow_cfg_dict['TRAIN_SAME_AS_TEST'] = dataset_cfg.WillowObject.TRAIN_SAME_AS_TEST - willow_cfg_dict['RAND_OUTLIER'] = dataset_cfg.WillowObject.RAND_OUTLIER - willow_cfg_dict['URL'] = 'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=18AvGwkuhnih5bFDjfJK5NYM16LvDfwW_' - dict_list.append(willow_cfg_dict) - - voc_cfg_dict = dict() - voc_cfg_dict['KPT_ANNO_DIR'] = dataset_cfg.PascalVOC.KPT_ANNO_DIR - voc_cfg_dict['ROOT_DIR'] = dataset_cfg.PascalVOC.ROOT_DIR - voc_cfg_dict['SET_SPLIT'] = dataset_cfg.PascalVOC.SET_SPLIT - voc_cfg_dict['CLASSES'] = dataset_cfg.PascalVOC.CLASSES - voc_cfg_dict['CACHE_PATH'] = dataset_cfg.CACHE_PATH - voc_cfg_dict['URL'] = 'https://huggingface.co/datasets/ziaoguo/small_VOC/resolve/main/small_voc.tar?download=true' - dict_list.append(voc_cfg_dict) - - spair_cfg_dict = dict() - spair_cfg_dict['TRAIN_DIFF_PARAMS'] = {'mirror': 0} - spair_cfg_dict['EVAL_DIFF_PARAMS'] = dataset_cfg.SPair.EVAL_DIFF_PARAMS - spair_cfg_dict['COMB_CLS'] = True - spair_cfg_dict['SIZE'] = 'small' - spair_cfg_dict['ROOT_DIR'] = dataset_cfg.SPair.ROOT_DIR - dict_list.append(spair_cfg_dict) - - imcpt_cfg_dict = dict() - imcpt_cfg_dict['MAX_KPT_NUM'] = dataset_cfg.IMC_PT_SparseGM.MAX_KPT_NUM - imcpt_cfg_dict['CLASSES'] = {'train': ['brandenburg_gate'], - 'test': ['reichstag']} - imcpt_cfg_dict['ROOT_DIR_NPZ'] = dataset_cfg.IMC_PT_SparseGM.ROOT_DIR_NPZ - imcpt_cfg_dict['ROOT_DIR_IMG'] = dataset_cfg.IMC_PT_SparseGM.ROOT_DIR_IMG - imcpt_cfg_dict['URL'] = 'https://drive.google.com/u/0/uc?id=1bisri2Ip1Of3RsUA8OBrdH5oa6HlH3k-&export=download' - dict_list.append(imcpt_cfg_dict) - - cub_cfg_dict = dict() - cub_cfg_dict['ROOT_DIR'] = dataset_cfg.CUB2011.ROOT_DIR - cub_cfg_dict['URL'] = 'https://drive.google.com/u/0/uc?id=1fcN3m2PmQF7rMQGPxldEICU8CtJ0-F-z&export=download' - dict_list.append(cub_cfg_dict) - - for i, dataset_name in enumerate(dataset_name_list): + all_cfgs = {} + + if 'WillowObject' in dataset_name_list: + willow_cfg_dict = dict() + willow_cfg_dict['CLASSES'] = dataset_cfg.WillowObject.CLASSES + willow_cfg_dict['KPT_LEN'] = dataset_cfg.WillowObject.KPT_LEN + willow_cfg_dict['ROOT_DIR'] = dataset_cfg.WillowObject.ROOT_DIR + willow_cfg_dict['TRAIN_NUM'] = dataset_cfg.WillowObject.TRAIN_NUM + willow_cfg_dict['SPLIT_OFFSET'] = dataset_cfg.WillowObject.SPLIT_OFFSET + willow_cfg_dict['TRAIN_SAME_AS_TEST'] = dataset_cfg.WillowObject.TRAIN_SAME_AS_TEST + willow_cfg_dict['RAND_OUTLIER'] = dataset_cfg.WillowObject.RAND_OUTLIER + willow_cfg_dict['URL'] = 'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=18AvGwkuhnih5bFDjfJK5NYM16LvDfwW_' + all_cfgs['WillowObject'] = willow_cfg_dict + + if 'PascalVOC' in dataset_name_list: + voc_cfg_dict = dict() + voc_cfg_dict['KPT_ANNO_DIR'] = dataset_cfg.PascalVOC.KPT_ANNO_DIR + voc_cfg_dict['ROOT_DIR'] = dataset_cfg.PascalVOC.ROOT_DIR + voc_cfg_dict['SET_SPLIT'] = dataset_cfg.PascalVOC.SET_SPLIT + voc_cfg_dict['CLASSES'] = dataset_cfg.PascalVOC.CLASSES + voc_cfg_dict['CACHE_PATH'] = dataset_cfg.CACHE_PATH + voc_cfg_dict['URL'] = 'https://huggingface.co/datasets/ziaoguo/small_VOC/resolve/main/small_voc.tar?download=true' + all_cfgs['PascalVOC'] = voc_cfg_dict + + if 'SPair71k' in dataset_name_list: + spair_cfg_dict = dict() + spair_cfg_dict['TRAIN_DIFF_PARAMS'] = {'mirror': 0} + spair_cfg_dict['EVAL_DIFF_PARAMS'] = dataset_cfg.SPair.EVAL_DIFF_PARAMS + spair_cfg_dict['COMB_CLS'] = True + spair_cfg_dict['SIZE'] = 'small' + spair_cfg_dict['ROOT_DIR'] = dataset_cfg.SPair.ROOT_DIR + all_cfgs['SPair71k'] = spair_cfg_dict + + if 'IMC_PT_SparseGM' in dataset_name_list: + imcpt_cfg_dict = dict() + imcpt_cfg_dict['MAX_KPT_NUM'] = dataset_cfg.IMC_PT_SparseGM.MAX_KPT_NUM + imcpt_cfg_dict['CLASSES'] = {'train': ['brandenburg_gate'], + 'test': ['reichstag']} + imcpt_cfg_dict['ROOT_DIR_NPZ'] = dataset_cfg.IMC_PT_SparseGM.ROOT_DIR_NPZ + imcpt_cfg_dict['ROOT_DIR_IMG'] = dataset_cfg.IMC_PT_SparseGM.ROOT_DIR_IMG + imcpt_cfg_dict['URL'] = 'https://drive.google.com/u/0/uc?id=1bisri2Ip1Of3RsUA8OBrdH5oa6HlH3k-&export=download' + all_cfgs['IMC_PT_SparseGM'] = imcpt_cfg_dict + + if 'CUB2011' in dataset_name_list: + cub_cfg_dict = dict() + cub_cfg_dict['ROOT_DIR'] = dataset_cfg.CUB2011.ROOT_DIR + cub_cfg_dict['URL'] = 'https://drive.google.com/u/0/uc?id=1fcN3m2PmQF7rMQGPxldEICU8CtJ0-F-z&export=download' + all_cfgs['CUB2011'] = cub_cfg_dict + + for dataset_name in dataset_name_list: for set in set_list: for problem_type in problem_type_list: filter = choice(filter_list) @@ -119,7 +124,7 @@ def test_dataset_and_benchmark(): continue if filter == 'inclusion' and problem_type == 'MGM': continue - _test_benchmark(dataset_name, set, problem_type, filter, **dict_list[i]) + _test_benchmark(dataset_name, set, problem_type, filter, **all_cfgs[dataset_name]) if __name__ == '__main__':