Skip to content

Commit

Permalink
fix dataset config setting
Browse files Browse the repository at this point in the history
  • Loading branch information
rogerwwww committed Nov 15, 2023
1 parent ae1c1eb commit 79e8ed9
Showing 1 changed file with 51 additions and 46 deletions.
97 changes: 51 additions & 46 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,59 +67,64 @@ def test_dataset_and_benchmark():
problem_type_list = ['2GM', 'MGM']
set_list = ['train', 'test']
filter_list = ['intersection', 'inclusion', 'unfiltered']
dict_list = []

willow_cfg_dict = dict()
willow_cfg_dict['CLASSES'] = dataset_cfg.WillowObject.CLASSES
willow_cfg_dict['KPT_LEN'] = dataset_cfg.WillowObject.KPT_LEN
willow_cfg_dict['ROOT_DIR'] = dataset_cfg.WillowObject.ROOT_DIR
willow_cfg_dict['TRAIN_NUM'] = dataset_cfg.WillowObject.TRAIN_NUM
willow_cfg_dict['SPLIT_OFFSET'] = dataset_cfg.WillowObject.SPLIT_OFFSET
willow_cfg_dict['TRAIN_SAME_AS_TEST'] = dataset_cfg.WillowObject.TRAIN_SAME_AS_TEST
willow_cfg_dict['RAND_OUTLIER'] = dataset_cfg.WillowObject.RAND_OUTLIER
willow_cfg_dict['URL'] = 'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=18AvGwkuhnih5bFDjfJK5NYM16LvDfwW_'
dict_list.append(willow_cfg_dict)

voc_cfg_dict = dict()
voc_cfg_dict['KPT_ANNO_DIR'] = dataset_cfg.PascalVOC.KPT_ANNO_DIR
voc_cfg_dict['ROOT_DIR'] = dataset_cfg.PascalVOC.ROOT_DIR
voc_cfg_dict['SET_SPLIT'] = dataset_cfg.PascalVOC.SET_SPLIT
voc_cfg_dict['CLASSES'] = dataset_cfg.PascalVOC.CLASSES
voc_cfg_dict['CACHE_PATH'] = dataset_cfg.CACHE_PATH
voc_cfg_dict['URL'] = 'https://huggingface.co/datasets/ziaoguo/small_VOC/resolve/main/small_voc.tar?download=true'
dict_list.append(voc_cfg_dict)

spair_cfg_dict = dict()
spair_cfg_dict['TRAIN_DIFF_PARAMS'] = {'mirror': 0}
spair_cfg_dict['EVAL_DIFF_PARAMS'] = dataset_cfg.SPair.EVAL_DIFF_PARAMS
spair_cfg_dict['COMB_CLS'] = True
spair_cfg_dict['SIZE'] = 'small'
spair_cfg_dict['ROOT_DIR'] = dataset_cfg.SPair.ROOT_DIR
dict_list.append(spair_cfg_dict)

imcpt_cfg_dict = dict()
imcpt_cfg_dict['MAX_KPT_NUM'] = dataset_cfg.IMC_PT_SparseGM.MAX_KPT_NUM
imcpt_cfg_dict['CLASSES'] = {'train': ['brandenburg_gate'],
'test': ['reichstag']}
imcpt_cfg_dict['ROOT_DIR_NPZ'] = dataset_cfg.IMC_PT_SparseGM.ROOT_DIR_NPZ
imcpt_cfg_dict['ROOT_DIR_IMG'] = dataset_cfg.IMC_PT_SparseGM.ROOT_DIR_IMG
imcpt_cfg_dict['URL'] = 'https://drive.google.com/u/0/uc?id=1bisri2Ip1Of3RsUA8OBrdH5oa6HlH3k-&export=download'
dict_list.append(imcpt_cfg_dict)

cub_cfg_dict = dict()
cub_cfg_dict['ROOT_DIR'] = dataset_cfg.CUB2011.ROOT_DIR
cub_cfg_dict['URL'] = 'https://drive.google.com/u/0/uc?id=1fcN3m2PmQF7rMQGPxldEICU8CtJ0-F-z&export=download'
dict_list.append(cub_cfg_dict)

for i, dataset_name in enumerate(dataset_name_list):
all_cfgs = {}

if 'WillowObject' in dataset_name_list:
willow_cfg_dict = dict()
willow_cfg_dict['CLASSES'] = dataset_cfg.WillowObject.CLASSES
willow_cfg_dict['KPT_LEN'] = dataset_cfg.WillowObject.KPT_LEN
willow_cfg_dict['ROOT_DIR'] = dataset_cfg.WillowObject.ROOT_DIR
willow_cfg_dict['TRAIN_NUM'] = dataset_cfg.WillowObject.TRAIN_NUM
willow_cfg_dict['SPLIT_OFFSET'] = dataset_cfg.WillowObject.SPLIT_OFFSET
willow_cfg_dict['TRAIN_SAME_AS_TEST'] = dataset_cfg.WillowObject.TRAIN_SAME_AS_TEST
willow_cfg_dict['RAND_OUTLIER'] = dataset_cfg.WillowObject.RAND_OUTLIER
willow_cfg_dict['URL'] = 'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=18AvGwkuhnih5bFDjfJK5NYM16LvDfwW_'
all_cfgs['WillowObject'] = willow_cfg_dict

if 'PascalVOC' in dataset_name_list:
voc_cfg_dict = dict()
voc_cfg_dict['KPT_ANNO_DIR'] = dataset_cfg.PascalVOC.KPT_ANNO_DIR
voc_cfg_dict['ROOT_DIR'] = dataset_cfg.PascalVOC.ROOT_DIR
voc_cfg_dict['SET_SPLIT'] = dataset_cfg.PascalVOC.SET_SPLIT
voc_cfg_dict['CLASSES'] = dataset_cfg.PascalVOC.CLASSES
voc_cfg_dict['CACHE_PATH'] = dataset_cfg.CACHE_PATH
voc_cfg_dict['URL'] = 'https://huggingface.co/datasets/ziaoguo/small_VOC/resolve/main/small_voc.tar?download=true'
all_cfgs['PascalVOC'] = voc_cfg_dict

if 'SPair71k' in dataset_name_list:
spair_cfg_dict = dict()
spair_cfg_dict['TRAIN_DIFF_PARAMS'] = {'mirror': 0}
spair_cfg_dict['EVAL_DIFF_PARAMS'] = dataset_cfg.SPair.EVAL_DIFF_PARAMS
spair_cfg_dict['COMB_CLS'] = True
spair_cfg_dict['SIZE'] = 'small'
spair_cfg_dict['ROOT_DIR'] = dataset_cfg.SPair.ROOT_DIR
all_cfgs['SPair71k'] = spair_cfg_dict

if 'IMC_PT_SparseGM' in dataset_name_list:
imcpt_cfg_dict = dict()
imcpt_cfg_dict['MAX_KPT_NUM'] = dataset_cfg.IMC_PT_SparseGM.MAX_KPT_NUM
imcpt_cfg_dict['CLASSES'] = {'train': ['brandenburg_gate'],
'test': ['reichstag']}
imcpt_cfg_dict['ROOT_DIR_NPZ'] = dataset_cfg.IMC_PT_SparseGM.ROOT_DIR_NPZ
imcpt_cfg_dict['ROOT_DIR_IMG'] = dataset_cfg.IMC_PT_SparseGM.ROOT_DIR_IMG
imcpt_cfg_dict['URL'] = 'https://drive.google.com/u/0/uc?id=1bisri2Ip1Of3RsUA8OBrdH5oa6HlH3k-&export=download'
all_cfgs['IMC_PT_SparseGM'] = imcpt_cfg_dict

if 'CUB2011' in dataset_name_list:
cub_cfg_dict = dict()
cub_cfg_dict['ROOT_DIR'] = dataset_cfg.CUB2011.ROOT_DIR
cub_cfg_dict['URL'] = 'https://drive.google.com/u/0/uc?id=1fcN3m2PmQF7rMQGPxldEICU8CtJ0-F-z&export=download'
all_cfgs['CUB2011'] = cub_cfg_dict

for dataset_name in dataset_name_list:
for set in set_list:
for problem_type in problem_type_list:
filter = choice(filter_list)
if dataset_name == 'SPair71k' and problem_type == 'MGM':
continue
if filter == 'inclusion' and problem_type == 'MGM':
continue
_test_benchmark(dataset_name, set, problem_type, filter, **dict_list[i])
_test_benchmark(dataset_name, set, problem_type, filter, **all_cfgs[dataset_name])


if __name__ == '__main__':
Expand Down

0 comments on commit 79e8ed9

Please sign in to comment.