From 946077b641a2f91d4f763a423e617456a10e9c09 Mon Sep 17 00:00:00 2001 From: Selda Uyanik <81234105+seldauyanik-maxim@users.noreply.github.com> Date: Tue, 20 Dec 2022 19:16:54 +0300 Subject: [PATCH] Modify truncation on AISegment testset (#203) --- datasets/aisegment.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/datasets/aisegment.py b/datasets/aisegment.py index c499313dc..ad504e63a 100644 --- a/datasets/aisegment.py +++ b/datasets/aisegment.py @@ -63,7 +63,7 @@ class AISegment(Dataset): num_of_imgs_to_use_hr = 20000 def __init__(self, root_dir, d_type, transform=None, im_size=(80, 80), fold_ratio=1, - use_memory=False): + use_memory=False, truncate_testset=False): if im_size not in ((80, 80), (352, 352)): raise ValueError('im_size can only be set to (80, 80) or (352, 352)') @@ -88,6 +88,8 @@ def __init__(self, root_dir, d_type, transform=None, im_size=(80, 80), fold_rati self.d_type = d_type + self.is_truncated = False + vertical_crop_area = AISegment.org_img_dim[0] - AISegment.img_crp_dim[0] if vertical_crop_area % (AISegment.num_of_cropped_imgs - 1) != 0: @@ -201,13 +203,14 @@ def __init__(self, root_dir, d_type, transform=None, im_size=(80, 80), fold_rati self.img_files_info = test_img_files_info self.dataset_pkl_file_path = test_dataset_pkl_file_path self.processed_folder_path = self.processed_test_data_folder + if truncate_testset: + self.is_truncated = True else: print(f'Unknown data type: {self.d_type}') return self.__create_pkl_files() - self.is_truncated = False def __create_pkl_files(self): if self.__check_pkl_files_exist(): @@ -407,7 +410,8 @@ def AISegment_get_datasets(data, load_train=True, load_test=True, im_size=(80, 8 train_dataset = AISegment(root_dir=data_dir, d_type='train', transform=train_transform, - im_size=im_size, fold_ratio=fold_ratio, use_memory=use_memory) + im_size=im_size, fold_ratio=fold_ratio, use_memory=use_memory, + truncate_testset=False) print(f'Train dataset length: {len(train_dataset)}\n') else: train_dataset = None @@ -420,7 +424,8 @@ def AISegment_get_datasets(data, load_train=True, load_test=True, im_size=(80, 8 test_dataset = AISegment(root_dir=data_dir, d_type='test', transform=test_transform, - im_size=im_size, fold_ratio=fold_ratio, use_memory=use_memory) + im_size=im_size, fold_ratio=fold_ratio, use_memory=use_memory, + truncate_testset=args.truncate_testset) print(f'Test dataset length: {len(test_dataset)}\n') else: