From 525bb6b2fe9a5fb65c735153b52cc5e5bdfc5cda Mon Sep 17 00:00:00 2001 From: lkm2835 Date: Fri, 29 Oct 2021 00:12:20 +0900 Subject: [PATCH 01/12] Fix typo in usage example --- configs/segformer/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/segformer/README.md b/configs/segformer/README.md index 58c6a1c90f..b57160cdf4 100644 --- a/configs/segformer/README.md +++ b/configs/segformer/README.md @@ -29,7 +29,7 @@ To use other repositories' pre-trained models, it is necessary to convert keys. We provide a script [`mit2mmseg.py`](../../tools/model_converters/mit2mmseg.py) in the tools directory to convert the key of models from [the official repo](https://github.com/NVlabs/SegFormer) to MMSegmentation style. ```shell -python tools/model_converters/swin2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH} +python tools/model_converters/mit2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH} ``` This script convert model from `PRETRAIN_PATH` and store the converted model in `STORE_PATH`. From fec2c0fed7e99487fcbc248c19af582416f4e26b Mon Sep 17 00:00:00 2001 From: lkm2835 Date: Tue, 9 Nov 2021 01:32:53 +0900 Subject: [PATCH 02/12] [Feature] Add CutOut transform --- mmseg/datasets/pipelines/__init__.py | 4 +- mmseg/datasets/pipelines/transforms.py | 73 ++++++++++++++++++++++++++ tests/test_data/test_transform.py | 66 +++++++++++++++++++++++ 3 files changed, 141 insertions(+), 2 deletions(-) diff --git a/mmseg/datasets/pipelines/__init__.py b/mmseg/datasets/pipelines/__init__.py index 563ae62807..3562b46369 100644 --- a/mmseg/datasets/pipelines/__init__.py +++ b/mmseg/datasets/pipelines/__init__.py @@ -4,7 +4,7 @@ Transpose, to_tensor) from .loading import LoadAnnotations, LoadImageFromFile from .test_time_aug import MultiScaleFlipAug -from .transforms import (CLAHE, AdjustGamma, Normalize, Pad, +from .transforms import (CLAHE, AdjustGamma, CutOut, Normalize, Pad, PhotoMetricDistortion, RandomCrop, RandomFlip, RandomRotate, Rerange, Resize, RGB2Gray, SegRescale) @@ -13,5 +13,5 @@ 'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile', 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop', 'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate', - 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray' + 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'CutOut' ] diff --git a/mmseg/datasets/pipelines/transforms.py b/mmseg/datasets/pipelines/transforms.py index f2a642c141..b1c170399e 100644 --- a/mmseg/datasets/pipelines/transforms.py +++ b/mmseg/datasets/pipelines/transforms.py @@ -948,3 +948,76 @@ def __repr__(self): f'{self.saturation_upper}), ' f'hue_delta={self.hue_delta})') return repr_str + + +@PIPELINES.register_module() +class CutOut: + """CutOut operation. + + Randomly drop some regions of image used in + `Cutout `_. + Args: + n_holes (int | tuple[int, int]): Number of regions to be dropped. + If it is given as a list, number of holes will be randomly + selected from the closed interval [`n_holes[0]`, `n_holes[1]`]. + cutout_shape (tuple[int, int] | list[tuple[int, int]]): The candidate + shape of dropped regions. It can be `tuple[int, int]` to use a + fixed cutout shape, or `list[tuple[int, int]]` to randomly choose + shape from the list. + cutout_ratio (tuple[float, float] | list[tuple[float, float]]): The + candidate ratio of dropped regions. It can be `tuple[float, float]` + to use a fixed ratio or `list[tuple[float, float]]` to randomly + choose ratio from the list. Please note that `cutout_shape` + and `cutout_ratio` cannot be both given at the same time. + fill_in (tuple[float, float, float] | tuple[int, int, int]): The value + of pixel to fill in the dropped regions. Default: (0, 0, 0). + """ + + def __init__(self, + n_holes, + cutout_shape=None, + cutout_ratio=None, + fill_in=(0, 0, 0)): + + assert (cutout_shape is None) ^ (cutout_ratio is None), \ + 'Either cutout_shape or cutout_ratio should be specified.' + assert (isinstance(cutout_shape, (list, tuple)) + or isinstance(cutout_ratio, (list, tuple))) + if isinstance(n_holes, tuple): + assert len(n_holes) == 2 and 0 <= n_holes[0] < n_holes[1] + else: + n_holes = (n_holes, n_holes) + self.n_holes = n_holes + self.fill_in = fill_in + self.with_ratio = cutout_ratio is not None + self.candidates = cutout_ratio if self.with_ratio else cutout_shape + if not isinstance(self.candidates, list): + self.candidates = [self.candidates] + + def __call__(self, results): + """Call function to drop some regions of image.""" + h, w, c = results['img'].shape + n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1) + for _ in range(n_holes): + x1 = np.random.randint(0, w) + y1 = np.random.randint(0, h) + index = np.random.randint(0, len(self.candidates)) + if not self.with_ratio: + cutout_w, cutout_h = self.candidates[index] + else: + cutout_w = int(self.candidates[index][0] * w) + cutout_h = int(self.candidates[index][1] * h) + + x2 = np.clip(x1 + cutout_w, 0, w) + y2 = np.clip(y1 + cutout_h, 0, h) + results['img'][y1:y2, x1:x2, :] = self.fill_in + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(n_holes={self.n_holes}, ' + repr_str += (f'cutout_ratio={self.candidates}, ' if self.with_ratio + else f'cutout_shape={self.candidates}, ') + repr_str += f'fill_in={self.fill_in})' + return repr_str diff --git a/tests/test_data/test_transform.py b/tests/test_data/test_transform.py index 3862e75a34..af666f8557 100644 --- a/tests/test_data/test_transform.py +++ b/tests/test_data/test_transform.py @@ -497,3 +497,69 @@ def test_seg_rescale(): rescale_module = build_from_cfg(transform, PIPELINES) rescale_results = rescale_module(results.copy()) assert rescale_results['gt_semantic_seg'].shape == (h, w) + + +def test_cutout(): + # test n_holes + with pytest.raises(AssertionError): + transform = dict(type='CutOut', n_holes=(5, 3), cutout_shape=(8, 8)) + build_from_cfg(transform, PIPELINES) + with pytest.raises(AssertionError): + transform = dict(type='CutOut', n_holes=(3, 4, 5), cutout_shape=(8, 8)) + build_from_cfg(transform, PIPELINES) + # test cutout_shape and cutout_ratio + with pytest.raises(AssertionError): + transform = dict(type='CutOut', n_holes=1, cutout_shape=8) + build_from_cfg(transform, PIPELINES) + with pytest.raises(AssertionError): + transform = dict(type='CutOut', n_holes=1, cutout_ratio=0.2) + build_from_cfg(transform, PIPELINES) + # either of cutout_shape and cutout_ratio should be given + with pytest.raises(AssertionError): + transform = dict(type='CutOut', n_holes=1) + build_from_cfg(transform, PIPELINES) + with pytest.raises(AssertionError): + transform = dict( + type='CutOut', + n_holes=1, + cutout_shape=(2, 2), + cutout_ratio=(0.4, 0.4)) + build_from_cfg(transform, PIPELINES) + + results = dict() + img = mmcv.imread( + osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') + + results['img'] = img + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + results['pad_shape'] = img.shape + results['img_fields'] = ['img'] + + transform = dict(type='CutOut', n_holes=1, cutout_shape=(10, 10)) + cutout_module = build_from_cfg(transform, PIPELINES) + cutout_result = cutout_module(copy.deepcopy(results)) + assert cutout_result['img'].sum() < img.sum() + + transform = dict(type='CutOut', n_holes=1, cutout_ratio=(0.8, 0.8)) + cutout_module = build_from_cfg(transform, PIPELINES) + cutout_result = cutout_module(copy.deepcopy(results)) + assert cutout_result['img'].sum() < img.sum() + + transform = dict( + type='CutOut', + n_holes=(2, 4), + cutout_shape=[(10, 10), (15, 15)], + fill_in=(255, 255, 255)) + cutout_module = build_from_cfg(transform, PIPELINES) + cutout_result = cutout_module(copy.deepcopy(results)) + assert cutout_result['img'].sum() > img.sum() + + transform = dict( + type='CutOut', + n_holes=1, + cutout_ratio=(0.8, 0.8), + fill_in=(255, 255, 255)) + cutout_module = build_from_cfg(transform, PIPELINES) + cutout_result = cutout_module(copy.deepcopy(results)) + assert cutout_result['img'].sum() > img.sum() From 4aaebbe3de32ca77113eed846bca39ae00c6d4ef Mon Sep 17 00:00:00 2001 From: lkm2835 Date: Tue, 9 Nov 2021 22:44:55 +0900 Subject: [PATCH 03/12] CutOut repr covered by unittests --- tests/test_data/test_transform.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_data/test_transform.py b/tests/test_data/test_transform.py index af666f8557..18a7f6ac35 100644 --- a/tests/test_data/test_transform.py +++ b/tests/test_data/test_transform.py @@ -538,11 +538,13 @@ def test_cutout(): transform = dict(type='CutOut', n_holes=1, cutout_shape=(10, 10)) cutout_module = build_from_cfg(transform, PIPELINES) + assert 'cutout_shape' in repr(cutout_module) cutout_result = cutout_module(copy.deepcopy(results)) assert cutout_result['img'].sum() < img.sum() transform = dict(type='CutOut', n_holes=1, cutout_ratio=(0.8, 0.8)) cutout_module = build_from_cfg(transform, PIPELINES) + assert 'cutout_ratio' in repr(cutout_module) cutout_result = cutout_module(copy.deepcopy(results)) assert cutout_result['img'].sum() < img.sum() From 606bf12f34a17c99998bffd57abf5ef07bf802fd Mon Sep 17 00:00:00 2001 From: lkm2835 Date: Wed, 10 Nov 2021 15:24:42 +0900 Subject: [PATCH 04/12] Cutout ignore index, test --- mmseg/datasets/pipelines/transforms.py | 7 ++++++- tests/test_data/test_transform.py | 13 +++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/mmseg/datasets/pipelines/transforms.py b/mmseg/datasets/pipelines/transforms.py index b1c170399e..f952e93a35 100644 --- a/mmseg/datasets/pipelines/transforms.py +++ b/mmseg/datasets/pipelines/transforms.py @@ -977,7 +977,8 @@ def __init__(self, n_holes, cutout_shape=None, cutout_ratio=None, - fill_in=(0, 0, 0)): + fill_in=(0, 0, 0), + ignore_index=255): assert (cutout_shape is None) ^ (cutout_ratio is None), \ 'Either cutout_shape or cutout_ratio should be specified.' @@ -989,6 +990,7 @@ def __init__(self, n_holes = (n_holes, n_holes) self.n_holes = n_holes self.fill_in = fill_in + self.ignore_index = ignore_index self.with_ratio = cutout_ratio is not None self.candidates = cutout_ratio if self.with_ratio else cutout_shape if not isinstance(self.candidates, list): @@ -1012,6 +1014,9 @@ def __call__(self, results): y2 = np.clip(y1 + cutout_h, 0, h) results['img'][y1:y2, x1:x2, :] = self.fill_in + for key in results.get('seg_fields', []): + results[key][y1:y2, x1:x2] = self.ignore_index + return results def __repr__(self): diff --git a/tests/test_data/test_transform.py b/tests/test_data/test_transform.py index 18a7f6ac35..c5b2a67295 100644 --- a/tests/test_data/test_transform.py +++ b/tests/test_data/test_transform.py @@ -530,7 +530,12 @@ def test_cutout(): img = mmcv.imread( osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') + seg = np.array( + Image.open(osp.join(osp.dirname(__file__), '../data/seg.png'))) + results['img'] = img + results['gt_semantic_seg'] = seg + results['seg_fields'] = ['gt_semantic_seg'] results['img_shape'] = img.shape results['ori_shape'] = img.shape results['pad_shape'] = img.shape @@ -552,16 +557,20 @@ def test_cutout(): type='CutOut', n_holes=(2, 4), cutout_shape=[(10, 10), (15, 15)], - fill_in=(255, 255, 255)) + fill_in=(255, 255, 255), + ignore_index=255) cutout_module = build_from_cfg(transform, PIPELINES) cutout_result = cutout_module(copy.deepcopy(results)) assert cutout_result['img'].sum() > img.sum() + assert cutout_result['gt_semantic_seg'].sum() > seg.sum() transform = dict( type='CutOut', n_holes=1, cutout_ratio=(0.8, 0.8), - fill_in=(255, 255, 255)) + fill_in=(255, 255, 255), + ignore_index=255) cutout_module = build_from_cfg(transform, PIPELINES) cutout_result = cutout_module(copy.deepcopy(results)) assert cutout_result['img'].sum() > img.sum() + assert cutout_result['gt_semantic_seg'].sum() > seg.sum() From 8cde1c93e4fa1c3845dfaa1c79f7571fe3f9c6ea Mon Sep 17 00:00:00 2001 From: lkm2835 Date: Wed, 10 Nov 2021 18:34:27 +0900 Subject: [PATCH 05/12] ignore_index -> seg_fill_in, defualt is None --- mmseg/datasets/pipelines/transforms.py | 9 +++++---- tests/test_data/test_transform.py | 4 ++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/mmseg/datasets/pipelines/transforms.py b/mmseg/datasets/pipelines/transforms.py index f952e93a35..6392ec6b39 100644 --- a/mmseg/datasets/pipelines/transforms.py +++ b/mmseg/datasets/pipelines/transforms.py @@ -978,7 +978,7 @@ def __init__(self, cutout_shape=None, cutout_ratio=None, fill_in=(0, 0, 0), - ignore_index=255): + seg_fill_in=None): assert (cutout_shape is None) ^ (cutout_ratio is None), \ 'Either cutout_shape or cutout_ratio should be specified.' @@ -990,7 +990,7 @@ def __init__(self, n_holes = (n_holes, n_holes) self.n_holes = n_holes self.fill_in = fill_in - self.ignore_index = ignore_index + self.seg_fill_in = seg_fill_in self.with_ratio = cutout_ratio is not None self.candidates = cutout_ratio if self.with_ratio else cutout_shape if not isinstance(self.candidates, list): @@ -1014,8 +1014,9 @@ def __call__(self, results): y2 = np.clip(y1 + cutout_h, 0, h) results['img'][y1:y2, x1:x2, :] = self.fill_in - for key in results.get('seg_fields', []): - results[key][y1:y2, x1:x2] = self.ignore_index + if self.seg_fill_in is not None: + for key in results.get('seg_fields', []): + results[key][y1:y2, x1:x2] = self.seg_fill_in return results diff --git a/tests/test_data/test_transform.py b/tests/test_data/test_transform.py index c5b2a67295..370da9f2cc 100644 --- a/tests/test_data/test_transform.py +++ b/tests/test_data/test_transform.py @@ -558,7 +558,7 @@ def test_cutout(): n_holes=(2, 4), cutout_shape=[(10, 10), (15, 15)], fill_in=(255, 255, 255), - ignore_index=255) + seg_fill_in=255) cutout_module = build_from_cfg(transform, PIPELINES) cutout_result = cutout_module(copy.deepcopy(results)) assert cutout_result['img'].sum() > img.sum() @@ -569,7 +569,7 @@ def test_cutout(): n_holes=1, cutout_ratio=(0.8, 0.8), fill_in=(255, 255, 255), - ignore_index=255) + seg_fill_in=255) cutout_module = build_from_cfg(transform, PIPELINES) cutout_result = cutout_module(copy.deepcopy(results)) assert cutout_result['img'].sum() > img.sum() From e23da408c4e5af58364fca13df17304a3b925a82 Mon Sep 17 00:00:00 2001 From: lkm2835 Date: Wed, 10 Nov 2021 18:46:25 +0900 Subject: [PATCH 06/12] seg_fill_in is added to repr --- mmseg/datasets/pipelines/transforms.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mmseg/datasets/pipelines/transforms.py b/mmseg/datasets/pipelines/transforms.py index 6392ec6b39..e92754c69e 100644 --- a/mmseg/datasets/pipelines/transforms.py +++ b/mmseg/datasets/pipelines/transforms.py @@ -1025,5 +1025,6 @@ def __repr__(self): repr_str += f'(n_holes={self.n_holes}, ' repr_str += (f'cutout_ratio={self.candidates}, ' if self.with_ratio else f'cutout_shape={self.candidates}, ') - repr_str += f'fill_in={self.fill_in})' + repr_str += f'fill_in={self.fill_in}, ' + repr_str += f'seg_fill_in={self.seg_fill_in})' return repr_str From 714093e10e5a143cad47fa02a48fcca281db8135 Mon Sep 17 00:00:00 2001 From: lkm2835 Date: Wed, 10 Nov 2021 18:49:48 +0900 Subject: [PATCH 07/12] test is modified for seg_fill_in is None --- tests/test_data/test_transform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_data/test_transform.py b/tests/test_data/test_transform.py index 370da9f2cc..bf905aeccf 100644 --- a/tests/test_data/test_transform.py +++ b/tests/test_data/test_transform.py @@ -558,11 +558,11 @@ def test_cutout(): n_holes=(2, 4), cutout_shape=[(10, 10), (15, 15)], fill_in=(255, 255, 255), - seg_fill_in=255) + seg_fill_in=None) cutout_module = build_from_cfg(transform, PIPELINES) cutout_result = cutout_module(copy.deepcopy(results)) assert cutout_result['img'].sum() > img.sum() - assert cutout_result['gt_semantic_seg'].sum() > seg.sum() + assert cutout_result['gt_semantic_seg'].sum() == seg.sum() transform = dict( type='CutOut', From 82599e024250dd980c8918def4b03a56d6089343 Mon Sep 17 00:00:00 2001 From: lkm2835 Date: Wed, 10 Nov 2021 20:27:34 +0900 Subject: [PATCH 08/12] seg_fill_in (int), 0-255 --- mmseg/datasets/pipelines/transforms.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mmseg/datasets/pipelines/transforms.py b/mmseg/datasets/pipelines/transforms.py index e92754c69e..600f2f5ac8 100644 --- a/mmseg/datasets/pipelines/transforms.py +++ b/mmseg/datasets/pipelines/transforms.py @@ -988,6 +988,9 @@ def __init__(self, assert len(n_holes) == 2 and 0 <= n_holes[0] < n_holes[1] else: n_holes = (n_holes, n_holes) + if seg_fill_in is not None: + assert (isinstance(seg_fill_in, int) and 0 <= seg_fill_in + and seg_fill_in <= 255) self.n_holes = n_holes self.fill_in = fill_in self.seg_fill_in = seg_fill_in From 57779b52cf2f0d7629b8bfc1532ed0acafb8d248 Mon Sep 17 00:00:00 2001 From: lkm2835 Date: Wed, 10 Nov 2021 20:36:32 +0900 Subject: [PATCH 09/12] add seg_fill_in test --- tests/test_data/test_transform.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/test_data/test_transform.py b/tests/test_data/test_transform.py index bf905aeccf..ba2dc3d788 100644 --- a/tests/test_data/test_transform.py +++ b/tests/test_data/test_transform.py @@ -525,6 +525,15 @@ def test_cutout(): cutout_shape=(2, 2), cutout_ratio=(0.4, 0.4)) build_from_cfg(transform, PIPELINES) + # test seg_fill_in + with pytest.raises(AssertionError): + transform = dict( + type='CutOut', n_holes=1, cutout_shape=(8, 8), seg_fill_in='a') + build_from_cfg(transform, PIPELINES) + with pytest.raises(AssertionError): + transform = dict( + type='CutOut', n_holes=1, cutout_shape=(8, 8), seg_fill_in=256) + build_from_cfg(transform, PIPELINES) results = dict() img = mmcv.imread( From a1a6f810fe48f8c4d1bc71053ad2539be0d33a8f Mon Sep 17 00:00:00 2001 From: lkm2835 Date: Wed, 10 Nov 2021 23:25:08 +0900 Subject: [PATCH 10/12] doc string for seg_fill_in --- mmseg/datasets/pipelines/transforms.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mmseg/datasets/pipelines/transforms.py b/mmseg/datasets/pipelines/transforms.py index 600f2f5ac8..dae9fddd37 100644 --- a/mmseg/datasets/pipelines/transforms.py +++ b/mmseg/datasets/pipelines/transforms.py @@ -971,6 +971,8 @@ class CutOut: and `cutout_ratio` cannot be both given at the same time. fill_in (tuple[float, float, float] | tuple[int, int, int]): The value of pixel to fill in the dropped regions. Default: (0, 0, 0). + seg_fill_in (int): The labels of pixel to fill in the dropped regions. + If seg_fill_in is None, skip. Default: None. """ def __init__(self, From d60a5a515f9cdbe8f569ccacb5e639d83346eb42 Mon Sep 17 00:00:00 2001 From: lkm2835 Date: Sat, 13 Nov 2021 00:32:30 +0900 Subject: [PATCH 11/12] rename CutOut to RandomCutOut, add prob --- mmseg/datasets/pipelines/__init__.py | 9 ++--- mmseg/datasets/pipelines/transforms.py | 49 +++++++++++++++----------- tests/test_data/test_transform.py | 48 ++++++++++++++++++------- 3 files changed, 69 insertions(+), 37 deletions(-) diff --git a/mmseg/datasets/pipelines/__init__.py b/mmseg/datasets/pipelines/__init__.py index 3562b46369..91d9e47491 100644 --- a/mmseg/datasets/pipelines/__init__.py +++ b/mmseg/datasets/pipelines/__init__.py @@ -4,14 +4,15 @@ Transpose, to_tensor) from .loading import LoadAnnotations, LoadImageFromFile from .test_time_aug import MultiScaleFlipAug -from .transforms import (CLAHE, AdjustGamma, CutOut, Normalize, Pad, - PhotoMetricDistortion, RandomCrop, RandomFlip, - RandomRotate, Rerange, Resize, RGB2Gray, SegRescale) +from .transforms import (CLAHE, AdjustGamma, Normalize, Pad, + PhotoMetricDistortion, RandomCrop, RandomCutOut, + RandomFlip, RandomRotate, Rerange, Resize, RGB2Gray, + SegRescale) __all__ = [ 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer', 'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile', 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop', 'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate', - 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'CutOut' + 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut' ] diff --git a/mmseg/datasets/pipelines/transforms.py b/mmseg/datasets/pipelines/transforms.py index dae9fddd37..567c960a10 100644 --- a/mmseg/datasets/pipelines/transforms.py +++ b/mmseg/datasets/pipelines/transforms.py @@ -951,12 +951,13 @@ def __repr__(self): @PIPELINES.register_module() -class CutOut: +class RandomCutOut(object): """CutOut operation. Randomly drop some regions of image used in `Cutout `_. Args: + prob (float): cutout probability. n_holes (int | tuple[int, int]): Number of regions to be dropped. If it is given as a list, number of holes will be randomly selected from the closed interval [`n_holes[0]`, `n_holes[1]`]. @@ -976,12 +977,14 @@ class CutOut: """ def __init__(self, + prob, n_holes, cutout_shape=None, cutout_ratio=None, fill_in=(0, 0, 0), seg_fill_in=None): + assert 0 <= prob and prob <= 1 assert (cutout_shape is None) ^ (cutout_ratio is None), \ 'Either cutout_shape or cutout_ratio should be specified.' assert (isinstance(cutout_shape, (list, tuple)) @@ -993,6 +996,7 @@ def __init__(self, if seg_fill_in is not None: assert (isinstance(seg_fill_in, int) and 0 <= seg_fill_in and seg_fill_in <= 255) + self.prob = prob self.n_holes = n_holes self.fill_in = fill_in self.seg_fill_in = seg_fill_in @@ -1003,31 +1007,34 @@ def __init__(self, def __call__(self, results): """Call function to drop some regions of image.""" - h, w, c = results['img'].shape - n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1) - for _ in range(n_holes): - x1 = np.random.randint(0, w) - y1 = np.random.randint(0, h) - index = np.random.randint(0, len(self.candidates)) - if not self.with_ratio: - cutout_w, cutout_h = self.candidates[index] - else: - cutout_w = int(self.candidates[index][0] * w) - cutout_h = int(self.candidates[index][1] * h) - - x2 = np.clip(x1 + cutout_w, 0, w) - y2 = np.clip(y1 + cutout_h, 0, h) - results['img'][y1:y2, x1:x2, :] = self.fill_in - - if self.seg_fill_in is not None: - for key in results.get('seg_fields', []): - results[key][y1:y2, x1:x2] = self.seg_fill_in + cutout = True if np.random.rand() < self.prob else False + if cutout: + h, w, c = results['img'].shape + n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1) + for _ in range(n_holes): + x1 = np.random.randint(0, w) + y1 = np.random.randint(0, h) + index = np.random.randint(0, len(self.candidates)) + if not self.with_ratio: + cutout_w, cutout_h = self.candidates[index] + else: + cutout_w = int(self.candidates[index][0] * w) + cutout_h = int(self.candidates[index][1] * h) + + x2 = np.clip(x1 + cutout_w, 0, w) + y2 = np.clip(y1 + cutout_h, 0, h) + results['img'][y1:y2, x1:x2, :] = self.fill_in + + if self.seg_fill_in is not None: + for key in results.get('seg_fields', []): + results[key][y1:y2, x1:x2] = self.seg_fill_in return results def __repr__(self): repr_str = self.__class__.__name__ - repr_str += f'(n_holes={self.n_holes}, ' + repr_str += f'(prob={self.prob}, ' + repr_str += f'n_holes={self.n_holes}, ' repr_str += (f'cutout_ratio={self.candidates}, ' if self.with_ratio else f'cutout_shape={self.candidates}, ') repr_str += f'fill_in={self.fill_in}, ' diff --git a/tests/test_data/test_transform.py b/tests/test_data/test_transform.py index ba2dc3d788..f91c7c46f3 100644 --- a/tests/test_data/test_transform.py +++ b/tests/test_data/test_transform.py @@ -500,27 +500,39 @@ def test_seg_rescale(): def test_cutout(): + # test prob + with pytest.raises(AssertionError): + transform = dict(type='RandomCutOut', prob=1.5, n_holes=1) + build_from_cfg(transform, PIPELINES) # test n_holes with pytest.raises(AssertionError): - transform = dict(type='CutOut', n_holes=(5, 3), cutout_shape=(8, 8)) + transform = dict( + type='RandomCutOut', prob=0.5, n_holes=(5, 3), cutout_shape=(8, 8)) build_from_cfg(transform, PIPELINES) with pytest.raises(AssertionError): - transform = dict(type='CutOut', n_holes=(3, 4, 5), cutout_shape=(8, 8)) + transform = dict( + type='RandomCutOut', + prob=0.5, + n_holes=(3, 4, 5), + cutout_shape=(8, 8)) build_from_cfg(transform, PIPELINES) # test cutout_shape and cutout_ratio with pytest.raises(AssertionError): - transform = dict(type='CutOut', n_holes=1, cutout_shape=8) + transform = dict( + type='RandomCutOut', prob=0.5, n_holes=1, cutout_shape=8) build_from_cfg(transform, PIPELINES) with pytest.raises(AssertionError): - transform = dict(type='CutOut', n_holes=1, cutout_ratio=0.2) + transform = dict( + type='RandomCutOut', prob=0.5, n_holes=1, cutout_ratio=0.2) build_from_cfg(transform, PIPELINES) # either of cutout_shape and cutout_ratio should be given with pytest.raises(AssertionError): - transform = dict(type='CutOut', n_holes=1) + transform = dict(type='RandomCutOut', prob=0.5, n_holes=1) build_from_cfg(transform, PIPELINES) with pytest.raises(AssertionError): transform = dict( - type='CutOut', + type='RandomCutOut', + prob=0.5, n_holes=1, cutout_shape=(2, 2), cutout_ratio=(0.4, 0.4)) @@ -528,11 +540,19 @@ def test_cutout(): # test seg_fill_in with pytest.raises(AssertionError): transform = dict( - type='CutOut', n_holes=1, cutout_shape=(8, 8), seg_fill_in='a') + type='RandomCutOut', + prob=0.5, + n_holes=1, + cutout_shape=(8, 8), + seg_fill_in='a') build_from_cfg(transform, PIPELINES) with pytest.raises(AssertionError): transform = dict( - type='CutOut', n_holes=1, cutout_shape=(8, 8), seg_fill_in=256) + type='RandomCutOut', + prob=0.5, + n_holes=1, + cutout_shape=(8, 8), + seg_fill_in=256) build_from_cfg(transform, PIPELINES) results = dict() @@ -550,20 +570,23 @@ def test_cutout(): results['pad_shape'] = img.shape results['img_fields'] = ['img'] - transform = dict(type='CutOut', n_holes=1, cutout_shape=(10, 10)) + transform = dict( + type='RandomCutOut', prob=1, n_holes=1, cutout_shape=(10, 10)) cutout_module = build_from_cfg(transform, PIPELINES) assert 'cutout_shape' in repr(cutout_module) cutout_result = cutout_module(copy.deepcopy(results)) assert cutout_result['img'].sum() < img.sum() - transform = dict(type='CutOut', n_holes=1, cutout_ratio=(0.8, 0.8)) + transform = dict( + type='RandomCutOut', prob=1, n_holes=1, cutout_ratio=(0.8, 0.8)) cutout_module = build_from_cfg(transform, PIPELINES) assert 'cutout_ratio' in repr(cutout_module) cutout_result = cutout_module(copy.deepcopy(results)) assert cutout_result['img'].sum() < img.sum() transform = dict( - type='CutOut', + type='RandomCutOut', + prob=1, n_holes=(2, 4), cutout_shape=[(10, 10), (15, 15)], fill_in=(255, 255, 255), @@ -574,7 +597,8 @@ def test_cutout(): assert cutout_result['gt_semantic_seg'].sum() == seg.sum() transform = dict( - type='CutOut', + type='RandomCutOut', + prob=1, n_holes=1, cutout_ratio=(0.8, 0.8), fill_in=(255, 255, 255), From 554f1ea6f931c59e8a0b2c230700a918811c7280 Mon Sep 17 00:00:00 2001 From: lkm2835 Date: Tue, 16 Nov 2021 00:47:28 +0900 Subject: [PATCH 12/12] Add unittest when cutout is False --- tests/test_data/test_transform.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test_data/test_transform.py b/tests/test_data/test_transform.py index f91c7c46f3..ab7ffe0664 100644 --- a/tests/test_data/test_transform.py +++ b/tests/test_data/test_transform.py @@ -584,6 +584,13 @@ def test_cutout(): cutout_result = cutout_module(copy.deepcopy(results)) assert cutout_result['img'].sum() < img.sum() + transform = dict( + type='RandomCutOut', prob=0, n_holes=1, cutout_ratio=(0.8, 0.8)) + cutout_module = build_from_cfg(transform, PIPELINES) + cutout_result = cutout_module(copy.deepcopy(results)) + assert cutout_result['img'].sum() == img.sum() + assert cutout_result['gt_semantic_seg'].sum() == seg.sum() + transform = dict( type='RandomCutOut', prob=1,