diff --git a/mmseg/datasets/pipelines/__init__.py b/mmseg/datasets/pipelines/__init__.py index 563ae62807..91d9e47491 100644 --- a/mmseg/datasets/pipelines/__init__.py +++ b/mmseg/datasets/pipelines/__init__.py @@ -5,13 +5,14 @@ from .loading import LoadAnnotations, LoadImageFromFile from .test_time_aug import MultiScaleFlipAug from .transforms import (CLAHE, AdjustGamma, Normalize, Pad, - PhotoMetricDistortion, RandomCrop, RandomFlip, - RandomRotate, Rerange, Resize, RGB2Gray, SegRescale) + PhotoMetricDistortion, RandomCrop, RandomCutOut, + RandomFlip, RandomRotate, Rerange, Resize, RGB2Gray, + SegRescale) __all__ = [ 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer', 'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile', 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop', 'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate', - 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray' + 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut' ] diff --git a/mmseg/datasets/pipelines/transforms.py b/mmseg/datasets/pipelines/transforms.py index f2a642c141..567c960a10 100644 --- a/mmseg/datasets/pipelines/transforms.py +++ b/mmseg/datasets/pipelines/transforms.py @@ -948,3 +948,95 @@ def __repr__(self): f'{self.saturation_upper}), ' f'hue_delta={self.hue_delta})') return repr_str + + +@PIPELINES.register_module() +class RandomCutOut(object): + """CutOut operation. + + Randomly drop some regions of image used in + `Cutout `_. + Args: + prob (float): cutout probability. + n_holes (int | tuple[int, int]): Number of regions to be dropped. + If it is given as a list, number of holes will be randomly + selected from the closed interval [`n_holes[0]`, `n_holes[1]`]. + cutout_shape (tuple[int, int] | list[tuple[int, int]]): The candidate + shape of dropped regions. It can be `tuple[int, int]` to use a + fixed cutout shape, or `list[tuple[int, int]]` to randomly choose + shape from the list. + cutout_ratio (tuple[float, float] | list[tuple[float, float]]): The + candidate ratio of dropped regions. It can be `tuple[float, float]` + to use a fixed ratio or `list[tuple[float, float]]` to randomly + choose ratio from the list. Please note that `cutout_shape` + and `cutout_ratio` cannot be both given at the same time. + fill_in (tuple[float, float, float] | tuple[int, int, int]): The value + of pixel to fill in the dropped regions. Default: (0, 0, 0). + seg_fill_in (int): The labels of pixel to fill in the dropped regions. + If seg_fill_in is None, skip. Default: None. + """ + + def __init__(self, + prob, + n_holes, + cutout_shape=None, + cutout_ratio=None, + fill_in=(0, 0, 0), + seg_fill_in=None): + + assert 0 <= prob and prob <= 1 + assert (cutout_shape is None) ^ (cutout_ratio is None), \ + 'Either cutout_shape or cutout_ratio should be specified.' + assert (isinstance(cutout_shape, (list, tuple)) + or isinstance(cutout_ratio, (list, tuple))) + if isinstance(n_holes, tuple): + assert len(n_holes) == 2 and 0 <= n_holes[0] < n_holes[1] + else: + n_holes = (n_holes, n_holes) + if seg_fill_in is not None: + assert (isinstance(seg_fill_in, int) and 0 <= seg_fill_in + and seg_fill_in <= 255) + self.prob = prob + self.n_holes = n_holes + self.fill_in = fill_in + self.seg_fill_in = seg_fill_in + self.with_ratio = cutout_ratio is not None + self.candidates = cutout_ratio if self.with_ratio else cutout_shape + if not isinstance(self.candidates, list): + self.candidates = [self.candidates] + + def __call__(self, results): + """Call function to drop some regions of image.""" + cutout = True if np.random.rand() < self.prob else False + if cutout: + h, w, c = results['img'].shape + n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1) + for _ in range(n_holes): + x1 = np.random.randint(0, w) + y1 = np.random.randint(0, h) + index = np.random.randint(0, len(self.candidates)) + if not self.with_ratio: + cutout_w, cutout_h = self.candidates[index] + else: + cutout_w = int(self.candidates[index][0] * w) + cutout_h = int(self.candidates[index][1] * h) + + x2 = np.clip(x1 + cutout_w, 0, w) + y2 = np.clip(y1 + cutout_h, 0, h) + results['img'][y1:y2, x1:x2, :] = self.fill_in + + if self.seg_fill_in is not None: + for key in results.get('seg_fields', []): + results[key][y1:y2, x1:x2] = self.seg_fill_in + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' + repr_str += f'n_holes={self.n_holes}, ' + repr_str += (f'cutout_ratio={self.candidates}, ' if self.with_ratio + else f'cutout_shape={self.candidates}, ') + repr_str += f'fill_in={self.fill_in}, ' + repr_str += f'seg_fill_in={self.seg_fill_in})' + return repr_str diff --git a/tests/test_data/test_transform.py b/tests/test_data/test_transform.py index 3862e75a34..ab7ffe0664 100644 --- a/tests/test_data/test_transform.py +++ b/tests/test_data/test_transform.py @@ -497,3 +497,120 @@ def test_seg_rescale(): rescale_module = build_from_cfg(transform, PIPELINES) rescale_results = rescale_module(results.copy()) assert rescale_results['gt_semantic_seg'].shape == (h, w) + + +def test_cutout(): + # test prob + with pytest.raises(AssertionError): + transform = dict(type='RandomCutOut', prob=1.5, n_holes=1) + build_from_cfg(transform, PIPELINES) + # test n_holes + with pytest.raises(AssertionError): + transform = dict( + type='RandomCutOut', prob=0.5, n_holes=(5, 3), cutout_shape=(8, 8)) + build_from_cfg(transform, PIPELINES) + with pytest.raises(AssertionError): + transform = dict( + type='RandomCutOut', + prob=0.5, + n_holes=(3, 4, 5), + cutout_shape=(8, 8)) + build_from_cfg(transform, PIPELINES) + # test cutout_shape and cutout_ratio + with pytest.raises(AssertionError): + transform = dict( + type='RandomCutOut', prob=0.5, n_holes=1, cutout_shape=8) + build_from_cfg(transform, PIPELINES) + with pytest.raises(AssertionError): + transform = dict( + type='RandomCutOut', prob=0.5, n_holes=1, cutout_ratio=0.2) + build_from_cfg(transform, PIPELINES) + # either of cutout_shape and cutout_ratio should be given + with pytest.raises(AssertionError): + transform = dict(type='RandomCutOut', prob=0.5, n_holes=1) + build_from_cfg(transform, PIPELINES) + with pytest.raises(AssertionError): + transform = dict( + type='RandomCutOut', + prob=0.5, + n_holes=1, + cutout_shape=(2, 2), + cutout_ratio=(0.4, 0.4)) + build_from_cfg(transform, PIPELINES) + # test seg_fill_in + with pytest.raises(AssertionError): + transform = dict( + type='RandomCutOut', + prob=0.5, + n_holes=1, + cutout_shape=(8, 8), + seg_fill_in='a') + build_from_cfg(transform, PIPELINES) + with pytest.raises(AssertionError): + transform = dict( + type='RandomCutOut', + prob=0.5, + n_holes=1, + cutout_shape=(8, 8), + seg_fill_in=256) + build_from_cfg(transform, PIPELINES) + + results = dict() + img = mmcv.imread( + osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') + + seg = np.array( + Image.open(osp.join(osp.dirname(__file__), '../data/seg.png'))) + + results['img'] = img + results['gt_semantic_seg'] = seg + results['seg_fields'] = ['gt_semantic_seg'] + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + results['pad_shape'] = img.shape + results['img_fields'] = ['img'] + + transform = dict( + type='RandomCutOut', prob=1, n_holes=1, cutout_shape=(10, 10)) + cutout_module = build_from_cfg(transform, PIPELINES) + assert 'cutout_shape' in repr(cutout_module) + cutout_result = cutout_module(copy.deepcopy(results)) + assert cutout_result['img'].sum() < img.sum() + + transform = dict( + type='RandomCutOut', prob=1, n_holes=1, cutout_ratio=(0.8, 0.8)) + cutout_module = build_from_cfg(transform, PIPELINES) + assert 'cutout_ratio' in repr(cutout_module) + cutout_result = cutout_module(copy.deepcopy(results)) + assert cutout_result['img'].sum() < img.sum() + + transform = dict( + type='RandomCutOut', prob=0, n_holes=1, cutout_ratio=(0.8, 0.8)) + cutout_module = build_from_cfg(transform, PIPELINES) + cutout_result = cutout_module(copy.deepcopy(results)) + assert cutout_result['img'].sum() == img.sum() + assert cutout_result['gt_semantic_seg'].sum() == seg.sum() + + transform = dict( + type='RandomCutOut', + prob=1, + n_holes=(2, 4), + cutout_shape=[(10, 10), (15, 15)], + fill_in=(255, 255, 255), + seg_fill_in=None) + cutout_module = build_from_cfg(transform, PIPELINES) + cutout_result = cutout_module(copy.deepcopy(results)) + assert cutout_result['img'].sum() > img.sum() + assert cutout_result['gt_semantic_seg'].sum() == seg.sum() + + transform = dict( + type='RandomCutOut', + prob=1, + n_holes=1, + cutout_ratio=(0.8, 0.8), + fill_in=(255, 255, 255), + seg_fill_in=255) + cutout_module = build_from_cfg(transform, PIPELINES) + cutout_result = cutout_module(copy.deepcopy(results)) + assert cutout_result['img'].sum() > img.sum() + assert cutout_result['gt_semantic_seg'].sum() > seg.sum()