diff --git a/mmseg/datasets/pipelines/transforms.py b/mmseg/datasets/pipelines/transforms.py index 5d38fa27aa5..2f4262fdaca 100644 --- a/mmseg/datasets/pipelines/transforms.py +++ b/mmseg/datasets/pipelines/transforms.py @@ -392,6 +392,53 @@ def __repr__(self): return repr_str +@PIPELINES.register_module() +class Rerange(object): + """Rerange the image pixel value. + + Args: + min_value (float or int): Minimum value of the reranged image. + Default: 0. + max_value (float or int): Maximum value of the reranged image. + Default: 255. + """ + + def __init__(self, min_value=0, max_value=255): + assert isinstance(min_value, float) or isinstance(min_value, int) + assert isinstance(max_value, float) or isinstance(max_value, int) + assert min_value < max_value + self.min_value = min_value + self.max_value = max_value + + def __call__(self, results): + """Call function to rerange images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Reranged results. + """ + + img = results['img'] + img_min_value = np.min(img) + img_max_value = np.max(img) + + assert img_min_value < img_max_value + # rerange to [0, 1] + img = (img - img_min_value) / (img_max_value - img_min_value) + # rerange to [min_value, max_value] + img = img * (self.max_value - self.min_value) + self.min_value + results['img'] = img + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(min_value={self.min_value}, max_value={self.max_value})' + return repr_str + + @PIPELINES.register_module() class RandomCrop(object): """Random crop the image & seg. diff --git a/tests/test_data/test_transform.py b/tests/test_data/test_transform.py index 4d199e993f8..d4f81ecc6f8 100644 --- a/tests/test_data/test_transform.py +++ b/tests/test_data/test_transform.py @@ -330,6 +330,49 @@ def test_rgb2gray(): assert results['ori_shape'] == (h, w, c) +def test_rerange(): + # test assertion if min_value or max_value is illegal + with pytest.raises(AssertionError): + transform = dict(type='Rerange', min_value=[0], max_value=[255]) + build_from_cfg(transform, PIPELINES) + + # test assertion if min_value >= max_value + with pytest.raises(AssertionError): + transform = dict(type='Rerange', min_value=1, max_value=1) + build_from_cfg(transform, PIPELINES) + + # test assertion if img_min_value == img_max_value + with pytest.raises(AssertionError): + transform = dict(type='Rerange', min_value=0, max_value=1) + transform = build_from_cfg(transform, PIPELINES) + results = dict() + results['img'] = np.array([[1, 1], [1, 1]]) + transform(results) + + img_rerange_cfg = dict() + transform = dict(type='Rerange', **img_rerange_cfg) + transform = build_from_cfg(transform, PIPELINES) + results = dict() + img = mmcv.imread( + osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') + original_img = copy.deepcopy(img) + results['img'] = img + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + # Set initial values for default meta_keys + results['pad_shape'] = img.shape + results['scale_factor'] = 1.0 + + results = transform(results) + + min_value = np.min(original_img) + max_value = np.max(original_img) + converted_img = (original_img - min_value) / (max_value - min_value) * 255 + + assert np.allclose(results['img'], converted_img) + assert str(transform) == f'Rerange(min_value={0}, max_value={255})' + + def test_seg_rescale(): results = dict() seg = np.array(